Skip to content

Triton GSPO kernel#522

Open
jlamypoirier wants to merge 2 commits into
jlp_gspofrom
jlp_gspo_triton
Open

Triton GSPO kernel#522
jlamypoirier wants to merge 2 commits into
jlp_gspofrom
jlp_gspo_triton

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

Follow-up to #517. Depends on jlp_gspo for the PyTorch GSPO reference.

Summary

Triton kernel for the GSPO loss, exposed via a new LanguageModelGSPOLossConfig.use_triton field (matches GRPO's pattern).

Design

Forward and backward are split because segment-level R_s and A_s need every contributing token before any per-token computation can finish. Triton doesn't sync across blocks, so that work lands in PyTorch between two Triton kernel launches:

  1. Forward (Triton): reuses the existing triton_cross_entropy_forward_from_labels_parallel_kernel to compute per-token max_logits, sum_exp_logits, predicted_logit. With TP, the existing parallel_sum_exp_logits + all-reduce dance merges the shards — same as the Triton GRPO TP path.
  2. PyTorch: derive new_log_probs, log_ratio per token; segment-aggregate via scatter_add; optional SDP all-reduce; compute R_s, A_s, token_weight = mask_t / token_count_s; compute the loss scalar.
  3. Backward (Triton): new triton_gspo_loss_backward_kernel. Identical chain rule to GRPO's backward (grad_logit_i = effective_grad * (softmax_i - delta_{i,target})) but with effective_grad built from the segment-broadcast R_s, A_s, and token_weight instead of per-token ratio, advantage, mask.

The kernel only fuses what's actually compute-bound — the softmax + chain-rule loops over the vocab dimension. The cheap per-segment work stays in PyTorch where readability wins.

Test plan

  • TRITON_INTERPRET=1 pytest tests/layers/test_lm_losses.py::test_gspo_loss — 20 cases pass against the reference and the PyTorch fused kernel
  • pytest tests/layers/test_lm_losses.py::test_grpo_loss tests/layers/test_lm_losses.py::test_grpo_metrics — 60 cases pass (sanity check that adding the GSPO triton path didn't disturb GRPO)
  • GPU verification (needs cluster run)

jlamypoirier and others added 2 commits May 19, 2026 17:10
Triton backward kernel mirrors GRPO's backward — same softmax chain rule
through (softmax_k - delta_{k,target}), with the per-token IS ratio replaced
by the segment-broadcast R_{s(t)} and the loss mask scaled by 1/token_count_s
(token_weight). The forward pass reuses the existing
triton_cross_entropy_forward_from_labels_parallel_kernel to produce
max/sum/predicted_logit per token (with TP support via the same
parallel_sum_exp_logits dance as GRPO); segment aggregation, loss, and the
SDP all-reduce live in PyTorch between the two Triton passes.

Triton is opt-in via a new LanguageModelGSPOLossConfig.use_triton field
(mirrors GRPO config).

Test coverage: `test_gspo_loss` now also runs the Triton path when
available — 20 cases pass under TRITON_INTERPRET=1.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant