Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi

[[autodoc]] loaders.lora_pipeline.ZImageLoraLoaderMixin

## CosmosLoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.CosmosLoraLoaderMixin

## KandinskyLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figures should be hosted somewhere else.


Expand Down
97 changes: 97 additions & 0 deletions examples/cosmos/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# LoRA fine-tuning for Cosmos Predict 2.5

This example shows how to fine-tune [Cosmos Predict 2.5](https://huggingface.co/nvidia/Cosmos-Predict2.5-2B) using LoRA on a custom video dataset.

## Requirements

Install the library from source and the example-specific dependencies:

```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e ".[dev]"
cd examples/cosmos
pip install -r requirements.txt
```

## Data preparation

The training script expects a dataset directory with the following layout:

```
<dataset_dir>/
├── videos/ # .mp4 files
└── metas/ # one .txt prompt file per video (same stem)
├── 0.txt
├── 1.txt
└── ...
```

### GR1 dataset (quick start)

The `download_and_preprocess_datasets.sh` script downloads the GR1-100 training set and the EVAL-175 test set, then runs the preprocessing script to create the per-video prompt files.

```bash
bash download_and_preprocess_datasets.sh
```

This produces:
- `gr1_dataset/train/` — training videos + prompts
- `gr1_dataset/test/` — evaluation images + prompts

## Training

Launch LoRA training with `accelerate`:

```bash
export MODEL_NAME="nvidia/Cosmos-Predict2.5-2B"
export DATA_DIR="gr1_dataset/train"
export OUT_DIR="lora-output"

accelerate launch --mixed_precision="bf16" train_cosmos_predict25_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--revision diffusers/base/post-trained \
--train_data_dir=$DATA_DIR \
--output_dir=$OUT_DIR \
--train_batch_size=1 \
--num_train_epochs=500 \
--checkpointing_epochs=100 \
--seed=0 \
--height 432 --width 768 \
--allow_tf32 \
--gradient_checkpointing \
--lora_rank 32 --lora_alpha 32 \
--report_to=wandb
```

Or use the provided shell script:

```bash
bash train_lora.sh
```

## Evaluation

Run inference with the trained LoRA adapter:

```bash
export DATA_DIR="gr1_dataset/test"
export LORA_DIR="lora-output"
export OUT_DIR="eval-output"

python eval_cosmos_predict25_lora.py \
--data_dir $DATA_DIR \
--output_dir $OUT_DIR \
--lora_dir $LORA_DIR \
--revision diffusers/base/post-trained \
--height 432 --width 768 \
--num_output_frames 93 \
--num_steps 36 \
--seed 0
```

Or use the provided shell script:

```bash
bash eval_lora.sh
```
Binary file added examples/cosmos/assets/figures/plot_IF.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/cosmos/assets/figures/plot_physics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/cosmos/assets/figures/plot_sampson.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
63 changes: 63 additions & 0 deletions examples/cosmos/create_prompts_for_gr1_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

from tqdm import tqdm


"""example command
python create_prompts_for_gr1_dataset.py --dataset_path datasets/benchmark_train/gr1
"""


def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Create text prompts for GR1 dataset")
parser.add_argument(
"--dataset_path", type=str, default="datasets/benchmark_train/gr1", help="Root path to the dataset"
)
parser.add_argument(
"--prompt_prefix", type=str, default="The robot arm is performing a task. ", help="Prefix of the prompt"
)
parser.add_argument(
"--meta_csv", type=str, default=None, help="Metadata csv file (defaults to <dataset_path>/metadata.csv)"
)
return parser.parse_args()


def main(args) -> None:
meta_csv = args.meta_csv or os.path.join(args.dataset_path, "metadata.csv")
meta_lines = open(meta_csv).readlines()[1:]
meta_txt_dir = os.path.join(args.dataset_path, "metas")
os.makedirs(meta_txt_dir, exist_ok=True)

for meta_line in tqdm(meta_lines):
video_filename, prompt = meta_line.split(",", 1)
prompt = prompt.strip("\n")
if prompt.startswith('"') and prompt.endswith('"'):
# Remove the quotes
prompt = prompt[1:-1]
prompt = args.prompt_prefix + prompt
meta_txt_filename = os.path.join(meta_txt_dir, os.path.basename(video_filename).replace(".mp4", ".txt"))
with open(meta_txt_filename, "w") as fp:
fp.write(prompt)

print(f"encoding prompt: {prompt}")


if __name__ == "__main__":
args = parse_args()
main(args)
25 changes: 25 additions & 0 deletions examples/cosmos/download_and_preprocess_datasets.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
dataset_dir='gr1_dataset'
train_dir=$dataset_dir/train
test_dir=$dataset_dir/test

# Download and Preprocess Training Dataset
hf download nvidia/GR1-100 --repo-type dataset --local-dir datasets/benchmark_train/hf_gr1/ && \
mkdir -p datasets/benchmark_train/gr1/videos && \
mv datasets/benchmark_train/hf_gr1/gr1/*mp4 datasets/benchmark_train/gr1/videos && \
mv datasets/benchmark_train/hf_gr1/metadata.csv datasets/benchmark_train/gr1/

python create_prompts_for_gr1_dataset.py --dataset_path datasets/benchmark_train/gr1

# Download Eval Dataset
hf download nvidia/EVAL-175 --repo-type dataset --local-dir dream_gen_benchmark


# Rename dataset directory
mkdir $dataset_dir
mv datasets/benchmark_train/gr1 $train_dir
mv dream_gen_benchmark/gr1_object $test_dir
echo Download training data to $train_dir
echo Download test data to $test_dir

# Clean up staging directories
rm -rf datasets/ dream_gen_benchmark/
164 changes: 164 additions & 0 deletions examples/cosmos/eval_cosmos_predict25_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
import os

import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from diffusers import Cosmos2_5_PredictBasePipeline
from diffusers.utils import export_to_video, load_image


IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"}


class ImageDataset(Dataset):
"""Dataset that loads images and their corresponding text prompts.

Expects a directory with:
<filename>.jpg / .jpeg / .png — the conditioning image
<filename>.txt — the prompt text
"""

def __init__(self, data_dir: str):
self.data_dir = data_dir
self.samples = []

for filename in sorted(os.listdir(data_dir)):
stem, ext = os.path.splitext(filename)
if ext.lower() not in IMAGE_EXTENSIONS:
continue
img_path = os.path.join(data_dir, filename)
txt_path = os.path.join(data_dir, stem + ".txt")
if not os.path.exists(txt_path):
print(f"WARNING: no prompt file found for {img_path}, skipping.")
continue
self.samples.append((img_path, txt_path, stem))

if len(self.samples) == 0:
raise ValueError(f"No valid image/prompt pairs found in {data_dir}")

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
img_path, txt_path, stem = self.samples[idx]
image = load_image(img_path)
with open(txt_path) as f:
prompt = f.read().strip()
return {
"image": image,
"prompt": prompt,
"stem": stem,
}


def collate_fn(batch):
"""Keep images as a list (PIL images can't be stacked into a tensor)."""
return {
"images": [item["image"] for item in batch],
"prompts": [item["prompt"] for item in batch],
"stems": [item["stem"] for item in batch],
}


def parse_args():
parser = argparse.ArgumentParser(description="Eval Cosmos Predict 2.5 with optional LoRA weights.")

parser.add_argument("--data_dir", type=str, required=True, help="Directory with image/prompt pairs.")
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save generated outputs.")
parser.add_argument(
"--model_id", type=str, default="nvidia/Cosmos-Predict2.5-2B", help="HuggingFace model repository."
)
parser.add_argument(
"--revision",
type=str,
default="diffusers/base/post-trained",
choices=["diffusers/base/post-trained", "diffusers/base/pre-trained"],
)
parser.add_argument("--lora_dir", type=str, default=None, help="Path to LoRA weights directory.")
parser.add_argument("--num_output_frames", type=int, default=93, help="1 for image output, 93 for video output.")
parser.add_argument("--num_steps", type=int, default=36, help="Number of inference steps.")
parser.add_argument("--height", type=int, default=704, help="Output height in pixels (must be divisible by 16).")
parser.add_argument("--width", type=int, default=1280, help="Output width in pixels (must be divisible by 16).")
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
parser.add_argument("--device", type=str, default="cuda", help="Device to use.")
parser.add_argument("--batch_size", type=int, default=1, help="Number of samples per batch.")
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker processes.")
parser.add_argument(
"--negative_prompt",
type=str,
default=None,
help="Negative prompt. Defaults to the pipeline's built-in negative prompt.",
)
return parser.parse_args()


def main():
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)

dataset = ImageDataset(args.data_dir)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_fn,
)

print(f"Found {len(dataset)} examples.")

class MockSafetyChecker:
def to(self, *args, **kwargs):
return self

def check_text_safety(self, *args, **kwargs):
return True

def check_video_safety(self, video):
return video

pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
args.model_id,
revision=args.revision,
device_map=args.device,
torch_dtype=torch.bfloat16,
safety_checker=MockSafetyChecker(),
)

if args.lora_dir is not None:
pipe.load_lora_weights(args.lora_dir)
pipe.fuse_lora(lora_scale=1.0)
print(f"Loaded LoRA weights from {args.lora_dir}")

progress = tqdm(total=len(dataset), desc="Generating")
for batch in dataloader:
images = batch["images"]
prompts = batch["prompts"]
stems = batch["stems"]

for image, prompt, stem in zip(images, prompts, stems):
frames = pipe(
image=image,
prompt=prompt,
negative_prompt=args.negative_prompt,
num_frames=args.num_output_frames,
num_inference_steps=args.num_steps,
height=args.height,
width=args.width,
).frames[0] # NOTE: batch_size == 1

out_path = os.path.join(args.output_dir, f"{stem}.mp4")
export_to_video(frames, out_path, fps=16)

tqdm.write(f" Saved to: {out_path}")
progress.update(1)


if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions examples/cosmos/eval_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export DATA_DIR="gr1_dataset/test"
export LORA_DIR=YOUR_ADAPTER_DIR
export OUT_DIR=YOUR_EVAL_OUTPUT_DIR
revision="post-trained"

export TOKENIZERS_PARALLELISM=false
python eval_cosmos_predict25_lora.py \
--data_dir $DATA_DIR \
--output_dir $OUT_DIR \
--lora_dir $LORA_DIR \
--revision diffusers/base/$revision \
--height 432 --width 768 \
--num_output_frames 93 \
--num_steps 36 \
--seed 0
Loading
Loading