Skip to content

Add GSPO loss#517

Open
jlamypoirier wants to merge 2 commits into
mainfrom
jlp_gspo
Open

Add GSPO loss#517
jlamypoirier wants to merge 2 commits into
mainfrom
jlp_gspo

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

Summary

  • Adds Group Sequence Policy Optimization (per-segment geometric-mean IS-ratio clipping) as a sibling to GRPO. Sequence-level surrogate, computed as a per-token sum so the softmax-chain backward and SDP partitioning fall out identically to GRPO.
  • Extracts shared abstract bases (LanguageModelPolicyGradientLossConfig / LanguageModelPolicyGradientLoss) for the common scaffolding (epsilon_low/high, new_logprobs metric, preprocessing flags) and refactors GRPO to inherit from them.
  • Lifts document_index_q/document_index_k from MixerKwargs up to BlockKwargs so the LM head can read them without depending on attention-namespaced keys.
  • Renames fast_llm/layers/language_model/loss/grpo.py to policy_gradient.py (now contains both losses + shared base).

Design notes

The GSPO kernel structurally parallels fused_grpo_loss_forward_backward:

  1. Same softmax → predicted_logits → new_log_probs setup.
  2. New mid-kernel block: scatter-add per-token log_ratio, advantages, and counts into per-segment buffers; optional SDP all-reduce of the three buffers; compute R_s = exp(lrn_sum / token_count) and A_s = adv_sum / token_count (detached).
  3. Broadcast R_s, A_s back to per-token; per-token loss weight is mask / token_count_s so each segment contributes once to the sum.
  4. Downstream loss and the softmax-chain backward are line-for-line identical to GRPO with probability_ratio = R_{s(t)} and per-token advantages replaced by A_{s(t)}.

The per-token decomposition gets SDP correctness "for free": each rank only sums contributions from its own tokens, so SUM-reducing at the LossDef level reproduces the canonical single-rank result with no /sdp_size correction.

No Triton variant yet — comes in a follow-up.

Test plan

  • pytest tests/layers/test_lm_losses.py::test_gspo_loss — 20 cases (10 param sets × 2 batch shapes)
  • pytest tests/layers/test_lm_losses.py::test_grpo_loss — 20 cases pass after refactor
  • pytest tests/layers/test_lm_losses.py::test_grpo_metrics — 40 cases pass after rename
  • End-to-end RL training run (not exercised by this PR)

Group Sequence Policy Optimization: per-segment geometric-mean IS-ratio
clipping. Mirrors GRPO's structure via shared abstract bases
(LanguageModelPolicyGradientLossConfig / LanguageModelPolicyGradientLoss);
the kernel matches GRPO except for a segment-aggregation block that
produces per-segment R and A and broadcasts them back, so the softmax-chain
backward is identical to GRPO. SDP-aware via optional all-reduce of
segment sums; per-token weighting (mask / token_count_s) lets the SUM
reduction at LossDef level give the canonical result without further
correction. PyTorch kernel only; no Triton variant yet.

Also lifts document_index_q/k from MixerKwargs to BlockKwargs so the LM
head can read them without cross-namespace coupling, and renames
grpo.py -> policy_gradient.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
# Conflicts:
#	fast_llm/layers/attention/config.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant