Skip to content

feat(wan22): native DiT checkpoint path — drops the diffusers DiT remap#224

Open
wenqingw-nv wants to merge 3 commits into
NVIDIA:mainfrom
wenqingw-nv:wenqing/hy-worldplay-dit-native-ckpt
Open

feat(wan22): native DiT checkpoint path — drops the diffusers DiT remap#224
wenqingw-nv wants to merge 3 commits into
NVIDIA:mainfrom
wenqingw-nv:wenqing/hy-worldplay-dit-native-ckpt

Conversation

@wenqingw-nv

Copy link
Copy Markdown
Collaborator

One follow-up item from #203 (HY-WorldPlay #155).

Same question as the VAE follow-up, for the DiT remap (~25 rules).

Finding: the remap can be dropped entirely. Upstream's native Wan-AI/Wan2.2-TI2V-5B DiT checkpoint (sharded safetensors + index) uses byte-for-byte the WanDiTNetwork key names (blocks.N.self_attn.q, text_embedding.0, head.head ...), because our network was ported from native Wan. Verified an 825↔825 name-identical bijection, so it loads with state_dict_transform=None. load_checkpoint already handles the .safetensors.index.json shard format.

  • Add WAN22_TI2V_5B_DIT_NATIVE_PATH (+ docstring documenting the zero-remap finding); note the native alternative on the diffusers remap dict.
  • test_native_dit_checkpoint_needs_no_remap (manual; fetches the ~250 KB index, no weights) guards the key-identity claim.

Opt-in: diffusers stays the production default; a follow-up flips the default + deletes _WAN22_TI2V_5B_DIT_KEY_REMAP once a GPU decode-parity smoke confirms identical output. (Native DiT is sharded fp32 ~20 GB vs the diffusers single file — worth confirming download/merge cost in that smoke.)

Part of #203.

Follow-up from PR #155 review (tracked in #203). Ruilong asked whether
the DiT key remap could be simplified by loading a different checkpoint
(as with the VAE ``.pth``).

Finding: yes -- entirely. Upstream's *native* ``Wan-AI/Wan2.2-TI2V-5B``
DiT checkpoint (sharded safetensors + index) uses byte-for-byte the
``WanDiTNetwork`` key names (``blocks.N.self_attn.q``,
``text_embedding.0``, ``head.head`` ...), because our network was
ported from native Wan. Verified an 825<->825 name-identical bijection,
so it loads with ``state_dict_transform=None`` -- no analogue of the
~25-rule ``_WAN22_TI2V_5B_DIT_KEY_REMAP`` is needed. ``load_checkpoint``
already handles the ``.safetensors.index.json`` shard format.

- Add ``WAN22_TI2V_5B_DIT_NATIVE_PATH`` (+ docstring documenting the
  zero-remap finding) and export it.
- Note the native alternative on the diffusers remap dict / module.
- ``test_native_dit_checkpoint_needs_no_remap`` (manual; fetches the
  ~250 KB index, no weights) guards the key-identity claim.

Kept opt-in: diffusers stays the production default until a GPU
decode-parity smoke confirms the two checkpoints decode identically.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
@copy-pr-bot

copy-pr-bot Bot commented May 29, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

wenqingw-nv and others added 2 commits May 29, 2026 23:35
…aming

GPU-box verification of the native-DiT finding: the native
`Wan-AI/Wan2.2-TI2V-5B` DiT and the diffusers checkpoint (after the
~25-rule remap) carry bit-identical weights -- 825/825 fp32 tensors,
max |Δ| = 0.0. Adds `test_native_dit_matches_diffusers_weights` (manual)
to lock that in.

Correct the earlier framing: the remap can NOT be deleted. The HY
distilled checkpoint ships in diffusers-key format and routes through
`wan22_ti2v_5b_dit_state_dict_transform` (hy_worldplay/_checkpoint.py),
so the remap is load-bearing. The native path is a proven-equivalent,
simpler source for the *base* (un-distilled) Wan 2.2 pipeline only; kept
opt-in (diffusers default avoids the larger sharded fp32 download).

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
`WAN22_TI2V_5B_DIT_DIFFUSERS_PATH` pointed at a single-file
`transformer/diffusion_pytorch_model.safetensors` that does not exist —
the diffusers repo ships 5 shards + a `.safetensors.index.json`, so the
bare-filename URL returns 404 and any base-pipeline DiT load fails.
(Went unnoticed because GPU CI always loads via `--ckpt-path`/distilled,
never the base diffusers DiT.) Point the constant at the index;
`load_checkpoint` resolves shards from it. Found while running the
`--example-data` rollout to verify the pose follow-up.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
@wenqingw-nv

Copy link
Copy Markdown
Collaborator Author

/ok to test be70dbe

@wenqingw-nv wenqingw-nv enabled auto-merge May 30, 2026 00:56

@liruilong940607 liruilong940607 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any test to verify the output is the same from two different way of loading the checkpoint?

Comment on lines +62 to +68
The ``transformer/`` subfolder ships 5 shards + a
``.safetensors.index.json``; there is **no** single-file
``diffusion_pytorch_model.safetensors`` (that bare-filename URL 404s --
it was the prior value of this constant, which broke any base-pipeline
load). ``load_checkpoint`` resolves the index directly. Loads via
:func:`wan22_ti2v_5b_dit_state_dict_transform` (the diffusers naming
differs from ours). This is the production default."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do a pass in a separate PR later on to clean up all those chain of thoughts comment / doc strings from the agent? Not only in this file but also many places in the HY/wan22 integration.

Imagine you are a user reading the code and these doc strings, Many details here are not needed to be known for the user. Only be explainable enough for user to know how to use. The history of how we get to here doesn't need to be marked.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a separate PR: #314. It trims the agent's chain-of-thought and historical narration from comments/docstrings across the HY-WorldPlay + wan22 integration, leaving tight user-facing docs and one-line correctness notes.

@greptile-apps

greptile-apps Bot commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes a broken diffusers DiT URL (was 404-ing on a missing single-file safetensors) and introduces WAN22_TI2V_5B_DIT_NATIVE_PATH as an opt-in zero-remap alternative that loads the upstream native Wan-AI/Wan2.2-TI2V-5B checkpoint directly, backed by a verified 825↔825 key-identity bijection. The diffusers path and its remap remain the production default.

  • config.py: corrects WAN22_TI2V_5B_DIT_DIFFUSERS_PATH to point at the sharded-index JSON (the bare .safetensors URL 404ed), adds the new WAN22_TI2V_5B_DIT_NATIVE_PATH constant with thorough documentation, and annotates _WAN22_TI2V_5B_DIT_KEY_REMAP explaining why it is still needed for the diffusers path and the HY distilled checkpoint.
  • test_dit_remap.py: two @pytest.mark.manual tests — one fetches only the ~250 KB index to assert key identity, the other downloads both full checkpoints (~40 GB) to assert bit-identical weights.

Confidence Score: 4/5

The config change is a straightforward bug-fix plus an additive constant; the new test file has one logic gap worth addressing before the manual test is relied upon.

The weight-parity test can silently pass the key-set equality check and then crash with a confusing ValueError if no shard files are globbed — an easy mistake to hit if HuggingFace renames the files or the download path changes. The config-side changes are clean: fixing a 404 URL and adding a well-documented opt-in constant with no production default change.

integrations/wan22/tests/test_dit_remap.py — specifically the max() call in test_native_dit_matches_diffusers_weights needs empty-result guards.

Important Files Changed

Filename Overview
integrations/wan22/wan22/config.py Fixes WAN22_TI2V_5B_DIT_DIFFUSERS_PATH (was pointing to a 404 bare safetensors URL) to the correct sharded-index URL, adds WAN22_TI2V_5B_DIT_NATIVE_PATH as an opt-in zero-remap alternative, and documents why _WAN22_TI2V_5B_DIT_KEY_REMAP must stay.
integrations/wan22/tests/test_dit_remap.py Adds two manual tests guarding the native-checkpoint key-identity and weight-parity claims. The weight-parity test can silently pass the set-equality check and then crash with a confusing ValueError when no shards are globbed; explicit non-empty guards are missing before the max() call.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[load_checkpoint call] --> B{Which path?}
    B -->|WAN22_TI2V_5B_DIT_DIFFUSERS_PATH production default| C[Fetch sharded index JSON Wan-AI/Wan2.2-TI2V-5B-Diffusers]
    B -->|WAN22_TI2V_5B_DIT_NATIVE_PATH opt-in| D[Fetch sharded index JSON Wan-AI/Wan2.2-TI2V-5B]
    C --> E[Download 5 diffusers shards bf16 ~5 GB]
    D --> F[Download 3 native shards fp32 ~20 GB]
    E --> G[wan22_ti2v_5b_dit_state_dict_transform _WAN22_TI2V_5B_DIT_KEY_REMAP applied]
    F --> H[state_dict_transform=None keys already match WanDiTNetwork]
    G --> I[WanDiTNetwork loaded]
    H --> I
    I --> J[Bit-identical weights verified max delta = 0]
Loading

Reviews (1): Last reviewed commit: "fix(wan22): diffusers DiT default 404s —..." | Re-trigger Greptile

Comment on lines +101 to +107
assert set(from_diffusers) == set(native), "remapped key sets differ"
worst = max(
(from_diffusers[k].float() - native[k].float()).abs().max().item()
for k in from_diffusers
)
assert worst == 0.0, (
f"DiT weights differ between checkpoints (max |delta| = {worst})"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Silent pass + ValueError when no shards are found

If glob.glob resolves zero files (e.g., the allow_patterns naming changes on HuggingFace or the download lands in an unexpected directory), both diff_raw and native stay as empty dicts. wan22_ti2v_5b_dit_state_dict_transform({}) returns {}, so set(from_diffusers) == set(native) passes silently (set() == set()), and then max((...) for k in from_diffusers) raises ValueError: max() arg is an empty sequence instead of a helpful assertion message. Add an explicit guard before the max() call, e.g. assert from_diffusers, "no diffusers shards loaded" and assert native, "no native shards loaded".

Comment on lines +53 to +59
missing = model_keys - native_keys # would load onto meta -> .to() raises
extra = native_keys - model_keys # unexpected keys
assert not missing, (
f"{len(missing)} model params absent from native ckpt: {sorted(missing)[:5]}"
)
assert not extra, f"{len(extra)} native keys not in the model: {sorted(extra)[:5]}"
assert native_keys == model_keys

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The third assertion is logically redundant: not missing and not extra together are equivalent to native_keys == model_keys (both sets have zero symmetric difference), so this assert can never fire independently.

Suggested change
missing = model_keys - native_keys # would load onto meta -> .to() raises
extra = native_keys - model_keys # unexpected keys
assert not missing, (
f"{len(missing)} model params absent from native ckpt: {sorted(missing)[:5]}"
)
assert not extra, f"{len(extra)} native keys not in the model: {sorted(extra)[:5]}"
assert native_keys == model_keys
missing = model_keys - native_keys # would load onto meta -> .to() raises
extra = native_keys - model_keys # unexpected keys
assert not missing, (
f"{len(missing)} model params absent from native ckpt: {sorted(missing)[:5]}"
)
assert not extra, f"{len(extra)} native keys not in the model: {sorted(extra)[:5]}"

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

wenqingw-nv added a commit to wenqingw-nv/flashdreams-wq that referenced this pull request Jun 9, 2026
)

Address Ruilong's PR NVIDIA#224 review: clean up the agent's chain-of-thought
and historical narration from comments and docstrings across the HY-
WorldPlay and wan22 integration, leaving tight, user-facing docs plus
one-line correctness invariants. Comments/docstrings only -- no
executable code changes (AST-verified).

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
wenqingw-nv added a commit to wenqingw-nv/flashdreams-wq that referenced this pull request Jun 9, 2026
Answer Ruilong's PR NVIDIA#224 question ("any test verifying the two load
paths match?"). The TI2V-5B DiT can be loaded from the diffusers port
(remapped) or the native checkpoint (identity); this manual test loads
both, applies each transform, and asserts identical key sets and
max|delta| == 0 per tensor -- identical weights => identical output.
Marked `manual` since it downloads both checkpoints.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
wenqingw-nv added a commit to wenqingw-nv/flashdreams-wq that referenced this pull request Jun 9, 2026
Address Ruilong's PR NVIDIA#227: inference should load the HY-WorldPlay
checkpoint directly, not the Wan 2.2 base. The static pipeline now
defaults its transformer checkpoint to the distilled WAN-5B weights
(HF `tencent/HY-WorldPlay`, via `hy_worldplay_distilled_state_dict_transform`)
instead of inheriting the base diffusers safetensors. `--ckpt-path`
becomes a local override rather than the only way to get HY weights.

Updates the smoke tests and README to the new default. The base-
checkpoint identity path still works via `--ckpt-path`. (Also carries
the PR NVIDIA#224 doc cleanup for these three modules.)

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
@wenqingw-nv

Copy link
Copy Markdown
Collaborator Author

Do we have any test to verify the output is the same from two different way of loading the checkpoint?

Yes — test_native_dit_matches_diffusers_weights in integrations/wan22/tests/test_dit_remap.py loads both paths (native + identity remap, and diffusers + remap) and asserts max |Δ| == 0 across all 825 DiT tensors. Identical weights ⇒ identical output.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants