feat(wan22): native DiT checkpoint path — drops the diffusers DiT remap#224
feat(wan22): native DiT checkpoint path — drops the diffusers DiT remap#224wenqingw-nv wants to merge 3 commits into
Conversation
Follow-up from PR #155 review (tracked in #203). Ruilong asked whether the DiT key remap could be simplified by loading a different checkpoint (as with the VAE ``.pth``). Finding: yes -- entirely. Upstream's *native* ``Wan-AI/Wan2.2-TI2V-5B`` DiT checkpoint (sharded safetensors + index) uses byte-for-byte the ``WanDiTNetwork`` key names (``blocks.N.self_attn.q``, ``text_embedding.0``, ``head.head`` ...), because our network was ported from native Wan. Verified an 825<->825 name-identical bijection, so it loads with ``state_dict_transform=None`` -- no analogue of the ~25-rule ``_WAN22_TI2V_5B_DIT_KEY_REMAP`` is needed. ``load_checkpoint`` already handles the ``.safetensors.index.json`` shard format. - Add ``WAN22_TI2V_5B_DIT_NATIVE_PATH`` (+ docstring documenting the zero-remap finding) and export it. - Note the native alternative on the diffusers remap dict / module. - ``test_native_dit_checkpoint_needs_no_remap`` (manual; fetches the ~250 KB index, no weights) guards the key-identity claim. Kept opt-in: diffusers stays the production default until a GPU decode-parity smoke confirms the two checkpoints decode identically. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
…aming GPU-box verification of the native-DiT finding: the native `Wan-AI/Wan2.2-TI2V-5B` DiT and the diffusers checkpoint (after the ~25-rule remap) carry bit-identical weights -- 825/825 fp32 tensors, max |Δ| = 0.0. Adds `test_native_dit_matches_diffusers_weights` (manual) to lock that in. Correct the earlier framing: the remap can NOT be deleted. The HY distilled checkpoint ships in diffusers-key format and routes through `wan22_ti2v_5b_dit_state_dict_transform` (hy_worldplay/_checkpoint.py), so the remap is load-bearing. The native path is a proven-equivalent, simpler source for the *base* (un-distilled) Wan 2.2 pipeline only; kept opt-in (diffusers default avoids the larger sharded fp32 download). Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
`WAN22_TI2V_5B_DIT_DIFFUSERS_PATH` pointed at a single-file `transformer/diffusion_pytorch_model.safetensors` that does not exist — the diffusers repo ships 5 shards + a `.safetensors.index.json`, so the bare-filename URL returns 404 and any base-pipeline DiT load fails. (Went unnoticed because GPU CI always loads via `--ckpt-path`/distilled, never the base diffusers DiT.) Point the constant at the index; `load_checkpoint` resolves shards from it. Found while running the `--example-data` rollout to verify the pose follow-up. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
|
/ok to test be70dbe |
liruilong940607
left a comment
There was a problem hiding this comment.
Do we have any test to verify the output is the same from two different way of loading the checkpoint?
| The ``transformer/`` subfolder ships 5 shards + a | ||
| ``.safetensors.index.json``; there is **no** single-file | ||
| ``diffusion_pytorch_model.safetensors`` (that bare-filename URL 404s -- | ||
| it was the prior value of this constant, which broke any base-pipeline | ||
| load). ``load_checkpoint`` resolves the index directly. Loads via | ||
| :func:`wan22_ti2v_5b_dit_state_dict_transform` (the diffusers naming | ||
| differs from ours). This is the production default.""" |
There was a problem hiding this comment.
Can you do a pass in a separate PR later on to clean up all those chain of thoughts comment / doc strings from the agent? Not only in this file but also many places in the HY/wan22 integration.
Imagine you are a user reading the code and these doc strings, Many details here are not needed to be known for the user. Only be explainable enough for user to know how to use. The history of how we get to here doesn't need to be marked.
There was a problem hiding this comment.
Done in a separate PR: #314. It trims the agent's chain-of-thought and historical narration from comments/docstrings across the HY-WorldPlay + wan22 integration, leaving tight user-facing docs and one-line correctness notes.
Greptile SummaryThis PR fixes a broken diffusers DiT URL (was 404-ing on a missing single-file safetensors) and introduces
Confidence Score: 4/5The config change is a straightforward bug-fix plus an additive constant; the new test file has one logic gap worth addressing before the manual test is relied upon. The weight-parity test can silently pass the key-set equality check and then crash with a confusing ValueError if no shard files are globbed — an easy mistake to hit if HuggingFace renames the files or the download path changes. The config-side changes are clean: fixing a 404 URL and adding a well-documented opt-in constant with no production default change. integrations/wan22/tests/test_dit_remap.py — specifically the max() call in test_native_dit_matches_diffusers_weights needs empty-result guards. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[load_checkpoint call] --> B{Which path?}
B -->|WAN22_TI2V_5B_DIT_DIFFUSERS_PATH production default| C[Fetch sharded index JSON Wan-AI/Wan2.2-TI2V-5B-Diffusers]
B -->|WAN22_TI2V_5B_DIT_NATIVE_PATH opt-in| D[Fetch sharded index JSON Wan-AI/Wan2.2-TI2V-5B]
C --> E[Download 5 diffusers shards bf16 ~5 GB]
D --> F[Download 3 native shards fp32 ~20 GB]
E --> G[wan22_ti2v_5b_dit_state_dict_transform _WAN22_TI2V_5B_DIT_KEY_REMAP applied]
F --> H[state_dict_transform=None keys already match WanDiTNetwork]
G --> I[WanDiTNetwork loaded]
H --> I
I --> J[Bit-identical weights verified max delta = 0]
Reviews (1): Last reviewed commit: "fix(wan22): diffusers DiT default 404s —..." | Re-trigger Greptile |
| assert set(from_diffusers) == set(native), "remapped key sets differ" | ||
| worst = max( | ||
| (from_diffusers[k].float() - native[k].float()).abs().max().item() | ||
| for k in from_diffusers | ||
| ) | ||
| assert worst == 0.0, ( | ||
| f"DiT weights differ between checkpoints (max |delta| = {worst})" |
There was a problem hiding this comment.
Silent pass +
ValueError when no shards are found
If glob.glob resolves zero files (e.g., the allow_patterns naming changes on HuggingFace or the download lands in an unexpected directory), both diff_raw and native stay as empty dicts. wan22_ti2v_5b_dit_state_dict_transform({}) returns {}, so set(from_diffusers) == set(native) passes silently (set() == set()), and then max((...) for k in from_diffusers) raises ValueError: max() arg is an empty sequence instead of a helpful assertion message. Add an explicit guard before the max() call, e.g. assert from_diffusers, "no diffusers shards loaded" and assert native, "no native shards loaded".
| missing = model_keys - native_keys # would load onto meta -> .to() raises | ||
| extra = native_keys - model_keys # unexpected keys | ||
| assert not missing, ( | ||
| f"{len(missing)} model params absent from native ckpt: {sorted(missing)[:5]}" | ||
| ) | ||
| assert not extra, f"{len(extra)} native keys not in the model: {sorted(extra)[:5]}" | ||
| assert native_keys == model_keys |
There was a problem hiding this comment.
The third assertion is logically redundant:
not missing and not extra together are equivalent to native_keys == model_keys (both sets have zero symmetric difference), so this assert can never fire independently.
| missing = model_keys - native_keys # would load onto meta -> .to() raises | |
| extra = native_keys - model_keys # unexpected keys | |
| assert not missing, ( | |
| f"{len(missing)} model params absent from native ckpt: {sorted(missing)[:5]}" | |
| ) | |
| assert not extra, f"{len(extra)} native keys not in the model: {sorted(extra)[:5]}" | |
| assert native_keys == model_keys | |
| missing = model_keys - native_keys # would load onto meta -> .to() raises | |
| extra = native_keys - model_keys # unexpected keys | |
| assert not missing, ( | |
| f"{len(missing)} model params absent from native ckpt: {sorted(missing)[:5]}" | |
| ) | |
| assert not extra, f"{len(extra)} native keys not in the model: {sorted(extra)[:5]}" |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
) Address Ruilong's PR NVIDIA#224 review: clean up the agent's chain-of-thought and historical narration from comments and docstrings across the HY- WorldPlay and wan22 integration, leaving tight, user-facing docs plus one-line correctness invariants. Comments/docstrings only -- no executable code changes (AST-verified). Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Answer Ruilong's PR NVIDIA#224 question ("any test verifying the two load paths match?"). The TI2V-5B DiT can be loaded from the diffusers port (remapped) or the native checkpoint (identity); this manual test loads both, applies each transform, and asserts identical key sets and max|delta| == 0 per tensor -- identical weights => identical output. Marked `manual` since it downloads both checkpoints. Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Address Ruilong's PR NVIDIA#227: inference should load the HY-WorldPlay checkpoint directly, not the Wan 2.2 base. The static pipeline now defaults its transformer checkpoint to the distilled WAN-5B weights (HF `tencent/HY-WorldPlay`, via `hy_worldplay_distilled_state_dict_transform`) instead of inheriting the base diffusers safetensors. `--ckpt-path` becomes a local override rather than the only way to get HY weights. Updates the smoke tests and README to the new default. The base- checkpoint identity path still works via `--ckpt-path`. (Also carries the PR NVIDIA#224 doc cleanup for these three modules.) Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Yes — test_native_dit_matches_diffusers_weights in integrations/wan22/tests/test_dit_remap.py loads both paths (native + identity remap, and diffusers + remap) and asserts max |Δ| == 0 across all 825 DiT tensors. Identical weights ⇒ identical output. |
One follow-up item from #203 (HY-WorldPlay #155).
Same question as the VAE follow-up, for the DiT remap (~25 rules).
Finding: the remap can be dropped entirely. Upstream's native
Wan-AI/Wan2.2-TI2V-5BDiT checkpoint (sharded safetensors + index) uses byte-for-byte theWanDiTNetworkkey names (blocks.N.self_attn.q,text_embedding.0,head.head...), because our network was ported from native Wan. Verified an 825↔825 name-identical bijection, so it loads withstate_dict_transform=None.load_checkpointalready handles the.safetensors.index.jsonshard format.WAN22_TI2V_5B_DIT_NATIVE_PATH(+ docstring documenting the zero-remap finding); note the native alternative on the diffusers remap dict.test_native_dit_checkpoint_needs_no_remap(manual; fetches the ~250 KB index, no weights) guards the key-identity claim.Opt-in: diffusers stays the production default; a follow-up flips the default + deletes
_WAN22_TI2V_5B_DIT_KEY_REMAPonce a GPU decode-parity smoke confirms identical output. (Native DiT is sharded fp32 ~20 GB vs the diffusers single file — worth confirming download/merge cost in that smoke.)Part of #203.