Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ jobs:
python --version
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
- name: PyTest
env:
HF_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --ignore=src/maxdiffusion/kernels/ --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --ignore=src/maxdiffusion/kernels/ -x --durations=0 -W ignore::DeprecationWarning -W ignore::UserWarning -W ignore::RuntimeWarning
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Filtering out these warnings makes CI logs much cleaner and easier to navigate for developers focusing on test results.

# add_pull_ready
# if: github.ref != 'refs/heads/main'
# permissions:
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ ici_tensor_parallelism: 1
allow_split_physical_axes: False
learning_rate_schedule_steps: -1
max_train_steps: 500
pretrained_model_name_or_path: ''
pretrained_model_name_or_path: 'Lightricks/LTX-Video'
unet_checkpoint: ''
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def create_key(seed=0):
def run(config):
rng = jax.random.PRNGKey(config.seed)

devices_array = max_utils.create_device_mesh(config)
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)

prompts = config.prompt
negative_prompts = config.negative_prompt
controlnet_conditioning_scale = config.controlnet_conditioning_scale
Expand All @@ -48,13 +51,14 @@ def run(config):
image = np.concatenate([image, image, image], axis=2)
image = Image.fromarray(image)

controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
config.controlnet_model_name_or_path, from_pt=config.controlnet_from_pt, dtype=config.activations_dtype
)
with mesh:
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
config.controlnet_model_name_or_path, from_pt=config.controlnet_from_pt, dtype=config.activations_dtype
)

pipe, params = FlaxStableDiffusionXLControlNetPipeline.from_pretrained(
config.pretrained_model_name_or_path, controlnet=controlnet, revision=config.revision, dtype=config.activations_dtype
)
pipe, params = FlaxStableDiffusionXLControlNetPipeline.from_pretrained(
config.pretrained_model_name_or_path, controlnet=controlnet, revision=config.revision, dtype=config.activations_dtype
)

scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
Expand All @@ -68,21 +72,23 @@ def run(config):
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)

output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=config.num_inference_steps,
neg_prompt_ids=negative_prompt_ids,
controlnet_conditioning_scale=controlnet_conditioning_scale,
jit=True,
).images

with mesh:
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)

output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=config.num_inference_steps,
neg_prompt_ids=negative_prompt_ids,
controlnet_conditioning_scale=controlnet_conditioning_scale,
jit=True,
).images

