Skip to content

Commit ef6d869

Browse files
committed
fix(frontend): auto-detect force_reasoning when chat template appends <think>
Signed-off-by: Naveen Marri <[email protected]>
1 parent e041ccf commit ef6d869

2 files changed

Lines changed: 213 additions & 15 deletions

File tree

components/src/dynamo/frontend/sglang_prepost.py

Lines changed: 165 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import json
77
import logging
8+
import re
89
from dataclasses import dataclass
910
from typing import Any, TypeAlias
1011

@@ -40,6 +41,106 @@ class SglangPreprocessResult:
4041
reasoning_parser: ReasoningParser | None
4142
guided_decoding: dict[str, Any] | None
4243
request: dict[str, Any]
44+
force_reasoning: bool = False
45+
46+
47+
# --- force_reasoning detection (mirrors sglang's template_manager) -------
48+
#
49+
# sglang's template_manager sets ``_force_reasoning`` once at startup by
50+
# scanning the chat template for ``<|im_start|>assistant\n<think>\n``
51+
# (the qwen3 pattern). We broaden that to also catch GLM-4.5/5 templates
52+
# which open a thinking block right before the generation prompt.
53+
#
54+
# A static, per-server boolean is plenty: per-request decoding of prompt
55+
# tails adds latency on the hot path with nothing to show for it. The
56+
# per-request knobs live downstream (``separate_reasoning``,
57+
# ``chat_template_kwargs.enable_thinking``), matching sglang's API.
58+
_FORCE_REASONING_PATTERNS = (
59+
# qwen3-family: <|im_start|>assistant\n<think>\n
60+
re.compile(r"<\|im_start\|>assistant\\n<think>\\n"),
61+
# GLM-4.5/5 and similar: <|assistant|> followed by an opening <think>
62+
# within the generation-prompt block. The template often has Jinja
63+
# expressions (including a '</think>' literal) between the two, so we
64+
# match the opening tag literally -- '<think>' never matches
65+
# '</think>' because the '/' breaks the literal prefix.
66+
re.compile(r"<\|assistant\|>[\s\S]{0,400}?<think>"),
67+
# generic fallback for non-delimiter-style templates
68+
re.compile(r"\bassistant\b[\s\S]{0,200}?<think>"),
69+
)
70+
71+
72+
def detect_force_reasoning_from_template(chat_template: str | None) -> bool:
73+
"""Return True if the chat template auto-opens a reasoning block.
74+
75+
Intended to be called once at processor startup with
76+
``tokenizer.chat_template`` and cached on the processor.
77+
"""
78+
if not chat_template or not isinstance(chat_template, str):
79+
return False
80+
for pat in _FORCE_REASONING_PATTERNS:
81+
if pat.search(chat_template):
82+
return True
83+
return False
84+
85+
86+
# Reasoning parsers that default to "thinking on" unless the client
87+
# explicitly opts out via chat_template_kwargs. Mirrors sglang's
88+
# serving_chat._get_reasoning_from_request table.
89+
_THINKING_BY_DEFAULT = {"qwen3", "glm45", "nemotron_3", "interns1", "kimi_k2"}
90+
_THINKING_OPT_IN = {"deepseek-v3", "gemma4"}
91+
92+
93+
def resolve_request_force_reasoning(
94+
request: dict[str, Any],
95+
reasoning_parser_name: str | None,
96+
template_default: bool,
97+
) -> bool:
98+
"""Resolve the effective force_reasoning flag for a single request.
99+
100+
Mirrors sglang.srt.entrypoints.openai.serving_chat._get_reasoning_from_request
101+
combined with template_manager.force_reasoning:
102+
103+
* opt-out families (``glm45``/``qwen3``/``kimi_k2``/...): on by
104+
default, ``chat_template_kwargs.enable_thinking=False`` (or
105+
``thinking=False`` for ``kimi_k2``) disables it.
106+
* opt-in families (``deepseek-v3``/``gemma4``): off by default,
107+
enabled by ``chat_template_kwargs.{thinking,enable_thinking}=True``.
108+
* anything else: follow the statically-detected template default.
109+
"""
110+
if not reasoning_parser_name:
111+
return False
112+
113+
kwargs = request.get("chat_template_kwargs") or {}
114+
115+
if reasoning_parser_name in _THINKING_BY_DEFAULT:
116+
flag_key = (
117+
"thinking" if reasoning_parser_name == "kimi_k2" else "enable_thinking"
118+
)
119+
return kwargs.get(flag_key) is not False
120+
121+
if reasoning_parser_name in _THINKING_OPT_IN:
122+
flag_key = (
123+
"thinking" if reasoning_parser_name == "deepseek-v3" else "enable_thinking"
124+
)
125+
return kwargs.get(flag_key) is True
126+
127+
return template_default
128+
129+
130+
def _client_wants_separate_reasoning(request: dict[str, Any]) -> bool:
131+
"""Honor the client's ``separate_reasoning`` flag (default True).
132+
133+
Matches sglang's ChatCompletionRequest.separate_reasoning: a client
134+
sending ``separate_reasoning=False`` asks for thinking text to land in
135+
``delta.content`` instead of ``delta.reasoning_content``. We implement
136+
that by skipping reasoning-parser creation entirely for the request.
137+
"""
138+
value = request.get("separate_reasoning", True)
139+
if isinstance(value, bool):
140+
return value
141+
if isinstance(value, str):
142+
return value.lower() not in ("0", "false", "no", "off")
143+
return bool(value)
43144

44145

45146
def convert_tools(tools: list[dict[str, Any]] | None) -> list[SglangTool] | None:
@@ -73,15 +174,48 @@ def _materialize_messages(messages: list[Any]) -> list[dict[str, Any]]:
73174
normalized.append(msg)
74175
else:
75176
normalized.append(dict(msg))
177+
_parse_tool_call_arguments(normalized)
76178
return normalized
77179

78180

181+
def _parse_tool_call_arguments(messages: list[dict[str, Any]]) -> None:
182+
"""In-place: parse ``tool_calls[*].function.arguments`` JSON strings to dicts.
183+
184+
OpenAI sends tool-call arguments as JSON strings, but some chat templates
185+
(e.g. GLM-4.5/5) iterate ``arguments.items()`` and raise a Jinja2
186+
``UndefinedError`` on a string. Decoding here keeps the on-wire request
187+
OpenAI-compatible while giving the template a mapping to iterate.
188+
Malformed JSON is left untouched so downstream validation still sees it.
189+
"""
190+
for m in messages:
191+
tool_calls = m.get("tool_calls") if isinstance(m, dict) else None
192+
if not tool_calls:
193+
continue
194+
for tc in tool_calls:
195+
if not isinstance(tc, dict):
196+
continue
197+
target = tc.get("function") if isinstance(tc.get("function"), dict) else tc
198+
args = target.get("arguments")
199+
if not isinstance(args, str):
200+
continue
201+
if not args:
202+
target["arguments"] = {}
203+
continue
204+
try:
205+
parsed = json.loads(args)
206+
except (ValueError, TypeError):
207+
continue
208+
if isinstance(parsed, dict):
209+
target["arguments"] = parsed
210+
211+
79212
def create_parsers(
80213
request: dict[str, Any],
81214
*,
82215
tool_call_parser_name: str | None,
83216
reasoning_parser_name: str | None,
84217
sglang_tools: list[SglangTool] | None = None,
218+
force_reasoning: bool = False,
85219
) -> tuple[ToolCallParserType | None, ReasoningParser | None]:
86220
"""Create tool call and reasoning parsers for a request.
87221
@@ -115,6 +249,7 @@ def create_parsers(
115249
reasoning_parser = ReasoningParser(
116250
model_type=reasoning_parser_name,
117251
stream_reasoning=True,
252+
force_reasoning=force_reasoning,
118253
)
119254

120255
return tool_call_parser, reasoning_parser
@@ -193,19 +328,16 @@ def build_tool_call_guided_decoding(
193328

194329

195330
def _normalize_prompt_token_ids(prompt_token_ids: Any) -> list[int]:
196-
if isinstance(prompt_token_ids, list):
197-
return prompt_token_ids
331+
"""Flatten ``apply_chat_template`` output to ``list[int]``.
198332
199-
input_ids = getattr(prompt_token_ids, "input_ids", None)
200-
if input_ids is not None and not isinstance(input_ids, str):
201-
return list(input_ids)
202-
203-
if isinstance(prompt_token_ids, dict):
204-
dict_input_ids = prompt_token_ids.get("input_ids")
205-
if dict_input_ids is not None and not isinstance(dict_input_ids, str):
206-
return list(dict_input_ids)
207-
208-
return list(prompt_token_ids)
333+
On transformers v5 the default ``TokenizersBackend`` returns a
334+
``BatchEncoding`` from ``apply_chat_template(..., tokenize=True)``;
335+
unwrap to ``.input_ids`` (a flat list for a single conversation).
336+
"""
337+
ids = getattr(prompt_token_ids, "input_ids", prompt_token_ids)
338+
if isinstance(ids, dict):
339+
ids = ids.get("input_ids", prompt_token_ids)
340+
return list(ids)
209341

210342

211343
def preprocess_chat_request(
@@ -215,13 +347,31 @@ def preprocess_chat_request(
215347
tool_call_parser_name: str | None,
216348
reasoning_parser_name: str | None,
217349
exclude_tools_when_tool_choice_none: bool = True,
350+
template_force_reasoning: bool = False,
218351
) -> SglangPreprocessResult:
219352
"""Preprocess a chat request using SGLang tokenizer and parser APIs.
220353
354+
``template_force_reasoning`` is the static per-server flag derived from
355+
the chat template (see :func:`detect_force_reasoning_from_template`);
356+
the effective per-request value combines it with client knobs
357+
(``separate_reasoning``, ``chat_template_kwargs.enable_thinking``).
358+
221359
Synchronous -- suitable for both main-process and worker-process execution.
222360
"""
223361
messages = _materialize_messages(request.get("messages", []))
224362

363+
# Per-request client escape hatch: skip reasoning parsing entirely when
364+
# the client sends ``separate_reasoning=False`` -- thinking text then
365+
# lands in ``delta.content`` instead of ``delta.reasoning_content``.
366+
effective_reasoning_parser_name = (
367+
reasoning_parser_name if _client_wants_separate_reasoning(request) else None
368+
)
369+
force_reasoning = resolve_request_force_reasoning(
370+
request,
371+
effective_reasoning_parser_name,
372+
template_force_reasoning,
373+
)
374+
225375
# Convert tools to SGLang format (done once, shared with parser creation)
226376
sglang_tools = convert_tools(request.get("tools"))
227377

@@ -266,8 +416,9 @@ def preprocess_chat_request(
266416
tool_call_parser, reasoning_parser = create_parsers(
267417
request,
268418
tool_call_parser_name=tool_call_parser_name,
269-
reasoning_parser_name=reasoning_parser_name,
419+
reasoning_parser_name=effective_reasoning_parser_name,
270420
sglang_tools=sglang_tools,
421+
force_reasoning=force_reasoning,
271422
)
272423
guided_decoding = build_tool_call_guided_decoding(
273424
request,
@@ -281,6 +432,7 @@ def preprocess_chat_request(
281432
reasoning_parser=reasoning_parser,
282433
guided_decoding=guided_decoding,
283434
request=request,
435+
force_reasoning=force_reasoning,
284436
)
285437

286438

components/src/dynamo/frontend/sglang_processor.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@
3333
from .sglang_prepost import (
3434
SglangStreamingPostProcessor,
3535
ToolCallParserType,
36+
_client_wants_separate_reasoning,
3637
_get_history_tool_calls_count,
3738
convert_tools,
3839
create_parsers,
40+
detect_force_reasoning_from_template,
3941
preprocess_chat_request,
4042
)
4143
from .utils import PreprocessError, extract_mm_urls, random_uuid, worker_warmup
@@ -104,6 +106,7 @@ def _map_finish_reason(raw: str | None) -> str | None:
104106
_w_tool_call_parser_name: str | None = None
105107
_w_reasoning_parser_name: str | None = None
106108
_w_exclude_tools_when_tool_choice_none: bool = True
109+
_w_template_force_reasoning: bool = False
107110

108111

109112
@dataclass
@@ -113,6 +116,12 @@ class SglangPreprocessWorkerResult:
113116
prompt_token_ids: list[int]
114117
dynamo_preproc: dict[str, Any]
115118
request: dict[str, Any]
119+
force_reasoning: bool = False
120+
# ``effective_reasoning_parser_name`` is None when the request opted out
121+
# via ``separate_reasoning=False``; the main process must skip creating
122+
# a reasoning parser in that case so the pool path matches the inline
123+
# path byte-for-byte.
124+
effective_reasoning_parser_name: str | None = None
116125

117126

118127
def _init_worker(
@@ -121,14 +130,16 @@ def _init_worker(
121130
reasoning_parser_name: str | None,
122131
exclude_tools_when_tool_choice_none: bool = True,
123132
trust_remote_code: bool = False,
133+
template_force_reasoning: bool = False,
124134
) -> None:
125135
"""Initialize a worker process with its own tokenizer."""
126136
global _w_tokenizer, _w_tool_call_parser_name, _w_reasoning_parser_name
127-
global _w_exclude_tools_when_tool_choice_none
137+
global _w_exclude_tools_when_tool_choice_none, _w_template_force_reasoning
128138
_w_tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
129139
_w_tool_call_parser_name = tool_call_parser_name
130140
_w_reasoning_parser_name = reasoning_parser_name
131141
_w_exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
142+
_w_template_force_reasoning = template_force_reasoning
132143

133144

134145
def _preprocess_worker(
@@ -143,6 +154,7 @@ def _preprocess_worker(
143154
tool_call_parser_name=_w_tool_call_parser_name,
144155
reasoning_parser_name=_w_reasoning_parser_name,
145156
exclude_tools_when_tool_choice_none=_w_exclude_tools_when_tool_choice_none,
157+
template_force_reasoning=_w_template_force_reasoning,
146158
)
147159

148160
n = request.get("n", 1)
@@ -158,10 +170,16 @@ def _preprocess_worker(
158170
pre.tool_call_parser,
159171
)
160172

173+
effective_reasoning_parser_name = (
174+
_w_reasoning_parser_name if _client_wants_separate_reasoning(request) else None
175+
)
176+
161177
return SglangPreprocessWorkerResult(
162178
prompt_token_ids=pre.prompt_token_ids,
163179
dynamo_preproc=dynamo_preproc,
164180
request=request,
181+
force_reasoning=pre.force_reasoning,
182+
effective_reasoning_parser_name=effective_reasoning_parser_name,
165183
)
166184

167185

@@ -254,6 +272,20 @@ def __init__(
254272
stream_interval: int = 1,
255273
):
256274
self.tokenizer = tokenizer
275+
# Detect force_reasoning once from the chat template, matching
276+
# sglang's template_manager. Per-request overrides still apply
277+
# (see resolve_request_force_reasoning).
278+
self.template_force_reasoning = detect_force_reasoning_from_template(
279+
getattr(tokenizer, "chat_template", None)
280+
)
281+
if self.template_force_reasoning:
282+
logger.info(
283+
"Detected force-reasoning pattern in chat template; "
284+
"thinking tokens will route to delta.reasoning_content by "
285+
"default (clients can opt out via "
286+
"separate_reasoning=false or "
287+
"chat_template_kwargs.enable_thinking=false)."
288+
)
257289
self.router = router
258290
self.is_kv_router = isinstance(router, KvRouter)
259291
self.tool_call_parser_name = tool_call_parser_name
@@ -317,6 +349,7 @@ async def _generator_inner(
317349
tool_call_parser_name=self.tool_call_parser_name,
318350
reasoning_parser_name=self.reasoning_parser_name,
319351
exclude_tools_when_tool_choice_none=self.exclude_tools_when_tool_choice_none,
352+
template_force_reasoning=self.template_force_reasoning,
320353
)
321354

322355
if self.debug_perf:
@@ -405,10 +438,15 @@ async def _generator_inner_pool(
405438
return
406439

407440
# --- Phase 2: Recreate parsers in main process (not picklable) ---
441+
# The worker already decided effective_reasoning_parser_name based on
442+
# the request's separate_reasoning flag and computed force_reasoning;
443+
# we mirror those choices to keep pool- and inline-path outputs
444+
# identical.
408445
tool_call_parser, reasoning_parser = create_parsers(
409446
request,
410447
tool_call_parser_name=self.tool_call_parser_name,
411-
reasoning_parser_name=self.reasoning_parser_name,
448+
reasoning_parser_name=preproc_result.effective_reasoning_parser_name,
449+
force_reasoning=preproc_result.force_reasoning,
412450
)
413451

414452
post = SglangStreamingPostProcessor(
@@ -596,6 +634,13 @@ async def chat_engine_factory(
596634

597635
eos_token_id = getattr(tokenizer, "eos_token_id", None)
598636

637+
# Static reasoning-template scan (mirrors sglang's template_manager).
638+
# Shared with worker-pool processes via initargs so they compute the
639+
# same per-request force_reasoning flag as the main process.
640+
template_force_reasoning = detect_force_reasoning_from_template(
641+
getattr(tokenizer, "chat_template", None)
642+
)
643+
599644
tool_call_parser_name = (
600645
self.tool_call_parser_name
601646
or _runtime_config_parser_name(mdc, "tool_call_parser")
@@ -643,6 +688,7 @@ async def chat_engine_factory(
643688
reasoning_parser_name,
644689
self.config.exclude_tools_when_tool_choice_none,
645690
self.trust_remote_code,
691+
template_force_reasoning,
646692
),
647693
)
648694
futures = [

0 commit comments

Comments
 (0)