Add GSPO loss#517
Open
jlamypoirier wants to merge 2 commits into
Open
Conversation
Group Sequence Policy Optimization: per-segment geometric-mean IS-ratio clipping. Mirrors GRPO's structure via shared abstract bases (LanguageModelPolicyGradientLossConfig / LanguageModelPolicyGradientLoss); the kernel matches GRPO except for a segment-aggregation block that produces per-segment R and A and broadcasts them back, so the softmax-chain backward is identical to GRPO. SDP-aware via optional all-reduce of segment sums; per-token weighting (mask / token_count_s) lets the SUM reduction at LossDef level give the canonical result without further correction. PyTorch kernel only; no Triton variant yet. Also lifts document_index_q/k from MixerKwargs to BlockKwargs so the LM head can read them without cross-namespace coupling, and renames grpo.py -> policy_gradient.py. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
This was referenced May 19, 2026
# Conflicts: # fast_llm/layers/attention/config.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
LanguageModelPolicyGradientLossConfig/LanguageModelPolicyGradientLoss) for the common scaffolding (epsilon_low/high, new_logprobs metric, preprocessing flags) and refactors GRPO to inherit from them.document_index_q/document_index_kfromMixerKwargsup toBlockKwargsso the LM head can read them without depending on attention-namespaced keys.fast_llm/layers/language_model/loss/grpo.pytopolicy_gradient.py(now contains both losses + shared base).Design notes
The GSPO kernel structurally parallels
fused_grpo_loss_forward_backward:log_ratio,advantages, and counts into per-segment buffers; optional SDP all-reduce of the three buffers; computeR_s = exp(lrn_sum / token_count)andA_s = adv_sum / token_count(detached).R_s,A_sback to per-token; per-token loss weight ismask / token_count_sso each segment contributes once to the sum.probability_ratio = R_{s(t)}and per-token advantages replaced byA_{s(t)}.The per-token decomposition gets SDP correctness "for free": each rank only sums contributions from its own tokens, so SUM-reducing at the
LossDeflevel reproduces the canonical single-rank result with no/sdp_sizecorrection.No Triton variant yet — comes in a follow-up.
Test plan
pytest tests/layers/test_lm_losses.py::test_gspo_loss— 20 cases (10 param sets × 2 batch shapes)pytest tests/layers/test_lm_losses.py::test_grpo_loss— 20 cases pass after refactorpytest tests/layers/test_lm_losses.py::test_grpo_metrics— 40 cases pass after rename