Skip to content

Canonicalize varlen cu_seqlens_k; share K/V buffer across micro-sequences#514

Open
jlamypoirier wants to merge 2 commits into
mainfrom
jlp_varlen-cu-seqlens-fix
Open

Canonicalize varlen cu_seqlens_k; share K/V buffer across micro-sequences#514
jlamypoirier wants to merge 2 commits into
mainfrom
jlp_varlen-cu-seqlens-fix

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented May 14, 2026

Summary

Two related cleanups on top of the SDPA causal-alignment fix already in main (#512):

  • Canonical cu_seqlens_k. Preprocessor now emits cu_seqlens_k[0] = 0 (was first_document_begin), the layout every public varlen attention API documents. document_index_k / position_index are recomputed against the narrowed K extent so they stay consistent. Downstream consumers (kernel inputs, backup mask cols) narrow K by first_document_begin to match.
  • Pre-allocated per-layer K/V buffer. One buffer per attention layer is reused across all micro-sequences of a sequence. Each forward writes the SDP-gather result into the next slice via gather_op(out=); backward accumulates per-micro-seq K/V grad into a shared grad buffer slice. The leading + trailing narrows and the per-step torch.cat / AttachGrad cross-micro-seq splice are absorbed into _query_key_value's custom autograd region — no eager cat, no detached-leaf graph edge per step.

Test plan

  • tests/data/test_preprocessing.py — 843/843 pass, updated to expect canonical layout.
  • tests/layers/test_attention.py — 57/57 pass (CPU + CUDA), including new first_document_begin regression case.
  • New regression test test_attention[first_document_begin-[4, 1, 10]] — injects a fake past K/V slot, drives attention with sequence_k_past > 0 and first_document_begin > 0, verifies output + parameter grads match a per-doc reference and that slot.grad_buffer[:past_length] is exactly zero (the specific guarantee of the narrow). Passes for backup, sdpa_dense, flash, sdpa_nested.

🤖 Generated with Claude Code

Base automatically changed from jlp_sdpa-attention to main May 19, 2026 18:11
The data preprocessor emitted `cu_seqlens_k[0] = first_document_begin` rather
than 0, violating the canonical varlen prefix-sum layout required by every
public varlen attention API. SDPA's EFFICIENT backward writes corrupted dK/dV
rows when fed this layout, propagating wrong gradients through the K/V
projection's reduce-scatter under sequence-data-parallel + micro-batch splits.

Three changes that compose:
- `LengthModelInputPreprocessor` now produces `cu_seqlens_k` starting at 0 and
  narrows `document_index_k` / `position_index` to the active K extent. The
  dropped leading-prefix length is exposed as a new `first_document_begin` int
  kwarg.
- Pre-allocate one K/V buffer per attention layer across all micro-sequences
  of a sequence. Each forward writes the SDP-gather result into the next slice
  via `gather_op(out=)`; backward accumulates each micro-seq's K/V grad into
  a shared grad buffer slice. The leading + trailing narrows and the per-step
  `torch.cat` / `AttachGrad` workaround for the cross-micro-seq splice are all
  absorbed into the `_query_key_value` custom autograd region.
- `_preprocess_for_backup_attention` builds the attention mask against the
  narrowed K cols so `sdpa_dense` and `backup` consume the same K extent as
  flash and `sdpa_nested`.

Update `tests/data/test_preprocessing.py` to expect the canonical layout.
`_test_first_document_begin` injects a fake past K/V slot with arbitrary leading
data, drives attention through a manually-built kwargs with `sequence_k_past`
and `first_document_begin` both set to a non-zero `past_length`, and verifies:
- forward output matches a per-doc reference computed on the active documents
  alone (the dropped prefix has no observable effect),
- parameter gradients match the reference,
- the K/V grad buffer at `[:past_length]` is exactly zero — the specific
  guarantee of the cu_seqlens_k canonicalization fix.

Runs backup + sdpa_dense on fp32, flash + sdpa_nested on bf16 (flash rejects
fp32). Plugged into the existing `test_attention` parametrization as a new
case with `name="first_document_begin"`, dispatched via name check.
@jlamypoirier jlamypoirier force-pushed the jlp_varlen-cu-seqlens-fix branch from 5761ff6 to e09bea7 Compare May 19, 2026 18:51
@jlamypoirier jlamypoirier changed the title Fix non-canonical cu_seqlens_k from preprocessor Canonicalize varlen cu_seqlens_k; share K/V buffer across micro-sequences May 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant