Skip to content

Add synthetic interactive-drive world model#337

Open
pknowlesnv wants to merge 1 commit into
mainfrom
dev/pknowles/restore_synthetic_vm
Open

Add synthetic interactive-drive world model#337
pknowlesnv wants to merge 1 commit into
mainfrom
dev/pknowles/restore_synthetic_vm

Conversation

@pknowlesnv

Copy link
Copy Markdown
Collaborator

Enable offline interactive-drive latency runs by swapping only the model weight sources: the DiT initializes random weights (checkpoint_path=None) and the VAE/TAE load locally generated default-initialized checkpoints, while torch.compile, CUDA graphs, and native fp8 acceleration stay exactly as a real run configures them. Combined with --synthetic-scene and config-derived synthetic embeddings, the full world-model render and tracing path runs without any checkpoint or scene downloads.

Add the --synthetic-model flag, the example_world_model_synthetic.yaml manifest, the offline asset fixture, CPU tests, and a GPU smoke test covering the path.

@copy-pr-bot

copy-pr-bot Bot commented Jun 15, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@pknowlesnv

Copy link
Copy Markdown
Collaborator Author

/ok to test 055fa38

@greptile-apps

greptile-apps Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces a --synthetic-model flag (and synthetic_model manifest field) that swaps only weight sources — DiT uses checkpoint_path=None, VAE/TAE use locally generated default-initialized checkpoints — while leaving compile, CUDA graphs, and fp8 acceleration intact. Combined with --synthetic-scene, the full interactive-drive world-model render and tracing path can run completely offline for latency benchmarking.

  • synthetic_fixture.py materializes CPU-side safetensors checkpoints for WanVAE encoder, WanVAE or TAEHV decoder, and (optionally) a fp8 quantization-state file; the checkpoint-swap lock, atomic writes, and arch-fingerprinted cache filenames address all concerns raised in prior reviews.
  • _apply_synthetic_model_overrides in flashdreams_adapter.py patches checkpoint paths and zeros the text/image embeddings through derive_config, preserving synthetic_text_max_length for downstream embedding sizing instead of a mutable attribute assignment.
  • New tests cover CPU round-trips, config-patch invariants, and a GPU smoke test; the native_vae_fp8=True code path inside _build_native_vae_fp8_state has no test coverage and contains an undocumented encoder.vae attribute assumption that would fail silently in production with an fp8 manifest.

Confidence Score: 4/5

Safe to merge for the recommended workflow (synthetic YAML + disabled fp8); the fp8+synthetic combination reachable via --manifest perf.yaml --synthetic-model is untested and could fail at runtime.

The entire new code path for native_vae_fp8=True in _build_native_vae_fp8_state is exercised by no test. It accesses encoder.vae on the object returned by WanVAEEncoderConfig.setup() — an undocumented structural assumption — and would produce an AttributeError the first time someone runs --manifest example_world_model_perf.yaml --synthetic-model against an fp8 manifest. All other changes are well-tested and prior review concerns have been addressed.

integrations/omnidreams/omnidreams/interactive_drive/world_model/synthetic_fixture.py — specifically _build_native_vae_fp8_state and _load_lightvae_fp8_export_module

Important Files Changed

