Skip to content

fix(hy_worldplay): RoPE memory frames at their true positions (fixes choppy native output) + Re-evaluated perfs#318

Open
wenqingw-nv wants to merge 5 commits into
mainfrom
dev/wenqingw-nv/hy-worldplay-fix-memory-rope
Open

fix(hy_worldplay): RoPE memory frames at their true positions (fixes choppy native output) + Re-evaluated perfs#318
wenqingw-nv wants to merge 5 commits into
mainfrom
dev/wenqingw-nv/hy-worldplay-fix-memory-rope

Conversation

@wenqingw-nv

@wenqingw-nv wenqingw-nv commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

Fixes the choppiness on PR #231.

Root cause

Memory prefill RoPE-rotated each selected frame at its buffer-slot index (arange(K)), not its frame identity. Frames are re-selected every chunk, so a frame's slot — and thus its positional phase — jumped chunk-to-chunk, jolting the scene on each window slide.

Fix

Rotate each memory frame at its true temporal position (clean-latent-history index), matching the main path's shift_t offset. One-liner in _action.py: t_positions=torch.arange(K, ...)selected_idx_t.to(torch.float32). Plus two fixes for the compiled path: use_cuda_graph=False (graph capture replayed stale per-chunk memory-KV pointers → speckle corruption) and unwrap OptimizedModule._orig_mod before the isinstance guard. CPU regression test added; full HY suite 94 passed / 4 skipped.

Perf — 8-chunk native vs vendor (GB300)

num_chunk=8, pose=w-31, seed=0, 704×1280, "a person walking", post-warmup medians (chunks 5–7). Both legs: cuDNN SDPA + torch.compile.

stage native vendor speedup
DiT (diffuse) 5085 ms 27939 ms 5.49×
VAE decode 2712 ms 3195 ms 1.18×
DiT + VAE / chunk 7797 ms 31135 ms 3.99×
image DiT nat/ven (ms) VAE nat/ven (ms) ratio (DiT+VAE) mean |Δ|
1.png 5121 / 27939 2712 / 3192 3.97× 27.8
2.png 4941 / 28021 2711 / 3196 4.08× 22.6
5.jpeg 4962 / 27909 2712 / 3195 4.05× 25.7
6.jpeg 5407 / 27932 2712 / 3187 3.83× 16.6
10.png 5085 / 28007 2712 / 3200 4.00× 24.0
median 5085 / 27939 2712 / 3195 4.00× 24.0

DiT (diffuse) = per-chunk median (≠ PR #231's per-forward basis; compare ratios). mean |Δ| = per-pixel uint8/255 diff of the two MP4s (cumulative bf16 drift over 8 chunks). All chunks render cleanly. Full report: integrations/hy_worldplay/tests/parity_check/bench_8chunk_walking.md.

Native vs vendor video pairs

image native vendor
1.png
hy-worldplay-wan-i2v-5b.mp4
1-vendor.mp4
2.png
hy-worldplay-wan-i2v-5b.mp4
2-vendor.mp4
5.jpeg
hy-worldplay-wan-i2v-5b.mp4
5-vendor.mp4
10.png
hy-worldplay-wan-i2v-5b.mp4
10-vendor.mp4

Parity refreshed with the Wan 2.2 VAE patchify-order fix (#338) on the native leg: median |Δ| 36.0 → 24.0 / 255 (the removed ~2px decode checkerboard had inflated it). Perf/speedup table unchanged (the fix is a zero-cost einops axis swap). The embedded video pairs above are pre-fix and should be re-uploaded from the regenerated clips.

@copy-pr-bot

copy-pr-bot Bot commented Jun 9, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes two bugs introduced when the memory-prefill path was first added: memory frames were RoPE-rotated at their buffer-slot index (arange(K)) instead of their true clean-latent-history index, causing per-chunk positional phase jumps that made native output choppy; and the compiled network was never correctly unwrapped from torch.compile's OptimizedModule before the isinstance assert, so the prefill silently fell through on any compiled run. A third fix disables CUDA-graph capture for the HY-WorldPlay pipeline because the graph bakes in memory-KV pointers that are reallocated every chunk, producing speckle corruption on replay.

  • RoPE positions: torch.arange(K) replaced with selected_idx_t.to(torch.float32) so each selected frame is encoded at its absolute temporal position regardless of which buffer slot it occupies.
  • OptimizedModule unwrap: hasattr(network, "_orig_mod") guard added before the isinstance assert so the prefill entry point resolves correctly when compile_network=True.
  • CUDA-graph fix: use_cuda_graph=False hardcoded in config.py with a detailed comment; compile_network (Inductor) is retained for the ~4× DiT speedup.

Confidence Score: 5/5

Safe to merge — all three fixes address confirmed bugs with clear root-cause explanations and are well-guarded by the existing test suite (94 passed / 4 skipped on full HY suite).

The RoPE position fix is a one-line change from arange(K) to selected_idx_t that directly matches the coordinate frame the main path uses; the reasoning is solid and the driver-level test covers the key invariant. The _orig_mod unwrap and use_cuda_graph=False are both targeted, well-documented fixes for confirmed runtime failures. No correctness or data-integrity regressions were found.

No files require special attention. The bench script run.sh uses --no-deps for torchvision, which is a bench-time convenience rather than production code.

Important Files Changed

Filename Overview
integrations/hy_worldplay/hy_worldplay/_action.py Core fix: arange(K)selected_idx_t.to(float32) for memory RoPE positions, plus _orig_mod unwrap before isinstance and method rename from _build_collapsed_rope_freqs to _build_memory_rope_freqs. Logic is correct and well-commented.
integrations/hy_worldplay/hy_worldplay/_camera.py Documentation-only changes: docstrings updated throughout to reflect that memory KV is stored at true temporal positions rather than collapsed [0, K) positions. No logic changes.
integrations/hy_worldplay/hy_worldplay/config.py Hardcodes use_cuda_graph=False for the HY-WorldPlay pipeline config with a detailed explanatory comment. Correct fix for per-chunk memory-KV pointer invalidation under graph replay.
integrations/hy_worldplay/tests/test_prefill.py New test test_prefill_rope_positions_track_frame_identity_not_slot validates that _build_memory_rope_freqs receives the selected frame indices (not arange(K)) and that frame 3's position is stable across different slot assignments. Driver-level validation only (real _freq_components not exercised).
integrations/hy_worldplay/tests/parity_check/bench.sh Adds HF_HOME env-var defaulting to a local hf_cache/ directory to keep the bench hermetic on shared machines. Cosmetic improvement with no functional impact on the core changes.
integrations/hy_worldplay/tests/parity_check/run.sh Adds torchvision==0.26.0 --no-deps to the vendor heavy-dep install step. The --no-deps flag skips torch/torchaudio compatibility verification, which could silently produce a mis-matched torchvision binary.
integrations/hy_worldplay/tests/parity_check/bench_8chunk_walking.md New benchmark report for the 8-chunk walking scenario on GB300. `mean

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[prefill_memory_kv_cache called\nper chunk] --> B[slice clean_latent_history\nat selected frame indices]
    B --> C["selected_idx_t = tensor(selected, long)"]
    C --> D["t_positions = selected_idx_t.to(float32)\n✅ true temporal positions\n❌ OLD: arange(K) — slot indices"]
    D --> E["_build_memory_rope_freqs\ncache, t_positions"]
    E --> F["rope._freq_components(t_positions)\nfreqs_t, freqs_h, freqs_w"]
    F --> G["rope_freqs shape:\nK×tokens_per_frame, 1, 1, head_dim"]
    G --> H{self.network is\nOptimizedModule?}
    H -->|"hasattr(_orig_mod) == True"| I["network = network._orig_mod\n✅ unwrap to HyWorldPlayWanDiTNetwork"]
    H -->|"no wrapper"| J
    I --> J["assert isinstance\nHyWorldPlayWanDiTNetwork"]
    J --> K["network.prefill_memory_kv_cache\ncond pass"]
    K --> L{CFG enabled?}
    L -->|Yes| M["network.prefill_memory_kv_cache\nuncond pass"]
    L -->|No| N[done]
    M --> N
Loading

Reviews (6): Last reviewed commit: "docs(hy_worldplay): refresh parity after..." | Re-trigger Greptile

Comment on lines +414 to +443
def _capture(*, cache, t_positions): # noqa: ANN001 (test stub)
captured["t"] = t_positions.detach().clone()
raise _Stop()

transformer._build_memory_rope_freqs = _capture # type: ignore[assignment]

def positions_for(selected: list[int]) -> list[float]:
inp = types.SimpleNamespace(
memory_frame_indices=selected,
rollout_viewmats=None,
viewmats=None,
rollout_Ks=None,
Ks=None,
rollout_action=None,
action=None,
)
with pytest.raises(_Stop):
transformer.prefill_memory_kv_cache(
cache=cache, # type: ignore[arg-type]
input=inp, # type: ignore[arg-type]
timestep=torch.zeros(1),
)
return captured["t"].tolist()

# Positions are the selected frame indices themselves, not arange(K).
assert positions_for([1, 3]) == [1.0, 3.0]
assert positions_for([0, 2, 5]) == [0.0, 2.0, 5.0]
# Identity-stability: frame 3 keeps position 3 whether it's in slot 1
# (set [1, 3]) or slot 0 (set [3, 5]) -- the property the fix restores.
assert positions_for([1, 3])[1] == positions_for([3, 5])[0] == 3.0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Test verifies driver input, not RoPE output

The test monkey-patches _build_memory_rope_freqs with _capture before the function body runs, so it confirms that prefill_memory_kv_cache constructs the right t_positions argument — but it never calls rope._freq_components with those non-zero-aligned positions. If _freq_components assumed seq_t[0] == 0 or had any issue handling sparse/out-of-order temporal indices, this test would still pass. Consider adding a complementary integration-level test (even CPU/tiny-model) that calls the real _build_memory_rope_freqs with a non-trivial selected set (e.g., [2, 5]) and asserts the returned rope_freqs shape is [K * tokens_per_frame, 1, 1, head_dim] and that the result for frame 5 differs from that of frame 2 — confirming the frequencies are actually position-dependent at the output level.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

… native output)

The reconstituted-context memory prefill modulated the selected memory
frames at collapsed RoPE positions [0, K) — i.e. the buffer slot index —
regardless of which frames were chosen. Because the memory set is
re-selected each chunk, a given frame's slot (and thus its RoPE position)
changed between chunks, so its key's positional phase shifted and the
scene jolted on each window slide. That's the choppiness Ruilong flagged
on PR #231 (native choppy, vendor smooth).

Fix: build the memory RoPE freqs at each frame's true temporal position
(its clean-latent-history index), which is the same absolute coordinate
frame the main path uses via `shift_t` (offset = ar_idx * len_t). A
frame's encoding is now stable across chunks regardless of its slot.
Renames `_build_collapsed_rope_freqs` -> `_build_memory_rope_freqs` and
updates the now-inaccurate "collapsed position" docs.

Adds a CPU regression test asserting the prefill driver passes the
selected frame indices (not arange(K)) and that a frame keeps its
position across re-selection. Full HY CPU suite passes; ruff clean.
Numerical parity vs the vendor prefill still to be confirmed on GPU.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
@wenqingw-nv wenqingw-nv force-pushed the dev/wenqingw-nv/hy-worldplay-fix-memory-rope branch from 4e8afb7 to 2408981 Compare June 9, 2026 21:32
@liruilong940607

liruilong940607 commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Is this verified to fix the choppy results in #231 ?

jmccaffrey-nv and others added 2 commits June 12, 2026 17:02
The CUDAGraphWrapper captures pointers into the KV cache, but HY-WorldPlay
re-runs prefill_memory_kv_cache every chunk: it resets and repopulates each
PRoPE block's memory KV from a different FOV-selected frame set
(select_mem_frames_wan), reallocating the underlying storage. A graph
captured on one chunk then replays against another chunk's stale/freed
memory-KV slots, decoding to a deterministic "shatter" of speckle corruption
on the captured/replayed chunks (independent of seed and prompt). Set
use_cuda_graph=False so the rollout keeps torch.compile (Inductor) -- still
~5x faster diffuse than vendor -- without the unsafe replay.

Also unwrap torch.compile's OptimizedModule (_orig_mod) before the
HyWorldPlayWanDiTNetwork isinstance assert in prefill_memory_kv_cache; with
compile_network=True the wrapper otherwise trips the assert at the first
memory-prefill step (AR 1).

Parity vs vendor (8-chunk, matched prompt) improves 80.4 -> 44.9 / 255 as the
corruption is removed.

Test harness: keep HF_HOME local in bench.sh (avoids ~/.cache permission
errors on shared machines) and add torchvision to the vendor heavy-deps in
run.sh.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Per-stage and per-input perf + parity (PR #231 format) for num_chunk=8,
pose=w-31, prompt "a person walking", over HY-WorldPlay/assets/img/
{1,2,5,6,10} on GB300. Native (use_cuda_graph=False after the memory-prefill
corruption fix) is 5.5x faster DiT diffuse and ~4x faster DiT+VAE per chunk
than vendor; median parity 36.0/255 with all chunks rendering cleanly.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
@wenqingw-nv wenqingw-nv changed the title fix(hy_worldplay): RoPE memory frames at their true positions (fixes choppy native output) fix(hy_worldplay): RoPE memory frames at their true positions (fixes choppy native output) + Re-evaluated video bench perfs Jun 13, 2026
@wenqingw-nv wenqingw-nv changed the title fix(hy_worldplay): RoPE memory frames at their true positions (fixes choppy native output) + Re-evaluated video bench perfs fix(hy_worldplay): RoPE memory frames at their true positions (fixes choppy native output) + Re-evaluated perfs Jun 13, 2026
@wenqingw-nv

wenqingw-nv commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator Author

Is this verified to fix the choppy results in #231 ?

Yes, it's fixed now.

Re-measured native-vs-vendor mean |Δ| with the VAE patchify channel-order
fix on the native leg (native regenerated, vendor reused unchanged). The
~2px decode checkerboard removal drops median parity 36.0 -> 24.0 / 255
(per-image 36->28, 28->23, 36->26, 21->17, 47->24). Perf/speedup table is
unchanged — the fix is a zero-cost einops axis swap.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
@wenqingw-nv wenqingw-nv force-pushed the dev/wenqingw-nv/hy-worldplay-fix-memory-rope branch from d011203 to 02cead4 Compare June 15, 2026 05:15
@wenqingw-nv

Copy link
Copy Markdown
Collaborator Author

/ok to test 02cead4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants