Canonicalize varlen cu_seqlens_k; share K/V buffer across micro-sequences#514
Open
jlamypoirier wants to merge 2 commits into
Open
Canonicalize varlen cu_seqlens_k; share K/V buffer across micro-sequences#514jlamypoirier wants to merge 2 commits into
jlamypoirier wants to merge 2 commits into
Conversation
The data preprocessor emitted `cu_seqlens_k[0] = first_document_begin` rather than 0, violating the canonical varlen prefix-sum layout required by every public varlen attention API. SDPA's EFFICIENT backward writes corrupted dK/dV rows when fed this layout, propagating wrong gradients through the K/V projection's reduce-scatter under sequence-data-parallel + micro-batch splits. Three changes that compose: - `LengthModelInputPreprocessor` now produces `cu_seqlens_k` starting at 0 and narrows `document_index_k` / `position_index` to the active K extent. The dropped leading-prefix length is exposed as a new `first_document_begin` int kwarg. - Pre-allocate one K/V buffer per attention layer across all micro-sequences of a sequence. Each forward writes the SDP-gather result into the next slice via `gather_op(out=)`; backward accumulates each micro-seq's K/V grad into a shared grad buffer slice. The leading + trailing narrows and the per-step `torch.cat` / `AttachGrad` workaround for the cross-micro-seq splice are all absorbed into the `_query_key_value` custom autograd region. - `_preprocess_for_backup_attention` builds the attention mask against the narrowed K cols so `sdpa_dense` and `backup` consume the same K extent as flash and `sdpa_nested`. Update `tests/data/test_preprocessing.py` to expect the canonical layout.
`_test_first_document_begin` injects a fake past K/V slot with arbitrary leading data, drives attention through a manually-built kwargs with `sequence_k_past` and `first_document_begin` both set to a non-zero `past_length`, and verifies: - forward output matches a per-doc reference computed on the active documents alone (the dropped prefix has no observable effect), - parameter gradients match the reference, - the K/V grad buffer at `[:past_length]` is exactly zero — the specific guarantee of the cu_seqlens_k canonicalization fix. Runs backup + sdpa_dense on fp32, flash + sdpa_nested on bf16 (flash rejects fp32). Plugged into the existing `test_attention` parametrization as a new case with `name="first_document_begin"`, dispatched via name check.
5761ff6 to
e09bea7
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two related cleanups on top of the SDPA causal-alignment fix already in main (#512):
cu_seqlens_k[0] = 0(wasfirst_document_begin), the layout every public varlen attention API documents.document_index_k/position_indexare recomputed against the narrowed K extent so they stay consistent. Downstream consumers (kernel inputs, backup mask cols) narrow K byfirst_document_beginto match.gather_op(out=); backward accumulates per-micro-seq K/V grad into a shared grad buffer slice. The leading + trailing narrows and the per-steptorch.cat/AttachGradcross-micro-seq splice are absorbed into_query_key_value's custom autograd region — no eager cat, no detached-leaf graph edge per step.Test plan
tests/data/test_preprocessing.py— 843/843 pass, updated to expect canonical layout.tests/layers/test_attention.py— 57/57 pass (CPU + CUDA), including newfirst_document_beginregression case.test_attention[first_document_begin-[4, 1, 10]]— injects a fake past K/V slot, drives attention withsequence_k_past > 0andfirst_document_begin > 0, verifies output + parameter grads match a per-doc reference and thatslot.grad_buffer[:past_length]is exactly zero (the specific guarantee of the narrow). Passes for backup, sdpa_dense, flash, sdpa_nested.🤖 Generated with Claude Code