Skip to content

feat(wan-vae): native .pth state-dict transform for Wan 2.2 TI2V-5B VAE#223

Open
wenqingw-nv wants to merge 2 commits into
NVIDIA:mainfrom
wenqingw-nv:wenqing/hy-worldplay-vae-pth-transform
Open

feat(wan-vae): native .pth state-dict transform for Wan 2.2 TI2V-5B VAE#223
wenqingw-nv wants to merge 2 commits into
NVIDIA:mainfrom
wenqingw-nv:wenqing/hy-worldplay-vae-pth-transform

Conversation

@wenqingw-nv

Copy link
Copy Markdown
Collaborator

One follow-up item from #203 (HY-WorldPlay #155).

#155 review asked whether loading upstream's single-file Wan2.2_VAE.pth lets us drop the diffusers state_dict_transform.

Finding: it can't be dropped, but it shrinks a lot. The .pth uses Wan's native module layout (each down/up stage is a flat Sequential of residual blocks + a resample) while our WanVAE groups them (resnets.{j} + downsampler/upsampler). So it still needs a remap — 4 rules vs the ~50-rule diffusers remap. The earlier "missing params / meta tensors" was a misdiagnosis: all params are present under different key names; without a transform load_state_dict(assign=True) simply left our slots on meta.

  • Add wan22_ti2v_5b_vae_pth_state_dict_transform (4-rule remap). Verified 196↔196 key/shape bijection with the WanVAE module tree (no meta leftovers, no unexpected keys).
  • CPU tests: test_wan22_vae_pth_remap_is_full_bijection (builds the TI2V-5B WanVAE on meta, proves full coverage) + test_wan22_vae_pth_remap_spot_checks_real_keys (guards the regex against real .pth key strings).
  • Correct two stale comments that claimed the .pth matched our layout directly.

Opt-in: diffusers stays the production default; a follow-up flips the default + drops the ~50-rule remap once a GPU decode-parity smoke confirms identical output.

Part of #203.

Follow-up from PR #155 review (tracked in #203). Ruilong asked whether
loading upstream's single-file ``Wan2.2_VAE.pth`` lets us drop the
diffusers ``state_dict_transform``.

Finding: the transform can't be dropped, but it shrinks a lot. The
``.pth`` uses Wan's native module layout (each down/up stage is a flat
``Sequential`` of residual blocks + a resample), while our ``WanVAE``
groups them (``resnets.{j}`` + ``downsampler`` / ``upsampler``). So the
``.pth`` still needs a remap -- a 4-rule one vs the ~50-rule diffusers
remap. The earlier audit's "missing params / meta tensors" was a
misdiagnosis: the params are all present under different key names;
without a transform ``load_state_dict(assign=True)`` simply left our
slots on meta.

- Add ``wan22_ti2v_5b_vae_pth_state_dict_transform`` (+ its 4-rule
  remap). Verified 1:1 key/shape bijection with the WanVAE module tree
  (196<->196, no meta leftovers, no unexpected keys).
- CPU tests: ``test_wan22_vae_pth_remap_is_full_bijection`` (builds the
  TI2V-5B WanVAE on meta, proves full coverage) +
  ``test_wan22_vae_pth_remap_spot_checks_real_keys`` (guards the regex
  against real ``.pth`` key strings).
- Correct two stale comments that claimed the ``.pth`` matched our
  layout directly / needed no remap.

Kept as opt-in: diffusers stays the production default until a GPU
decode-parity smoke confirms the ``.pth`` weights decode identically.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
@copy-pr-bot

copy-pr-bot Bot commented May 29, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Flip `Wan22TI2V5BVAE{Encoder,Decoder}Config` to load upstream's
canonical single-file `Wan2.2_VAE.pth` via the 4-rule
`wan22_ti2v_5b_vae_pth_state_dict_transform`, instead of the diffusers
safetensors + ~50-rule remap.

Verified safe at the weight level (no GPU smoke needed): the native
`.pth` + 4-rule transform and the diffusers safetensors + ~50-rule
transform map to bit-identical WanVAE weights -- 196/196 tensors,
max |Δ| = 0.0, both fp32. So this is a pure checkpoint-source change.

The diffusers path + its remap stay in tree as a documented opt-in
fallback (still re-exported). Adds `test_wan22_vae_pth_and_diffusers_
weights_identical` (manual) to lock in the equality.

Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
@wenqingw-nv

Copy link
Copy Markdown
Collaborator Author

/ok to test 8057598

