Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ async def main(args: argparse.Namespace) -> None:
help="(WIP) Run ablation study: full benchmark + 7 single-agent-removed runs",
)
parser.add_argument(
"--provider", choices=["anthropic", "openai", "local"],
"--provider", choices=["anthropic", "openai", "local", "claude-code"],
help="LLM provider (default: from config)",
)
parser.add_argument("--model", help="Model name (e.g. qwen/qwen3.5-9b)")
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ async def main(args: argparse.Namespace) -> None:
parser.add_argument("scenario", nargs="?", help="Scenario name or number prefix")
parser.add_argument("--list", action="store_true", help="List available scenarios")
parser.add_argument(
"--provider", choices=["anthropic", "openai", "local"],
"--provider", choices=["anthropic", "openai", "local", "claude-code"],
help="LLM provider (default: from config)",
)
parser.add_argument("--model", help="Model name (e.g. unsloth/nvidia-nemotron-3-nano-4b)")
Expand Down
5 changes: 4 additions & 1 deletion src/rle/agents/base_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,10 @@ def _call_provider(
ChatMessage(role=MessageRole.SYSTEM, content=system_prompt),
ChatMessage(role=MessageRole.USER, content=user_prompt),
]
if self._no_think and self.provider.provider_name != "anthropic":
if self._no_think and self.provider.provider_name not in (
"anthropic",
"claudecode",
):
messages.append(
ChatMessage(role=MessageRole.ASSISTANT, content="</think>"),
)
Expand Down
3 changes: 3 additions & 0 deletions src/rle/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
)
from pydantic_settings import BaseSettings

from rle.providers.claude_code import ClaudeCodeProvider

_HELIX_PRESETS: dict[str, HelixConfig] = {
"default": HelixConfig.default(),
"research_heavy": HelixConfig.research_heavy(),
Expand All @@ -23,6 +25,7 @@
"anthropic": AnthropicProvider,
"openai": OpenAIProvider,
"local": LocalProvider,
"claude-code": ClaudeCodeProvider,
}


Expand Down
5 changes: 5 additions & 0 deletions src/rle/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""RLE-local LLM providers (beyond those shipped with felix-agent-sdk)."""

from rle.providers.claude_code import ClaudeCodeProvider

__all__ = ["ClaudeCodeProvider"]
195 changes: 195 additions & 0 deletions src/rle/providers/claude_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""Claude subscription provider: routes completions through ``claude -p``.

Bills the user's Claude subscription (Pro/Max/Team) instead of an API key
by shelling out to the Claude Code CLI in headless print mode. Tools, MCP
servers, and project settings are all disabled so a call is as close to a
raw completion as the CLI allows.
"""

from __future__ import annotations

import json
import logging
import os
import shutil
import subprocess
import tempfile
from collections.abc import Iterator, Sequence
from typing import Any

from felix_agent_sdk.providers.base import BaseProvider
from felix_agent_sdk.providers.errors import ProviderError
from felix_agent_sdk.providers.types import (
ChatMessage,
CompletionResult,
MessageRole,
ProviderConfig,
StreamChunk,
)

logger = logging.getLogger(__name__)

# Env vars that would either crash the subprocess (nested-session guard) or
# silently flip billing from the subscription login to an API account.
_STRIPPED_ENV_VARS = ("CLAUDECODE", "ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN")


class ClaudeCodeProvider(BaseProvider):
"""Provider that completes via the Claude Code CLI (``claude -p``).

Uses the machine's existing Claude Code subscription login, so usage
bills the Claude plan rather than the API. ``temperature``,
``max_tokens``, and ``stop_sequences`` are accepted but ignored — the
CLI does not expose sampling controls (Fable-class models reject them
anyway).
"""

def __init__(
self,
model: str = "claude-fable-5",
base_url: str | None = None,
**kwargs: Any,
) -> None:
config = ProviderConfig(model=model, api_key=None, base_url=base_url, **kwargs)
super().__init__(config)
self._cli_path: str | None = None
# Neutral cwd so the CLI never picks up a project's CLAUDE.md.
self._workdir = tempfile.mkdtemp(prefix="rle-claude-p-")

def _resolve_cli(self) -> str:
if self._cli_path is None:
cli = shutil.which("claude")
if cli is None:
raise ProviderError(
"Claude Code CLI not found on PATH. Install it and log in "
"with a Claude subscription to use --provider claude-code.",
provider=self.provider_name,
)
self._cli_path = cli
return self._cli_path

