feat(wan-vae): native .pth state-dict transform for Wan 2.2 TI2V-5B VAE#223
feat(wan-vae): native .pth state-dict transform for Wan 2.2 TI2V-5B VAE#223wenqingw-nv wants to merge 2 commits into
Conversation
Follow-up from PR #155 review (tracked in #203). Ruilong asked whether
loading upstream's single-file ``Wan2.2_VAE.pth`` lets us drop the
diffusers ``state_dict_transform``.
Finding: the transform can't be dropped, but it shrinks a lot. The
``.pth`` uses Wan's native module layout (each down/up stage is a flat
``Sequential`` of residual blocks + a resample), while our ``WanVAE``
groups them (``resnets.{j}`` + ``downsampler`` / ``upsampler``). So the
``.pth`` still needs a remap -- a 4-rule one vs the ~50-rule diffusers
remap. The earlier audit's "missing params / meta tensors" was a
misdiagnosis: the params are all present under different key names;
without a transform ``load_state_dict(assign=True)`` simply left our
slots on meta.
- Add ``wan22_ti2v_5b_vae_pth_state_dict_transform`` (+ its 4-rule
remap). Verified 1:1 key/shape bijection with the WanVAE module tree
(196<->196, no meta leftovers, no unexpected keys).
- CPU tests: ``test_wan22_vae_pth_remap_is_full_bijection`` (builds the
TI2V-5B WanVAE on meta, proves full coverage) +
``test_wan22_vae_pth_remap_spot_checks_real_keys`` (guards the regex
against real ``.pth`` key strings).
- Correct two stale comments that claimed the ``.pth`` matched our
layout directly / needed no remap.
Kept as opt-in: diffusers stays the production default until a GPU
decode-parity smoke confirms the ``.pth`` weights decode identically.
Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
Flip `Wan22TI2V5BVAE{Encoder,Decoder}Config` to load upstream's
canonical single-file `Wan2.2_VAE.pth` via the 4-rule
`wan22_ti2v_5b_vae_pth_state_dict_transform`, instead of the diffusers
safetensors + ~50-rule remap.
Verified safe at the weight level (no GPU smoke needed): the native
`.pth` + 4-rule transform and the diffusers safetensors + ~50-rule
transform map to bit-identical WanVAE weights -- 196/196 tensors,
max |Δ| = 0.0, both fp32. So this is a pure checkpoint-source change.
The diffusers path + its remap stay in tree as a documented opt-in
fallback (still re-exported). Adds `test_wan22_vae_pth_and_diffusers_
weights_identical` (manual) to lock in the equality.
Co-Authored-By: Claude Opus 4.8 (1M context) <[email protected]>
|
/ok to test 8057598 |
| _WAN22_TI2V_5B_VAE_PTH_KEY_REMAP: dict[str, str] = { | ||
| # Residual blocks: drop the inner ``downsamples`` / ``upsamples`` | ||
| # container name, keep the per-block index, regroup under ``resnets``. | ||
| r"^encoder\.downsamples\.(\d+)\.downsamples\.(\d+)\.(residual|shortcut)\.(.*)$": ( | ||
| r"encoder.downsamples.\1.resnets.\2.\3.\4" | ||
| ), | ||
| r"^decoder\.upsamples\.(\d+)\.upsamples\.(\d+)\.(residual|shortcut)\.(.*)$": ( | ||
| r"decoder.upsamples.\1.resnets.\2.\3.\4" | ||
| ), | ||
| # Resample: the single resample per stage (it carries the | ||
| # ``resample`` / ``time_conv`` leaves, never ``residual``) folds into | ||
| # the stage's ``downsampler`` / ``upsampler``; its sub-block index is | ||
| # dropped since there is exactly one. | ||
| r"^encoder\.downsamples\.(\d+)\.downsamples\.\d+\.(resample|time_conv)\.(.*)$": ( | ||
| r"encoder.downsamples.\1.downsampler.\2.\3" | ||
| ), | ||
| r"^decoder\.upsamples\.(\d+)\.upsamples\.\d+\.(resample|time_conv)\.(.*)$": ( | ||
| r"decoder.upsamples.\1.upsampler.\2.\3" | ||
| ), | ||
| } | ||
|
|
There was a problem hiding this comment.
I think we just need to name the component in the same way it is named in pth they it can load pth?
There was a problem hiding this comment.
Done — the production VAE config now loads the native Wan2.2_VAE.pth directly, whose keys already match our WanVAE layout, so no remap runs.
Greptile SummaryThis PR adds
Confidence Score: 4/5Safe to merge once the stale Note is updated; the remap logic and bijection tests are correct. The remap logic is sound and well-tested, but the function docstring immediately contradicts the code: it claims diffusers remains the production default on the same commit that switches both config dataclasses to the .pth path. A second minor concern is a fragile len(raw) <= 4 guard in the manual test, but that test is not run in CI. The stale Note in wan22_ti2v_5b_vae_pth_state_dict_transform (lines 1685-1688 of vae.py) is the only spot that needs a fix before merge. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["Wan2.2_VAE.pth\n(upstream native layout)"] --> B["wan22_ti2v_5b_vae_pth_state_dict_transform"]
B --> C{"Key type\n(leaf name)"}
C -- "residual | shortcut" --> D["encoder.downsamples.i.resnets.j.*\ndecoder.upsamples.i.resnets.j.*"]
C -- "resample | time_conv" --> E["encoder.downsamples.i.downsampler.*\ndecoder.upsamples.i.upsampler.*"]
C -- "no match\n(middle / head / conv1 / conv2)" --> F["pass-through unchanged"]
D --> G["WanVAE state_dict\n(196 keys, full bijection)"]
E --> G
F --> G
G --> H["load_state_dict(strict=False, assign=True)\nno meta leftovers"]
Reviews (1): Last reviewed commit: "feat(wan-vae): default Wan 2.2 TI2V-5B V..." | Re-trigger Greptile |
| Note: | ||
| Decode-parity against the diffusers path is not yet verified on | ||
| GPU; until it is, the production configs keep diffusers as the | ||
| default. |
There was a problem hiding this comment.
Stale
Note contradicts the production configs changed in this PR
The Note says "the production configs keep diffusers as the default", but lines 1712 and 1732 in this very PR change Wan22TI2V5BVAEEncoderConfig.checkpoint_path and Wan22TI2V5BVAEDecoderConfig.checkpoint_path to WAN22_TI2V_5B_VAE_PATH and wan22_ti2v_5b_vae_pth_state_dict_transform. A reader of this docstring will be told the opposite of what the dataclass defaults actually are.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| from flashdreams.recipes.wan.autoencoder.vae import ( | ||
| wan22_ti2v_5b_vae_pth_state_dict_transform, | ||
| wan22_ti2v_5b_vae_state_dict_transform, |
There was a problem hiding this comment.
Magic-number length guard may silently break on a deeper wrapper
The len(raw) <= 4 condition is meant to tell a top-level metadata dict from the real state dict, but the threshold is a hard-coded guess. If the upstream .pth ever adds even one extra metadata key (e.g. epoch, training step, or a second nested dict), len(raw) could exceed 4 while the actual weights are still under raw["state_dict"], causing the function to use the un-unwrapped metadata dict as from_pth. A single unconditional unwrap — check for "state_dict" once, then stop — is the conventional pattern and avoids this fragility.
One follow-up item from #203 (HY-WorldPlay #155).
#155 review asked whether loading upstream's single-file
Wan2.2_VAE.pthlets us drop the diffusersstate_dict_transform.Finding: it can't be dropped, but it shrinks a lot. The
.pthuses Wan's native module layout (each down/up stage is a flatSequentialof residual blocks + a resample) while ourWanVAEgroups them (resnets.{j}+downsampler/upsampler). So it still needs a remap — 4 rules vs the ~50-rule diffusers remap. The earlier "missing params / meta tensors" was a misdiagnosis: all params are present under different key names; without a transformload_state_dict(assign=True)simply left our slots on meta.wan22_ti2v_5b_vae_pth_state_dict_transform(4-rule remap). Verified 196↔196 key/shape bijection with the WanVAE module tree (no meta leftovers, no unexpected keys).test_wan22_vae_pth_remap_is_full_bijection(builds the TI2V-5B WanVAE on meta, proves full coverage) +test_wan22_vae_pth_remap_spot_checks_real_keys(guards the regex against real.pthkey strings)..pthmatched our layout directly.Opt-in: diffusers stays the production default; a follow-up flips the default + drops the ~50-rule remap once a GPU decode-parity smoke confirms identical output.
Part of #203.