output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
output_images[0].save("generated_image.png")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
"""

import os
import functools
from absl import app
from typing import Sequence, Union, List
from datasets import load_dataset
import numpy as np
import jax
from flax import nnx
import jax.numpy as jnp
from jax.sharding import Mesh
from maxdiffusion import pyconfig, max_utils
Expand Down Expand Up @@ -110,8 +110,9 @@ def generate_dataset(config, pipeline):
vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample)
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)

# jit vae fun.
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))
@nnx.jit
def p_vae_encode(video, rng, vae, vae_cache):
return vae_encode(video, rng, vae, vae_cache)

# Load dataset
ds = load_dataset(config.dataset_name, split="train")
Expand All @@ -126,7 +127,7 @@ def generate_dataset(config, pipeline):
videos = [video_processor.preprocess_video([video], height=config.height, width=config.width) for video in videos]
video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
with mesh:
latents = p_vae_encode(video=video, rng=new_rng)
latents = p_vae_encode(video=video, rng=new_rng, vae=pipeline.vae, vae_cache=pipeline.vae_cache)
encoder_hidden_states = text_encode(pipeline, text)
for latent, encoder_hidden_state in zip(latents, encoder_hidden_states):
writer.write(create_example(latent, encoder_hidden_state))
Expand Down
24 changes: 20 additions & 4 deletions src/maxdiffusion/generate_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,18 @@ def tokenize(prompt, pipeline):
return inputs


def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
def get_unet_inputs(pipeline, scheduler_params, states, config, rng, mesh, batch_size):
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))

vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Good use of sharding constraints to ensure consistent data placement and avoid unnecessary communication or re-sharding during the inference loop.

prompt_ids = [config.prompt] * batch_size
prompt_ids = tokenize(prompt_ids, pipeline)
prompt_ids = jax.lax.with_sharding_constraint(prompt_ids, jax.sharding.NamedSharding(mesh, P("data", None, None)))
negative_prompt_ids = [config.negative_prompt] * batch_size
negative_prompt_ids = tokenize(negative_prompt_ids, pipeline)
negative_prompt_ids = jax.lax.with_sharding_constraint(
negative_prompt_ids, jax.sharding.NamedSharding(mesh, P("data", None, None))
)
guidance_scale = config.guidance_scale
guidance_rescale = config.guidance_rescale
num_inference_steps = config.num_inference_steps
Expand All @@ -133,6 +137,8 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
"text_encoder_2": states["text_encoder_2_state"].params,
}
prompt_embeds, pooled_embeds = get_embeddings(prompt_ids, pipeline, text_encoder_params)
prompt_embeds = jax.lax.with_sharding_constraint(prompt_embeds, jax.sharding.NamedSharding(mesh, P("data", None, None)))
pooled_embeds = jax.lax.with_sharding_constraint(pooled_embeds, jax.sharding.NamedSharding(mesh, P("data", None)))

batch_size = prompt_embeds.shape[0]
add_time_ids = get_add_time_ids(
Expand All @@ -148,6 +154,9 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):

prompt_embeds = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0)
add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0)
prompt_embeds = jax.lax.with_sharding_constraint(prompt_embeds, jax.sharding.NamedSharding(mesh, P("data", None, None)))
add_text_embeds = jax.lax.with_sharding_constraint(add_text_embeds, jax.sharding.NamedSharding(mesh, P("data", None)))

add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0)

else:
Expand All @@ -166,8 +175,11 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):

latents = jax.random.normal(rng, shape=latents_shape, dtype=jnp.float32)

if isinstance(scheduler_params, dict) and "scheduler" in scheduler_params:
scheduler_params = scheduler_params["scheduler"]

scheduler_state = pipeline.scheduler.set_timesteps(
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
scheduler_params, num_inference_steps=num_inference_steps, shape=latents.shape
)

latents = latents * scheduler_state.init_noise_sigma
Expand Down Expand Up @@ -217,9 +229,11 @@ def run_inference(states, pipeline, params, config, rng, mesh, batch_size):
def run(config):
checkpoint_loader = GenerateSDXL(config)
mesh = checkpoint_loader.mesh
with mesh:
pipeline, params = checkpoint_loader.load_checkpoint()
# NOTE: load_checkpoint() is called outside the mesh context intentionally.
# If checkpoint loading requires mesh-aware sharding, move this back inside `with mesh:`.
pipeline, params = checkpoint_loader.load_checkpoint()

with mesh:
noise_scheduler, noise_scheduler_state = create_scheduler(pipeline.scheduler.config, config)

weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng)
Expand Down Expand Up @@ -303,11 +317,13 @@ def run(config):
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
p_run_inference(states).block_until_ready()
print("compile time: ", (time.time() - s))

s = time.time()
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
images = p_run_inference(states).block_until_ready()
print("inference time: ", (time.time() - s))

images = jax.experimental.multihost_utils.process_allgather(images, tiled=True)
numpy_images = np.array(images)
images = VaeImageProcessor.numpy_to_pil(numpy_images)
Expand Down
9 changes: 6 additions & 3 deletions src/maxdiffusion/tests/data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

import os
import pytest
import functools
import jax
import jax.numpy as jnp
from flax import nnx
from flax.linen import partitioning as nn_partitioning
from jax.sharding import Mesh
from .. import pyconfig
Expand Down Expand Up @@ -81,11 +81,14 @@ def test_wan_vae_encode_normalization(self):
video = load_video(video_path)
videos = [video_processor.preprocess_video([video], height=config.height, width=config.width)]
videos = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))

@nnx.jit
def p_vae_encode(video, rng, vae, vae_cache):
return vae_encode(video, rng, vae, vae_cache)

rng = jax.random.key(config.seed)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
latents = p_vae_encode(videos, rng=rng)
latents = p_vae_encode(videos, rng=rng, vae=pipeline.vae, vae_cache=pipeline.vae_cache)
# 1. Verify Channel Count (Wan 2.1 requires 16)
self.assertEqual(latents.shape[1], 16, f"Expected 16 channels, got {latents.shape[1]}")

Expand Down
8 changes: 2 additions & 6 deletions src/maxdiffusion/tests/generate_ltx2_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,10 @@ def setUpClass(cls):
)
cls.config = pyconfig.config
checkpoint_loader = LTX2Checkpointer(config=cls.config)
# Load pipeline without upsampler for simplicity in smoke test
cls.pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=False)

cls.prompt = [cls.config.prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1)
cls.negative_prompt = [cls.config.negative_prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1)
cls.prompt = [cls.config.prompt]
cls.negative_prompt = [cls.config.negative_prompt]

def test_ltx2_inference(self):
"""Test that LTX2 pipeline can run inference and produce output."""
Expand Down Expand Up @@ -90,9 +89,6 @@ def test_ltx2_inference(self):
# Check that we got frames
self.assertGreater(len(videos), 0)

# LTX2 might also produce audio, check if it's there if expected
# The config doesn't explicitly say if it's T2AV or just T2V, but the pipeline seems to handle audio.
# We can just log if audio is present.
if audios is not None:
print(f"Audio produced with shape: {audios[0].shape}")
self.assertGreater(len(audios), 0)
Expand Down
18 changes: 16 additions & 2 deletions src/maxdiffusion/tests/generate_sdxl_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
class Generate(unittest.TestCase):
"""Smoke test."""

def tearDown(self):
super().tearDown()
import gc

gc.collect()
import jax
jax.clear_caches()

@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
def test_hyper_sdxl_lora(self):
img_url = os.path.join(THIS_DIR, "images", "test_hyper_sdxl.png")
Expand All @@ -53,6 +61,7 @@ def test_hyper_sdxl_lora(self):
'diffusion_scheduler_config={"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}',
'lora_config={"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}',
f"jax_cache_dir={JAX_CACHE_DIR}",
"jit_initializers=False",
],
unittest=True,
)
Expand Down Expand Up @@ -84,6 +93,7 @@ def test_sdxl_config(self):
"run_name=sdxl-inference-test",
"split_head_dim=False",
f"jax_cache_dir={JAX_CACHE_DIR}",
"jit_initializers=False",
],
unittest=True,
)
Expand Down Expand Up @@ -116,6 +126,7 @@ def test_sdxl_from_gcs(self):
"run_name=sdxl-inference-test",
"split_head_dim=False",
f"jax_cache_dir={JAX_CACHE_DIR}",
"jit_initializers=False",
],
unittest=True,
)
Expand All @@ -139,14 +150,16 @@ def test_controlnet_sdxl(self):
"activations_dtype=bfloat16",
"weights_dtype=bfloat16",
f"jax_cache_dir={JAX_CACHE_DIR}",
"controlnet_image=" + os.path.join(THIS_DIR, "images", "cnet_test.png"),
"jit_initializers=False",
],
unittest=True,
)
images = generate_run_sdxl_controlnet(pyconfig.config)
test_image = np.array(images[0]).astype(np.uint8)
ssim_compare = ssim(base_image, test_image, multichannel=True, channel_axis=-1, data_range=255)
assert base_image.shape == test_image.shape
assert ssim_compare >= 0.70
assert ssim_compare >= 0.80

@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
def test_sdxl_lightning(self):
Expand All @@ -158,14 +171,15 @@ def test_sdxl_lightning(self):
os.path.join(THIS_DIR, "..", "configs", "base_xl_lightning.yml"),
"run_name=sdxl-lightning-test",
f"jax_cache_dir={JAX_CACHE_DIR}",
"jit_initializers=False",
],
unittest=True,
)
images = generate_run_xl(pyconfig.config)
test_image = np.array(images[0]).astype(np.uint8)
ssim_compare = ssim(base_image, test_image, multichannel=True, channel_axis=-1, data_range=255)
assert base_image.shape == test_image.shape
assert ssim_compare >= 0.70
assert ssim_compare >= 0.80


if __name__ == "__main__":
Expand Down
Loading
Loading