Skip to content

Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing").#13681

Draft
gueraf wants to merge 34 commits intohuggingface:mainfrom
gueraf:wan-rolling-kv-cache
Draft

Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing").#13681
gueraf wants to merge 34 commits intohuggingface:mainfrom
gueraf:wan-rolling-kv-cache

Conversation

@gueraf
Copy link
Copy Markdown

@gueraf gueraf commented May 5, 2026

What does this PR do?

  • Implements a (rolling) KV cache for Wan models to enable autoregressive generation.
  • Tries to mirror the KV cache pattern in transformer_flux2.py as much as possible.
  • Vidoes and byte-level equivalence against upstream Self Forcing tested in https://github.com/gueraf/self-forcing-diffusers/releases/tag/inline-rolling-kv-20260504.
  • This initial PR does not yet implement sink-frame pinning yet, and lacks some model-level adjustments (Self Forcing has cross-attention QK norms and per-frame timestep modulation).
  • Add tests for cache append/overwrite/window behavior.

Motivation

This is a tightly scoped follow-up to #12773 and a first step toward #12600. The previous draft explored similar functionality but also included Krea-specific experiments and broader integration work.

As for practical use, we (https://odyssey.ml/) would like to rely on the Hugging Face Diffusers ecosystem to ship Self-Forcing-like models without having to ship many custom modules, ideally none.

Progresses #12600

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

gueraf added 4 commits May 4, 2026 11:22
Adds inline rolling KV cache classes to transformer_wan.py, enabling
efficient autoregressive chunk-wise video generation without recomputing
previous chunks' attention context.

New public API (importable from diffusers.models.transformers.transformer_wan):
- WanRollingKVBlockCache: per-block K/V storage with optional cross-attn cache
- WanRollingKVCache: container managing all blocks, write modes, and window trimming

Usage:
    cache = WanRollingKVCache(num_blocks=len(transformer.blocks), window_size=8000)
    transformer(..., attention_kwargs={"rolling_kv_cache": cache})

Key features:
- "append" mode: grows the temporal prefix each chunk
- "overwrite" mode: writes clean K/V at a specific absolute token offset,
  allowing arbitrary temporal placement (e.g. injecting ground-truth frames)
- window_size: trims oldest tokens to keep memory bounded
- cache_cross_attention: reuses text encoder projections across chunks
- frame_offset parameter on WanRotaryPosEmbed and WanTransformer3DModel.forward
  for correct temporal RoPE positioning during chunk-wise generation

Also adds examples/inference/autoregressive_video_generation.py demonstrating
chunk-wise generation with the Self-Forcing transformer.
Tests cover WanRollingKVBlockCache and WanRollingKVCache state, both helper
functions, append mode (incremental chunk assertions), window trimming,
overwrite mode, cross-attention caching, and frame offset behavior.
Split WanAttnProcessor.__call__ into _wan_self_attention and
_wan_cross_attention module-level functions; __call__ is now a
3-line router. Also hoists apply_rotary_emb to module level.

Moves rolling KV cache tests into test_models_transformer_wan.py
(deleted test_wan_rolling_kv_cache.py). Tests now use arange-based
deterministic inputs and assert on explicit Python lists.
…rename _tok

- _get_kv_projections: new helper that projects only K and V; used in the
  cross-attention cached path instead of discarding Q from _get_qkv_projections
- _wan_self_attention / _wan_cross_attention: annotate backend and parallel_config
  as AttentionBackendName | None and ParallelConfig | None (TYPE_CHECKING imports)
- Export WanRollingKVCache and WanRollingKVBlockCache from diffusers top-level
- Rename _tok to _arange_tokens in tests
@github-actions github-actions Bot added models tests size/L PR with diff > 200 LOC labels May 5, 2026
gueraf and others added 3 commits May 5, 2026 14:42
- Drop `import unittest`
- Replace `(unittest.TestCase)` bases with plain classes
- Replace `self.assert*` calls with bare `assert` and `pytest.raises`
- Rename `setUp` → `setup_method` (pytest convention for plain classes)
- Use `torch.equal` instead of `not torch.allclose` for frame-offset
  differing-output test: the tiny model produces a real but sub-1e-5
  difference that falls within allclose's default rtol

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
- Drop TestWanRollingKVBlockCache/TestWanRollingKVCache (initial-state
  and configure_write validation tests — not interesting behavior)
- Drop TestCrossAttentionCache, TestFrameOffset as standalone classes
- Collapse AppendMode/WindowSize/OverwriteMode/FrameOffset into one
  TestWanRollingKVCache with _chunk/_run/_len as instance methods
- Move _arange_tokens inline as TestTrimToWindow._tok and
  TestSliceForOverwrite._bc (keeps helpers local to their users)
- Shrink module-level config to a single compact dict

35 tests → 12 tests

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Move _TINY_CONFIG, _NUM_BLOCKS, _TOKENS_PER_CHUNK and all three test
classes (TestTrimToWindow, TestSliceForOverwrite, TestWanRollingKVCache)
into one TestWanRollingKVCache. Config and constants are class attributes
_CONFIG/_N/_T; trim/slice helpers share _tok with the forward tests.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
- self.t → self.transformer
- _N/_T → NUM_BLOCKS/TOKENS_PER_CHUNK (class constants, not cryptic)
- _CONFIG expanded to one key per line; num_layers=NUM_BLOCKS makes
  the relationship explicit
- _run(*args) → explicit (latents, timestep, encoder_hidden_states)
  params; cache is now required (every caller always passes one);
  dead None-branch removed
- _len(cache=None) → _cached_len(cache) — no implicit self.cache
  default; all callers name the cache explicitly
- _tok → _filled_block_cache (describes what it actually builds)
- test_overwrite: local T = self.TOKENS_PER_CHUNK to avoid repetition

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
…ffsets)

