Skip to content

RL training features (#502 minus GSPO)#520

Open
jlamypoirier wants to merge 27 commits into
mainfrom
jlp_rl_features
Open

RL training features (#502 minus GSPO)#520
jlamypoirier wants to merge 27 commits into
mainfrom
jlp_rl_features

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

Follow-up to #517 (which took the GSPO core from #502 and reimplemented it cleanly from main).

What this PR is

The remainder of #502 with the GSPO-specific content removed. Built on top of the original gspo branch with just three deletions on top.

What was removed

  • GSPO loss class, config, kernel, and dedicated test file — landed in Add GSPO loss #517 with a cleaner per-token decomposition.
  • LanguageModelKwargs.document_index and its data-pipeline plumbing — only needed for the GSPO kernel. Add GSPO loss #517 reads document_index_q from BlockKwargs instead.
  • normalize_by_documents flag on LanguageModelPolicyGradientLossConfig — was the DS-style /M loss with /M^2 gradient. Add GSPO loss #517 bakes /num_documents into GSPO unconditionally; the GRPO path here keeps the per-token-count normalization.
  • Kernel /sdp_size "fix" — only existed in the GSPO kernel; deleted with it.

What's kept (unchanged from #502)

  • docs_per_step — schedule config field, Schedule._eff_* properties, Trainer._prefetch_to_doc_target, and the corresponding unit tests.
  • fp32_lm_head — head config field, forward upcast + manual weight-grad backward block in head.py.
  • grad_divisor parameter on fused_grpo_loss_forward_backward and triton_grpo_loss_forward_backward — allows the gradient to use a different divisor than the loss. Currently always defaults to divisor (callers no longer pass a different value), but the plumbing is in place.

Diff stats

7 files changed, 11 insertions(+), 808 deletions(-) — net delete-only.

Test plan

  • pytest tests/layers/test_lm_losses.py::test_grpo_loss tests/layers/test_lm_losses.py::test_grpo_metrics tests/layers/test_docs_per_step.py — 70 cases pass

Note: this branch is behind main by one commit (#508). Will rebase before merge.

bigximik and others added 26 commits April 27, 2026 12:56
Adds GRPO metrics parity with DeepSpeed: old_logprobs, ratio, ratio_sum,
ratio_sq_sum, kl_new_old, clamp_frac, advantage, max/min_advantage,
num_tokens, and optional per-token entropy.

New files:
- fast_llm/layers/language_model/loss/pg_metrics.py: reusable
  PolicyGradientMetrics dataclass + compute_policy_gradient_metrics()
  (callable by future PPO), with chunked vocab-parallel entropy support.
- tests/layers/test_grpo_metrics.py: 8 unit tests covering single-seq,
  packed multi-seq, masked tokens, clamp fraction, entropy correctness,
  mock SDP correctness, mock vocab-parallel entropy, normalization parity.

Config additions to LanguageModelGRPOLossConfig:
- compute_extra_metrics (default False): log all non-entropy metrics
- compute_entropy_metric (default False): additionally log per-token entropy
- entropy_chunk_size (default 4096): batch chunk size for entropy pass

Normalization matches existing new_logprobs_mean: sum(v*mask/label_counts)
then divided by num_documents_in_batch. MAX/MIN use LossDef ReductionType
and correct ReduceOp so they aggregate correctly across microbatches and
SDP/sequence-parallel ranks.
Rename four metrics to match DeepSpeed's naming exactly so runs on both
backends produce comparable WandB keys:

  ratio        → ratio_new_old
  ratio_sum    → ratio_new_old_sum
  ratio_sq_sum → ratio_new_old_squared_sum
  clamp_frac   → clamp_log_ratio_new_old_indicator
Implements GSPO (geometric-mean sequence-level policy-gradient loss) as
an alternative to the existing per-token GRPO clipping. Controlled via
LanguageModelGRPOLossConfig.policy_loss = "gspo".

Key changes:
- data pipeline: expose per-token document_index when return_document_index=True
- LanguageModelKwargs.document_index: new kwarg constant
- LanguageModelLoss: store SDP dim for cross-rank segment aggregation
- grpo.py: fused_gspo_loss_forward_backward with all_reduce(SUM) across
  SDP ranks before computing segment-level R_s and A_s; gradient derivation
  exploits tok_count cancellation so every token in a segment gets the
  same gradient factor R_s * clip_indicator_s
- tests/layers/test_gspo_loss.py: 8 unit tests (single-segment, packed,
  ratio-1 equivalence, clipping, masking, SDP mock, gradient finite-diff,
  per-token metrics unchanged)
Adds ScheduleConfig.rollouts_per_step (default 0). When >0, TrainerConfig._from_dict
computes depth_first_micro_batches = rollouts_per_step // (batch_data_parallel ×
breadth_first_micro_batches) before sub-configs are created (and frozen).

Matches DeepSpeed gradient_accumulation_passes semantics for RL: with train_batch_size=1
each microbatch holds one rollout, so setting rollouts_per_step=1024 with data_parallel=8
gives depth_first_micro_batches=128 → exactly 1024 rollouts per optimizer step globally.

YAML usage:
  schedule:
    rollouts_per_step: 1024   # replaces manual depth_first_micro_batches
  model:
    distributed:
      data_parallel: 8        # used for the division
- Rename rollouts_per_step → docs_per_step in ScheduleConfig; depth_first
  is now determined at runtime rather than statically in _from_dict
- Add Schedule._depth_first_override and _eff_{depth_first,sequential,num_inputs}
  properties so per-step schedules share the same config object as the runner
- Add Trainer._prefetch_to_doc_target: fetches microbatches one at a time,
  all-reduces doc count per microbatch, stops when global total ≥ docs_per_step,
  then resets num_documents_in_batch to the step total on all inputs
- Add Trainer._get_or_build_schedule: builds and caches per-N Schedule with
  _depth_first_override=N//breadth_first_micro_batches
- Add normalize_by_documents flag to LanguageModelGRPOLossConfig; when True
  both GRPO and GSPO paths divide by num_documents_in_batch instead of
  num_labels_in_batch (matches DeepSpeed's per-rollout normalization)
- Add tests/layers/test_docs_per_step.py: 13 unit tests covering divisor
  scaling, normalize_by_documents layer routing, Schedule._eff_* properties,
  and _prefetch_to_doc_target accumulation logic
Add temperature field to LanguageModelGRPOLossConfig. When set to match
the actor's sampling temperature (e.g. 0.7), new log-probs are computed
at the same temperature as the stored old log-probs, so the IS ratio
starts near 1.0 instead of ~1.08.

Implementation: _effective_logits_scale = logits_scale_factor / temperature,
substituted for logits_scale_factor at all three callsites in
_forward_backward (GRPO path, GSPO path, _register_pg_metrics). Default
temperature=1.0 preserves existing behaviour exactly.
Add fp32_lm_head to LanguageModelHeadConfig. When enabled, input hidden
states and output_weights are cast to float32 before the lm_head linear,
producing FP32 logits. This matches vLLM's bf16_last_layer_fp32
quantization (pipelinerl/vllm_quantization.py) and the DeepSpeed trainer's
apply_fp32_lm_head() patch, so new_logprobs and old_logprobs are computed
at the same numerical precision and the IS ratio starts near 1.0 at init.

The gradient flowing back through the linear is cast to the original
input dtype (bf16) before returning, keeping the transformer backward pass
in its native dtype.
…accumulation

Detaching the FP32 weight copy (requires_grad=False) prevents
output_parallel_linear_backward from trying to write to a non-existent
grad_buffer on the copy. Weight grad is then computed explicitly from
the FP32 matmul and accumulated into the original BF16 param's grad_buffer
via accumulate_gradient, restoring the correct FSDP gradient contract.
When normalize_by_documents=true, fast-LLM's reported grad_norm was ~1024×
larger than DeepSpeed's for the equivalent loss, causing the default
gradient_norm_clipping=0.3 to over-clip by ~500× and making training ~10
reward points slower than DS GSPO at the same step count. The lm_head_loss
metric was also off — 1024× smaller than DS's rl/loss in the previous
divisor=num_documents² formulation, then 2× too large from SDP doubling.

Root cause analysis
-------------------

DeepSpeed has TWO 1/batch_size factors with different sources:

  1. Loss reported (rl/loss) uses /batch_size via tokens_weights = 1/batch_size
     (pipelinerl/finetune/rl/__init__.py:246). The reported `rl/loss = -1.7`
     value is the raw policy_loss_total, divided once by batch_size.

  2. Gradient buffer uses an ADDITIONAL /(gas × world_size) factor that comes
     from `scale_wrt_gas=True` in engine.backward()
     (deepspeed/runtime/engine.py:1995-1996) and `tensor.div_(world_sz)` in
     reduce_scatter_coalesced (deepspeed/runtime/comm/coalesced_collectives.py:124).

For DS with samples_per_microbatch=1 (PipelineRL standard), gas × world_size
= batch_size, so DS's effective gradient buffer factor is 1/batch_size² while
the loss metric factor is 1/batch_size. Loss and gradient have asymmetric
scaling.

Fast-LLM's existing implementation used a single `divisor` for both loss and
gradient. Worse, the data_parallel × grad_scale factor in grad_output
(runner.py:318) cancels with FSDP's RS-AVG /world_size, structurally removing
DS's /(gas × world_size) factor from the gradient. So fast-LLM's gradient
buffer ended up at 1/batch_size while DS's was at 1/batch_size² — a
~batch_size = 1024× mismatch.

Additionally, GSPO's SDP allreduce of lrn_sum/adv_sum/tok_sum makes both SDP
ranks compute IDENTICAL per-segment loss values. When LossDef.reduce sums
over the data_group (which includes SDP ranks), the loss metric is
double-counted by sdp_size. The gradient buffer is NOT double-counted —
each SDP rank contributes gradient from its own LOCAL tokens, with different
contributions for different tokens of the same segment.

Fixes
-----

1. Add a `grad_divisor` parameter to `fused_gspo_loss_forward_backward`,
   `fused_grpo_loss_forward_backward`, and `triton_grpo_loss_forward_backward`,
   defaulting to `divisor` (existing behavior). Allows the gradient to use a
   different divisor than the loss.

2. In `LanguageModelGRPOLoss._forward_backward`, when normalize_by_documents
   is True, set:
     loss divisor      = num_documents_in_batch     (matches DS rl/loss)
     gradient divisor  = num_documents_in_batch²    (matches DS grad_norm)
   This is independent of TP/PP/SDP/DP parallelism and microbatching schedule
   because batch_size is invariant under all of these.

3. In the GSPO path, divide the loss by sdp_size when sdp_group is active
   (`fused_gspo_loss_forward_backward`). This pre-cancels the SDP doubling
   that LossDef.reduce's SUM over data_group introduces. The gradient is
   unaffected — different SDP ranks naturally contribute gradient from
   different LOCAL token positions, no double-counting at any layer.

Verification
------------

Tested on 7B math run with 4 nodes, GSPO, gradient_norm_clipping=0.3:

  Before fix          | After fix          | DS GSPO reference
  ------------------- | ------------------ | ------------------
  step 1 grad_norm=141| step 1 grad_norm=0.135 | step 1 grad_norm=0.145
  step 1 lm_head_loss | step 1 lm_head_loss   | step 1 rl/loss
   = -13.7            |  ~ -1.7 (sign varies  |   = -1.7
                      |   per data sample)    |
  clip_coeff=0.002    | clip_coeff=1.000      | no clipping at step 1
  newlp at step 50    | newlp at step 50      | newlp at step 50
   trapped at -0.17   |  = -0.103             |  = -0.105

newlp trajectory tracks DS step-by-step: step 1 within 3%, step 50 within 2%.
Both systems show grad_norm spikes at the same training phase (steps 14-20)
during warmup ramp-up — DS step 16 grad_norm=6.365 vs Fast-LLM 6.093.

Files changed
-------------

- fast_llm/layers/language_model/loss/grpo.py:
  - LanguageModelGRPOLoss._forward_backward: split divisor and grad_divisor
    based on normalize_by_documents flag, with detailed comments referencing
    the corresponding lines in DeepSpeed and PipelineRL.
  - fused_gspo_loss_forward_backward: add grad_divisor parameter; divide loss
    by sdp_size when sdp_group is active.
  - fused_grpo_loss_forward_backward: add grad_divisor parameter.

- fast_llm/functional/triton/grpo_loss.py:
  - triton_grpo_loss_forward_backward: add grad_divisor parameter.
- Inline pg_metrics.py into grpo.py; rename to GRPOMetrics
- Drop entropy_chunk_size; reuse fused_softmax_base outputs for entropy
- Replace two bool flags with a single metrics: GRPOMetricsLevel enum
- Rename clamp_log_ratio_new_old_indicator -> clipped_ratio_fraction
- Raise on metrics enabled with pipeline_parallel > 1 (MAX/MIN reduce
  would be corrupted by the zero placeholder on empty pipeline ranks)
- Migrate tests into tests/layers/test_lm_losses.py, reusing the
  existing helpers and parametrization (single + distributed runner)

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
- Drop stale "second softmax pass" overhead note from `metrics`
  description (entropy now reuses the base softmax outputs)
- De-mirror max/min in reference_grpo_metrics: use
  advantages[loss_mask].max()/.min() instead of the implementation's
  -inf/+inf sentinel pattern

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
- Align (logits, target, advantages, old_log_probabilities, ...) order
  across compute_grpo_metrics, fused_grpo_loss_forward_backward, and
  reference_grpo_metrics
- Replace **kwargs in LanguageModelGRPOLoss.__init__ with the explicit
  keyword-only signature mirroring LanguageModelLoss.__init__
- num_docs -> num_documents
- Drop the comment that restated the k3 KL formula
- Give compute_grpo_metrics the same defaults as the loss kernel
- Trim the metrics field description to category-level wording
- Always exercise varying label_counts in _test_grpo_metrics so per-token
  denominator broadcasting is covered
- reference_grpo_metrics returns GRPOMetrics; comparison loop iterates
  dataclasses.fields
- Drop name = self._name micro-rebinds; use self._name inline
- defs = super()...; defs.append(...); defs.extend(...) consistently
- Tighten _register_extra_metrics losses type to dict[str, list[Tensor]]
- Split compiled tuple-returning core from outer GRPOMetrics wrapper to
  avoid @torch.compile graph-breaks on dataclass construction
- One-line comment on the metrics gate explaining the softmax-skip

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
NamedTuple is a tuple subclass that dynamo handles natively, so the
previous wrapper/inner split (added to dodge a dataclass graph-break)
collapses into one @torch.compile function. Field order now lives
exactly once — on the class.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
- Entropy under vocab-parallel TP was wrong: the dot-product term
  (exp_logits * logits_norm).sum(-1) summed only the local vocab slice,
  so dividing by the global sum_exp_logits gave a per-rank fragment
  instead of the full E_p[logit_norm]. All-reduce the partial sum.
- Replace the verbose pipeline-parallel guard with Assert.custom; the
  field description already explains the constraint.
- Drop the cryptic `# k3` comment.
- Match _register_extra_metrics losses annotation to the base class
  (dict | None).

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
# Conflicts:
#	fast_llm/layers/language_model/loss/config.py
#	fast_llm/layers/language_model/loss/grpo.py
# Conflicts:
#	fast_llm/layers/language_model/loss/config.py
#	fast_llm/layers/language_model/loss/grpo.py
- Drop unused self._preprocessing_config store in Trainer.setup.
- Replace torch.ones + index_add_ with torch.bincount for tok_sum
  in fused_gspo_loss_forward_backward.
- Drop load-bearing-sounding docs_per_step reference from the
  normalize_by_documents field description (no cross-config check
  exists to enforce it).

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Splits the policy-gradient loss config and class hierarchy:

- LanguageModelPolicyGradientLossConfig (abstract base): shared fields
  (epsilon_low/high, metrics, normalize_by_documents, temperature).
- LanguageModelGRPOLossConfig: registers `type: grpo` (keeps GRPO-only
  use_triton).
- LanguageModelGSPOLossConfig: registers `type: gspo`.
- LanguageModelPolicyGradientLoss (abstract base): shared
  __init__/_forward_backward/_register_extra_metrics/get_loss_definitions/
  get_preprocessing_config plumbing; abstract `_call_kernel`.
- LanguageModelGRPOLoss / LanguageModelGSPOLoss: each implements
  `_call_kernel` against its kernel; GSPO overrides
  `get_preprocessing_config` to add `return_document_index`.

Drops the stringly-typed `policy_loss: str` switch and the in-method
if/else dispatch, addressing review items #1 and #5 plus Note 2.

YAML migration: `type: grpo` + `policy_loss: gspo` → `type: gspo`.
No checked-in YAML configs use the old form.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Replaces the abstract `_call_kernel` + per-algorithm subclass pattern
with the assignment-at-init pattern used by `Normalization._forward`.

- Single LanguageModelPolicyGradientLoss class hosts both kernel calls
  as `_call_grpo_kernel` and `_call_gspo_kernel`.
- __init__ assigns `self._call_kernel` to the matching method based on
  isinstance(config, LanguageModelGSPOLossConfig).
- get_preprocessing_config dispatches inline on the same isinstance.
- Both LanguageModelGRPOLossConfig and LanguageModelGSPOLossConfig
  return the same loss class — the YAML-side type split (registered
  via @config_class(dynamic_type=...)) stays as in #1.

Drops ~30 lines net from grpo.py: removes the abstract `_call_kernel`
declaration and the two single-method subclasses.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Reverts the class merge from d2c051a in favor of the assignment-at-init
pattern used by Normalization._forward. Drops the per-call _call_kernel
wrapper that just shuffled args.

- LanguageModelPolicyGradientLoss now hosts only shared scaffolding:
  _compute_divisors (token vs document), _shared_kernel_kwargs (the
  9 kwargs both kernels accept), _finalize_loss (post-call register
  + extra metrics), and the per-token metrics machinery.
- LanguageModelGRPOLoss and LanguageModelGSPOLoss are restored. Each
  __init__ assigns self._forward to the actual kernel function:
    GRPO: triton_grpo_loss_forward_backward or fused_grpo_loss_forward_backward
    GSPO: fused_gspo_loss_forward_backward
- Each subclass's _forward_backward calls self._forward(...) directly
  with the kernel's real signature; no intermediate wrapper.
- Configs map type:grpo → LanguageModelGRPOLoss, type:gspo →
  LanguageModelGSPOLoss again.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
# Conflicts:
#	fast_llm/layers/language_model/config.py
#	fast_llm/layers/language_model/head.py
GSPO core landed via #517 (cleaner reimplementation from main). Drop the
GSPO loss class, config, kernel, dedicated test file, and the supporting
LanguageModelKwargs.document_index plumbing (#517 reads document_index_q
from BlockKwargs instead).

Also drop two GSPO-specific knobs that no longer apply once GSPO is
removed:

- normalize_by_documents on LanguageModelPolicyGradientLossConfig — was
  GRPO/GSPO's DS-style /M loss with /M^2 gradient. The GSPO loss in #517
  bakes /num_documents in unconditionally and the GRPO path here keeps
  the per-token-count normalization.
- The kernel's /sdp_size "fix" only existed in the GSPO kernel (global
  per-segment loss made identical on every SDP rank); deleted with the
  GSPO kernel.

The rest of #502 (docs_per_step, fp32_lm_head, grad_divisor parameter on
GRPO kernels) is preserved as-is for follow-up review.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@jlamypoirier jlamypoirier mentioned this pull request May 19, 2026
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants