Skip to content

Commit 09bdfbc

Browse files
committed
re-align implementation with sglang
1 parent 52248ee commit 09bdfbc

2 files changed

Lines changed: 212 additions & 47 deletions

File tree

components/src/dynamo/frontend/sglang_prepost.py

Lines changed: 161 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import json
77
import logging
8+
import re
89
from dataclasses import dataclass
910
from typing import Any, TypeAlias
1011

@@ -40,6 +41,102 @@ class SglangPreprocessResult:
4041
reasoning_parser: ReasoningParser | None
4142
guided_decoding: dict[str, Any] | None
4243
request: dict[str, Any]
44+
force_reasoning: bool = False
45+
46+
47+
# --- force_reasoning detection (mirrors sglang's template_manager) -------
48+
#
49+
# sglang's template_manager sets ``_force_reasoning`` once at startup by
50+
# scanning the chat template for ``<|im_start|>assistant\n<think>\n``
51+
# (the qwen3 pattern). We broaden that to also catch GLM-4.5/5 templates
52+
# which open a thinking block right before the generation prompt.
53+
#
54+
# A static, per-server boolean is plenty: per-request decoding of prompt
55+
# tails adds latency on the hot path with nothing to show for it. The
56+
# per-request knobs live downstream (``separate_reasoning``,
57+
# ``chat_template_kwargs.enable_thinking``), matching sglang's API.
58+
_FORCE_REASONING_PATTERNS = (
59+
# qwen3-family: <|im_start|>assistant\n<think>\n
60+
re.compile(r"<\|im_start\|>assistant\\n<think>\\n"),
61+
# GLM-4.5/5 and similar: <|assistant|> followed by an opening <think>
62+
# within the generation-prompt block. The template often has Jinja
63+
# expressions (including a '</think>' literal) between the two, so we
64+
# match the opening tag literally -- '<think>' never matches
65+
# '</think>' because the '/' breaks the literal prefix.
66+
re.compile(r"<\|assistant\|>[\s\S]{0,400}?<think>"),
67+
# generic fallback for non-delimiter-style templates
68+
re.compile(r"\bassistant\b[\s\S]{0,200}?<think>"),
69+
)
70+
71+
72+
def detect_force_reasoning_from_template(chat_template: str | None) -> bool:
73+
"""Return True if the chat template auto-opens a reasoning block.
74+
75+
Intended to be called once at processor startup with
76+
``tokenizer.chat_template`` and cached on the processor.
77+
"""
78+
if not chat_template or not isinstance(chat_template, str):
79+
return False
80+
for pat in _FORCE_REASONING_PATTERNS:
81+
if pat.search(chat_template):
82+
return True
83+
return False
84+
85+
86+
# Reasoning parsers that default to "thinking on" unless the client
87+
# explicitly opts out via chat_template_kwargs. Mirrors sglang's
88+
# serving_chat._get_reasoning_from_request table.
89+
_THINKING_BY_DEFAULT = {"qwen3", "glm45", "nemotron_3", "interns1", "kimi_k2"}
90+
_THINKING_OPT_IN = {"deepseek-v3", "gemma4"}
91+
92+
93+
def resolve_request_force_reasoning(
94+
request: dict[str, Any],
95+
reasoning_parser_name: str | None,
96+
template_default: bool,
97+
) -> bool:
98+
"""Resolve the effective force_reasoning flag for a single request.
99+
100+
Mirrors sglang.srt.entrypoints.openai.serving_chat._get_reasoning_from_request
101+
combined with template_manager.force_reasoning:
102+
103+
* opt-out families (``glm45``/``qwen3``/``kimi_k2``/...): on by
104+
default, ``chat_template_kwargs.enable_thinking=False`` (or
105+
``thinking=False`` for ``kimi_k2``) disables it.
106+
* opt-in families (``deepseek-v3``/``gemma4``): off by default,
107+
enabled by ``chat_template_kwargs.{thinking,enable_thinking}=True``.
108+
* anything else: follow the statically-detected template default.
109+
"""
110+
if not reasoning_parser_name:
111+
return False
112+
113+
kwargs = request.get("chat_template_kwargs") or {}
114+
115+
if reasoning_parser_name in _THINKING_BY_DEFAULT:
116+
flag_key = "thinking" if reasoning_parser_name == "kimi_k2" else "enable_thinking"
117+
return kwargs.get(flag_key) is not False
118+
119+
if reasoning_parser_name in _THINKING_OPT_IN:
120+
flag_key = "thinking" if reasoning_parser_name == "deepseek-v3" else "enable_thinking"
121+
return kwargs.get(flag_key) is True
122+
123+
return template_default
124+
125+
126+
def _client_wants_separate_reasoning(request: dict[str, Any]) -> bool:
127+
"""Honor the client's ``separate_reasoning`` flag (default True).
128+
129+
Matches sglang's ChatCompletionRequest.separate_reasoning: a client
130+
sending ``separate_reasoning=False`` asks for thinking text to land in
131+
``delta.content`` instead of ``delta.reasoning_content``. We implement
132+
that by skipping reasoning-parser creation entirely for the request.
133+
"""
134+
value = request.get("separate_reasoning", True)
135+
if isinstance(value, bool):
136+
return value
137+
if isinstance(value, str):
138+
return value.lower() not in ("0", "false", "no", "off")
139+
return bool(value)
43140

44141

45142
def convert_tools(tools: list[dict[str, Any]] | None) -> list[SglangTool] | None:
@@ -73,22 +170,39 @@ def _materialize_messages(messages: list[Any]) -> list[dict[str, Any]]:
73170
normalized.append(msg)
74171
else:
75172
normalized.append(dict(msg))
173+
_parse_tool_call_arguments(normalized)
76174
return normalized
77175

78176

79-
def detect_force_reasoning(tokenizer, prompt_token_ids: list[int]) -> bool:
80-
"""Check if the chat template's generation prompt ends with ``<think>``.
177+
def _parse_tool_call_arguments(messages: list[dict[str, Any]]) -> None:
178+
"""In-place: parse ``tool_calls[*].function.arguments`` JSON strings to dicts.
81179
82-
When the template appends ``<think>`` to the prompt, the model output
83-
starts inside a reasoning block without an explicit opening tag.
84-
The reasoning parser must be told to begin in reasoning mode
85-
(``force_reasoning=True``) so that it correctly separates reasoning
86-
content from normal content.
180+
OpenAI sends tool-call arguments as JSON strings, but some chat templates
181+
(e.g. GLM-4.5/5) iterate ``arguments.items()`` and raise a Jinja2
182+
``UndefinedError`` on a string. Decoding here keeps the on-wire request
183+
OpenAI-compatible while giving the template a mapping to iterate.
184+
Malformed JSON is left untouched so downstream validation still sees it.
87185
"""
88-
if not prompt_token_ids:
89-
return False
90-
tail = tokenizer.decode(prompt_token_ids[-10:], skip_special_tokens=False)
91-
return tail.rstrip().endswith("<think>")
186+
for m in messages:
187+
tool_calls = m.get("tool_calls") if isinstance(m, dict) else None
188+
if not tool_calls:
189+
continue
190+
for tc in tool_calls:
191+
if not isinstance(tc, dict):
192+
continue
193+
target = tc.get("function") if isinstance(tc.get("function"), dict) else tc
194+
args = target.get("arguments")
195+
if not isinstance(args, str):
196+
continue
197+
if not args:
198+
target["arguments"] = {}
199+
continue
200+
try:
201+
parsed = json.loads(args)
202+
except (ValueError, TypeError):
203+
continue
204+
if isinstance(parsed, dict):
205+
target["arguments"] = parsed
92206

93207

94208
def create_parsers(
@@ -128,13 +242,11 @@ def create_parsers(
128242

129243
reasoning_parser = None
130244
if reasoning_parser_name:
131-
kwargs: dict[str, Any] = {
132-
"model_type": reasoning_parser_name,
133-
"stream_reasoning": True,
134-
}
135-
if force_reasoning:
136-
kwargs["force_reasoning"] = True
137-
reasoning_parser = ReasoningParser(**kwargs)
245+
reasoning_parser = ReasoningParser(
246+
model_type=reasoning_parser_name,
247+
stream_reasoning=True,
248+
force_reasoning=force_reasoning,
249+
)
138250

139251
return tool_call_parser, reasoning_parser
140252

@@ -212,19 +324,16 @@ def build_tool_call_guided_decoding(
212324

213325

214326
def _normalize_prompt_token_ids(prompt_token_ids: Any) -> list[int]:
215-
if isinstance(prompt_token_ids, list):
216-
return prompt_token_ids
217-
218-
input_ids = getattr(prompt_token_ids, "input_ids", None)
219-
if input_ids is not None and not isinstance(input_ids, str):
220-
return list(input_ids)
221-
222-
if isinstance(prompt_token_ids, dict):
223-
dict_input_ids = prompt_token_ids.get("input_ids")
224-
if dict_input_ids is not None and not isinstance(dict_input_ids, str):
225-
return list(dict_input_ids)
327+
"""Flatten ``apply_chat_template`` output to ``list[int]``.
226328
227-
return list(prompt_token_ids)
329+
On transformers v5 the default ``TokenizersBackend`` returns a
330+
``BatchEncoding`` from ``apply_chat_template(..., tokenize=True)``;
331+
unwrap to ``.input_ids`` (a flat list for a single conversation).
332+
"""
333+
ids = getattr(prompt_token_ids, "input_ids", prompt_token_ids)
334+
if isinstance(ids, dict):
335+
ids = ids.get("input_ids", prompt_token_ids)
336+
return list(ids)
228337

229338

230339
def preprocess_chat_request(
@@ -234,13 +343,33 @@ def preprocess_chat_request(
234343
tool_call_parser_name: str | None,
235344
reasoning_parser_name: str | None,
236345
exclude_tools_when_tool_choice_none: bool = True,
346+
template_force_reasoning: bool = False,
237347
) -> SglangPreprocessResult:
238348
"""Preprocess a chat request using SGLang tokenizer and parser APIs.
239349
350+
``template_force_reasoning`` is the static per-server flag derived from
351+
the chat template (see :func:`detect_force_reasoning_from_template`);
352+
the effective per-request value combines it with client knobs
353+
(``separate_reasoning``, ``chat_template_kwargs.enable_thinking``).
354+
240355
Synchronous -- suitable for both main-process and worker-process execution.
241356
"""
242357
messages = _materialize_messages(request.get("messages", []))
243358

359+
# Per-request client escape hatch: skip reasoning parsing entirely when
360+
# the client sends ``separate_reasoning=False`` -- thinking text then
361+
# lands in ``delta.content`` instead of ``delta.reasoning_content``.
362+
effective_reasoning_parser_name = (
363+
reasoning_parser_name
364+
if _client_wants_separate_reasoning(request)
365+
else None
366+
)
367+
force_reasoning = resolve_request_force_reasoning(
368+
request,
369+
effective_reasoning_parser_name,
370+
template_force_reasoning,
371+
)
372+
244373
# Convert tools to SGLang format (done once, shared with parser creation)
245374
sglang_tools = convert_tools(request.get("tools"))
246375

@@ -262,7 +391,6 @@ def preprocess_chat_request(
262391
template_kwargs: dict[str, Any] = {
263392
"add_generation_prompt": True,
264393
"tokenize": True,
265-
"return_dict": False,
266394
}
267395
# Strip tools from template when tool_choice=none so the model doesn't
268396
# see them and generate raw XML tool calls in its response.
@@ -283,16 +411,10 @@ def preprocess_chat_request(
283411
tokenizer.apply_chat_template(messages, **template_kwargs)
284412
)
285413

286-
force_reasoning = (
287-
detect_force_reasoning(tokenizer, prompt_token_ids)
288-
if reasoning_parser_name
289-
else False
290-
)
291-
292414
tool_call_parser, reasoning_parser = create_parsers(
293415
request,
294416
tool_call_parser_name=tool_call_parser_name,
295-
reasoning_parser_name=reasoning_parser_name,
417+
reasoning_parser_name=effective_reasoning_parser_name,
296418
sglang_tools=sglang_tools,
297419
force_reasoning=force_reasoning,
298420
)
@@ -308,6 +430,7 @@ def preprocess_chat_request(
308430
reasoning_parser=reasoning_parser,
309431
guided_decoding=guided_decoding,
310432
request=request,
433+
force_reasoning=force_reasoning,
311434
)
312435

313436

0 commit comments

Comments
 (0)