@wenqingw-nv wenqingw-nv enabled auto-merge May 30, 2026 00:57
Comment on lines +1650 to +1670
_WAN22_TI2V_5B_VAE_PTH_KEY_REMAP: dict[str, str] = {
# Residual blocks: drop the inner ``downsamples`` / ``upsamples``
# container name, keep the per-block index, regroup under ``resnets``.
r"^encoder\.downsamples\.(\d+)\.downsamples\.(\d+)\.(residual|shortcut)\.(.*)$": (
r"encoder.downsamples.\1.resnets.\2.\3.\4"
),
r"^decoder\.upsamples\.(\d+)\.upsamples\.(\d+)\.(residual|shortcut)\.(.*)$": (
r"decoder.upsamples.\1.resnets.\2.\3.\4"
),
# Resample: the single resample per stage (it carries the
# ``resample`` / ``time_conv`` leaves, never ``residual``) folds into
# the stage's ``downsampler`` / ``upsampler``; its sub-block index is
# dropped since there is exactly one.
r"^encoder\.downsamples\.(\d+)\.downsamples\.\d+\.(resample|time_conv)\.(.*)$": (
r"encoder.downsamples.\1.downsampler.\2.\3"
),
r"^decoder\.upsamples\.(\d+)\.upsamples\.\d+\.(resample|time_conv)\.(.*)$": (
r"decoder.upsamples.\1.upsampler.\2.\3"
),
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we just need to name the component in the same way it is named in pth they it can load pth?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — the production VAE config now loads the native Wan2.2_VAE.pth directly, whose keys already match our WanVAE layout, so no remap runs.

@greptile-apps

greptile-apps Bot commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds wan22_ti2v_5b_vae_pth_state_dict_transform, a 4-rule regex remap for Wan's native Wan2.2_VAE.pth checkpoint, and immediately flips the production defaults in Wan22TI2V5BVAEEncoderConfig and Wan22TI2V5BVAEDecoderConfig from the diffusers path to the native .pth path. Stale comments that claimed the .pth matched the model layout directly are corrected.

  • The 4-rule remap disambiguates residual blocks (residual|shortcut leaves → resnets.{j}) from the per-stage resample (resample|time_conv leaves → downsampler|upsampler) using leaf names rather than index position, which is robust to any upstream index ordering.
  • Two CPU tests verify a full 196-key bijection on meta tensors and spot-check against real .pth key strings; a @pytest.mark.manual parity test downloads both checkpoints and asserts max |Δ| = 0, confirming bit-identical weights.
  • The Note inside the new function's docstring says "the production configs keep diffusers as the default," but this PR changes those defaults to the .pth path, making the Note factually wrong immediately upon merge.

Confidence Score: 4/5

Safe to merge once the stale Note is updated; the remap logic and bijection tests are correct.

The remap logic is sound and well-tested, but the function docstring immediately contradicts the code: it claims diffusers remains the production default on the same commit that switches both config dataclasses to the .pth path. A second minor concern is a fragile len(raw) <= 4 guard in the manual test, but that test is not run in CI.

The stale Note in wan22_ti2v_5b_vae_pth_state_dict_transform (lines 1685-1688 of vae.py) is the only spot that needs a fix before merge.

Important Files Changed

Filename Overview
flashdreams/flashdreams/recipes/wan/autoencoder/vae.py Adds 4-rule PTH key remap, corrects stale comments, and flips production config defaults from diffusers to the native .pth path; one docstring Note is now factually wrong about which checkpoint is the production default.
flashdreams/tests/test_vae.py Adds bijection test (meta-only), spot-check test, and a manual parity test; fragile len(raw) <= 4 heuristic in the manual test could silently use the wrong dict if the .pth gains additional metadata keys.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["Wan2.2_VAE.pth\n(upstream native layout)"] --> B["wan22_ti2v_5b_vae_pth_state_dict_transform"]
    B --> C{"Key type\n(leaf name)"}
    C -- "residual | shortcut" --> D["encoder.downsamples.i.resnets.j.*\ndecoder.upsamples.i.resnets.j.*"]
    C -- "resample | time_conv" --> E["encoder.downsamples.i.downsampler.*\ndecoder.upsamples.i.upsampler.*"]
    C -- "no match\n(middle / head / conv1 / conv2)" --> F["pass-through unchanged"]
    D --> G["WanVAE state_dict\n(196 keys, full bijection)"]
    E --> G
    F --> G
    G --> H["load_state_dict(strict=False, assign=True)\nno meta leftovers"]
Loading

Reviews (1): Last reviewed commit: "feat(wan-vae): default Wan 2.2 TI2V-5B V..." | Re-trigger Greptile

Comment on lines +1685 to +1688
Note:
Decode-parity against the diffusers path is not yet verified on
GPU; until it is, the production configs keep diffusers as the
default.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Stale Note contradicts the production configs changed in this PR

The Note says "the production configs keep diffusers as the default", but lines 1712 and 1732 in this very PR change Wan22TI2V5BVAEEncoderConfig.checkpoint_path and Wan22TI2V5BVAEDecoderConfig.checkpoint_path to WAN22_TI2V_5B_VAE_PATH and wan22_ti2v_5b_vae_pth_state_dict_transform. A reader of this docstring will be told the opposite of what the dataclass defaults actually are.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +285 to +287
from flashdreams.recipes.wan.autoencoder.vae import (
wan22_ti2v_5b_vae_pth_state_dict_transform,
wan22_ti2v_5b_vae_state_dict_transform,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Magic-number length guard may silently break on a deeper wrapper

The len(raw) <= 4 condition is meant to tell a top-level metadata dict from the real state dict, but the threshold is a hard-coded guess. If the upstream .pth ever adds even one extra metadata key (e.g. epoch, training step, or a second nested dict), len(raw) could exceed 4 while the actual weights are still under raw["state_dict"], causing the function to use the un-unwrapped metadata dict as from_pth. A single unconditional unwrap — check for "state_dict" once, then stop — is the conventional pattern and avoids this fragility.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants