diff --git a/scripts/run_benchmark.py b/scripts/run_benchmark.py index 83f67fb..9e1cf47 100644 --- a/scripts/run_benchmark.py +++ b/scripts/run_benchmark.py @@ -786,7 +786,7 @@ async def main(args: argparse.Namespace) -> None: help="(WIP) Run ablation study: full benchmark + 7 single-agent-removed runs", ) parser.add_argument( - "--provider", choices=["anthropic", "openai", "local"], + "--provider", choices=["anthropic", "openai", "local", "claude-code"], help="LLM provider (default: from config)", ) parser.add_argument("--model", help="Model name (e.g. qwen/qwen3.5-9b)") diff --git a/scripts/run_scenario.py b/scripts/run_scenario.py index 63efcc5..5cf62fd 100644 --- a/scripts/run_scenario.py +++ b/scripts/run_scenario.py @@ -370,7 +370,7 @@ async def main(args: argparse.Namespace) -> None: parser.add_argument("scenario", nargs="?", help="Scenario name or number prefix") parser.add_argument("--list", action="store_true", help="List available scenarios") parser.add_argument( - "--provider", choices=["anthropic", "openai", "local"], + "--provider", choices=["anthropic", "openai", "local", "claude-code"], help="LLM provider (default: from config)", ) parser.add_argument("--model", help="Model name (e.g. unsloth/nvidia-nemotron-3-nano-4b)") diff --git a/src/rle/agents/base_role.py b/src/rle/agents/base_role.py index 40eb579..b63bb10 100644 --- a/src/rle/agents/base_role.py +++ b/src/rle/agents/base_role.py @@ -223,7 +223,10 @@ def _call_provider( ChatMessage(role=MessageRole.SYSTEM, content=system_prompt), ChatMessage(role=MessageRole.USER, content=user_prompt), ] - if self._no_think and self.provider.provider_name != "anthropic": + if self._no_think and self.provider.provider_name not in ( + "anthropic", + "claudecode", + ): messages.append( ChatMessage(role=MessageRole.ASSISTANT, content=""), ) diff --git a/src/rle/config.py b/src/rle/config.py index 103894f..06e4fd9 100644 --- a/src/rle/config.py +++ b/src/rle/config.py @@ -13,6 +13,8 @@ ) from pydantic_settings import BaseSettings +from rle.providers.claude_code import ClaudeCodeProvider + _HELIX_PRESETS: dict[str, HelixConfig] = { "default": HelixConfig.default(), "research_heavy": HelixConfig.research_heavy(), @@ -23,6 +25,7 @@ "anthropic": AnthropicProvider, "openai": OpenAIProvider, "local": LocalProvider, + "claude-code": ClaudeCodeProvider, } diff --git a/src/rle/providers/__init__.py b/src/rle/providers/__init__.py new file mode 100644 index 0000000..362afb2 --- /dev/null +++ b/src/rle/providers/__init__.py @@ -0,0 +1,5 @@ +"""RLE-local LLM providers (beyond those shipped with felix-agent-sdk).""" + +from rle.providers.claude_code import ClaudeCodeProvider + +__all__ = ["ClaudeCodeProvider"] diff --git a/src/rle/providers/claude_code.py b/src/rle/providers/claude_code.py new file mode 100644 index 0000000..c5123a8 --- /dev/null +++ b/src/rle/providers/claude_code.py @@ -0,0 +1,195 @@ +"""Claude subscription provider: routes completions through ``claude -p``. + +Bills the user's Claude subscription (Pro/Max/Team) instead of an API key +by shelling out to the Claude Code CLI in headless print mode. Tools, MCP +servers, and project settings are all disabled so a call is as close to a +raw completion as the CLI allows. +""" + +from __future__ import annotations + +import json +import logging +import os +import shutil +import subprocess +import tempfile +from collections.abc import Iterator, Sequence +from typing import Any + +from felix_agent_sdk.providers.base import BaseProvider +from felix_agent_sdk.providers.errors import ProviderError +from felix_agent_sdk.providers.types import ( + ChatMessage, + CompletionResult, + MessageRole, + ProviderConfig, + StreamChunk, +) + +logger = logging.getLogger(__name__) + +# Env vars that would either crash the subprocess (nested-session guard) or +# silently flip billing from the subscription login to an API account. +_STRIPPED_ENV_VARS = ("CLAUDECODE", "ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN") + + +class ClaudeCodeProvider(BaseProvider): + """Provider that completes via the Claude Code CLI (``claude -p``). + + Uses the machine's existing Claude Code subscription login, so usage + bills the Claude plan rather than the API. ``temperature``, + ``max_tokens``, and ``stop_sequences`` are accepted but ignored — the + CLI does not expose sampling controls (Fable-class models reject them + anyway). + """ + + def __init__( + self, + model: str = "claude-fable-5", + base_url: str | None = None, + **kwargs: Any, + ) -> None: + config = ProviderConfig(model=model, api_key=None, base_url=base_url, **kwargs) + super().__init__(config) + self._cli_path: str | None = None + # Neutral cwd so the CLI never picks up a project's CLAUDE.md. + self._workdir = tempfile.mkdtemp(prefix="rle-claude-p-") + + def _resolve_cli(self) -> str: + if self._cli_path is None: + cli = shutil.which("claude") + if cli is None: + raise ProviderError( + "Claude Code CLI not found on PATH. Install it and log in " + "with a Claude subscription to use --provider claude-code.", + provider=self.provider_name, + ) + self._cli_path = cli + return self._cli_path + + def _split_messages(self, messages: Sequence[ChatMessage]) -> tuple[str, str]: + """Flatten messages into (system_prompt, user_prompt) for the CLI.""" + system_parts: list[str] = [] + user_parts: list[str] = [] + for msg in messages: + if msg.role == MessageRole.SYSTEM: + system_parts.append(msg.content) + elif msg.role == MessageRole.USER: + user_parts.append(msg.content) + else: + logger.debug( + "ClaudeCodeProvider ignoring %s message (prefills unsupported)", + msg.role.value, + ) + return "\n\n".join(system_parts), "\n\n".join(user_parts) + + def complete( + self, + messages: Sequence[ChatMessage], + *, + temperature: float | None = None, + max_tokens: int | None = None, + stop_sequences: list[str] | None = None, + **kwargs: Any, + ) -> CompletionResult: + cli = self._resolve_cli() + system_prompt, user_prompt = self._split_messages(messages) + + cmd = [ + cli, + "-p", + "--model", self.config.model, + "--output-format", "json", + "--tools", "", + "--setting-sources", "", + "--strict-mcp-config", + ] + if system_prompt: + cmd.extend(["--system-prompt", system_prompt]) + + env = {k: v for k, v in os.environ.items() if k not in _STRIPPED_ENV_VARS} + + try: + proc = subprocess.run( + cmd, + input=user_prompt, + capture_output=True, + text=True, + encoding="utf-8", + timeout=self.config.timeout, + cwd=self._workdir, + env=env, + ) + except subprocess.TimeoutExpired as e: + raise ProviderError( + f"claude -p timed out after {self.config.timeout}s", + provider=self.provider_name, + ) from e + + if proc.returncode != 0: + detail = (proc.stderr or proc.stdout or "").strip()[-2000:] + raise ProviderError( + f"claude -p exited with code {proc.returncode}: {detail}", + provider=self.provider_name, + ) + + try: + data: dict[str, Any] = json.loads(proc.stdout) + except json.JSONDecodeError as e: + raise ProviderError( + f"claude -p returned non-JSON output: {proc.stdout[:500]!r}", + provider=self.provider_name, + ) from e + + if data.get("is_error"): + detail = str(data.get("result", ""))[:2000] + raise ProviderError( + f"claude -p reported an error: {detail}", + provider=self.provider_name, + ) + + usage_raw = data.get("usage", {}) + prompt_tokens = int(usage_raw.get("input_tokens", 0)) + completion_tokens = int(usage_raw.get("output_tokens", 0)) + model_usage = data.get("modelUsage", {}) + served_model = next(iter(model_usage), self.config.model) + + return CompletionResult( + content=str(data.get("result", "")), + model=served_model, + usage={ + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + finish_reason=str(data.get("stop_reason") or "stop"), + raw_response=data, + ) + + def stream( + self, + messages: Sequence[ChatMessage], + *, + temperature: float | None = None, + max_tokens: int | None = None, + stop_sequences: list[str] | None = None, + **kwargs: Any, + ) -> Iterator[StreamChunk]: + """Pseudo-stream: one chunk with the full completion, then a final chunk.""" + result = self.complete( + messages, + temperature=temperature, + max_tokens=max_tokens, + stop_sequences=stop_sequences, + **kwargs, + ) + yield StreamChunk(text=result.content) + yield StreamChunk(text="", is_final=True, usage=result.usage) + + def count_tokens(self, messages: Sequence[ChatMessage]) -> int: + """Character-based approximation (1 token ~ 4 characters).""" + return sum(len(m.content) for m in messages) // 4 + + +__all__ = ["ClaudeCodeProvider"] diff --git a/tests/unit/test_base_role.py b/tests/unit/test_base_role.py index ed61eed..9fcb4bf 100644 --- a/tests/unit/test_base_role.py +++ b/tests/unit/test_base_role.py @@ -84,10 +84,11 @@ def test_prefill_appended_for_openai_provider( assert messages[-1].role == MessageRole.ASSISTANT assert messages[-1].content == "" - def test_prefill_skipped_for_anthropic_provider( - self, mock_provider: MagicMock, helix: HelixGeometry, + @pytest.mark.parametrize("provider_name", ["anthropic", "claudecode"]) + def test_prefill_skipped_for_prefill_rejecting_providers( + self, mock_provider: MagicMock, helix: HelixGeometry, provider_name: str, ) -> None: - mock_provider.provider_name = "anthropic" + mock_provider.provider_name = provider_name agent = _DummyRoleAgent("d-01", mock_provider, helix, spawn_time=0.0) agent.set_no_think(True) agent._call_provider("sys", "user", 0.5, 100) diff --git a/tests/unit/test_claude_code_provider.py b/tests/unit/test_claude_code_provider.py new file mode 100644 index 0000000..93a1f97 --- /dev/null +++ b/tests/unit/test_claude_code_provider.py @@ -0,0 +1,169 @@ +"""Tests for ClaudeCodeProvider — all CLI calls are mocked.""" + +from __future__ import annotations + +import json +import subprocess +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from felix_agent_sdk.providers.errors import ProviderError +from felix_agent_sdk.providers.types import ChatMessage, MessageRole + +from rle.providers.claude_code import ClaudeCodeProvider + +MESSAGES = [ + ChatMessage(role=MessageRole.SYSTEM, content="You are a test agent."), + ChatMessage(role=MessageRole.USER, content="Do the thing."), +] + + +def _cli_envelope( + result: str = '{"actions": []}', + input_tokens: int = 100, + output_tokens: int = 20, + is_error: bool = False, +) -> str: + return json.dumps({ + "type": "result", + "subtype": "success", + "is_error": is_error, + "result": result, + "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, + "modelUsage": {"claude-fable-5": {"inputTokens": input_tokens}}, + "stop_reason": None, + }) + + +def _mock_proc(stdout: str, returncode: int = 0, stderr: str = "") -> MagicMock: + proc = MagicMock() + proc.returncode = returncode + proc.stdout = stdout + proc.stderr = stderr + return proc + + +def _run_complete( + provider: ClaudeCodeProvider, + proc: MagicMock | None = None, + side_effect: Any = None, +) -> tuple[Any, MagicMock]: + """Call provider.complete with mocked CLI resolution + subprocess.""" + with patch( + "rle.providers.claude_code.shutil.which", return_value="claude", + ), patch("rle.providers.claude_code.subprocess.run") as mock_run: + if side_effect is not None: + mock_run.side_effect = side_effect + else: + mock_run.return_value = proc or _mock_proc(_cli_envelope()) + result = provider.complete(MESSAGES) + return result, mock_run + + +class TestComplete: + def test_parses_result_and_usage(self) -> None: + provider = ClaudeCodeProvider() + result, _ = _run_complete( + provider, _mock_proc(_cli_envelope(result="hello", input_tokens=50, output_tokens=7)), + ) + assert result.content == "hello" + assert result.model == "claude-fable-5" + assert result.usage["prompt_tokens"] == 50 + assert result.usage["completion_tokens"] == 7 + assert result.usage["total_tokens"] == 57 + assert result.finish_reason == "stop" + + def test_cli_invocation_shape(self) -> None: + provider = ClaudeCodeProvider(model="claude-fable-5") + _, mock_run = _run_complete(provider) + + cmd = mock_run.call_args[0][0] + kwargs = mock_run.call_args[1] + assert "--model" in cmd and "claude-fable-5" in cmd + assert "--output-format" in cmd and "json" in cmd + assert "--tools" in cmd + assert "--strict-mcp-config" in cmd + assert "--system-prompt" in cmd + assert cmd[cmd.index("--system-prompt") + 1] == "You are a test agent." + assert kwargs["input"] == "Do the thing." + + def test_env_strips_session_and_billing_vars( + self, monkeypatch: pytest.MonkeyPatch, + ) -> None: + monkeypatch.setenv("CLAUDECODE", "1") + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-x") + monkeypatch.setenv("ANTHROPIC_AUTH_TOKEN", "tok") + monkeypatch.setenv("UNRELATED_VAR", "keep-me") + + provider = ClaudeCodeProvider() + _, mock_run = _run_complete(provider) + + env = mock_run.call_args[1]["env"] + assert "CLAUDECODE" not in env + assert "ANTHROPIC_API_KEY" not in env + assert "ANTHROPIC_AUTH_TOKEN" not in env + assert env["UNRELATED_VAR"] == "keep-me" + + def test_nonzero_exit_raises(self) -> None: + provider = ClaudeCodeProvider() + with pytest.raises(ProviderError, match="exited with code 1"): + _run_complete(provider, _mock_proc("", returncode=1, stderr="boom")) + + def test_is_error_raises(self) -> None: + provider = ClaudeCodeProvider() + with pytest.raises(ProviderError, match="reported an error"): + _run_complete(provider, _mock_proc(_cli_envelope(is_error=True))) + + def test_non_json_output_raises(self) -> None: + provider = ClaudeCodeProvider() + with pytest.raises(ProviderError, match="non-JSON"): + _run_complete(provider, _mock_proc("not json at all")) + + def test_timeout_raises(self) -> None: + provider = ClaudeCodeProvider() + with pytest.raises(ProviderError, match="timed out"): + _run_complete( + provider, + side_effect=subprocess.TimeoutExpired(cmd="claude", timeout=120), + ) + + def test_missing_cli_raises(self) -> None: + provider = ClaudeCodeProvider() + with patch("rle.providers.claude_code.shutil.which", return_value=None): + with pytest.raises(ProviderError, match="not found on PATH"): + provider.complete(MESSAGES) + + def test_assistant_messages_ignored(self) -> None: + provider = ClaudeCodeProvider() + messages = [ + *MESSAGES, + ChatMessage(role=MessageRole.ASSISTANT, content=""), + ] + with patch( + "rle.providers.claude_code.shutil.which", return_value="claude", + ), patch("rle.providers.claude_code.subprocess.run") as mock_run: + mock_run.return_value = _mock_proc(_cli_envelope()) + provider.complete(messages) + assert mock_run.call_args[1]["input"] == "Do the thing." + + +class TestStreamAndTokens: + def test_stream_yields_content_then_final(self) -> None: + provider = ClaudeCodeProvider() + with patch( + "rle.providers.claude_code.shutil.which", return_value="claude", + ), patch("rle.providers.claude_code.subprocess.run") as mock_run: + mock_run.return_value = _mock_proc(_cli_envelope(result="streamed")) + chunks = list(provider.stream(MESSAGES)) + assert chunks[0].text == "streamed" + assert chunks[-1].is_final is True + assert chunks[-1].usage is not None + + def test_count_tokens_heuristic(self) -> None: + provider = ClaudeCodeProvider() + messages = [ChatMessage(role=MessageRole.USER, content="a" * 400)] + assert provider.count_tokens(messages) == 100 + + def test_provider_name(self) -> None: + assert ClaudeCodeProvider().provider_name == "claudecode" diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index b85ae10..fd54e9e 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -33,6 +33,16 @@ def test_noop_when_unset(self, monkeypatch: pytest.MonkeyPatch) -> None: assert "ANTHROPIC_API_KEY" not in os.environ +class TestProviderRegistry: + def test_claude_code_provider_registered(self) -> None: + from rle.providers.claude_code import ClaudeCodeProvider + + config = RLEConfig(provider="claude-code", model="claude-fable-5") + provider = config.get_provider() + assert isinstance(provider, ClaudeCodeProvider) + assert provider.model == "claude-fable-5" + + class TestBridgeOpenRouterKey: def test_exports_key_to_process_env(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("OPENAI_API_KEY", "sentinel")