def _split_messages(self, messages: Sequence[ChatMessage]) -> tuple[str, str]:
"""Flatten messages into (system_prompt, user_prompt) for the CLI."""
system_parts: list[str] = []
user_parts: list[str] = []
for msg in messages:
if msg.role == MessageRole.SYSTEM:
system_parts.append(msg.content)
elif msg.role == MessageRole.USER:
user_parts.append(msg.content)
else:
logger.debug(
"ClaudeCodeProvider ignoring %s message (prefills unsupported)",
msg.role.value,
)
return "\n\n".join(system_parts), "\n\n".join(user_parts)

def complete(
self,
messages: Sequence[ChatMessage],
*,
temperature: float | None = None,
max_tokens: int | None = None,
stop_sequences: list[str] | None = None,
**kwargs: Any,
) -> CompletionResult:
cli = self._resolve_cli()
system_prompt, user_prompt = self._split_messages(messages)

cmd = [
cli,
"-p",
"--model", self.config.model,
"--output-format", "json",
"--tools", "",
"--setting-sources", "",
"--strict-mcp-config",
]
if system_prompt:
cmd.extend(["--system-prompt", system_prompt])

env = {k: v for k, v in os.environ.items() if k not in _STRIPPED_ENV_VARS}

try:
proc = subprocess.run(
cmd,
input=user_prompt,
capture_output=True,
text=True,
encoding="utf-8",
timeout=self.config.timeout,
cwd=self._workdir,
env=env,
)
except subprocess.TimeoutExpired as e:
raise ProviderError(
f"claude -p timed out after {self.config.timeout}s",
provider=self.provider_name,
) from e

if proc.returncode != 0:
detail = (proc.stderr or proc.stdout or "").strip()[-2000:]
raise ProviderError(
f"claude -p exited with code {proc.returncode}: {detail}",
provider=self.provider_name,
)

try:
data: dict[str, Any] = json.loads(proc.stdout)
except json.JSONDecodeError as e:
raise ProviderError(
f"claude -p returned non-JSON output: {proc.stdout[:500]!r}",
provider=self.provider_name,
) from e

if data.get("is_error"):
detail = str(data.get("result", ""))[:2000]
raise ProviderError(
f"claude -p reported an error: {detail}",
provider=self.provider_name,
)

usage_raw = data.get("usage", {})
prompt_tokens = int(usage_raw.get("input_tokens", 0))
completion_tokens = int(usage_raw.get("output_tokens", 0))
model_usage = data.get("modelUsage", {})
served_model = next(iter(model_usage), self.config.model)

return CompletionResult(
content=str(data.get("result", "")),
model=served_model,
usage={
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
finish_reason=str(data.get("stop_reason") or "stop"),
raw_response=data,
)

def stream(
self,
messages: Sequence[ChatMessage],
*,
temperature: float | None = None,
max_tokens: int | None = None,
stop_sequences: list[str] | None = None,
**kwargs: Any,
) -> Iterator[StreamChunk]:
"""Pseudo-stream: one chunk with the full completion, then a final chunk."""
result = self.complete(
messages,
temperature=temperature,
max_tokens=max_tokens,
stop_sequences=stop_sequences,
**kwargs,
)
yield StreamChunk(text=result.content)
yield StreamChunk(text="", is_final=True, usage=result.usage)

def count_tokens(self, messages: Sequence[ChatMessage]) -> int:
"""Character-based approximation (1 token ~ 4 characters)."""
return sum(len(m.content) for m in messages) // 4


__all__ = ["ClaudeCodeProvider"]
7 changes: 4 additions & 3 deletions tests/unit/test_base_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ def test_prefill_appended_for_openai_provider(
assert messages[-1].role == MessageRole.ASSISTANT
assert messages[-1].content == "</think>"

def test_prefill_skipped_for_anthropic_provider(
self, mock_provider: MagicMock, helix: HelixGeometry,
@pytest.mark.parametrize("provider_name", ["anthropic", "claudecode"])
def test_prefill_skipped_for_prefill_rejecting_providers(
self, mock_provider: MagicMock, helix: HelixGeometry, provider_name: str,
) -> None:
mock_provider.provider_name = "anthropic"
mock_provider.provider_name = provider_name
agent = _DummyRoleAgent("d-01", mock_provider, helix, spawn_time=0.0)
agent.set_no_think(True)
agent._call_provider("sys", "user", 0.5, 100)
Expand Down
Loading
Loading