The old configure_write(absolute_token_offset=N) silently truncated the cache
to N+chunk_size, making mid-sequence overwrites a footgun. Restrict the API
to two well-defined modes:
- append: extend cache by chunk_size (transition to a new chunk)
- overwrite_end: drop last chunk_size tokens before appending (replace last
  chunk in place — used for subsequent denoising steps within the same chunk)

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
…rop-last helper

- setup_method only creates the transformer; each test creates its own cache
  with explicit window_size, making the (mode, window_size) configuration
  obvious from the test body
- Add test_append_windowed_three_chunks to exercise the rolling case where the
  window holds multiple chunks and surviving chunks shift left on eviction
- Simplify _wan_rolling_kv_drop_last_chunk to always slice (returning empty
  tensors when keep=0 instead of None), matching the previous overwrite path

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
gueraf and others added 15 commits May 6, 2026 13:13
Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
All previous _assert_unchanged / _assert_changed (with from/to ints) and
raw torch.equal / torch.allclose calls now go through one pair of helpers
that take pre-sliced tensors. Asymmetric slices (e.g. snap[T:3T] vs
snap_after[0:2T] in the windowed-rolling test) read consistently with
symmetric ones.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
The flag was never actually flipped in any caller — the only "user" was
self-forcing-diffusers' write_rolling_kv_cache, which set should_update=True
defensively before each chunk write. With the narrowed append/overwrite_end
API both modes always write; there is no read-without-update use case.

Removes the field, the guard inside _wan_self_attention, and the test that
toggled the flag.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
- Reorder TYPE_CHECKING imports alphabetically
- Reflow long argument lists to one-arg-per-line
- Reformat _import_structure list literal in models/__init__.py
- Replace dict() literal with {} in test config
- Drop unused WanRollingKVBlockCache import in tests

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…y bound to the self-attention path

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
… diff

Previously the extraction line lived right before the block loop, leaving
the WanTransformer3DModel.forward diff in three separated hunks. With it
moved next to the rope() call, the rope change and cache extraction render
together as one cohesive hunk in the upstream diff.

No behavioural change — the variable is still defined before the block
loop reads it.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…better

Previously the chunk of new content (WanRollingKVBlockCache, WanRollingKVCache,
helper functions, _wan_self_attention/_wan_cross_attention) was inserted
*before* WanAttnProcessor. GitHub's diff algorithm aligned the deleted lines
of the old WanAttnProcessor body with the inserted new content, rendering
"class WanAttnProcessor:" as if it had been replaced by "class WanRollingKVBlockCache:".

Move the entire new chunk to right after WanAttention. WanAttnProcessor stays
near its original line position with a small in-place change (two new
parameters + a refactored body that calls the new helpers), so GitHub now
shows it as a modification rather than a rewrite.

Quote `WanRollingKVCache | None` in WanAttnProcessor.__call__ since the type
is now defined later in the file.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Conventional Python ordering: helpers above their callers. Lazy annotations
(from __future__ import annotations) keep WanAttention's forward references
in helpers' signatures resolvable.

Trade-off: GitHub's diff aligner may once again pair upstream's deleted
WanAttnProcessor body with the new content, rendering it as a class rename.
The ordering wins out over the diff aesthetics.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
cache_cross_attention=True was never exercised by any caller — same kind of
dead-code surface area we removed for should_update. Drops:
- The cache_cross_attention flag on WanRollingKVCache.
- The cached_cross_* fields on WanRollingKVBlockCache.
- The use_cross_cache branches inside _wan_cross_attention (now matches
  upstream's plain cross-attention path again).
- _get_kv_projections (only used by the now-removed branch).
- The rolling_kv_cache / block_idx arguments on _wan_cross_attention.

Also converts WanRollingKVBlockCache to a dataclass — single source of truth
for fields and defaults; reset() just re-runs __init__.

A TODO on WanRollingKVCache notes that the cross-attn projections could be
reused across chunks since the text embeddings are constant; addable later.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
The previous extraction into _wan_self_attention / _wan_cross_attention was
needed when the cache also handled cross-attention. Now that cross-attn
caching is gone, the remaining cache logic is a 20-line block guarded on
encoder_hidden_states is None — small enough to live inline.

Returns WanAttnProcessor to the diffusers convention (single monolithic
__call__ that branches between self/cross via encoder_hidden_states),
matching every other attn processor in the repo. The diff against upstream
now reads as just two additions:
- rolling_kv_cache + block_idx params on the signature
- a self-attn-only cache block after the rotary embedding

Drops _wan_self_attention, _wan_cross_attention, _apply_rotary_emb, and
the now-unused TYPE_CHECKING imports for AttentionBackendName / ParallelConfig.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/L PR with diff > 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant