diff --git a/README.md b/README.md index d2b82507c..2cd46fbc9 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,20 @@ ART's functionality is divided into a **client** and a **server**. The OpenAI-co This training loop runs until a specified number of inference and training iterations have completed. +## Verifiers Interoperability + +ART includes lightweight helpers for moving rollout data between ART and Verifiers-style tasks without adding Verifiers as a required dependency: + +```python +import art + +vf_state = art.trajectory_to_verifiers_state(trajectory) +trajectory = art.verifiers_state_to_trajectory(vf_state) +trajectory_group = art.verifiers_states_to_trajectory_group([vf_state]) +``` + +By default, Verifiers assistant turns are imported as plain transcript messages. If you have the original response payload and want those assistant turns to remain trainable ART choices, pass a `choice_factory` to `verifiers_state_to_trajectory`. + ## 🧩 Supported Models ART should work with most vLLM/HuggingFace-transformers compatible causal language models, or at least the ones supported by [Unsloth](https://docs.unsloth.ai/get-started/all-our-models). Gemma 3 does not appear to be supported for the time being. If any other model isn't working for you, please let us know on [Discord](https://discord.gg/zbBHRUpwf4) or open an issue on [GitHub](https://github.com/openpipe/art/issues)! diff --git a/src/art/__init__.py b/src/art/__init__.py index 6cdc18667..2bd1e27cb 100644 --- a/src/art/__init__.py +++ b/src/art/__init__.py @@ -73,6 +73,11 @@ TrainSFTConfig, ) from .utils import retry +from .verifiers import ( + trajectory_to_verifiers_state, + verifiers_state_to_trajectory, + verifiers_states_to_trajectory_group, +) from .yield_trajectory import capture_yielded_trajectory, yield_trajectory __all__ = [ @@ -98,6 +103,9 @@ "TrainResult", "Trajectory", "TrajectoryGroup", + "trajectory_to_verifiers_state", + "verifiers_state_to_trajectory", + "verifiers_states_to_trajectory_group", "capture_yielded_trajectory", "yield_trajectory", ] diff --git a/src/art/verifiers.py b/src/art/verifiers.py new file mode 100644 index 000000000..28ad149c2 --- /dev/null +++ b/src/art/verifiers.py @@ -0,0 +1,281 @@ +"""Interop helpers for moving rollout data between ART and Verifiers. + +The helpers keep Verifiers optional: callers pass and receive plain mapping +objects that match the public Verifiers state/trajectory-step shape. +""" + +from collections.abc import Callable, Iterable, Mapping +from copy import deepcopy +from typing import Any, cast + +from openai.types.chat.chat_completion import Choice + +from .trajectories import Trajectory, TrajectoryGroup +from .types import Message, MessageOrChoice, Messages, MessagesAndChoices + +VerifiersState = dict[str, Any] +VerifiersStep = dict[str, Any] +ChoiceFactory = Callable[[dict[str, Any], Mapping[str, Any]], MessageOrChoice] + + +def trajectory_to_verifiers_state( + trajectory: Trajectory, + *, + task: Mapping[str, Any] | None = None, + example_id: str | int | None = None, + answer: Any | None = None, +) -> VerifiersState: + """Convert one ART trajectory into a Verifiers-compatible rollout state. + + ART stores a single flat ``messages_and_choices`` transcript. Verifiers + stores rollout turns as steps containing the prompt that produced a model + response plus the completion for that turn. This function creates one + Verifiers step for each ART ``Choice`` and keeps the final ART reward and + metrics on the state and final step. + """ + + task_data = dict(task or {}) + messages = _messages_and_choices_to_messages(trajectory.messages_and_choices) + steps = _messages_and_choices_to_steps(trajectory.messages_and_choices) + prompt = steps[0]["prompt"] if steps else messages + completion = messages[len(prompt) :] if len(prompt) <= len(messages) else [] + + if steps: + steps[-1]["reward"] = float(trajectory.reward) + steps[-1]["extras"]["art_metrics"] = dict(trajectory.metrics) + steps[-1]["extras"]["art_metadata"] = dict(trajectory.metadata) + is_truncated = any(bool(step.get("is_truncated")) for step in steps) + + state: VerifiersState = { + "task": task_data, + "prompt": deepcopy(prompt), + "completion": deepcopy(completion), + "trajectory": steps, + "reward": float(trajectory.reward), + "metrics": dict(trajectory.metrics), + "is_completed": True, + "is_truncated": is_truncated, + "stop_condition": "art_trajectory_imported", + "error": None, + } + if trajectory.metadata: + state["metadata"] = dict(trajectory.metadata) + if trajectory.logs: + state["logs"] = list(trajectory.logs) + if example_id is not None: + state["example_id"] = example_id + if answer is not None: + state["answer"] = answer + if "info" in task_data: + state["info"] = deepcopy(task_data["info"]) + return state + + +def verifiers_state_to_trajectory( + state: Mapping[str, Any], + *, + choice_factory: ChoiceFactory | None = None, +) -> Trajectory: + """Convert a Verifiers rollout state into an ART trajectory. + + Verifiers states usually contain serializable assistant messages rather + than OpenAI ``Choice`` objects with logprobs. By default those assistant + turns remain plain, non-trainable messages. Pass ``choice_factory`` when a + caller has the original response payload and wants selected assistant + messages to become trainable ART choices. + """ + + messages_and_choices = _state_messages_and_choices( + state, choice_factory=choice_factory + ) + return Trajectory( + messages_and_choices=messages_and_choices, + reward=float(state.get("reward") or 0.0), + metrics=dict(cast(Mapping[str, Any], state.get("metrics") or {})), + metadata=_state_metadata(state), + logs=list(cast(Iterable[str], state.get("logs") or [])), + ) + + +def verifiers_states_to_trajectory_group( + states: Iterable[Mapping[str, Any]], + *, + choice_factory: ChoiceFactory | None = None, +) -> TrajectoryGroup: + """Convert Verifiers rollout states into an ART trajectory group.""" + + return TrajectoryGroup( + [ + verifiers_state_to_trajectory(state, choice_factory=choice_factory) + for state in states + ] + ) + + +def _messages_and_choices_to_steps( + messages_and_choices: MessagesAndChoices, +) -> list[VerifiersStep]: + steps: list[VerifiersStep] = [] + transcript: Messages = [] + for item in messages_and_choices: + if isinstance(item, Choice): + completion = [_choice_to_message(item)] + steps.append( + { + "prompt": deepcopy(transcript), + "completion": deepcopy(completion), + "response": _choice_response_payload(item), + "tokens": None, + "reward": None, + "advantage": None, + "is_truncated": item.finish_reason == "length", + "extras": {}, + } + ) + transcript.extend(completion) + else: + transcript.append(_normalize_message(item)) + return steps + + +def _messages_and_choices_to_messages( + messages_and_choices: MessagesAndChoices, +) -> Messages: + messages: Messages = [] + for item in messages_and_choices: + if isinstance(item, Choice): + messages.append(_choice_to_message(item)) + else: + messages.append(_normalize_message(item)) + return messages + + +def _state_messages_and_choices( + state: Mapping[str, Any], + *, + choice_factory: ChoiceFactory | None, +) -> MessagesAndChoices: + steps = state.get("trajectory") or [] + if isinstance(steps, list) and steps: + return _append_state_completion_tail( + _steps_to_messages_and_choices(steps, choice_factory=choice_factory), + state, + ) + return _append_completion_messages( + [_normalize_message(message) for message in state.get("prompt") or []], + state.get("completion") or [], + {}, + choice_factory=choice_factory, + ) + + +def _steps_to_messages_and_choices( + steps: list[Any], + *, + choice_factory: ChoiceFactory | None, +) -> MessagesAndChoices: + transcript: MessagesAndChoices = [] + for raw_step in steps: + if not isinstance(raw_step, Mapping): + continue + prompt = [ + _normalize_message(message) for message in raw_step.get("prompt") or [] + ] + if not transcript: + transcript.extend(prompt) + elif _transcript_has_prefix(transcript, prompt): + transcript.extend(prompt[len(transcript) :]) + else: + transcript.extend(prompt) + transcript = _append_completion_messages( + transcript, + raw_step.get("completion") or [], + raw_step, + choice_factory=choice_factory, + ) + return transcript + + +def _append_state_completion_tail( + transcript: MessagesAndChoices, + state: Mapping[str, Any], +) -> MessagesAndChoices: + prompt = [_normalize_message(message) for message in state.get("prompt") or []] + completion = [ + _normalize_message(message) for message in state.get("completion") or [] + ] + state_transcript = prompt + completion + if not state_transcript: + return transcript + if _transcript_has_prefix(transcript, state_transcript): + transcript.extend(state_transcript[len(transcript) :]) + return transcript + + +def _append_completion_messages( + transcript: MessagesAndChoices, + completion: Any, + step: Mapping[str, Any], + *, + choice_factory: ChoiceFactory | None, +) -> MessagesAndChoices: + for message in completion: + normalized = _normalize_message(message) + if choice_factory is not None and normalized.get("role") == "assistant": + transcript.append(choice_factory(normalized, step)) + else: + transcript.append(normalized) + return transcript + + +def _transcript_has_prefix( + transcript: MessagesAndChoices, + prompt: Messages, +) -> bool: + if len(prompt) < len(transcript): + return False + prefix = [_message_or_choice_to_dict(item) for item in transcript] + return prefix == [dict(message) for message in prompt[: len(transcript)]] + + +def _choice_to_message(choice: Choice) -> Message: + message = choice.message.model_dump(mode="json", exclude_none=True) + message["role"] = "assistant" + if message.get("content") is None: + message["content"] = "" + return cast(Message, message) + + +def _choice_response_payload(choice: Choice) -> dict[str, Any]: + return { + "choices": [choice.model_dump(mode="json", exclude_none=True)], + } + + +def _normalize_message(message: Any) -> Message: + if hasattr(message, "model_dump"): + data = message.model_dump(mode="json", exclude_none=True) + else: + data = dict(message) + if data.get("content") is None: + data["content"] = "" + return cast(Message, data) + + +def _message_or_choice_to_dict(item: MessageOrChoice) -> dict[str, Any]: + if isinstance(item, Choice): + return dict(_choice_to_message(item)) + return dict(_normalize_message(item)) + + +def _state_metadata( + state: Mapping[str, Any], +) -> dict[str, str | int | float | bool | None]: + metadata = dict(cast(Mapping[str, Any], state.get("metadata") or {})) + for key in ("example_id", "stop_condition"): + if key not in state: + continue + value = state.get(key) + if isinstance(value, str | int | float | bool) or value is None: + metadata.setdefault(key, value) + return metadata diff --git a/tests/unit/test_verifiers_interop.py b/tests/unit/test_verifiers_interop.py new file mode 100644 index 000000000..e4aa76e41 --- /dev/null +++ b/tests/unit/test_verifiers_interop.py @@ -0,0 +1,181 @@ +from openai.types.chat import ChatCompletionMessage +from openai.types.chat.chat_completion import Choice + +import art +from art.verifiers import ( + trajectory_to_verifiers_state, + verifiers_state_to_trajectory, + verifiers_states_to_trajectory_group, +) + + +def assistant_choice(content: str, *, finish_reason: str = "stop") -> Choice: + return Choice( + finish_reason=finish_reason, + index=0, + logprobs=None, + message=ChatCompletionMessage( + role="assistant", + content=content, + refusal=None, + ), + ) + + +def test_trajectory_to_verifiers_state_splits_choice_turns(): + trajectory = art.Trajectory( + messages_and_choices=[ + {"role": "system", "content": "Answer briefly."}, + {"role": "user", "content": "First?"}, + assistant_choice("one"), + {"role": "user", "content": "Second?"}, + assistant_choice("two", finish_reason="length"), + ], + reward=0.75, + metrics={"exact": 1.0}, + metadata={"scenario": "smoke"}, + logs=["converted"], + ) + + state = trajectory_to_verifiers_state( + trajectory, + task={"prompt": [{"role": "user", "content": "First?"}], "info": {"id": "t1"}}, + example_id="ex-1", + answer="two", + ) + + assert state["reward"] == 0.75 + assert state["metrics"] == {"exact": 1.0} + assert state["metadata"] == {"scenario": "smoke"} + assert state["logs"] == ["converted"] + assert state["example_id"] == "ex-1" + assert state["answer"] == "two" + assert state["info"] == {"id": "t1"} + + steps = state["trajectory"] + assert len(steps) == 2 + assert steps[0]["prompt"] == [ + {"role": "system", "content": "Answer briefly."}, + {"role": "user", "content": "First?"}, + ] + assert steps[0]["completion"] == [{"role": "assistant", "content": "one"}] + assert steps[0]["reward"] is None + assert steps[1]["prompt"] == [ + {"role": "system", "content": "Answer briefly."}, + {"role": "user", "content": "First?"}, + {"role": "assistant", "content": "one"}, + {"role": "user", "content": "Second?"}, + ] + assert steps[1]["completion"] == [{"role": "assistant", "content": "two"}] + assert steps[1]["is_truncated"] is True + assert steps[1]["reward"] == 0.75 + assert steps[1]["extras"]["art_metrics"] == {"exact": 1.0} + assert state["is_truncated"] is True + + +def test_verifiers_state_to_trajectory_reconstructs_transcript(): + state = { + "prompt": [{"role": "user", "content": "First?"}], + "completion": [{"role": "assistant", "content": "one"}], + "trajectory": [ + { + "prompt": [{"role": "user", "content": "First?"}], + "completion": [{"role": "assistant", "content": "one"}], + }, + { + "prompt": [ + {"role": "user", "content": "First?"}, + {"role": "assistant", "content": "one"}, + {"role": "user", "content": "Second?"}, + ], + "completion": [{"role": "assistant", "content": "two"}], + }, + ], + "reward": 1.0, + "metrics": {"answer": 1.0}, + "metadata": {"task_id": "t1"}, + "stop_condition": "done", + } + + trajectory = verifiers_state_to_trajectory(state) + + assert trajectory.reward == 1.0 + assert trajectory.metrics == {"answer": 1.0} + assert trajectory.metadata == {"task_id": "t1", "stop_condition": "done"} + assert trajectory.messages_and_choices == [ + {"role": "user", "content": "First?"}, + {"role": "assistant", "content": "one"}, + {"role": "user", "content": "Second?"}, + {"role": "assistant", "content": "two"}, + ] + + +def test_verifiers_state_to_trajectory_preserves_completion_tail_after_steps(): + state = { + "prompt": [{"role": "user", "content": "First?"}], + "completion": [ + {"role": "assistant", "content": "one"}, + {"role": "tool", "content": "tail result"}, + ], + "trajectory": [ + { + "prompt": [{"role": "user", "content": "First?"}], + "completion": [{"role": "assistant", "content": "one"}], + }, + ], + } + + trajectory = verifiers_state_to_trajectory(state) + + assert trajectory.messages_and_choices == [ + {"role": "user", "content": "First?"}, + {"role": "assistant", "content": "one"}, + {"role": "tool", "content": "tail result"}, + ] + + +def test_verifiers_state_to_trajectory_can_restore_trainable_choices(): + def choice_factory(message, step): + assert step["reward"] == 1.0 + return assistant_choice(message["content"]) + + trajectory = verifiers_state_to_trajectory( + { + "trajectory": [ + { + "prompt": [{"role": "user", "content": "Say hi"}], + "completion": [{"role": "assistant", "content": "hi"}], + "reward": 1.0, + } + ], + "reward": 1.0, + }, + choice_factory=choice_factory, + ) + + assert trajectory.messages_and_choices[0] == {"role": "user", "content": "Say hi"} + assert isinstance(trajectory.messages_and_choices[1], Choice) + assert trajectory.messages() == [ + {"role": "user", "content": "Say hi"}, + {"role": "assistant", "content": "hi"}, + ] + + +def test_verifiers_states_to_trajectory_group(): + group = verifiers_states_to_trajectory_group( + [ + { + "prompt": [{"role": "user", "content": "A"}], + "completion": [{"role": "assistant", "content": "B"}], + "reward": 0.25, + }, + { + "prompt": [{"role": "user", "content": "C"}], + "completion": [{"role": "assistant", "content": "D"}], + "reward": 0.5, + }, + ] + ) + + assert len(group.trajectories) == 2 + assert [trajectory.reward for trajectory in group.trajectories] == [0.25, 0.5]