Triton GSPO kernel#522
Open
jlamypoirier wants to merge 2 commits into
Open
Conversation
Triton backward kernel mirrors GRPO's backward — same softmax chain rule
through (softmax_k - delta_{k,target}), with the per-token IS ratio replaced
by the segment-broadcast R_{s(t)} and the loss mask scaled by 1/token_count_s
(token_weight). The forward pass reuses the existing
triton_cross_entropy_forward_from_labels_parallel_kernel to produce
max/sum/predicted_logit per token (with TP support via the same
parallel_sum_exp_logits dance as GRPO); segment aggregation, loss, and the
SDP all-reduce live in PyTorch between the two Triton passes.
Triton is opt-in via a new LanguageModelGSPOLossConfig.use_triton field
(mirrors GRPO config).
Test coverage: `test_gspo_loss` now also runs the Triton path when
available — 20 cases pass under TRITON_INTERPRET=1.
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Follow-up to #517. Depends on
jlp_gspofor the PyTorch GSPO reference.Summary
Triton kernel for the GSPO loss, exposed via a new
LanguageModelGSPOLossConfig.use_tritonfield (matches GRPO's pattern).Design
Forward and backward are split because segment-level
R_sandA_sneed every contributing token before any per-token computation can finish. Triton doesn't sync across blocks, so that work lands in PyTorch between two Triton kernel launches:triton_cross_entropy_forward_from_labels_parallel_kernelto compute per-tokenmax_logits,sum_exp_logits,predicted_logit. With TP, the existingparallel_sum_exp_logits+ all-reduce dance merges the shards — same as the Triton GRPO TP path.new_log_probs,log_ratioper token; segment-aggregate viascatter_add; optional SDP all-reduce; computeR_s,A_s,token_weight = mask_t / token_count_s; compute the loss scalar.triton_gspo_loss_backward_kernel. Identical chain rule to GRPO's backward (grad_logit_i = effective_grad * (softmax_i - delta_{i,target})) but witheffective_gradbuilt from the segment-broadcastR_s,A_s, andtoken_weightinstead of per-tokenratio,advantage,mask.The kernel only fuses what's actually compute-bound — the softmax + chain-rule loops over the vocab dimension. The cheap per-segment work stays in PyTorch where readability wins.
Test plan
TRITON_INTERPRET=1 pytest tests/layers/test_lm_losses.py::test_gspo_loss— 20 cases pass against the reference and the PyTorch fused kernelpytest tests/layers/test_lm_losses.py::test_grpo_loss tests/layers/test_lm_losses.py::test_grpo_metrics— 60 cases pass (sanity check that adding the GSPO triton path didn't disturb GRPO)