fix(hy_worldplay): RoPE memory frames at their true positions (fixes choppy native output) + Re-evaluated perfs#318
Conversation
Greptile SummaryThis PR fixes two bugs introduced when the memory-prefill path was first added: memory frames were RoPE-rotated at their buffer-slot index (
Confidence Score: 5/5Safe to merge — all three fixes address confirmed bugs with clear root-cause explanations and are well-guarded by the existing test suite (94 passed / 4 skipped on full HY suite). The RoPE position fix is a one-line change from No files require special attention. The bench script Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[prefill_memory_kv_cache called\nper chunk] --> B[slice clean_latent_history\nat selected frame indices]
B --> C["selected_idx_t = tensor(selected, long)"]
C --> D["t_positions = selected_idx_t.to(float32)\n✅ true temporal positions\n❌ OLD: arange(K) — slot indices"]
D --> E["_build_memory_rope_freqs\ncache, t_positions"]
E --> F["rope._freq_components(t_positions)\nfreqs_t, freqs_h, freqs_w"]
F --> G["rope_freqs shape:\nK×tokens_per_frame, 1, 1, head_dim"]
G --> H{self.network is\nOptimizedModule?}
H -->|"hasattr(_orig_mod) == True"| I["network = network._orig_mod\n✅ unwrap to HyWorldPlayWanDiTNetwork"]
H -->|"no wrapper"| J
I --> J["assert isinstance\nHyWorldPlayWanDiTNetwork"]
J --> K["network.prefill_memory_kv_cache\ncond pass"]
K --> L{CFG enabled?}
L -->|Yes| M["network.prefill_memory_kv_cache\nuncond pass"]
L -->|No| N[done]
M --> N
Reviews (6): Last reviewed commit: "docs(hy_worldplay): refresh parity after..." | Re-trigger Greptile |
| def _capture(*, cache, t_positions): # noqa: ANN001 (test stub) | ||
| captured["t"] = t_positions.detach().clone() | ||
| raise _Stop() | ||
|
|
||
| transformer._build_memory_rope_freqs = _capture # type: ignore[assignment] | ||
|
|
||
| def positions_for(selected: list[int]) -> list[float]: | ||
| inp = types.SimpleNamespace( | ||
| memory_frame_indices=selected, | ||
| rollout_viewmats=None, | ||
| viewmats=None, | ||
| rollout_Ks=None, | ||
| Ks=None, | ||
| rollout_action=None, | ||
| action=None, | ||
| ) | ||
| with pytest.raises(_Stop): | ||
| transformer.prefill_memory_kv_cache( | ||
| cache=cache, # type: ignore[arg-type] | ||
| input=inp, # type: ignore[arg-type] | ||
| timestep=torch.zeros(1), | ||
| ) | ||
| return captured["t"].tolist() | ||
|
|
||
| # Positions are the selected frame indices themselves, not arange(K). | ||
| assert positions_for([1, 3]) == [1.0, 3.0] | ||
| assert positions_for([0, 2, 5]) == [0.0, 2.0, 5.0] | ||
| # Identity-stability: frame 3 keeps position 3 whether it's in slot 1 | ||
| # (set [1, 3]) or slot 0 (set [3, 5]) -- the property the fix restores. | ||
| assert positions_for([1, 3])[1] == positions_for([3, 5])[0] == 3.0 |
There was a problem hiding this comment.
Test verifies driver input, not RoPE output
The test monkey-patches _build_memory_rope_freqs with _capture before the function body runs, so it confirms that prefill_memory_kv_cache constructs the right t_positions argument — but it never calls rope._freq_components with those non-zero-aligned positions. If _freq_components assumed seq_t[0] == 0 or had any issue handling sparse/out-of-order temporal indices, this test would still pass. Consider adding a complementary integration-level test (even CPU/tiny-model) that calls the real _build_memory_rope_freqs with a non-trivial selected set (e.g., [2, 5]) and asserts the returned rope_freqs shape is [K * tokens_per_frame, 1, 1, head_dim] and that the result for frame 5 differs from that of frame 2 — confirming the frequencies are actually position-dependent at the output level.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
… native output) The reconstituted-context memory prefill modulated the selected memory frames at collapsed RoPE positions [0, K) — i.e. the buffer slot index — regardless of which frames were chosen. Because the memory set is re-selected each chunk, a given frame's slot (and thus its RoPE position) changed between chunks, so its key's positional phase shifted and the scene jolted on each window slide. That's the choppiness Ruilong flagged on PR #231 (native choppy, vendor smooth). Fix: build the memory RoPE freqs at each frame's true temporal position (its clean-latent-history index), which is the same absolute coordinate frame the main path uses via `shift_t` (offset = ar_idx * len_t). A frame's encoding is now stable across chunks regardless of its slot. Renames `_build_collapsed_rope_freqs` -> `_build_memory_rope_freqs` and updates the now-inaccurate "collapsed position" docs. Adds a CPU regression test asserting the prefill driver passes the selected frame indices (not arange(K)) and that a frame keeps its position across re-selection. Full HY CPU suite passes; ruff clean. Numerical parity vs the vendor prefill still to be confirmed on GPU. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
4e8afb7 to
2408981
Compare
|
Is this verified to fix the choppy results in #231 ? |
The CUDAGraphWrapper captures pointers into the KV cache, but HY-WorldPlay re-runs prefill_memory_kv_cache every chunk: it resets and repopulates each PRoPE block's memory KV from a different FOV-selected frame set (select_mem_frames_wan), reallocating the underlying storage. A graph captured on one chunk then replays against another chunk's stale/freed memory-KV slots, decoding to a deterministic "shatter" of speckle corruption on the captured/replayed chunks (independent of seed and prompt). Set use_cuda_graph=False so the rollout keeps torch.compile (Inductor) -- still ~5x faster diffuse than vendor -- without the unsafe replay. Also unwrap torch.compile's OptimizedModule (_orig_mod) before the HyWorldPlayWanDiTNetwork isinstance assert in prefill_memory_kv_cache; with compile_network=True the wrapper otherwise trips the assert at the first memory-prefill step (AR 1). Parity vs vendor (8-chunk, matched prompt) improves 80.4 -> 44.9 / 255 as the corruption is removed. Test harness: keep HF_HOME local in bench.sh (avoids ~/.cache permission errors on shared machines) and add torchvision to the vendor heavy-deps in run.sh. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Per-stage and per-input perf + parity (PR #231 format) for num_chunk=8, pose=w-31, prompt "a person walking", over HY-WorldPlay/assets/img/ {1,2,5,6,10} on GB300. Native (use_cuda_graph=False after the memory-prefill corruption fix) is 5.5x faster DiT diffuse and ~4x faster DiT+VAE per chunk than vendor; median parity 36.0/255 with all chunks rendering cleanly. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Yes, it's fixed now. |
…rldplay-fix-memory-rope
Re-measured native-vs-vendor mean |Δ| with the VAE patchify channel-order fix on the native leg (native regenerated, vendor reused unchanged). The ~2px decode checkerboard removal drops median parity 36.0 -> 24.0 / 255 (per-image 36->28, 28->23, 36->26, 21->17, 47->24). Perf/speedup table is unchanged — the fix is a zero-cost einops axis swap. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
d011203 to
02cead4
Compare
|
/ok to test 02cead4 |
Fixes the choppiness on PR #231.
Root cause
Memory prefill RoPE-rotated each selected frame at its buffer-slot index (
arange(K)), not its frame identity. Frames are re-selected every chunk, so a frame's slot — and thus its positional phase — jumped chunk-to-chunk, jolting the scene on each window slide.Fix
Rotate each memory frame at its true temporal position (clean-latent-history index), matching the main path's
shift_toffset. One-liner in_action.py:t_positions=torch.arange(K, ...)→selected_idx_t.to(torch.float32). Plus two fixes for the compiled path:use_cuda_graph=False(graph capture replayed stale per-chunk memory-KV pointers → speckle corruption) and unwrapOptimizedModule._orig_modbefore theisinstanceguard. CPU regression test added; full HY suite 94 passed / 4 skipped.Perf — 8-chunk native vs vendor (GB300)
num_chunk=8,pose=w-31,seed=0, 704×1280,"a person walking", post-warmup medians (chunks 5–7). Both legs: cuDNN SDPA +torch.compile.1.png2.png5.jpeg6.jpeg10.pngDiT (diffuse)= per-chunk median (≠ PR #231's per-forward basis; compare ratios).mean |Δ|= per-pixel uint8/255 diff of the two MP4s (cumulative bf16 drift over 8 chunks). All chunks render cleanly. Full report:integrations/hy_worldplay/tests/parity_check/bench_8chunk_walking.md.Native vs vendor video pairs
1.pnghy-worldplay-wan-i2v-5b.mp4
1-vendor.mp4
2.pnghy-worldplay-wan-i2v-5b.mp4
2-vendor.mp4
5.jpeghy-worldplay-wan-i2v-5b.mp4
5-vendor.mp4
10.pnghy-worldplay-wan-i2v-5b.mp4
10-vendor.mp4
Parity refreshed with the Wan 2.2 VAE patchify-order fix (#338) on the native leg: median
|Δ|36.0 → 24.0 / 255 (the removed ~2px decode checkerboard had inflated it). Perf/speedup table unchanged (the fix is a zero-cost einops axis swap). The embedded video pairs above are pre-fix and should be re-uploaded from the regenerated clips.