Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing").#13681
Draft
gueraf wants to merge 34 commits intohuggingface:mainfrom
Draft
Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing").#13681gueraf wants to merge 34 commits intohuggingface:mainfrom
gueraf wants to merge 34 commits intohuggingface:mainfrom
Conversation
Adds inline rolling KV cache classes to transformer_wan.py, enabling
efficient autoregressive chunk-wise video generation without recomputing
previous chunks' attention context.
New public API (importable from diffusers.models.transformers.transformer_wan):
- WanRollingKVBlockCache: per-block K/V storage with optional cross-attn cache
- WanRollingKVCache: container managing all blocks, write modes, and window trimming
Usage:
cache = WanRollingKVCache(num_blocks=len(transformer.blocks), window_size=8000)
transformer(..., attention_kwargs={"rolling_kv_cache": cache})
Key features:
- "append" mode: grows the temporal prefix each chunk
- "overwrite" mode: writes clean K/V at a specific absolute token offset,
allowing arbitrary temporal placement (e.g. injecting ground-truth frames)
- window_size: trims oldest tokens to keep memory bounded
- cache_cross_attention: reuses text encoder projections across chunks
- frame_offset parameter on WanRotaryPosEmbed and WanTransformer3DModel.forward
for correct temporal RoPE positioning during chunk-wise generation
Also adds examples/inference/autoregressive_video_generation.py demonstrating
chunk-wise generation with the Self-Forcing transformer.
Tests cover WanRollingKVBlockCache and WanRollingKVCache state, both helper functions, append mode (incremental chunk assertions), window trimming, overwrite mode, cross-attention caching, and frame offset behavior.
Split WanAttnProcessor.__call__ into _wan_self_attention and _wan_cross_attention module-level functions; __call__ is now a 3-line router. Also hoists apply_rotary_emb to module level. Moves rolling KV cache tests into test_models_transformer_wan.py (deleted test_wan_rolling_kv_cache.py). Tests now use arange-based deterministic inputs and assert on explicit Python lists.
…rename _tok - _get_kv_projections: new helper that projects only K and V; used in the cross-attention cached path instead of discarding Q from _get_qkv_projections - _wan_self_attention / _wan_cross_attention: annotate backend and parallel_config as AttentionBackendName | None and ParallelConfig | None (TYPE_CHECKING imports) - Export WanRollingKVCache and WanRollingKVBlockCache from diffusers top-level - Rename _tok to _arange_tokens in tests
- Drop `import unittest` - Replace `(unittest.TestCase)` bases with plain classes - Replace `self.assert*` calls with bare `assert` and `pytest.raises` - Rename `setUp` → `setup_method` (pytest convention for plain classes) - Use `torch.equal` instead of `not torch.allclose` for frame-offset differing-output test: the tiny model produces a real but sub-1e-5 difference that falls within allclose's default rtol Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
- Drop TestWanRollingKVBlockCache/TestWanRollingKVCache (initial-state and configure_write validation tests — not interesting behavior) - Drop TestCrossAttentionCache, TestFrameOffset as standalone classes - Collapse AppendMode/WindowSize/OverwriteMode/FrameOffset into one TestWanRollingKVCache with _chunk/_run/_len as instance methods - Move _arange_tokens inline as TestTrimToWindow._tok and TestSliceForOverwrite._bc (keeps helpers local to their users) - Shrink module-level config to a single compact dict 35 tests → 12 tests Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Move _TINY_CONFIG, _NUM_BLOCKS, _TOKENS_PER_CHUNK and all three test classes (TestTrimToWindow, TestSliceForOverwrite, TestWanRollingKVCache) into one TestWanRollingKVCache. Config and constants are class attributes _CONFIG/_N/_T; trim/slice helpers share _tok with the forward tests. Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
- self.t → self.transformer - _N/_T → NUM_BLOCKS/TOKENS_PER_CHUNK (class constants, not cryptic) - _CONFIG expanded to one key per line; num_layers=NUM_BLOCKS makes the relationship explicit - _run(*args) → explicit (latents, timestep, encoder_hidden_states) params; cache is now required (every caller always passes one); dead None-branch removed - _len(cache=None) → _cached_len(cache) — no implicit self.cache default; all callers name the cache explicitly - _tok → _filled_block_cache (describes what it actually builds) - test_overwrite: local T = self.TOKENS_PER_CHUNK to avoid repetition Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…ce unit tests Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…eady -1) Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…helpers Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…sert_unchanged Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…ll sites Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…servation Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…nk-loop test Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…ffsets) The old configure_write(absolute_token_offset=N) silently truncated the cache to N+chunk_size, making mid-sequence overwrites a footgun. Restrict the API to two well-defined modes: - append: extend cache by chunk_size (transition to a new chunk) - overwrite_end: drop last chunk_size tokens before appending (replace last chunk in place — used for subsequent denoising steps within the same chunk) Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…rop-last helper - setup_method only creates the transformer; each test creates its own cache with explicit window_size, making the (mode, window_size) configuration obvious from the test body - Add test_append_windowed_three_chunks to exercise the rolling case where the window holds multiple chunks and surviving chunks shift left on eviction - Simplify _wan_rolling_kv_drop_last_chunk to always slice (returning empty tensors when keep=0 instead of None), matching the previous overwrite path Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
All previous _assert_unchanged / _assert_changed (with from/to ints) and raw torch.equal / torch.allclose calls now go through one pair of helpers that take pre-sliced tensors. Asymmetric slices (e.g. snap[T:3T] vs snap_after[0:2T] in the windowed-rolling test) read consistently with symmetric ones. Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
The flag was never actually flipped in any caller — the only "user" was self-forcing-diffusers' write_rolling_kv_cache, which set should_update=True defensively before each chunk write. With the narrowed append/overwrite_end API both modes always write; there is no read-without-update use case. Removes the field, the guard inside _wan_self_attention, and the test that toggled the flag. Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
- Reorder TYPE_CHECKING imports alphabetically
- Reflow long argument lists to one-arg-per-line
- Reformat _import_structure list literal in models/__init__.py
- Replace dict() literal with {} in test config
- Drop unused WanRollingKVBlockCache import in tests
Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…y bound to the self-attention path Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…own fence Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
… diff Previously the extraction line lived right before the block loop, leaving the WanTransformer3DModel.forward diff in three separated hunks. With it moved next to the rope() call, the rope change and cache extraction render together as one cohesive hunk in the upstream diff. No behavioural change — the variable is still defined before the block loop reads it. Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…better Previously the chunk of new content (WanRollingKVBlockCache, WanRollingKVCache, helper functions, _wan_self_attention/_wan_cross_attention) was inserted *before* WanAttnProcessor. GitHub's diff algorithm aligned the deleted lines of the old WanAttnProcessor body with the inserted new content, rendering "class WanAttnProcessor:" as if it had been replaced by "class WanRollingKVBlockCache:". Move the entire new chunk to right after WanAttention. WanAttnProcessor stays near its original line position with a small in-place change (two new parameters + a refactored body that calls the new helpers), so GitHub now shows it as a modification rather than a rewrite. Quote `WanRollingKVCache | None` in WanAttnProcessor.__call__ since the type is now defined later in the file. Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Conventional Python ordering: helpers above their callers. Lazy annotations (from __future__ import annotations) keep WanAttention's forward references in helpers' signatures resolvable. Trade-off: GitHub's diff aligner may once again pair upstream's deleted WanAttnProcessor body with the new content, rendering it as a class rename. The ordering wins out over the diff aesthetics. Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
cache_cross_attention=True was never exercised by any caller — same kind of dead-code surface area we removed for should_update. Drops: - The cache_cross_attention flag on WanRollingKVCache. - The cached_cross_* fields on WanRollingKVBlockCache. - The use_cross_cache branches inside _wan_cross_attention (now matches upstream's plain cross-attention path again). - _get_kv_projections (only used by the now-removed branch). - The rolling_kv_cache / block_idx arguments on _wan_cross_attention. Also converts WanRollingKVBlockCache to a dataclass — single source of truth for fields and defaults; reset() just re-runs __init__. A TODO on WanRollingKVCache notes that the cross-attn projections could be reused across chunks since the text embeddings are constant; addable later. Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
The previous extraction into _wan_self_attention / _wan_cross_attention was needed when the cache also handled cross-attention. Now that cross-attn caching is gone, the remaining cache logic is a 20-line block guarded on encoder_hidden_states is None — small enough to live inline. Returns WanAttnProcessor to the diffusers convention (single monolithic __call__ that branches between self/cross via encoder_hidden_states), matching every other attn processor in the repo. The diff against upstream now reads as just two additions: - rolling_kv_cache + block_idx params on the signature - a self-attn-only cache block after the rotary embedding Drops _wan_self_attention, _wan_cross_attention, _apply_rotary_emb, and the now-unused TYPE_CHECKING imports for AttentionBackendName / ParallelConfig. Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Motivation
This is a tightly scoped follow-up to #12773 and a first step toward #12600. The previous draft explored similar functionality but also included Krea-specific experiments and broader integration work.
As for practical use, we (https://odyssey.ml/) would like to rely on the Hugging Face Diffusers ecosystem to ship Self-Forcing-like models without having to ship many custom modules, ideally none.
Progresses #12600
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.