Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/art/dev/get_model_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from ..megatron.model_support import default_target_modules_for_model
from .engine import EngineArgs
from .model import InitArgs, InternalModelConfig, PeftArgs, TrainerArgs
from .model import (
InitArgs,
InternalModelConfig,
PeftArgs,
TrainerArgs,
)
from .validate import is_dedicated_mode


Expand Down
9 changes: 8 additions & 1 deletion src/art/megatron/service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from dataclasses import dataclass, field
import gc
import importlib
import os
from pathlib import Path
Expand All @@ -18,7 +19,6 @@
from ..local.checkpoints import get_last_checkpoint_dir
from ..preprocessing.pack import DiskPackedTensors
from ..preprocessing.tokenize import SFTBatch
from ..unsloth.train import gc_and_empty_cuda_cache
from ..utils.convert_moe_lora import convert_checkpoint_if_needed
from ..utils.get_model_step import get_step_from_dir
from ..utils.lifecycle import (
Expand Down Expand Up @@ -57,6 +57,13 @@
safe_open = safetensors.safe_open


def gc_and_empty_cuda_cache(n: int = 3) -> None:
for _ in range(n):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()


class _RuntimeRequestKwargs(TypedDict, total=False):
headers: dict[str, str]

Expand Down
5 changes: 4 additions & 1 deletion src/art/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,15 @@ def litellm_completion_params(self, step: int | None = None) -> dict:
model_name = self.get_inference_name(step)
if self.trainable:
model_name = f"hosted_vllm/{model_name}"
return {
params = {
"model": model_name,
"base_url": self.inference_base_url,
"api_key": self.inference_api_key,
"temperature": 1, # Important for trainable models
}
if extra_body := self._default_chat_completion_extra_body():
params["extra_body"] = extra_body
return params

# ------------------------------------------------------------------
# Inference name helpers
Expand Down
50 changes: 50 additions & 0 deletions src/art/preprocessing/response_masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from transformers.tokenization_utils_base import PreTrainedTokenizerBase


def token_ids_for_template_part(
tokenizer: PreTrainedTokenizerBase,
template_part: str,
) -> list[int]:
return list(tokenizer(template_part, add_special_tokens=False).input_ids)


def _find_subsequence(
values: list[int],
pattern: list[int],
*,
start: int = 0,
) -> int | None:
if not pattern:
return None
last_start = len(values) - len(pattern)
for index in range(start, last_start + 1):
if values[index : index + len(pattern)] == pattern:
return index
return None


def response_only_labels(
input_ids: list[int],
*,
instruction_ids: list[int],
response_ids: list[int],
) -> list[int]:
labels = [-100] * len(input_ids)
index = 0
while index < len(input_ids):
response_start = _find_subsequence(input_ids, response_ids, start=index)
if response_start is None:
break

trainable_start = response_start + len(response_ids)
next_instruction_start = _find_subsequence(
input_ids,
instruction_ids,
start=trainable_start,
)
trainable_end = (
len(input_ids) if next_instruction_start is None else next_instruction_start
)
labels[trainable_start:trainable_end] = input_ids[trainable_start:trainable_end]
index = trainable_end
return labels
36 changes: 16 additions & 20 deletions src/art/preprocessing/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@

from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages
from ..types import MessagesAndChoices
from ..utils.chat_template import (
default_chat_template_kwargs_for_tokenizer,
merge_chat_template_kwargs,
)
from .response_masking import response_only_labels, token_ids_for_template_part

ChatTemplateTool = dict[Any, Any] | Callable[..., Any]
ChatTemplateToolSchemaFormat = Literal["default", "vllm_openai"]
Expand All @@ -24,14 +29,10 @@ def _chat_template_kwargs(
tokenizer: PreTrainedTokenizerBase,
chat_template_kwargs: dict[str, Any] | None,
) -> dict[str, Any]:
kwargs: dict[str, Any] = {}
if isinstance(tokenizer.chat_template, str):
if "enable_thinking" in tokenizer.chat_template:
kwargs["enable_thinking"] = False
if "preserve_thinking" in tokenizer.chat_template:
kwargs["preserve_thinking"] = True
kwargs.update(chat_template_kwargs or {})
return kwargs
return merge_chat_template_kwargs(
default_chat_template_kwargs_for_tokenizer(tokenizer),
chat_template_kwargs,
)


def _normalize_tool_for_vllm_openai(tool: ChatTemplateTool) -> ChatTemplateTool:
Expand Down Expand Up @@ -583,17 +584,8 @@ def tokenize_sft_batch(
"""
_validate_max_seq_length(max_seq_length)

import unsloth # noqa: F401 - Must be imported first to set UNSLOTH_IS_PRESENT env var
from unsloth_zoo.dataset_utils import train_on_responses_only

train_on_responses_only_fn = train_on_responses_only(
trainer=None,
instruction_part=instruction_part,
response_part=response_part,
force_match=False,
tokenizer=tokenizer,
return_function=True,
)
instruction_ids = token_ids_for_template_part(tokenizer, instruction_part)
response_ids = token_ids_for_template_part(tokenizer, response_part)
# Tokenize all trajectories (no padding — each keeps its natural length)
trajectory_tensors = []
num_tokens = 0
Expand Down Expand Up @@ -625,7 +617,11 @@ def tokenize_sft_batch(

attention_mask = [1] * len(input_ids)

labels = train_on_responses_only_fn({"input_ids": [input_ids]})["labels"][0]
labels = response_only_labels(
input_ids,
instruction_ids=instruction_ids,
response_ids=response_ids,
)

trajectory_tensors.append(
{
Expand Down
8 changes: 2 additions & 6 deletions src/art/tinker/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from art.tinker.prefix_cache import LRUTrieCache
from art.tinker.renderers import get_renderer_name, is_qwen3_dot_family_model
from art.types import Message, Tools
from art.utils.chat_template import default_chat_template_kwargs_for_tokenizer
from mp_actors import close_proxy, move_to_child_process


Expand Down Expand Up @@ -552,12 +553,7 @@ async def prompt_tokens(
) -> list[int]:
normalized_messages = _normalize_qwen3_dot_messages(base_model, messages)
tokenizer = self._get_renderer(base_model).tokenizer
chat_template_kwargs = {}
if isinstance(tokenizer.chat_template, str):
if "enable_thinking" in tokenizer.chat_template:
chat_template_kwargs["enable_thinking"] = False
if "preserve_thinking" in tokenizer.chat_template:
chat_template_kwargs["preserve_thinking"] = True
chat_template_kwargs = default_chat_template_kwargs_for_tokenizer(tokenizer)
encoding = tokenizer.apply_chat_template(
cast(Any, normalized_messages),
tools=cast(Any, tools),
Expand Down
32 changes: 32 additions & 0 deletions src/art/utils/chat_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Any

THINKING_CHAT_TEMPLATE_KWARGS: dict[str, Any] = {
"enable_thinking": False,
"preserve_thinking": True,
}


def default_chat_template_kwargs_for_template(
chat_template: object,
) -> dict[str, Any]:
kwargs: dict[str, Any] = {}
if not isinstance(chat_template, str):
return kwargs
if "enable_thinking" in chat_template:
kwargs["enable_thinking"] = False
if "preserve_thinking" in chat_template:
kwargs["preserve_thinking"] = True
return kwargs


def default_chat_template_kwargs_for_tokenizer(tokenizer: object) -> dict[str, Any]:
return default_chat_template_kwargs_for_template(
getattr(tokenizer, "chat_template", None)
)


def merge_chat_template_kwargs(
defaults: dict[str, Any] | None,
overrides: dict[str, Any] | None,
) -> dict[str, Any]:
return {**(defaults or {}), **(overrides or {})}
33 changes: 32 additions & 1 deletion tests/unit/test_preprocessing_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
from transformers.tokenization_utils_base import BatchEncoding

from art.preprocessing.tokenize import tokenize_trajectory
from art.preprocessing.tokenize import tokenize_sft_batch, tokenize_trajectory
from art.trajectories import History, Trajectory
from art.types import MessagesAndChoices

Expand Down Expand Up @@ -66,6 +66,13 @@ def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
del add_special_tokens
return [ord(char) for char in text]

def __call__(self, text: str, add_special_tokens: bool = False):
return type(
"TokenizedText",
(),
{"input_ids": self.encode(text, add_special_tokens=add_special_tokens)},
)()

def decode(self, token_ids):
if isinstance(token_ids, int):
return chr(token_ids)
Expand Down Expand Up @@ -199,6 +206,30 @@ def test_tokenize_trajectory_passes_chat_template_kwargs() -> None:
)


def test_tokenize_sft_batch_masks_response_tokens_without_unsloth_import() -> None:
tokenizer = _FakeTokenizer()
messages = cast(
MessagesAndChoices,
[
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "OK"},
],
)

batch = tokenize_sft_batch(
trajectory_batch=[Trajectory(messages_and_choices=messages, reward=1.0)],
learning_rate=1e-5,
tokenizer=tokenizer, # type: ignore[arg-type]
instruction_part="<user>",
response_part="<assistant>",
)

labels = batch.trajectory_tensors[0]["labels"][0].tolist()
trainable_token_ids = [token_id for token_id in labels if token_id != -100]
assert tokenizer.decode(trainable_token_ids) == "OK"
assert batch.num_trainable_tokens == 2


def test_tokenize_trajectory_does_not_continue_real_completion_with_thinking() -> None:
tokenizer = _ContinueFinalMessageRejectingTokenizer()
choice = Choice.model_validate(
Expand Down
Loading