-
Notifications
You must be signed in to change notification settings - Fork 72
Fix tests for Flux, WAN, SDXL and LTX-Video to resolve execution and environment issues #394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Perseus14
wants to merge
11
commits into
main
Choose a base branch
from
tests_fix
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
8c37368
Fix tests for Flux, WAN, SDXL and LTX-Video to resolve execution and …
Perseus14 f4529c9
Move nnx import to top of files
Perseus14 43e917e
Fix lint
Perseus14 a5deade
Fix dimension mismatch in controlnet and add tearDown for GC in SDXL …
Perseus14 6149afe
Fix lint
Perseus14 49e580f
Fix lint
Perseus14 9c15390
Fix SDXL
Perseus14 38b5daa
Fix sdxl tests
Perseus14 ccf57e2
Fix sdxl
Perseus14 55aae40
Fix SDXL
Perseus14 967db7a
Fix sdxl
Perseus14 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -115,14 +115,18 @@ def tokenize(prompt, pipeline): | |
| return inputs | ||
|
|
||
|
|
||
| def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size): | ||
| def get_unet_inputs(pipeline, scheduler_params, states, config, rng, mesh, batch_size): | ||
| data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) | ||
|
|
||
| vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟡 Good use of sharding constraints to ensure consistent data placement and avoid unnecessary communication or re-sharding during the inference loop.
|
||
| prompt_ids = [config.prompt] * batch_size | ||
| prompt_ids = tokenize(prompt_ids, pipeline) | ||
| prompt_ids = jax.lax.with_sharding_constraint(prompt_ids, jax.sharding.NamedSharding(mesh, P("data", None, None))) | ||
| negative_prompt_ids = [config.negative_prompt] * batch_size | ||
| negative_prompt_ids = tokenize(negative_prompt_ids, pipeline) | ||
| negative_prompt_ids = jax.lax.with_sharding_constraint( | ||
| negative_prompt_ids, jax.sharding.NamedSharding(mesh, P("data", None, None)) | ||
| ) | ||
| guidance_scale = config.guidance_scale | ||
| guidance_rescale = config.guidance_rescale | ||
| num_inference_steps = config.num_inference_steps | ||
|
|
@@ -133,6 +137,8 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size): | |
| "text_encoder_2": states["text_encoder_2_state"].params, | ||
| } | ||
| prompt_embeds, pooled_embeds = get_embeddings(prompt_ids, pipeline, text_encoder_params) | ||
| prompt_embeds = jax.lax.with_sharding_constraint(prompt_embeds, jax.sharding.NamedSharding(mesh, P("data", None, None))) | ||
| pooled_embeds = jax.lax.with_sharding_constraint(pooled_embeds, jax.sharding.NamedSharding(mesh, P("data", None))) | ||
|
|
||
| batch_size = prompt_embeds.shape[0] | ||
| add_time_ids = get_add_time_ids( | ||
|
|
@@ -148,6 +154,9 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size): | |
|
|
||
| prompt_embeds = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) | ||
| add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) | ||
| prompt_embeds = jax.lax.with_sharding_constraint(prompt_embeds, jax.sharding.NamedSharding(mesh, P("data", None, None))) | ||
| add_text_embeds = jax.lax.with_sharding_constraint(add_text_embeds, jax.sharding.NamedSharding(mesh, P("data", None))) | ||
|
|
||
| add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0) | ||
|
|
||
| else: | ||
|
|
@@ -166,8 +175,11 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size): | |
|
|
||
| latents = jax.random.normal(rng, shape=latents_shape, dtype=jnp.float32) | ||
|
|
||
| if isinstance(scheduler_params, dict) and "scheduler" in scheduler_params: | ||
| scheduler_params = scheduler_params["scheduler"] | ||
|
|
||
| scheduler_state = pipeline.scheduler.set_timesteps( | ||
| params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape | ||
| scheduler_params, num_inference_steps=num_inference_steps, shape=latents.shape | ||
| ) | ||
|
|
||
| latents = latents * scheduler_state.init_noise_sigma | ||
|
|
@@ -217,9 +229,11 @@ def run_inference(states, pipeline, params, config, rng, mesh, batch_size): | |
| def run(config): | ||
| checkpoint_loader = GenerateSDXL(config) | ||
| mesh = checkpoint_loader.mesh | ||
| with mesh: | ||
| pipeline, params = checkpoint_loader.load_checkpoint() | ||
| # NOTE: load_checkpoint() is called outside the mesh context intentionally. | ||
| # If checkpoint loading requires mesh-aware sharding, move this back inside `with mesh:`. | ||
| pipeline, params = checkpoint_loader.load_checkpoint() | ||
|
|
||
| with mesh: | ||
| noise_scheduler, noise_scheduler_state = create_scheduler(pipeline.scheduler.config, config) | ||
|
|
||
| weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng) | ||
|
|
@@ -303,11 +317,13 @@ def run(config): | |
| _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] | ||
| p_run_inference(states).block_until_ready() | ||
| print("compile time: ", (time.time() - s)) | ||
|
|
||
| s = time.time() | ||
| with ExitStack() as stack: | ||
| _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] | ||
| images = p_run_inference(states).block_until_ready() | ||
| print("inference time: ", (time.time() - s)) | ||
|
|
||
| images = jax.experimental.multihost_utils.process_allgather(images, tiled=True) | ||
| numpy_images = np.array(images) | ||
| images = VaeImageProcessor.numpy_to_pil(numpy_images) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.