Add synthetic interactive-drive world model#337
Conversation
|
/ok to test 055fa38 |
Greptile SummaryThis PR introduces a
Confidence Score: 4/5Safe to merge for the recommended workflow (synthetic YAML + disabled fp8); the fp8+synthetic combination reachable via --manifest perf.yaml --synthetic-model is untested and could fail at runtime. The entire new code path for native_vae_fp8=True in _build_native_vae_fp8_state is exercised by no test. It accesses encoder.vae on the object returned by WanVAEEncoderConfig.setup() — an undocumented structural assumption — and would produce an AttributeError the first time someone runs --manifest example_world_model_perf.yaml --synthetic-model against an fp8 manifest. All other changes are well-tested and prior review concerns have been addressed. integrations/omnidreams/omnidreams/interactive_drive/world_model/synthetic_fixture.py — specifically _build_native_vae_fp8_state and _load_lightvae_fp8_export_module Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
CLI["CLI / YAML manifest\n--synthetic-model"] --> LoadManifest["load_world_model_manifest()\nWorldModelManifest.synthetic_model=True"]
LoadManifest --> BuildPipelineCfg["_build_pipeline_config()"]
BuildPipelineCfg --> ApplyOverrides["_apply_synthetic_model_overrides()"]
ApplyOverrides --> BuildAssets["build_synthetic_world_model_assets()"]
BuildAssets --> WanEncoder["_maybe_build_wan_encoder_checkpoint()\n-> synthetic_vae_encoder_XXXX.safetensors"]
BuildAssets --> Decoder["_build_decoder_checkpoint()\nTAEHV or WanVAE\n-> synthetic_*_decoder_XXXX.safetensors"]
BuildAssets --> FP8{native_vae_fp8?}
FP8 -->|Yes| FP8State["_build_native_vae_fp8_state()\nencoder.vae <- UNTESTED PATH\n-> synthetic_lightvae_fp8_state_XXXX.pt"]
FP8 -->|No| SkipFP8[skip]
ApplyOverrides --> DeriveConfig["derive_config()\ncheckpoint_path = synthetic path\ntext_encoder = None\nimage_encoder = None\ndiffusion_model.transformer.checkpoint_path = None\nsynthetic_text_max_length = N"]
DeriveConfig --> Warmup["warmup_model()\nbuilds pipeline eagerly"]
Warmup --> PrepareScene["prepare_for_scene() -> no-op"]
PrepareScene --> InitCache["_initialize_cache()\nsynthetic branch"]
InitCache --> SynthEmb["_synthetic_embeddings_for_pipeline()\nzero text + image tensors"]
SynthEmb --> CacheFromEmb["initialize_cache_from_embeddings()"]
CacheFromEmb --> Generate["generate() x N chunks\n-> latency measurement"]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
CLI["CLI / YAML manifest\n--synthetic-model"] --> LoadManifest["load_world_model_manifest()\nWorldModelManifest.synthetic_model=True"]
LoadManifest --> BuildPipelineCfg["_build_pipeline_config()"]
BuildPipelineCfg --> ApplyOverrides["_apply_synthetic_model_overrides()"]
ApplyOverrides --> BuildAssets["build_synthetic_world_model_assets()"]
BuildAssets --> WanEncoder["_maybe_build_wan_encoder_checkpoint()\n-> synthetic_vae_encoder_XXXX.safetensors"]
BuildAssets --> Decoder["_build_decoder_checkpoint()\nTAEHV or WanVAE\n-> synthetic_*_decoder_XXXX.safetensors"]
BuildAssets --> FP8{native_vae_fp8?}
FP8 -->|Yes| FP8State["_build_native_vae_fp8_state()\nencoder.vae <- UNTESTED PATH\n-> synthetic_lightvae_fp8_state_XXXX.pt"]
FP8 -->|No| SkipFP8[skip]
ApplyOverrides --> DeriveConfig["derive_config()\ncheckpoint_path = synthetic path\ntext_encoder = None\nimage_encoder = None\ndiffusion_model.transformer.checkpoint_path = None\nsynthetic_text_max_length = N"]
DeriveConfig --> Warmup["warmup_model()\nbuilds pipeline eagerly"]
Warmup --> PrepareScene["prepare_for_scene() -> no-op"]
PrepareScene --> InitCache["_initialize_cache()\nsynthetic branch"]
InitCache --> SynthEmb["_synthetic_embeddings_for_pipeline()\nzero text + image tensors"]
SynthEmb --> CacheFromEmb["initialize_cache_from_embeddings()"]
CacheFromEmb --> Generate["generate() x N chunks\n-> latency measurement"]
Reviews (6): Last reviewed commit: "Add synthetic interactive-drive world mo..." | Re-trigger Greptile |
| assert config.encoder is not None, "synthetic Omnidreams config requires an encoder" | ||
| assert config.decoder is not None, "synthetic Omnidreams config requires a decoder" |
There was a problem hiding this comment.
assert guards disabled under -O, giving a confusing downstream error
Both checks are removed entirely when Python is launched with -O (common for optimized deployments). If config.encoder or config.decoder is None, execution continues into build_synthetic_world_model_assets, where it passes None as encoder_cfg/decoder_cfg, eventually failing with a confusing AttributeError or isinstance mismatch instead of the clear message written here. Replace both assert statements with explicit if/raise ValueError guards.
There was a problem hiding this comment.
check if all issues on this PR are resolved now
| def _config_uses_lightvae(config: Any) -> bool: | ||
| """Detect the lightvae variant without assuming ``checkpoint_path`` is set.""" | ||
|
|
||
| return "lightvae" in (getattr(config, "checkpoint_path", None) or "") |
There was a problem hiding this comment.
_config_uses_lightvae silently mis-detects architecture when checkpoint path lacks the substring
The docstring says "without assuming checkpoint_path is set," but the implementation falls back to "" when the attribute is absent or None, always returning False. More critically, if a production config stores the encoder under a path that doesn't contain the literal substring "lightvae" (e.g., a URL like hf://nvidia/flashdreams-vae/light_encoder.safetensors), the function returns False, _new_wan_vae_from_encoder_config instantiates the standard-VAE architecture, and the emitted state-dict will have mismatched keys/shapes relative to what the lightvae loader expects. This surfaces as a load error only at warmup time. Consider using a dedicated boolean attribute on WanVAEEncoderConfig to identify the variant instead of a path substring match.
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
| from typing import Any | ||
| from unittest import mock |
There was a problem hiding this comment.
unittest.mock imported in production code
synthetic_fixture.py is a production module (not under tests/), yet it imports from unittest import mock and uses mock.patch.object to suppress load_checkpoint during WanVAE construction. This couples production code to the test stdlib, makes static analysis tools flag the import, and surprises readers. A cleaner alternative is to pass an optional checkpoint_loader callable into wan_vae_module.WanVAE (if the API supports it) or wrap the VAE construction in a small helper that locally overrides only the checkpoint loading call without reaching into unittest.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if path.exists(): | ||
| return path |
There was a problem hiding this comment.
Cached checkpoint not invalidated when config architecture changes
_maybe_build_wan_encoder_checkpoint, _build_decoder_checkpoint, and _build_native_vae_fp8_state all short-circuit with if path.exists(): return path. The cache key is the config_name slug; if the underlying recipe's architectural parameters (e.g., z_dim, base_dim) change between runs the stale file is silently reused. The mismatch only surfaces as a load error at warmup time. Consider embedding a hash of the architecture parameters in the filename, or adding a small sidecar manifest that records the parameters used to generate the checkpoint.
| config = derive_config( | ||
| config, | ||
| text_encoder=None, | ||
| image_encoder=None, | ||
| encoder=encoder_patch, | ||
| decoder=decoder_patch, | ||
| diffusion_model=dict(transformer=dict(checkpoint_path=None)), | ||
| ) | ||
| config.synthetic_text_max_length = text_max_length | ||
| return config |
There was a problem hiding this comment.
synthetic_text_max_length set via direct attribute mutation rather than derive_config
Directly assigning to the config object bypasses the derive_config pattern used everywhere else in this file, breaks if the config is ever made a frozen dataclass, and means the value is read back through a fragile getattr(..., 512) default that would silently give the wrong sequence length if the assignment were ever missed. Expose the value through derive_config like the other overrides in this function so it participates in the immutable-config chain.
| config = derive_config( | |
| config, | |
| text_encoder=None, | |
| image_encoder=None, | |
| encoder=encoder_patch, | |
| decoder=decoder_patch, | |
| diffusion_model=dict(transformer=dict(checkpoint_path=None)), | |
| ) | |
| config.synthetic_text_max_length = text_max_length | |
| return config | |
| config = derive_config( | |
| config, | |
| text_encoder=None, | |
| image_encoder=None, | |
| encoder=encoder_patch, | |
| decoder=decoder_patch, | |
| diffusion_model=dict(transformer=dict(checkpoint_path=None)), | |
| synthetic_text_max_length=text_max_length, | |
| ) | |
| return config |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
055fa38 to
f6d614c
Compare
| torch.save(state, path) | ||
| return path |
There was a problem hiding this comment.
torch.save writes directly to the final path, so a crash mid-write leaves a partial .pt file on disk. The next run hits if path.exists(): return path and silently hands a corrupt file to the production warmup loader, which fails with a confusing UnpicklingError (or similar) far from the actual write site. Every other cache-write in this file (_save_state_dict) uses the tmp → atomic-replace pattern; the fp8 state path should too.
| torch.save(state, path) | |
| return path | |
| tmp_path = path.with_suffix(path.suffix + ".tmp") | |
| torch.save(state, tmp_path) | |
| tmp_path.replace(path) | |
| return path |
f6d614c to
b5a8aea
Compare
| @contextlib.contextmanager | ||
| def _suppressed_checkpoint_load() -> Iterator[None]: | ||
| """Build ``WanVAE`` without reading a checkpoint. | ||
|
|
||
| The VAE is constructed only to capture its default-initialized architecture | ||
| before we overwrite the weights, so the module-level ``load_checkpoint`` | ||
| (which would fail -- the synthetic file does not exist yet) is swapped for a | ||
| no-op returning an empty state dict. A local set/restore keeps | ||
| ``unittest.mock`` out of this production module. | ||
| """ | ||
|
|
||
| original = wan_vae_module.load_checkpoint | ||
| wan_vae_module.load_checkpoint = lambda *args, **kwargs: {} | ||
| try: | ||
| yield | ||
| finally: | ||
| wan_vae_module.load_checkpoint = original |
There was a problem hiding this comment.
Thread-unsafe module-level attribute swap in
_suppressed_checkpoint_load
The set/restore of wan_vae_module.load_checkpoint has a classic TOCTOU race when two threads enter this context manager concurrently: Thread B can read original after Thread A has already written the lambda, so Thread B's finally block restores the mock instead of the real function. From that point on every real (non-synthetic) WanVAE construction anywhere in the process silently returns {} from load_checkpoint, producing a module with random weights and no diagnostic.
The current call sites are sequential (encoder then decoder inside build_synthetic_world_model_assets), so the race is latent rather than immediately triggered. However the public API of build_synthetic_world_model_assets imposes no ordering constraint, and a test that builds multiple config fixtures in a ThreadPoolExecutor would expose it. Guard the swap with a threading.Lock at module level — or, since this only needs to block concurrent entries, a threading.RLock would also work if the same thread ever re-enters.
b5a8aea to
42bc16a
Compare
|
/ok to test 42bc16a |
Enable offline interactive-drive latency runs by swapping only the model weight sources: the DiT initializes random weights (checkpoint_path=None) and the VAE/TAE load locally generated default-initialized checkpoints, while torch.compile, CUDA graphs, and native fp8 acceleration stay exactly as a real run configures them. Combined with --synthetic-scene and config-derived synthetic embeddings, the full world-model render and tracing path runs without any checkpoint or scene downloads. Add the --synthetic-model flag, the example_world_model_synthetic.yaml manifest, the offline asset fixture, CPU tests, and a GPU smoke test covering the path.
42bc16a to
42a9066
Compare
|
/ok to test 42a9066 |
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
Enable offline interactive-drive latency runs by swapping only the model weight sources: the DiT initializes random weights (checkpoint_path=None) and the VAE/TAE load locally generated default-initialized checkpoints, while torch.compile, CUDA graphs, and native fp8 acceleration stay exactly as a real run configures them. Combined with --synthetic-scene and config-derived synthetic embeddings, the full world-model render and tracing path runs without any checkpoint or scene downloads.
Add the --synthetic-model flag, the example_world_model_synthetic.yaml manifest, the offline asset fixture, CPU tests, and a GPU smoke test covering the path.