diff --git a/src/art/dev/get_model_config.py b/src/art/dev/get_model_config.py index 850008ae0..8715c8d1d 100644 --- a/src/art/dev/get_model_config.py +++ b/src/art/dev/get_model_config.py @@ -1,6 +1,11 @@ from ..megatron.model_support import default_target_modules_for_model from .engine import EngineArgs -from .model import InitArgs, InternalModelConfig, PeftArgs, TrainerArgs +from .model import ( + InitArgs, + InternalModelConfig, + PeftArgs, + TrainerArgs, +) from .validate import is_dedicated_mode diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index bd9f462f5..fad4c06f6 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -1,5 +1,6 @@ import asyncio from dataclasses import dataclass, field +import gc import importlib import os from pathlib import Path @@ -18,7 +19,6 @@ from ..local.checkpoints import get_last_checkpoint_dir from ..preprocessing.pack import DiskPackedTensors from ..preprocessing.tokenize import SFTBatch -from ..unsloth.train import gc_and_empty_cuda_cache from ..utils.convert_moe_lora import convert_checkpoint_if_needed from ..utils.get_model_step import get_step_from_dir from ..utils.lifecycle import ( @@ -57,6 +57,13 @@ safe_open = safetensors.safe_open +def gc_and_empty_cuda_cache(n: int = 3) -> None: + for _ in range(n): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + class _RuntimeRequestKwargs(TypedDict, total=False): headers: dict[str, str] diff --git a/src/art/model.py b/src/art/model.py index 902f337e0..f499fe1d3 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -384,12 +384,15 @@ def litellm_completion_params(self, step: int | None = None) -> dict: model_name = self.get_inference_name(step) if self.trainable: model_name = f"hosted_vllm/{model_name}" - return { + params = { "model": model_name, "base_url": self.inference_base_url, "api_key": self.inference_api_key, "temperature": 1, # Important for trainable models } + if extra_body := self._default_chat_completion_extra_body(): + params["extra_body"] = extra_body + return params # ------------------------------------------------------------------ # Inference name helpers diff --git a/src/art/preprocessing/response_masking.py b/src/art/preprocessing/response_masking.py new file mode 100644 index 000000000..8ee31efef --- /dev/null +++ b/src/art/preprocessing/response_masking.py @@ -0,0 +1,50 @@ +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + + +def token_ids_for_template_part( + tokenizer: PreTrainedTokenizerBase, + template_part: str, +) -> list[int]: + return list(tokenizer(template_part, add_special_tokens=False).input_ids) + + +def _find_subsequence( + values: list[int], + pattern: list[int], + *, + start: int = 0, +) -> int | None: + if not pattern: + return None + last_start = len(values) - len(pattern) + for index in range(start, last_start + 1): + if values[index : index + len(pattern)] == pattern: + return index + return None + + +def response_only_labels( + input_ids: list[int], + *, + instruction_ids: list[int], + response_ids: list[int], +) -> list[int]: + labels = [-100] * len(input_ids) + index = 0 + while index < len(input_ids): + response_start = _find_subsequence(input_ids, response_ids, start=index) + if response_start is None: + break + + trainable_start = response_start + len(response_ids) + next_instruction_start = _find_subsequence( + input_ids, + instruction_ids, + start=trainable_start, + ) + trainable_end = ( + len(input_ids) if next_instruction_start is None else next_instruction_start + ) + labels[trainable_start:trainable_end] = input_ids[trainable_start:trainable_end] + index = trainable_end + return labels diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 5a55e2186..e3187d935 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -15,6 +15,11 @@ from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages from ..types import MessagesAndChoices +from ..utils.chat_template import ( + default_chat_template_kwargs_for_tokenizer, + merge_chat_template_kwargs, +) +from .response_masking import response_only_labels, token_ids_for_template_part ChatTemplateTool = dict[Any, Any] | Callable[..., Any] ChatTemplateToolSchemaFormat = Literal["default", "vllm_openai"] @@ -24,14 +29,10 @@ def _chat_template_kwargs( tokenizer: PreTrainedTokenizerBase, chat_template_kwargs: dict[str, Any] | None, ) -> dict[str, Any]: - kwargs: dict[str, Any] = {} - if isinstance(tokenizer.chat_template, str): - if "enable_thinking" in tokenizer.chat_template: - kwargs["enable_thinking"] = False - if "preserve_thinking" in tokenizer.chat_template: - kwargs["preserve_thinking"] = True - kwargs.update(chat_template_kwargs or {}) - return kwargs + return merge_chat_template_kwargs( + default_chat_template_kwargs_for_tokenizer(tokenizer), + chat_template_kwargs, + ) def _normalize_tool_for_vllm_openai(tool: ChatTemplateTool) -> ChatTemplateTool: @@ -583,17 +584,8 @@ def tokenize_sft_batch( """ _validate_max_seq_length(max_seq_length) - import unsloth # noqa: F401 - Must be imported first to set UNSLOTH_IS_PRESENT env var - from unsloth_zoo.dataset_utils import train_on_responses_only - - train_on_responses_only_fn = train_on_responses_only( - trainer=None, - instruction_part=instruction_part, - response_part=response_part, - force_match=False, - tokenizer=tokenizer, - return_function=True, - ) + instruction_ids = token_ids_for_template_part(tokenizer, instruction_part) + response_ids = token_ids_for_template_part(tokenizer, response_part) # Tokenize all trajectories (no padding — each keeps its natural length) trajectory_tensors = [] num_tokens = 0 @@ -625,7 +617,11 @@ def tokenize_sft_batch( attention_mask = [1] * len(input_ids) - labels = train_on_responses_only_fn({"input_ids": [input_ids]})["labels"][0] + labels = response_only_labels( + input_ids, + instruction_ids=instruction_ids, + response_ids=response_ids, + ) trajectory_tensors.append( { diff --git a/src/art/tinker/server.py b/src/art/tinker/server.py index 9f9341895..9c4f21f9a 100644 --- a/src/art/tinker/server.py +++ b/src/art/tinker/server.py @@ -36,6 +36,7 @@ from art.tinker.prefix_cache import LRUTrieCache from art.tinker.renderers import get_renderer_name, is_qwen3_dot_family_model from art.types import Message, Tools +from art.utils.chat_template import default_chat_template_kwargs_for_tokenizer from mp_actors import close_proxy, move_to_child_process @@ -552,12 +553,7 @@ async def prompt_tokens( ) -> list[int]: normalized_messages = _normalize_qwen3_dot_messages(base_model, messages) tokenizer = self._get_renderer(base_model).tokenizer - chat_template_kwargs = {} - if isinstance(tokenizer.chat_template, str): - if "enable_thinking" in tokenizer.chat_template: - chat_template_kwargs["enable_thinking"] = False - if "preserve_thinking" in tokenizer.chat_template: - chat_template_kwargs["preserve_thinking"] = True + chat_template_kwargs = default_chat_template_kwargs_for_tokenizer(tokenizer) encoding = tokenizer.apply_chat_template( cast(Any, normalized_messages), tools=cast(Any, tools), diff --git a/src/art/utils/chat_template.py b/src/art/utils/chat_template.py new file mode 100644 index 000000000..89043cc19 --- /dev/null +++ b/src/art/utils/chat_template.py @@ -0,0 +1,32 @@ +from typing import Any + +THINKING_CHAT_TEMPLATE_KWARGS: dict[str, Any] = { + "enable_thinking": False, + "preserve_thinking": True, +} + + +def default_chat_template_kwargs_for_template( + chat_template: object, +) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + if not isinstance(chat_template, str): + return kwargs + if "enable_thinking" in chat_template: + kwargs["enable_thinking"] = False + if "preserve_thinking" in chat_template: + kwargs["preserve_thinking"] = True + return kwargs + + +def default_chat_template_kwargs_for_tokenizer(tokenizer: object) -> dict[str, Any]: + return default_chat_template_kwargs_for_template( + getattr(tokenizer, "chat_template", None) + ) + + +def merge_chat_template_kwargs( + defaults: dict[str, Any] | None, + overrides: dict[str, Any] | None, +) -> dict[str, Any]: + return {**(defaults or {}), **(overrides or {})} diff --git a/tests/unit/test_preprocessing_tokenize.py b/tests/unit/test_preprocessing_tokenize.py index 4be51b026..587207b30 100644 --- a/tests/unit/test_preprocessing_tokenize.py +++ b/tests/unit/test_preprocessing_tokenize.py @@ -6,7 +6,7 @@ import pytest from transformers.tokenization_utils_base import BatchEncoding -from art.preprocessing.tokenize import tokenize_trajectory +from art.preprocessing.tokenize import tokenize_sft_batch, tokenize_trajectory from art.trajectories import History, Trajectory from art.types import MessagesAndChoices @@ -66,6 +66,13 @@ def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: del add_special_tokens return [ord(char) for char in text] + def __call__(self, text: str, add_special_tokens: bool = False): + return type( + "TokenizedText", + (), + {"input_ids": self.encode(text, add_special_tokens=add_special_tokens)}, + )() + def decode(self, token_ids): if isinstance(token_ids, int): return chr(token_ids) @@ -199,6 +206,30 @@ def test_tokenize_trajectory_passes_chat_template_kwargs() -> None: ) +def test_tokenize_sft_batch_masks_response_tokens_without_unsloth_import() -> None: + tokenizer = _FakeTokenizer() + messages = cast( + MessagesAndChoices, + [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "OK"}, + ], + ) + + batch = tokenize_sft_batch( + trajectory_batch=[Trajectory(messages_and_choices=messages, reward=1.0)], + learning_rate=1e-5, + tokenizer=tokenizer, # type: ignore[arg-type] + instruction_part="", + response_part="", + ) + + labels = batch.trajectory_tensors[0]["labels"][0].tolist() + trainable_token_ids = [token_id for token_id in labels if token_id != -100] + assert tokenizer.decode(trainable_token_ids) == "OK" + assert batch.num_trainable_tokens == 2 + + def test_tokenize_trajectory_does_not_continue_real_completion_with_thinking() -> None: tokenizer = _ContinueFinalMessageRejectingTokenizer() choice = Choice.model_validate(