Filename Overview
integrations/omnidreams/omnidreams/interactive_drive/world_model/synthetic_fixture.py New module building offline synthetic VAE/TAE checkpoints; addresses prior review concerns (atomic writes, thread-safe swap, arch-fingerprinted cache) but the _build_native_vae_fp8_state code path is entirely untested and relies on undocumented encoder.vae attribute access.
integrations/omnidreams/omnidreams/interactive_drive/world_model/flashdreams_adapter.py Adds _apply_synthetic_model_overrides and _synthetic_embeddings_for_pipeline; previous thread concerns (assert guards, synthetic_text_max_length via derive_config) addressed; logic for synthetic cache init path looks correct.
integrations/omnidreams/omnidreams/interactive_drive/world_model/manifest.py Adds synthetic_model: bool = False field to frozen dataclass with correct bool() coercion in loader.
integrations/omnidreams/omnidreams/pipeline.py Adds synthetic_text_max_length: int
integrations/omnidreams/tests/interactive_drive/test_synthetic_fixture.py Tests CPU round-trip for WanVAE encoder + TAEHV decoder; deliberately omits the native_vae_fp8=True path, leaving _build_native_vae_fp8_state without any coverage.
integrations/omnidreams/tests/interactive_drive/test_world_model_adapter.py Adds comprehensive tests for synthetic config patching and cache initialization with fake pipeline; embedding shape assertions are correct per the compression ratio arithmetic.
integrations/omnidreams/tests/interactive_drive/test_app_smoke.py Adds GPU smoke test using example_world_model_synthetic.yaml; correctly fixes interactive_drive → omnidreams.interactive_drive module invocation path.
integrations/omnidreams/omnidreams/interactive_drive/configs/example_world_model_synthetic.yaml Mirrors perf YAML with synthetic_model: true and native_vae_encoder: disabled; comments guide users toward the --synthetic-model flag to avoid drift.
integrations/omnidreams/omnidreams/interactive_drive/cli.py Adds --synthetic-model BooleanOptionalAction flag and correctly threads it into the manifest via replace().

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    CLI["CLI / YAML manifest\n--synthetic-model"] --> LoadManifest["load_world_model_manifest()\nWorldModelManifest.synthetic_model=True"]
    LoadManifest --> BuildPipelineCfg["_build_pipeline_config()"]
    BuildPipelineCfg --> ApplyOverrides["_apply_synthetic_model_overrides()"]

    ApplyOverrides --> BuildAssets["build_synthetic_world_model_assets()"]
    BuildAssets --> WanEncoder["_maybe_build_wan_encoder_checkpoint()\n-> synthetic_vae_encoder_XXXX.safetensors"]
    BuildAssets --> Decoder["_build_decoder_checkpoint()\nTAEHV or WanVAE\n-> synthetic_*_decoder_XXXX.safetensors"]
    BuildAssets --> FP8{native_vae_fp8?}
    FP8 -->|Yes| FP8State["_build_native_vae_fp8_state()\nencoder.vae <- UNTESTED PATH\n-> synthetic_lightvae_fp8_state_XXXX.pt"]
    FP8 -->|No| SkipFP8[skip]

    ApplyOverrides --> DeriveConfig["derive_config()\ncheckpoint_path = synthetic path\ntext_encoder = None\nimage_encoder = None\ndiffusion_model.transformer.checkpoint_path = None\nsynthetic_text_max_length = N"]

    DeriveConfig --> Warmup["warmup_model()\nbuilds pipeline eagerly"]
    Warmup --> PrepareScene["prepare_for_scene() -> no-op"]
    PrepareScene --> InitCache["_initialize_cache()\nsynthetic branch"]
    InitCache --> SynthEmb["_synthetic_embeddings_for_pipeline()\nzero text + image tensors"]
    SynthEmb --> CacheFromEmb["initialize_cache_from_embeddings()"]
    CacheFromEmb --> Generate["generate() x N chunks\n-> latency measurement"]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    CLI["CLI / YAML manifest\n--synthetic-model"] --> LoadManifest["load_world_model_manifest()\nWorldModelManifest.synthetic_model=True"]
    LoadManifest --> BuildPipelineCfg["_build_pipeline_config()"]
    BuildPipelineCfg --> ApplyOverrides["_apply_synthetic_model_overrides()"]

    ApplyOverrides --> BuildAssets["build_synthetic_world_model_assets()"]
    BuildAssets --> WanEncoder["_maybe_build_wan_encoder_checkpoint()\n-> synthetic_vae_encoder_XXXX.safetensors"]
    BuildAssets --> Decoder["_build_decoder_checkpoint()\nTAEHV or WanVAE\n-> synthetic_*_decoder_XXXX.safetensors"]
    BuildAssets --> FP8{native_vae_fp8?}
    FP8 -->|Yes| FP8State["_build_native_vae_fp8_state()\nencoder.vae <- UNTESTED PATH\n-> synthetic_lightvae_fp8_state_XXXX.pt"]
    FP8 -->|No| SkipFP8[skip]

    ApplyOverrides --> DeriveConfig["derive_config()\ncheckpoint_path = synthetic path\ntext_encoder = None\nimage_encoder = None\ndiffusion_model.transformer.checkpoint_path = None\nsynthetic_text_max_length = N"]

    DeriveConfig --> Warmup["warmup_model()\nbuilds pipeline eagerly"]
    Warmup --> PrepareScene["prepare_for_scene() -> no-op"]
    PrepareScene --> InitCache["_initialize_cache()\nsynthetic branch"]
    InitCache --> SynthEmb["_synthetic_embeddings_for_pipeline()\nzero text + image tensors"]
    SynthEmb --> CacheFromEmb["initialize_cache_from_embeddings()"]
    CacheFromEmb --> Generate["generate() x N chunks\n-> latency measurement"]
Loading

Reviews (6): Last reviewed commit: "Add synthetic interactive-drive world mo..." | Re-trigger Greptile

Comment on lines +192 to +193
assert config.encoder is not None, "synthetic Omnidreams config requires an encoder"
assert config.decoder is not None, "synthetic Omnidreams config requires a decoder"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 assert guards disabled under -O, giving a confusing downstream error

Both checks are removed entirely when Python is launched with -O (common for optimized deployments). If config.encoder or config.decoder is None, execution continues into build_synthetic_world_model_assets, where it passes None as encoder_cfg/decoder_cfg, eventually failing with a confusing AttributeError or isinstance mismatch instead of the clear message written here. Replace both assert statements with explicit if/raise ValueError guards.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if all issues on this PR are resolved now

Comment on lines +102 to +105
def _config_uses_lightvae(config: Any) -> bool:
"""Detect the lightvae variant without assuming ``checkpoint_path`` is set."""

return "lightvae" in (getattr(config, "checkpoint_path", None) or "")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _config_uses_lightvae silently mis-detects architecture when checkpoint path lacks the substring

The docstring says "without assuming checkpoint_path is set," but the implementation falls back to "" when the attribute is absent or None, always returning False. More critically, if a production config stores the encoder under a path that doesn't contain the literal substring "lightvae" (e.g., a URL like hf://nvidia/flashdreams-vae/light_encoder.safetensors), the function returns False, _new_wan_vae_from_encoder_config instantiates the standard-VAE architecture, and the emitted state-dict will have mismatched keys/shapes relative to what the lightvae loader expects. This surfaces as a load error only at warmup time. Consider using a dedicated boolean attribute on WanVAEEncoderConfig to identify the variant instead of a path substring match.

from dataclasses import dataclass
from pathlib import Path
from typing import Any
from unittest import mock

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 unittest.mock imported in production code

synthetic_fixture.py is a production module (not under tests/), yet it imports from unittest import mock and uses mock.patch.object to suppress load_checkpoint during WanVAE construction. This couples production code to the test stdlib, makes static analysis tools flag the import, and surprises readers. A cleaner alternative is to pass an optional checkpoint_loader callable into wan_vae_module.WanVAE (if the API supports it) or wrap the VAE construction in a small helper that locally overrides only the checkpoint loading call without reaching into unittest.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +122 to +123
if path.exists():
return path

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Cached checkpoint not invalidated when config architecture changes

_maybe_build_wan_encoder_checkpoint, _build_decoder_checkpoint, and _build_native_vae_fp8_state all short-circuit with if path.exists(): return path. The cache key is the config_name slug; if the underlying recipe's architectural parameters (e.g., z_dim, base_dim) change between runs the stale file is silently reused. The mismatch only surfaces as a load error at warmup time. Consider embedding a hash of the architecture parameters in the filename, or adding a small sidecar manifest that records the parameters used to generate the checkpoint.

Comment on lines +221 to +230
config = derive_config(
config,
text_encoder=None,
image_encoder=None,
encoder=encoder_patch,
decoder=decoder_patch,
diffusion_model=dict(transformer=dict(checkpoint_path=None)),
)
config.synthetic_text_max_length = text_max_length
return config

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 synthetic_text_max_length set via direct attribute mutation rather than derive_config

Directly assigning to the config object bypasses the derive_config pattern used everywhere else in this file, breaks if the config is ever made a frozen dataclass, and means the value is read back through a fragile getattr(..., 512) default that would silently give the wrong sequence length if the assignment were ever missed. Expose the value through derive_config like the other overrides in this function so it participates in the immutable-config chain.

Suggested change
config = derive_config(
config,
text_encoder=None,
image_encoder=None,
encoder=encoder_patch,
decoder=decoder_patch,
diffusion_model=dict(transformer=dict(checkpoint_path=None)),
)
config.synthetic_text_max_length = text_max_length
return config
config = derive_config(
config,
text_encoder=None,
image_encoder=None,
encoder=encoder_patch,
decoder=decoder_patch,
diffusion_model=dict(transformer=dict(checkpoint_path=None)),
synthetic_text_max_length=text_max_length,
)
return config

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@pknowlesnv pknowlesnv force-pushed the dev/pknowles/restore_synthetic_vm branch from 055fa38 to f6d614c Compare June 15, 2026 21:09
Comment on lines +350 to +351
torch.save(state, path)
return path

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 torch.save writes directly to the final path, so a crash mid-write leaves a partial .pt file on disk. The next run hits if path.exists(): return path and silently hands a corrupt file to the production warmup loader, which fails with a confusing UnpicklingError (or similar) far from the actual write site. Every other cache-write in this file (_save_state_dict) uses the tmp → atomic-replace pattern; the fp8 state path should too.

Suggested change
torch.save(state, path)
return path
tmp_path = path.with_suffix(path.suffix + ".tmp")
torch.save(state, tmp_path)
tmp_path.replace(path)
return path

@pknowlesnv pknowlesnv force-pushed the dev/pknowles/restore_synthetic_vm branch from f6d614c to b5a8aea Compare June 15, 2026 22:51
Comment on lines +157 to +173
@contextlib.contextmanager
def _suppressed_checkpoint_load() -> Iterator[None]:
"""Build ``WanVAE`` without reading a checkpoint.

The VAE is constructed only to capture its default-initialized architecture
before we overwrite the weights, so the module-level ``load_checkpoint``
(which would fail -- the synthetic file does not exist yet) is swapped for a
no-op returning an empty state dict. A local set/restore keeps
``unittest.mock`` out of this production module.
"""

original = wan_vae_module.load_checkpoint
wan_vae_module.load_checkpoint = lambda *args, **kwargs: {}
try:
yield
finally:
wan_vae_module.load_checkpoint = original

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Thread-unsafe module-level attribute swap in _suppressed_checkpoint_load

The set/restore of wan_vae_module.load_checkpoint has a classic TOCTOU race when two threads enter this context manager concurrently: Thread B can read original after Thread A has already written the lambda, so Thread B's finally block restores the mock instead of the real function. From that point on every real (non-synthetic) WanVAE construction anywhere in the process silently returns {} from load_checkpoint, producing a module with random weights and no diagnostic.

The current call sites are sequential (encoder then decoder inside build_synthetic_world_model_assets), so the race is latent rather than immediately triggered. However the public API of build_synthetic_world_model_assets imposes no ordering constraint, and a test that builds multiple config fixtures in a ThreadPoolExecutor would expose it. Guard the swap with a threading.Lock at module level — or, since this only needs to block concurrent entries, a threading.RLock would also work if the same thread ever re-enters.

@pknowlesnv pknowlesnv force-pushed the dev/pknowles/restore_synthetic_vm branch from b5a8aea to 42bc16a Compare June 15, 2026 23:11
@pknowlesnv

Copy link
Copy Markdown
Collaborator Author

/ok to test 42bc16a

Enable offline interactive-drive latency runs by swapping only the model
weight sources: the DiT initializes random weights (checkpoint_path=None)
and the VAE/TAE load locally generated default-initialized checkpoints,
while torch.compile, CUDA graphs, and native fp8 acceleration stay exactly
as a real run configures them. Combined with --synthetic-scene and
config-derived synthetic embeddings, the full world-model render and
tracing path runs without any checkpoint or scene downloads.

Add the --synthetic-model flag, the example_world_model_synthetic.yaml
manifest, the offline asset fixture, CPU tests, and a GPU smoke test
covering the path.
@pknowlesnv pknowlesnv force-pushed the dev/pknowles/restore_synthetic_vm branch from 42bc16a to 42a9066 Compare June 15, 2026 23:36
@pknowlesnv

Copy link
Copy Markdown
Collaborator Author

/ok to test 42a9066

@greptile-apps

greptile-apps Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Want your agent to iterate on Greptile's feedback? Try greploops.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant