diff --git a/examples/diffusion/README.md b/examples/diffusion/README.md new file mode 100644 index 0000000000..ba59237717 --- /dev/null +++ b/examples/diffusion/README.md @@ -0,0 +1,3 @@ +# Megatron Examples + +Recipes and configuration overrides for megatron training. diff --git a/examples/diffusion/override_configs/README.md b/examples/diffusion/override_configs/README.md new file mode 100644 index 0000000000..6731a25a6f --- /dev/null +++ b/examples/diffusion/override_configs/README.md @@ -0,0 +1,3 @@ +# Override Configs + +Parallelism configuration overrides for different CP/TP/SP sizes. diff --git a/examples/diffusion/override_configs/wan_pretrain_sample_data.yaml b/examples/diffusion/override_configs/wan_pretrain_sample_data.yaml new file mode 100644 index 0000000000..9648874e4c --- /dev/null +++ b/examples/diffusion/override_configs/wan_pretrain_sample_data.yaml @@ -0,0 +1,44 @@ +# WAN Pretrain Mock Data Test Configuration +# Converted from L2_Function_Tests_GPU_Wan_Mock_Data.sh + +model: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + crossattn_emb_size: 1536 + hidden_size: 1536 + ffn_hidden_size: 8960 + num_attention_heads: 12 + num_layers: 3 + qkv_format: thd + seq_length: 2048 + +train: + eval_iters: 0 + train_iters: 10 + global_batch_size: 2 + micro_batch_size: 1 + +optimizer: + lr: 5.0e-6 + min_lr: 5.0e-6 + +scheduler: + lr_decay_style: constant + lr_warmup_iters: 0 + +checkpoint: + save: ${oc.env:CHECKPOINT_DIR,null} + load: ${oc.env:CHECKPOINT_DIR,null} + load_optim: false + save_interval: 200 + +dataset: + path: ${oc.env:DATASET_PATH,null} + seq_length: 2048 + global_batch_size: 2 + micro_batch_size: 1 + packing_buffer_size: 50 + +logger: + log_interval: 1 diff --git a/examples/diffusion/recipes/README.md b/examples/diffusion/recipes/README.md new file mode 100644 index 0000000000..7ed591b1db --- /dev/null +++ b/examples/diffusion/recipes/README.md @@ -0,0 +1,3 @@ +# Recipe + +Training recipes for Wan2.1 pretraining, finetuning, and weight verification. diff --git a/examples/diffusion/recipes/__init__.py b/examples/diffusion/recipes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/diffusion/recipes/flux/conf/flux_pretrain_override_example.yaml b/examples/diffusion/recipes/flux/conf/flux_pretrain_override_example.yaml new file mode 100644 index 0000000000..4d1c184c93 --- /dev/null +++ b/examples/diffusion/recipes/flux/conf/flux_pretrain_override_example.yaml @@ -0,0 +1,56 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example FLUX pretrain configuration override file +# This file shows common overrides for FLUX pretraining + +# Model configuration +model: + # Parallelism settings + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + sequence_parallel: false + + # FLUX architecture (FLUX-schnell defaults) + num_joint_layers: 19 + num_single_layers: 38 + hidden_size: 3072 + num_attention_heads: 24 + in_channels: 64 + context_dim: 4096 + + # For FLUX-dev, set guidance_embed: true + guidance_embed: false + guidance_scale: 3.5 + +# Training configuration +train: + train_iters: 10000 + eval_interval: 2000 + eval_iters: 32 + global_batch_size: 64 + micro_batch_size: 1 + +# Optimizer configuration +optimizer: + lr: 1.0e-4 + +# Checkpoint configuration +checkpoint: + save_interval: 2000 + +# Logger configuration +logger: + log_interval: 1 diff --git a/examples/diffusion/recipes/flux/conversion/convert_checkpoints.py b/examples/diffusion/recipes/flux/conversion/convert_checkpoints.py new file mode 100644 index 0000000000..f9a1d85643 --- /dev/null +++ b/examples/diffusion/recipes/flux/conversion/convert_checkpoints.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Megatron-HuggingFace Checkpoint Conversion Example + +This script demonstrates how to convert models between HuggingFace and Megatron formats +using the AutoBridge import_ckpt and export_ckpt methods. + +Usage examples: + # Download the HF checkpoint locally + huggingface-cli download black-forest-labs/FLUX.1-dev \ + --local-dir /root/.cache/huggingface/flux.1-dev \ + --local-dir-use-symlinks False + + # Import a HuggingFace model to Megatron format + python examples/diffusion/recipes/flux/conversion/convert_checkpoints.py import \ + --hf-model /root/.cache/huggingface/flux.1-dev\ + --megatron-path /workspace/checkpoints/megatron_checkpoints/flux.1-dev + + # Export a Megatron checkpoint to HuggingFace format + python examples/diffusion/recipes/flux/conversion/convert_checkpoints.py export \ + --hf-model /root/.cache/huggingface/flux.1-dev \ + --megatron-path /workspace/checkpoints/megatron_checkpoints/flux.1-dev/iter_0000000 \ + --hf-path /workspace/checkpoints/hf_checkpoints/flux.1-dev_hf + + NOTE: The converted checkpoint /workspace/checkpoints/hf_checkpoints/flux.1-dev_hf + only contains the DiT model transformer weights. You still need other components in + the diffusion pipeline (VAE, text encoders, etc) to run inference. To do so, you can + duplicate the original HF checkpoint directory /root/.cache/huggingface/flux.1-dev (which + contains VAE, text encoders, etc.), and replace ./transformer with + /workspace/checkpoints/hf_checkpoints/flux.1-dev_hf/transformer. + +""" + +import argparse +import os +import random +import sys +from pathlib import Path +from typing import Optional + +import torch + +from megatron.bridge.diffusion.conversion.flux.flux_bridge import FluxBridge +from megatron.bridge.diffusion.conversion.flux.flux_hf_pretrained import PreTrainedFlux +from megatron.bridge.training.model_load_save import ( + load_megatron_model, + save_megatron_model, + temporary_distributed_context, +) + + +def validate_path(path: str, must_exist: bool = False) -> Path: + """Validate and convert string path to Path object.""" + path_obj = Path(path) + if must_exist and not path_obj.exists(): + raise ValueError(f"Path does not exist: {path}") + return path_obj + + +def get_torch_dtype(dtype_str: str) -> torch.dtype: + """Convert string to torch dtype.""" + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + if dtype_str not in dtype_map: + raise ValueError(f"Unsupported dtype: {dtype_str}. Supported: {list(dtype_map.keys())}") + return dtype_map[dtype_str] + + +def import_hf_to_megatron( + hf_model: str, + megatron_path: str, + torch_dtype: Optional[str] = None, + device_map: Optional[str] = None, + trust_remote_code: bool = False, +) -> None: + """ + Import a HuggingFace model and save it as a Megatron checkpoint. + + Args: + hf_model: HuggingFace model ID or path to model directory + megatron_path: Directory path where the Megatron checkpoint will be saved + torch_dtype: Model precision ("float32", "float16", "bfloat16") + device_map: Device placement strategy ("auto", "cuda:0", etc.) + trust_remote_code: Allow custom model code execution + """ + print(f"🔄 Starting import: {hf_model} -> {megatron_path}") + + # Prepare kwargs + kwargs = {} + if torch_dtype: + kwargs["torch_dtype"] = get_torch_dtype(torch_dtype) + print(f" Using torch_dtype: {torch_dtype}") + + if device_map: + kwargs["device_map"] = device_map + print(f" Using device_map: {device_map}") + + if trust_remote_code: + kwargs["trust_remote_code"] = trust_remote_code + print(f" Trust remote code: {trust_remote_code}") + + # Import using the convenience method + print(f"đŸ“Ĩ Loading HuggingFace model from diffusers: {hf_model}") + # Minimal single-rank env to satisfy provider init if needed + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", str(29500 + random.randint(0, 1000))) + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + + hf = PreTrainedFlux(hf_model) + bridge = FluxBridge() + provider = bridge.provider_bridge(hf) + provider.perform_initialization = False + megatron_models = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True) + bridge.load_weights_hf_to_megatron(hf, megatron_models) + + # Save all parameters in transformer with their norms to a file + param_norms_file = "megatron_bridge_param_norms.txt" + with open(param_norms_file, "w") as f: + f.write("=" * 80 + "\n") + f.write("Transformer Parameters and Norms\n") + f.write("=" * 80 + "\n") + for name, param in megatron_models[0].named_parameters(): + if param.requires_grad: + norm = param.data.norm().item() + f.write(f"{name:80s} | shape: {str(list(param.shape)):20s} | norm: {norm:.6f}\n") + f.write("=" * 80 + "\n") + print(f"Parameter norms saved to: {param_norms_file}") + + save_megatron_model(megatron_models, megatron_path, hf_tokenizer_path=None) + + print(f"✅ Successfully imported model to: {megatron_path}") + + # Verify the checkpoint was created + checkpoint_path = Path(megatron_path) + if checkpoint_path.exists(): + print("📁 Checkpoint structure:") + for item in checkpoint_path.iterdir(): + if item.is_dir(): + print(f" 📂 {item.name}/") + else: + print(f" 📄 {item.name}") + + +def export_megatron_to_hf( + hf_model: str, + megatron_path: str, + hf_path: str, + show_progress: bool = True, + strict: bool = True, +) -> None: + """ + Export a Megatron checkpoint to HuggingFace format. + + Args: + megatron_path: Directory path where the Megatron checkpoint is stored + hf_path: Directory path where the HuggingFace model will be saved + show_progress: Display progress bar during weight export + """ + print(f"🔄 Starting export: {megatron_path} -> {hf_path}") + + # Validate megatron checkpoint exists + checkpoint_path = validate_path(megatron_path, must_exist=True) + print(f"📂 Found Megatron checkpoint: {checkpoint_path}") + + # Look for configuration files to determine the model type + config_files = list(checkpoint_path.glob("**/run_config.yaml")) + if not config_files: + # Look in iter_ subdirectories + iter_dirs = [d for d in checkpoint_path.iterdir() if d.is_dir() and d.name.startswith("iter_")] + if iter_dirs: + # Use the latest iteration + latest_iter = max(iter_dirs, key=lambda d: int(d.name.replace("iter_", ""))) + config_files = list(latest_iter.glob("run_config.yaml")) + + if not config_files: + raise FileNotFoundError( + f"Could not find run_config.yaml in {checkpoint_path}. Please ensure this is a valid Megatron checkpoint." + ) + + print(f"📋 Found configuration: {config_files[0]}") + + print("â„šī¸ Starting Flux Diffusers export...") + # Minimal single-process distributed context on CPU for loading Megatron ckpt + with temporary_distributed_context(backend="gloo"): + # Resolve latest iter_* directory (use the config file we found) + checkpoint_iter_dir = config_files[0].parent + # 1) Load Megatron model from checkpoint + megatron_models = load_megatron_model(str(checkpoint_iter_dir), use_cpu_init=True, skip_temp_dist_context=True) + if not isinstance(megatron_models, list): + megatron_models = [megatron_models] + + # 2) Prepare HF Flux wrapper for state/metadata and save artifacts + hf = PreTrainedFlux(hf_model) + Path(hf_path).mkdir(parents=True, exist_ok=True) + # Some diffusers configs are FrozenDict and don't support save_pretrained; skip quietly + try: + hf.save_artifacts(hf_path) + except Exception: + pass + + # 3) Stream-export weights Megatron -> HF safetensors via Flux bridge + bridge = FluxBridge() + generator = bridge.stream_weights_megatron_to_hf(megatron_models, hf, cpu=True, show_progress=show_progress) + # 4) Save streamed weights into hf_path + hf.state.source.save_generator(generator, hf_path) + + print(f"✅ Successfully exported model to: {hf_path}") + + # Verify the export was created + export_path = Path(hf_path) + if export_path.exists(): + print("📁 Export structure:") + for item in export_path.iterdir(): + if item.is_dir(): + print(f" 📂 {item.name}/") + else: + print(f" 📄 {item.name}") + + print("🔍 You can now load this model with:") + print(" from transformers import AutoModelForCausalLM") + print(f" model = AutoModelForCausalLM.from_pretrained('{hf_path}')") + + +def main(): + """Main function to handle command line arguments and execute conversions.""" + parser = argparse.ArgumentParser( + description="Convert models between HuggingFace and Megatron formats", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + subparsers = parser.add_subparsers(dest="command", help="Conversion direction") + + # Import subcommand (HF -> Megatron) + import_parser = subparsers.add_parser("import", help="Import HuggingFace model to Megatron checkpoint format") + import_parser.add_argument("--hf-model", required=True, help="HuggingFace model ID or path to model directory") + import_parser.add_argument( + "--megatron-path", required=True, help="Directory path where the Megatron checkpoint will be saved" + ) + import_parser.add_argument("--torch-dtype", choices=["float32", "float16", "bfloat16"], help="Model precision") + import_parser.add_argument("--device-map", help='Device placement strategy (e.g., "auto", "cuda:0")') + import_parser.add_argument("--trust-remote-code", action="store_true", help="Allow custom model code execution") + + # Export subcommand (Megatron -> HF) + export_parser = subparsers.add_parser("export", help="Export Megatron checkpoint to HuggingFace format") + export_parser.add_argument("--hf-model", required=True, help="HuggingFace model ID or path to model directory") + export_parser.add_argument( + "--megatron-path", required=True, help="Directory path where the Megatron checkpoint is stored" + ) + export_parser.add_argument( + "--hf-path", required=True, help="Directory path where the HuggingFace model will be saved" + ) + export_parser.add_argument("--no-progress", action="store_true", help="Disable progress bar during export") + export_parser.add_argument( + "--not-strict", action="store_true", help="Allow source and target checkpoint to have different keys" + ) + + args = parser.parse_args() + + if not args.command: + parser.print_help() + return 1 + + if args.command == "import": + import_hf_to_megatron( + hf_model=args.hf_model, + megatron_path=args.megatron_path, + torch_dtype=args.torch_dtype, + device_map=args.device_map, + trust_remote_code=args.trust_remote_code, + ) + + elif args.command == "export": + export_megatron_to_hf( + hf_model=args.hf_model, + megatron_path=args.megatron_path, + hf_path=args.hf_path, + show_progress=not args.no_progress, + strict=not args.not_strict, + ) + else: + raise RuntimeError(f"Unknown command: {args.command}") + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/diffusion/recipes/flux/finetune_flux.py b/examples/diffusion/recipes/flux/finetune_flux.py new file mode 100644 index 0000000000..9c48f70e43 --- /dev/null +++ b/examples/diffusion/recipes/flux/finetune_flux.py @@ -0,0 +1,407 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FLUX Fine-tuning Script with YAML and CLI Configuration Overrides. + +This script provides a flexible way to fine-tune FLUX models using Megatron-Bridge with support for +both YAML configuration files and command-line overrides using Hydra-style syntax. + +The script loads a pretrained checkpoint and continues training with your custom dataset. +Fine-tuning typically uses lower learning rates and fewer training iterations compared to pretraining. + +Forward Step Options: + - Automodel FlowMatchingPipeline (default): Unified flow matching implementation + - Original FluxForwardStep (--use-original-step): Classic implementation + +Examples: + Basic usage with checkpoint loading (uses automodel pipeline): + $ torchrun --nproc_per_node=8 finetune_flux.py \ + --load-checkpoint /path/to/pretrained/checkpoint --mock + + Using original FluxForwardStep: + $ torchrun --nproc_per_node=8 finetune_flux.py \ + --load-checkpoint /path/to/pretrained/checkpoint --mock --use-original-step + + Using a custom YAML config file: + $ torchrun --nproc_per_node=8 finetune_flux.py \ + --load-checkpoint /path/to/pretrained/checkpoint \ + --config-file my_custom_config.yaml + + Using CLI overrides only: + $ torchrun --nproc_per_node=8 finetune_flux.py \ + --load-checkpoint /path/to/pretrained/checkpoint \ + model.tensor_model_parallel_size=4 train.train_iters=5000 optimizer.lr=1e-5 + + Combining YAML and CLI overrides (CLI takes precedence): + $ torchrun --nproc_per_node=8 finetune_flux.py \ + --load-checkpoint /path/to/pretrained/checkpoint \ + --config-file conf/my_config.yaml \ + model.pipeline_dtype=torch.float16 \ + train.global_batch_size=512 + + Using automodel pipeline with custom parameters (automodel is default): + $ torchrun --nproc_per_node=8 finetune_flux.py \ + --load-checkpoint /path/to/pretrained/checkpoint --mock \ + --flow-shift=1.0 --use-loss-weighting + +Configuration Precedence: + 1. Base configuration from pretrain_config() recipe + 2. YAML overrides from --config-file (if provided) + 3. CLI overrides (highest precedence) + 4. Checkpoint loading path (from --load-checkpoint) + +Supported Override Syntax: + - Standard assignment: key=value + - Nested assignment: section.subsection.key=value + - Addition: +new_key=value + - Deletion: ~key_to_remove + - Type conversion: Automatic for basic types (int, float, bool, str) + - Complex types: torch.dtype, enums, etc. are supported +""" + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import Tuple + +from omegaconf import OmegaConf + +from megatron.bridge.diffusion.models.flux.flux_step_with_automodel import create_flux_forward_step +from megatron.bridge.diffusion.recipes.flux.flux import pretrain_config +from megatron.bridge.training.config import ConfigContainer +from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.training.utils.omegaconf_utils import ( + apply_overrides, + create_omegaconf_dict_config, + parse_hydra_overrides, +) +from megatron.bridge.utils.common_utils import get_rank_safe + + +logger: logging.Logger = logging.getLogger(__name__) + + +# Define paths relative to this script's location +SCRIPT_DIR: Path = Path(__file__).parent.resolve() +DEFAULT_CONFIG_FILENAME: str = "flux_pretrain_override_example.yaml" +DEFAULT_CONFIG_FILE_PATH: Path = SCRIPT_DIR / "conf" / DEFAULT_CONFIG_FILENAME + + +def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: + """Parse command line arguments, separating known script args from OmegaConf overrides.""" + parser = argparse.ArgumentParser( + description="Fine-tune FLUX model using Megatron-Bridge with YAML and CLI overrides", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--load-checkpoint", + type=str, + required=True, + help="Path to the pretrained checkpoint directory to load for fine-tuning.", + ) + parser.add_argument("--mock", action="store_true", help="Whether to use mock data.") + parser.add_argument( + "--timestep-sampling", + choices=["logit_normal", "uniform", "mode"], + default="logit_normal", + help="Timestep sampling strategy for flow matching.", + ) + parser.add_argument( + "--logit-mean", + type=float, + default=0.0, + help="Mean for logit-normal timestep sampling.", + ) + parser.add_argument( + "--logit-std", + type=float, + default=1.0, + help="Std for logit-normal timestep sampling.", + ) + parser.add_argument( + "--mode-scale", + type=float, + default=1.29, + help="Scale for mode timestep sampling.", + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=3.5, + help="Guidance scale for FLUX-dev models.", + ) + parser.add_argument( + "--scheduler-steps", + type=int, + default=1000, + help="Number of scheduler training steps.", + ) + parser.add_argument( + "--config-file", + type=str, + default=str(DEFAULT_CONFIG_FILE_PATH), + help="Path to the YAML OmegaConf override file. Default: conf/flux_pretrain_override_example.yaml", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + + # Forward step implementation choice + parser.add_argument( + "--use-original-step", + action="store_true", + help="Use original FluxForwardStep instead of automodel FlowMatchingPipeline (default)", + ) + parser.add_argument( + "--flow-shift", + type=float, + default=1.0, + help="Flow shift parameter (for automodel pipeline)", + ) + parser.add_argument( + "--use-loss-weighting", + action="store_true", + help="Use loss weighting (for automodel pipeline)", + ) + + # Parse known args for the script, remaining will be treated as overrides + args, cli_dotlist_overrides = parser.parse_known_args() + return args, cli_dotlist_overrides + + +def main() -> None: + """ + Entry point for the FLUX fine-tuning script. + + This function orchestrates the complete configuration workflow: + 1. Loads the base configuration from pretrain_config() recipe + 2. Applies YAML overrides from --config-file (if exists) + 3. Applies CLI overrides using Hydra-style syntax + 4. Sets checkpoint loading path for fine-tuning + 5. Starts Megatron training with the final merged configuration + + Configuration merging preserves callable fields (like activation functions) + and handles type conversions automatically. + + Examples of CLI usage: + # Fine-tune with default config and custom learning rate (automodel pipeline is default) + torchrun --nproc_per_node=8 finetune_flux.py \ + --load-checkpoint /path/to/checkpoint --mock optimizer.lr=1e-5 + + # Use original FluxForwardStep instead of automodel pipeline + torchrun --nproc_per_node=8 finetune_flux.py \ + --load-checkpoint /path/to/checkpoint --mock --use-original-step + + # Custom config file with additional overrides + torchrun --nproc_per_node=8 finetune_flux.py \ + --load-checkpoint /path/to/checkpoint \ + --config-file my_config.yaml train.train_iters=5000 + + # Multiple overrides for distributed fine-tuning (uses automodel by default) + torchrun --nproc_per_node=8 finetune_flux.py \ + --load-checkpoint /path/to/checkpoint --mock \ + model.tensor_model_parallel_size=4 \ + model.pipeline_model_parallel_size=2 \ + train.global_batch_size=512 \ + optimizer.lr=5e-6 + + # Automodel pipeline with custom flow matching parameters + torchrun --nproc_per_node=8 finetune_flux.py \ + --load-checkpoint /path/to/checkpoint --mock \ + --flow-shift=1.0 --use-loss-weighting + """ + args, cli_overrides = parse_cli_args() + + logger.info("Megatron-Bridge FLUX Fine-tuning Script with YAML & CLI Overrides") + logger.info("------------------------------------------------------------------") + + # Validate and normalize checkpoint path + checkpoint_path = args.load_checkpoint + if not os.path.exists(checkpoint_path): + logger.error(f"Checkpoint path does not exist: {checkpoint_path}") + sys.exit(1) + + # Check if the path points to a specific iteration directory (iter_XXXXXXX) + # If so, extract the base directory and iteration number + import re + + checkpoint_dir_name = os.path.basename(checkpoint_path.rstrip("/")) + iter_match = re.match(r"iter_(\d+)", checkpoint_dir_name) + + if iter_match: + # User provided a specific iteration directory + iteration_num = int(iter_match.group(1)) + base_dir = os.path.dirname(checkpoint_path.rstrip("/")) + if not base_dir: + base_dir = "." + logger.info(f"Detected iteration directory: {checkpoint_dir_name}") + logger.info(f"Extracted base directory: {base_dir}, iteration: {iteration_num}") + checkpoint_path = base_dir + # Set ckpt_step to load from this specific iteration + ckpt_step_override = iteration_num + else: + # User provided base directory - will load latest checkpoint + ckpt_step_override = None + logger.info(f"Using checkpoint base directory: {checkpoint_path}") + + # Load base configuration from the recipe as a Python dataclass + cfg: ConfigContainer = pretrain_config(mock=args.mock) + logger.info("Loaded base configuration") + + # Set checkpoint configuration for fine-tuning + # If ckpt_step is specified, we need to set load (validation requirement) + # Otherwise, use pretrained_checkpoint (preferred for fine-tuning) + if ckpt_step_override is not None: + # When loading a specific iteration, set load (required by validation when ckpt_step is set) + cfg.checkpoint.load = checkpoint_path + cfg.checkpoint.ckpt_step = ckpt_step_override + cfg.checkpoint.pretrained_checkpoint = None + logger.info(f"Will load from specific iteration: {ckpt_step_override}") + logger.info(f"Checkpoint load path set to: {checkpoint_path}") + else: + # When loading latest checkpoint, use pretrained_checkpoint (preferred for fine-tuning) + cfg.checkpoint.pretrained_checkpoint = checkpoint_path + cfg.checkpoint.load = None # Clear load to ensure pretrained_checkpoint takes precedence + logger.info(f"Pretrained checkpoint path set to: {checkpoint_path}") + # Explicitly set finetune flag to True - this ensures: + # - Model weights are loaded + # - Iteration is reset to 0 + # - Optimizer state is NOT loaded (fresh optimizer) + # - RNG state is NOT loaded (fresh RNG) + cfg.checkpoint.finetune = True + logger.info("Fine-tuning mode enabled (checkpoint.finetune=True)") + + # Print configuration on rank 0 + if get_rank_safe() == 0: + cfg.print_yaml() + + # Convert the initial Python dataclass to an OmegaConf DictConfig for merging + merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg) + + # Load and merge YAML overrides if a config file is provided + if args.config_file: + logger.debug(f"Loading YAML overrides from: {args.config_file}") + if not os.path.exists(args.config_file): + logger.error(f"Override YAML file not found: {args.config_file}") + sys.exit(1) + yaml_overrides_omega = OmegaConf.load(args.config_file) + merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega) + logger.debug("YAML overrides merged successfully.") + + # Apply command-line overrides using Hydra-style parsing + if cli_overrides: + logger.debug(f"Applying Hydra-style command-line overrides: {cli_overrides}") + merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides) + logger.debug("Hydra-style command-line overrides applied successfully.") + + # Ensure checkpoint configuration is set correctly (override any YAML/CLI that might have changed it) + if ckpt_step_override is not None: + # When loading a specific iteration, set load (required by validation when ckpt_step is set) + merged_omega_conf.checkpoint.load = checkpoint_path + merged_omega_conf.checkpoint.ckpt_step = ckpt_step_override + merged_omega_conf.checkpoint.pretrained_checkpoint = None + else: + # When loading latest checkpoint, use pretrained_checkpoint (preferred for fine-tuning) + merged_omega_conf.checkpoint.pretrained_checkpoint = checkpoint_path + merged_omega_conf.checkpoint.load = None # Clear load to ensure pretrained_checkpoint takes precedence + merged_omega_conf.checkpoint.finetune = True + + # Apply the final merged OmegaConf configuration back to the original ConfigContainer + logger.debug("Applying final merged configuration back to Python ConfigContainer...") + final_overrides_as_dict = OmegaConf.to_container(merged_omega_conf, resolve=True) + # Apply overrides while preserving excluded fields + apply_overrides(cfg, final_overrides_as_dict, excluded_fields) + + # Ensure checkpoint configuration is set correctly in the final config + if ckpt_step_override is not None: + # When loading a specific iteration, set load (required by validation when ckpt_step is set) + cfg.checkpoint.load = checkpoint_path + cfg.checkpoint.ckpt_step = ckpt_step_override + cfg.checkpoint.pretrained_checkpoint = None + else: + # When loading latest checkpoint, use pretrained_checkpoint (preferred for fine-tuning) + cfg.checkpoint.pretrained_checkpoint = checkpoint_path + cfg.checkpoint.load = None # Clear load to ensure pretrained_checkpoint takes precedence + cfg.checkpoint.finetune = True + + # Create forward step (configurable: original or automodel pipeline) + # Default is automodel pipeline unless --use-original-step is specified + if not args.use_original_step: + # Use automodel FlowMatchingPipeline + flux_forward_step = create_flux_forward_step( + use_automodel_pipeline=True, + timestep_sampling=args.timestep_sampling, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + flow_shift=args.flow_shift, + scheduler_steps=args.scheduler_steps, + guidance_scale=args.guidance_scale, + use_loss_weighting=args.use_loss_weighting, + ) + if get_rank_safe() == 0: + logger.info("=" * 70) + logger.info("✅ Using AUTOMODEL FlowMatchingPipeline") + logger.info(f" Timestep Sampling: {args.timestep_sampling}") + logger.info(f" Flow Shift: {args.flow_shift}") + logger.info(f" Loss Weighting: {args.use_loss_weighting}") + logger.info("=" * 70) + else: + # Use original FluxForwardStep + flux_forward_step = create_flux_forward_step( + use_automodel_pipeline=False, + timestep_sampling=args.timestep_sampling, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + scheduler_steps=args.scheduler_steps, + guidance_scale=args.guidance_scale, + ) + if get_rank_safe() == 0: + logger.info("=" * 70) + logger.info("✅ Using ORIGINAL FluxForwardStep") + logger.info(f" Timestep Sampling: {args.timestep_sampling}") + logger.info("=" * 70) + + # Display final configuration + if get_rank_safe() == 0: + logger.info("--- Final Merged Configuration ---") + cfg.print_yaml() + logger.info("----------------------------------") + logger.info("Fine-tuning Configuration:") + if cfg.checkpoint.pretrained_checkpoint is not None: + logger.info(f" Pretrained Checkpoint Path: {cfg.checkpoint.pretrained_checkpoint}") + if cfg.checkpoint.load is not None: + logger.info(f" Checkpoint Load Path: {cfg.checkpoint.load}") + if cfg.checkpoint.ckpt_step is not None: + logger.info(f" Checkpoint Step: {cfg.checkpoint.ckpt_step}") + logger.info("FluxForwardStep config:") + logger.info(f" timestep_sampling: {args.timestep_sampling}") + logger.info(f" logit_mean: {args.logit_mean}") + logger.info(f" logit_std: {args.logit_std}") + logger.info(f" mode_scale: {args.mode_scale}") + logger.info(f" scheduler_steps: {args.scheduler_steps}") + logger.info(f" guidance_scale: {args.guidance_scale}") + if not args.use_original_step: + logger.info(f" flow_shift: {args.flow_shift}") + logger.info(f" use_loss_weighting: {args.use_loss_weighting}") + + # Start training (fine-tuning) + logger.debug("Starting fine-tuning...") + pretrain(config=cfg, forward_step_func=flux_forward_step) + + +if __name__ == "__main__": + main() diff --git a/examples/diffusion/recipes/flux/inference_flux.py b/examples/diffusion/recipes/flux/inference_flux.py new file mode 100644 index 0000000000..2acc7bf468 --- /dev/null +++ b/examples/diffusion/recipes/flux/inference_flux.py @@ -0,0 +1,84 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch +import torch.distributed as dist + + +def parse_args(): # noqa: D103 + parser = argparse.ArgumentParser(description="FLUX inference") + parser.add_argument("--flux_ckpt", type=str, required=True, help="Path to FLUX checkpoint") + parser.add_argument("--vae_ckpt", type=str, default=None, help="Path to VAE") + parser.add_argument("--t5_version", type=str, default="google/t5-v1_1-xxl") + parser.add_argument("--clip_version", type=str, default="openai/clip-vit-large-patch14") + parser.add_argument("--do_convert_from_hf", action="store_true", default=False) + parser.add_argument( + "--prompts", + type=str, + action="append", + help="Prompt(s) to generate images from. Can be specified multiple times for multiple prompts.", + required=True, + ) + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--width", type=int, default=1024) + parser.add_argument("--num_inference_steps", type=int, default=10) + parser.add_argument("--guidance_scale", type=float, default=0.0) + parser.add_argument("--output_path", type=str, default="/tmp/flux_output") + return parser.parse_args() + + +def main(): # noqa: D103 + args = parse_args() + + # Initialize megatron parallel state + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + if world_size > 1: + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size) + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + # Import FLUX + + from megatron.bridge.diffusion.models.flux import FluxInferencePipeline + + # Create pipeline + pipeline = FluxInferencePipeline( + flux_checkpoint_dir=args.flux_ckpt, + t5_checkpoint_dir=args.t5_version, + clip_checkpoint_dir=args.clip_version, + vae_checkpoint_dir=args.vae_ckpt, + ) + + # Generate + images = pipeline( + prompt=args.prompts, + height=args.height, + width=args.width, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + output_path=args.output_path, + ) + print(f"Generated {len(images)} images to {args.output_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/diffusion/recipes/flux/prepare_energon_dataset_flux.py b/examples/diffusion/recipes/flux/prepare_energon_dataset_flux.py new file mode 100644 index 0000000000..180d0d1bf5 --- /dev/null +++ b/examples/diffusion/recipes/flux/prepare_energon_dataset_flux.py @@ -0,0 +1,621 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import pickle +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import cv2 +import numpy as np +import torch +import torch.distributed as dist +import webdataset as wds +from tqdm import tqdm +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + + +def _map_interpolation(resize_mode: str) -> int: + """Map resize mode string to OpenCV interpolation constant.""" + interpolation_map = { + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "nearest": cv2.INTER_NEAREST, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4, + } + if resize_mode not in interpolation_map: + raise ValueError(f"Invalid resize_mode '{resize_mode}'. Choose from: {list(interpolation_map.keys())}") + return interpolation_map[resize_mode] + + +def _calculate_resize_dimensions( + original_height: int, + original_width: int, + target_size: Optional[Tuple[int, int]], + maintain_aspect_ratio: bool, + center_crop: bool = False, +) -> Tuple[int, int]: + """Calculate target dimensions for resizing.""" + if target_size is None: + return original_height, original_width + + target_height, target_width = target_size + if not maintain_aspect_ratio: + return target_height, target_width + + original_aspect = original_width / max(1, original_height) + target_aspect = target_width / max(1, target_height) + + if center_crop: + # For center crop: resize so BOTH dimensions are >= target (resize on shorter edge) + # This ensures we can crop to exact size without padding + if original_aspect > target_aspect: + # Image is wider: match height, width will be larger + new_height = target_height + new_width = int(round(target_height * original_aspect)) + else: + # Image is taller: match width, height will be larger + new_width = target_width + new_height = int(round(target_width / max(1e-6, original_aspect))) + else: + # For no center crop: resize so image fits within target (resize on longer edge) + if original_aspect > target_aspect: + new_width = target_width + new_height = int(round(target_width / max(1e-6, original_aspect))) + else: + new_height = target_height + new_width = int(round(target_height * original_aspect)) + + return new_height, new_width + + +def _resize_image( + image: np.ndarray, + target_size: Optional[Tuple[int, int]], + resize_mode: str, + maintain_aspect_ratio: bool, + center_crop: bool, +) -> np.ndarray: + """Resize and optionally center crop an image.""" + if target_size is None: + return image + + original_height, original_width = image.shape[:2] + resize_height, resize_width = _calculate_resize_dimensions( + original_height, original_width, target_size, maintain_aspect_ratio, center_crop + ) + + interpolation = _map_interpolation(resize_mode) + resized_image = cv2.resize(image, (resize_width, resize_height), interpolation=interpolation) + + if maintain_aspect_ratio and center_crop: + target_height, target_width = target_size + if resize_height != target_height or resize_width != target_width: + y_start = max(0, (resize_height - target_height) // 2) + x_start = max(0, (resize_width - target_width) // 2) + y_end = min(resize_height, y_start + target_height) + x_end = min(resize_width, x_start + target_width) + resized_image = resized_image[y_start:y_end, x_start:x_end] + + if resized_image.shape[0] < target_height or resized_image.shape[1] < target_width: + pad_height = max(0, target_height - resized_image.shape[0]) + pad_width = max(0, target_width - resized_image.shape[1]) + resized_image = np.pad( + resized_image, ((0, pad_height), (0, pad_width), (0, 0)), mode="constant", constant_values=0 + ) + + return resized_image + + +def _load_image(image_path: str) -> np.ndarray: + """Load an image from file.""" + image = cv2.imread(image_path) + if image is None: + raise ValueError(f"Failed to load image from {image_path}") + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image + + +def _load_metadata(data_folder: Path, image_extensions: List[str] = None) -> List[Dict]: + """ + Load metadata from meta.json or scan directory for images. + + Expected meta.json format (JSON array): + [ + { + "file_name": "image1.jpg", + "caption": "A description of the image" + }, + ... + ] + + Or JSON Lines format (one JSON object per line): + {"file_name": "image1.jpg", "caption": "A description"} + {"file_name": "image2.jpg", "caption": "Another description"} + """ + if image_extensions is None: + image_extensions = [".jpg", ".jpeg", ".png", ".webp", ".bmp"] + + meta_path = data_folder / "meta.json" + if meta_path.exists(): + with open(meta_path, "r") as f: + content = f.read().strip() + + # Try to parse as JSON array first + try: + return json.loads(content) + except json.JSONDecodeError: + # If that fails, try parsing as JSON Lines (one JSON object per line) + items = [] + for line in content.split("\n"): + line = line.strip() + if line: + try: + items.append(json.loads(line)) + except json.JSONDecodeError as e: + print(f"Warning: Failed to parse line in meta.json: {line[:100]}... Error: {e}") + continue + if items: + return items + raise ValueError("Failed to parse meta.json as either JSON array or JSON Lines format") + + # Fallback: scan for image files with sidecar captions + items: List[Dict] = [] + for entry in sorted(data_folder.iterdir()): + if not entry.is_file(): + continue + if entry.suffix.lower() not in image_extensions: + continue + + image_name = entry.name + # Look for caption in .txt file + caption_file = entry.with_suffix(".txt") + caption = "" + if caption_file.exists(): + with open(caption_file, "r") as f: + caption = f.read().strip() + + items.append( + { + "file_name": image_name, + "caption": caption, + } + ) + + if not items: + raise FileNotFoundError(f"No meta.json and no image files found in {data_folder}") + return items + + +@torch.no_grad() +def _init_flux_vae( + model_id: str, + device: str, + enable_memory_optimization: bool, +): + """Initialize FLUX VAE from pretrained model.""" + try: + from diffusers import AutoencoderKL + except ImportError: + raise ImportError("Please install diffusers: pip install diffusers") + + # Use float32 for all devices to avoid dtype mismatch issues + # The FLUX VAE appears to have internal operations that require float32 + dtype = torch.float32 + + vae = AutoencoderKL.from_pretrained( + model_id, + subfolder="vae", + torch_dtype=dtype, + ) + # Ensure all parameters and buffers are on the correct device and dtype + vae = vae.to(device=device, dtype=dtype) + vae.eval() + + # Verify dtype consistency for all parameters + for name, param in vae.named_parameters(): + if param.dtype != dtype: + param.data = param.data.to(dtype) + + # Also convert buffers (like running means in batch norm) + for name, buffer in vae.named_buffers(): + if buffer.dtype not in [torch.int32, torch.int64, torch.long]: # Skip integer buffers + if buffer.dtype != dtype: + buffer.data = buffer.data.to(dtype) + + if enable_memory_optimization and hasattr(vae, "enable_slicing"): + vae.enable_slicing() + if enable_memory_optimization and hasattr(vae, "enable_tiling"): + vae.enable_tiling() + + return vae, dtype + + +@torch.no_grad() +def _init_text_encoders( + t5_model_id: str, + clip_model_id: str, + device: str, +): + """Initialize T5 and CLIP text encoders.""" + # Use float32 to avoid dtype mismatch with Apex fused layer norm + dtype = torch.float32 + + # T5 encoder + t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model_id) + t5_encoder = T5EncoderModel.from_pretrained(t5_model_id, torch_dtype=dtype) + t5_encoder.to(device=device, dtype=dtype) + t5_encoder.eval() + + # CLIP encoder + clip_tokenizer = CLIPTokenizer.from_pretrained(clip_model_id) + clip_encoder = CLIPTextModel.from_pretrained(clip_model_id, torch_dtype=dtype) + clip_encoder.to(device=device, dtype=dtype) + clip_encoder.eval() + + return t5_tokenizer, t5_encoder, clip_tokenizer, clip_encoder, dtype + + +@torch.no_grad() +def _encode_text_flux( + t5_tokenizer: T5TokenizerFast, + t5_encoder: T5EncoderModel, + clip_tokenizer: CLIPTokenizer, + clip_encoder: CLIPTextModel, + device: str, + caption: str, + max_sequence_length: int = 512, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Encode text with both T5 and CLIP encoders. + + Returns: + Tuple of (t5_embeds [seq_len, hidden_dim], clip_pooled_embeds [hidden_dim]) + """ + caption = caption.strip() + + # T5 encoding + t5_inputs = t5_tokenizer( + caption, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_attention_mask=True, + ) + t5_inputs = {k: v.to(device) for k, v in t5_inputs.items()} + t5_outputs = t5_encoder(input_ids=t5_inputs["input_ids"], attention_mask=t5_inputs["attention_mask"]) + t5_embeds = t5_outputs.last_hidden_state[0] # [seq_len, hidden_dim] + + # CLIP encoding + clip_inputs = clip_tokenizer( + caption, + max_length=77, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + clip_inputs = {k: v.to(device) for k, v in clip_inputs.items()} + clip_outputs = clip_encoder(input_ids=clip_inputs["input_ids"]) + clip_pooled_embeds = clip_outputs.pooler_output[0] # [hidden_dim] + + return t5_embeds, clip_pooled_embeds + + +@torch.no_grad() +def _encode_image_latents( + vae, + device: str, + image: np.ndarray, + deterministic_latents: bool, +) -> torch.Tensor: + """ + Encode image to latents using FLUX VAE. + + Args: + vae: FLUX VAE model + device: Device to use + image: RGB numpy array [H, W, C] in range [0, 255] + deterministic_latents: If True, use mean; if False, sample + + Returns: + Latents tensor [C, H_latent, W_latent] + """ + # Normalize to [0, 1] then to [-1, 1] (standard for diffusion VAEs) + image = image.astype(np.float32) / 255.0 + image = image * 2.0 - 1.0 + + # Convert to tensor [1, C, H, W] + image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) + # Match VAE's dtype to avoid dtype mismatch with internal weights + image_tensor = image_tensor.to(device=device, dtype=vae.dtype) + + # FLUX VAE expects inputs in [-1, 1] range + latent_dist = vae.encode(image_tensor) + + if deterministic_latents: + latents = latent_dist.latent_dist.mode() + else: + latents = latent_dist.latent_dist.sample() + + # Remove batch dimension: [1, C, H, W] -> [C, H, W] + latents = latents[0] + + return latents + + +def get_start_end_idx_for_this_rank(dataset_size: int, rank: int, world_size: int) -> Tuple[int, int]: + """Calculate start and end indices for distributed processing.""" + split_size = dataset_size // world_size + start_idx = rank * split_size + # The last rank takes the remainder + end_idx = start_idx + split_size if rank != world_size - 1 else dataset_size + return start_idx, end_idx + + +def _save_individual_sample( + output_dir: Path, + sample_key: str, + latents: torch.Tensor, + text_embeddings: Dict, + json_data: Dict, + processed_image: Optional[np.ndarray] = None, +) -> None: + """ + Save individual files for a sample. + + Args: + output_dir: Base output directory + sample_key: Unique key for this sample (e.g., "000001") + latents: Latent tensor [C, H, W] + text_embeddings: Dict with prompt_embeds and pooled_prompt_embeds + json_data: Metadata dict + processed_image: Optional processed image [H, W, C] in RGB format + """ + sample_dir = output_dir / "individual_samples" / sample_key + sample_dir.mkdir(parents=True, exist_ok=True) + + # Save latents + torch.save(latents, sample_dir / "latents.pt") + + # Save text embeddings + with open(sample_dir / "text_embeddings.pkl", "wb") as f: + pickle.dump(text_embeddings, f) + + # Save metadata + with open(sample_dir / "metadata.json", "w") as f: + json.dump(json_data, f, indent=2) + + # Optionally save processed image + if processed_image is not None: + # Convert RGB back to BGR for OpenCV + image_bgr = cv2.cvtColor(processed_image, cv2.COLOR_RGB2BGR) + cv2.imwrite(str(sample_dir / "processed_image.jpg"), image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 95]) + + +def main(): # noqa: D103 + import argparse + + parser = argparse.ArgumentParser( + description="Prepare FLUX WebDataset shards with VAE latents, T5, and CLIP embeddings" + ) + parser.add_argument("--data_folder", type=str, required=True, help="Folder containing images and meta.json") + parser.add_argument("--output_dir", type=str, required=True, help="Directory to write webdataset shards") + parser.add_argument( + "--vae_model", + default="black-forest-labs/FLUX.1-schnell", + help="FLUX model ID for VAE (e.g., black-forest-labs/FLUX.1-schnell or FLUX.1-dev)", + ) + parser.add_argument( + "--t5_model", + default="google/t5-v1_1-xxl", + help="T5 model ID (e.g., google/t5-v1_1-xxl)", + ) + parser.add_argument( + "--clip_model", + default="openai/clip-vit-large-patch14", + help="CLIP model ID (e.g., openai/clip-vit-large-patch14)", + ) + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") + parser.add_argument( + "--stochastic", + action="store_true", + help="Use stochastic encoding (sampling) instead of deterministic mode", + ) + parser.add_argument("--no-memory-optimization", action="store_true", help="Disable VAE slicing/tiling") + parser.add_argument("--shard_maxcount", type=int, default=10000, help="Max samples per shard") + parser.add_argument("--max_sequence_length", type=int, default=512, help="Max sequence length for T5 encoding") + + # Resize arguments + parser.add_argument("--height", type=int, default=1024, help="Target height for images") + parser.add_argument("--width", type=int, default=1024, help="Target width for images") + parser.add_argument( + "--resize_mode", + default="bilinear", + choices=["bilinear", "bicubic", "nearest", "area", "lanczos"], + help="Interpolation mode for resizing", + ) + parser.add_argument("--no-aspect-ratio", action="store_true", help="Disable aspect ratio preservation") + parser.add_argument("--center-crop", action="store_true", help="Center crop to exact target size after resize") + + # Distributed processing + parser.add_argument("--distributed", action="store_true", help="Use distributed processing") + + # Individual file saving + parser.add_argument( + "--save_individual_files", + action="store_true", + help="Save individual files (latents, embeddings, metadata) in addition to webdataset tars", + ) + parser.add_argument( + "--save_processed_images", + action="store_true", + help="Also save processed images when --save_individual_files is enabled", + ) + + args = parser.parse_args() + + data_folder = Path(args.data_folder) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Setup distributed processing if requested + if args.distributed: + dist.init_process_group("nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + device = f"cuda:{local_rank}" + else: + rank = 0 + world_size = 1 + device = args.device + + # Output shard pattern + if world_size > 1: + shard_pattern = str(output_dir / f"rank{rank}-%06d.tar") + else: + shard_pattern = str(output_dir / "shard-%06d.tar") + + # Target size + target_size = (args.height, args.width) + + # Initialize models + print(f"Rank {rank}: Initializing models...") + vae, vae_dtype = _init_flux_vae( + model_id=args.vae_model, + device=device, + enable_memory_optimization=not args.no_memory_optimization, + ) + + t5_tokenizer, t5_encoder, clip_tokenizer, clip_encoder, text_dtype = _init_text_encoders( + t5_model_id=args.t5_model, + clip_model_id=args.clip_model, + device=device, + ) + + # Load metadata + metadata_list = _load_metadata(data_folder) + print(f"Total samples in dataset: {len(metadata_list)}") + + # Distribute work across ranks + start_idx, end_idx = get_start_end_idx_for_this_rank(len(metadata_list), rank, world_size) + print(f"Rank {rank} of {world_size} processing {end_idx - start_idx} samples, from {start_idx} to {end_idx}") + + if args.save_individual_files: + print(f"Individual files will be saved to: {output_dir / 'individual_samples'}") + if args.save_processed_images: + print("Processed images will also be saved") + + with wds.ShardWriter(shard_pattern, maxcount=args.shard_maxcount) as sink: + written = 0 + for index in tqdm(range(start_idx, end_idx), desc=f"Rank {rank}"): + meta = metadata_list[index] + image_name = meta["file_name"] + caption = meta.get("caption", "") + + image_path = str(data_folder / image_name) + + try: + # Load and resize image + image = _load_image(image_path) + image = _resize_image( + image=image, + target_size=target_size, + resize_mode=args.resize_mode, + maintain_aspect_ratio=not args.no_aspect_ratio, + center_crop=args.center_crop, + ) + + H, W = image.shape[:2] + + # Encode image to latents + latents = _encode_image_latents(vae, device, image, deterministic_latents=not args.stochastic) + + # Encode text with T5 and CLIP + t5_embeds, clip_pooled_embeds = _encode_text_flux( + t5_tokenizer, + t5_encoder, + clip_tokenizer, + clip_encoder, + device, + caption, + max_sequence_length=args.max_sequence_length, + ) + + # Move to CPU + latents_cpu = latents.detach().to(device="cpu") + t5_embeds_cpu = t5_embeds.detach().to(device="cpu") + clip_pooled_embeds_cpu = clip_pooled_embeds.detach().to(device="cpu") + + # Create text embeddings dict + text_embeddings = { + "prompt_embeds": t5_embeds_cpu, + "pooled_prompt_embeds": clip_pooled_embeds_cpu, + } + + # Build JSON metadata + json_data = { + "image_path": image_path, + "processed_height": int(H), + "processed_width": int(W), + "caption": caption, + "deterministic_latents": bool(not args.stochastic), + "memory_optimization": bool(not args.no_memory_optimization), + "model_version": "flux", + "vae_normalization": "[-1, 1]", + "resize_settings": { + "target_size": target_size, + "resize_mode": args.resize_mode, + "maintain_aspect_ratio": bool(not args.no_aspect_ratio), + "center_crop": bool(args.center_crop), + }, + } + + # Write to webdataset + sample = { + "__key__": f"{index:06}", + "pth": latents_cpu, + "pickle": pickle.dumps(text_embeddings), + "json": json_data, + } + sink.write(sample) + written += 1 + + # Optionally save individual files + if args.save_individual_files: + _save_individual_sample( + output_dir=output_dir, + sample_key=f"{index:06}", + latents=latents_cpu, + text_embeddings=text_embeddings, + json_data=json_data, + processed_image=image if args.save_processed_images else None, + ) + + except Exception as e: + print(f"Rank {rank}: Error processing {image_path}: {e}") + continue + + print(f"Rank {rank}: Done! Wrote {written} samples.") + + if args.distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/diffusion/recipes/flux/pretrain_flux.py b/examples/diffusion/recipes/flux/pretrain_flux.py new file mode 100644 index 0000000000..127c4346c0 --- /dev/null +++ b/examples/diffusion/recipes/flux/pretrain_flux.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FLUX Pretraining Script with YAML and CLI Configuration Overrides. + +This script provides a flexible way to pretrain FLUX models using Megatron-Bridge with support for +both YAML configuration files and command-line overrides using Hydra-style syntax. + +Forward Step Options: + - Automodel FlowMatchingPipeline (default): Unified flow matching implementation + - Original FluxForwardStep (--use-original-step): Classic implementation + +Examples: + Basic usage with default configuration (uses automodel pipeline): + $ torchrun --nproc_per_node=8 pretrain_flux.py --mock + + Using original FluxForwardStep: + $ torchrun --nproc_per_node=8 pretrain_flux.py --mock --use-original-step + + Using a custom YAML config file: + $ torchrun --nproc_per_node=8 pretrain_flux.py --config-file my_custom_config.yaml + + Using CLI overrides only: + $ torchrun --nproc_per_node=8 pretrain_flux.py model.tensor_model_parallel_size=4 train.train_iters=100000 + + Combining YAML and CLI overrides (CLI takes precedence): + $ torchrun --nproc_per_node=8 pretrain_flux.py --config-file conf/my_config.yaml \ + model.pipeline_dtype=torch.float16 \ + train.global_batch_size=512 + + Using automodel pipeline with custom parameters (automodel is default): + $ torchrun --nproc_per_node=8 pretrain_flux.py --mock \ + --flow-shift=1.0 --use-loss-weighting + +Configuration Precedence: + 1. Base configuration from pretrain_config() recipe + 2. YAML overrides from --config-file (if provided) + 3. CLI overrides (highest precedence) + +Supported Override Syntax: + - Standard assignment: key=value + - Nested assignment: section.subsection.key=value + - Addition: +new_key=value + - Deletion: ~key_to_remove + - Type conversion: Automatic for basic types (int, float, bool, str) + - Complex types: torch.dtype, enums, etc. are supported +""" + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import Tuple + +from omegaconf import OmegaConf + +from megatron.bridge.diffusion.models.flux.flux_step_with_automodel import create_flux_forward_step +from megatron.bridge.diffusion.recipes.flux.flux import pretrain_config +from megatron.bridge.training.config import ConfigContainer +from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.training.utils.omegaconf_utils import ( + apply_overrides, + create_omegaconf_dict_config, + parse_hydra_overrides, +) +from megatron.bridge.utils.common_utils import get_rank_safe + + +logger: logging.Logger = logging.getLogger(__name__) + + +# Define paths relative to this script's location +SCRIPT_DIR: Path = Path(__file__).parent.resolve() +DEFAULT_CONFIG_FILENAME: str = "flux_pretrain_override_example.yaml" +DEFAULT_CONFIG_FILE_PATH: Path = SCRIPT_DIR / "conf" / DEFAULT_CONFIG_FILENAME + + +def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: + """Parse command line arguments, separating known script args from OmegaConf overrides.""" + parser = argparse.ArgumentParser( + description="Pretrain FLUX model using Megatron-Bridge with YAML and CLI overrides", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument("--mock", action="store_true", help="Whether to use mock data.") + parser.add_argument( + "--timestep-sampling", + choices=["logit_normal", "uniform", "mode"], + default="logit_normal", + help="Timestep sampling strategy for flow matching.", + ) + parser.add_argument( + "--logit-mean", + type=float, + default=0.0, + help="Mean for logit-normal timestep sampling.", + ) + parser.add_argument( + "--logit-std", + type=float, + default=1.0, + help="Std for logit-normal timestep sampling.", + ) + parser.add_argument( + "--mode-scale", + type=float, + default=1.29, + help="Scale for mode timestep sampling.", + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=3.5, + help="Guidance scale for FLUX-dev models.", + ) + parser.add_argument( + "--scheduler-steps", + type=int, + default=1000, + help="Number of scheduler training steps.", + ) + parser.add_argument( + "--config-file", + type=str, + default=str(DEFAULT_CONFIG_FILE_PATH), + help="Path to the YAML OmegaConf override file. Default: conf/flux_pretrain_override_example.yaml", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + + # Forward step implementation choice + parser.add_argument( + "--use-original-step", + action="store_true", + help="Use original FluxForwardStep instead of automodel FlowMatchingPipeline (default)", + ) + parser.add_argument( + "--flow-shift", + type=float, + default=1.0, + help="Flow shift parameter (for automodel pipeline)", + ) + parser.add_argument( + "--use-loss-weighting", + action="store_true", + help="Use loss weighting (for automodel pipeline)", + ) + + # Parse known args for the script, remaining will be treated as overrides + args, cli_dotlist_overrides = parser.parse_known_args() + return args, cli_dotlist_overrides + + +def main() -> None: + """ + Entry point for the FLUX pretraining script. + + This function orchestrates the complete configuration workflow: + 1. Loads the base configuration from pretrain_config() recipe + 2. Applies YAML overrides from --config-file (if exists) + 3. Applies CLI overrides using Hydra-style syntax + 4. Starts Megatron pretraining with the final merged configuration + + Configuration merging preserves callable fields (like activation functions) + and handles type conversions automatically. + + Examples of CLI usage: + # Use default config with custom learning rate (automodel pipeline is default) + torchrun --nproc_per_node=8 pretrain_flux.py --mock optimizer.lr=0.0002 + + # Use original FluxForwardStep instead of automodel pipeline + torchrun --nproc_per_node=8 pretrain_flux.py --mock --use-original-step + + # Custom config file with additional overrides + torchrun --nproc_per_node=8 pretrain_flux.py --config-file my_config.yaml train.train_iters=50000 + + # Multiple overrides for distributed training (uses automodel by default) + torchrun --nproc_per_node=8 pretrain_flux.py --mock \ + model.tensor_model_parallel_size=4 \ + model.pipeline_model_parallel_size=2 \ + train.global_batch_size=512 + + # Automodel pipeline with custom flow matching parameters + torchrun --nproc_per_node=8 pretrain_flux.py --mock \ + --flow-shift=1.0 --use-loss-weighting + """ + args, cli_overrides = parse_cli_args() + + logger.info("Megatron-Bridge FLUX Pretraining Script with YAML & CLI Overrides") + logger.info("------------------------------------------------------------------") + + # Load base configuration from the recipe as a Python dataclass + cfg: ConfigContainer = pretrain_config(mock=args.mock) + logger.info("Loaded base configuration") + + # Print configuration on rank 0 + if get_rank_safe() == 0: + cfg.print_yaml() + + # Convert the initial Python dataclass to an OmegaConf DictConfig for merging + merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg) + + # Load and merge YAML overrides if a config file is provided + if args.config_file: + logger.debug(f"Loading YAML overrides from: {args.config_file}") + if not os.path.exists(args.config_file): + logger.error(f"Override YAML file not found: {args.config_file}") + sys.exit(1) + yaml_overrides_omega = OmegaConf.load(args.config_file) + merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega) + logger.debug("YAML overrides merged successfully.") + + # Apply command-line overrides using Hydra-style parsing + if cli_overrides: + logger.debug(f"Applying Hydra-style command-line overrides: {cli_overrides}") + merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides) + logger.debug("Hydra-style command-line overrides applied successfully.") + + # Apply the final merged OmegaConf configuration back to the original ConfigContainer + logger.debug("Applying final merged configuration back to Python ConfigContainer...") + final_overrides_as_dict = OmegaConf.to_container(merged_omega_conf, resolve=True) + # Apply overrides while preserving excluded fields + apply_overrides(cfg, final_overrides_as_dict, excluded_fields) + + # Create forward step (configurable: original or automodel pipeline) + # Default is automodel pipeline unless --use-original-step is specified + if not args.use_original_step: + # Use automodel FlowMatchingPipeline + flux_forward_step = create_flux_forward_step( + use_automodel_pipeline=True, + timestep_sampling=args.timestep_sampling, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + flow_shift=args.flow_shift, + scheduler_steps=args.scheduler_steps, + guidance_scale=args.guidance_scale, + use_loss_weighting=args.use_loss_weighting, + ) + if get_rank_safe() == 0: + logger.info("=" * 70) + logger.info("✅ Using AUTOMODEL FlowMatchingPipeline") + logger.info(f" Timestep Sampling: {args.timestep_sampling}") + logger.info(f" Flow Shift: {args.flow_shift}") + logger.info(f" Loss Weighting: {args.use_loss_weighting}") + logger.info("=" * 70) + else: + # Use original FluxForwardStep + flux_forward_step = create_flux_forward_step( + use_automodel_pipeline=False, + timestep_sampling=args.timestep_sampling, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + scheduler_steps=args.scheduler_steps, + guidance_scale=args.guidance_scale, + ) + if get_rank_safe() == 0: + logger.info("=" * 70) + logger.info("✅ Using ORIGINAL FluxForwardStep") + logger.info(f" Timestep Sampling: {args.timestep_sampling}") + logger.info("=" * 70) + + # Display final configuration + if get_rank_safe() == 0: + logger.info("--- Final Merged Configuration ---") + cfg.print_yaml() + logger.info("----------------------------------") + logger.info("FluxForwardStep config:") + logger.info(f" timestep_sampling: {args.timestep_sampling}") + logger.info(f" logit_mean: {args.logit_mean}") + logger.info(f" logit_std: {args.logit_std}") + logger.info(f" mode_scale: {args.mode_scale}") + logger.info(f" scheduler_steps: {args.scheduler_steps}") + logger.info(f" guidance_scale: {args.guidance_scale}") + if not args.use_original_step: + logger.info(f" flow_shift: {args.flow_shift}") + logger.info(f" use_loss_weighting: {args.use_loss_weighting}") + + # Start training + logger.debug("Starting pretraining...") + pretrain(config=cfg, forward_step_func=flux_forward_step) + + +if __name__ == "__main__": + main() diff --git a/examples/diffusion/recipes/wan/README-perf-test.md b/examples/diffusion/recipes/wan/README-perf-test.md new file mode 100644 index 0000000000..e2d972cba5 --- /dev/null +++ b/examples/diffusion/recipes/wan/README-perf-test.md @@ -0,0 +1,177 @@ +## WAN Model Setup and Usage (for Perf Test) + +This guide provides concise steps to set up the environment and run WAN pretraining and inference. It pins repo commits and shows explicit commands for the 1.3B and 14B configurations. + +## Container Launch + +```bash +CONT="nvcr.io/nvidia/nemo:25.09.00" +MOUNT="/lustre/fsw/:/lustre/fsw/" + +srun -t 02:00:00 \ + --account \ + -N 1 \ + -J \ + -p batch \ + --exclusive \ + --container-image="${CONT}" \ + --container-mounts="${MOUNT}" \ + --pty bash +``` + +## Setup Inside the Container + +Setup DFM, Megatron-Bridge, Megatron-LM with specific commits, and other dependencies. + +```bash +cd /opt/ + +# DFM (pinned) +git clone --no-checkout https://github.com/NVIDIA-NeMo/DFM.git +git -C DFM checkout 174bb7b34de002ebbbcae1ba8e2b12363c7dee01 +export DFM_PATH=/opt/DFM + +# Megatron-Bridge (pinned) +rm -rf /opt/Megatron-Bridge +git clone --no-checkout https://github.com/huvunvidia/Megatron-Bridge.git +git -C Megatron-Bridge checkout 713ab548e4bfee307eb94a7bb3f57c17dbb31b50 + +# Megatron-LM (pinned) +rm -rf /opt/Megatron-LM +git clone --no-checkout https://github.com/NVIDIA/Megatron-LM.git +git -C Megatron-LM checkout ce8185cbbe04f38beb74360e878450f2e8525885 + +# Python path +export PYTHONPATH="${DFM_PATH}/.:/opt/Megatron-Bridge/.:/opt/Megatron-LM" + +# Python deps +python3 -m pip install --upgrade diffusers==0.35.1 +pip install easydict imageio imageio-ffmpeg +``` + +## Pretraining +Set data path and checkpoint directory: + +```bash +DATASET_PATH="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/shared_datasets/processed_arrietty_scene_automodel" +EXP_NAME=wan_debug_perf +CHECKPOINT_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/results/wan_finetune/${EXP_NAME}" + +export HF_TOKEN= +export WANDB_API_KEY= +cd ${DFM_PATH} +``` + + +### 1.3B configuration + +```bash +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/diffusion/recipes/wan/pretrain_wan.py \ + --training-mode pretrain \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=1 \ + model.context_parallel_size=4 \ + model.crossattn_emb_size=1536 \ + model.hidden_size=1536 \ + model.ffn_hidden_size=8960 \ + model.num_attention_heads=12 \ + model.num_layers=30 \ + model.qkv_format=thd \ + dataset.path="${DATASET_PATH}" \ + checkpoint.save="${CHECKPOINT_DIR}" \ + checkpoint.load="${CHECKPOINT_DIR}" \ + checkpoint.load_optim=false \ + checkpoint.save_interval=200 \ + optimizer.lr=5e-6 \ + optimizer.min_lr=5e-6 \ + train.eval_iters=0 \ + scheduler.lr_decay_style=constant \ + scheduler.lr_warmup_iters=0 \ + model.seq_length=2048 \ + dataset.seq_length=2048 \ + train.global_batch_size=2 \ + train.micro_batch_size=1 \ + dataset.global_batch_size=2 \ + dataset.micro_batch_size=1 \ + logger.log_interval=1 \ + logger.wandb_project="wan" \ + logger.wandb_exp_name="${EXP_NAME}" \ + logger.wandb_save_dir="${CHECKPOINT_DIR}" +``` + +### 14B configuration + +```bash +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/diffusion/recipes/wan/pretrain_wan.py \ + --training-mode pretrain \ + model.tensor_model_parallel_size=2 \ + model.pipeline_model_parallel_size=1 \ + model.context_parallel_size=4 \ + model.recompute_granularity=full \ + model.recompute_method=uniform \ + model.recompute_num_layers=1 \ + model.crossattn_emb_size=5120 \ + model.hidden_size=5120 \ + model.ffn_hidden_size=13824 \ + model.num_attention_heads=40 \ + model.num_layers=40 \ + model.qkv_format=thd \ + dataset.path="${DATASET_PATH}" \ + checkpoint.save="${CHECKPOINT_DIR}" \ + checkpoint.load="${CHECKPOINT_DIR}" \ + checkpoint.load_optim=false \ + checkpoint.save_interval=200 \ + optimizer.lr=5e-6 \ + optimizer.min_lr=5e-6 \ + train.eval_iters=0 \ + scheduler.lr_decay_style=constant \ + scheduler.lr_warmup_iters=0 \ + model.seq_length=2048 \ + dataset.seq_length=2048 \ + train.global_batch_size=2 \ + train.micro_batch_size=1 \ + dataset.global_batch_size=2 \ + dataset.micro_batch_size=1 \ + logger.log_interval=1 \ + logger.wandb_project="wan" \ + logger.wandb_exp_name="${EXP_NAME}" \ + logger.wandb_save_dir="${CHECKPOINT_DIR}" +``` + +### Using mock data (optional, for debugging) + +- Using `--mock` argument. +- Adjust `video_size` (F_latents, H_latents, W_latents) and `number_packed_samples` of `WanMockDataModuleConfig` in `wan.py`. Total `seq_len = F * H * W * number_packed_samples`. + +## Inference + +```bash +cd ${DFM_PATH} +export HF_TOKEN= + +T5_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/wan_checkpoints/t5" +VAE_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/wan_checkpoints/vae" +CKPT_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/shared_checkpoints/megatron_checkpoint_1.3B" + +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/diffusion/recipes/wan/inference_wan.py \ + --task t2v-1.3B \ + --sizes 480*832 \ + --checkpoint_dir "${CKPT_DIR}" \ + --checkpoint_step 0 \ + --t5_checkpoint_dir "${T5_DIR}" \ + --vae_checkpoint_dir "${VAE_DIR}" \ + --frame_nums 81 \ + --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + --tensor_parallel_size 1 \ + --context_parallel_size 1 \ + --pipeline_parallel_size 1 \ + --sequence_parallel False \ + --base_seed 42 \ + --sample_steps 50 +``` + +## Notes + +- Replace placeholders (tokens, account, dataset/checkpoint paths) with your own. +- Keep the specified commit hashes for compatibility. +- `NVTE_FUSED_ATTN=1` enables fused attention where supported. diff --git a/examples/diffusion/recipes/wan/conf/gb200_perf_pretrain_mock.yaml b/examples/diffusion/recipes/wan/conf/gb200_perf_pretrain_mock.yaml new file mode 100644 index 0000000000..7b170d363f --- /dev/null +++ b/examples/diffusion/recipes/wan/conf/gb200_perf_pretrain_mock.yaml @@ -0,0 +1,33 @@ +model: + tensor_model_parallel_size: 1 + sequence_parallel: false + pipeline_model_parallel_size: 1 + context_parallel_size: 4 + crossattn_emb_size: 5120 + hidden_size: 5120 + ffn_hidden_size: 13824 + num_attention_heads: 40 + num_layers: 40 + qkv_format: thd + seq_length: 2048 # This is not used + +train: + global_batch_size: 64 + micro_batch_size: 1 + eval_iters: 0 + +scheduler: + lr_decay_style: constant + lr_warmup_iters: 0 + +optimizer: + lr: 5e-6 + min_lr: 5e-6 + +dataset: + seq_length: 2048 # This is not used + global_batch_size: 64 + micro_batch_size: 1 + +logger: + log_interval: 1 diff --git a/examples/diffusion/recipes/wan/conf/gb300_perf_pretrain_mock.yaml b/examples/diffusion/recipes/wan/conf/gb300_perf_pretrain_mock.yaml new file mode 100644 index 0000000000..a35e6238c4 --- /dev/null +++ b/examples/diffusion/recipes/wan/conf/gb300_perf_pretrain_mock.yaml @@ -0,0 +1,33 @@ +model: + tensor_model_parallel_size: 1 + sequence_parallel: false + pipeline_model_parallel_size: 1 + context_parallel_size: 2 + crossattn_emb_size: 5120 + hidden_size: 5120 + ffn_hidden_size: 13824 + num_attention_heads: 40 + num_layers: 40 + qkv_format: thd + seq_length: 2048 # This is not used + +train: + global_batch_size: 64 + micro_batch_size: 1 + eval_iters: 0 + +scheduler: + lr_decay_style: constant + lr_warmup_iters: 0 + +optimizer: + lr: 5e-6 + min_lr: 5e-6 + +dataset: + seq_length: 2048 # This is not used + global_batch_size: 64 + micro_batch_size: 1 + +logger: + log_interval: 1 diff --git a/examples/diffusion/recipes/wan/conf/h100_perf_pretrain_mock.yaml b/examples/diffusion/recipes/wan/conf/h100_perf_pretrain_mock.yaml new file mode 100644 index 0000000000..0013fb32bf --- /dev/null +++ b/examples/diffusion/recipes/wan/conf/h100_perf_pretrain_mock.yaml @@ -0,0 +1,37 @@ +model: + tensor_model_parallel_size: 2 + sequence_parallel: true + pipeline_model_parallel_size: 1 + context_parallel_size: 4 + recompute_granularity: full + recompute_method: block + recompute_num_layers: 8 + crossattn_emb_size: 5120 + hidden_size: 5120 + ffn_hidden_size: 13824 + num_attention_heads: 40 + num_layers: 40 + qkv_format: thd + seq_length: 2048 + +train: + global_batch_size: 128 + micro_batch_size: 1 + eval_iters: 0 + empty_unused_memory_level: 0 + +scheduler: + lr_decay_style: constant + lr_warmup_iters: 0 + +optimizer: + lr: 5e-6 + min_lr: 5e-6 + +dataset: + seq_length: 2048 # This is not used + global_batch_size: 128 + micro_batch_size: 1 + +logger: + log_interval: 1 diff --git a/examples/diffusion/recipes/wan/conf/wan_14B.yaml b/examples/diffusion/recipes/wan/conf/wan_14B.yaml new file mode 100644 index 0000000000..0a1a0149b8 --- /dev/null +++ b/examples/diffusion/recipes/wan/conf/wan_14B.yaml @@ -0,0 +1,42 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example override file + +# To override a parameter, ensure the structure matches the ConfigContainer +# and its sub-configurations (e.g., model, train, etc.) +# Top-level ConfigContainer fields are dataclasses themselves + +model: + + crossattn_emb_size: 5120 + hidden_size: 5120 + ffn_hidden_size: 13824 + num_attention_heads: 40 + num_layers: 40 + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + context_parallel_size: 4 + sequence_parallel: true + recompute_granularity: full + recompute_method: uniform + recompute_num_layers: 1 + +train: + global_batch_size: 1 + micro_batch_size: 1 + +dataset: + global_batch_size: 1 + micro_batch_size: 1 diff --git a/examples/diffusion/recipes/wan/conf/wan_1_3B.yaml b/examples/diffusion/recipes/wan/conf/wan_1_3B.yaml new file mode 100644 index 0000000000..89a15d4be9 --- /dev/null +++ b/examples/diffusion/recipes/wan/conf/wan_1_3B.yaml @@ -0,0 +1,38 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example override file + +# To override a parameter, ensure the structure matches the ConfigContainer +# and its sub-configurations (e.g., model, train, etc.) +# Top-level ConfigContainer fields are dataclasses themselves + +model: + crossattn_emb_size: 1536 + hidden_size: 1536 + ffn_hidden_size: 8960 + num_attention_heads: 12 + num_layers: 30 + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + context_parallel_size: 8 + sequence_parallel: false + +train: + global_batch_size: 2 + micro_batch_size: 1 + +dataset: + global_batch_size: 2 + micro_batch_size: 1 diff --git a/examples/diffusion/recipes/wan/conf/wan_pretrain_override_example.yaml b/examples/diffusion/recipes/wan/conf/wan_pretrain_override_example.yaml new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/diffusion/recipes/wan/conversion/convert_checkpoints.py b/examples/diffusion/recipes/wan/conversion/convert_checkpoints.py new file mode 100644 index 0000000000..fc8dedc6e6 --- /dev/null +++ b/examples/diffusion/recipes/wan/conversion/convert_checkpoints.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Megatron-HuggingFace Checkpoint Conversion Example + +This script demonstrates how to convert models between HuggingFace and Megatron formats +using the AutoBridge import_ckpt and export_ckpt methods. + +Usage examples: + # Download the HF checkpoint locally + huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --local-dir /root/.cache/huggingface/wan2.1 \ + --local-dir-use-symlinks False + + # Import a HuggingFace model to Megatron format + python examples/diffusion/recipes/wan/conversion/convert_checkpoints.py import \ + --hf-model /root/.cache/huggingface/wan2.1 \ + --megatron-path /workspace/checkpoints/megatron_checkpoints/wan_1_3b + + # Export a Megatron checkpoint to HuggingFace format + python examples/diffusion/recipes/wan/conversion/convert_checkpoints.py export \ + --hf-model /root/.cache/huggingface/wan2.1 \ + --megatron-path /workspace/checkpoints/megatron_checkpoints/wan_1_3b/iter_0000000 \ + --hf-path /workspace/checkpoints/hf_checkpoints/wan_1_3b_hf + + NOTE: The converted checkpoint /workspace/checkpoints/hf_checkpoints/wan_1_3b_hf + only contains the DiT model transformer weights. You still need other components in + the diffusion pipeline (VAE, text encoders, etc) to run inference. To do so, you can + duplicate the original HF checkpoint directory /root/.cache/huggingface/wan2.1 (which + contains VAE, text encoders, etc.), and replace ./transformer with + /workspace/checkpoints/hf_checkpoints/wan_1_3b_hf/transformer. + +""" + +import argparse +import os +import random +import sys +from pathlib import Path +from typing import Optional + +import torch + +from megatron.bridge import AutoBridge +from megatron.bridge.diffusion.conversion.wan.wan_bridge import WanBridge +from megatron.bridge.diffusion.conversion.wan.wan_hf_pretrained import PreTrainedWAN +from megatron.bridge.training.model_load_save import ( + load_megatron_model, + save_megatron_model, + temporary_distributed_context, +) + + +def validate_path(path: str, must_exist: bool = False) -> Path: + """Validate and convert string path to Path object.""" + path_obj = Path(path) + if must_exist and not path_obj.exists(): + raise ValueError(f"Path does not exist: {path}") + return path_obj + + +def get_torch_dtype(dtype_str: str) -> torch.dtype: + """Convert string to torch dtype.""" + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + if dtype_str not in dtype_map: + raise ValueError(f"Unsupported dtype: {dtype_str}. Supported: {list(dtype_map.keys())}") + return dtype_map[dtype_str] + + +def import_hf_to_megatron( + hf_model: str, + megatron_path: str, + torch_dtype: Optional[str] = None, + device_map: Optional[str] = None, + trust_remote_code: bool = False, +) -> None: + """ + Import a HuggingFace model and save it as a Megatron checkpoint. + + Args: + hf_model: HuggingFace model ID or path to model directory + megatron_path: Directory path where the Megatron checkpoint will be saved + torch_dtype: Model precision ("float32", "float16", "bfloat16") + device_map: Device placement strategy ("auto", "cuda:0", etc.) + trust_remote_code: Allow custom model code execution + """ + print(f"🔄 Starting import: {hf_model} -> {megatron_path}") + + # Prepare kwargs + kwargs = {} + if torch_dtype: + kwargs["torch_dtype"] = get_torch_dtype(torch_dtype) + print(f" Using torch_dtype: {torch_dtype}") + + if device_map: + kwargs["device_map"] = device_map + print(f" Using device_map: {device_map}") + + if trust_remote_code: + kwargs["trust_remote_code"] = trust_remote_code + print(f" Trust remote code: {trust_remote_code}") + + # Import using the convenience method + print(f"đŸ“Ĩ Loading HuggingFace model: {hf_model}") + try: + AutoBridge.import_ckpt( + hf_model_id=hf_model, + megatron_path=megatron_path, + **kwargs, + ) + except ValueError as e: + # Fallback for Diffusers-based WAN repos that do not provide a transformers config + msg = str(e) + is_wan_repo = ("wan" in hf_model.lower()) or ("diffusers" in hf_model.lower()) + auto_config_failed = ("Unrecognized model" in msg) or ("Failed to load configuration" in msg) + if is_wan_repo or auto_config_failed: + print("â„šī¸ AutoConfig path failed; falling back to WAN Diffusers conversion.") + # Minimal single-rank env to satisfy provider init if needed + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", str(29500 + random.randint(0, 1000))) + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + + hf = PreTrainedWAN(hf_model) + bridge = WanBridge() + provider = bridge.provider_bridge(hf) + provider.perform_initialization = False + megatron_models = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True) + bridge.load_weights_hf_to_megatron(hf, megatron_models) + save_megatron_model(megatron_models, megatron_path, hf_tokenizer_path=None) + else: + raise + + print(f"✅ Successfully imported model to: {megatron_path}") + + # Verify the checkpoint was created + checkpoint_path = Path(megatron_path) + if checkpoint_path.exists(): + print("📁 Checkpoint structure:") + for item in checkpoint_path.iterdir(): + if item.is_dir(): + print(f" 📂 {item.name}/") + else: + print(f" 📄 {item.name}") + + +def export_megatron_to_hf( + hf_model: str, + megatron_path: str, + hf_path: str, + show_progress: bool = True, + strict: bool = True, +) -> None: + """ + Export a Megatron checkpoint to HuggingFace format. + + Args: + megatron_path: Directory path where the Megatron checkpoint is stored + hf_path: Directory path where the HuggingFace model will be saved + show_progress: Display progress bar during weight export + """ + print(f"🔄 Starting export: {megatron_path} -> {hf_path}") + + # Validate megatron checkpoint exists + checkpoint_path = validate_path(megatron_path, must_exist=True) + print(f"📂 Found Megatron checkpoint: {checkpoint_path}") + + # Look for configuration files to determine the model type + config_files = list(checkpoint_path.glob("**/run_config.yaml")) + if not config_files: + # Look in iter_ subdirectories + iter_dirs = [d for d in checkpoint_path.iterdir() if d.is_dir() and d.name.startswith("iter_")] + if iter_dirs: + # Use the latest iteration + latest_iter = max(iter_dirs, key=lambda d: int(d.name.replace("iter_", ""))) + config_files = list(latest_iter.glob("run_config.yaml")) + + if not config_files: + raise FileNotFoundError( + f"Could not find run_config.yaml in {checkpoint_path}. Please ensure this is a valid Megatron checkpoint." + ) + + print(f"📋 Found configuration: {config_files[0]}") + + # Try generic export first + try: + # For demonstration, we'll create a bridge from a known config + # This would typically be extracted from the checkpoint metadata + bridge = AutoBridge.from_hf_pretrained(hf_model, trust_remote_code=True) + + # Export using the convenience method + print("📤 Exporting to HuggingFace format...") + bridge.export_ckpt( + megatron_path=megatron_path, + hf_path=hf_path, + show_progress=show_progress, + ) + except ValueError as e: + # Fallback for Diffusers-based WAN repos that do not provide a transformers config + msg = str(e) + is_wan_repo = ("wan" in hf_model.lower()) or ("diffusers" in hf_model.lower()) + auto_config_failed = ("Unrecognized model" in msg) or ("Failed to load configuration" in msg) + if is_wan_repo or auto_config_failed: + print("â„šī¸ AutoConfig path failed; falling back to WAN Diffusers export.") + # Minimal single-process distributed context on CPU for loading Megatron ckpt + with temporary_distributed_context(backend="gloo"): + # Resolve latest iter_* directory (use the config file we found) + checkpoint_iter_dir = config_files[0].parent + # 1) Load Megatron model from checkpoint + megatron_models = load_megatron_model( + str(checkpoint_iter_dir), use_cpu_init=True, skip_temp_dist_context=True + ) + if not isinstance(megatron_models, list): + megatron_models = [megatron_models] + + # 2) Prepare HF WAN wrapper for state/metadata and save artifacts + hf = PreTrainedWAN(hf_model) + Path(hf_path).mkdir(parents=True, exist_ok=True) + # Some diffusers configs are FrozenDict and don't support save_pretrained; skip quietly + try: + hf.save_artifacts(hf_path) + except Exception: + pass + + # 3) Stream-export weights Megatron -> HF safetensors via WAN bridge + bridge = WanBridge() + generator = bridge.stream_weights_megatron_to_hf( + megatron_models, hf, cpu=True, show_progress=show_progress + ) + # 4) Save streamed weights into hf_path + hf.state.source.save_generator(generator, hf_path) + else: + raise + + print(f"✅ Successfully exported model to: {hf_path}") + + # Verify the export was created + export_path = Path(hf_path) + if export_path.exists(): + print("📁 Export structure:") + for item in export_path.iterdir(): + if item.is_dir(): + print(f" 📂 {item.name}/") + else: + print(f" 📄 {item.name}") + + print("🔍 You can now load this model with:") + print(" from transformers import AutoModelForCausalLM") + print(f" model = AutoModelForCausalLM.from_pretrained('{hf_path}')") + + +def main(): + """Main function to handle command line arguments and execute conversions.""" + parser = argparse.ArgumentParser( + description="Convert models between HuggingFace and Megatron formats", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + subparsers = parser.add_subparsers(dest="command", help="Conversion direction") + + # Import subcommand (HF -> Megatron) + import_parser = subparsers.add_parser("import", help="Import HuggingFace model to Megatron checkpoint format") + import_parser.add_argument("--hf-model", required=True, help="HuggingFace model ID or path to model directory") + import_parser.add_argument( + "--megatron-path", required=True, help="Directory path where the Megatron checkpoint will be saved" + ) + import_parser.add_argument("--torch-dtype", choices=["float32", "float16", "bfloat16"], help="Model precision") + import_parser.add_argument("--device-map", help='Device placement strategy (e.g., "auto", "cuda:0")') + import_parser.add_argument("--trust-remote-code", action="store_true", help="Allow custom model code execution") + + # Export subcommand (Megatron -> HF) + export_parser = subparsers.add_parser("export", help="Export Megatron checkpoint to HuggingFace format") + export_parser.add_argument("--hf-model", required=True, help="HuggingFace model ID or path to model directory") + export_parser.add_argument( + "--megatron-path", required=True, help="Directory path where the Megatron checkpoint is stored" + ) + export_parser.add_argument( + "--hf-path", required=True, help="Directory path where the HuggingFace model will be saved" + ) + export_parser.add_argument("--no-progress", action="store_true", help="Disable progress bar during export") + export_parser.add_argument( + "--not-strict", action="store_true", help="Allow source and target checkpoint to have different keys" + ) + + args = parser.parse_args() + + if not args.command: + parser.print_help() + return 1 + + if args.command == "import": + import_hf_to_megatron( + hf_model=args.hf_model, + megatron_path=args.megatron_path, + torch_dtype=args.torch_dtype, + device_map=args.device_map, + trust_remote_code=args.trust_remote_code, + ) + + elif args.command == "export": + export_megatron_to_hf( + hf_model=args.hf_model, + megatron_path=args.megatron_path, + hf_path=args.hf_path, + show_progress=not args.no_progress, + strict=not args.not_strict, + ) + else: + raise RuntimeError(f"Unknown command: {args.command}") + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/diffusion/recipes/wan/inference_wan.py b/examples/diffusion/recipes/wan/inference_wan.py new file mode 100644 index 0000000000..e9f870de82 --- /dev/null +++ b/examples/diffusion/recipes/wan/inference_wan.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import sys +import warnings +from datetime import datetime + +from easydict import EasyDict + + +warnings.filterwarnings("ignore") + +import random + +import torch +import torch.distributed as dist + +from megatron.bridge.diffusion.models.wan.flow_matching.flow_inference_pipeline import FlowInferencePipeline +from megatron.bridge.diffusion.models.wan.inference import SIZE_CONFIGS, SUPPORTED_SIZES +from megatron.bridge.diffusion.models.wan.inference.utils import cache_video, str2bool + + +EXAMPLE_PROMPT = { + "t2v-1.3B": { + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "t2v-14B": { + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, +} + + +def _validate_args(args): + # Basic check + assert args.task in SUPPORTED_SIZES, f"Unsupport task: {args.task}" + assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" + + # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. + if args.sample_steps is None: + args.sample_steps = 50 + + if args.sample_shift is None: + args.sample_shift = 5.0 + + # Frames default handled later; no single frame arg anymore + + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize) + # Size check: only validate provided --sizes; default handled later + if args.sizes is not None and len(args.sizes) > 0: + for s in args.sizes: + assert s in SUPPORTED_SIZES[args.task], ( + f"Unsupport size {s} for task {args.task}, supported sizes are: " + f"{', '.join(SUPPORTED_SIZES[args.task])}" + ) + + +def _parse_args(): + parser = argparse.ArgumentParser(description="Generate a image or video from a text prompt or image using Wan") + parser.add_argument( + "--task", type=str, default="t2v-14B", choices=list(SUPPORTED_SIZES.keys()), help="The task to run." + ) + parser.add_argument( + "--sizes", + type=str, + nargs="+", + default=None, + choices=list(SIZE_CONFIGS.keys()), + help="A list of sizes to generate multiple images or videos (WIDTH*HEIGHT). Example: --sizes 1280*720 1920*1080", + ) + parser.add_argument( + "--frame_nums", + type=int, + nargs="+", + default=None, + help="List of frame counts (each should be 4n+1). Broadcasts if single value.", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default=None, + help="The path to the main WAN checkpoint directory.", + ) + parser.add_argument( + "--checkpoint_step", + type=int, + default=None, + help=( + "Optional training step to load, e.g. 1800 -> iter_0001800. " + "If not provided, the latest (largest) step in --checkpoint_dir is used.", + ), + ) + parser.add_argument( + "--t5_checkpoint_dir", type=str, default=None, help="Optional directory containing T5 checkpoint/tokenizer" + ) + parser.add_argument( + "--vae_checkpoint_dir", type=str, default=None, help="Optional directory containing VAE checkpoint" + ) + parser.add_argument( + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage.", + ) + parser.add_argument( + "--t5_cpu", + action="store_true", + default=False, + help="Whether to place T5 model on CPU.", + ) + parser.add_argument( + "--save_file", type=str, default=None, help="The file to save the generated image or video to." + ) + parser.add_argument( + "--prompts", + type=str, + nargs="+", + default=None, + help="A list of prompts to generate multiple images or videos. Example: --prompts 'a cat' 'a dog'", + ) + parser.add_argument("--base_seed", type=int, default=-1, help="The seed to use for generating the image or video.") + parser.add_argument("--sample_steps", type=int, default=None, help="The sampling steps.") + parser.add_argument( + "--sample_shift", type=float, default=None, help="Sampling shift factor for flow matching schedulers." + ) + parser.add_argument("--sample_guide_scale", type=float, default=5.0, help="Classifier free guidance scale.") + parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Tensor parallel size.") + parser.add_argument("--context_parallel_size", type=int, default=1, help="Context parallel size.") + parser.add_argument("--pipeline_parallel_size", type=int, default=1, help="Pipeline parallel size.") + parser.add_argument("--sequence_parallel", type=str2bool, default=False, help="Sequence parallel.") + + args = parser.parse_args() + + _validate_args(args) + + return args + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)], + ) + else: + logging.basicConfig(level=logging.ERROR) + + +def generate(args): # noqa: D103 + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = local_rank + _init_logging(rank) + videos = [] + + if args.offload_model is None: + logging.info(f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size) + + inference_cfg = EasyDict( + { + # t5 + "t5_dtype": torch.bfloat16, + "text_len": 512, + # vae + "vae_stride": (4, 8, 8), + # transformer + "param_dtype": torch.bfloat16, + "patch_size": (1, 2, 2), + # others + "num_train_timesteps": 1000, + "sample_fps": 16, + "chinese_sample_neg_prompt": "č‰˛č°ƒč‰ŗä¸ŊīŧŒčŋ‡æ›īŧŒé™æ€īŧŒįģ†čŠ‚æ¨ĄįŗŠä¸æ¸…īŧŒå­—åš•īŧŒéŖŽæ ŧīŧŒäŊœå“īŧŒį”ģäŊœīŧŒį”ģéĸīŧŒé™æ­ĸīŧŒæ•´äŊ“å‘į°īŧŒæœ€åˇŽč´¨é‡īŧŒäŊŽč´¨é‡īŧŒJPEG压įŧпދᕙīŧŒä¸‘é™‹įš„īŧŒæŽ‹įŧēįš„īŧŒå¤šäŊ™įš„æ‰‹æŒ‡īŧŒį”ģ垗不åĨŊįš„æ‰‹éƒ¨īŧŒį”ģ垗不åĨŊįš„č„¸éƒ¨īŧŒį•¸åŊĸįš„īŧŒæ¯åŽšįš„īŧŒåŊĸæ€į•¸åŊĸįš„č‚ĸäŊ“īŧŒæ‰‹æŒ‡čžåˆīŧŒé™æ­ĸä¸åŠ¨įš„į”ģéĸīŧŒæ‚äšąįš„čƒŒæ™¯īŧŒä¸‰æĄč…ŋīŧŒčƒŒæ™¯äēē垈多īŧŒå€’į€čĩ°", + "english_sample_neg_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + } + ) + + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {inference_cfg}") + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + if "t2v" in args.task: + # Resolve prompts list (default to example prompt) + if args.prompts is not None and len(args.prompts) > 0: + prompts = args.prompts + else: + prompts = [EXAMPLE_PROMPT[args.task]["prompt"]] + + # Resolve sizes list (default to first supported size for task) + if args.sizes is not None and len(args.sizes) > 0: + size_keys = args.sizes + else: + size_keys = [SUPPORTED_SIZES[args.task][0]] + + # Resolve frame counts list (default 81) + if args.frame_nums is not None and len(args.frame_nums) > 0: + frame_nums = args.frame_nums + else: + frame_nums = [81] + + # Enforce 1:1 pairing across lists + assert len(prompts) == len(size_keys) == len(frame_nums), ( + f"prompts ({len(prompts)}), sizes ({len(size_keys)}), and frame_nums ({len(frame_nums)}) " + f"must have the same length" + ) + + logging.info("Creating flow inference pipeline.") + pipeline = FlowInferencePipeline( + inference_cfg=inference_cfg, + checkpoint_dir=args.checkpoint_dir, + model_id="Wan-AI/Wan2.1-T2V-14B-Diffusers", + checkpoint_step=args.checkpoint_step, + t5_checkpoint_dir=args.t5_checkpoint_dir, + vae_checkpoint_dir=args.vae_checkpoint_dir, + device_id=device, + rank=rank, + t5_cpu=args.t5_cpu, + tensor_parallel_size=args.tensor_parallel_size, + context_parallel_size=args.context_parallel_size, + pipeline_parallel_size=args.pipeline_parallel_size, + sequence_parallel=args.sequence_parallel, + pipeline_dtype=torch.float32, + ) + + rank = dist.get_rank() + if rank == 0: + print("Running inference with tensor_parallel_size:", args.tensor_parallel_size) + print("Running inference with context_parallel_size:", args.context_parallel_size) + print("Running inference with pipeline_parallel_size:", args.pipeline_parallel_size) + print("Running inference with sequence_parallel:", args.sequence_parallel) + print("\n\n\n") + + logging.info("Generating videos ...") + videos = pipeline.generate( + prompts=prompts, + sizes=[SIZE_CONFIGS[size] for size in size_keys], + frame_nums=frame_nums, + shift=args.sample_shift, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model, + ) + + if rank == 0: + for i, video in enumerate(videos): + formatted_experiment_name = (args.save_file) if args.save_file is not None else "DefaultExp" + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = prompts[i].replace(" ", "_").replace("/", "_")[:50] + suffix = ".mp4" + formatted_save_file = ( + f"{args.task}_{formatted_experiment_name}_videoindex{int(i)}_size{size_keys[i].replace('*', 'x') if sys.platform == 'win32' else size_keys[i]}_{formatted_prompt}_{formatted_time}" + + suffix + ) + + logging.info(f"Saving generated video to {formatted_save_file}") + cache_video( + tensor=video[None], + save_file=formatted_save_file, + fps=inference_cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1), + ) + logging.info("Finished.") + + +if __name__ == "__main__": + args = _parse_args() + generate(args) diff --git a/examples/diffusion/recipes/wan/pretrain_wan.py b/examples/diffusion/recipes/wan/pretrain_wan.py new file mode 100644 index 0000000000..49bad5e737 --- /dev/null +++ b/examples/diffusion/recipes/wan/pretrain_wan.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wan Pretraining Script with YAML and CLI Configuration Overrides. + +This script provides a flexible way to pretrain Wan models using Megatron-Bridge with support for +both YAML configuration files and command-line overrides using Hydra-style syntax. + +Examples: + Basic usage with default configuration: + $ torchrun --nproc_per_node=8 pretrain_wan.py + + Using a custom YAML config file: + $ torchrun --nproc_per_node=8 pretrain_wan.py --config-file my_custom_config.yaml + + Using CLI overrides only: + $ torchrun --nproc_per_node=8 pretrain_wan.py model.tensor_model_parallel_size=4 train.train_iters=100000 + + Combining YAML and CLI overrides (CLI takes precedence): + $ torchrun --nproc_per_node=8 pretrain_wan.py --config-file conf/my_config.yaml \ + model.pipeline_dtype=torch.float16 \ + train.global_batch_size=512 + +Configuration Precedence: + 1. Base configuration from pretrain_config() recipe + 2. YAML overrides from --config-file (if provided) + 3. CLI overrides (highest precedence) + +Supported Override Syntax: + - Standard assignment: key=value + - Nested assignment: section.subsection.key=value + - Addition: +new_key=value + - Deletion: ~key_to_remove + - Type conversion: Automatic for basic types (int, float, bool, str) + - Complex types: torch.dtype, enums, etc. are supported +""" + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import Tuple + +from omegaconf import OmegaConf + +from megatron.bridge.diffusion.models.wan.wan_step import WanForwardStep +from megatron.bridge.diffusion.recipes.wan.wan import pretrain_config +from megatron.bridge.training.config import ConfigContainer +from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.training.utils.omegaconf_utils import ( + apply_overrides, + create_omegaconf_dict_config, + parse_hydra_overrides, +) +from megatron.bridge.utils.common_utils import get_rank_safe + + +logger: logging.Logger = logging.getLogger(__name__) + + +# Define paths relative to this script's location +# Assumes this script (pretrain_wan.py) is in Megatron-Bridge/examples/recipes/wan/ +# and the config is in a 'conf' subdirectory. +SCRIPT_DIR: Path = Path(__file__).parent.resolve() +DEFAULT_CONFIG_FILENAME: str = "wan_pretrain_override_example.yaml" +DEFAULT_CONFIG_FILE_PATH: Path = SCRIPT_DIR / "conf" / DEFAULT_CONFIG_FILENAME + + +def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: + """Parse command line arguments, separating known script args from OmegaConf overrides.""" + parser = argparse.ArgumentParser( + description="Pretrain Wan model using Megatron-Bridge with YAML and CLI overrides", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument("--mock", action="store_true", help="Whether to use mock data.") + parser.add_argument( + "--training-mode", + choices=["pretrain", "finetune"], + default="finetune", + help="Set training mode, 'pretrain' or 'finetune'.", + ) + parser.add_argument( + "--config-file", + type=str, + default=str(DEFAULT_CONFIG_FILE_PATH), + help="Path to the YAML OmegaConf override file. Default: conf/wan_pretrain_override_example.yaml", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + + # Parse known args for the script, remaining will be treated as overrides + args, cli_dotlist_overrides = parser.parse_known_args() + return args, cli_dotlist_overrides + + +def main() -> None: + """ + Entry point for the Wan pretraining script. + + This function orchestrates the complete configuration workflow: + 1. Loads the base configuration from pretrain_config() recipe + 2. Applies YAML overrides from --config-file (if exists) + 3. Applies CLI overrides using Hydra-style syntax + 4. Starts Megatron pretraining with the final merged configuration + + Configuration merging preserves callable fields (like activation functions) + and handles type conversions automatically. + + Examples of CLI usage: + # Use default config with custom learning rate + torchrun --nproc_per_node=8 pretrain_wan.py optimizer.lr=0.0002 + + # Custom config file with additional overrides + torchrun --nproc_per_node=8 pretrain_wan.py --config-file my_config.yaml train.train_iters=50000 + + # Multiple overrides for distributed training + torchrun --nproc_per_node=8 pretrain_wan.py \ + model.tensor_model_parallel_size=4 \ + model.pipeline_model_parallel_size=2 \ + train.global_batch_size=512 + """ + args, cli_overrides = parse_cli_args() + + logger.info("Megatron-Bridge Wan Pretraining Script with YAML & CLI Overrides") + logger.info("------------------------------------------------------------------") + + # Load base configuration from the recipe as a Python dataclass + cfg: ConfigContainer = pretrain_config(mock=args.mock) + logger.info("Loaded base configuration") + + # Print configuration on rank 0 + if get_rank_safe() == 0: + cfg.print_yaml() + + # Convert the initial Python dataclass to an OmegaConf DictConfig for merging + merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg) + + # Load and merge YAML overrides if a config file is provided + if args.config_file: + logger.debug(f"Loading YAML overrides from: {args.config_file}") + if not os.path.exists(args.config_file): + logger.error(f"Override YAML file not found: {args.config_file}") + sys.exit(1) + yaml_overrides_omega = OmegaConf.load(args.config_file) + merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega) + logger.debug("YAML overrides merged successfully.") + + # Apply command-line overrides using Hydra-style parsing + if cli_overrides: + logger.debug(f"Applying Hydra-style command-line overrides: {cli_overrides}") + merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides) + logger.debug("Hydra-style command-line overrides applied successfully.") + + # Apply the final merged OmegaConf configuration back to the original ConfigContainer + logger.debug("Applying final merged configuration back to Python ConfigContainer...") + final_overrides_as_dict = OmegaConf.to_container(merged_omega_conf, resolve=True) + # Apply overrides while preserving excluded fields + apply_overrides(cfg, final_overrides_as_dict, excluded_fields) + + # Config FlowPipeline based on training mode + if args.training_mode == "pretrain": + wan_forward_step = WanForwardStep( + timestep_sampling="logit_normal", + logit_std=1.5, + flow_shift=2.5, + mix_uniform_ratio=0.2, + sigma_min=0.0, + sigma_max=1.0, + ) + elif args.training_mode == "finetune": + wan_forward_step = WanForwardStep( + timestep_sampling="uniform", + logit_std=1.0, + flow_shift=3.0, + mix_uniform_ratio=0.1, + sigma_min=0.0, + sigma_max=1.0, + ) + + # Display final configuration + if get_rank_safe() == 0: + logger.info("--- Final Merged Configuration ---") + cfg.print_yaml() + logger.info("----------------------------------") + + # Start training + logger.debug("Starting pretraining...") + pretrain(config=cfg, forward_step_func=wan_forward_step) + + +if __name__ == "__main__": + main() diff --git a/src/megatron/bridge/diffusion/README.md b/src/megatron/bridge/diffusion/README.md new file mode 100644 index 0000000000..f6806ea2c4 --- /dev/null +++ b/src/megatron/bridge/diffusion/README.md @@ -0,0 +1,40 @@ +# Megatron-Bridge Diffusion + +Diffusion Foundation Models (DFM) integrated into Megatron-Bridge. This module provides +Megatron-based implementations of diffusion models including DiT, FLUX, and WAN. + +## Directory Structure + +``` +diffusion/ +├── models/ # Model implementations (architecture, layers, forward steps) +│ ├── common/ # Shared modules (attention, embeddings, normalization) +│ ├── dit/ # DiT model (with EDM pipeline) +│ ├── flux/ # FLUX model (MMDiT, flow matching) +│ └── wan/ # WAN model (video generation, flow matching, inference) +├── conversion/ # HF ↔ Megatron checkpoint conversion bridges +│ ├── flux/ # FLUX bridge and HF pretrained adapter +│ └── wan/ # WAN bridge and HF pretrained adapter +├── data/ # Data loading and task encoders +│ ├── common/ # Shared data modules (energon, diffusion samples, sequence packing) +│ ├── dit/ # DiT task encoder and mock data +│ ├── flux/ # FLUX task encoder and data modules +│ └── wan/ # WAN task encoder and data modules +├── recipes/ # Training recipe configurations +│ ├── dit/ # DiT pretraining recipe +│ ├── flux/ # FLUX pretraining recipe +│ └── wan/ # WAN pretraining recipe +├── common/ # Shared utilities (video saving, tokenizers, batch ops) +└── base/ # Base module placeholder +``` + +## Supported Models + +- **DiT**: Diffusion Transformer with EDM (Elucidating Diffusion Models) pipeline +- **FLUX**: State-of-the-art text-to-image model using MMDiT-style transformer blocks +- **WAN**: Video generation model with 3D rotary embeddings and flow matching + +## Examples + +Training examples and configuration files are in `examples/diffusion/`. + diff --git a/src/megatron/bridge/diffusion/__init__.py b/src/megatron/bridge/diffusion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/base/README.md b/src/megatron/bridge/diffusion/base/README.md new file mode 100644 index 0000000000..670b57726e --- /dev/null +++ b/src/megatron/bridge/diffusion/base/README.md @@ -0,0 +1,3 @@ +# Base + +Base classes and interfaces for Megatron models. diff --git a/src/megatron/bridge/diffusion/base/__init__.py b/src/megatron/bridge/diffusion/base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/common/__init__.py b/src/megatron/bridge/diffusion/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/common/tokenizers/__init__.py b/src/megatron/bridge/diffusion/common/tokenizers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/common/utils/__init__.py b/src/megatron/bridge/diffusion/common/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/common/utils/batch_ops.py b/src/megatron/bridge/diffusion/common/utils/batch_ops.py new file mode 100644 index 0000000000..956dfbee36 --- /dev/null +++ b/src/megatron/bridge/diffusion/common/utils/batch_ops.py @@ -0,0 +1,104 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import Tensor + + +def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: + """ + Broadcasts two tensors to have the same shape by adding singleton dimensions where necessary. + + Args: + x (Tensor): The first input tensor. + y (Tensor): The second input tensor. + + Returns: + tuple[Tensor, Tensor]: A tuple containing the two tensors with broadcasted shapes. + + Raises: + AssertionError: If the dimensions of the tensors do not match at any axis within their common dimensions. + """ + ndims1 = x.ndim + ndims2 = y.ndim + + common_ndims = min(ndims1, ndims2) + for axis in range(common_ndims): + assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis) + + if ndims1 < ndims2: + x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) + elif ndims2 < ndims1: + y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) + + return x, y + + +def batch_add(x: Tensor, y: Tensor) -> Tensor: + """ + Adds two tensors element-wise after broadcasting them to a common shape. + + Args: + x (Tensor): The first input tensor. + y (Tensor): The second input tensor. + + Returns: + Tensor: The element-wise sum of the input tensors after broadcasting. + """ + x, y = common_broadcast(x, y) + return x + y + + +def batch_mul(x: Tensor, y: Tensor) -> Tensor: + """ + Multiplies two tensors element-wise after broadcasting them to a common shape. + + Args: + x (Tensor): The first input tensor. + y (Tensor): The second input tensor. + + Returns: + Tensor: The element-wise product of the input tensors after broadcasting. + """ + x, y = common_broadcast(x, y) + return x * y + + +def batch_sub(x: Tensor, y: Tensor) -> Tensor: + """ + Subtracts two tensors element-wise after broadcasting them to a common shape. + + Args: + x (Tensor): The first input tensor. + y (Tensor): The second input tensor. + + Returns: + Tensor: The result of element-wise subtraction of the input tensors. + """ + x, y = common_broadcast(x, y) + return x - y + + +def batch_div(x: Tensor, y: Tensor) -> Tensor: + """ + Divides two tensors element-wise after broadcasting them to a common shape. + + Args: + x (Tensor): The first input tensor. + y (Tensor): The second input tensor. + + Returns: + Tensor: The result of element-wise division of `x` by `y` after broadcasting. + """ + x, y = common_broadcast(x, y) + return x / y diff --git a/src/megatron/bridge/diffusion/common/utils/dynamic_import.py b/src/megatron/bridge/diffusion/common/utils/dynamic_import.py new file mode 100644 index 0000000000..dfe4c6dd11 --- /dev/null +++ b/src/megatron/bridge/diffusion/common/utils/dynamic_import.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + + +def dynamic_import(full_path): + """ + Dynamically import a class or function from a given full path. + + :param full_path: The full path to the class or function (e.g., "package.module.ClassName") + :return: The imported class or function + :raises ImportError: If the module or attribute cannot be imported + :raises AttributeError: If the attribute does not exist in the module + """ + try: + # Split the full path into module path and attribute name + module_path, attribute_name = full_path.rsplit(".", 1) + except ValueError as e: + raise ImportError( + f"Invalid full path '{full_path}'. It should contain both module and attribute names." + ) from e + + # Import the module + try: + module = importlib.import_module(module_path) + except ImportError as e: + raise ImportError(f"Cannot import module '{module_path}'.") from e + + # Retrieve the attribute from the module + try: + attribute = getattr(module, attribute_name) + except AttributeError as e: + raise AttributeError(f"Module '{module_path}' does not have an attribute '{attribute_name}'.") from e + + return attribute diff --git a/src/megatron/bridge/diffusion/common/utils/save_video.py b/src/megatron/bridge/diffusion/common/utils/save_video.py new file mode 100644 index 0000000000..48c4f96bb1 --- /dev/null +++ b/src/megatron/bridge/diffusion/common/utils/save_video.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import imageio +import numpy as np + + +def save_video( # noqa: D103 + grid: np.ndarray, + fps: int, + H: int, + W: int, + video_save_quality: int, + video_save_path: str, + caption: str = None, +): + ffmpeg_params = ["-s", f"{W}x{H}"] + + # Add caption as metadata if provided + if caption is not None: + ffmpeg_params.extend( + [ + "-metadata", + f"description={caption}", + ] + ) + + kwargs = { + "fps": fps, + "quality": video_save_quality, + "macro_block_size": 1, + "ffmpeg_params": ffmpeg_params, + "output_params": ["-f", "mp4"], + } + + imageio.mimsave(video_save_path, grid, "mp4", **kwargs) diff --git a/src/megatron/bridge/diffusion/common/utils/torch_split_tensor_for_cp.py b/src/megatron/bridge/diffusion/common/utils/torch_split_tensor_for_cp.py new file mode 100644 index 0000000000..6d72050bbe --- /dev/null +++ b/src/megatron/bridge/diffusion/common/utils/torch_split_tensor_for_cp.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import transformer_engine_torch as tex +from torch import Tensor +from torch.distributed import ProcessGroup, all_gather, get_rank, get_world_size + + +def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup, thd_cu_seqlens: Optional[Tensor] = None) -> Tensor: + """ + Concatenates tensors from multiple processes along a specified dimension. + + This function gathers tensors from all processes in the given process group + and concatenates them along the specified dimension. + + Args: + x (Tensor): The input tensor to be gathered and concatenated. + seq_dim (int): The dimension along which to concatenate the gathered tensors. + cp_group (ProcessGroup): The process group containing all the processes involved in the gathering. + thd_cu_seqlens (Tensor, optional): THD cumulative sequence lengths used during partitioning. Provide + this to restore the original token order after gathering. + + Returns: + Tensor: A tensor resulting from the concatenation of tensors from all processes. If `thd_cu_seqlens` + is provided, the tensor is reordered to match the original (pre-partition) sequence order. + + Raises: + RuntimeError: If the gathering of tensors fails. + """ + # Number of processes in the group + world_size = get_world_size(cp_group) + # List to hold tensors from each rank + gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)] + + # Attempt to gather tensors from all ranks + all_gather(gathered_tensors, x, group=cp_group) + + # Concatenate tensors along the specified dimension + gathered = torch.cat(gathered_tensors, dim=seq_dim) + total_seq_len = int(thd_cu_seqlens[-1].item()) + # Rebuild the global index ordering used during THD partitioning. + cp_rank = get_rank(cp_group) + local_indices = tex.thd_get_partitioned_indices(thd_cu_seqlens, total_seq_len, world_size, cp_rank).to( + device=x.device, dtype=torch.long + ) + + # Gather indices from all ranks to compute the inverse permutation. + gathered_indices = [torch.empty_like(local_indices) for _ in range(world_size)] + all_gather(gathered_indices, local_indices, group=cp_group) + global_indices = torch.cat(gathered_indices, dim=0) + + if global_indices.numel() != gathered.size(seq_dim): + raise RuntimeError("Gathered indices size does not match gathered tensor along sequence dimension.") + + restore_order = torch.argsort(global_indices, dim=0) + gathered = gathered.index_select(seq_dim, restore_order.to(device=gathered.device)) + return gathered.contiguous() diff --git a/src/megatron/bridge/diffusion/conversion/__init__.py b/src/megatron/bridge/diffusion/conversion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/conversion/flux/__init__.py b/src/megatron/bridge/diffusion/conversion/flux/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/conversion/flux/flux_bridge.py b/src/megatron/bridge/diffusion/conversion/flux/flux_bridge.py new file mode 100644 index 0000000000..b517883859 --- /dev/null +++ b/src/megatron/bridge/diffusion/conversion/flux/flux_bridge.py @@ -0,0 +1,237 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Mapping + +import torch +from diffusers import FluxTransformer2DModel + +from megatron.bridge.diffusion.conversion.flux.flux_hf_pretrained import PreTrainedFlux +from megatron.bridge.diffusion.models.flux.flux_model import Flux +from megatron.bridge.diffusion.models.flux.flux_provider import FluxProvider +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + QKVMapping, + RowParallelMapping, +) + + +@MegatronModelBridge.register_bridge(source=FluxTransformer2DModel, target=Flux) +class FluxBridge(MegatronModelBridge): + """ + Megatron Bridge for FLUX model. + + As a user you would not use this bridge directly, but through `AutoBridge`. + + Example: + >>> from megatron.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("black-forest-labs/FLUX.1-dev") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedFlux) -> FluxProvider: + hf_config = hf_pretrained.config + + provider = FluxProvider( + in_channels=hf_config.in_channels, + patch_size=hf_config.patch_size, + num_joint_layers=hf_config.num_layers, + num_single_layers=hf_config.num_single_layers, + num_attention_heads=hf_config.num_attention_heads, + # out_channels: None + # joint_attention_dim: 4096 + kv_channels=hf_config.attention_head_dim, + num_query_groups=hf_config.num_attention_heads, + vec_in_dim=hf_config.pooled_projection_dim, + guidance_embed=hf_config.guidance_embeds, + axes_dims_rope=hf_config.axes_dims_rope, + bf16=False, + params_dtype=torch.float32, + ) + self.hidden_size = provider.hidden_size + return provider + + def maybe_modify_loaded_hf_weight( + self, hf_param: str | dict[str, str], hf_state_dict: Mapping[str, torch.Tensor] + ) -> torch.Tensor: + """Load weights from HuggingFace state dict. + This function can be overridden by subclasses to preprocess the HF weights before conversion, such as renaming + certain parameters to avoid mapping conflicts, or dequantize the weights. + + Note that loading is done lazily before this function is called, so the weights are actually loaded in + this function when hf_state_dict.__getitem__ is called. + + Args: + hf_param: The parameter name or dictionary of parameter names to load. + hf_state_dict: The HuggingFace state dictionary. + + Returns: + The loaded weights. + """ + if isinstance(hf_param, str): + if hf_param.endswith("weight_1"): + hf_weights = hf_state_dict[hf_param.replace("weight_1", "weight")] + hf_weights = hf_weights[:, self.hidden_size :] + elif hf_param.endswith("weight_2"): + hf_weights = hf_state_dict[hf_param.replace("weight_2", "weight")] + hf_weights = hf_weights[:, : self.hidden_size] + else: + hf_weights = hf_state_dict[hf_param] + else: + hf_weights = {k: hf_state_dict[v] for k, v in hf_param.items()} + return hf_weights + + def mapping_registry(self) -> MegatronMappingRegistry: + """Return MegatronMappingRegistry containing parameter mappings from HF to Megatron format. + + Returns: + MegatronMappingRegistry: Registry of parameter mappings + """ + # Dictionary maps HF parameter names -> Megatron parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "norm_out.linear.bias": "norm_out.adaLN_modulation.1.bias", + "norm_out.linear.weight": "norm_out.adaLN_modulation.1.weight", + "proj_out.bias": "proj_out.bias", + "proj_out.weight": "proj_out.weight", + "time_text_embed.guidance_embedder.linear_1.bias": "guidance_embedding.in_layer.bias", + "time_text_embed.guidance_embedder.linear_1.weight": "guidance_embedding.in_layer.weight", + "time_text_embed.guidance_embedder.linear_2.bias": "guidance_embedding.out_layer.bias", + "time_text_embed.guidance_embedder.linear_2.weight": "guidance_embedding.out_layer.weight", + "x_embedder.bias": "img_embed.bias", + "x_embedder.weight": "img_embed.weight", + "context_embedder.bias": "txt_embed.bias", + "context_embedder.weight": "txt_embed.weight", + "time_text_embed.timestep_embedder.linear_1.bias": "timestep_embedding.time_embedder.in_layer.bias", + "time_text_embed.timestep_embedder.linear_1.weight": "timestep_embedding.time_embedder.in_layer.weight", + "time_text_embed.timestep_embedder.linear_2.bias": "timestep_embedding.time_embedder.out_layer.bias", + "time_text_embed.timestep_embedder.linear_2.weight": "timestep_embedding.time_embedder.out_layer.weight", + "time_text_embed.text_embedder.linear_1.bias": "vector_embedding.in_layer.bias", + "time_text_embed.text_embedder.linear_1.weight": "vector_embedding.in_layer.weight", + "time_text_embed.text_embedder.linear_2.bias": "vector_embedding.out_layer.bias", + "time_text_embed.text_embedder.linear_2.weight": "vector_embedding.out_layer.weight", + "transformer_blocks.*.norm1.linear.weight": "double_blocks.*.adaln.linear.weight", + "transformer_blocks.*.norm1.linear.bias": "double_blocks.*.adaln.linear.bias", + "transformer_blocks.*.norm1_context.linear.weight": "double_blocks.*.adaln_context.linear.weight", + "transformer_blocks.*.norm1_context.linear.bias": "double_blocks.*.adaln_context.linear.bias", + "transformer_blocks.*.attn.norm_q.weight": "double_blocks.*.self_attention.q_layernorm.weight", + "transformer_blocks.*.attn.norm_k.weight": "double_blocks.*.self_attention.k_layernorm.weight", + "transformer_blocks.*.attn.norm_added_q.weight": "double_blocks.*.self_attention.added_q_layernorm.weight", + "transformer_blocks.*.attn.norm_added_k.weight": "double_blocks.*.self_attention.added_k_layernorm.weight", + "transformer_blocks.*.attn.to_out.0.weight": "double_blocks.*.self_attention.linear_proj.weight", + "transformer_blocks.*.attn.to_out.0.bias": "double_blocks.*.self_attention.linear_proj.bias", + "transformer_blocks.*.attn.to_add_out.weight": "double_blocks.*.self_attention.added_linear_proj.weight", + "transformer_blocks.*.attn.to_add_out.bias": "double_blocks.*.self_attention.added_linear_proj.bias", + "transformer_blocks.*.ff.net.0.proj.weight": "double_blocks.*.mlp.linear_fc1.weight", + "transformer_blocks.*.ff.net.0.proj.bias": "double_blocks.*.mlp.linear_fc1.bias", + "transformer_blocks.*.ff.net.2.weight": "double_blocks.*.mlp.linear_fc2.weight", + "transformer_blocks.*.ff.net.2.bias": "double_blocks.*.mlp.linear_fc2.bias", + "transformer_blocks.*.ff_context.net.0.proj.weight": "double_blocks.*.context_mlp.linear_fc1.weight", + "transformer_blocks.*.ff_context.net.0.proj.bias": "double_blocks.*.context_mlp.linear_fc1.bias", + "transformer_blocks.*.ff_context.net.2.weight": "double_blocks.*.context_mlp.linear_fc2.weight", + "transformer_blocks.*.ff_context.net.2.bias": "double_blocks.*.context_mlp.linear_fc2.bias", + "single_transformer_blocks.*.norm.linear.weight": "single_blocks.*.adaln.linear.weight", + "single_transformer_blocks.*.norm.linear.bias": "single_blocks.*.adaln.linear.bias", + "single_transformer_blocks.*.proj_mlp.weight": "single_blocks.*.mlp.linear_fc1.weight", + "single_transformer_blocks.*.proj_mlp.bias": "single_blocks.*.mlp.linear_fc1.bias", + "single_transformer_blocks.*.attn.norm_q.weight": "single_blocks.*.self_attention.q_layernorm.weight", + "single_transformer_blocks.*.attn.norm_k.weight": "single_blocks.*.self_attention.k_layernorm.weight", + } + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(hf_param, megatron_param) + for hf_param, megatron_param in param_mappings.items(): + mapping_list.append(AutoMapping(hf_param=hf_param, megatron_param=megatron_param)) + + # Split proj_out into linear_fc2 and linear_proj + # The proj_out weight is split between MLP output and attention projection + # The proj_out bias is mapped to MLP output + mapping_list.append( + AutoMapping( + hf_param="single_transformer_blocks.*.proj_out.bias", + megatron_param="single_blocks.*.mlp.linear_fc2.bias", + ) + ) + mapping_list.append( + SplitRowParallelMapping( + hf_param="single_transformer_blocks.*.proj_out.weight_1", + megatron_param="single_blocks.*.mlp.linear_fc2.weight", + ) + ) + mapping_list.append( + SplitRowParallelMapping( + hf_param="single_transformer_blocks.*.proj_out.weight_2", + megatron_param="single_blocks.*.self_attention.linear_proj.weight", + ) + ) + + AutoMapping.register_module_type("Linear", "replicated") + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # Single blockQKV: Combine separate Q, K, V matrices into single QKV matrix + QKVMapping( + q="single_transformer_blocks.*.attn.to_q.weight", + k="single_transformer_blocks.*.attn.to_k.weight", + v="single_transformer_blocks.*.attn.to_v.weight", + megatron_param="single_blocks.*.self_attention.linear_qkv.weight", + ), + # Single block QKV bias: Combine separate Q, K, V bias into single QKV bias + QKVMapping( + q="single_transformer_blocks.*.attn.to_q.bias", + k="single_transformer_blocks.*.attn.to_k.bias", + v="single_transformer_blocks.*.attn.to_v.bias", + megatron_param="single_blocks.*.self_attention.linear_qkv.bias", + ), + # Double block Self-attention QKV weights + QKVMapping( + q="transformer_blocks.*.attn.to_q.weight", + k="transformer_blocks.*.attn.to_k.weight", + v="transformer_blocks.*.attn.to_v.weight", + megatron_param="double_blocks.*.self_attention.linear_qkv.weight", + ), + # Double block Self-attention QKV bias + QKVMapping( + q="transformer_blocks.*.attn.to_q.bias", + k="transformer_blocks.*.attn.to_k.bias", + v="transformer_blocks.*.attn.to_v.bias", + megatron_param="double_blocks.*.self_attention.linear_qkv.bias", + ), + # Double block Added (context) attention QKV weights + QKVMapping( + q="transformer_blocks.*.attn.add_q_proj.weight", + k="transformer_blocks.*.attn.add_k_proj.weight", + v="transformer_blocks.*.attn.add_v_proj.weight", + megatron_param="double_blocks.*.self_attention.added_linear_qkv.weight", + ), + # Double block Added (context) attention QKV bias + QKVMapping( + q="transformer_blocks.*.attn.add_q_proj.bias", + k="transformer_blocks.*.attn.add_k_proj.bias", + v="transformer_blocks.*.attn.add_v_proj.bias", + megatron_param="double_blocks.*.self_attention.added_linear_qkv.bias", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) + + +class SplitRowParallelMapping(RowParallelMapping): # noqa: D101 + def __init__(self, megatron_param: str, hf_param: str): + super().__init__(megatron_param, hf_param) + self.allow_hf_name_mismatch = True diff --git a/src/megatron/bridge/diffusion/conversion/flux/flux_hf_pretrained.py b/src/megatron/bridge/diffusion/conversion/flux/flux_hf_pretrained.py new file mode 100644 index 0000000000..22fe4b0c1e --- /dev/null +++ b/src/megatron/bridge/diffusion/conversion/flux/flux_hf_pretrained.py @@ -0,0 +1,121 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import shutil +from pathlib import Path +from typing import Union + +from diffusers import FluxTransformer2DModel +from transformers import AutoConfig + +from megatron.bridge.models.hf_pretrained.base import PreTrainedBase +from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource, StateDict, StateSource + + +class FluxSafeTensorsStateSource(SafeTensorsStateSource): + """ + FLUX-specific state source that writes exported HF shards under 'transformer/'. + """ + + def save_generator(self, generator, output_path, strict: bool = True): + # Ensure shards are written under transformer/ + target_dir = Path(output_path) / "transformer" + return super().save_generator(generator, target_dir, strict=strict) + + +class PreTrainedFlux(PreTrainedBase): + """ + Lightweight pretrained wrapper for Diffusers FLUX models. + + Provides access to FLUX config and state through the common PreTrainedBase API + so bridges can consume `.config` and `.state` uniformly. + + NOTE: Due to FLUX uses HF's Diffusers library, which has different checkpoint directory structure to HF's Transformer library, + we need a wrapper to load the model weights and config from the correct directory (e.g., ./transformer). + The diffusers's structure includes all components in the diffusion pipeline (VAE, text encoders, etc.). + The actual transformer weights are stored in the ./transformer directory. Hence, we adjust the input and output + path directory accordingly. We also need to override the save_artifacts method to save relevant correct configs + files to the corresponding directory. + """ + + def __init__(self, model_name_or_path: Union[str, Path], **kwargs): + self._model_name_or_path = str(model_name_or_path) + super().__init__(**kwargs) + + @property + def model_name_or_path(self) -> str: + return self._model_name_or_path + + # Model loading is optional for conversion; implemented for completeness + def _load_model(self) -> FluxTransformer2DModel: + return FluxTransformer2DModel.from_pretrained(self.model_name_or_path) + + # Config is required by the FLUX bridge + def _load_config(self) -> AutoConfig: + # FluxTransformer2DModel returns a config-like object with required fields + + print(f"Loading config from {self.model_name_or_path}") + + return FluxTransformer2DModel.from_pretrained(self.model_name_or_path, subfolder="transformer").config + + @property + def state(self) -> StateDict: + """ + FLUX-specific StateDict that reads safetensors from the fixed 'transformer/' subfolder. + """ + if getattr(self, "_state_dict_accessor", None) is None: + source: StateSource | None = None + if hasattr(self, "_model") and self._model is not None: + # If model is loaded, use its in-memory state_dict + source = self.model.state_dict() + else: + # Always load from 'transformer/' subfolder for FLUX + source = FluxSafeTensorsStateSource(Path(self.model_name_or_path) / "transformer") + self._state_dict_accessor = StateDict(source) + return self._state_dict_accessor + + def save_artifacts(self, save_directory: Union[str, Path]): + """ + Save FLUX artifacts (currently config) alongside exported weights. + Writes transformer/config.json into the destination. + """ + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + # Ensure transformer subdir exists at destination + dest_transformer = save_path / "transformer" + dest_transformer.mkdir(parents=True, exist_ok=True) + + # 1) If source has a config.json under transformer/, copy it + src_config = Path(self.model_name_or_path) / "transformer" / "config.json" + src_index = Path(self.model_name_or_path) / "transformer" / "diffusion_pytorch_model.safetensors.index.json" + if src_config.exists(): + shutil.copyfile(src_config, dest_transformer / "config.json") + if src_index.exists(): + shutil.copyfile(src_index, dest_transformer / "diffusion_pytorch_model.safetensors.index.json") + return + + # 2) Otherwise, try to export config from the HF model instance + try: + model = FluxTransformer2DModel.from_pretrained(self.model_name_or_path, subfolder="transformer") + cfg = getattr(model, "config", None) + if cfg is not None: + # Prefer to_dict if available + cfg_dict = cfg.to_dict() if hasattr(cfg, "to_dict") else dict(cfg) + with open(dest_transformer / "config.json", "w") as f: + json.dump(cfg_dict, f, indent=2) + except Exception: + # Best-effort: if config cannot be produced, leave only weights + pass diff --git a/src/megatron/bridge/diffusion/conversion/wan/__init__.py b/src/megatron/bridge/diffusion/conversion/wan/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/conversion/wan/wan_bridge.py b/src/megatron/bridge/diffusion/conversion/wan/wan_bridge.py new file mode 100644 index 0000000000..a42e510f2c --- /dev/null +++ b/src/megatron/bridge/diffusion/conversion/wan/wan_bridge.py @@ -0,0 +1,187 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from diffusers import WanTransformer3DModel + +from megatron.bridge.diffusion.conversion.wan.wan_hf_pretrained import PreTrainedWAN +from megatron.bridge.diffusion.models.wan.wan_model import WanModel +from megatron.bridge.diffusion.models.wan.wan_provider import WanModelProvider +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + KVMapping, + QKVMapping, + ReplicatedMapping, +) +from megatron.bridge.models.conversion.utils import get_module_and_param_from_name + + +@MegatronModelBridge.register_bridge(source=WanTransformer3DModel, target=WanModel) +class WanBridge(MegatronModelBridge): + """ + Megatron Bridge for WAN model. + + As a user you would not use this bridge directly, but through `AutoBridge`. + + Example: + >>> from megatron.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("WAN-3D-1.3B-v1") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedWAN) -> WanModelProvider: + hf_config = hf_pretrained.config + + cls = WanModelProvider + + provider = cls( + num_layers=hf_config.num_layers, + hidden_size=hf_config.num_attention_heads * hf_config.attention_head_dim, + kv_channels=hf_config.attention_head_dim, + num_query_groups=hf_config.num_attention_heads, + crossattn_emb_size=hf_config.num_attention_heads * hf_config.attention_head_dim, + ffn_hidden_size=hf_config.ffn_dim, + num_attention_heads=hf_config.num_attention_heads, + in_channels=hf_config.in_channels, + out_channels=hf_config.out_channels, + text_dim=hf_config.text_dim, + patch_spatial=hf_config.patch_size[1], + patch_temporal=hf_config.patch_size[0], + layernorm_epsilon=hf_config.eps, + hidden_dropout=0, + attention_dropout=0, + use_cpu_initialization=True, + freq_dim=hf_config.freq_dim, + bf16=False, + params_dtype=torch.float32, + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + """Return MegatronMappingRegistry containing parameter mappings from HF to Megatron format. + + Returns: + MegatronMappingRegistry: Registry of parameter mappings + """ + # Dictionary maps HF parameter names -> Megatron parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "scale_shift_table": "head.modulation", + "patch_embedding.weight": "patch_embedding.weight", + "patch_embedding.bias": "patch_embedding.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedder.linear_1.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedder.linear_1.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedder.linear_2.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedder.linear_2.bias", + "condition_embedder.time_proj.weight": "time_proj.weight", + "condition_embedder.time_proj.bias": "time_proj.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "blocks.*.scale_shift_table": "decoder.layers.*.adaLN.modulation", + "blocks.*.attn1.to_out.0.weight": "decoder.layers.*.full_self_attention.linear_proj.weight", + "blocks.*.attn1.to_out.0.bias": "decoder.layers.*.full_self_attention.linear_proj.bias", + "blocks.*.attn1.norm_q.weight": "decoder.layers.*.full_self_attention.q_layernorm.weight", + "blocks.*.attn1.norm_k.weight": "decoder.layers.*.full_self_attention.k_layernorm.weight", + "blocks.*.attn2.to_q.weight": "decoder.layers.*.cross_attention.linear_q.weight", + "blocks.*.attn2.to_q.bias": "decoder.layers.*.cross_attention.linear_q.bias", + "blocks.*.attn2.to_out.0.weight": "decoder.layers.*.cross_attention.linear_proj.weight", + "blocks.*.attn2.to_out.0.bias": "decoder.layers.*.cross_attention.linear_proj.bias", + "blocks.*.attn2.norm_q.weight": "decoder.layers.*.cross_attention.q_layernorm.weight", + "blocks.*.attn2.norm_k.weight": "decoder.layers.*.cross_attention.k_layernorm.weight", + "blocks.*.norm2.weight": "decoder.layers.*.norm3.weight", + "blocks.*.norm2.bias": "decoder.layers.*.norm3.bias", + "blocks.*.ffn.net.0.proj.weight": "decoder.layers.*.mlp.linear_fc1.weight", + "blocks.*.ffn.net.0.proj.bias": "decoder.layers.*.mlp.linear_fc1.bias", + "blocks.*.ffn.net.2.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "blocks.*.ffn.net.2.bias": "decoder.layers.*.mlp.linear_fc2.bias", + "proj_out.weight": "head.head.weight", + "proj_out.bias": "head.head.bias", + } + + # Custom WAN mapping to safely handle replicated params whose owning module + # does not expose a top-level `.weight` (e.g., Head.modulation) + class _ReplicatedByParamNameMapping(ReplicatedMapping): + def hf_to_megatron(self, hf_weights, megatron_module): + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + + target_device = target_param.device + target_dtype = target_param.dtype + + hf_weights = hf_weights.to(device=target_device, dtype=target_dtype) + if self.tp_size == 1: + return hf_weights + + if target_device.type == "cuda" and torch.cuda.is_available(): + if target_device.index != torch.cuda.current_device(): + hf_weights = hf_weights.to(torch.cuda.current_device()) + + if self.tp_rank > 0: + hf_weights = torch.empty_like(hf_weights) + + return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0) + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(hf_param, megatron_param) + for hf_param, megatron_param in param_mappings.items(): + if hf_param in {"scale_shift_table", "blocks.*.scale_shift_table", "proj_out.weight", "proj_out.bias"}: + # Use WAN-specific replicated mapping that resolves the exact param + mapping_list.append(_ReplicatedByParamNameMapping(hf_param=hf_param, megatron_param=megatron_param)) + else: + mapping_list.append(AutoMapping(hf_param=hf_param, megatron_param=megatron_param)) + + # Adding custom module types for AutoMapping + AutoMapping.register_module_type("Linear", "replicated") + AutoMapping.register_module_type("Conv3d", "replicated") + AutoMapping.register_module_type("WanAdaLN", "replicated") + AutoMapping.register_module_type("Head", "replicated") + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # QKV: Combine separate Q, K, V matrices into single QKV matrix + QKVMapping( + q="blocks.*.attn1.to_q.weight", + k="blocks.*.attn1.to_k.weight", + v="blocks.*.attn1.to_v.weight", + megatron_param="decoder.layers.*.full_self_attention.linear_qkv.weight", + ), + # QKV bias: Combine separate Q, K, V bias into single QKV bias + QKVMapping( + q="blocks.*.attn1.to_q.bias", + k="blocks.*.attn1.to_k.bias", + v="blocks.*.attn1.to_v.bias", + megatron_param="decoder.layers.*.full_self_attention.linear_qkv.bias", + ), + # K, V: Combine separate K, V matrices into single KV matrix + KVMapping( + k="blocks.*.attn2.to_k.weight", + v="blocks.*.attn2.to_v.weight", + megatron_param="decoder.layers.*.cross_attention.linear_kv.weight", + ), + # K, V bias: Combine separate K, V bias into single KV bias + KVMapping( + k="blocks.*.attn2.to_k.bias", + v="blocks.*.attn2.to_v.bias", + megatron_param="decoder.layers.*.cross_attention.linear_kv.bias", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) diff --git a/src/megatron/bridge/diffusion/conversion/wan/wan_hf_pretrained.py b/src/megatron/bridge/diffusion/conversion/wan/wan_hf_pretrained.py new file mode 100644 index 0000000000..7457467ee1 --- /dev/null +++ b/src/megatron/bridge/diffusion/conversion/wan/wan_hf_pretrained.py @@ -0,0 +1,121 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import shutil +from pathlib import Path +from typing import Union + +from diffusers import WanTransformer3DModel +from transformers import AutoConfig + +from megatron.bridge.models.hf_pretrained.base import PreTrainedBase +from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource, StateDict, StateSource + + +class WanSafeTensorsStateSource(SafeTensorsStateSource): + """ + WAN-specific state source that writes exported HF shards under 'transformer/'. + """ + + def save_generator(self, generator, output_path, strict: bool = True): + # Ensure shards are written under transformer/ + target_dir = Path(output_path) / "transformer" + return super().save_generator(generator, target_dir, strict=strict) + + +class PreTrainedWAN(PreTrainedBase): + """ + Lightweight pretrained wrapper for Diffusers WAN models. + + Provides access to WAN config and state through the common PreTrainedBase API + so bridges can consume `.config` and `.state` uniformly. + + NOTE: Due to Wan uses HF's Diffusers library, which has different checkpoint directory structure to HF's Transformer library, + we need a wrapper to load the model weights and config from the correct directory (e.g., ./transformer). + The diffusers's structure includes all components in the diffusion pipeline (VAE, text encoders, etc.). + The actual transformer weights are stored in the ./transformer directory. Hence, we adjust the input and output + path directory accordingly. We also need to override the save_artifacts method to save relevant correct configs + files to the corresponding directory. + """ + + def __init__(self, model_name_or_path: Union[str, Path], **kwargs): + self._model_name_or_path = str(model_name_or_path) + super().__init__(**kwargs) + + @property + def model_name_or_path(self) -> str: + return self._model_name_or_path + + # Model loading is optional for conversion; implemented for completeness + def _load_model(self) -> WanTransformer3DModel: + return WanTransformer3DModel.from_pretrained(self.model_name_or_path) + + # Config is required by the WAN bridge + def _load_config(self) -> AutoConfig: + # WanTransformer3DModel returns a config-like object with required fields + + print(f"Loading config from {self.model_name_or_path}") + + return WanTransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer").config + + @property + def state(self) -> StateDict: + """ + WAN-specific StateDict that reads safetensors from the fixed 'transformer/' subfolder. + """ + if getattr(self, "_state_dict_accessor", None) is None: + source: StateSource | None = None + if hasattr(self, "_model") and self._model is not None: + # If model is loaded, use its in-memory state_dict + source = self.model.state_dict() + else: + # Always load from 'transformer/' subfolder for WAN + source = WanSafeTensorsStateSource(Path(self.model_name_or_path) / "transformer") + self._state_dict_accessor = StateDict(source) + return self._state_dict_accessor + + def save_artifacts(self, save_directory: Union[str, Path]): + """ + Save WAN artifacts (currently config) alongside exported weights. + Writes transformer/config.json into the destination. + """ + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + # Ensure transformer subdir exists at destination + dest_transformer = save_path / "transformer" + dest_transformer.mkdir(parents=True, exist_ok=True) + + # 1) If source has a config.json under transformer/, copy it + src_config = Path(self.model_name_or_path) / "transformer" / "config.json" + src_index = Path(self.model_name_or_path) / "transformer" / "diffusion_pytorch_model.safetensors.index.json" + if src_config.exists(): + shutil.copyfile(src_config, dest_transformer / "config.json") + if src_index.exists(): + shutil.copyfile(src_index, dest_transformer / "diffusion_pytorch_model.safetensors.index.json") + return + + # 2) Otherwise, try to export config from the HF model instance + try: + model = WanTransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer") + cfg = getattr(model, "config", None) + if cfg is not None: + # Prefer to_dict if available + cfg_dict = cfg.to_dict() if hasattr(cfg, "to_dict") else dict(cfg) + with open(dest_transformer / "config.json", "w") as f: + json.dump(cfg_dict, f, indent=2) + except Exception: + # Best-effort: if config cannot be produced, leave only weights + pass diff --git a/src/megatron/bridge/diffusion/data/__init__.py b/src/megatron/bridge/diffusion/data/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/data/common/__init__.py b/src/megatron/bridge/diffusion/data/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/data/common/diffusion_energon_datamodule.py b/src/megatron/bridge/diffusion/data/common/diffusion_energon_datamodule.py new file mode 100644 index 0000000000..030116151e --- /dev/null +++ b/src/megatron/bridge/diffusion/data/common/diffusion_energon_datamodule.py @@ -0,0 +1,174 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + + +import logging +from dataclasses import dataclass +from typing import Any, Dict, Literal + +from megatron.energon import DefaultTaskEncoder, get_train_dataset +from torch import int_repr + +from megatron.bridge.data.energon.base_energon_datamodule import EnergonMultiModalDataModule +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider + + +@dataclass(kw_only=True) +class DiffusionDataModuleConfig(DatasetProvider): # noqa: D101 + path: str + seq_length: int + micro_batch_size: int + packing_buffer_size: int + global_batch_size: int + num_workers: int_repr + task_encoder_seq_length: int = None + dataloader_type: str = "external" + use_train_split_for_val: bool = False + + def build_datasets(self, context: DatasetBuildContext): + return ( + iter(self.dataset.train_dataloader()), + iter(self.dataset.val_dataloader()), + iter(self.dataset.val_dataloader()), + ) + + +class DiffusionDataModule(EnergonMultiModalDataModule): + """ + A PyTorch Lightning DataModule for handling multimodal datasets with images and text. + + This data module is designed to work with multimodal datasets that involve both images and text. + It provides a seamless interface to load training and validation data, manage batching, and handle + the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon + framework for efficient data handling in large-scale distributed training. + + Attributes: + path (str): Path to the energon dataset. + tokenizer (Tokenizer): The tokenizer used for processing text. + image_processor (ImageProcessor): The image processor used for preprocessing images. + seq_length (int): The maximum sequence length for tokenized text. + micro_batch_size (int): The batch size for training and validation. + num_workers (int): Number of workers for data loading. + pin_memory (bool): Whether to pin memory in the DataLoader. + multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples. + task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples. + init_global_step (int): The initial global step for the trainer, used for resuming training. + data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples. + train_dataloader_object (Optional): The DataLoader object for training data. + val_dataloader_object (Optional): The DataLoader object for validation data. + """ + + def __init__( + self, + path: str, + seq_length: int = 2048, + micro_batch_size: int = 1, + global_batch_size: int = 8, + num_workers: int = 1, + pin_memory: bool = True, + packing_buffer_size: int = None, + task_encoder: DefaultTaskEncoder = None, + use_train_split_for_val: bool = False, + ) -> None: + """ + Initialize the SimpleMultiModalDataModule. + + Parameters: + path (str): Path to the dataset. + tokenizer (Tokenizer): The tokenizer used for processing text. + image_processor (ImageProcessor): The image processor used for preprocessing images. + seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048. + micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. + num_workers (int, optional): Number of workers for data loading. Defaults to 1. + pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True. + """ + + super().__init__( + path=path, + tokenizer=None, + image_processor=None, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=num_workers, + packing_buffer_size=packing_buffer_size, + pin_memory=pin_memory, + task_encoder=task_encoder, + ) + self.use_train_split_for_val = use_train_split_for_val + + def datasets_provider(self, worker_config, split: Literal["train", "val"] = "val"): + """ + Provide the dataset for training or validation. + + This method retrieves the dataset for the specified split (either 'train' or 'val') and configures + it according to the worker configuration. + + Parameters: + worker_config: Configuration for the data loader workers. + split (Literal['train', 'val'], optional): The data split to retrieve ('train' or 'val'). Defaults to 'val'. + + Returns: + Dataset: The dataset configured for the specified split. + """ + if split not in {"train", "val"}: + raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.") + if self.use_train_split_for_val: + split = "train" + _dataset = get_train_dataset( + self.path, + batch_size=self.micro_batch_size, + packing_buffer_size=self.packing_buffer_size, + task_encoder=self.task_encoder, + worker_config=worker_config, + max_samples_per_sequence=None, + shuffle_buffer_size=100, + split_part=split, + batch_drop_last=True, + virtual_epoch_length=1_000_000_000, # a hack to avoid energon end of epoch warning + ) + return _dataset + + def val_dataloader(self): + """ + Configure the validation DataLoader. + + This method configures the DataLoader for validation data. + + Parameters: + worker_config: Configuration for the data loader workers. + + Returns: + DataLoader: The DataLoader for validation data. + """ + if self.use_train_split_for_val: + return self.train_dataloader() + return super().val_dataloader() + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Load the state of the data module from a checkpoint. + + This method is called when loading a checkpoint. It restores the state of the data module, + including the state of the dataloader and the number of consumed samples. + + Parameters: + state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module. + """ + try: + super().load_state_dict(state_dict) + except Exception as e: + logging.warning(f"datamodule.load_state_dict failed {e}") diff --git a/src/megatron/bridge/diffusion/data/common/diffusion_sample.py b/src/megatron/bridge/diffusion/data/common/diffusion_sample.py new file mode 100644 index 0000000000..702a392c6d --- /dev/null +++ b/src/megatron/bridge/diffusion/data/common/diffusion_sample.py @@ -0,0 +1,114 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Optional + +import torch +from megatron.energon import Sample + + +@dataclass +class DiffusionSample(Sample): + """ + Data class representing a sample for diffusion tasks. + + Attributes: + video (torch.Tensor): Video latents (C T H W). + t5_text_embeddings (torch.Tensor): Text embeddings (S D). + t5_text_mask (torch.Tensor): Mask for text embeddings. + loss_mask (torch.Tensor): Mask indicating valid positions for loss computation. + image_size (Optional[torch.Tensor]): Tensor containing image dimensions. + fps (Optional[torch.Tensor]): Frame rate of the video. + num_frames (Optional[torch.Tensor]): Number of frames in the video. + padding_mask (Optional[torch.Tensor]): Mask indicating padding positions. + seq_len_q (Optional[torch.Tensor]): Sequence length for query embeddings. + seq_len_q_padded (Optional[torch.Tensor]): Sequence length for query embeddings after padding. + seq_len_kv (Optional[torch.Tensor]): Sequence length for key/value embeddings. + pos_ids (Optional[torch.Tensor]): Positional IDs. + latent_shape (Optional[torch.Tensor]): Shape of the latent tensor. + video_metadata (Optional[dict]): Metadata of the video. + """ + + video: torch.Tensor # video latents (C T H W) + context_embeddings: torch.Tensor # (S D) + context_mask: torch.Tensor = None # 1 + image_size: Optional[torch.Tensor] = None + loss_mask: torch.Tensor = None + fps: Optional[torch.Tensor] = None + num_frames: Optional[torch.Tensor] = None + padding_mask: Optional[torch.Tensor] = None + seq_len_q: Optional[torch.Tensor] = None + seq_len_q_padded: Optional[torch.Tensor] = None + seq_len_kv: Optional[torch.Tensor] = None + seq_len_kv_padded: Optional[torch.Tensor] = None + pos_ids: Optional[torch.Tensor] = None + latent_shape: Optional[torch.Tensor] = None + video_metadata: Optional[dict] = None + + def to_dict(self) -> dict: + """Converts the sample to a dictionary.""" + return dict( + video=self.video, + context_embeddings=self.context_embeddings, + context_mask=self.context_mask, + loss_mask=self.loss_mask, + image_size=self.image_size, + fps=self.fps, + num_frames=self.num_frames, + padding_mask=self.padding_mask, + seq_len_q=self.seq_len_q, + seq_len_q_padded=self.seq_len_q_padded, + seq_len_kv=self.seq_len_kv, + seq_len_kv_padded=self.seq_len_kv_padded, + pos_ids=self.pos_ids, + latent_shape=self.latent_shape, + video_metadata=self.video_metadata, + ) + + def __add__(self, other: Any) -> int: + """Adds the sequence length of this sample with another sample or integer.""" + if isinstance(other, DiffusionSample): + # Use padded length if available (for CP), otherwise use unpadded + self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item() + other_len = other.seq_len_q_padded.item() if other.seq_len_q_padded is not None else other.seq_len_q.item() + return self_len + other_len + elif isinstance(other, int): + # Use padded length if available (for CP), otherwise use unpadded + self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item() + return self_len + other + raise NotImplementedError + + def __radd__(self, other: Any) -> int: + """Handles reverse addition for summing with integers.""" + # This is called if sum or other operations start with a non-DiffusionSample object. + # e.g., sum([DiffusionSample(1), DiffusionSample(2)]) -> the 0 + DiffusionSample(1) calls __radd__. + if isinstance(other, int): + # Use padded length if available (for CP), otherwise use unpadded + self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item() + return self_len + other + raise NotImplementedError + + def __lt__(self, other: Any) -> bool: + """Compares this sample's sequence length with another sample or integer.""" + if isinstance(other, DiffusionSample): + # Use padded length if available (for CP), otherwise use unpadded + self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item() + other_len = other.seq_len_q_padded.item() if other.seq_len_q_padded is not None else other.seq_len_q.item() + return self_len < other_len + elif isinstance(other, int): + # Use padded length if available (for CP), otherwise use unpadded + self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item() + return self_len < other + raise NotImplementedError diff --git a/src/megatron/bridge/diffusion/data/common/diffusion_task_encoder_with_sp.py b/src/megatron/bridge/diffusion/data/common/diffusion_task_encoder_with_sp.py new file mode 100644 index 0000000000..6d1a2b0b09 --- /dev/null +++ b/src/megatron/bridge/diffusion/data/common/diffusion_task_encoder_with_sp.py @@ -0,0 +1,122 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from abc import ABC, abstractmethod +from typing import List + +import torch +from megatron.energon import DefaultTaskEncoder +from megatron.energon.task_encoder.base import stateless +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys + +from megatron.bridge.diffusion.data.common.diffusion_sample import DiffusionSample +from megatron.bridge.diffusion.data.common.sequence_packing_utils import first_fit_decreasing + + +def cook(sample: dict) -> dict: + """ + Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - All keys from the result of `basic_sample_keys(sample)` + - 'json': The contains meta data like resolution, aspect ratio, fps, etc. + - 'pth': contains video latent tensor + - 'pickle': contains text embeddings + """ + return dict( + **basic_sample_keys(sample), + json=sample["json"], + pth=sample["pth"], + pickle=sample["pickle"], + ) + + +class DiffusionTaskEncoderWithSequencePacking(DefaultTaskEncoder, ABC): # noqa: D101 + cookers = [ + Cooker(cook), + ] + + def __init__( + self, + *args, + max_frames: int = None, + text_embedding_max_length: int = 512, + seq_length: int = None, + patch_spatial: int = 2, + patch_temporal: int = 1, + packing_buffer_size: int = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.max_frames = max_frames + self.text_embedding_max_length = text_embedding_max_length + self.seq_length = seq_length + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.packing_buffer_size = packing_buffer_size + + @abstractmethod + def encode_sample(self, sample: dict) -> dict: + raise NotImplementedError + + def select_samples_to_pack(self, samples: List[DiffusionSample]) -> List[List[DiffusionSample]]: + """ + Selects sequences to pack for mixed image-video training. + """ + results = first_fit_decreasing(samples, self.seq_length) + random.shuffle(results) + return results + + @stateless + def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSample: + """Construct a new Diffusion sample by concatenating the sequences.""" + + def stack(attr): + if hasattr(samples[0], attr) and getattr(samples[0], attr) is not None: + return torch.stack([getattr(sample, attr) for sample in samples], dim=0) + else: + return None + + def cat(attr): + if hasattr(samples[0], attr) and getattr(samples[0], attr) is not None: + return torch.cat([getattr(sample, attr) for sample in samples], dim=0) + else: + return None + + return DiffusionSample( + __key__=",".join([s.__key__ for s in samples]), + __restore_key__=(), # Will be set by energon based on `samples` + __subflavor__=None, + __subflavors__=samples[0].__subflavors__, + video=cat("video"), + context_embeddings=cat("context_embeddings"), + context_mask=cat("context_mask"), + loss_mask=cat("loss_mask"), + seq_len_q=cat("seq_len_q"), + seq_len_q_padded=cat("seq_len_q_padded"), + seq_len_kv=cat("seq_len_kv"), + seq_len_kv_padded=cat("seq_len_kv_padded"), + pos_ids=cat("pos_ids"), + latent_shape=stack("latent_shape"), + video_metadata=[sample.video_metadata for sample in samples], + ) + + @stateless + def batch(self, samples: List[DiffusionSample]) -> dict: + raise NotImplementedError diff --git a/src/megatron/bridge/diffusion/data/common/sequence_packing_utils.py b/src/megatron/bridge/diffusion/data/common/sequence_packing_utils.py new file mode 100644 index 0000000000..612d083c29 --- /dev/null +++ b/src/megatron/bridge/diffusion/data/common/sequence_packing_utils.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List + + +def find_first_bin_that_fits(bins: List[List[int]], s: int, bin_size: int) -> int: + """ + Finds the first bin in a list of bins that has enough space to fit a sequence of size 's'. + + Args: + bins: A list of lists, where each inner list represents a bin and contains the current elements in that bin. + s: The size of the sequence to be placed in a bin. + bin_size: The maximum capacity of each bin. + + Returns: + The index of the first bin that can fit the sequence 's', or -1 if no such bin exists. + """ + for i, abin in enumerate(bins): + if sum(abin) + s <= bin_size: + return i + return -1 + + +def first_fit(seqlens: List[int], pack_size: int) -> List[List[int]]: + """ + Packs sequences of varying lengths into bins using the First-Fit algorithm. + + Args: + seqlens: A list of integers, representing the lengths of the sequences to be packed. + pack_size: The maximum capacity of each bin. + + Returns: + A list of lists, where each inner list represents a bin and contains the indices + of the sequences assigned to that bin. + """ + res = [] + for s in seqlens: + first_bin = find_first_bin_that_fits(res, s, pack_size) + if first_bin == -1: # open a new bin + res.append([s]) + else: + res[first_bin].append(s) + return res + + +def first_fit_decreasing(seqlens: List[int], pack_size: int) -> List[List[int]]: + """ + Packs sequences of varying lengths into bins using the First-Fit Decreasing algorithm. + + This is a variation of the First-Fit algorithm where the sequences are sorted by decreasing length before packing. + + Args: + seqlens: A list of integers, representing the lengths of the sequences to be packed. + pack_size: The maximum capacity of each bin. + + Returns: + A list of lists, similar to the output of the 'first_fit' function. + """ + sorted_seqlens = sorted(seqlens, reverse=True) + return first_fit(sorted_seqlens, pack_size) diff --git a/src/megatron/bridge/diffusion/data/flux/__init__.py b/src/megatron/bridge/diffusion/data/flux/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/data/flux/flux_energon_datamodule.py b/src/megatron/bridge/diffusion/data/flux/flux_energon_datamodule.py new file mode 100644 index 0000000000..26605b48b6 --- /dev/null +++ b/src/megatron/bridge/diffusion/data/flux/flux_energon_datamodule.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from dataclasses import dataclass + +from torch import int_repr + +from megatron.bridge.data.utils import DatasetBuildContext +from megatron.bridge.diffusion.data.common.diffusion_energon_datamodule import ( + DiffusionDataModule, + DiffusionDataModuleConfig, +) +from megatron.bridge.diffusion.data.flux.flux_taskencoder import FluxTaskEncoder + + +@dataclass(kw_only=True) +class FluxDataModuleConfig(DiffusionDataModuleConfig): # noqa: D101 + path: str + seq_length: int + packing_buffer_size: int + micro_batch_size: int + global_batch_size: int + num_workers: int_repr + dataloader_type: str = "external" + vae_scale_factor: int = 8 + latent_channels: int = 16 + + def __post_init__(self): + self.dataset = DiffusionDataModule( + path=self.path, + seq_length=self.seq_length, + packing_buffer_size=self.packing_buffer_size, + task_encoder=FluxTaskEncoder( + seq_length=self.seq_length, + packing_buffer_size=self.packing_buffer_size, + vae_scale_factor=self.vae_scale_factor, + latent_channels=self.latent_channels, + ), + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + num_workers=self.num_workers, + use_train_split_for_val=True, + ) + self.sequence_length = self.dataset.seq_length + + def build_datasets(self, context: DatasetBuildContext): + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() diff --git a/src/megatron/bridge/diffusion/data/flux/flux_mock_datamodule.py b/src/megatron/bridge/diffusion/data/flux/flux_mock_datamodule.py new file mode 100644 index 0000000000..9a7598c27c --- /dev/null +++ b/src/megatron/bridge/diffusion/data/flux/flux_mock_datamodule.py @@ -0,0 +1,256 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mock data module for FLUX model training.""" + +from dataclasses import dataclass + +import torch +from torch.utils.data import DataLoader, Dataset + +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider + + +class _MockT2IDataset(Dataset): + """ + A mock dataset class for text-to-image tasks, simulating data samples for training and testing. + + This dataset generates synthetic data for both image and text inputs, with options to use + pre-cached latent representations or raw data. The class is designed for use in testing and + prototyping machine learning models. + + Attributes: + image_H (int): Height of the generated images. + image_W (int): Width of the generated images. + length (int): Total number of samples in the dataset. + image_precached (bool): Whether to use pre-cached latent representations for images. + text_precached (bool): Whether to use pre-cached embeddings for text. + prompt_seq_len (int): Sequence length for text prompts. + pooled_prompt_dim (int): Dimensionality of pooled text embeddings. + context_dim (int): Dimensionality of the text embedding context. + vae_scale_factor (int): Scaling factor for the VAE latent representation. + vae_channels (int): Number of channels in the VAE latent representation. + """ + + def __init__( + self, + image_H: int = 1024, + image_W: int = 1024, + length: int = 100000, + image_precached: bool = True, + text_precached: bool = True, + prompt_seq_len: int = 512, + pooled_prompt_dim: int = 768, + context_dim: int = 4096, + vae_scale_factor: int = 8, + vae_channels: int = 16, + ): + super().__init__() + self.length = length + self.H = image_H + self.W = image_W + self.image_precached = image_precached + self.text_precached = text_precached + self.vae_channels = vae_channels + self.vae_scale_factor = vae_scale_factor + self.prompt_seq_len = prompt_seq_len + self.pooled_prompt_dim = pooled_prompt_dim + self.context_dim = context_dim + + if self.image_precached: + self.latent_shape = ( + vae_channels, + int(image_H // vae_scale_factor), + int(image_W // vae_scale_factor), + ) + if self.text_precached: + self.prompt_embeds_shape = (prompt_seq_len, context_dim) + self.pooled_prompt_embeds_shape = (pooled_prompt_dim,) + self.text_ids_shape = (prompt_seq_len, 3) + + def __getitem__(self, index): + """ + Retrieves a single sample from the dataset. + + The sample includes pre-cached latent representations for images and text. + + Args: + index (int): Index of the sample to retrieve. + + Returns: + dict: A dictionary containing the generated data sample with keys: + - 'latents': Pre-cached latent representation of the image [C, H, W]. + - 'prompt_embeds': Pre-cached text prompt embeddings [seq_len, context_dim]. + - 'pooled_prompt_embeds': Pooled text prompt embeddings [pooled_dim]. + - 'text_ids': Text position IDs [seq_len, 3]. + """ + item = {} + + if self.image_precached: + # Latents in [C, H, W] format - will be batched to [B, C, H, W] + item["latents"] = torch.randn(self.latent_shape, dtype=torch.bfloat16) + else: + # Raw images [3, H, W] + item["images"] = torch.randn(3, self.H, self.W, dtype=torch.bfloat16) + + if self.text_precached: + # T5 embeddings [seq_len, context_dim] + item["prompt_embeds"] = torch.randn(self.prompt_embeds_shape, dtype=torch.bfloat16) + # CLIP pooled embeddings [pooled_dim] + item["pooled_prompt_embeds"] = torch.randn(self.pooled_prompt_embeds_shape, dtype=torch.bfloat16) + # Text position IDs [seq_len, 3] + item["text_ids"] = torch.zeros(self.text_ids_shape, dtype=torch.bfloat16) + else: + item["txt"] = "This is a sample caption input" + + return item + + def __len__(self): + """Returns the total number of samples in the dataset.""" + return self.length + + +def _collate_fn(samples): + """ + Collate function to batch samples from _MockT2IDataset. + + Args: + samples: List of sample dictionaries from the dataset. + + Returns: + dict: Batched dictionary with stacked tensors. + """ + batch = {} + + # Stack latents: [B, C, H, W] + if "latents" in samples[0]: + batch["latents"] = torch.stack([s["latents"] for s in samples], dim=0) + elif "images" in samples[0]: + batch["images"] = torch.stack([s["images"] for s in samples], dim=0) + + # Stack text embeddings + if "prompt_embeds" in samples[0]: + # [B, seq_len, context_dim] + batch["prompt_embeds"] = torch.stack([s["prompt_embeds"] for s in samples], dim=0) + # [B, pooled_dim] + batch["pooled_prompt_embeds"] = torch.stack([s["pooled_prompt_embeds"] for s in samples], dim=0) + # [B, seq_len, 3] + batch["text_ids"] = torch.stack([s["text_ids"] for s in samples], dim=0) + elif "txt" in samples[0]: + batch["txt"] = [s["txt"] for s in samples] + + # Add loss mask (all ones) + if "latents" in batch: + batch_size = batch["latents"].shape[0] + latent_h = batch["latents"].shape[2] + latent_w = batch["latents"].shape[3] + # Loss mask covers all latent positions + batch["loss_mask"] = torch.ones(batch_size, latent_h * latent_w, dtype=torch.bfloat16) + + return batch + + +@dataclass(kw_only=True) +class FluxMockDataModuleConfig(DatasetProvider): + """ + Configuration for FLUX mock data module. + + This data module generates synthetic data for FLUX model training, + matching the expected input format of FluxForwardStep. + + Attributes: + path: Unused, kept for interface compatibility. + seq_length: Sequence length (unused for FLUX, kept for interface compatibility). + packing_buffer_size: Packing buffer size (unused for FLUX). + micro_batch_size: Micro batch size for training. + global_batch_size: Global batch size for training. + num_workers: Number of data loading workers. + dataloader_type: Type of dataloader ("external" for mock data). + image_H: Height of input images. + image_W: Width of input images. + vae_channels: Number of VAE latent channels. + vae_scale_factor: VAE spatial downsampling factor. + prompt_seq_len: Sequence length for T5 text embeddings. + context_dim: Dimensionality of T5 text embeddings. + pooled_prompt_dim: Dimensionality of CLIP pooled embeddings. + image_precached: Whether images are pre-encoded as VAE latents. + text_precached: Whether text is pre-encoded as embeddings. + num_train_samples: Number of training samples. + """ + + path: str = "" + seq_length: int = 1024 + packing_buffer_size: int = None + micro_batch_size: int = 1 + global_batch_size: int = 4 + num_workers: int = 8 + dataloader_type: str = "external" + + # Image dimensions + image_H: int = 1024 + image_W: int = 1024 + + # VAE settings + vae_channels: int = 16 + vae_scale_factor: int = 8 + + # Text embedding settings + prompt_seq_len: int = 512 + context_dim: int = 4096 + pooled_prompt_dim: int = 768 + + # Precaching settings (FLUX typically uses precached data) + image_precached: bool = True + text_precached: bool = True + + # Dataset size + num_train_samples: int = 10000 + + def __post_init__(self): + """Initialize the mock dataset and dataloader.""" + mock_ds = _MockT2IDataset( + image_H=self.image_H, + image_W=self.image_W, + length=self.num_train_samples, + image_precached=self.image_precached, + text_precached=self.text_precached, + prompt_seq_len=self.prompt_seq_len, + pooled_prompt_dim=self.pooled_prompt_dim, + context_dim=self.context_dim, + vae_scale_factor=self.vae_scale_factor, + vae_channels=self.vae_channels, + ) + + kwargs = {} + if self.num_workers > 0: + kwargs["prefetch_factor"] = 8 + kwargs["persistent_workers"] = True + + self._train_dl = DataLoader( + mock_ds, + batch_size=self.micro_batch_size, + num_workers=self.num_workers, + collate_fn=_collate_fn, + shuffle=True, + drop_last=True, + pin_memory=True, + **kwargs, + ) + self._train_dl_iter = iter(self._train_dl) + self.sequence_length = self.seq_length + + def build_datasets(self, _context: DatasetBuildContext): + """Build and return train/val/test dataloaders.""" + # Return iterator for external dataloader type + return self._train_dl_iter, self._train_dl_iter, self._train_dl_iter diff --git a/src/megatron/bridge/diffusion/data/flux/flux_taskencoder.py b/src/megatron/bridge/diffusion/data/flux/flux_taskencoder.py new file mode 100644 index 0000000000..43b4f47186 --- /dev/null +++ b/src/megatron/bridge/diffusion/data/flux/flux_taskencoder.py @@ -0,0 +1,297 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from typing import List + +import torch +import torch.nn.functional as F +from megatron.core import parallel_state +from megatron.energon import SkipSample +from megatron.energon.task_encoder.base import stateless +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys + +from megatron.bridge.diffusion.data.common.diffusion_sample import DiffusionSample +from megatron.bridge.diffusion.data.common.diffusion_task_encoder_with_sp import ( + DiffusionTaskEncoderWithSequencePacking, +) + + +def cook(sample: dict) -> dict: + """ + Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - All keys from the result of `basic_sample_keys(sample)` + - 'json': The contains meta data like resolution, etc. + - 'pth': contains image latent tensor + - 'pickle': contains text embeddings (T5 and CLIP pooled) + """ + return dict( + **basic_sample_keys(sample), + json=sample["json"], + pth=sample["pth"], + pickle=sample["pickle"], + ) + + +class FluxTaskEncoder(DiffusionTaskEncoderWithSequencePacking): + """ + Task encoder for Flux dataset. + Attributes: + cookers (list): A list of Cooker objects used for processing. + vae_scale_factor (int): The VAE downsampling factor. Defaults to 8. + seq_length (int): The sequence length. Defaults to 1024. + latent_channels (int): Number of latent channels from VAE. Defaults to 16. + """ + + cookers = [ + Cooker(cook), + ] + + def __init__( + self, + *args, + vae_scale_factor: int = 8, + seq_length: int = 1024, + latent_channels: int = 16, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.vae_scale_factor = vae_scale_factor + self.seq_length = seq_length + self.latent_channels = latent_channels + + @stateless(restore_seeds=True) + def encode_sample(self, sample: dict) -> dict: + image_latent = sample["pth"] + text_embeddings = sample["pickle"] + image_metadata = sample["json"] + + # sanity quality check + if torch.isnan(image_latent).any() or torch.isinf(image_latent).any(): + raise SkipSample() + if torch.max(torch.abs(image_latent)) > 1e3: + raise SkipSample() + + # image_latent shape: [C, H, W] + # Keep latents unpacked - flux_step will pack them during forward pass + C, H, W = image_latent.shape + + # Extract T5 embeddings and CLIP pooled embeddings + # text_embeddings is expected to be a dict with keys: + # - 'prompt_embeds': T5 embeddings [text_seq_len, context_dim] + # - 'pooled_prompt_embeds': CLIP pooled embeddings [pooled_dim] + if isinstance(text_embeddings, dict): + prompt_embeds = text_embeddings.get("prompt_embeds", text_embeddings.get("t5_embeds")) + pooled_prompt_embeds = text_embeddings.get("pooled_prompt_embeds", text_embeddings.get("clip_embeds")) + + # Ensure pooled_prompt_embeds is not None + if pooled_prompt_embeds is None: + pooled_prompt_embeds = torch.zeros(768, dtype=torch.bfloat16) + else: + # If it's a single tensor, assume it's T5 embeddings + prompt_embeds = text_embeddings + pooled_prompt_embeds = torch.zeros(768, dtype=torch.bfloat16) # Default CLIP dim + + # pad text embeddings to fixed length + text_max_len = 512 + if prompt_embeds.shape[0] < text_max_len: + prompt_embeds = F.pad(prompt_embeds, (0, 0, 0, text_max_len - prompt_embeds.shape[0])) + else: + prompt_embeds = prompt_embeds[:text_max_len] + + # calculate sequence lengths + # For flux, seq_len_q is the number of patches after packing: (H/2)*(W/2) + seq_len_q = (H // 2) * (W // 2) + seq_len_kv = prompt_embeds.shape[0] # text_seq_len + + # loss mask - covers all latent positions + loss_mask = torch.ones(seq_len_q, dtype=torch.bfloat16) + + # CAVEAT: + # when using context parallelism, we need to pad batch sequence length to be divisible by [cp_rank*2] + # (because TransformerEngine's context parallelism requires "AssertionError: Sequence length per GPU needs to be divisible by 2!") + if parallel_state.get_context_parallel_world_size() > 1: + sharding_factor = parallel_state.get_context_parallel_world_size() * 2 + seq_len_q_padded = ((seq_len_q + sharding_factor - 1) // sharding_factor) * sharding_factor + seq_len_kv_padded = ((seq_len_kv + sharding_factor - 1) // sharding_factor) * sharding_factor + else: + seq_len_q_padded = seq_len_q + seq_len_kv_padded = seq_len_kv + + # padding + if seq_len_q < seq_len_q_padded: + # Note: For unpacked latents [C, H, W], we need to pad H and W dimensions + # But since we're padding sequence length, we pad the loss_mask only + # The latent padding will be handled during packing in flux_step + loss_mask = F.pad(loss_mask, (0, seq_len_q_padded - seq_len_q)) + if seq_len_kv < seq_len_kv_padded: + prompt_embeds = F.pad(prompt_embeds, (0, 0, 0, seq_len_kv_padded - seq_len_kv)) + + ### Note: shape of sample's values + # image_latent: [C, H, W] - unpacked format + # latent_shape: [H, W] + # prompt_embeds: [text_seq_len, text_embedding_dim] + # pooled_prompt_embeds: [pooled_dim] + # text_ids: [text_seq_len, 3] + + # Prepare text IDs for position encoding + text_ids = torch.zeros(prompt_embeds.shape[0], 3, dtype=torch.bfloat16) + text_ids[:, 0] = torch.arange(prompt_embeds.shape[0], dtype=torch.bfloat16) + + # Store pooled embeddings and text_ids in metadata + metadata = { + "pooled_prompt_embeds": pooled_prompt_embeds, + "text_ids": text_ids, + "original_metadata": image_metadata, + } + + return DiffusionSample( + __key__=sample["__key__"], + __restore_key__=sample["__restore_key__"], + __subflavor__=None, + __subflavors__=sample["__subflavors__"], + video=image_latent, # Store unpacked latents [C, H, W] + context_embeddings=prompt_embeds, + latent_shape=torch.tensor([H, W], dtype=torch.int32), + loss_mask=loss_mask, + seq_len_q=torch.tensor([seq_len_q], dtype=torch.int32), + seq_len_q_padded=torch.tensor([seq_len_q_padded], dtype=torch.int32), + seq_len_kv=torch.tensor([seq_len_kv], dtype=torch.int32), + seq_len_kv_padded=torch.tensor([seq_len_kv_padded], dtype=torch.int32), + pos_ids=torch.zeros(1, dtype=torch.bfloat16), # dummy pos_ids + video_metadata=metadata, + ) + + # NOTE: + # the method select_samples_to_pack() and pack_selected_samples() are inherited from the parent + # class DiffusionTaskEncoderWithSequencePacking + + @stateless + def batch(self, samples: List[DiffusionSample]) -> dict: + """Return dictionary with data for batch.""" + + # Helper function to extract metadata + def extract_metadata(sample): + # Handle case where video_metadata is a list (from packed samples) + metadata = sample.video_metadata + if isinstance(metadata, list): + metadata = metadata[0] if len(metadata) > 0 else {} + + if isinstance(metadata, dict) and "pooled_prompt_embeds" in metadata: + pooled = metadata["pooled_prompt_embeds"] + text_ids = metadata.get("text_ids", torch.zeros(512, 3, dtype=torch.bfloat16)) + orig_metadata = metadata.get("original_metadata", metadata) + + return (pooled, text_ids, orig_metadata) + else: + raise ValueError("Expected 'pooled_prompt_embeds' in metadata.") + + if self.packing_buffer_size is None: + # No packing - batch multiple samples + latents_list = [] + prompt_embeds_list = [] + pooled_embeds_list = [] + text_ids_list = [] + loss_mask_list = [] + seq_len_q_list = [] + seq_len_q_padded_list = [] + seq_len_kv_list = [] + seq_len_kv_padded_list = [] + latent_shape_list = [] + metadata_list = [] + + for sample in samples: + pooled, text_ids, metadata = extract_metadata(sample) + + latents_list.append(sample.video) + prompt_embeds_list.append(sample.context_embeddings) + pooled_embeds_list.append(pooled) + text_ids_list.append(text_ids) + loss_mask_list.append(sample.loss_mask if sample.loss_mask is not None else torch.ones(1)) + seq_len_q_list.append(sample.seq_len_q) + seq_len_q_padded_list.append(sample.seq_len_q_padded) + seq_len_kv_list.append(sample.seq_len_kv) + seq_len_kv_padded_list.append(sample.seq_len_kv_padded) + latent_shape_list.append(sample.latent_shape) + metadata_list.append(metadata) + + return dict( + latents=torch.stack(latents_list), + prompt_embeds=torch.stack(prompt_embeds_list), + pooled_prompt_embeds=torch.stack(pooled_embeds_list), + text_ids=torch.stack(text_ids_list), + loss_mask=torch.stack(loss_mask_list), + seq_len_q=torch.cat(seq_len_q_list), + seq_len_q_padded=torch.cat(seq_len_q_padded_list), + seq_len_kv=torch.cat(seq_len_kv_list), + seq_len_kv_padded=torch.cat(seq_len_kv_padded_list), + latent_shape=torch.stack(latent_shape_list), + image_metadata=metadata_list, + ) + + # Packing case - single packed sample + sample = samples[0] + pooled_prompt_embeds, text_ids, image_metadata = extract_metadata(sample) + + # Stack to create batch dimension + # sample.video has shape [C, H, W] -> unsqueeze to [1, C, H, W] for batch + latents = sample.video.unsqueeze(0) # [1, C, H, W] + + # Prompt embeds: [text_seq_len, D] -> [1, text_seq_len, D] for batch + prompt_embeds = sample.context_embeddings.unsqueeze(0) # [1, text_seq_len, D] + + # Pooled embeds: [pooled_dim] -> [1, pooled_dim] for batch + pooled_prompt_embeds = pooled_prompt_embeds.unsqueeze(0) # [1, pooled_dim] + + # Text IDs: [text_seq_len, 3] -> [1, text_seq_len, 3] for batch + text_ids = text_ids.unsqueeze(0) # [1, text_seq_len, 3] + + # Loss mask: [seq_len_q] -> [1, seq_len_q] for batch + loss_mask = sample.loss_mask.unsqueeze(0) if sample.loss_mask is not None else None + + batch = dict( + latents=latents, # [1, C, H, W] - unpacked format + prompt_embeds=prompt_embeds, # [1, text_seq_len, D] + pooled_prompt_embeds=pooled_prompt_embeds, # [1, pooled_dim] + text_ids=text_ids, # [1, text_seq_len, 3] + loss_mask=loss_mask, # [1, seq_len_q] + seq_len_q=sample.seq_len_q, + seq_len_q_padded=sample.seq_len_q_padded, + seq_len_kv=sample.seq_len_kv, + seq_len_kv_padded=sample.seq_len_kv_padded, + latent_shape=sample.latent_shape, # [H, W] + image_metadata=image_metadata, + ) + + ### Note: shape of batch's values (with packing_buffer_size, batch size is 1) + # latents: [1, C, H, W] - unpacked format + # prompt_embeds: [1, text_seq_len, D] + # pooled_prompt_embeds: [1, pooled_dim] + # text_ids: [1, text_seq_len, 3] + # loss_mask: [1, seq_len_q] where seq_len_q = (H/2)*(W/2) + # seq_len_q: [num_samples] + # seq_len_q_padded: [num_samples] + # seq_len_kv: [num_samples] + # seq_len_kv_padded: [num_samples] + # latent_shape: [num_samples, 2] + # image_metadata: [num_samples] + + return batch diff --git a/src/megatron/bridge/diffusion/data/wan/__init__.py b/src/megatron/bridge/diffusion/data/wan/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/data/wan/wan_energon_datamodule.py b/src/megatron/bridge/diffusion/data/wan/wan_energon_datamodule.py new file mode 100644 index 0000000000..e98cd1cb00 --- /dev/null +++ b/src/megatron/bridge/diffusion/data/wan/wan_energon_datamodule.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from dataclasses import dataclass + +from torch import int_repr + +from megatron.bridge.data.utils import DatasetBuildContext +from megatron.bridge.diffusion.data.common.diffusion_energon_datamodule import ( + DiffusionDataModule, + DiffusionDataModuleConfig, +) +from megatron.bridge.diffusion.data.wan.wan_taskencoder import WanTaskEncoder + + +@dataclass(kw_only=True) +class WanDataModuleConfig(DiffusionDataModuleConfig): # noqa: D101 + path: str + seq_length: int + packing_buffer_size: int + micro_batch_size: int + global_batch_size: int + num_workers: int_repr + dataloader_type: str = "external" + + def __post_init__(self): + self.dataset = DiffusionDataModule( + path=self.path, + seq_length=self.seq_length, + packing_buffer_size=self.packing_buffer_size, + task_encoder=WanTaskEncoder(seq_length=self.seq_length, packing_buffer_size=self.packing_buffer_size), + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + num_workers=self.num_workers, + ) + self.sequence_length = self.dataset.seq_length + + def build_datasets(self, context: DatasetBuildContext): + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() diff --git a/src/megatron/bridge/diffusion/data/wan/wan_mock_datamodule.py b/src/megatron/bridge/diffusion/data/wan/wan_mock_datamodule.py new file mode 100644 index 0000000000..0a55942387 --- /dev/null +++ b/src/megatron/bridge/diffusion/data/wan/wan_mock_datamodule.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from dataclasses import dataclass + +import torch +from torch.utils.data import DataLoader, Dataset + +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider +from megatron.bridge.diffusion.models.wan.utils import patchify + + +class _MockDataset(Dataset): + def __init__(self, length: int): + self.length = max(int(length), 1) + + def __len__(self) -> int: + return self.length + + def __getitem__(self, idx: int) -> dict: + return {} + + +def mock_batch( # noqa: D103 + F_latents: int, + H_latents: int, + W_latents: int, + patch_temporal: int, + patch_spatial: int, + number_packed_samples: int, + context_seq_len: int, + context_embeddings_dim: int, +) -> dict: + # set mock values for one video sample + video_latent = torch.randn(16, F_latents, H_latents, W_latents, dtype=torch.float32) + grid_size = torch.tensor( + [ + video_latent.shape[1] // patch_temporal, + video_latent.shape[2] // patch_spatial, + video_latent.shape[3] // patch_spatial, + ], + dtype=torch.int32, + ) + video_latent = patchify([video_latent], (patch_temporal, patch_spatial, patch_spatial))[0] + video_latent = torch.as_tensor(video_latent, dtype=torch.float32) + seq_len_q = video_latent.shape[0] + seq_len_q_padded = seq_len_q + loss_mask = torch.ones(seq_len_q, dtype=torch.bfloat16) + context_embeddings = torch.randn(context_seq_len, context_embeddings_dim, dtype=torch.float32) + seq_len_kv = context_embeddings.shape[0] + seq_len_kv_padded = seq_len_kv + video_metadata = {} + + # set mock values for packed video samples + video_latents_packed = [video_latent for _ in range(number_packed_samples)] + video_latents_packed = torch.cat(video_latents_packed, dim=0) + loss_masks_packed = [loss_mask for _ in range(number_packed_samples)] + loss_masks_packed = torch.cat(loss_masks_packed, dim=0) + seq_len_q_packed = torch.tensor([seq_len_q for _ in range(number_packed_samples)], dtype=torch.int32) + seq_len_q_padded_packed = torch.tensor([seq_len_q_padded for _ in range(number_packed_samples)], dtype=torch.int32) + seq_len_kv_packed = torch.tensor([seq_len_kv for _ in range(number_packed_samples)], dtype=torch.int32) + seq_len_kv_padded_packed = torch.tensor( + [seq_len_kv_padded for _ in range(number_packed_samples)], dtype=torch.int32 + ) + grid_sizes_packed = torch.stack([grid_size for _ in range(number_packed_samples)], dim=0) + context_embeddings_packed = [context_embeddings for _ in range(number_packed_samples)] + context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0) + + ### Note: shape of sample's values + # video_latent: [num_patches, latents_channels * pF * pH * pW] + # grid_size: [F_patches, W_patches, H_patches] + # context_embeddings: [context_seq_len, text_embedding_dim] + + batch = dict( + video_latents=video_latents_packed.unsqueeze(1), + context_embeddings=context_embeddings_packed.unsqueeze(1), + loss_mask=loss_masks_packed.unsqueeze(1), + seq_len_q=seq_len_q_packed, + seq_len_q_padded=seq_len_q_padded_packed, + seq_len_kv=seq_len_kv_packed, + seq_len_kv_padded=seq_len_kv_padded_packed, + grid_sizes=grid_sizes_packed, + video_metadata=video_metadata, + ) + + return batch + + +@dataclass(kw_only=True) +class WanMockDataModuleConfig(DatasetProvider): # noqa: D101 + path: str = "" + seq_length: int + packing_buffer_size: int + micro_batch_size: int + global_batch_size: int + num_workers: int + dataloader_type: str = "external" + F_latents: int = 24 + H_latents: int = 104 + W_latents: int = 60 + patch_spatial: int = 2 + patch_temporal: int = 1 + number_packed_samples: int = 1 + context_seq_len: int = 512 + context_embeddings_dim: int = 4096 + + def __post_init__(self): + mock_ds = _MockDataset(length=1024) + kwargs = {} + if self.num_workers > 0: + kwargs["prefetch_factor"] = 8 + self._train_dl = DataLoader( + mock_ds, + batch_size=self.micro_batch_size, + num_workers=self.num_workers, + collate_fn=lambda samples: mock_batch( + F_latents=self.F_latents, + H_latents=self.H_latents, + W_latents=self.W_latents, + patch_temporal=self.patch_temporal, + patch_spatial=self.patch_spatial, + number_packed_samples=self.number_packed_samples, + context_seq_len=self.context_seq_len, + context_embeddings_dim=self.context_embeddings_dim, + ), + shuffle=False, + drop_last=False, + pin_memory=True, + **kwargs, + ) + self._train_dl = iter(self._train_dl) + self.sequence_length = self.seq_length + + def build_datasets(self, _context: DatasetBuildContext): + if hasattr(self, "dataset"): + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() + return self._train_dl, self._train_dl, self._train_dl diff --git a/src/megatron/bridge/diffusion/data/wan/wan_taskencoder.py b/src/megatron/bridge/diffusion/data/wan/wan_taskencoder.py new file mode 100644 index 0000000000..bf3e0b63ec --- /dev/null +++ b/src/megatron/bridge/diffusion/data/wan/wan_taskencoder.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from typing import List + +import torch +import torch.nn.functional as F +from megatron.core import parallel_state +from megatron.energon import SkipSample +from megatron.energon.task_encoder.base import stateless +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys + +from megatron.bridge.diffusion.data.common.diffusion_sample import DiffusionSample +from megatron.bridge.diffusion.data.common.diffusion_task_encoder_with_sp import ( + DiffusionTaskEncoderWithSequencePacking, +) +from megatron.bridge.diffusion.models.wan.utils import grid_sizes_calculation, patchify + + +def cook(sample: dict) -> dict: + """ + Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - All keys from the result of `basic_sample_keys(sample)` + - 'json': The contains meta data like resolution, aspect ratio, fps, etc. + - 'pth': contains video latent tensor + - 'pickle': contains text embeddings + """ + return dict( + **basic_sample_keys(sample), + json=sample["json"], + pth=sample["pth"], + pickle=sample["pickle"], + ) + + +class WanTaskEncoder(DiffusionTaskEncoderWithSequencePacking): + """ + Task encoder for Wan dataset. + Attributes: + cookers (list): A list of Cooker objects used for processing. + patch_spatial (int): The spatial patch size. Defaults to 2. + patch_temporal (int): The temporal patch size. Defaults to 1. + seq_length (int): The sequence length. Defaults to 1024. + """ + + cookers = [ + Cooker(cook), + ] + + def __init__( + self, + *args, + max_frames: int = None, + patch_spatial: int = 2, + patch_temporal: int = 1, + seq_length: int = 1024, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.seq_length = seq_length + + @stateless(restore_seeds=True) + def encode_sample(self, sample: dict) -> dict: + video_latent = sample["pth"] + context_embeddings = sample["pickle"] + video_metadata = sample["json"] + + # sanity quality check + if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): + raise SkipSample() + if torch.max(torch.abs(video_latent)) > 1e3: + raise SkipSample() + + # calculate grid size + grid_size = grid_sizes_calculation( + input_shape=video_latent.shape[1:], + patch_size=(self.patch_temporal, self.patch_spatial, self.patch_spatial), + ) + + # patchify video_latent + video_latent = patchify([video_latent], (self.patch_temporal, self.patch_spatial, self.patch_spatial))[0] + + # process text embeddings + # pad here for text embeddings + context_max_len = 512 + context_embeddings = F.pad(context_embeddings, (0, 0, 0, context_max_len - context_embeddings.shape[0])) + + # calculate sequence length + seq_len_q = video_latent.shape[0] + seq_len_kv = context_embeddings.shape[0] + + # loss mask + loss_mask = torch.ones(seq_len_q, dtype=torch.bfloat16) + + # CAVEAT: + # when using context parallelism, we need to pad batch sequence length to be divisible by [cp_rank*2] + # (because TransformerEngine's context parallelism requires "AssertionError: Sequence length per GPU needs to be divisible by 2!") + if parallel_state.get_context_parallel_world_size() > 1: + sharding_factor = parallel_state.get_context_parallel_world_size() * 2 + seq_len_q_padded = ((seq_len_q + sharding_factor - 1) // sharding_factor) * sharding_factor + seq_len_kv_padded = ((seq_len_kv + sharding_factor - 1) // sharding_factor) * sharding_factor + else: + seq_len_q_padded = seq_len_q + seq_len_kv_padded = seq_len_kv + + # padding + if seq_len_q < seq_len_q_padded: + video_latent = F.pad(video_latent, (0, 0, 0, seq_len_q_padded - seq_len_q)) + loss_mask = F.pad(loss_mask, (0, seq_len_q_padded - seq_len_q)) + context_embeddings = F.pad(context_embeddings, (0, 0, 0, seq_len_kv_padded - seq_len_kv)) + + ### Note: shape of sample's values + # video_latent: [num_patches, latents_channels * pF * pH * pW] + # grid_size: [F_patches, W_patches, H_patches] + # context_embeddings: [context_seq_len, text_embedding_dim] + + return DiffusionSample( + __key__=sample["__key__"], + __restore_key__=sample["__restore_key__"], + __subflavor__=None, + __subflavors__=sample["__subflavors__"], + video=video_latent, + context_embeddings=context_embeddings, + latent_shape=torch.tensor(grid_size, dtype=torch.int32), + loss_mask=loss_mask, + seq_len_q=torch.tensor([seq_len_q], dtype=torch.int32), + seq_len_q_padded=torch.tensor([seq_len_q_padded], dtype=torch.int32), + seq_len_kv=torch.tensor([seq_len_kv], dtype=torch.int32), + seq_len_kv_padded=torch.tensor([seq_len_kv_padded], dtype=torch.int32), + pos_ids=torch.zeros(1, dtype=torch.bfloat16), # dummy pos_ids + video_metadata=video_metadata, + ) + + # NOTE: + # the method select_samples_to_pack() and pack_selected_samples() are inherited from the parent + # class DiffusionTaskEncoderWithSequencePacking + + @stateless + def batch(self, samples: List[DiffusionSample]) -> dict: + """Return dictionary with data for batch.""" + if self.packing_buffer_size is None: + # no packing + return super().batch(samples).to_dict() + + # packing + sample = samples[0] + + # # CAVEAT: + # # when using pipeline parallelism, we need to set batch sequence length to DataModule's seq_length because + # # because pipeline parallelism requires pre-specified sequence length to create buffer + # if parallel_state.get_pipeline_model_parallel_world_size() > 1: + # if sample.video.shape[0] > self.seq_length: + # raise ValueError( + # f"video sequence length {sample.video.shape[0]} is greater than DataModule's seq_length {self.seq_length}" + # ) + # else: + # # set max_video_seq_len to DataModule's seq_length + # padded_seq_len = self.seq_length + + batch = dict( + video_latents=sample.video.unsqueeze(1), + context_embeddings=sample.context_embeddings.unsqueeze(1), + loss_mask=sample.loss_mask.unsqueeze(1) if sample.loss_mask is not None else None, + seq_len_q=sample.seq_len_q, + seq_len_q_padded=sample.seq_len_q_padded, + seq_len_kv=sample.seq_len_kv, + seq_len_kv_padded=sample.seq_len_kv_padded, + grid_sizes=sample.latent_shape, + video_metadata=sample.video_metadata, + ) + + ### Note: shape of batch's values + # video_latents: [seq_len, 1, latents_channels * pF * pH * pW] + # context_embeddings: [seq_len, 1, text_embedding_dim] + # loss_mask: [seq_len, 1] + # seq_len_q: [num_samples] + # seq_len_q_padded: [num_samples] + # seq_len_kv: [num_samples] + # seq_len_kv_padded: [num_samples] + # grid_sizes: [num_samples, 3] + # video_metadata: [num_samples] + + return batch diff --git a/src/megatron/bridge/diffusion/models/README.md b/src/megatron/bridge/diffusion/models/README.md new file mode 100644 index 0000000000..caeaad742f --- /dev/null +++ b/src/megatron/bridge/diffusion/models/README.md @@ -0,0 +1,3 @@ +# Model + +Model implementations for DiT and dLLM architectures. diff --git a/src/megatron/bridge/diffusion/models/__init__.py b/src/megatron/bridge/diffusion/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/models/common/__init__.py b/src/megatron/bridge/diffusion/models/common/__init__.py new file mode 100644 index 0000000000..d3dc45b130 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/common/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common modules for diffusion models in DFM.""" + +from megatron.bridge.diffusion.models.common.dit_attention import ( + DiTCrossAttention, + DiTCrossAttentionSubmodules, + DiTSelfAttention, +) +from megatron.bridge.diffusion.models.common.dit_embeddings import ( + FactorizedLearnable3DEmbedding, + ParallelTimestepEmbedding, + SinCosPosEmb3D, +) +from megatron.bridge.diffusion.models.common.normalization import RMSNorm + + +__all__ = [ + # Attention modules + "DiTCrossAttention", + "DiTCrossAttentionSubmodules", + "DiTSelfAttention", + # Embeddings + "FactorizedLearnable3DEmbedding", + "ParallelTimestepEmbedding", + "SinCosPosEmb3D", + # Normalization + "RMSNorm", +] diff --git a/src/megatron/bridge/diffusion/models/common/dit_attention.py b/src/megatron/bridge/diffusion/models/common/dit_attention.py new file mode 100644 index 0000000000..ef323d32c2 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/common/dit_attention.py @@ -0,0 +1,302 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import copy +from dataclasses import dataclass +from typing import Union + +import torch +from megatron.core import parallel_state, tensor_parallel +from megatron.core.extensions.transformer_engine import SplitAlongDim +from megatron.core.transformer.attention import ( + CrossAttention, + SelfAttention, + SelfAttentionSubmodules, +) +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig + + +@dataclass +class DiTCrossAttentionSubmodules: + """ + Configuration class for specifying the submodules of a cross-attention. + """ + + linear_q: Union[ModuleSpec, type] = None + linear_kv: Union[ModuleSpec, type] = None + core_attention: Union[ModuleSpec, type] = None + linear_proj: Union[ModuleSpec, type] = None + q_layernorm: Union[ModuleSpec, type] = None + k_layernorm: Union[ModuleSpec, type] = None + + +class DiTSelfAttention(SelfAttention): # noqa: D101 + def __init__( + self, + config: TransformerConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + attn_mask_type: AttnMaskType, + cp_comm_type: str = None, + pg_collection=None, + ): + super().__init__( + config, + submodules, + layer_number, + attn_mask_type, + cp_comm_type, + pg_collection, + ) + + self.layernorm_across_heads = getattr(self.config, "layernorm_across_heads", False) + + # override q_layernorm + if submodules.q_layernorm is not None: + if self.layernorm_across_heads: + q_layernorm_size = self.query_projection_size + else: + q_layernorm_size = self.hidden_size_per_attention_head + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.q_layernorm = build_module( + submodules.q_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=q_layernorm_size, + config=norm_config, + ) + else: + self.q_layernorm = None + + # override k_layernorm + if submodules.k_layernorm is not None: + if self.layernorm_across_heads: + k_layernorm_size = self.kv_projection_size + else: + k_layernorm_size = self.hidden_size_per_attention_head + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.k_layernorm = build_module( + submodules.k_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=k_layernorm_size, + config=norm_config, + ) + else: + self.k_layernorm = None + + def get_query_key_value_tensors(self, hidden_states, key_value_states=None, output_gate=None, split_qkv=True): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_qkv, _ = self.linear_qkv(hidden_states) + + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + + split_arg_list = [ + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + + if SplitAlongDim is not None: + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) + else: + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) + + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + + # gather query and key heads across TP ranks if self.layernorm_across_heads is True + if self.layernorm_across_heads and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.gather_from_tensor_model_parallel_region(query) + key = tensor_parallel.gather_from_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + + if self.q_layernorm is not None: + if self.layernorm_across_heads: + q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] + q_flat = self.q_layernorm(q_flat) + query = q_flat.view( + query.size(0), query.size(1), -1, self.hidden_size_per_attention_head + ) # [sq, b, np, hn] + else: + query = self.q_layernorm(query.contiguous()) + + if self.k_layernorm is not None: + if self.layernorm_across_heads: + k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() + k_flat = self.k_layernorm(k_flat) + key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) + else: + key = self.k_layernorm(key.contiguous()) + + # scatter query and key heads across TP ranks if self.layernorm_across_heads is True + if self.layernorm_across_heads and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.scatter_to_tensor_model_parallel_region(query) + key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = query.contiguous() # important becuase TE attention expects contiguous tensors + key = key.contiguous() # important becuase TE attention expects contiguous tensors + + if self.config.test_mode: + self.run_realtime_tests() + + return query, key, value + + +class DiTCrossAttention(CrossAttention): # noqa: D101 + def __init__( + self, + config: TransformerConfig, + submodules: DiTCrossAttentionSubmodules, + layer_number: int, + attn_mask_type: AttnMaskType, + cp_comm_type: str = None, + pg_collection=None, + ): + super().__init__( + config, + submodules, + layer_number, + attn_mask_type, + cp_comm_type, + pg_collection, + ) + + self.layernorm_across_heads = getattr(self.config, "layernorm_across_heads", False) + + # override q_layernorm + if submodules.q_layernorm is not None: + if self.layernorm_across_heads: + q_layernorm_size = self.query_projection_size + else: + q_layernorm_size = self.hidden_size_per_attention_head + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.q_layernorm = build_module( + submodules.q_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=q_layernorm_size, + config=norm_config, + ) + else: + self.q_layernorm = None + + # override k_layernorm + if submodules.k_layernorm is not None: + if self.layernorm_across_heads: + k_layernorm_size = self.kv_projection_size + else: + k_layernorm_size = self.hidden_size_per_attention_head + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.k_layernorm = build_module( + submodules.k_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=k_layernorm_size, + config=norm_config, + ) + else: + self.k_layernorm = None + + linear_kv_hidden_size = getattr(self.config, "crossattn_emb_size", self.config.hidden_size) + self.linear_kv = build_module( + submodules.linear_kv, + linear_kv_hidden_size, + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear, + skip_bias_add=False, + is_expert=False, + ) + + def get_query_key_value_tensors(self, hidden_states, key_value_states, output_gate=None, split_qkv=True): + """ + Derives `query` tensor from `hidden_states`, and `key`/`value` tensors + from `key_value_states`. + """ + + query, key, value = super().get_query_key_value_tensors( + hidden_states, key_value_states, output_gate=output_gate, split_qkv=split_qkv + ) + + # gather query and key heads across TP ranks if self.layernorm_across_heads is True + if self.layernorm_across_heads and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.gather_from_tensor_model_parallel_region(query) + key = tensor_parallel.gather_from_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + + if self.q_layernorm is not None: + if self.layernorm_across_heads: + q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] + q_flat = self.q_layernorm(q_flat) + query = q_flat.view( + query.size(0), query.size(1), -1, self.hidden_size_per_attention_head + ) # [sq, b, np, hn] + else: + query = self.q_layernorm(query.contiguous()) + + if self.k_layernorm is not None: + if self.layernorm_across_heads: + k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() + k_flat = self.k_layernorm(k_flat) + key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) + else: + key = self.k_layernorm(key.contiguous()) + + # scatter query and key heads across TP ranks if self.layernorm_across_heads is True + if self.layernorm_across_heads and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.scatter_to_tensor_model_parallel_region(query) + key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = query.contiguous() # important becuase TE attention expects contiguous tensors + key = key.contiguous() # important becuase TE attention expects contiguous tensors + + return query, key, value diff --git a/src/megatron/bridge/diffusion/models/common/dit_embeddings.py b/src/megatron/bridge/diffusion/models/common/dit_embeddings.py new file mode 100644 index 0000000000..04fdcec04d --- /dev/null +++ b/src/megatron/bridge/diffusion/models/common/dit_embeddings.py @@ -0,0 +1,164 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + + +import logging + +import torch +from diffusers.models.embeddings import TimestepEmbedding, get_3d_sincos_pos_embed +from einops import rearrange +from megatron.core import parallel_state +from megatron.core.transformer.module import MegatronModule + + +log = logging.getLogger(__name__) + + +# To be used from Common +class ParallelTimestepEmbedding(TimestepEmbedding): + """ + ParallelTimestepEmbedding is a subclass of TimestepEmbedding that initializes + the embedding layers with an optional random seed for syncronization. + + Args: + in_channels (int): Number of input channels. + time_embed_dim (int): Dimension of the time embedding. + seed (int, optional): Random seed for initializing the embedding layers. + If None, no specific seed is set. + + Attributes: + linear_1 (nn.Module): First linear layer for the embedding. + linear_2 (nn.Module): Second linear layer for the embedding. + + Methods: + __init__(in_channels, time_embed_dim, seed=None): Initializes the embedding layers. + """ + + def __init__(self, in_channels: int, time_embed_dim: int, seed=None): + super().__init__(in_channels=in_channels, time_embed_dim=time_embed_dim) + if seed is not None: + with torch.random.fork_rng(): + torch.manual_seed(seed) + self.linear_1.reset_parameters() + self.linear_2.reset_parameters() + + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + setattr(self.linear_1.weight, "pipeline_parallel", True) + setattr(self.linear_1.bias, "pipeline_parallel", True) + setattr(self.linear_2.weight, "pipeline_parallel", True) + setattr(self.linear_2.bias, "pipeline_parallel", True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the positional embeddings for the input tensor. + + Args: + x (torch.Tensor): Input tensor of shape (B, T, H, W, C). + + Returns: + torch.Tensor: Positional embeddings of shape (B, T, H, W, C). + """ + return super().forward(x.to(torch.bfloat16, non_blocking=True)) + + +class SinCosPosEmb3D(MegatronModule): + """ + SinCosPosEmb3D is a 3D sine-cosine positional embedding module. + + Args: + model_channels (int): Number of channels in the model. + h (int): Length of the height dimension. + w (int): Length of the width dimension. + t (int): Length of the temporal dimension. + spatial_interpolation_scale (float, optional): Scale factor for spatial interpolation. Default is 1.0. + temporal_interpolation_scale (float, optional): Scale factor for temporal interpolation. Default is 1.0. + + Methods: + forward(pos_ids: torch.Tensor) -> torch.Tensor: + Computes the positional embeddings for the input tensor. + + Args: + pos_ids (torch.Tensor): Input tensor of shape (B S 3). + + Returns: + torch.Tensor: Positional embeddings of shape (B S D). + """ + + def __init__( + self, + config, + h: int, + w: int, + t: int, + spatial_interpolation_scale=1.0, + temporal_interpolation_scale=1.0, + ): + super().__init__(config=config) + self.h = h + self.w = w + self.t = t + # h w t + param = get_3d_sincos_pos_embed( + config.hidden_size, [h, w], t, spatial_interpolation_scale, temporal_interpolation_scale, output_type="pt" + ) + param = rearrange(param, "t hw c -> (t hw) c") + self.pos_embedding = torch.nn.Embedding(param.shape[0], config.hidden_size) + self.pos_embedding.weight = torch.nn.Parameter(torch.tensor(param), requires_grad=False) + + def forward(self, pos_ids: torch.Tensor): + # pos_ids: t h w + pos_id = pos_ids[..., 0] * self.h * self.w + pos_ids[..., 1] * self.w + pos_ids[..., 2] + return self.pos_embedding(pos_id) + + +class FactorizedLearnable3DEmbedding(MegatronModule): # noqa: D101 + def __init__( + self, + config, + t: int, + h: int, + w: int, + **kwargs, + ): + super().__init__(config=config) + self.emb_t = torch.nn.Embedding(t, config.hidden_size) + self.emb_h = torch.nn.Embedding(h, config.hidden_size) + self.emb_w = torch.nn.Embedding(w, config.hidden_size) + + if "seed" in kwargs.keys(): + seed = kwargs["seed"] + with torch.random.fork_rng(): + torch.manual_seed(seed) + if config.perform_initialization: + self.customize_init_param() + else: + self.reset_parameters() + else: + if config.perform_initialization: + self.customize_init_param() + + def customize_init_param(self): + self.config.init_method(self.emb_t.weight) + self.config.init_method(self.emb_h.weight) + self.config.init_method(self.emb_w.weight) + + def reset_parameters(self): + self.emb_t.reset_parameters() + self.emb_h.reset_parameters() + self.emb_w.reset_parameters() + + def forward(self, pos_ids: torch.Tensor): + return self.emb_t(pos_ids[..., 0]) + self.emb_h(pos_ids[..., 1]) + self.emb_w(pos_ids[..., 2]) diff --git a/src/megatron/bridge/diffusion/models/common/normalization.py b/src/megatron/bridge/diffusion/models/common/normalization.py new file mode 100644 index 0000000000..115972ddd0 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/common/normalization.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common normalization modules for diffusion models.""" + +import torch +import torch.nn as nn + + +class RMSNorm(nn.Module): + """ + Root Mean Square Layer Normalization. + + A normalization technique that normalizes the input by its root mean square, + then scales by a learnable weight parameter. + + Args: + hidden_size: Size of the hidden dimension. + config: Transformer configuration (unused, for compatibility with megatron build_module). + eps: Small epsilon for numerical stability. + """ + + def __init__(self, hidden_size: int, config=None, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + # Compute normalization and weight scaling in float32 for numerical stability, + # then convert back to input dtype to preserve dtype throughout the model + output = self._norm(x.float()) * self.weight + return output.type_as(x) diff --git a/src/megatron/bridge/diffusion/models/flux/__init__.py b/src/megatron/bridge/diffusion/models/flux/__init__.py new file mode 100644 index 0000000000..c52268850d --- /dev/null +++ b/src/megatron/bridge/diffusion/models/flux/__init__.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FLUX diffusion model implementation for DFM. + +This module provides the FLUX model architecture, which is a state-of-the-art +text-to-image diffusion model using MMDiT-style transformer blocks. + +Components: + - Flux: Main FLUX model class + - FluxProvider: Configuration and provider dataclass for FLUX models + - MMDiTLayer: Multi-modal DiT layer for double blocks + - FluxSingleTransformerBlock: Single transformer block for FLUX + - JointSelfAttention: Joint self-attention for MMDiT layers + - FluxSingleAttention: Self-attention for single blocks + - EmbedND: N-dimensional rotary position embedding + - MLPEmbedder: MLP embedding module + - TimeStepEmbedder: Timestep embedding module + - AdaLN: Adaptive Layer Normalization + - AdaLNContinuous: Continuous Adaptive Layer Normalization +""" + +from megatron.bridge.diffusion.models.common.normalization import RMSNorm +from megatron.bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline import ( + ClipConfig, + FlowMatchEulerDiscreteScheduler, + FluxInferencePipeline, + T5Config, +) +from megatron.bridge.diffusion.models.flux.flux_attention import ( + FluxSingleAttention, + JointSelfAttention, + JointSelfAttentionSubmodules, +) +from megatron.bridge.diffusion.models.flux.flux_layer_spec import ( + AdaLN, + AdaLNContinuous, + FluxSingleTransformerBlock, + MMDiTLayer, + get_flux_double_transformer_engine_spec, + get_flux_single_transformer_engine_spec, +) +from megatron.bridge.diffusion.models.flux.flux_model import Flux +from megatron.bridge.diffusion.models.flux.flux_provider import FluxProvider +from megatron.bridge.diffusion.models.flux.layers import ( + EmbedND, + MLPEmbedder, + TimeStepEmbedder, + Timesteps, + rope, +) + + +__all__ = [ + # Main model + "Flux", + "FluxProvider", + # Transformer layers + "MMDiTLayer", + "FluxSingleTransformerBlock", + # Attention modules + "JointSelfAttention", + "JointSelfAttentionSubmodules", + "FluxSingleAttention", + # Normalization + "AdaLN", + "AdaLNContinuous", + "RMSNorm", + # Embeddings + "EmbedND", + "MLPEmbedder", + "TimeStepEmbedder", + "Timesteps", + "rope", + # Layer specs + "get_flux_double_transformer_engine_spec", + "get_flux_single_transformer_engine_spec", + # Inference pipeline + "FluxInferencePipeline", + "FlowMatchEulerDiscreteScheduler", + "T5Config", + "ClipConfig", + "AutoEncoderConfig", +] diff --git a/src/megatron/bridge/diffusion/models/flux/flow_matching/__init__.py b/src/megatron/bridge/diffusion/models/flux/flow_matching/__init__.py new file mode 100644 index 0000000000..abe9af832b --- /dev/null +++ b/src/megatron/bridge/diffusion/models/flux/flow_matching/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Flow matching components for FLUX model. + +This module contains the flow matching specific code, separated from the +model architecture to maintain consistency with other flow matching models +like WAN. +""" + +from megatron.bridge.diffusion.models.flux.flow_matching.flux_adapter import MegatronFluxAdapter +from megatron.bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline import ( + ClipConfig, + FlowMatchEulerDiscreteScheduler, + FluxInferencePipeline, + T5Config, +) + + +__all__ = [ + "MegatronFluxAdapter", + "FluxInferencePipeline", + "FlowMatchEulerDiscreteScheduler", + "T5Config", + "ClipConfig", +] diff --git a/src/megatron/bridge/diffusion/models/flux/flow_matching/flux_adapter.py b/src/megatron/bridge/diffusion/models/flux/flow_matching/flux_adapter.py new file mode 100644 index 0000000000..4d3e636710 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/flux/flow_matching/flux_adapter.py @@ -0,0 +1,222 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Megatron-specific adapter for FLUX models using the automodel FlowMatching pipeline. +""" + +import random +from typing import Any, Dict + +import torch +from dfm.src.automodel.flow_matching.adapters.base import FlowMatchingContext, ModelAdapter +from megatron.core.models.common.vision_module.vision_module import VisionModule + + +class MegatronFluxAdapter(ModelAdapter): + """ + Adapter for FLUX models in Megatron training framework. + + Key differences from standard FluxAdapter: + - Handles sequence-first tensor layout [S, B, ...] required by Megatron + - Integrates with pipeline parallelism + - Maps Megatron batch keys to expected format + - Handles guidance embedding for FLUX-dev models + """ + + def __init__(self, guidance_scale: float = 3.5): + """ + Initialize MegatronFluxAdapter. + + Args: + guidance_scale: Guidance scale for classifier-free guidance + """ + self.guidance_scale = guidance_scale + + def _pack_latents(self, latents: torch.Tensor) -> torch.Tensor: + """ + Pack latents from [B, C, H, W] to Flux format [B, (H//2)*(W//2), C*4]. + + Flux uses a 2x2 patch embedding, so latents are reshaped accordingly. + """ + b, c, h, w = latents.shape + latents = latents.view(b, c, h // 2, 2, w // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(b, (h // 2) * (w // 2), c * 4) + return latents + + def _unpack_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + Unpack latents from Flux format [B, num_patches, C*4] back to [B, C, H, W]. + + Args: + latents: Packed latents of shape [B, num_patches, channels] + height: Target latent height + width: Target latent width + """ + batch_size, num_patches, channels = latents.shape + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + latents = latents.reshape(batch_size, channels // 4, height, width) + return latents + + def _prepare_latent_image_ids( + self, + batch_size: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """ + Prepare positional IDs for image latents. + + Returns tensor of shape [B, (H//2)*(W//2), 3] containing (batch_idx, y, x). + """ + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = torch.arange(width // 2)[None, :] + + latent_image_ids = latent_image_ids.reshape(-1, 3) + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1) + return latent_image_ids.to(device=device, dtype=dtype) + + def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: + """ + Prepare inputs for Megatron Flux model from FlowMatchingContext. + + Handles batch key mapping: + - Megatron uses: latents, prompt_embeds, pooled_prompt_embeds, text_ids + - Automodel expects: image_latents, text_embeddings, pooled_prompt_embeds + """ + batch = context.batch + device = context.device + dtype = context.dtype + + # Get model reference if passed in batch for guidance check + model = batch.get("_model") + + # Get latents - Megatron uses 'latents' key + noisy_latents = context.noisy_latents + if noisy_latents.ndim != 4: + raise ValueError(f"MegatronFluxAdapter expects 4D latents [B, C, H, W], got {noisy_latents.ndim}D") + + batch_size, channels, height, width = noisy_latents.shape + + # Get text embeddings - Megatron uses 'prompt_embeds' (T5) + if "prompt_embeds" in batch: + # Megatron stores as [S, B, D], need to transpose to [B, S, D] + text_embeddings = batch["prompt_embeds"] + if text_embeddings.shape[1] == batch_size: # Already [S, B, D] + text_embeddings = text_embeddings.transpose(0, 1).to(device, dtype=dtype) + else: + text_embeddings = text_embeddings.to(device, dtype=dtype) + else: + raise ValueError("Expected 'prompt_embeds' in batch for Megatron FLUX training") + + # Get pooled embeddings (CLIP) + if "pooled_prompt_embeds" in batch: + pooled_projections = batch["pooled_prompt_embeds"].to(device, dtype=dtype) + else: + pooled_projections = torch.zeros(batch_size, 768, device=device, dtype=dtype) + + if pooled_projections.ndim == 1: + pooled_projections = pooled_projections.unsqueeze(0) + + # Apply CFG dropout if needed + if random.random() < context.cfg_dropout_prob: + text_embeddings = torch.zeros_like(text_embeddings) + pooled_projections = torch.zeros_like(pooled_projections) + + # Pack latents for Flux transformer + packed_latents = self._pack_latents(noisy_latents) + + # Prepare positional IDs + img_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + # Text positional IDs + if "text_ids" in batch: + txt_ids = batch["text_ids"].to(device, dtype=dtype) + else: + text_seq_len = text_embeddings.shape[1] + txt_ids = torch.zeros(batch_size, text_seq_len, 3, device=device, dtype=dtype) + + # Timesteps - normalize to [0, 1] for FLUX + timesteps = context.timesteps.to(dtype) / 1000.0 + + # Guidance vector for FLUX-dev (only if model supports it) + # Exactly match original implementation pattern + guidance = None + if model is not None: + # Unwrap model wrappers (DDP, etc.) + unwrapped = model + while hasattr(unwrapped, "module"): + unwrapped = unwrapped.module + + # Check if model has guidance enabled (matches original flux_step.py logic) + if hasattr(unwrapped, "guidance_embed") and unwrapped.guidance_embed: + guidance = torch.full((batch_size,), self.guidance_scale, device=device, dtype=torch.float32) + + # Transpose to sequence-first for Megatron: [B, ...] -> [S, B, ...] + packed_latents = packed_latents.transpose(0, 1) + text_embeddings = text_embeddings.transpose(0, 1) + + inputs = { + "img": packed_latents, + "txt": text_embeddings, + "y": pooled_projections, + "timesteps": timesteps, + "img_ids": img_ids, + "txt_ids": txt_ids, + # Store original shape for unpacking + "_original_shape": (batch_size, channels, height, width), + } + + # Only add guidance if model supports it + if guidance is not None: + inputs["guidance"] = guidance + + return inputs + + def forward(self, model: VisionModule, inputs: Dict[str, Any]) -> torch.Tensor: + """ + Execute forward pass for Megatron Flux model. + + Returns unpacked prediction in [B, C, H, W] format. + """ + original_shape = inputs.pop("_original_shape") + batch_size, channels, height, width = original_shape + + # Megatron forward pass (guidance may be None if model doesn't support it) + model_pred = model( + img=inputs["img"], + txt=inputs["txt"], + y=inputs["y"], + timesteps=inputs["timesteps"], + img_ids=inputs["img_ids"], + txt_ids=inputs["txt_ids"], + guidance=inputs.get("guidance"), # Use .get() in case it's None + ) + + # Handle potential tuple output and transpose back from sequence-first + if isinstance(model_pred, tuple): + model_pred = model_pred[0] + + # Transpose from [S, B, D] to [B, S, D] + model_pred = model_pred.transpose(0, 1) + + # Unpack from Flux format back to [B, C, H, W] + model_pred = self._unpack_latents(model_pred, height, width) + + return model_pred diff --git a/src/megatron/bridge/diffusion/models/flux/flow_matching/flux_inference_pipeline.py b/src/megatron/bridge/diffusion/models/flux/flow_matching/flux_inference_pipeline.py new file mode 100644 index 0000000000..f207c94ae2 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/flux/flow_matching/flux_inference_pipeline.py @@ -0,0 +1,693 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FLUX inference pipeline for text-to-image generation.""" + +import math +import os +from dataclasses import dataclass, field +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image +from torch import nn +from tqdm import tqdm + +from megatron.bridge.diffusion.models.flux.flux_provider import FluxProvider +from megatron.bridge.training.model_load_save import load_megatron_model as _load_megatron_model + + +@dataclass +class T5Config: + """T5 encoder configuration.""" + + version: Optional[str] = field(default_factory=lambda: "google/t5-v1_1-xxl") + max_length: Optional[int] = field(default_factory=lambda: 512) + load_config_only: bool = False + device: str = "cuda" + + +@dataclass +class ClipConfig: + """CLIP encoder configuration.""" + + version: Optional[str] = field(default_factory=lambda: "openai/clip-vit-large-patch14") + max_length: Optional[int] = field(default_factory=lambda: 77) + always_return_pooled: Optional[bool] = field(default_factory=lambda: True) + device: str = "cuda" + + +class FlowMatchEulerDiscreteScheduler: + """ + Euler scheduler. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + self.base_shift = base_shift + self.max_shift = max_shift + self.base_image_seq_len = base_image_seq_len + self.max_image_seq_len = max_image_seq_len + self.use_dynamic_shifting = use_dynamic_shifting + self.num_train_timesteps = num_train_timesteps + self.shift = shift + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.num_train_timesteps + + if self.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.num_train_timesteps + + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + ) -> Tuple: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + + Returns: + A tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + prev_sample = sample + (sigma_next - sigma) * model_output + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + return (prev_sample,) + + def __len__(self): + return self.num_train_timesteps + + +class FluxInferencePipeline(nn.Module): + """ + FLUX inference pipeline for text-to-image generation. + + This pipeline orchestrates the full inference process including: + - Text encoding with T5 and CLIP + - Latent preparation and denoising + - VAE decoding to images + + Args: + params: FluxModelParams configuration. + flux: Optional pre-initialized Flux model. + scheduler_steps: Number of scheduler steps. + + Example: + >>> params = FluxModelParams() + >>> pipeline = FluxInferencePipeline(params) + >>> pipeline.load_from_pretrained("path/to/flux_ckpt") + >>> images = pipeline( + ... prompt=["A cat holding a sign that says hello world"], + ... height=1024, + ... width=1024, + ... num_inference_steps=20, + ... ) + """ + + def __init__( + self, + flux_checkpoint_dir: Optional[str] = None, + t5_checkpoint_dir: Optional[str] = None, + clip_checkpoint_dir: Optional[str] = None, + vae_checkpoint_dir: Optional[str] = None, + scheduler_steps: int = 1000, + ): + super().__init__() + + # Initialize transformer + self.transformer = self.setup_model_from_checkpoint(flux_checkpoint_dir) + self.device = "cuda:0" + + # VAE scale factor based on channel multipliers + # if params and params.vae_config: + # self.vae_scale_factor = 2 ** len(params.vae_config.ch_mult) + # else: + self.vae_scale_factor = 16 # Default for FLUX + + # Initialize scheduler + self.scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=scheduler_steps) + + # Placeholders for encoders (to be loaded separately) + self.load_text_encoders(t5_checkpoint_dir, clip_checkpoint_dir) + self.load_vae(vae_checkpoint_dir) + + def setup_model_from_checkpoint(self, checkpoint_dir): + provider = FluxProvider() + # provider.tensor_model_parallel_size = self.tensor_parallel_size + # provider.pipeline_model_parallel_size = self.pipeline_parallel_size + # provider.context_parallel_size = self.context_parallel_size + # provider.sequence_parallel = self.sequence_parallel + # provider.pipeline_dtype = self.pipeline_dtype + # Once all overrides are set, finalize the model provider to ensure the post initialization logic is run + provider.finalize() + provider.initialize_model_parallel(seed=0) + + ## Read from megatron checkpoint + model = _load_megatron_model( + checkpoint_dir, + # mp_overrides={ + # "tensor_model_parallel_size": self.tensor_parallel_size, + # "pipeline_model_parallel_size": self.pipeline_parallel_size, + # "context_parallel_size": self.context_parallel_size, + # "sequence_parallel": self.sequence_parallel, + # "pipeline_dtype": self.pipeline_dtype, + # }, + ) + if isinstance(model, list): + model = model[0] + if hasattr(model, "module"): + model = model.module + + return model + + def load_text_encoders(self, t5_version: str = None, clip_version: str = None): + """ + Load T5 and CLIP text encoders. + + Args: + t5_version: HuggingFace model ID or path for T5. + clip_version: HuggingFace model ID or path for CLIP. + """ + try: + from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer + + # Load T5 + t5_version = t5_version or "google/t5-v1_1-xxl" + print(f"Loading T5 encoder from {t5_version}...") + self.t5_tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl") + self.t5_encoder = T5EncoderModel.from_pretrained(t5_version).to(self.device).eval() + + # Load CLIP + clip_version = clip_version or "openai/clip-vit-large-patch14" + print(f"Loading CLIP encoder from {clip_version}...") + self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_encoder = CLIPTextModel.from_pretrained(clip_version).to(self.device).eval() + + print("Text encoders loaded successfully") + except ImportError: + raise ImportError("Please install transformers: pip install transformers") + + def load_vae(self, vae_path: str): + """ + Load VAE from checkpoint. + + Args: + vae_path: Path to VAE checkpoint (ae.safetensors). + """ + try: + from diffusers import AutoencoderKL + + self.vae = AutoencoderKL.from_pretrained(vae_path).to(self.device).eval() + except ImportError: + raise ImportError("Please install diffusers: pip install diffusers") + + def encode_prompt( + self, + prompt: Union[str, List[str]], + max_sequence_length: int = 512, + num_images_per_prompt: int = 1, + device: str = "cuda", + dtype: torch.dtype = torch.float32, + ): + """ + Encode text prompts using T5 and CLIP. + + Returns: + Tuple of (prompt_embeds, pooled_prompt_embeds, text_ids). + """ + if isinstance(prompt, str): + prompt = [prompt] + batch_size = len(prompt) + + # T5 encoding + t5_inputs = self.t5_tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + + with torch.no_grad(): + prompt_embeds = self.t5_encoder(input_ids=t5_inputs.input_ids).last_hidden_state + + # CLIP encoding + clip_inputs = self.clip_tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="pt", + ).to(device) + + with torch.no_grad(): + clip_output = self.clip_encoder(input_ids=clip_inputs.input_ids) + pooled_prompt_embeds = clip_output.pooler_output + + # Repeat for multiple images per prompt + seq_len = prompt_embeds.shape[1] + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1).to(dtype=dtype) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1).to(dtype=dtype) + + # Create text IDs + text_ids = torch.zeros(batch_size * num_images_per_prompt, seq_len, 3, device=device, dtype=dtype) + + return prompt_embeds.transpose(0, 1), pooled_prompt_embeds, text_ids + + @staticmethod + def _prepare_latent_image_ids(batch_size: int, height: int, width: int, device: torch.device, dtype: torch.dtype): + """Prepare latent image IDs for position encoding.""" + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape(batch_size, (height // 2) * (width // 2), 3) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + """Pack latents for FLUX processing.""" + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + """Unpack latents for VAE decoding.""" + batch_size, num_patches, channels = latents.shape + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + latents = latents.reshape(batch_size, channels // 4, height * 2, width * 2) + return latents + + @staticmethod + def _calculate_shift(image_seq_len, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.16): + """Calculate timestep shift based on sequence length.""" + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + return image_seq_len * m + b + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator=None): + """Prepare random latents for generation.""" + height = 2 * int(height) // self.vae_scale_factor + width = 2 * int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents.transpose(0, 1), latent_image_ids + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 10, + guidance_scale: float = 3.5, + num_images_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + max_sequence_length: int = 512, + output_type: str = "pil", + output_path: Optional[str] = None, + dtype: torch.dtype = torch.bfloat16, + ): + """ + Generate images from text prompts. + + Args: + prompt: Text prompt(s) for image generation. + height: Output image height. + width: Output image width. + num_inference_steps: Number of denoising steps. + guidance_scale: Classifier-free guidance scale. + num_images_per_prompt: Number of images per prompt. + generator: Random number generator for reproducibility. + max_sequence_length: Maximum sequence length for text encoding. + output_type: "pil" for PIL images, "latent" for latent tensors. + output_path: Path to save generated images. + dtype: Data type for inference. + + Returns: + List of PIL images or latent tensors. + """ + device = self.device + + if isinstance(prompt, str): + prompt = [prompt] + batch_size = len(prompt) + + # Encode prompts + if self.t5_encoder is None or self.clip_encoder is None: + raise RuntimeError("Text encoders not loaded. Call load_text_encoders() first.") + + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + prompt, + max_sequence_length=max_sequence_length, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + ) + + # Prepare latents + num_channels_latents = self.transformer.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dtype, + device, + generator, + ) + + # Setup timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[0] + mu = self._calculate_shift( + image_seq_len, + self.scheduler.base_image_seq_len, + self.scheduler.max_image_seq_len, + self.scheduler.base_shift, + self.scheduler.max_shift, + ) + self.scheduler.set_timesteps(sigmas=sigmas, device=device, mu=mu) + timesteps = self.scheduler.timesteps + + # Denoising loop + for t in tqdm(timesteps, desc="Denoising"): + timestep = t.expand(latents.shape[1]).to(device=device, dtype=dtype) + + if self.transformer.guidance_embed: + guidance = torch.full((latents.shape[1],), guidance_scale, device=device, dtype=dtype) + else: + guidance = None + + with torch.autocast(device_type="cuda", dtype=dtype): + pred = self.transformer( + img=latents, + txt=prompt_embeds, + y=pooled_prompt_embeds, + timesteps=timestep / 1000, + img_ids=latent_image_ids, + txt_ids=text_ids, + guidance=guidance, + ) + latents = self.scheduler.step(pred, t, latents)[0] + + if output_type == "latent": + return latents.transpose(0, 1) + + # Decode latents to images + if self.vae is None: + raise RuntimeError("VAE not loaded. Call load_vae() first.") + + latents = self._unpack_latents(latents.transpose(0, 1), height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + with torch.autocast(device_type="cuda", dtype=dtype): + images = self.vae.decode(latents, return_dict=False)[0] + + # Post-process + images = FluxInferencePipeline.denormalize(images) + images = FluxInferencePipeline.torch_to_numpy(images) + images = FluxInferencePipeline.numpy_to_pil(images) + + # Save if requested + if output_path: + os.makedirs(output_path, exist_ok=True) + assert len(images) == int(len(prompt) * num_images_per_prompt) + prompt = [p[:40] + f"_{idx}" for p in prompt for idx in range(num_images_per_prompt)] + for file_name, image in zip(prompt, images): + image.save(os.path.join(output_path, f"{file_name}.png")) + + return images + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + @staticmethod + def torch_to_numpy(images): + """ + Convert a torch image or a batch of images to a numpy image. + """ + numpy_images = images.float().cpu().permute(0, 2, 3, 1).numpy() + return numpy_images + + @staticmethod + def denormalize(image): + # pylint: disable=C0116 + return (image / 2 + 0.5).clamp(0, 1) diff --git a/src/megatron/bridge/diffusion/models/flux/flux_attention.py b/src/megatron/bridge/diffusion/models/flux/flux_attention.py new file mode 100644 index 0000000000..72bd1f91a2 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/flux/flux_attention.py @@ -0,0 +1,516 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FLUX attention modules for diffusion models.""" + +from dataclasses import dataclass +from typing import Union + +import torch +from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb +from megatron.core.transformer.attention import Attention, SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig + + +try: + from megatron.core.extensions.transformer_engine import SplitAlongDim +except ImportError: + try: + from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim + except ImportError: + SplitAlongDim = None + + +@dataclass +class JointSelfAttentionSubmodules: + """ + Submodules for Joint Self-attention layer. + + Used for MMDIT-like transformer blocks in FLUX. + """ + + linear_qkv: Union[ModuleSpec, type] = None + added_linear_qkv: Union[ModuleSpec, type] = None + core_attention: Union[ModuleSpec, type] = None + linear_proj: Union[ModuleSpec, type] = None + q_layernorm: Union[ModuleSpec, type] = None + k_layernorm: Union[ModuleSpec, type] = None + added_q_layernorm: Union[ModuleSpec, type] = None + added_k_layernorm: Union[ModuleSpec, type] = None + + +class JointSelfAttention(Attention): + """ + Joint Self-attention layer class. + + Used for MMDIT-like transformer blocks in FLUX double blocks. + This attention layer processes both image hidden states and text encoder + hidden states jointly. + + Args: + config: Transformer configuration. + submodules: Joint self-attention submodules specification. + layer_number: Layer index in the transformer. + attn_mask_type: Type of attention mask to use. + context_pre_only: Whether to only use context for pre-processing. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: JointSelfAttentionSubmodules, + layer_number: int, + attn_mask_type=AttnMaskType.padding, + context_pre_only: bool = False, + **kwargs, + ): + # Use RMSnorm for qk norm + config.normalization = "RMSNorm" + super().__init__( + config=config, + submodules=submodules, + layer_number=layer_number, + attn_mask_type=attn_mask_type, + attention_type="self", + **kwargs, + ) + + self.linear_qkv = build_module( + submodules.linear_qkv, + self.config.hidden_size, + self.query_projection_size + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear or self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name="qkv", + ) + + if submodules.added_linear_qkv is not None: + self.added_linear_qkv = build_module( + submodules.added_linear_qkv, + self.config.hidden_size, + self.query_projection_size + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name="qkv", + ) + + if not context_pre_only: + self.added_linear_proj = build_module( + submodules.linear_proj, + self.query_projection_size, + self.config.hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name="proj", + ) + + if submodules.q_layernorm is not None: + self.q_layernorm = build_module( + submodules.q_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.q_layernorm = None + + if submodules.k_layernorm is not None: + self.k_layernorm = build_module( + submodules.k_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.k_layernorm = None + + if submodules.added_q_layernorm is not None: + self.added_q_layernorm = build_module( + submodules.added_q_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.added_q_layernorm = None + + if submodules.added_k_layernorm is not None: + self.added_k_layernorm = build_module( + submodules.added_k_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.added_k_layernorm = None + + def _split_qkv(self, mixed_qkv): + """Split mixed QKV tensor into separate Q, K, V tensors.""" + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + + split_arg_list = [ + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + + if SplitAlongDim is not None: + # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = SplitAlongDim( + mixed_qkv, + 3, + split_arg_list, + ) + else: + # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = torch.split( + mixed_qkv, + split_arg_list, + dim=3, + ) + + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + return query, key, value + + def get_query_key_value_tensors(self, hidden_states, key_value_states=None): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_qkv, _ = self.linear_qkv(hidden_states) + + query, key, value = self._split_qkv(mixed_qkv) + + if self.config.test_mode: + self.run_realtime_tests() + + if self.q_layernorm is not None: + query = self.q_layernorm(query) + + if self.k_layernorm is not None: + key = self.k_layernorm(key) + + return query, key, value + + def get_added_query_key_value_tensors(self, added_hidden_states, key_value_states=None): + """ + Derives `query`, `key` and `value` tensors from `added_hidden_states`. + """ + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_qkv, _ = self.added_linear_qkv(added_hidden_states) + + query, key, value = self._split_qkv(mixed_qkv) + + if self.config.test_mode: + self.run_realtime_tests() + + if self.added_q_layernorm is not None: + query = self.added_q_layernorm(query) + + if self.added_k_layernorm is not None: + key = self.added_k_layernorm(key) + + return query, key, value + + def forward( + self, + hidden_states, + attention_mask, + key_value_states=None, + inference_params=None, + rotary_pos_emb=None, + packed_seq_params=None, + additional_hidden_states=None, + ): + """ + Forward pass for joint self-attention. + + Args: + hidden_states: Image hidden states [sq, b, h]. + attention_mask: Attention mask. + key_value_states: Optional key-value states. + inference_params: Inference parameters. + rotary_pos_emb: Rotary position embeddings. + packed_seq_params: Packed sequence parameters. + additional_hidden_states: Text encoder hidden states. + + Returns: + Tuple of (image_attention_output, encoder_attention_output). + """ + # hidden_states: [sq, b, h] + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + query, key, value = self.get_query_key_value_tensors(hidden_states) + added_query, added_key, added_value = self.get_added_query_key_value_tensors(additional_hidden_states) + + query = torch.cat([added_query, query], dim=0) + key = torch.cat([added_key, key], dim=0) + value = torch.cat([added_value, value], dim=0) + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + query, key, value, rotary_pos_emb, attn_mask_type, *_ = self._adjust_key_value_for_inference( + inference_params, query, key, value, rotary_pos_emb + ) + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + + # ================================================ + # Relative positional embedding (rotary embedding) + # ================================================ + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + query = apply_rotary_pos_emb( + query, + q_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_q, + ) + key = apply_rotary_pos_emb( + key, + k_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_kv, + ) + + # ================================== + # Core attention computation + # ================================== + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, + ) + else: + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, + ) + + if packed_seq_params is not None: + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + # ================= + # Output. [sq, b, h] + # ================= + encoder_attention_output = core_attn_out[: additional_hidden_states.shape[0], :, :] + attention_output = core_attn_out[additional_hidden_states.shape[0] :, :, :] + + output, bias = self.linear_proj(attention_output) + encoder_output, encoder_bias = self.added_linear_proj(encoder_attention_output) + + output = output + bias + encoder_output = encoder_output + encoder_bias + + return output, encoder_output + + +class FluxSingleAttention(SelfAttention): + """ + Self-attention layer class for FLUX single transformer blocks. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + + Args: + config: Transformer configuration. + submodules: Self-attention submodules specification. + layer_number: Layer index in the transformer. + attn_mask_type: Type of attention mask to use. + cp_comm_type: Context parallel communication type. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + attn_mask_type=AttnMaskType.padding, + cp_comm_type: str = None, + **kwargs, + ): + # Use RMSnorm for qk norm + config.normalization = "RMSNorm" + super().__init__( + config=config, + submodules=submodules, + layer_number=layer_number, + attn_mask_type=attn_mask_type, + cp_comm_type=cp_comm_type, + **kwargs, + ) + self.linear_proj = build_module( + submodules.linear_proj, + self.query_projection_size, + self.config.hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=False, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name="proj", + ) + + def forward( + self, + hidden_states, + attention_mask, + key_value_states=None, + inference_params=None, + rotary_pos_emb=None, + packed_seq_params=None, + ): + """ + Forward pass for FLUX single attention. + + Args: + hidden_states: Input hidden states [sq, b, h]. + attention_mask: Attention mask. + key_value_states: Optional key-value states. + inference_params: Inference parameters. + rotary_pos_emb: Rotary position embeddings. + packed_seq_params: Packed sequence parameters. + + Returns: + Attention output tensor. + """ + # hidden_states: [sq, b, h] + + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + query, key, value, rotary_pos_emb, attn_mask_type, *_ = self._adjust_key_value_for_inference( + inference_params, query, key, value, rotary_pos_emb + ) + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + + # ================================================ + # Relative positional embedding (rotary embedding) + # ================================================ + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + query = apply_rotary_pos_emb( + query, + q_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_q, + ) + key = apply_rotary_pos_emb( + key, + k_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_kv, + ) + + # ================================== + # Core attention computation + # ================================== + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, + ) + else: + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, + ) + + if packed_seq_params is not None: + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + output, _ = self.linear_proj(core_attn_out) + return output diff --git a/src/megatron/bridge/diffusion/models/flux/flux_layer_spec.py b/src/megatron/bridge/diffusion/models/flux/flux_layer_spec.py new file mode 100644 index 0000000000..0cde58043b --- /dev/null +++ b/src/megatron/bridge/diffusion/models/flux/flux_layer_spec.py @@ -0,0 +1,474 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FLUX layer specifications and transformer blocks.""" + +import copy + +import torch +import torch.nn as nn +from megatron.core.jit import jit_fuser +from megatron.core.tensor_parallel.layers import ColumnParallelLinear +from megatron.core.transformer.attention import SelfAttentionSubmodules +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +from megatron.bridge.diffusion.models.common.normalization import RMSNorm +from megatron.bridge.diffusion.models.flux.flux_attention import ( + FluxSingleAttention, + JointSelfAttention, + JointSelfAttentionSubmodules, +) + + +try: + from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TENorm, + TERowParallelLinear, + ) +except ImportError: + TEColumnParallelLinear = None + TEDotProductAttention = None + TENorm = None + TERowParallelLinear = None + +try: + from megatron.core.transformer.cuda_graphs import CudaGraphManager +except ImportError: + CudaGraphManager = None + + +class AdaLN(MegatronModule): + """ + Adaptive Layer Normalization Module for DiT/FLUX models. + + Implements adaptive layer normalization that conditions on timestep embeddings. + + Args: + config: Transformer configuration. + n_adaln_chunks: Number of adaptive LN chunks for modulation outputs. + norm: Normalization type to use. + modulation_bias: Whether to use bias in modulation layers. + use_second_norm: Whether to use a second layer norm. + """ + + def __init__( + self, + config: TransformerConfig, + n_adaln_chunks: int = 9, + norm=nn.LayerNorm, + modulation_bias: bool = False, + use_second_norm: bool = False, + ): + super().__init__(config) + if norm == TENorm: + self.ln = norm(config, config.hidden_size, config.layernorm_epsilon) + else: + self.ln = norm(config.hidden_size, elementwise_affine=False, eps=self.config.layernorm_epsilon) + self.n_adaln_chunks = n_adaln_chunks + self.activation = nn.SiLU() + self.linear = ColumnParallelLinear( + config.hidden_size, + self.n_adaln_chunks * config.hidden_size, + config=config, + init_method=nn.init.normal_, + bias=modulation_bias, + gather_output=True, + ) + self.use_second_norm = use_second_norm + if self.use_second_norm: + self.ln2 = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) + nn.init.constant_(self.linear.weight, 0) + + setattr(self.linear.weight, "sequence_parallel", config.sequence_parallel) + + @jit_fuser + def forward(self, timestep_emb): + """Apply adaptive layer normalization modulation.""" + output = self.activation(timestep_emb) + output, bias = self.linear(output) + output = output + bias if bias is not None else output + return output.chunk(self.n_adaln_chunks, dim=-1) + + @jit_fuser + def modulate(self, x, shift, scale): + """Apply modulation with shift and scale.""" + return x * (1 + scale) + shift + + @jit_fuser + def scale_add(self, residual, x, gate): + """Add gated output to residual.""" + return residual + gate * x + + @jit_fuser + def modulated_layernorm(self, x, shift, scale, layernorm_idx=0): + """Apply layer norm followed by modulation.""" + if self.use_second_norm and layernorm_idx == 1: + layernorm = self.ln2 + else: + layernorm = self.ln + # Optional Input Layer norm + input_layernorm_output = layernorm(x).type_as(x) + + # DiT block specific + return self.modulate(input_layernorm_output, shift, scale) + + @jit_fuser + def scaled_modulated_layernorm(self, residual, x, gate, shift, scale, layernorm_idx=0): + """Apply scale, add, and modulated layer norm.""" + hidden_states = self.scale_add(residual, x, gate) + shifted_pre_mlp_layernorm_output = self.modulated_layernorm(hidden_states, shift, scale, layernorm_idx) + return hidden_states, shifted_pre_mlp_layernorm_output + + +class AdaLNContinuous(MegatronModule): + """ + A variant of AdaLN used for FLUX models. + + Continuous adaptive layer normalization that outputs scale and shift + directly from conditioning embeddings. + + Args: + config: Transformer configuration. + conditioning_embedding_dim: Dimension of the conditioning embedding. + modulation_bias: Whether to use bias in modulation layer. + norm_type: Type of normalization ("layer_norm" or "rms_norm"). + """ + + def __init__( + self, + config: TransformerConfig, + conditioning_embedding_dim: int, + modulation_bias: bool = True, + norm_type: str = "layer_norm", + ): + super().__init__(config) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(conditioning_embedding_dim, config.hidden_size * 2, bias=modulation_bias) + ) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6, bias=modulation_bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(config.hidden_size, eps=1e-6) + else: + raise ValueError(f"Unknown normalization type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + """Apply continuous adaptive layer normalization.""" + emb = self.adaLN_modulation(conditioning_embedding) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale) + shift + return x + + +class MMDiTLayer(TransformerLayer): + """ + Multi-modal transformer layer for FLUX double blocks. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + MMDiT layer implementation from [https://arxiv.org/pdf/2403.03206]. + + Args: + config: Transformer configuration. + submodules: Transformer layer submodules. + layer_number: Layer index. + context_pre_only: Whether to only use context for pre-processing. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + context_pre_only: bool = False, + ): + hidden_size = config.hidden_size + super().__init__(config=config, submodules=submodules, layer_number=layer_number) + + # Enable per-Transformer layer cuda graph + if CudaGraphManager is not None and config.enable_cuda_graph and config.cuda_graph_scope != "full_iteration": + self.cudagraph_manager = CudaGraphManager(config, share_cudagraph_io_buffers=False) + + self.adaln = AdaLN(config, modulation_bias=True, n_adaln_chunks=6, use_second_norm=True) + + self.context_pre_only = context_pre_only + context_norm_type = "ada_norm_continuous" if context_pre_only else "ada_norm_zero" + + if context_norm_type == "ada_norm_continuous": + self.adaln_context = AdaLNContinuous(config, hidden_size, modulation_bias=True, norm_type="layer_norm") + elif context_norm_type == "ada_norm_zero": + self.adaln_context = AdaLN(config, modulation_bias=True, n_adaln_chunks=6, use_second_norm=True) + else: + raise ValueError( + f"Unknown context_norm_type: {context_norm_type}, " + f"currently only support `ada_norm_continous`, `ada_norm_zero`" + ) + + # Override config for context MLP to disable CP. + # Disable TP Comm overlap as well. + cp_override_config = copy.deepcopy(config) + cp_override_config.context_parallel_size = 1 + cp_override_config.tp_comm_overlap = False + + if not context_pre_only: + self.context_mlp = build_module( + submodules.mlp, + config=cp_override_config, + ) + else: + self.context_mlp = None + + def forward( + self, + hidden_states, + encoder_hidden_states, + attention_mask=None, + context=None, + context_mask=None, + rotary_pos_emb=None, + inference_params=None, + packed_seq_params=None, + emb=None, + ): + """ + Forward pass for MMDiT layer. + + Args: + hidden_states: Image hidden states. + encoder_hidden_states: Text encoder hidden states. + attention_mask: Attention mask. + context: Context tensor (unused). + context_mask: Context mask (unused). + rotary_pos_emb: Rotary position embeddings. + inference_params: Inference parameters. + packed_seq_params: Packed sequence parameters. + emb: Timestep/conditioning embedding. + + Returns: + Tuple of (hidden_states, encoder_hidden_states). + """ + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaln(emb) + + norm_hidden_states = self.adaln.modulated_layernorm( + hidden_states, shift=shift_msa, scale=scale_msa, layernorm_idx=0 + ) + if self.context_pre_only: + norm_encoder_hidden_states = self.adaln_context(encoder_hidden_states, emb) + else: + c_shift_msa, c_scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.adaln_context(emb) + norm_encoder_hidden_states = self.adaln_context.modulated_layernorm( + encoder_hidden_states, shift=c_shift_msa, scale=c_scale_msa, layernorm_idx=0 + ) + + attention_output, encoder_attention_output = self.self_attention( + norm_hidden_states, + attention_mask=attention_mask, + key_value_states=None, + additional_hidden_states=norm_encoder_hidden_states, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = self.adaln.scale_add(hidden_states, x=attention_output, gate=gate_msa) + norm_hidden_states = self.adaln.modulated_layernorm( + hidden_states, shift=shift_mlp, scale=scale_mlp, layernorm_idx=1 + ) + + mlp_output, mlp_output_bias = self.mlp(norm_hidden_states) + hidden_states = self.adaln.scale_add(hidden_states, x=(mlp_output + mlp_output_bias), gate=gate_mlp) + + if self.context_pre_only: + encoder_hidden_states = None + else: + encoder_hidden_states = self.adaln_context.scale_add( + encoder_hidden_states, x=encoder_attention_output, gate=c_gate_msa + ) + norm_encoder_hidden_states = self.adaln_context.modulated_layernorm( + encoder_hidden_states, shift=c_shift_mlp, scale=c_scale_mlp, layernorm_idx=1 + ) + + context_mlp_output, context_mlp_output_bias = self.context_mlp(norm_encoder_hidden_states) + encoder_hidden_states = self.adaln.scale_add( + encoder_hidden_states, x=(context_mlp_output + context_mlp_output_bias), gate=c_gate_mlp + ) + + return hidden_states, encoder_hidden_states + + def __call__(self, *args, **kwargs): + if hasattr(self, "cudagraph_manager"): + return self.cudagraph_manager(self, args, kwargs) + return super(MegatronModule, self).__call__(*args, **kwargs) + + +class FluxSingleTransformerBlock(TransformerLayer): + """ + FLUX Single Transformer Block. + + Single transformer layer mathematically equivalent to original Flux single transformer. + This layer is re-implemented with megatron-core and altered in structure for better performance. + + Args: + config: Transformer configuration. + submodules: Transformer layer submodules. + layer_number: Layer index. + mlp_ratio: MLP hidden size ratio. + n_adaln_chunks: Number of adaptive LN chunks. + modulation_bias: Whether to use bias in modulation. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + mlp_ratio: int = 4, + n_adaln_chunks: int = 3, + modulation_bias: bool = True, + ): + super().__init__(config=config, submodules=submodules, layer_number=layer_number) + + # Enable per-Transformer layer cuda graph + if CudaGraphManager is not None and config.enable_cuda_graph and config.cuda_graph_scope != "full_iteration": + self.cudagraph_manager = CudaGraphManager(config, share_cudagraph_io_buffers=False) + self.adaln = AdaLN( + config=config, n_adaln_chunks=n_adaln_chunks, modulation_bias=modulation_bias, use_second_norm=False + ) + + def forward( + self, + hidden_states, + attention_mask=None, + context=None, + context_mask=None, + rotary_pos_emb=None, + inference_params=None, + packed_seq_params=None, + emb=None, + ): + """ + Forward pass for FLUX single transformer block. + + Args: + hidden_states: Input hidden states. + attention_mask: Attention mask. + context: Context tensor (unused). + context_mask: Context mask (unused). + rotary_pos_emb: Rotary position embeddings. + inference_params: Inference parameters. + packed_seq_params: Packed sequence parameters. + emb: Timestep/conditioning embedding. + + Returns: + Tuple of (hidden_states, None). + """ + residual = hidden_states + + shift, scale, gate = self.adaln(emb) + + norm_hidden_states = self.adaln.modulated_layernorm(hidden_states, shift=shift, scale=scale) + + mlp_hidden_states, mlp_bias = self.mlp(norm_hidden_states) + + attention_output = self.self_attention( + norm_hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb + ) + + hidden_states = mlp_hidden_states + mlp_bias + attention_output + + hidden_states = self.adaln.scale_add(residual, x=hidden_states, gate=gate) + + return hidden_states, None + + def __call__(self, *args, **kwargs): + if hasattr(self, "cudagraph_manager"): + return self.cudagraph_manager(self, args, kwargs) + return super(MegatronModule, self).__call__(*args, **kwargs) + + +# ============================================================================ +# Layer Spec Functions +# ============================================================================ + + +def get_flux_double_transformer_engine_spec() -> ModuleSpec: + """ + Get the module specification for FLUX double transformer blocks. + + Returns: + ModuleSpec for MMDiTLayer with JointSelfAttention. + """ + return ModuleSpec( + module=MMDiTLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=JointSelfAttention, + params={"attn_mask_type": AttnMaskType.no_mask}, + submodules=JointSelfAttentionSubmodules( + q_layernorm=TENorm, + k_layernorm=TENorm, + added_q_layernorm=TENorm, + added_k_layernorm=TENorm, + linear_qkv=TEColumnParallelLinear, + added_linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) + + +def get_flux_single_transformer_engine_spec() -> ModuleSpec: + """ + Get the module specification for FLUX single transformer blocks. + + Returns: + ModuleSpec for FluxSingleTransformerBlock with FluxSingleAttention. + """ + return ModuleSpec( + module=FluxSingleTransformerBlock, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=FluxSingleAttention, + params={"attn_mask_type": AttnMaskType.no_mask}, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + q_layernorm=TENorm, + k_layernorm=TENorm, + linear_proj=TERowParallelLinear, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) diff --git a/src/megatron/bridge/diffusion/models/flux/flux_model.py b/src/megatron/bridge/diffusion/models/flux/flux_model.py new file mode 100644 index 0000000000..2e55fef1e6 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/flux/flux_model.py @@ -0,0 +1,399 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FLUX diffusion model implementation with Megatron Core.""" + +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import numpy as np +import torch +from megatron.core import parallel_state +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import sharded_state_dict_default +from megatron.core.utils import make_sharded_tensor_for_checkpoint +from torch import nn + +from megatron.bridge.diffusion.models.flux.flux_layer_spec import ( + AdaLNContinuous, + FluxSingleTransformerBlock, + MMDiTLayer, + get_flux_double_transformer_engine_spec, + get_flux_single_transformer_engine_spec, +) +from megatron.bridge.diffusion.models.flux.layers import EmbedND, MLPEmbedder, TimeStepEmbedder + + +if TYPE_CHECKING: + pass + + +class Flux(VisionModule): + """ + FLUX diffusion model implementation with Megatron Core. + + FLUX is a state-of-the-art text-to-image diffusion model that uses + a combination of double (MMDiT-style) and single transformer blocks. + + Args: + config: FluxProvider containing model hyperparameters. + + Attributes: + out_channels: Number of output channels. + hidden_size: Hidden dimension size. + num_attention_heads: Number of attention heads. + patch_size: Patch size for image embedding. + in_channels: Number of input channels. + guidance_embed: Whether guidance embedding is used. + pos_embed: N-dimensional position embedding module. + img_embed: Image embedding linear layer. + txt_embed: Text embedding linear layer. + timestep_embedding: Timestep embedding module. + vector_embedding: Vector (CLIP pooled) embedding module. + guidance_embedding: Guidance embedding module (if guidance_embed=True). + double_blocks: List of MMDiT layers for double blocks. + single_blocks: List of single transformer blocks. + norm_out: Output normalization layer. + proj_out: Output projection layer. + """ + + def __init__( + self, + config: TransformerConfig, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + **kwargs, + ): + super(Flux, self).__init__(config=config) + + self.config: TransformerConfig = config + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + + # megatron core pipelining currently depends on model type + # TODO: remove this dependency ? + self.model_type = ModelType.encoder_or_decoder + + self.out_channels = config.in_channels + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.patch_size = config.patch_size + self.in_channels = config.in_channels + self.guidance_embed = config.guidance_embed + + # Position embedding for rotary embeddings + self.pos_embed = EmbedND(dim=self.hidden_size, theta=10000, axes_dim=config.axes_dims_rope) + + # Input embeddings + self.img_embed = nn.Linear(config.in_channels, self.hidden_size) + self.txt_embed = nn.Linear(config.context_dim, self.hidden_size) + + # Timestep and conditioning embeddings + self.timestep_embedding = TimeStepEmbedder(config.model_channels, self.hidden_size) + self.vector_embedding = MLPEmbedder(in_dim=config.vec_in_dim, hidden_dim=self.hidden_size) + + # Optional guidance embedding (for FLUX-dev) + if config.guidance_embed: + self.guidance_embedding = MLPEmbedder(in_dim=config.model_channels, hidden_dim=self.hidden_size) + + # Double blocks (MMDiT-style joint attention) + self.double_blocks = nn.ModuleList( + [ + MMDiTLayer( + config=config, + submodules=get_flux_double_transformer_engine_spec().submodules, + layer_number=i, + context_pre_only=False, + ) + for i in range(config.num_joint_layers) + ] + ) + + # Single blocks + self.single_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + config=config, + submodules=get_flux_single_transformer_engine_spec().submodules, + layer_number=i, + ) + for i in range(config.num_single_layers) + ] + ) + + # Output layers + self.norm_out = AdaLNContinuous(config=config, conditioning_embedding_dim=self.hidden_size) + self.proj_out = nn.Linear(self.hidden_size, self.patch_size * self.patch_size * self.out_channels, bias=True) + + def get_fp8_context(self): + """Get FP8 autocast context if FP8 is enabled.""" + if not self.config.fp8: + fp8_context = nullcontext() + else: + # Import TE dependencies only when training in fp8 + from transformer_engine.common.recipe import ( + DelayedScaling, + Float8BlockScaling, + Float8CurrentScaling, + Format, + MXFP8BlockScaling, + ) + from transformer_engine.pytorch import fp8_autocast + + if self.config.fp8 == "e4m3": + fp8_format = Format.E4M3 + elif self.config.fp8 == "hybrid": + fp8_format = Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + # Defaults to delayed scaling for backward compatibility + if not self.config.fp8_recipe: + self.config.fp8_recipe = "delayed" + + if self.config.fp8_recipe == "delayed": + fp8_recipe = DelayedScaling( + margin=self.config.fp8_margin, + interval=self.config.fp8_interval, + fp8_format=fp8_format, + amax_compute_algo=self.config.fp8_amax_compute_algo, + amax_history_len=self.config.fp8_amax_history_len, + override_linear_precision=(False, False, not self.config.fp8_wgrad), + ) + elif self.config.fp8_recipe == "current": + fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format) + elif self.config.fp8_recipe == "block": + fp8_recipe = Float8BlockScaling(fp8_format=fp8_format) + elif self.config.fp8_recipe == "mxfp8": + fp8_recipe = MXFP8BlockScaling(fp8_format=fp8_format) + else: + raise ValueError(f"Unsupported FP8 recipe: {self.config.fp8_recipe}") + + fp8_group = None + if parallel_state.model_parallel_is_initialized(): + fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group) + return fp8_context + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor = None, + y: torch.Tensor = None, + timesteps: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + controlnet_double_block_samples: torch.Tensor = None, + controlnet_single_block_samples: torch.Tensor = None, + ): + """ + Forward pass through the FLUX model. + + Args: + img: Image input tensor (latents) [B, S, C]. + txt: Text input tensor (text embeddings) [B, S, D]. + y: Vector input for embedding (CLIP pooled output) [B, D]. + timesteps: Timestep input tensor [B]. + img_ids: Image position IDs for rotary embedding [B, S, 3]. + txt_ids: Text position IDs for rotary embedding [B, S, 3]. + guidance: Guidance input for conditioning (FLUX-dev) [B]. + controlnet_double_block_samples: Optional controlnet samples for double blocks. + controlnet_single_block_samples: Optional controlnet samples for single blocks. + + Returns: + Output tensor of shape [B, S, out_channels]. + """ + # Embed image and text + hidden_states = self.img_embed(img) + encoder_hidden_states = self.txt_embed(txt) + + # Timestep embedding + timesteps = timesteps.to(img.dtype) * 1000 + vec_emb = self.timestep_embedding(timesteps) + + # Optional guidance embedding + if guidance is not None: + vec_emb = vec_emb + self.guidance_embedding(self.timestep_embedding.time_proj(guidance * 1000)) + + # Add vector (CLIP pooled) embedding + vec_emb = vec_emb + self.vector_embedding(y) + + # Compute rotary position embeddings + ids = torch.cat((txt_ids, img_ids), dim=1) + rotary_pos_emb = self.pos_embed(ids) + + # Process through double blocks (MMDiT) + for id_block, block in enumerate(self.double_blocks): + with self.get_fp8_context(): + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_pos_emb=rotary_pos_emb, + emb=vec_emb, + ) + + # Apply controlnet residuals if provided + if controlnet_double_block_samples is not None: + interval_control = len(self.double_blocks) / len(controlnet_double_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_double_block_samples[id_block // interval_control] + + # Concatenate encoder and image hidden states for single blocks + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0) + + # Process through single blocks + for id_block, block in enumerate(self.single_blocks): + with self.get_fp8_context(): + hidden_states, _ = block( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + emb=vec_emb, + ) + + # Apply controlnet residuals if provided + if controlnet_single_block_samples is not None: + interval_control = len(self.single_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = torch.cat( + [ + hidden_states[: encoder_hidden_states.shape[0]], + hidden_states[encoder_hidden_states.shape[0] :] + + controlnet_single_block_samples[id_block // interval_control], + ] + ) + + # Extract image hidden states (remove text portion) + hidden_states = hidden_states[encoder_hidden_states.shape[0] :, ...] + + # Output normalization and projection + hidden_states = self.norm_out(hidden_states, vec_emb) + output = self.proj_out(hidden_states) + + return output + + def sharded_state_dict(self, prefix="", sharded_offsets: tuple = (), metadata: dict = None) -> ShardedStateDict: + """ + Get sharded state dict for distributed checkpointing. + + Args: + prefix: Prefix for state dict keys. + sharded_offsets: Sharded offsets tuple. + metadata: Additional metadata. + + Returns: + ShardedStateDict for the model. + """ + sharded_state_dict = {} + + # Handle double blocks + layer_prefix = f"{prefix}double_blocks." + for layer in self.double_blocks: + offset = layer._get_layer_offset(self.config) + + global_layer_offset = layer.layer_number + state_dict_prefix = f"{layer_prefix}{global_layer_offset - offset}." + sharded_prefix = f"{layer_prefix}{global_layer_offset}." + sharded_pp_offset = [] + + layer_sharded_state_dict = layer.sharded_state_dict(state_dict_prefix, sharded_pp_offset, metadata) + replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) + + sharded_state_dict.update(layer_sharded_state_dict) + + # Handle single blocks + layer_prefix = f"{prefix}single_blocks." + for layer in self.single_blocks: + offset = layer._get_layer_offset(self.config) + + global_layer_offset = layer.layer_number + state_dict_prefix = f"{layer_prefix}{global_layer_offset - offset}." + sharded_prefix = f"{layer_prefix}{global_layer_offset}." + sharded_pp_offset = [] + + layer_sharded_state_dict = layer.sharded_state_dict(state_dict_prefix, sharded_pp_offset, metadata) + replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) + + sharded_state_dict.update(layer_sharded_state_dict) + + # Handle other modules + for name, module in self.named_children(): + if not (module is self.single_blocks or module is self.double_blocks): + sharded_state_dict.update( + sharded_state_dict_default(module, f"{prefix}{name}.", sharded_offsets, metadata) + ) + + # Set replica IDs for embedding and output layers + # These layers are replicated across tensor parallel ranks and need proper replica IDs + replica_modules = ["img_embed", "txt_embed", "timestep_embedding", "vector_embedding", "proj_out"] + if self.guidance_embed: + replica_modules.append("guidance_embedding") + + for module_name in replica_modules: + if hasattr(self, module_name): + module = getattr(self, module_name) + for param_name, param in module.named_parameters(): + weight_key = f"{prefix}{module_name}.{param_name}" + if weight_key in sharded_state_dict: + self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key) + + return sharded_state_dict + + def _set_embedder_weights_replica_id( + self, tensor: torch.Tensor, sharded_state_dict: ShardedStateDict, embedder_weight_key: str + ) -> None: + """Set replica IDs of the weights in embedding layers for sharded state dict. + + Args: + tensor: The parameter tensor to set replica ID for. + sharded_state_dict: State dict with the weight to tie. + embedder_weight_key: Key of the weight in the state dict. + + Returns: + None, acts in-place. + """ + tp_rank = parallel_state.get_tensor_model_parallel_rank() + vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() + vpp_rank = vpp_rank if vpp_rank else 0 + vpp_world = parallel_state.get_virtual_pipeline_model_parallel_world_size() + vpp_world = vpp_world if vpp_world else 1 + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + + # Remove the existing entry and replace with properly configured sharded tensor + del sharded_state_dict[embedder_weight_key] + + replica_id = ( + tp_rank, + (vpp_rank + pp_rank * vpp_world), + parallel_state.get_data_parallel_rank(with_context_parallel=True), + ) + + sharded_state_dict[embedder_weight_key] = make_sharded_tensor_for_checkpoint( + tensor=tensor, + key=embedder_weight_key, + replica_id=replica_id, + allow_shape_mismatch=False, + ) + + def set_input_tensor(self, input_tensor): + """Set input tensor for pipeline parallelism.""" + pass diff --git a/src/megatron/bridge/diffusion/models/flux/flux_provider.py b/src/megatron/bridge/diffusion/models/flux/flux_provider.py new file mode 100644 index 0000000000..c1d8073964 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/flux/flux_provider.py @@ -0,0 +1,134 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from dataclasses import dataclass, field +from typing import Callable, List, Optional + +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.transformer.utils import openai_gelu + +from megatron.bridge.diffusion.models.flux.flux_model import Flux +from megatron.bridge.models.model_provider import ModelProviderMixin +from megatron.bridge.models.transformer_config import TransformerConfig + + +logger = logging.getLogger(__name__) + + +@dataclass +class FluxProvider(TransformerConfig, ModelProviderMixin[VisionModule]): + """ + FLUX model provider configuration. + + Extends TransformerConfig with FLUX-specific parameters and provides + model instantiation through the ModelProviderMixin interface. + + Attributes: + num_layers: Dummy setting (required by base class). + num_joint_layers: Number of double (joint) transformer blocks. + num_single_layers: Number of single transformer blocks. + hidden_size: Hidden dimension size. + num_attention_heads: Number of attention heads. + activation_func: Activation function to use. + add_qkv_bias: Whether to add bias to QKV projections. + in_channels: Number of input channels (latent channels). + context_dim: Text encoder context dimension. + model_channels: Model channel dimension for timestep embedding. + patch_size: Patch size for image embedding. + guidance_embed: Whether to use guidance embedding (for FLUX-dev). + vec_in_dim: Vector input dimension (CLIP pooled output dim). + rotary_interleaved: Whether to use interleaved rotary embeddings. + apply_rope_fusion: Whether to apply RoPE fusion. + guidance_scale: Classifier-free guidance scale. + ckpt_path: Path to checkpoint for loading weights. + load_dist_ckpt: Whether to load distributed checkpoint. + do_convert_from_hf: Whether to convert from HuggingFace format. + save_converted_model_to: Path to save converted model. + """ + + # Base class requirements + num_layers: int = 1 # Dummy setting + hidden_size: int = 3072 + ffn_hidden_size: int = 12288 + num_attention_heads: int = 24 + layernorm_epsilon: float = 1e-06 + hidden_dropout: float = 0 + attention_dropout: float = 0 + + # FLUX-specific layer configuration + num_joint_layers: int = 19 + num_single_layers: int = 38 + + # Model architecture + activation_func: Callable = openai_gelu + add_qkv_bias: bool = True + in_channels: int = 64 + context_dim: int = 4096 + model_channels: int = 256 + axes_dims_rope: List[int] = field(default_factory=lambda: [16, 56, 56]) + patch_size: int = 1 + guidance_embed: bool = False + vec_in_dim: int = 768 + + # Rotary embedding settings + rotary_interleaved: bool = True + apply_rope_fusion: bool = False + + # Initialization and performance settings + use_cpu_initialization: bool = True + gradient_accumulation_fusion: bool = False + enable_cuda_graph: bool = False + cuda_graph_scope: Optional[str] = None # full, full_iteration + use_te_rng_tracker: bool = False + cuda_graph_warmup_steps: int = 2 + + # Inference settings + guidance_scale: float = 3.5 + + # Checkpoint loading settings + ckpt_path: Optional[str] = None + load_dist_ckpt: bool = False + do_convert_from_hf: bool = False + save_converted_model_to: Optional[str] = None + + # these attributes are unused for images/videos, we just set because bridge training requires for LLMs + seq_length: int = 1024 + share_embeddings_and_output_weights: bool = False + vocab_size: int = 25256 * 8 + make_vocab_size_divisible_by: int = 128 + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Flux: + """ + Create and return a Flux model with this configuration. + + Args: + pre_process: Whether this is the first pipeline stage (unused for Flux). + post_process: Whether this is the last pipeline stage (unused for Flux). + vp_stage: Virtual pipeline stage (unused for Flux). + + Returns: + Configured Flux model instance. + """ + vp_size = self.virtual_pipeline_model_parallel_size + if vp_size: + p_size = self.pipeline_model_parallel_size + total_layers = self.num_joint_layers + self.num_single_layers + assert (total_layers // p_size) % vp_size == 0, ( + "Make sure the number of model chunks is the same across all pipeline stages." + ) + + model = Flux(config=self) + return model diff --git a/src/megatron/bridge/diffusion/models/flux/flux_step.py b/src/megatron/bridge/diffusion/models/flux/flux_step.py new file mode 100644 index 0000000000..f9de829bb7 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/flux/flux_step.py @@ -0,0 +1,432 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import math +from functools import lru_cache, partial +from typing import Iterable + +import torch +from megatron.core import parallel_state +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.utils import get_model_config + +from megatron.bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline import FlowMatchEulerDiscreteScheduler +from megatron.bridge.training.losses import masked_next_token_loss +from megatron.bridge.training.state import GlobalState + + +logger = logging.getLogger(__name__) + + +def flux_data_step(dataloader_iter, store_in_state=False): + """Process batch data for FLUX model. + + Args: + dataloader_iter: Iterator over the dataloader. + store_in_state: If True, store the batch in GlobalState for callbacks. + + Returns: + Processed batch dictionary with tensors moved to CUDA. + """ + batch = next(dataloader_iter) + if isinstance(batch, tuple) and len(batch) == 3: + _batch = batch[0] + else: + _batch = batch + + _batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in _batch.items()} + + if "loss_mask" not in _batch or _batch["loss_mask"] is None: + _batch["loss_mask"] = torch.ones(1, device="cuda") + + # Store batch in state for callbacks (e.g., validation image generation) + if store_in_state: + try: + from megatron.bridge.training.pretrain import get_current_state + + state = get_current_state() + state._last_validation_batch = _batch + except: + pass # If state access fails, silently continue + + return _batch + + +class FluxForwardStep: + """Forward step for FLUX diffusion model training. + + This class handles the forward pass during training, including: + - Timestep sampling using flow matching + - Noise injection with latent packing + - Model prediction + - Loss computation + + Args: + timestep_sampling: Method for sampling timesteps ("logit_normal", "uniform", "mode"). + logit_mean: Mean for logit-normal sampling. + logit_std: Standard deviation for logit-normal sampling. + mode_scale: Scale for mode sampling. + scheduler_steps: Number of scheduler training steps. + guidance_scale: Guidance scale for FLUX-dev models. + """ + + def __init__( + self, + timestep_sampling: str = "logit_normal", + logit_mean: float = 0.0, + logit_std: float = 1.0, + mode_scale: float = 1.29, + scheduler_steps: int = 1000, + guidance_scale: float = 3.5, + ): + self.timestep_sampling = timestep_sampling + self.logit_mean = logit_mean + self.logit_std = logit_std + self.mode_scale = mode_scale + self.scheduler_steps = scheduler_steps + self.guidance_scale = guidance_scale + self.autocast_dtype = torch.bfloat16 + # Initialize scheduler for timestep/sigma computations + self.scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=scheduler_steps) + + def __call__( + self, state: GlobalState, data_iterator: Iterable, model: VisionModule + ) -> tuple[torch.Tensor, partial]: + """Forward training step. + + Args: + state: Global state for the run. + data_iterator: Input data iterator. + model: The FLUX model. + + Returns: + Tuple containing the output tensor and the loss function. + """ + timers = state.timers + straggler_timer = state.straggler_timer + + config = get_model_config(model) + + timers("batch-generator", log_level=2).start() + + with straggler_timer(bdata=True): + batch = flux_data_step(data_iterator) + # Store batch for validation callbacks (only during evaluation) + if not torch.is_grad_enabled(): + state._last_batch = batch + timers("batch-generator").stop() + + check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss + check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss + + # Run diffusion training step + with straggler_timer: + if parallel_state.is_pipeline_last_stage(): + output_tensor, loss, loss_mask = self._training_step(model, batch, config) + batch["loss_mask"] = loss_mask + else: + output_tensor = self._training_step(model, batch, config) + + loss = output_tensor + if "loss_mask" not in batch or batch["loss_mask"] is None: + loss_mask = torch.ones_like(loss) + else: + loss_mask = batch["loss_mask"] + + loss_function = self._create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) + + return output_tensor, loss_function + + def _training_step( + self, model: VisionModule, batch: dict, config + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | torch.Tensor: + """Perform single training step with flow matching. + + Args: + model: The FLUX model. + batch: Data batch containing latents and text embeddings. + config: Model configuration. + + Returns: + On last pipeline stage: tuple of (output_tensor, loss, loss_mask). + On other stages: hidden_states tensor. + """ + # Get latents from batch - expected in [B, C, H, W] format + if "latents" in batch: + latents = batch["latents"] + else: + raise ValueError("Expected 'latents' in batch. VAE encoding should be done in data preprocessing.") + + # Prepare image latents with flow matching noise + ( + latents, + noise, + packed_noisy_model_input, + latent_image_ids, + guidance_vec, + timesteps, + ) = self.prepare_image_latent(latents, model) + + # Get text embeddings (precached) + if "prompt_embeds" in batch: + prompt_embeds = batch["prompt_embeds"].transpose(0, 1) + pooled_prompt_embeds = batch["pooled_prompt_embeds"] + text_ids = batch["text_ids"] + else: + raise ValueError("Expected precached text embeddings in batch.") + + # Forward pass + with torch.amp.autocast( + "cuda", enabled=self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype + ): + noise_pred = model( + img=packed_noisy_model_input, + txt=prompt_embeds, + y=pooled_prompt_embeds, + timesteps=timesteps / 1000, + img_ids=latent_image_ids, + txt_ids=text_ids, + guidance=guidance_vec, + ) + + # Unpack predictions for loss computation + noise_pred = self._unpack_latents( + noise_pred.transpose(0, 1), + latents.shape[2], + latents.shape[3], + ).transpose(0, 1) + + # Flow matching target: v = noise - latents (velocity formulation) + target = noise - latents + + # MSE loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + output_tensor = torch.mean(loss, dim=-1) + + # Create loss mask (all ones for now) + loss_mask = torch.ones_like(output_tensor) + + return output_tensor, loss, loss_mask + + # else: + # hidden_states = model( + # img=packed_noisy_model_input, + # txt=prompt_embeds, + # y=pooled_prompt_embeds, + # timesteps=timesteps / 1000, + # img_ids=latent_image_ids, + # txt_ids=text_ids, + # guidance=guidance_vec, + # ) + # return hidden_states + + def prepare_image_latent(self, latents: torch.Tensor, model: VisionModule): + """Prepare image latents with flow matching noise. + + Args: + latents: Input latent tensor [B, C, H, W]. + model: The FLUX model (for guidance_embed config). + + Returns: + Tuple of (latents, noise, packed_noisy_input, latent_image_ids, guidance, timesteps). + """ + latent_image_ids = self._prepare_latent_image_ids( + latents.shape[0], + latents.shape[2], + latents.shape[3], + latents.device, + latents.dtype, + ) + + noise = torch.randn_like(latents, device=latents.device, dtype=latents.dtype) + batch_size = latents.shape[0] + u = self.compute_density_for_timestep_sampling( + self.timestep_sampling, + batch_size, + ) + indices = (u * self.scheduler.num_train_timesteps).long() + timesteps = self.scheduler.timesteps[indices].to(device=latents.device) + + sigmas = self.scheduler.sigmas.to(device=latents.device, dtype=latents.dtype) + scheduler_timesteps = self.scheduler.timesteps.to(device=latents.device) + step_indices = [(scheduler_timesteps == t).nonzero().item() for t in timesteps] + timesteps = timesteps.to(dtype=latents.dtype) + sigma = sigmas[step_indices].flatten() + + while len(sigma.shape) < latents.ndim: + sigma = sigma.unsqueeze(-1) + + noisy_model_input = (1.0 - sigma) * latents + sigma * noise + packed_noisy_model_input = self._pack_latents( + noisy_model_input, + batch_size=latents.shape[0], + num_channels_latents=latents.shape[1], + height=latents.shape[2], + width=latents.shape[3], + ) + + # Guidance embedding (for FLUX-dev) + if hasattr(model, "guidance_embed") and model.guidance_embed: + guidance_vec = torch.full( + (noisy_model_input.shape[0],), + self.guidance_scale, + device=latents.device, + dtype=latents.dtype, + ) + else: + guidance_vec = None + + return ( + latents.transpose(0, 1), + noise.transpose(0, 1), + packed_noisy_model_input.transpose(0, 1), + latent_image_ids, + guidance_vec, + timesteps, + ) + + def compute_density_for_timestep_sampling( + self, + weighting_scheme: str, + batch_size: int, + logit_mean: float = None, + logit_std: float = None, + mode_scale: float = None, + ) -> torch.Tensor: + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + + Args: + weighting_scheme: Sampling scheme ("logit_normal", "mode", or "uniform"). + batch_size: Number of samples in batch. + logit_mean: Mean for logit-normal sampling. + logit_std: Standard deviation for logit-normal sampling. + mode_scale: Scale for mode sampling. + + Returns: + Tensor of sampled u values in [0, 1]. + """ + # Use instance defaults if not provided + logit_mean = logit_mean if logit_mean is not None else self.logit_mean + logit_std = logit_std if logit_std is not None else self.logit_std + mode_scale = mode_scale if mode_scale is not None else self.mode_scale + + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$) + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + @lru_cache + def _prepare_latent_image_ids( + self, batch_size: int, height: int, width: int, device: torch.device, dtype: torch.dtype + ) -> torch.Tensor: + """Prepare latent image IDs for positional encoding. + + Args: + batch_size: Number of samples. + height: Latent height. + width: Latent width. + device: Target device. + dtype: Target dtype. + + Returns: + Tensor of shape [B, (H/2)*(W/2), 3] with position IDs. + """ + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype, non_blocking=True) + + def _pack_latents( + self, latents: torch.Tensor, batch_size: int, num_channels_latents: int, height: int, width: int + ) -> torch.Tensor: + """Pack latents for FLUX processing. + + Rearranges [B, C, H, W] -> [B, (H/2)*(W/2), C*4]. + + Args: + latents: Input tensor [B, C, H, W]. + batch_size: Batch size. + num_channels_latents: Number of latent channels. + height: Latent height. + width: Latent width. + + Returns: + Packed tensor [B, num_patches, C*4]. + """ + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + return latents + + def _unpack_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor: + """Unpack latents from FLUX format. + + Rearranges [B, num_patches, C*4] -> [B, C, H, W]. + + Args: + latents: Packed tensor [B, num_patches, C*4]. + height: Target height. + width: Target width. + + Returns: + Unpacked tensor [B, C, H, W]. + """ + batch_size, num_patches, channels = latents.shape + + # Adjust h and w for patching + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // 4, height, width) + + return latents + + def _create_loss_function( + self, loss_mask: torch.Tensor, check_for_nan_in_loss: bool, check_for_spiky_loss: bool + ) -> partial: + """Create a partial loss function with the specified configuration. + + Args: + loss_mask: Used to mask out some portions of the loss. + check_for_nan_in_loss: Whether to check for NaN values in the loss. + check_for_spiky_loss: Whether to check for spiky loss values. + + Returns: + A partial function that can be called with output_tensor to compute the loss. + """ + return partial( + masked_next_token_loss, + loss_mask, + check_for_nan_in_loss=check_for_nan_in_loss, + check_for_spiky_loss=check_for_spiky_loss, + ) diff --git a/src/megatron/bridge/diffusion/models/flux/flux_step_with_automodel.py b/src/megatron/bridge/diffusion/models/flux/flux_step_with_automodel.py new file mode 100644 index 0000000000..3e85a4a2cf --- /dev/null +++ b/src/megatron/bridge/diffusion/models/flux/flux_step_with_automodel.py @@ -0,0 +1,333 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FLUX Forward Step with Automodel Pipeline Integration. + +This is a prototype showing how to integrate the automodel FlowMatchingPipeline +into Megatron's training flow, reusing the well-tested flow matching logic. +""" + +import logging +from functools import partial +from typing import Iterable + +import torch + +# Import automodel pipeline components +from dfm.src.automodel.flow_matching.flow_matching_pipeline import FlowMatchingPipeline +from megatron.core import parallel_state +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.utils import get_model_config + +# Import MegatronFluxAdapter from flow_matching module +from megatron.bridge.diffusion.models.flux.flow_matching.flux_adapter import MegatronFluxAdapter +from megatron.bridge.training.losses import masked_next_token_loss +from megatron.bridge.training.state import GlobalState + + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Megatron Forward Step with Automodel Pipeline +# ============================================================================= + + +def flux_data_step(dataloader_iter, store_in_state=False): + """Process batch data for FLUX model. + + Args: + dataloader_iter: Iterator over the dataloader. + store_in_state: If True, store the batch in GlobalState for callbacks. + + Returns: + Processed batch dictionary with tensors moved to CUDA. + """ + batch = next(dataloader_iter) + if isinstance(batch, tuple) and len(batch) == 3: + _batch = batch[0] + else: + _batch = batch + + _batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in _batch.items()} + + if "loss_mask" not in _batch or _batch["loss_mask"] is None: + _batch["loss_mask"] = torch.ones(1, device="cuda") + + # Store batch in state for callbacks (e.g., validation image generation) + if store_in_state: + try: + from megatron.bridge.training.pretrain import get_current_state + + state = get_current_state() + state._last_validation_batch = _batch + except: + pass # If state access fails, silently continue + + return _batch + + +class FluxForwardStepWithAutomodel: + """ + Forward step for FLUX using the automodel FlowMatchingPipeline. + + This class demonstrates how to integrate the well-tested automodel pipeline + into Megatron's training flow, gaining benefits like: + - Unified flow matching implementation + - Better logging and debugging + - Consistent timestep sampling across models + - Easier maintenance + + Args: + timestep_sampling: Method for sampling timesteps ("logit_normal", "uniform", "mode"). + logit_mean: Mean for logit-normal sampling. + logit_std: Standard deviation for logit-normal sampling. + flow_shift: Shift parameter for timestep transformation (default: 1.0 for FLUX). + scheduler_steps: Number of scheduler training steps. + guidance_scale: Guidance scale for FLUX-dev models. + use_loss_weighting: Whether to apply flow-based loss weighting. + """ + + def __init__( + self, + timestep_sampling: str = "logit_normal", + logit_mean: float = 0.0, + logit_std: float = 1.0, + flow_shift: float = 1.0, # FLUX uses shift=1.0 typically + scheduler_steps: int = 1000, + guidance_scale: float = 3.5, + use_loss_weighting: bool = False, # FLUX typically doesn't use loss weighting + ): + self.autocast_dtype = torch.bfloat16 + + # Create the automodel pipeline with Megatron adapter + adapter = MegatronFluxAdapter(guidance_scale=guidance_scale) + + self.pipeline = FlowMatchingPipeline( + model_adapter=adapter, + num_train_timesteps=scheduler_steps, + timestep_sampling=timestep_sampling, + flow_shift=flow_shift, + logit_mean=logit_mean, + logit_std=logit_std, + sigma_min=0.0, + sigma_max=1.0, + use_loss_weighting=use_loss_weighting, + cfg_dropout_prob=0.0, # No CFG dropout in Megatron training + log_interval=100, + summary_log_interval=10, + ) + + logger.info( + f"FluxForwardStepWithAutomodel initialized with:\n" + f" - Timestep sampling: {timestep_sampling}\n" + f" - Flow shift: {flow_shift}\n" + f" - Guidance scale: {guidance_scale}\n" + f" - Loss weighting: {use_loss_weighting}" + ) + + def __call__( + self, state: GlobalState, data_iterator: Iterable, model: VisionModule + ) -> tuple[torch.Tensor, partial]: + """ + Forward training step using automodel pipeline. + + Args: + state: Global state for the run. + data_iterator: Input data iterator. + model: The FLUX model. + + Returns: + Tuple containing the output tensor and the loss function. + """ + timers = state.timers + straggler_timer = state.straggler_timer + + config = get_model_config(model) # noqa: F841 + + timers("batch-generator", log_level=2).start() + + with straggler_timer(bdata=True): + batch = flux_data_step(data_iterator) + # Store batch for validation callbacks (only during evaluation) + if not torch.is_grad_enabled(): + state._last_batch = batch + timers("batch-generator").stop() + + check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss + check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss + + # Prepare batch for automodel pipeline + # Map Megatron keys to automodel expected keys + pipeline_batch = self._prepare_batch_for_pipeline(batch) + + # Run the pipeline step + with straggler_timer: + if parallel_state.is_pipeline_last_stage(): + output_tensor, loss, loss_mask = self._training_step_with_pipeline(model, pipeline_batch) + # loss_mask is already created correctly in _training_step_with_pipeline + batch["loss_mask"] = loss_mask + else: + # For non-final pipeline stages, we still need to run the model + # but loss computation happens only on the last stage + output_tensor = self._training_step_with_pipeline(model, pipeline_batch) + loss_mask = None + + # Use the loss_mask from training step (already has correct shape) + if loss_mask is None: + # This should only happen for non-final pipeline stages + loss_mask = torch.ones(1, device="cuda") + + loss_function = self._create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) + + return output_tensor, loss_function + + def _prepare_batch_for_pipeline(self, batch: dict) -> dict: + """ + Prepare Megatron batch for automodel pipeline. + + Maps Megatron batch keys to automodel expected format: + - latents -> image_latents (for consistency) + - Keeps prompt_embeds, pooled_prompt_embeds, text_ids as-is + """ + pipeline_batch = { + "image_latents": batch["latents"], # Map to automodel expected key + "prompt_embeds": batch.get("prompt_embeds"), + "pooled_prompt_embeds": batch.get("pooled_prompt_embeds"), + "text_ids": batch.get("text_ids"), + "data_type": "image", # FLUX is for image generation + } + + # Copy any additional keys + for key in batch: + if key not in pipeline_batch and key != "latents": + pipeline_batch[key] = batch[key] + + return pipeline_batch + + def _training_step_with_pipeline( + self, model: VisionModule, batch: dict + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | torch.Tensor: + """ + Perform single training step using automodel pipeline. + + Args: + model: The FLUX model. + batch: Data batch prepared for pipeline. + + Returns: + On last pipeline stage: tuple of (output_tensor, loss, loss_mask). + On other stages: output tensor. + """ + device = torch.device("cuda") + dtype = self.autocast_dtype + + # Pass model in batch so adapter can check for guidance support + batch["_model"] = model + + with torch.amp.autocast("cuda", enabled=dtype in (torch.half, torch.bfloat16), dtype=dtype): + # Run the automodel pipeline step (global_step defaults to 0) + weighted_loss, average_weighted_loss, loss_mask, metrics = self.pipeline.step( + model=model, + batch=batch, + device=device, + dtype=dtype, + ) + + # Clean up temporary model reference + batch.pop("_model", None) + + if parallel_state.is_pipeline_last_stage(): + # Match original implementation's reduction pattern + # Original does: loss = mse(..., reduction="none"), then output_tensor = mean(loss, dim=-1) + # This keeps most dimensions and only reduces the last one + # But automodel returns full loss, so we reduce to match expected shape + + # For FLUX with images: weighted_loss is [B, C, H, W] + # Original pattern: mean over spatial dimensions -> [B, C] or similar + # But Megatron expects a 1D tensor per sample, so reduce to [B] + output_tensor = torch.mean(weighted_loss, dim=list(range(1, weighted_loss.ndim))) + + # Always create a fresh loss_mask matching output_tensor shape + # Ignore any loss_mask from batch as it may have incompatible shape + loss_mask = torch.ones_like(output_tensor) + + return output_tensor, average_weighted_loss, loss_mask + else: + # For intermediate stages, return the tensor for pipeline communication + return weighted_loss + + def _create_loss_function( + self, loss_mask: torch.Tensor, check_for_nan_in_loss: bool, check_for_spiky_loss: bool + ) -> partial: + """ + Create a partial loss function with the specified configuration. + + Args: + loss_mask: Used to mask out some portions of the loss. + check_for_nan_in_loss: Whether to check for NaN values in the loss. + check_for_spiky_loss: Whether to check for spiky loss values. + + Returns: + A partial function that can be called with output_tensor to compute the loss. + """ + return partial( + masked_next_token_loss, + loss_mask, + check_for_nan_in_loss=check_for_nan_in_loss, + check_for_spiky_loss=check_for_spiky_loss, + ) + + +# ============================================================================= +# Convenience Factory +# ============================================================================= + + +def create_flux_forward_step( + use_automodel_pipeline: bool = True, + **kwargs, +): + """ + Factory function to create either the automodel-based or original forward step. + + Args: + use_automodel_pipeline: If True, use FluxForwardStepWithAutomodel. + If False, use original FluxForwardStep. + **kwargs: Arguments passed to the forward step constructor. + + Returns: + Forward step instance. + + Example: + # Use automodel pipeline + forward_step = create_flux_forward_step( + use_automodel_pipeline=True, + timestep_sampling="logit_normal", + flow_shift=1.0, + ) + + # Use original implementation + forward_step = create_flux_forward_step( + use_automodel_pipeline=False, + timestep_sampling="logit_normal", + ) + """ + if use_automodel_pipeline: + return FluxForwardStepWithAutomodel(**kwargs) + else: + from megatron.bridge.diffusion.models.flux.flux_step import FluxForwardStep + + return FluxForwardStep(**kwargs) diff --git a/src/megatron/bridge/diffusion/models/flux/layers.py b/src/megatron/bridge/diffusion/models/flux/layers.py new file mode 100644 index 0000000000..0e45321939 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/flux/layers.py @@ -0,0 +1,156 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FLUX embedding layers for diffusion models.""" + +from typing import List + +import torch +from diffusers.models.embeddings import Timesteps +from torch import Tensor, nn + + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + """ + Compute rotary position embeddings. + + Different from the original ROPE used for flux. + Megatron attention takes the outer product and calculates sin/cos inside, + so we only need to get the freqs here in the shape of [seq, ..., dim]. + + Args: + pos: Position tensor. + dim: Embedding dimension (must be even). + theta: Base frequency. + + Returns: + Rotary position embeddings of shape [..., dim//2]. + """ + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + out = torch.einsum("...n,d->...nd", pos, omega) + + return out.float() + + +class EmbedND(nn.Module): + """ + N-Dimensional Rotary Position Embedding generator. + + Generate Rope matrix with preset axes dimensions. + + Args: + dim: Total embedding dimension. + theta: Base frequency for rotary embeddings. + axes_dim: List of dimensions for each axis. + """ + + def __init__(self, dim: int, theta: int, axes_dim: List[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + """ + Compute N-dimensional rotary position embeddings. + + Args: + ids: Position IDs tensor of shape [batch, seq, n_axes]. + + Returns: + Rotary embeddings tensor. + """ + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-1, + ) + emb = emb.unsqueeze(1).permute(2, 0, 1, 3) + return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) + + +class MLPEmbedder(nn.Module): + """ + MLP embedder with two projection layers and SiLU activation. + + Args: + in_dim: Input dimension. + hidden_dim: Hidden/output dimension. + """ + + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass through the MLP embedder.""" + return self.out_layer(self.silu(self.in_layer(x))) + + +class TimeStepEmbedder(nn.Module): + """ + A neural network module that embeds timesteps for use in diffusion models. + + It projects the input timesteps to a higher-dimensional space and then embeds + them using an MLP (Multilayer Perceptron). The projection and embedding provide + a learned representation of the timestep that can be used in further computations. + + Args: + embedding_dim: The dimensionality of the timestep embedding space. + hidden_dim: The dimensionality of the hidden layer in the MLPEmbedder. + flip_sin_to_cos: Whether to flip the sine and cosine components. + downscale_freq_shift: A scaling factor for the frequency shift. + scale: A scaling factor applied to the timestep projections. + max_period: The maximum period for the sine and cosine functions. + """ + + def __init__( + self, + embedding_dim: int, + hidden_dim: int, + flip_sin_to_cos: bool = True, + downscale_freq_shift: float = 0, + scale: float = 1, + max_period: int = 10000, + ): + super().__init__() + + self.time_proj = Timesteps( + num_channels=embedding_dim, + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=downscale_freq_shift, + scale=scale, + ) + self.time_embedder = MLPEmbedder(in_dim=embedding_dim, hidden_dim=hidden_dim) + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + """ + Compute timestep embeddings. + + Args: + timesteps: Input timestep tensor. + + Returns: + Embedded timesteps tensor. + """ + timesteps_proj = self.time_proj(timesteps) + timesteps_emb = self.time_embedder(timesteps_proj) + + return timesteps_emb diff --git a/src/megatron/bridge/diffusion/models/wan/__init__.py b/src/megatron/bridge/diffusion/models/wan/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/models/wan/flow_matching/__init__.py b/src/megatron/bridge/diffusion/models/wan/flow_matching/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/src/megatron/bridge/diffusion/models/wan/flow_matching/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/megatron/bridge/diffusion/models/wan/flow_matching/flow_inference_pipeline.py b/src/megatron/bridge/diffusion/models/wan/flow_matching/flow_inference_pipeline.py new file mode 100644 index 0000000000..048d7f9753 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/wan/flow_matching/flow_inference_pipeline.py @@ -0,0 +1,557 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import math +import os +import random +import re +import sys +from contextlib import contextmanager +from typing import Tuple + +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +from diffusers import AutoencoderKLWan +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from megatron.core import parallel_state +from megatron.core.inference.communication_utils import ( + broadcast_from_last_pipeline_stage, + recv_from_prev_pipeline_rank_, + send_to_next_pipeline_rank, +) +from megatron.core.packed_seq_params import PackedSeqParams +from torch.nn import functional as F +from tqdm import tqdm +from transformers import AutoTokenizer, UMT5EncoderModel + +from megatron.bridge.diffusion.models.wan.utils import grid_sizes_calculation, patchify, unpatchify +from megatron.bridge.diffusion.models.wan.wan_provider import WanModelProvider +from megatron.bridge.training.model_load_save import load_megatron_model as _load_megatron_model + + +@torch.no_grad() +def _encode_text( + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + device: str, + caption: str, +) -> torch.Tensor: + caption = caption.strip() + inputs = tokenizer( + caption, + max_length=512, + padding="max_length", + truncation=True, + return_tensors="pt", + return_attention_mask=True, + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + outputs = text_encoder(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).last_hidden_state + # Trim to the true (unpadded) sequence length using the attention mask + true_len = int(inputs["attention_mask"].sum(dim=-1).item()) + outputs = outputs[0, :true_len, :] + return outputs + + +class FlowInferencePipeline: # noqa: D101 + def __init__( + self, + inference_cfg, + model_id="Wan-AI/Wan2.1-T2V-14B-Diffusers", + checkpoint_dir=None, + checkpoint_step=None, + t5_checkpoint_dir=None, + vae_checkpoint_dir=None, + device_id=0, + rank=0, + t5_cpu=False, + tensor_parallel_size=1, + context_parallel_size=1, + pipeline_parallel_size=1, + sequence_parallel=False, + pipeline_dtype=torch.float32, + ): + r""" + Initializes the FlowInferencePipeline with the given parameters. + + Args: + inference_cfg (dict): + Object containing inference configuration. + checkpoint_dir (`str`): + Path to directory containing model checkpoints + t5_checkpoint_dir (`str`, *optional*, defaults to None): + Optional directory containing T5 checkpoint and tokenizer; falls back to `checkpoint_dir` if None. + vae_checkpoint_dir (`str`, *optional*, defaults to None): + Optional directory containing VAE checkpoint; falls back to `checkpoint_dir` if None. + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + """ + self.device = torch.device(f"cuda:{device_id}") + self.inference_cfg = inference_cfg + self.model_id = model_id + self.rank = rank + self.t5_cpu = t5_cpu + self.tensor_parallel_size = tensor_parallel_size + self.context_parallel_size = context_parallel_size + self.pipeline_parallel_size = pipeline_parallel_size + self.sequence_parallel = sequence_parallel + self.pipeline_dtype = pipeline_dtype + self.num_train_timesteps = inference_cfg.num_train_timesteps + self.param_dtype = inference_cfg.param_dtype + self.text_len = inference_cfg.text_len + + self.text_encoder = UMT5EncoderModel.from_pretrained( + model_id, + subfolder="text_encoder", + torch_dtype=inference_cfg.t5_dtype, + ) + self.tokenizer = AutoTokenizer.from_pretrained( + model_id, + subfolder="tokenizer", + ) + + self.vae_stride = inference_cfg.vae_stride + self.patch_size = inference_cfg.patch_size + self.vae = AutoencoderKLWan.from_pretrained( + model_id, + subfolder="vae", + torch_dtype=inference_cfg.param_dtype, + ) + self.vae.to(self.device) + + wan_checkpoint_dir = self._select_checkpoint_dir(checkpoint_dir, checkpoint_step) + self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) + + # if we use context parallelism, we need to set qkv_format to "thd" for context parallelism + self.model.config.qkv_format = "thd" # "sbhd" + + # set self.sp_size=1 for later use, just to respect the original Wan inference code + self.sp_size = 1 + + if dist.is_initialized(): + dist.barrier() + self.model.to(self.device) + + self.sample_neg_prompt = inference_cfg.english_sample_neg_prompt + + def setup_model_from_checkpoint(self, checkpoint_dir): + provider = WanModelProvider() + provider.tensor_model_parallel_size = self.tensor_parallel_size + provider.pipeline_model_parallel_size = self.pipeline_parallel_size + provider.context_parallel_size = self.context_parallel_size + provider.sequence_parallel = self.sequence_parallel + provider.pipeline_dtype = self.pipeline_dtype + # Once all overrides are set, finalize the model provider to ensure the post initialization logic is run + provider.finalize() + provider.initialize_model_parallel(seed=0) + + ## Read from megatron checkpoint + model = _load_megatron_model( + checkpoint_dir, + mp_overrides={ + "tensor_model_parallel_size": self.tensor_parallel_size, + "pipeline_model_parallel_size": self.pipeline_parallel_size, + "context_parallel_size": self.context_parallel_size, + "sequence_parallel": self.sequence_parallel, + "pipeline_dtype": self.pipeline_dtype, + }, + ) + if isinstance(model, list): + model = model[0] + if hasattr(model, "module"): + model = model.module + + return model + + def _select_checkpoint_dir(self, base_dir: str, checkpoint_step) -> str: + """ + Resolve checkpoint directory: + - If checkpoint_step is provided, use base_dir/iter_{step:07d} + - Otherwise, pick the largest iter_######## subdirectory under base_dir + """ + if checkpoint_step is not None: + path = os.path.join(base_dir, f"iter_{int(checkpoint_step):07d}") + if os.path.isdir(path): + logging.info(f"Using specified checkpoint: {path}") + return path + raise FileNotFoundError(f"Specified checkpoint step {checkpoint_step} not found at {path}") + + if not os.path.isdir(base_dir): + raise FileNotFoundError(f"Checkpoint base directory does not exist: {base_dir}") + + pattern = re.compile(r"^iter_(\d+)$") + try: + _, latest_path = max( + ( + (int(pattern.match(e.name).group(1)), e.path) + for e in os.scandir(base_dir) + if e.is_dir() and pattern.match(e.name) + ), + key=lambda x: x[0], + ) + except ValueError: + raise FileNotFoundError( + f"No checkpoints found under {base_dir}. Expected subdirectories named like 'iter_0001800'." + ) + + logging.info(f"Auto-selected latest checkpoint: {latest_path}") + return latest_path + + def forward_pp_step( + self, + latent_model_input: torch.Tensor, + grid_sizes: list[Tuple[int, int, int]], + max_video_seq_len: int, + timestep: torch.Tensor, + arg_c: dict, + ) -> torch.Tensor: + """ + Forward pass supporting pipeline parallelism. + """ + + pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() + # PP=1: no pipeline parallelism (avoid touching PP groups which may be uninitialized in unit tests) + if pp_world_size == 1: + noise_pred_pp = self.model(latent_model_input, grid_sizes=grid_sizes, t=timestep, **arg_c) + return noise_pred_pp + # For PP>1, safe to query stage information + is_pp_first = parallel_state.is_pipeline_first_stage(ignore_virtual=True) + is_pp_last = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + + # PP>1: pipeline parallelism + hidden_size = self.model.config.hidden_size + batch_size = latent_model_input.shape[1] + # noise prediction shape for communication between first and last pipeline stages + noise_pred_pp_shape = list(latent_model_input.shape) + + if is_pp_first: + # First stage: compute multimodal + first PP slice, send activations, then receive sampled token + hidden_states = self.model(latent_model_input, grid_sizes=grid_sizes, t=timestep, **arg_c) + send_to_next_pipeline_rank(hidden_states) + + noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) + return noise_pred_pp + + if is_pp_last: + # Last stage: recv activations, run final slice + output, sample, broadcast + recv_buffer = torch.empty( + (max_video_seq_len, batch_size, hidden_size), + dtype=next(self.model.parameters()).dtype, + device=latent_model_input[0].device, + ) + recv_from_prev_pipeline_rank_(recv_buffer) + recv_buffer = recv_buffer.to(torch.bfloat16) + self.model.set_input_tensor(recv_buffer) + noise_pred_pp = self.model(latent_model_input, grid_sizes=grid_sizes, t=timestep, **arg_c) + + noise_pred_pp = broadcast_from_last_pipeline_stage( + noise_pred_pp_shape, dtype=noise_pred_pp.dtype, tensor=noise_pred_pp.contiguous() + ) + return noise_pred_pp + + # Intermediate stages: recv -> run local slice -> send -> receive broadcast token + recv_buffer = torch.empty( + (max_video_seq_len, batch_size, hidden_size), + dtype=next(self.model.parameters()).dtype, + device=latent_model_input[0].device, + ) + recv_from_prev_pipeline_rank_(recv_buffer) + recv_buffer = recv_buffer.to(torch.bfloat16) + self.model.set_input_tensor(recv_buffer) + hidden_states = self.model(latent_model_input, grid_sizes=grid_sizes, t=timestep, **arg_c) + send_to_next_pipeline_rank(hidden_states) + + noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) + return noise_pred_pp + + def generate( + self, + prompts, + sizes, + frame_nums, + shift=5.0, + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + ): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + prompts (`list[str]`): + Text prompt for content generation + sizes (list[tuple[int, int]]): + Controls video resolution, (width,height). + frame_nums (`list[int]`): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + + # preprocess + target_shapes = [] + for size, frame_num in zip(sizes, frame_nums): + target_shapes.append( + ( + self.vae.config.z_dim, + (frame_num - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2], + ) + ) + max_video_seq_len = 0 + seq_lens = [] + for target_shape in target_shapes: + seq_len = ( + math.ceil( + (target_shape[2] * target_shape[3]) + / (self.patch_size[1] * self.patch_size[2]) + * target_shape[1] + / self.sp_size + ) + * self.sp_size + ) + seq_lens.append(seq_len) + max_video_seq_len = max(seq_lens) + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + ## process context + # we implement similar to Wan's diffuser setup + # (https://github.com/huggingface/diffusers/blob/0f252be0ed42006c125ef4429156cb13ae6c1d60/src/diffusers/pipelines/wan/pipeline_wan.py#L157) + # in which we pad the text to 512, pass through text encoder, and truncate to the actual tokens, then pad with 0s to 512. + context_max_len = self.text_len + context_lens = [] + contexts = [] + contexts_null = [] + for prompt in prompts: + if not self.t5_cpu: + self.text_encoder.to(self.device) + context = _encode_text(self.tokenizer, self.text_encoder, self.device, prompt) + context_null = _encode_text(self.tokenizer, self.text_encoder, self.device, n_prompt) + if offload_model: + self.text_encoder.cpu() + else: + context = self.text_encoder([prompt], torch.device("cpu"))[0].to(self.device) + context_null = self.text_encoder([n_prompt], torch.device("cpu"))[0].to(self.device) + context_lens.append(context_max_len) # all samples have the same context_max_len + contexts.append(context) + contexts_null.append(context_null) + + # pad to context_max_len tokens, and stack to a tensor of shape [s, b, hidden] + contexts = [F.pad(context, (0, 0, 0, context_max_len - context.shape[0])) for context in contexts] + contexts_null = [ + F.pad(context_null, (0, 0, 0, context_max_len - context_null.shape[0])) for context_null in contexts_null + ] + contexts = torch.stack(contexts, dim=1) + contexts_null = torch.stack(contexts_null, dim=1) + + ## setup noise + noises = [] + for target_shape in target_shapes: + noises.append( + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g, + ) + ) + + # calculate grid_sizes + grid_sizes = [ + grid_sizes_calculation( + input_shape=u.shape[1:], + patch_size=self.model.patch_size, + ) + for u in noises + ] + grid_sizes = torch.tensor(grid_sizes, dtype=torch.long) + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, "no_sync", noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + # Instantiate per-sample schedulers so each sample maintains its own state + batch_size_for_schedulers = len(noises) + schedulers = [] + for _ in range(batch_size_for_schedulers): + base_sched = FlowMatchEulerDiscreteScheduler.from_pretrained(self.model_id, subfolder="scheduler") + s = UniPCMultistepScheduler.from_config(base_sched.config, flow_shift=shift) + s.set_timesteps(sampling_steps, device=self.device) + + schedulers.append(s) + timesteps = schedulers[0].timesteps + + # sample videos + latents = noises + + cu_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)]) + cu_q = cu_q.to(torch.int32).to(self.device) + cu_kv_self = cu_q + cu_kv_cross = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(context_lens), dim=0)]) + cu_kv_cross = cu_kv_cross.to(torch.int32).to(self.device) + packed_seq_params = { + "self_attention": PackedSeqParams( + cu_seqlens_q=cu_q, + cu_seqlens_q_padded=cu_q, + cu_seqlens_kv=cu_kv_self, + cu_seqlens_kv_padded=cu_kv_self, + qkv_format=self.model.config.qkv_format, + ), + "cross_attention": PackedSeqParams( + cu_seqlens_q=cu_q, + cu_seqlens_q_padded=cu_q, + cu_seqlens_kv=cu_kv_cross, + qkv_format=self.model.config.qkv_format, + ), + } + + arg_c = {"context": contexts, "max_seq_len": max_video_seq_len, "packed_seq_params": packed_seq_params} + arg_null = { + "context": contexts_null, + "max_seq_len": max_video_seq_len, + "packed_seq_params": packed_seq_params, + } + + for _, t in enumerate(tqdm(timesteps)): + batch_size = len(latents) + + # patchify latents + unpatchified_latents = latents + latents = patchify(latents, self.patch_size) + # pad to have same length + for i in range(batch_size): + latents[i] = F.pad(latents[i], (0, 0, 0, max_video_seq_len - latents[i].shape[0])) + latents = torch.stack(latents, dim=1) + + latent_model_input = latents + timestep = [t] * batch_size + timestep = torch.stack(timestep) + + self.model.to(self.device) + noise_pred_cond = self.forward_pp_step( + latent_model_input, + grid_sizes=grid_sizes, + max_video_seq_len=max_video_seq_len, + timestep=timestep, + arg_c=arg_c, + ) + + noise_pred_uncond = self.forward_pp_step( + latent_model_input, + grid_sizes=grid_sizes, + max_video_seq_len=max_video_seq_len, + timestep=timestep, + arg_c=arg_null, + ) + + # run unpatchify + unpatchified_noise_pred_cond = noise_pred_cond + unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. + unpatchified_noise_pred_cond = unpatchify( + unpatchified_noise_pred_cond, grid_sizes, self.vae.config.z_dim, self.patch_size + ) + unpatchified_noise_pred_uncond = noise_pred_uncond + unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond.transpose(0, 1) # bring sbhd -> bshd + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. + unpatchified_noise_pred_uncond = unpatchify( + unpatchified_noise_pred_uncond, grid_sizes, self.vae.config.z_dim, self.patch_size + ) + + noise_preds = [] + for i in range(batch_size): + noise_pred = unpatchified_noise_pred_uncond[i] + guide_scale * ( + unpatchified_noise_pred_cond[i] - unpatchified_noise_pred_uncond[i] + ) + noise_preds.append(noise_pred) + + # step and update latents + latents = [] + for i in range(batch_size): + temp_x0 = schedulers[i].step( + noise_preds[i].unsqueeze(0), t, unpatchified_latents[i].unsqueeze(0), return_dict=False + )[0] + latents.append(temp_x0.squeeze(0)) + + x0 = latents + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + if self.rank == 0: + # Diffusers' VAE decoding + latents = torch.stack(x0, dim=0) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + videos = self.vae.decode(latents).sample + else: + videos = None + + del noises, latents + del schedulers + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos if self.rank == 0 else None diff --git a/src/megatron/bridge/diffusion/models/wan/flow_matching/flow_matching_pipeline_wan.py b/src/megatron/bridge/diffusion/models/wan/flow_matching/flow_matching_pipeline_wan.py new file mode 100644 index 0000000000..5b18b6480a --- /dev/null +++ b/src/megatron/bridge/diffusion/models/wan/flow_matching/flow_matching_pipeline_wan.py @@ -0,0 +1,161 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Tuple + +import torch +import torch.nn as nn +from dfm.src.automodel.flow_matching.adapters.base import FlowMatchingContext, ModelAdapter +from dfm.src.automodel.flow_matching.flow_matching_pipeline import FlowMatchingPipeline +from megatron.core import parallel_state + +from megatron.bridge.diffusion.models.wan.utils import thd_split_inputs_cp + + +class WanAdapter(ModelAdapter): + """ + Model adapter for Wan model (Megatron version). + + Handles mapping of standard FlowMatchingContext to Wan specific inputs. + """ + + def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: + grid_sizes = context.batch["grid_sizes"] + noisy_latents = context.noisy_latents + video_latents = context.video_latents # noqa: F841 + loss_mask = context.batch["loss_mask"] # noqa: F841 + context_embeddings = context.batch["context_embeddings"] + timesteps = context.timesteps + packed_seq_params = context.batch["packed_seq_params"] + + # tranpose back to have shape "sbhd" + # (before we reshaped to "bshd" to be compatible with flow matching pipeline) + noisy_latents = noisy_latents.transpose(0, 1) + + # ======================================================================== + # Cast model inputs to bf16 + # ======================================================================== + + noisy_latents = noisy_latents.to(torch.bfloat16) + context_embeddings = context_embeddings.to(torch.bfloat16) + + # NOTE: investigate the affect of bf16 timesteps on embedding precision + # CRITICAL: Keep timesteps in fp32 for embedding precision + # timesteps = timesteps.float() # NOT bf16! + timesteps = timesteps.to(torch.bfloat16) + + # ======================================================================== + # Split accross context parallelism + # ======================================================================== + + if parallel_state.get_context_parallel_world_size() > 1: + noisy_latents = thd_split_inputs_cp( + noisy_latents, + packed_seq_params["self_attention"].cu_seqlens_q_padded, + parallel_state.get_context_parallel_group(), + ) + # TODO (pmannan): Disable CP for CrossAttention as KV context is small. + # We don't need to split context embeddings across context parallelism + # if we disable context parallelism for cross-attention + context_embeddings = thd_split_inputs_cp( + context_embeddings, + packed_seq_params["cross_attention"].cu_seqlens_kv_padded, + parallel_state.get_context_parallel_group(), + ) + else: + noisy_latents = noisy_latents + context_embeddings = context_embeddings + + return { + "noisy_latents": noisy_latents, + "grid_sizes": grid_sizes, + "timesteps": timesteps, + "context_embeddings": context_embeddings, + "packed_seq_params": packed_seq_params, + } + + def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor: + """ + Execute forward pass for Wan model. + + Args: + model: Wan model + inputs: Dictionary from prepare_inputs() + + Returns: + Model prediction tensor + """ + + model_pred = model( + x=inputs["noisy_latents"], + grid_sizes=inputs["grid_sizes"], + t=inputs["timesteps"], + context=inputs["context_embeddings"], + packed_seq_params=inputs["packed_seq_params"], + ) + return self.post_process_prediction(model_pred) + + +class WanFlowMatchingPipeline(FlowMatchingPipeline): + """ + Wan-specific Flow Matching pipeline handling Context Parallelism and Custom Noise. + + This pipeline extends the standard FlowMatchingPipeline to support: + 1. Wan-specific noise generation (patching + padding) + 2. Context Parallelism (CP) splitting of inputs + 3. Masked loss computation + """ + + def determine_task_type(self, data_type: str) -> str: + """Determine task type based on data type and randomization.""" + return "t2v" + + def compute_loss( + self, + model_pred: torch.Tensor, + target: torch.Tensor, + sigma: torch.Tensor, + batch: Dict[str, Any], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + loss_mask = batch["loss_mask"] + packed_seq_params = batch["packed_seq_params"] + + # tranpose back to have shape "sbhd" + # (before we reshaped to "bshd" to be compatible with flow matching pipeline) + target = target.transpose(0, 1) + + # ======================================================================== + # Split accross context parallelism + # ======================================================================== + + if parallel_state.get_context_parallel_world_size() > 1: + target = thd_split_inputs_cp( + target, + packed_seq_params["self_attention"].cu_seqlens_q_padded, + parallel_state.get_context_parallel_group(), + ) + split_loss_mask = thd_split_inputs_cp( + loss_mask, + packed_seq_params["self_attention"].cu_seqlens_q_padded, + parallel_state.get_context_parallel_group(), + ) + else: + target = target + split_loss_mask = loss_mask + + batch["loss_mask"] = split_loss_mask + weighted_loss, average_weighted_loss, unweighted_loss, average_unweighted_loss, loss_weight, loss_mask = ( + super().compute_loss(model_pred, target, sigma, batch) + ) + return weighted_loss, average_weighted_loss, unweighted_loss, average_unweighted_loss, loss_weight, loss_mask diff --git a/src/megatron/bridge/diffusion/models/wan/flow_matching/time_shift_utils.py b/src/megatron/bridge/diffusion/models/wan/flow_matching/time_shift_utils.py new file mode 100644 index 0000000000..a221610dec --- /dev/null +++ b/src/megatron/bridge/diffusion/models/wan/flow_matching/time_shift_utils.py @@ -0,0 +1,116 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch + + +def time_shift( + t: torch.Tensor, + image_seq_len: int, + shift_type: str = "constant", + base_shift: float = 0.5, + max_shift: float = 1.15, + constant: float = 3.0, +): + """ + Convert timesteps to sigmas with sequence-length-aware shifting. + + Args: + t: timesteps in range [0, 1] + image_seq_len: number of tokens (frames * height * width / patch_size^2) + shift_type: "linear", "sqrt", or "constant" + base_shift: base shift for linear mode + max_shift: max shift for linear mode + constant: shift value for constant mode (default 3.0 matches Pika) + + Returns: + sigma values for noise scheduling + """ + if shift_type == "linear": + # Linear interpolation based on sequence length + mu = base_shift + (max_shift - base_shift) * (image_seq_len / 4096) + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)) + + elif shift_type == "sqrt": + # Square root scaling (Flux-style) + # Assuming 128x128 latent space (1024x1024 image) gives mu=3 + mu = np.maximum(1.0, np.sqrt(image_seq_len / (128.0 * 128.0)) * 3.0) + return mu / (mu + (1 / t - 1)) + + elif shift_type == "constant": + # Constant shift (Pika default) + return constant / (constant + (1 / t - 1)) + + else: + # No shift, return original t + return t + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, + batch_size: int, + logit_mean: float = 0.0, + logit_std: float = 1.0, + mode_scale: float = 1.29, +): + """ + Sample timesteps from different distributions for better training coverage. + + Args: + weighting_scheme: "uniform", "logit_normal", or "mode" + batch_size: number of samples to generate + logit_mean: mean for logit-normal distribution + logit_std: std for logit-normal distribution + mode_scale: scale for mode-based sampling + + Returns: + Tensor of shape (batch_size,) with values in [0, 1] + """ + if weighting_scheme == "logit_normal": + # SD3-style logit-normal sampling + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + + elif weighting_scheme == "mode": + # Mode-based sampling (concentrates around certain timesteps) + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + + else: + # Uniform sampling (default) + u = torch.rand(size=(batch_size,), device="cpu") + + return u + + +def get_flow_match_loss_weight(sigma: torch.Tensor, shift: float = 3.0): + """ + Compute loss weights for flow matching based on sigma values. + + Higher sigma (more noise) typically gets higher weight. + + Args: + sigma: sigma values in range [0, 1] + shift: weight scaling factor + + Returns: + Loss weights with same shape as sigma + """ + # Flow matching weight: weight = 1 + shift * sigma + # This gives more weight to noisier timesteps + weight = 1.0 + shift * sigma + return weight diff --git a/src/megatron/bridge/diffusion/models/wan/inference/__init__.py b/src/megatron/bridge/diffusion/models/wan/inference/__init__.py new file mode 100644 index 0000000000..1395bc9375 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/wan/inference/__init__.py @@ -0,0 +1,26 @@ +import os + + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +SIZE_CONFIGS = { + "416*240": (416, 240), + "720*1280": (720, 1280), + "1280*720": (1280, 720), + "480*832": (480, 832), + "832*480": (832, 480), + "1024*1024": (1024, 1024), +} + +MAX_AREA_CONFIGS = { + "720*1280": 720 * 1280, + "1280*720": 1280 * 720, + "480*832": 480 * 832, + "832*480": 832 * 480, +} + +SUPPORTED_SIZES = { + "t2v-14B": ("720*1280", "1280*720", "480*832", "832*480"), + "t2v-1.3B": ("416*240", "480*832", "832*480"), +} diff --git a/src/megatron/bridge/diffusion/models/wan/inference/utils.py b/src/megatron/bridge/diffusion/models/wan/inference/utils.py new file mode 100644 index 0000000000..324331f634 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/wan/inference/utils.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import binascii +import os +import os.path as osp + +import imageio +import torch +import torchvision + + +__all__ = ["cache_video", "cache_image", "str2bool"] + + +def rand_name(length=8, suffix=""): + name = binascii.b2a_hex(os.urandom(length)).decode("utf-8") + if suffix: + if not suffix.startswith("."): + suffix = "." + suffix + name += suffix + return name + + +def cache_video(tensor, save_file=None, fps=30, suffix=".mp4", nrow=8, normalize=True, value_range=(-1, 1), retry=5): # noqa: D103 + # cache file + cache_file = osp.join("/tmp", rand_name(suffix=suffix)) if save_file is None else save_file + + # save to cache + error = None + for _ in range(retry): + try: + # preprocess + tensor = tensor.clamp(min(value_range), max(value_range)) + tensor = torch.stack( + [ + torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range) + for u in tensor.unbind(2) + ], + dim=1, + ).permute(1, 2, 3, 0) + tensor = (tensor * 255).type(torch.uint8).cpu() + + # write video + writer = imageio.get_writer(cache_file, fps=fps, codec="libx264", quality=8) + for frame in tensor.numpy(): + writer.append_data(frame) + writer.close() + return cache_file + except Exception as e: + error = e + continue + else: + print(f"cache_video failed, error: {error}", flush=True) + return None + + +def cache_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1), retry=5): # noqa: D103 + # cache file + suffix = osp.splitext(save_file)[1] + if suffix.lower() not in [".jpg", ".jpeg", ".png", ".tiff", ".gif", ".webp"]: + suffix = ".png" + + # save to cache + for _ in range(retry): + try: + tensor = tensor.clamp(min(value_range), max(value_range)) + torchvision.utils.save_image(tensor, save_file, nrow=nrow, normalize=normalize, value_range=value_range) + return save_file + except Exception: + continue + + +def str2bool(v): + """ + Convert a string to a boolean. + + Supported true values: 'yes', 'true', 't', 'y', '1' + Supported false values: 'no', 'false', 'f', 'n', '0' + + Args: + v (str): String to convert. + + Returns: + bool: Converted boolean value. + + Raises: + argparse.ArgumentTypeError: If the value cannot be converted to boolean. + """ + if isinstance(v, bool): + return v + v_lower = v.lower() + if v_lower in ("yes", "true", "t", "y", "1"): + return True + elif v_lower in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected (True/False)") diff --git a/src/megatron/bridge/diffusion/models/wan/rope_utils.py b/src/megatron/bridge/diffusion/models/wan/rope_utils.py new file mode 100644 index 0000000000..42bbc585e2 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/wan/rope_utils.py @@ -0,0 +1,87 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + + +class Wan3DRopeEmbeddings(torch.nn.Module): + """ + Wan 3D RoPE embeddings implementation. + Implements Wan's 3D RoPE embeddings for Mcore Attention based on Wan's implementation at https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py. + """ + + def __init__(self, dim_head, max_position_len): + super().__init__() + self.freqs = torch.cat( + [ + self.rope_params(max_position_len, dim_head - 4 * (dim_head // 6)), + self.rope_params(max_position_len, 2 * (dim_head // 6)), + self.rope_params(max_position_len, 2 * (dim_head // 6)), + ], + dim=1, + ) + if torch.cuda.is_available(): + self.freqs = self.freqs.cuda() + + def rope_params(self, max_position_len, dim_head, theta=10000): + assert dim_head % 2 == 0 + freqs = torch.outer( + torch.arange(max_position_len), 1.0 / torch.pow(theta, torch.arange(0, dim_head, 2).div(dim_head)) + ) + return freqs + + def forward(self, n_head, dim_head, cu_seqlens_q_padded, grid_sizes, device): + _, c = n_head, dim_head // 2 + + # split freqs + freqs = self.freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + freqs_real = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + freqs_real_i = torch.cat( + [ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, 1, -1) # <-- add 1,1 for batch/head broadcasting + + # Double dimension from c -> 2c with rotating angles as (x0, x0, x1, x1, ...), for interleaving RoPE + freqs_real_i = freqs_real_i.unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(seq_len, 1, 1, dim_head) + + freqs_real.append(freqs_real_i) + + # Pad freqs_real_i to (padded_seq_len, 1, 1, dim_head) with 0s + for i, freqs_real_i in enumerate(freqs_real): + seq_len_q_padded = cu_seqlens_q_padded[i + 1] - cu_seqlens_q_padded[i] + if freqs_real_i.shape[0] < seq_len_q_padded: + pad_shape = (seq_len_q_padded - freqs_real_i.shape[0], 1, 1, dim_head) + freqs_real_i = torch.cat( + [freqs_real_i, torch.zeros(pad_shape, dtype=freqs_real_i.dtype, device=freqs_real_i.device)], dim=0 + ) + freqs_real[i] = freqs_real_i + + # Each freqs_real[i] is (seq_len, 1, 1, dim_head) + # We concatenate them along dim=0 to get (concatenated_seq_len, 1, 1, dim_head) + freqs_real = torch.cat(freqs_real, dim=0) + + # Note: + # when running context_parallel, which must use "thd" for qkv_format, + # we don't need to scatter the freqs to the context parallel region, + # because mcore rope_utils will automatically retrieve the correct freqs for each context parallel region + + return freqs_real diff --git a/src/megatron/bridge/diffusion/models/wan/utils.py b/src/megatron/bridge/diffusion/models/wan/utils.py new file mode 100644 index 0000000000..1de1fe3435 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/wan/utils.py @@ -0,0 +1,135 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple + +import torch +import torch.distributed as dist +import transformer_engine_torch as tex + + +def grid_sizes_calculation( + input_shape: Tuple[int, int, int], # (F_latents, H_latents, W_latents) + patch_size: Tuple[int, int, int], # (pF, pH, pW) +) -> Tuple[int, int, int]: + """ + Compute the (f,h,w) output spatial/temporal dimensions of a Conv3d patch embedder. + """ + + F_latents, H_latents, W_latents = input_shape + pF, pH, pW = patch_size + F_patches = F_latents // pF + H_patches = H_latents // pH + W_patches = W_latents // pW + + return [F_patches, H_patches, W_patches] + + +def patchify(x, patch_size): + """ + Convert a list of reconstructed video tensor into patch embeddings. + This method is the inverse of `unpatchify`. + + Args: + x (list[torch.Tensor]): list of tensors, each with shape [c, F_patches * pF, H_patches * pH, W_patches * pW] + patch_size (tuple): (pF, pH, pW) + + Returns: + torch.Tensor: shape [ (F_patches * H_patches * W_patches), (c * pF * pH * pW)], + """ + out = [] + for u in x: + c, F_pF, H_pH, W_pW = u.shape + pF, pH, pW = patch_size + assert F_pF % pF == 0 and H_pH % pH == 0 and W_pW % pW == 0, ( + "Spatial dimensions must be divisible by patch size." + ) + + F_patches, H_patches, W_patches = F_pF // pF, H_pH // pH, W_pW // pW + + # split spatial dims into (grid, patch) and reorder to match original patch layout: + # start: (c, F_patches * pF, H_patches * pH, W_patches * pW) + # reshape -> (c, F_patches, pF, H_patches, pH, W_patches, pW) + # permute -> (F_patches, H_patches, W_patches, pF, pH, pW, c) + t = u.reshape(c, F_patches, pF, H_patches, pH, W_patches, pW) + t = t.permute(1, 3, 5, 2, 4, 6, 0) + + num_patches = F_patches * H_patches * W_patches + out.append(t.reshape(num_patches, c * (pF * pH * pW))) + return out + + +def unpatchify( + x: list[torch.Tensor], grid_sizes: list[Tuple[int, int, int]], out_dim: int, patch_size: Tuple[int, int, int] +) -> list[torch.Tensor]: + """ + Reconstruct video tensors from patch embeddings into a list of videotensors. + + Args: + x (list[torch.Tensor]): + list of tensors, each with shape [seq_len, c * pF * pH * pW] + grid_sizes (list[Tuple[int, int, int]]): + list of tensors, each with original spatial-temporal grid dimensions before patching, + (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + list[torch.Tensor]: list of tensors, each with shape [c, F_latents, H_latents, W_latents] + """ + + c = out_dim + out = [] + for u, v in zip(x, grid_sizes): + u = u[: math.prod(v)].view(*v, *patch_size, c) + u = torch.einsum("fhwpqrc->cfphqwr", u) + u = u.reshape(c, *[i * j for i, j in zip(v, patch_size)]) + out.append(u) + return out + + +def thd_split_inputs_cp( + x: torch.Tensor, cu_seqlens_q_padded: torch.Tensor, cp_group: dist.ProcessGroup +) -> torch.Tensor: + """ + Split a THD-packed tensor across CP ranks for inputs shaped [S, B, ...]. + + Args: + x: [S, B, ...] tensor (sequence first). + cu_seqlens_q_padded: 1D int32 THD cu_seqlens (padded) used for packing. + cp_group: context-parallel process group. + + Returns: + x_local: [S_local, B, ...] shard for this CP rank. + """ + # Move to [B, S, ...] to use THD partitioning along S + x_bs = x.transpose(0, 1).contiguous() # [B, S, ...] + + total_S = x_bs.size(1) + cp_size = dist.get_world_size(cp_group) + cp_rank = dist.get_rank(cp_group) + + # Compute this rank's THD partition indices (same API as during gather) + idx = tex.thd_get_partitioned_indices( + cu_seqlens_q_padded, # int32 offsets + total_S, + cp_size, + cp_rank, + ).to(device=x_bs.device, dtype=torch.long) # [S_local] + + # Take the shard along sequence dim + x_local_bs = x_bs.index_select(dim=1, index=idx).contiguous() # [B, S_local, ...] + + # Return to [S, B, ...] + x_local = x_local_bs.transpose(0, 1).contiguous() # [S_local, B, ...] + return x_local diff --git a/src/megatron/bridge/diffusion/models/wan/wan_layer_spec.py b/src/megatron/bridge/diffusion/models/wan/wan_layer_spec.py new file mode 100644 index 0000000000..e3a569e0a4 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/wan/wan_layer_spec.py @@ -0,0 +1,305 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import copy +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn +from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.jit import jit_fuser +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.attention import SelfAttentionSubmodules +from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TERowParallelLinear, +) +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.utils import make_viewless_tensor + +from megatron.bridge.diffusion.models.common.dit_attention import ( + DiTCrossAttention, + DiTCrossAttentionSubmodules, + DiTSelfAttention, +) + + +@dataclass +class WanWithAdaLNSubmodules(TransformerLayerSubmodules): # noqa: D101 + temporal_self_attention: Union[ModuleSpec, type] = IdentityOp + full_self_attention: Union[ModuleSpec, type] = IdentityOp + norm1: Union[ModuleSpec, type] = None + norm3: Union[ModuleSpec, type] = None + norm2: Union[ModuleSpec, type] = None + + +class WanAdaLN(MegatronModule): + """ + Adaptive Layer Normalization Module for DiT. + """ + + def __init__(self, config: TransformerConfig): + super().__init__(config) + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5) + + setattr(self.modulation, "sequence_parallel", config.sequence_parallel) + + @jit_fuser + def forward(self, timestep_emb): + e = (self.modulation + timestep_emb).transpose(0, 1) + e = e.chunk(6, dim=0) + return e + + @jit_fuser + def normalize_modulate(self, norm, hidden_states, shift, scale): + return self.modulate(norm(hidden_states), shift, scale) + + @jit_fuser + def modulate(self, x, shift, scale): + return x * (1 + scale) + shift + + @jit_fuser + def scale_add(self, residual, x, gate): + return residual + gate * x + + +class WanLayerWithAdaLN(TransformerLayer): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + DiT with Adapative Layer Normalization. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + ): + def _replace_no_cp_submodules(submodules): + modified_submods = copy.deepcopy(submodules) + modified_submods.cross_attention = IdentityOp + return modified_submods + + # Replace any submodules that will have CP disabled and build them manually later after TransformerLayer init. + # modified_submods = _replace_no_cp_submodules(submodules) + super().__init__( + config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout + ) + + # TODO (pmannan): Override Cross Attention to disable CP. + # Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as + # Q and lead to incorrect tensor shapes. + # if submodules.cross_attention != IdentityOp: + # cp_override_config = copy.deepcopy(config) + # cp_override_config.context_parallel_size = 1 + # cp_override_config.tp_comm_overlap = False + # self.cross_attention = build_module( + # submodules.cross_attention, + # config=cp_override_config, + # layer_number=layer_number, + # ) + # else: + # self.cross_attention = None + + self.full_self_attention = build_module( + submodules.full_self_attention, + config=self.config, + layer_number=layer_number, + ) + + self.adaLN = WanAdaLN(config=self.config) + self.norm1 = build_module( + submodules.norm1, + normalized_shape=config.hidden_size, + eps=config.layernorm_epsilon, + elementwise_affine=False, + ) + self.norm3 = build_module( + submodules.norm3, + normalized_shape=config.hidden_size, + eps=config.layernorm_epsilon, + elementwise_affine=True, + ) + self.norm2 = build_module( + submodules.norm2, + normalized_shape=config.hidden_size, + eps=config.layernorm_epsilon, + elementwise_affine=False, + ) + + # set attributes "average_gradients_across_tp_domain" for nn.Parameter objects + # this is used for gradient averaging across TP domain with sequence parallelism + self._mark_trainable_params_for_tp_grad_avg([self.norm3, self.adaLN]) + + def _mark_trainable_params_for_tp_grad_avg(self, modules: Optional[list] = None) -> None: + """Mark selected modules' trainable parameters to average gradients across TP domain.""" + target_modules = modules if modules is not None else [self] + for module in target_modules: + for _name, param in module.named_parameters(recurse=True): + if isinstance(param, nn.Parameter) and param.requires_grad: + setattr(param, "average_gradients_across_tp_domain", True) + + @jit_fuser + def add_residual(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + return x + residual + + def forward( + self, + hidden_states, + attention_mask=None, + context=None, + context_mask=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + inference_context=None, + rotary_pos_cos_sin=None, + **kwargs, + ): + # the timestep embedding is stored in attention_mask argument + timestep_emb = attention_mask + rope_emb = rotary_pos_emb + + shift_full, scale_full, gate_full, shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb) + + # ******************************************** full self attention ******************************************* + + # adaLN with scale + shift + gate + pre_full_attn_layernorm_output_ada = self.adaLN.normalize_modulate( + self.norm1, + hidden_states, + shift=shift_full, + scale=scale_full, + ) + + attention_output, bias = self.full_self_attention( + pre_full_attn_layernorm_output_ada, + attention_mask=None, + rotary_pos_emb=rope_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params["self_attention"], + ) + if bias is not None: + attention_output = attention_output + bias + + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=attention_output, gate=gate_full) + + # ******************************************** cross attention ****************************************************** + + # TODO (pmannan): Disable CP for CrossAttention as KV context is small. + # But needs better support for packed sequences and padding to ensure correct calculations + # packed_seq_params['cross_attention'].cu_seqlens_q = torch.tensor( + # [0, hidden_states.shape[0]], + # device=packed_seq_params['cross_attention'].cu_seqlens_kv.device, + # dtype=torch.int32) + attention_output, bias = self.cross_attention( + self.norm3(hidden_states), + attention_mask=context_mask, + key_value_states=context, + packed_seq_params=packed_seq_params["cross_attention"], + ) + if bias is not None: + attention_output = attention_output + bias + + hidden_states = self.add_residual(hidden_states, attention_output) + + # ******************************************** mlp ****************************************************** + + pre_mlp_layernorm_output_ada = self.adaLN.normalize_modulate( + self.norm2, + hidden_states, + shift=shift_mlp, + scale=scale_mlp, + ) + + mlp_output, bias = self.mlp(pre_mlp_layernorm_output_ada) + if bias is not None: + mlp_output = mlp_output + bias + + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) + + # TODO: Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. ??? + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) + # output = hidden_states + + return output, context + + +def get_wan_block_with_transformer_engine_spec() -> ModuleSpec: # noqa: D103 + params = {"attn_mask_type": AttnMaskType.padding} + return ModuleSpec( + module=WanLayerWithAdaLN, + submodules=WanWithAdaLNSubmodules( + norm1=nn.LayerNorm, + norm3=nn.LayerNorm, + norm2=nn.LayerNorm, + full_self_attention=ModuleSpec( + module=DiTSelfAttention, + params=params, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + cross_attention=ModuleSpec( + module=DiTCrossAttention, + params=params, + submodules=DiTCrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) diff --git a/src/megatron/bridge/diffusion/models/wan/wan_model.py b/src/megatron/bridge/diffusion/models/wan/wan_model.py new file mode 100644 index 0000000000..197002cafd --- /dev/null +++ b/src/megatron/bridge/diffusion/models/wan/wan_model.py @@ -0,0 +1,352 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import math +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +from diffusers.models.embeddings import Timesteps +from megatron.core import parallel_state, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import make_sharded_tensor_for_checkpoint +from torch import Tensor + +from megatron.bridge.diffusion.models.common.dit_embeddings import ParallelTimestepEmbedding +from megatron.bridge.diffusion.models.wan.wan_layer_spec import ( + get_wan_block_with_transformer_engine_spec as WanLayerWithAdaLNspec, +) + +from .rope_utils import Wan3DRopeEmbeddings + + +def sinusoidal_embedding_1d(dim, position): # noqa: D103 + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position + + # calculation + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +class Head(nn.Module): # noqa: D101 + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) + return x + + +class WanModel(VisionModule): + """ + WanModel is a VisionModule that implements a Wan model. + Attributes: + config (TransformerConfig): Configuration for the transformer. + pre_process (bool): Whether to apply pre-processing steps. + post_process (bool): Whether to apply post-processing steps. + fp16_lm_cross_entropy (bool): Whether to use fp16 for cross-entropy loss. + parallel_output (bool): Whether to use parallel output. + transformer_decoder_layer_spec (WanLayerWithAdaLNspec): Specification for the transformer decoder layer. + model_type (ModelType): Type of the model. + """ + + def __init__( + self, + config: TransformerConfig, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + transformer_decoder_layer_spec=WanLayerWithAdaLNspec, + **kwargs, + ): + super(WanModel, self).__init__(config=config) + + self.config: TransformerConfig = config + + self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + + # megatron core pipelining currently depends on model type + # TODO: remove this dependency ? + self.model_type = ModelType.encoder_or_decoder + + self.num_heads = self.config.num_attention_heads + self.freq_dim = self.config.freq_dim + self.in_channels = self.config.in_channels + self.out_channels = self.config.out_channels + self.patch_spatial = self.config.patch_spatial + self.patch_temporal = self.config.patch_temporal + self.patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial) + + # these attributes are unused for images/videos, we just set because bridge training requires for LLMs + self.share_embeddings_and_output_weights = False + + ###################################### + ########## Wan architecture ########## + + # embeddings + if self.pre_process: + self.patch_embedding = nn.Conv3d( + self.in_channels, self.config.hidden_size, kernel_size=self.patch_size, stride=self.patch_size + ) + + self.text_embedding = nn.Sequential( + nn.Linear(self.config.text_dim, self.config.crossattn_emb_size), + nn.GELU(approximate="tanh"), + nn.Linear(self.config.crossattn_emb_size, self.config.crossattn_emb_size), + ) + + # As in diffuser's Wan implementation + self.timesteps_proj = Timesteps(num_channels=self.freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = ParallelTimestepEmbedding( + in_channels=self.freq_dim, time_embed_dim=self.config.hidden_size + ) + self.time_proj_act_fn = nn.SiLU() + self.time_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size * 6) + + self.rope_embeddings = Wan3DRopeEmbeddings( + dim_head=self.config.hidden_size // self.num_heads, max_position_len=1024 + ) + + # decoder blocks + self.decoder = TransformerBlock( + config=self.config, + spec=self.transformer_decoder_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=False, + ) + + # output head + if self.post_process: + self.head = Head(self.config.hidden_size, self.out_channels, self.patch_size, eps=1e-6) + + # set attributes "average_gradients_across_tp_domain" for nn.Parameter objects + # this is used for gradient averaging across TP domain with sequence parallelism + self._mark_trainable_params_for_tp_grad_avg( + [ + self.patch_embedding, + self.text_embedding, + self.time_embedder, + self.time_proj, + self.head, + ] + ) + + def forward( + self, + x: Tensor, + grid_sizes: list[Tuple[int, int, int]], + t: Tensor, + context: Tensor, + packed_seq_params: PackedSeqParams = None, + **kwargs, + ) -> Tensor: + """Forward pass. + + Args: + x List[Tensor]: list of vae encoded data (in_channel, f, h, w) + grid_sizes List[Tuple[int, int, int]]: list of grid sizes (f, h, w) + t Tensor: timesteps + context List[Tensor]: list of context (text_len, hidden_size) + packed_seq_params PackedSeqParams: packed sequence parameters + + Returns: + Tensor: output tensor (still patchified) of shape [seq_len, batch_size, hidden_size] + """ + ################################# + ########## Wan forward ########## + + # ============= embedders ============= + + # run input embedding + if self.pre_process: + # x.shape [s, b, c * pF * pH * pW] + seq_len, batch_size, _ = x.shape + c = self.out_channels + pF, pH, pW = self.patch_size + x = x.reshape(seq_len * batch_size, pF, pH, pW, c) # output: x.shape [s * b, pF, pH, pW, c] + x = x.permute(0, 4, 1, 2, 3) # output: x.shape [s * b, c, pF, pH, pW] + x = self.patch_embedding(x) # output: x.shape [s * b, hidden_size, 1, 1, 1] + x = x.flatten(1) # output: x.shape [s * b, hidden_size] + x = x.reshape(seq_len, batch_size, -1) # output: x.shape [s, b, hidden_size] + + # split sequence for sequence_parallel + # TODO: for PP, do we move scatter_to_sequence_parallel_region here or after "x = self.decoder.input_tensor" ??? + if self.config.sequence_parallel: + x = tensor_parallel.scatter_to_sequence_parallel_region( + x + ) # output: x.shape [s * b // tp_size, hidden_size] + + else: + # intermediate stage of pipeline + x = self.decoder.input_tensor + + # time embeddings + e = self.time_embedder(self.timesteps_proj(t).to(x.dtype)) + e0 = self.time_proj(self.time_proj_act_fn(e)).unflatten(1, (6, self.config.hidden_size)) + + # context embeddings + context = self.text_embedding(context) # shape [text_len, b, hidden_size] + + # ============= decoder ============= + # calculate rotary pos emb + n_head, dim_head = self.num_heads, self.config.hidden_size // self.num_heads + cu_seqlens_q_padded = packed_seq_params["self_attention"].cu_seqlens_q_padded + rotary_pos_emb = self.rope_embeddings( + n_head, dim_head, cu_seqlens_q_padded, grid_sizes, t.device + ) # output: rotary_pos_emb.shape [s, b, 1, dim_head] + + # run decoder + x = self.decoder( + hidden_states=x, + attention_mask=e0, + context=context, + context_mask=None, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=None, + rotary_pos_sin=None, + packed_seq_params=packed_seq_params, + ) + + # return if not post_process + if not self.post_process: + return x + + # head + x = x.transpose(0, 1) # head expects shape [b, s, hidden_size] + x = self.head(x, e) # output: x.shape [b, s, c * pF * pH * pW] + x = x.transpose(0, 1) # reshape back to shape [s, b, c * pF * pH * pW] + + # gather outputs for sequence_parallel + # Note: in GPT models, because the vocab projection matrix is ColumnParallelLinear, the sequence is + # automatically gathered in ColumnParallelLinear forward pass. + # However, in Wan models, we need to gather the outputs manually. + if self.config.sequence_parallel: + x = tensor_parallel.gather_from_sequence_parallel_region(x) + return x # output: x.shape [s, b, c * pF * pH * pW] + + def set_input_tensor(self, input_tensor: Tensor) -> None: + """Sets input tensor to the model. + + See megatron.model.transformer.set_input_tensor() + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + assert len(input_tensor) == 1, "input_tensor should only be length 1 for gpt/bert" + self.decoder.set_input_tensor(input_tensor[0]) + + def sharded_state_dict( + self, prefix: str = "module.", sharded_offsets: tuple = (), metadata: Optional[Dict] = None + ) -> ShardedStateDict: + """Sharded state dict implementation for GPTModel backward-compatibility (removing extra state). + + Args: + prefix (str): Module name prefix. + sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. + metadata (Optional[Dict]): metadata controlling sharded state dict creation. + + Returns: + ShardedStateDict: sharded state dict for the GPTModel + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + + # Ensure replica ids for non-transformer embedder weights include pipeline dimension + for module in ["text_embedding", "time_embedding", "time_projection"]: + if hasattr(self, module): + for param_name, param in getattr(self, module).named_parameters(): + weight_key = f"{prefix}{module}.{param_name}" + if weight_key in sharded_state_dict: + self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key) + + return sharded_state_dict + + def _mark_trainable_params_for_tp_grad_avg(self, modules: Optional[list] = None) -> None: + """Mark selected modules' trainable parameters to average gradients across TP domain.""" + target_modules = modules if modules is not None else [self] + for module in target_modules: + for _name, param in module.named_parameters(recurse=True): + if isinstance(param, nn.Parameter) and param.requires_grad: + setattr(param, "average_gradients_across_tp_domain", True) + + def _set_embedder_weights_replica_id( + self, tensor: Tensor, sharded_state_dict: ShardedStateDict, embedder_weight_key: str + ) -> None: + """set replica ids of the weights in t_embedder for sharded state dict. + + Args: + sharded_state_dict (ShardedStateDict): state dict with the weight to tie + weight_key (str): key of the weight in the state dict. + This entry will be replaced with a tied version + + Returns: None, acts in-place + """ + tp_rank = parallel_state.get_tensor_model_parallel_rank() + vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() + vpp_rank = vpp_rank if vpp_rank else 0 + vpp_world = parallel_state.get_virtual_pipeline_model_parallel_world_size() + vpp_world = vpp_world if vpp_world else 1 + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + del sharded_state_dict[embedder_weight_key] + replica_id = ( + tp_rank, + (vpp_rank + pp_rank * vpp_world), + parallel_state.get_data_parallel_rank(with_context_parallel=True), + ) + + sharded_state_dict[embedder_weight_key] = make_sharded_tensor_for_checkpoint( + tensor=tensor, + key=embedder_weight_key, + replica_id=replica_id, + allow_shape_mismatch=False, + ) diff --git a/src/megatron/bridge/diffusion/models/wan/wan_provider.py b/src/megatron/bridge/diffusion/models/wan/wan_provider.py new file mode 100644 index 0000000000..2135c0e0c3 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/wan/wan_provider.py @@ -0,0 +1,90 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from dataclasses import dataclass +from typing import Callable + +import torch +import torch.nn.functional as F +from megatron.core import parallel_state +from megatron.core.models.common.vision_module.vision_module import VisionModule + +from megatron.bridge.diffusion.models.wan.wan_model import WanModel +from megatron.bridge.models.model_provider import ModelProviderMixin +from megatron.bridge.models.transformer_config import TransformerConfig + + +logger = logging.getLogger(__name__) + + +@dataclass +class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): # noqa: D101 + crossattn_emb_size: int = 1536 # cross attention emebedding size after linear projection + add_bias_linear: bool = True + gated_linear_unit: bool = False + + num_layers: int = 30 + hidden_size: int = 1536 + ffn_hidden_size: int = 8960 + num_attention_heads: int = 12 + layernorm_epsilon: float = 1e-6 + normalization: str = "RMSNorm" + layernorm_zero_centered_gamma: bool = False + layernorm_across_heads: bool = True + add_qkv_bias: bool = True + rotary_interleaved: bool = True + activation_func: Callable = F.gelu + hidden_dropout: float = 0 + attention_dropout: float = 0 + fp16_lm_cross_entropy: bool = False + parallel_output: bool = True + bf16: bool = False + params_dtype: torch.dtype = torch.float32 + qkv_format: str = "thd" # "sbhd". NOTE: if we use context parallelism, we need to use "thd" + apply_rope_fusion: bool = True + bias_activation_fusion: bool = True + # these attributes are unused for images/videos, we just set because bridge training requires for LLMs + seq_length: int = 1024 + share_embeddings_and_output_weights: bool = False + vocab_size: int = 25256 * 8 + make_vocab_size_divisible_by: int = 128 + + # images/videos attributes + in_channels: int = 16 + out_channels: int = 16 + patch_spatial: int = 2 + patch_temporal: int = 1 + freq_dim: int = 256 + text_len: int = 512 + text_dim: int = 4096 + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> WanModel: + vp_size = self.virtual_pipeline_model_parallel_size + if vp_size: + p_size = self.pipeline_model_parallel_size + assert (self.num_layers // p_size) % vp_size == 0, ( + "Make sure the number of model chunks is the same across all pipeline stages." + ) + + model = WanModel + + return model( + self, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + ) diff --git a/src/megatron/bridge/diffusion/models/wan/wan_step.py b/src/megatron/bridge/diffusion/models/wan/wan_step.py new file mode 100644 index 0000000000..4299c74ce1 --- /dev/null +++ b/src/megatron/bridge/diffusion/models/wan/wan_step.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from functools import partial +from typing import Iterable + +import torch +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.utils import get_model_config + +from megatron.bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan import ( + WanAdapter, + WanFlowMatchingPipeline, +) +from megatron.bridge.training.losses import masked_next_token_loss +from megatron.bridge.training.state import GlobalState + + +logger = logging.getLogger(__name__) + + +def wan_data_step(qkv_format, dataloader_iter): # noqa: D103 + batch = next(dataloader_iter) + batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} + # Construct packed sequence parameters + if ("seq_len_q" in batch) and ("seq_len_kv" in batch): + zero = torch.zeros(1, dtype=torch.int32, device="cuda") + + cu_seqlens = batch["seq_len_q"].cumsum(dim=0).to(torch.int32) + cu_seqlens = torch.cat((zero, cu_seqlens)) + + cu_seqlens_padded = batch["seq_len_q_padded"].cumsum(dim=0).to(torch.int32) + cu_seqlens_padded = torch.cat((zero, cu_seqlens_padded)) + + cu_seqlens_kv = batch["seq_len_kv"].cumsum(dim=0).to(torch.int32) + cu_seqlens_kv = torch.cat((zero, cu_seqlens_kv)) + + cu_seqlens_kv_padded = batch["seq_len_kv_padded"].cumsum(dim=0).to(torch.int32) + cu_seqlens_kv_padded = torch.cat((zero, cu_seqlens_kv_padded)) + + batch["packed_seq_params"] = { + "self_attention": PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_kv_padded=cu_seqlens_padded, + qkv_format=qkv_format, + ), + "cross_attention": PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + qkv_format=qkv_format, + ), + } + + # tranpose from "sbhd" to "bshd" to be compatible with flow matching pipeline + batch["video_latents"] = batch["video_latents"].transpose(0, 1) + + return batch + + +class WanForwardStep: # noqa: D101 + def __init__( + self, + use_sigma_noise: bool = True, + timestep_sampling: str = "uniform", + logit_mean: float = 0.0, + logit_std: float = 1.0, + flow_shift: float = 3.0, + mix_uniform_ratio: float = 0.1, + sigma_min: float = 0.0, # Default: no clamping (pretrain) + sigma_max: float = 1.0, # Default: no clamping (pretrain) + ): + self.diffusion_pipeline = WanFlowMatchingPipeline( + model_adapter=WanAdapter(), + timestep_sampling=timestep_sampling, + logit_mean=logit_mean, + logit_std=logit_std, + flow_shift=flow_shift, + mix_uniform_ratio=mix_uniform_ratio, + sigma_min=sigma_min, + sigma_max=sigma_max, + ) + self.use_sigma_noise = use_sigma_noise + self.timestep_sampling = timestep_sampling + self.logit_mean = logit_mean + self.logit_std = logit_std + self.flow_shift = flow_shift + self.mix_uniform_ratio = mix_uniform_ratio + self.sigma_min = sigma_min + self.sigma_max = sigma_max + + def __call__( + self, state: GlobalState, data_iterator: Iterable, model: VisionModule + ) -> tuple[torch.Tensor, partial]: + """ + Forward training step. + """ + timers = state.timers + straggler_timer = state.straggler_timer + + config = get_model_config(model) + + timers("batch-generator", log_level=2).start() + + qkv_format = getattr(config, "qkv_format", "sbhd") + with straggler_timer(bdata=True): + batch = wan_data_step(qkv_format, data_iterator) + timers("batch-generator").stop() + + check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss + check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss + + # run diffusion training step + with straggler_timer: + weighted_loss, average_weighted_loss, loss_mask, metrics = self.diffusion_pipeline.step( + model, + batch, + ) + output_tensor = torch.mean(weighted_loss, dim=-1) + batch["loss_mask"] = loss_mask + + # TODO: do we need to gather output with sequence or context parallelism here + # especially when we have pipeline parallelism + + loss = output_tensor + if "loss_mask" not in batch or batch["loss_mask"] is None: + loss_mask = torch.ones_like(loss) + loss_mask = batch["loss_mask"] + + loss_function = self._create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) + + return output_tensor, loss_function + + def _create_loss_function( + self, loss_mask: torch.Tensor, check_for_nan_in_loss: bool, check_for_spiky_loss: bool + ) -> partial: + """Create a partial loss function with the specified configuration. + + Args: + loss_mask: Used to mask out some portions of the loss + check_for_nan_in_loss: Whether to check for NaN values in the loss + check_for_spiky_loss: Whether to check for spiky loss values + + Returns: + A partial function that can be called with output_tensor to compute the loss + """ + return partial( + masked_next_token_loss, + loss_mask, + check_for_nan_in_loss=check_for_nan_in_loss, + check_for_spiky_loss=check_for_spiky_loss, + ) diff --git a/src/megatron/bridge/diffusion/recipes/__init__.py b/src/megatron/bridge/diffusion/recipes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/recipes/flux/__init__.py b/src/megatron/bridge/diffusion/recipes/flux/__init__.py new file mode 100644 index 0000000000..c477ea64db --- /dev/null +++ b/src/megatron/bridge/diffusion/recipes/flux/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.bridge.diffusion.recipes.flux.flux import model_config, pretrain_config + + +__all__ = ["model_config", "pretrain_config"] diff --git a/src/megatron/bridge/diffusion/recipes/flux/flux.py b/src/megatron/bridge/diffusion/recipes/flux/flux.py new file mode 100644 index 0000000000..3ec23efc31 --- /dev/null +++ b/src/megatron/bridge/diffusion/recipes/flux/flux.py @@ -0,0 +1,288 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List, Optional, Union + +import torch +from megatron.core.distributed import DistributedDataParallelConfig + +from megatron.bridge.diffusion.data.flux.flux_mock_datamodule import FluxMockDataModuleConfig +from megatron.bridge.diffusion.models.flux.flux_provider import FluxProvider +from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing +from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE +from megatron.bridge.training.comm_overlap import CommOverlapConfig +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + LoggerConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, get_mixed_precision_config + + +def model_config( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + seq_length: int = 1024, + # FLUX-specific parameters + num_joint_layers: int = 19, + num_single_layers: int = 38, + hidden_size: int = 3072, + num_attention_heads: int = 24, + in_channels: int = 64, + context_dim: int = 4096, + guidance_embed: bool = False, + guidance_scale: float = 3.5, +) -> FluxProvider: + """ + Configure the FLUX model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + seq_length (int): Sequence length for the model. + num_joint_layers (int): Number of double (joint) transformer blocks. + num_single_layers (int): Number of single transformer blocks. + hidden_size (int): Hidden dimension size. + num_attention_heads (int): Number of attention heads. + in_channels (int): Number of input channels (latent channels). + context_dim (int): Text encoder context dimension. + guidance_embed (bool): Whether to use guidance embedding (for FLUX-dev). + guidance_scale (float): Classifier-free guidance scale. + + Returns: + FluxProvider: Configuration for the FLUX model. + """ + return FluxProvider( + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_dtype, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + seq_length=seq_length, + # FLUX-specific + num_joint_layers=num_joint_layers, + num_single_layers=num_single_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + in_channels=in_channels, + context_dim=context_dim, + guidance_embed=guidance_embed, + guidance_scale=guidance_scale, + ) + + +def pretrain_config( + dir: Optional[str] = None, + name: str = "default", + # Dataset configuration + data_paths: Optional[List[str]] = None, + mock: bool = False, + # Model configuration + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + use_megatron_fsdp: bool = False, + # FLUX model configuration + num_joint_layers: int = 19, + num_single_layers: int = 38, + hidden_size: int = 3072, + num_attention_heads: int = 24, + in_channels: int = 64, + context_dim: int = 4096, + guidance_embed: bool = False, + guidance_scale: float = 3.5, + # Image configuration + image_H: int = 1024, + image_W: int = 1024, + vae_channels: int = 16, + vae_scale_factor: int = 8, + prompt_seq_len: int = 512, + pooled_prompt_dim: int = 768, + # Training hyperparameters + train_iters: int = 10000, + global_batch_size: int = 4, + micro_batch_size: int = 1, + lr: float = 1e-4, + lr_warmup_iters: int = 1000, + # Precision recipe + precision_config: Optional[Union[MixedPrecisionConfig, str]] = "bf16_mixed", + comm_overlap_config: Optional[CommOverlapConfig] = None, +) -> ConfigContainer: + """ + Create a pre-training configuration for FLUX model. + + Args: + dir (Optional[str]): Base directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + data_paths (Optional[List[str]]): List of paths to dataset files. If None, mock data will be used. + mock (bool): Whether to use mock data. If True, ignores data_paths. + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + use_megatron_fsdp (bool): Whether to use Megatron FSDP. + num_joint_layers (int): Number of double (joint) transformer blocks. + num_single_layers (int): Number of single transformer blocks. + hidden_size (int): Hidden dimension size. + num_attention_heads (int): Number of attention heads. + in_channels (int): Number of input channels (latent channels). + context_dim (int): Text encoder context dimension. + guidance_embed (bool): Whether to use guidance embedding (for FLUX-dev). + guidance_scale (float): Classifier-free guidance scale. + image_H (int): Image height. + image_W (int): Image width. + vae_channels (int): Number of VAE latent channels. + vae_scale_factor (int): VAE downsampling factor. + prompt_seq_len (int): Sequence length for text prompts (T5). + pooled_prompt_dim (int): Dimensionality of pooled text embeddings (CLIP). + train_iters (int): Total number of training iterations. + global_batch_size (int): Global batch size for training. + micro_batch_size (int): Micro batch size for training. + lr (float): Learning rate. + lr_warmup_iters (int): Number of warmup iterations for the learning rate. + precision_config (Optional[Union[MixedPrecisionConfig, str]]): Precision configuration. + comm_overlap_config (Optional[CommOverlapConfig]): Communication overlap configuration. + + Returns: + ConfigContainer: Configuration for pre-training. + """ + base_output_dir = dir if dir is not None else os.path.join(os.getcwd(), "nemo_experiments") + run_output_dir = os.path.join(base_output_dir, name) + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + tensorboard_dir = os.path.join(run_output_dir, "tb_logs") + + model_cfg = model_config( + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, + pipeline_parallelism_dtype=pipeline_parallelism_dtype, + virtual_pipeline_parallelism=virtual_pipeline_parallelism, + context_parallelism=context_parallelism, + sequence_parallelism=sequence_parallelism, + seq_length=1024, + num_joint_layers=num_joint_layers, + num_single_layers=num_single_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + in_channels=in_channels, + context_dim=context_dim, + guidance_embed=guidance_embed, + guidance_scale=guidance_scale, + ) + + opt_config, scheduler = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=lr_warmup_iters, + lr_decay_iters=train_iters, + max_lr=lr, + ) + opt_config.use_precision_aware_optimizer = False + + if isinstance(precision_config, str): + precision_config = get_mixed_precision_config(precision_config) + + precision_config.grad_reduce_in_fp32 = False + + if mock: + dataset = FluxMockDataModuleConfig( + path=None, + seq_length=1024, + image_H=image_H, + image_W=image_W, + vae_channels=vae_channels, + vae_scale_factor=vae_scale_factor, + prompt_seq_len=prompt_seq_len, + context_dim=context_dim, + pooled_prompt_dim=pooled_prompt_dim, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=16, + packing_buffer_size=None, + ) + else: + # Real dataset configuration using Energon WebDataset + from megatron.bridge.diffusion.data.flux.flux_energon_datamodule import FluxDataModuleConfig + + dataset = FluxDataModuleConfig( + path=data_paths, # Path to WebDataset shards directory + seq_length=1024, + vae_scale_factor=vae_scale_factor, + latent_channels=vae_channels, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=16, + task_encoder_seq_length=None, + packing_buffer_size=None, # Disable Sequence Packing for now + ) + + # Config Container + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=train_iters, + eval_interval=2000, + eval_iters=32, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + manual_gc=True, + manual_gc_interval=100, + manual_gc_eval=100, + ), + optimizer=opt_config, + scheduler=scheduler, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=False, + overlap_param_gather=False, + average_in_collective=True, + use_distributed_optimizer=True, + use_megatron_fsdp=use_megatron_fsdp, + ), + dataset=dataset, + logger=LoggerConfig( + log_interval=10, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + ), + tokenizer=TokenizerConfig(tokenizer_type="NullTokenizer", vocab_size=DEFAULT_NULL_TOKENIZER_VOCAB_SIZE), + checkpoint=CheckpointConfig( + save_interval=2000, + save=checkpoint_dir, + load=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + ), + rng=RNGConfig(seed=1234), + comm_overlap=comm_overlap_config, + mixed_precision=precision_config, + ) + + return cfg diff --git a/src/megatron/bridge/diffusion/recipes/wan/__init__.py b/src/megatron/bridge/diffusion/recipes/wan/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/megatron/bridge/diffusion/recipes/wan/wan.py b/src/megatron/bridge/diffusion/recipes/wan/wan.py new file mode 100644 index 0000000000..fc6df8b086 --- /dev/null +++ b/src/megatron/bridge/diffusion/recipes/wan/wan.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List, Optional, Union + +import torch +from megatron.core.distributed import DistributedDataParallelConfig + +from megatron.bridge.diffusion.data.wan.wan_energon_datamodule import WanDataModuleConfig +from megatron.bridge.diffusion.data.wan.wan_mock_datamodule import WanMockDataModuleConfig +from megatron.bridge.diffusion.models.wan.wan_provider import WanModelProvider +from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing +from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE +from megatron.bridge.training.comm_overlap import CommOverlapConfig +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + LoggerConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, get_mixed_precision_config + + +def model_config( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 1, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + seq_length: int = 1024, +) -> WanModelProvider: + """ + Configure the Wan model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + seq_length (int): Sequence length for the model. + Returns: + WanModelProvider: Configuration for the Wan model. + """ + return WanModelProvider( + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_dtype, + virtual_pipeline_model_parallel_size=None, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + seq_length=seq_length, + ) + + +def pretrain_config( + dir: Optional[str] = None, + name: str = "default", + # Dataset configuration + data_paths: Optional[List[str]] = None, + data_args_path: Optional[str] = None, + train_data_path: Optional[List[str]] = None, + valid_data_path: Optional[List[str]] = None, + test_data_path: Optional[List[str]] = None, + per_split_data_args_path: Optional[str] = None, + mock: bool = False, + # Model configuration + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 1, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + use_megatron_fsdp: bool = False, + # Training hyperparameters + train_iters: int = 10000, + global_batch_size: int = 4, + micro_batch_size: int = 1, + lr: float = 0.9e-4, + lr_warmup_iters: int = 2000, + # Precision recipe + precision_config: Optional[Union[MixedPrecisionConfig, str]] = "bf16_mixed", + comm_overlap_config: Optional[CommOverlapConfig] = None, +) -> ConfigContainer: + """ + Create a pre-training configuration for GPT3 175B model. + + The default configuration is expected to run on 64 nodes with 8 GPUs each. + + Args: + dir (Optional[str]): Base directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + data_paths (Optional[List[str]]): List of paths to dataset files. If None, mock data will be used. + data_args_path (Optional[str]): Path to file containing data arguments. + train_data_path (Optional[List[str]]): List of training data paths. + valid_data_path (Optional[List[str]]): List of validation data paths. + test_data_path (Optional[List[str]]): List of test data paths. + per_split_data_args_path (Optional[str]): Path to JSON file with per-split data configuration. + mock (bool): Whether to use mock data. If True, ignores data_paths. + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism to be passed to model_config. + sequence_parallelism (bool): Whether to use sequence parallelism. + train_iters (int): Total number of training iterations. + global_batch_size (int): Global batch size for training. + micro_batch_size (int): Micro batch size for training. + seq_length (int): Sequence length for training data. + lr (float): Learning rate. + min_lr (float): Minimum learning rate for cosine decay. + lr_warmup_iters (int): Number of warmup iterations for the learning rate. + precision_config (Optional[Union[MixedPrecisionConfig, str]]): Precision configuration for the model. + comm_overlap_config (Optional[CommOverlapConfig]): Communication overlap configuration for the model. + + Returns: + ConfigContainer: Configuration for pre-training. + """ + base_output_dir = dir if dir is not None else os.path.join(os.getcwd(), "nemo_experiments") + run_output_dir = os.path.join(base_output_dir, name) + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + tensorboard_dir = os.path.join(run_output_dir, "tb_logs") + + model_cfg = model_config( + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, + pipeline_parallelism_dtype=pipeline_parallelism_dtype, + virtual_pipeline_parallelism=virtual_pipeline_parallelism, + context_parallelism=context_parallelism, + sequence_parallelism=sequence_parallelism, + seq_length=1024, + ) + + opt_config, scheduler = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=lr_warmup_iters, + lr_decay_iters=train_iters, + max_lr=lr, + ) + opt_config.use_precision_aware_optimizer = False + + if isinstance(precision_config, str): + precision_config = get_mixed_precision_config(precision_config) + + precision_config.grad_reduce_in_fp32 = False + + if mock: + dataset = WanMockDataModuleConfig( + path=None, + seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs + F_latents=24, + H_latents=104, + W_latents=60, + context_seq_len=512, + context_embeddings_dim=4096, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=16, + packing_buffer_size=None, + ) + else: + dataset = WanDataModuleConfig( + path=None, + seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=10, + task_encoder_seq_length=None, + packing_buffer_size=150, + ) + + # Config Container + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=train_iters, + eval_interval=2000, + eval_iters=32, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + manual_gc=True, + manual_gc_interval=100, + manual_gc_eval=100, + ), + optimizer=opt_config, + scheduler=scheduler, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=False, + overlap_param_gather=False, + average_in_collective=True, + use_distributed_optimizer=True, + use_megatron_fsdp=use_megatron_fsdp, # need use_distributed_optimizer=True + ), + dataset=dataset, + logger=LoggerConfig( + log_interval=10, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + ), + tokenizer=TokenizerConfig(tokenizer_type="NullTokenizer", vocab_size=DEFAULT_NULL_TOKENIZER_VOCAB_SIZE), + checkpoint=CheckpointConfig( + save_interval=2000, + save=checkpoint_dir, + load=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + ), + rng=RNGConfig(seed=1234), + comm_overlap=comm_overlap_config, + mixed_precision=precision_config, + ) + + return cfg diff --git a/tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh b/tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh new file mode 100644 index 0000000000..023be90f96 --- /dev/null +++ b/tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh @@ -0,0 +1,14 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +CUDA_VISIBLE_DEVICES="0,1" pytest tests/functional_tests/diffusion/recipes -m "not pleasefixme" --with_downloads -v diff --git a/tests/functional_tests/diffusion/__init__.py b/tests/functional_tests/diffusion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/functional_tests/diffusion/recipes/__init__.py b/tests/functional_tests/diffusion/recipes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/functional_tests/diffusion/recipes/test_flux_pretrain.py b/tests/functional_tests/diffusion/recipes/test_flux_pretrain.py new file mode 100644 index 0000000000..9ee7a73602 --- /dev/null +++ b/tests/functional_tests/diffusion/recipes/test_flux_pretrain.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functional smoke tests for Mcore FLUX pretrain mock runs.""" + +import os +import subprocess + +import pytest + + +class TestMcoreFluxPretrain: + """Test class for Mcore FLUX pretrain functional tests.""" + + @pytest.mark.run_only_on("GPU") + def test_flux_pretrain_mock(self, tmp_path): + """ + Functional test for FLUX pretrain recipe with mock data. + + This test verifies that the FLUX pretrain recipe can run successfully + in mock mode with minimal configuration, ensuring: + 1. The distributed training can start without errors + 2. Model initialization works correctly + 3. Forward/backward passes complete successfully + 4. The training loop executes without crashes + """ + # Set up temporary directories for dataset and checkpoints + dataset_path = os.path.join(tmp_path, "mock_dataset") + checkpoint_dir = os.path.join(tmp_path, "checkpoints") + os.makedirs(dataset_path, exist_ok=True) + os.makedirs(checkpoint_dir, exist_ok=True) + + # Build the command for the mock run + cmd = [ + "python", + "-m", + "torch.distributed.run", + "--nproc_per_node=2", + "examples/diffusion/recipes/flux/pretrain_flux.py", + "--mock", + "--timestep-sampling", + "logit_normal", + "--scheduler-steps", + "1000", + "model.tensor_model_parallel_size=1", + "model.pipeline_model_parallel_size=1", + "model.context_parallel_size=1", + "model.num_joint_layers=1", + "model.num_single_layers=2", + "model.hidden_size=1024", + "model.num_attention_heads=8", + "model.ffn_hidden_size=4096", + "model.in_channels=64", + "model.context_dim=4096", + "model.guidance_embed=false", + f"checkpoint.save={checkpoint_dir}", + f"checkpoint.load={checkpoint_dir}", + "checkpoint.save_interval=200", + "optimizer.lr=1e-4", + "train.eval_iters=0", + "train.train_iters=10", + "train.global_batch_size=2", + "train.micro_batch_size=1", + "logger.log_interval=1", + ] + + # Run the command with a timeout + result = None + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=1800, # 30 minute timeout + check=True, + ) + + # Basic verification that the run completed + assert result.returncode == 0, f"Command failed with return code {result.returncode}" + + except subprocess.TimeoutExpired: + pytest.fail("FLUX pretrain mock run exceeded timeout of 1800 seconds (30 minutes)") + except subprocess.CalledProcessError as e: + result = e + pytest.fail(f"FLUX pretrain mock run failed with return code {e.returncode}") + finally: + # Always print output for debugging + if result is not None: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) diff --git a/tests/functional_tests/diffusion/recipes/test_wan_pretrain.py b/tests/functional_tests/diffusion/recipes/test_wan_pretrain.py new file mode 100644 index 0000000000..559700c3f3 --- /dev/null +++ b/tests/functional_tests/diffusion/recipes/test_wan_pretrain.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functional smoke tests for Mcore WAN pretrain mock runs.""" + +import os +import subprocess + +import pytest + + +class TestMcoreWanPretrain: + """Test class for Mcore WAN pretrain functional tests.""" + + @pytest.mark.run_only_on("GPU") + def test_wan_pretrain_mock(self, tmp_path): + """ + Functional test for WAN pretrain recipe with mock data. + + This test verifies that the WAN pretrain recipe can run successfully + in mock mode with minimal configuration, ensuring: + 1. The distributed training can start without errors + 2. Model initialization works correctly + 3. Forward/backward passes complete successfully + 4. The training loop executes without crashes + """ + # Set up temporary directories for dataset and checkpoints + dataset_path = os.path.join(tmp_path, "mock_dataset") + checkpoint_dir = os.path.join(tmp_path, "checkpoints") + os.makedirs(dataset_path, exist_ok=True) + os.makedirs(checkpoint_dir, exist_ok=True) + + # Build the command for the mock run + cmd = [ + "python", + "-m", + "torch.distributed.run", + "--nproc_per_node=2", + "examples/diffusion/recipes/wan/pretrain_wan.py", + "--training-mode", + "pretrain", + "model.tensor_model_parallel_size=1", + "model.pipeline_model_parallel_size=1", + "model.context_parallel_size=1", + "model.crossattn_emb_size=1536", + "model.hidden_size=1536", + "model.ffn_hidden_size=8960", + "model.num_attention_heads=12", + "model.num_layers=3", + "model.qkv_format=thd", + f"dataset.path={dataset_path}", + f"checkpoint.save={checkpoint_dir}", + f"checkpoint.load={checkpoint_dir}", + "checkpoint.load_optim=false", + "checkpoint.save_interval=200", + "optimizer.lr=5e-6", + "optimizer.min_lr=5e-6", + "train.eval_iters=0", + "train.train_iters=10", + "scheduler.lr_decay_style=constant", + "scheduler.lr_warmup_iters=0", + "model.seq_length=2048", + "dataset.seq_length=2048", + "train.global_batch_size=2", + "train.micro_batch_size=1", + "dataset.global_batch_size=2", + "dataset.micro_batch_size=1", + "logger.log_interval=1", + "--mock", + ] + + # Run the command with a timeout + result = None + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=1800, # 30 minute timeout + check=True, + ) + + # Basic verification that the run completed + assert result.returncode == 0, f"Command failed with return code {result.returncode}" + + except subprocess.TimeoutExpired: + pytest.fail("WAN pretrain mock run exceeded timeout of 1800 seconds (30 minutes)") + except subprocess.CalledProcessError as e: + result = e + pytest.fail(f"WAN pretrain mock run failed with return code {e.returncode}") + finally: + # Always print output for debugging + if result is not None: + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) diff --git a/tests/unit_tests/diffusion/__init__.py b/tests/unit_tests/diffusion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/diffusion/conftest.py b/tests/unit_tests/diffusion/conftest.py new file mode 100644 index 0000000000..cd20fc6e58 --- /dev/null +++ b/tests/unit_tests/diffusion/conftest.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch + + +def pytest_addoption(parser): + """ + Additional command-line arguments passed to pytest. + """ + parser.addoption( + "--with_downloads", + action="store_true", + help="pass this argument to active tests which download models from the cloud.", + ) + + +@pytest.fixture(autouse=True) +def downloads_weights(request): + """Fixture to validate if the with_downloads flag is passed if necessary""" + if request.node.get_closest_marker("with_downloads"): + if not request.config.getoption("--with_downloads"): + pytest.skip( + "To run this test, pass --with_downloads option. It will download (and cache) models from cloud." + ) + + +@pytest.fixture(autouse=True) +def reset_env_vars(): + """Reset environment variables""" + # Store the original environment variables before the test + original_env = dict(os.environ) + + # Run the test + yield + + # After the test, restore the original environment + os.environ.clear() + os.environ.update(original_env) + + +@pytest.fixture(autouse=True) +def check_gpu_requirements(request): + """Fixture to skip tests that require GPU when CUDA is not available""" + marker = request.node.get_closest_marker("run_only_on") + if marker and "gpu" in [arg.lower() for arg in marker.args]: + if not torch.cuda.is_available(): + pytest.skip("Test requires GPU but CUDA is not available") + + +def pytest_configure(config): + """ + Initial configuration of conftest. + + Note: DFM uses the following pattern for CPU/GPU test separation: + Tests don't use markers - GPU visibility is controlled by CUDA_VISIBLE_DEVICES. + """ + config.addinivalue_line( + "markers", + "with_downloads: runs the test using data present in tests/.data", + ) + config.addinivalue_line( + "markers", + "pleasefixme: marks test as needing fixes (will be skipped in CI)", + ) + config.addinivalue_line( + "markers", + "run_only_on: marks test to run only on specific hardware (CPU/GPU)", + ) diff --git a/tests/unit_tests/diffusion/data/common/__init__.py b/tests/unit_tests/diffusion/data/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/diffusion/data/common/test_diffusion_data_module.py b/tests/unit_tests/diffusion/data/common/test_diffusion_data_module.py new file mode 100644 index 0000000000..bd396bf88c --- /dev/null +++ b/tests/unit_tests/diffusion/data/common/test_diffusion_data_module.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.bridge.diffusion.data.common.diffusion_energon_datamodule import ( + DiffusionDataModuleConfig, +) + + +def test_diffusion_data_module_config_initialization(): + """Test DiffusionDataModuleConfig initialization and default values.""" + + config = DiffusionDataModuleConfig( + path="/path/to/dataset", + seq_length=2048, + micro_batch_size=4, + task_encoder_seq_length=512, + packing_buffer_size=100, + global_batch_size=32, + num_workers=8, + ) + + # Verify default values + assert config.dataloader_type == "external", "Expected default dataloader_type to be 'external'" + assert config.use_train_split_for_val is False, "Expected default use_train_split_for_val to be False" + + # Verify required parameters are set correctly + assert config.path == "/path/to/dataset" + assert config.seq_length == 2048 + assert config.micro_batch_size == 4 + assert config.task_encoder_seq_length == 512 + assert config.packing_buffer_size == 100 + assert config.global_batch_size == 32 + assert config.num_workers == 8 diff --git a/tests/unit_tests/diffusion/data/common/test_diffusion_sample.py b/tests/unit_tests/diffusion/data/common/test_diffusion_sample.py new file mode 100644 index 0000000000..2847c48298 --- /dev/null +++ b/tests/unit_tests/diffusion/data/common/test_diffusion_sample.py @@ -0,0 +1,136 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron.bridge.diffusion.data.common.diffusion_sample import DiffusionSample + + +def test_add(): + """Test __add__ method for DiffusionSample.""" + # Create two DiffusionSample instances with different seq_len_q + sample1 = DiffusionSample( + __key__="sample1", + __restore_key__=(), + __subflavor__=None, + __subflavors__=["default"], + video=torch.randn(3, 8, 16, 16), + context_embeddings=torch.randn(10, 512), + seq_len_q=torch.tensor(100), + ) + sample2 = DiffusionSample( + __key__="sample2", + __restore_key__=(), + __subflavor__=None, + __subflavors__=["default"], + video=torch.randn(3, 8, 16, 16), + context_embeddings=torch.randn(10, 512), + seq_len_q=torch.tensor(200), + ) + + # Test adding two DiffusionSample instances + result = sample1 + sample2 + assert result == 300, f"Expected 300, got {result}" + + # Test adding DiffusionSample with an integer + result = sample1 + 50 + assert result == 150, f"Expected 150, got {result}" + + +def test_radd(): + """Test __radd__ method for DiffusionSample.""" + # Create a DiffusionSample instance + sample = DiffusionSample( + __key__="sample", + __restore_key__=(), + __subflavor__=None, + __subflavors__=["default"], + video=torch.randn(3, 8, 16, 16), + context_embeddings=torch.randn(10, 512), + seq_len_q=torch.tensor(100), + ) + + # Test reverse addition with an integer + result = 50 + sample + assert result == 150, f"Expected 150, got {result}" + + # Test sum() function which uses __radd__ (starting with 0) + samples = [ + DiffusionSample( + __key__="sample1", + __restore_key__=(), + __subflavor__=None, + __subflavors__=["default"], + video=torch.randn(3, 8, 16, 16), + context_embeddings=torch.randn(10, 512), + seq_len_q=torch.tensor(10), + ), + DiffusionSample( + __key__="sample2", + __restore_key__=(), + __subflavor__=None, + __subflavors__=["default"], + video=torch.randn(3, 8, 16, 16), + context_embeddings=torch.randn(10, 512), + seq_len_q=torch.tensor(20), + ), + DiffusionSample( + __key__="sample3", + __restore_key__=(), + __subflavor__=None, + __subflavors__=["default"], + video=torch.randn(3, 8, 16, 16), + context_embeddings=torch.randn(10, 512), + seq_len_q=torch.tensor(30), + ), + ] + result = sum(samples) + assert result == 60, f"Expected 60, got {result}" + + +def test_lt(): + """Test __lt__ method for DiffusionSample.""" + # Create two DiffusionSample instances with different seq_len_q + sample1 = DiffusionSample( + __key__="sample1", + __restore_key__=(), + __subflavor__=None, + __subflavors__=["default"], + video=torch.randn(3, 8, 16, 16), + context_embeddings=torch.randn(10, 512), + seq_len_q=torch.tensor(100), + ) + sample2 = DiffusionSample( + __key__="sample2", + __restore_key__=(), + __subflavor__=None, + __subflavors__=["default"], + video=torch.randn(3, 8, 16, 16), + context_embeddings=torch.randn(10, 512), + seq_len_q=torch.tensor(200), + ) + + # Test comparing two DiffusionSample instances + assert sample1 < sample2, "Expected sample1 < sample2" + assert not (sample2 < sample1), "Expected not (sample2 < sample1)" + + # Test comparing DiffusionSample with an integer + assert sample1 < 150, "Expected sample1 < 150" + assert not (sample1 < 50), "Expected not (sample1 < 50)" + + # Test sorting a list of DiffusionSample instances + samples = [sample2, sample1] + sorted_samples = sorted(samples) + assert sorted_samples[0].seq_len_q.item() == 100, "Expected first element to have seq_len_q=100" + assert sorted_samples[1].seq_len_q.item() == 200, "Expected second element to have seq_len_q=200" diff --git a/tests/unit_tests/diffusion/data/common/test_diffusion_task_encoder.py b/tests/unit_tests/diffusion/data/common/test_diffusion_task_encoder.py new file mode 100644 index 0000000000..e4830e2d2d --- /dev/null +++ b/tests/unit_tests/diffusion/data/common/test_diffusion_task_encoder.py @@ -0,0 +1,184 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch + +from megatron.bridge.diffusion.data.common.diffusion_sample import DiffusionSample +from megatron.bridge.diffusion.data.common.diffusion_task_encoder_with_sp import ( + DiffusionTaskEncoderWithSequencePacking, +) + + +class ConcreteDiffusionTaskEncoder(DiffusionTaskEncoderWithSequencePacking): + """Concrete implementation for testing.""" + + def encode_sample(self, sample: dict) -> dict: + """Simple implementation for testing purposes.""" + return sample + + def batch(self, samples: List[DiffusionSample]) -> dict: + """Simple batch implementation that returns first sample as dict.""" + if len(samples) == 1: + sample = samples[0] + return dict( + video=sample.video.unsqueeze(0), + context_embeddings=sample.context_embeddings.unsqueeze(0), + context_mask=sample.context_mask.unsqueeze(0) if sample.context_mask is not None else None, + loss_mask=sample.loss_mask.unsqueeze(0) if sample.loss_mask is not None else None, + seq_len_q=sample.seq_len_q, + seq_len_q_padded=sample.seq_len_q_padded, + seq_len_kv=sample.seq_len_kv, + seq_len_kv_padded=sample.seq_len_kv_padded, + pos_ids=sample.pos_ids.unsqueeze(0) if sample.pos_ids is not None else None, + latent_shape=sample.latent_shape, + video_metadata=sample.video_metadata, + ) + else: + # For multiple samples, just return a simple dict + return {"samples": samples} + + +def create_diffusion_sample(key: str, seq_len: int, video_shape=(16, 8), embedding_dim=128) -> DiffusionSample: + """Helper function to create a DiffusionSample for testing.""" + return DiffusionSample( + __key__=key, + __restore_key__=(), + __subflavor__=None, + __subflavors__=["default"], + video=torch.randn(seq_len, video_shape[0]), + context_embeddings=torch.randn(10, embedding_dim), + context_mask=torch.ones(10), + loss_mask=torch.ones(seq_len), + seq_len_q=torch.tensor([seq_len], dtype=torch.int32), + seq_len_q_padded=torch.tensor([seq_len], dtype=torch.int32), + seq_len_kv=torch.tensor([10], dtype=torch.int32), + seq_len_kv_padded=torch.tensor([10], dtype=torch.int32), + pos_ids=torch.arange(seq_len).unsqueeze(1), + latent_shape=torch.tensor([4, 2, 4, 4], dtype=torch.int32), + video_metadata={"fps": 30, "resolution": "512x512"}, + ) + + +def test_select_samples_to_pack(): + """Test select_samples_to_pack method.""" + # Create encoder with seq_length=20 + encoder = ConcreteDiffusionTaskEncoder(seq_length=20) + + # Create samples with different sequence lengths + samples = [ + create_diffusion_sample("sample_1", seq_len=8), + create_diffusion_sample("sample_2", seq_len=12), + create_diffusion_sample("sample_3", seq_len=5), + create_diffusion_sample("sample_4", seq_len=7), + create_diffusion_sample("sample_5", seq_len=3), + ] + + # Call select_samples_to_pack + result = encoder.select_samples_to_pack(samples) + + # Verify result is a list of lists + assert isinstance(result, list), "Result should be a list" + assert all(isinstance(group, list) for group in result), "All elements should be lists" + + # Verify all samples are included + all_samples = [sample for group in result for sample in group] + assert len(all_samples) == len(samples), "All samples should be included" + + # Verify no bin exceeds seq_length + for group in result: + total_seq_len = sum(sample.seq_len_q.item() for sample in group) + assert total_seq_len <= encoder.seq_length, ( + f"Bin with total {total_seq_len} exceeds seq_length {encoder.seq_length}" + ) + + # Verify that bins are non-empty + assert all(len(group) > 0 for group in result), "No bin should be empty" + + print(f"✓ Successfully packed {len(samples)} samples into {len(result)} bins") + print(f" Bin sizes: {[sum(s.seq_len_q.item() for s in group) for group in result]}") + + +def test_pack_selected_samples(): + """Test pack_selected_samples method.""" + encoder = ConcreteDiffusionTaskEncoder(seq_length=100) + + # Create multiple samples to pack + sample_1_length = 10 + sample_2_length = 15 + sample_3_length = 8 + sample_1 = create_diffusion_sample("sample_1", seq_len=sample_1_length) + sample_2 = create_diffusion_sample("sample_2", seq_len=sample_2_length) + sample_3 = create_diffusion_sample("sample_3", seq_len=sample_3_length) + + samples_to_pack = [sample_1, sample_2, sample_3] + + # Pack the samples + packed_sample = encoder.pack_selected_samples(samples_to_pack) + + # Verify the packed sample is a DiffusionSample + assert isinstance(packed_sample, DiffusionSample), "Result should be a DiffusionSample" + + # Verify __key__ is concatenated + expected_key = "sample_1,sample_2,sample_3" + assert packed_sample.__key__ == expected_key, f"Key should be '{expected_key}'" + + # Verify video is concatenated along dim 0 + expected_video_len = 10 + 15 + 8 + assert packed_sample.video.shape[0] == expected_video_len, f"Video should have length {expected_video_len}" + + # Verify context_embeddings is concatenated + expected_context_len = 10 * 3 # 3 samples with 10 embeddings each + assert packed_sample.context_embeddings.shape[0] == expected_context_len, ( + f"Context embeddings should have length {expected_context_len}" + ) + + # Verify context_mask is concatenated + assert packed_sample.context_mask.shape[0] == expected_context_len, ( + f"Context mask should have length {expected_context_len}" + ) + + # Verify loss_mask is concatenated + assert packed_sample.loss_mask.shape[0] == expected_video_len, f"Loss mask should have length {expected_video_len}" + + # Verify seq_len_q is concatenated + assert packed_sample.seq_len_q.shape[0] == 3, "seq_len_q should have 3 elements" + assert torch.equal( + packed_sample.seq_len_q, torch.tensor([sample_1_length, sample_2_length, sample_3_length], dtype=torch.int32) + ), "seq_len_q values incorrect" + + assert packed_sample.seq_len_q_padded.shape[0] == 3, "seq_len_q_padded should have 3 elements" + assert torch.equal( + packed_sample.seq_len_q_padded, + torch.tensor([sample_1_length, sample_2_length, sample_3_length], dtype=torch.int32), + ), "seq_len_q_padded values incorrect" + + assert packed_sample.seq_len_kv.shape[0] == 3, "seq_len_kv should have 3 elements" + assert torch.equal(packed_sample.seq_len_kv, torch.tensor([10, 10, 10], dtype=torch.int32)), ( + "seq_len_kv values incorrect" + ) + + assert packed_sample.seq_len_kv_padded.shape[0] == 3, "seq_len_kv_padded should have 3 elements" + assert torch.equal(packed_sample.seq_len_kv_padded, torch.tensor([10, 10, 10], dtype=torch.int32)), ( + "seq_len_kv_padded values incorrect" + ) + + assert packed_sample.latent_shape.shape[0] == 3, "latent_shape should have 3 rows" + assert isinstance(packed_sample.video_metadata, list), "video_metadata should be a list" + assert len(packed_sample.video_metadata) == 3, "video_metadata should have 3 elements" + + print(f"✓ Successfully packed {len(samples_to_pack)} samples") + print(f" Packed video shape: {packed_sample.video.shape}") + print(f" Packed context embeddings shape: {packed_sample.context_embeddings.shape}") diff --git a/tests/unit_tests/diffusion/data/common/test_sequence_packing_utils.py b/tests/unit_tests/diffusion/data/common/test_sequence_packing_utils.py new file mode 100644 index 0000000000..23e77bddbf --- /dev/null +++ b/tests/unit_tests/diffusion/data/common/test_sequence_packing_utils.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.bridge.diffusion.data.common.sequence_packing_utils import ( + find_first_bin_that_fits, + first_fit, + first_fit_decreasing, +) + + +def test_find_first_bin_that_fits(): + """Test find_first_bin_that_fits function.""" + # Test case: Find a bin that fits + bins = [[5, 3], [10], [2, 2, 2]] + s = 2 + bin_size = 10 + result = find_first_bin_that_fits(bins, s, bin_size) + assert result == 0, "Should return index 0 as first bin (5+3+2=10) fits" + + # Test case: No bin fits + bins = [[8, 2], [9, 1], [10]] + s = 5 + bin_size = 10 + result = find_first_bin_that_fits(bins, s, bin_size) + assert result == -1, "Should return -1 as no bin can accommodate size 5" + + # Test case: Empty bins list + bins = [] + s = 5 + bin_size = 10 + result = find_first_bin_that_fits(bins, s, bin_size) + assert result == -1, "Should return -1 for empty bins list" + + # Test case: First bin doesn't fit, but second does + bins = [[9], [5], [3]] + s = 4 + bin_size = 10 + result = find_first_bin_that_fits(bins, s, bin_size) + assert result == 1, "Should return index 1 as second bin (5+4=9) fits" + + +def test_first_fit(): + """Test first_fit bin packing algorithm.""" + # Test case: Simple packing scenario + seqlens = [5, 3, 2, 7, 4] + pack_size = 10 + result = first_fit(seqlens, pack_size) + + # Verify all sequences are packed + all_items = [item for bin in result for item in bin] + assert sum(all_items) == sum(seqlens), "Sum of all packed items should equal sum of input" + + # Verify no bin exceeds pack_size + for bin in result: + assert sum(bin) <= pack_size, f"Bin {bin} exceeds pack_size {pack_size}" + + # Verify expected packing: [5, 3, 2], [7], [4] (first-fit order) + assert len(result) == 3, "Should create 3 bins" + assert result[0] == [5, 3, 2], "First bin should contain [5, 3, 2]" + assert result[1] == [7], "Second bin should contain [7]" + assert result[2] == [4], "Third bin should contain [4]" + + +def test_first_fit_decreasing(): + """Test first_fit_decreasing bin packing algorithm.""" + # Test case: Same sequences as first_fit but sorted in decreasing order + seqlens = [5, 3, 2, 7, 4] + pack_size = 10 + result = first_fit_decreasing(seqlens, pack_size) + + # Verify all sequences are packed + all_items = [item for bin in result for item in bin] + assert sum(all_items) == sum(seqlens), "Sum of all packed items should equal sum of input" + + # Verify no bin exceeds pack_size + for bin in result: + assert sum(bin) <= pack_size, f"Bin {bin} exceeds pack_size {pack_size}" + + # Verify expected packing: sorted [7, 5, 4, 3, 2] -> [7, 3], [5, 4, 2] (more efficient) + assert len(result) <= 3, "Should create at most 3 bins" + # First-fit-decreasing should pack: [7, 3], [5, 4], [2] + assert result[0] == [7, 3], "First bin should contain [7, 3]" + assert result[1] == [5, 4], "Second bin should contain [5, 4]" + assert result[2] == [2], "Third bin should contain [2]" diff --git a/tests/unit_tests/diffusion/data/flux/__init__.py b/tests/unit_tests/diffusion/data/flux/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/diffusion/data/flux/test_flux_energon_datamodule.py b/tests/unit_tests/diffusion/data/flux/test_flux_energon_datamodule.py new file mode 100644 index 0000000000..419e02392c --- /dev/null +++ b/tests/unit_tests/diffusion/data/flux/test_flux_energon_datamodule.py @@ -0,0 +1,105 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.bridge.diffusion.data.flux import flux_energon_datamodule as flux_dm_mod +from megatron.bridge.diffusion.data.flux.flux_taskencoder import FluxTaskEncoder + + +class _FakeDiffusionDataModule: + def __init__( + self, + *, + path: str, + seq_length: int, + packing_buffer_size: int, + task_encoder, + micro_batch_size: int, + global_batch_size: int, + num_workers: int, + use_train_split_for_val: bool = True, + ): + self.path = path + self.seq_length = seq_length + self.packing_buffer_size = packing_buffer_size + self.task_encoder = task_encoder + self.micro_batch_size = micro_batch_size + self.global_batch_size = global_batch_size + self.num_workers = num_workers + self.use_train_split_for_val = use_train_split_for_val + + # mimic API used by FluxDataModuleConfig.build_datasets + def train_dataloader(self): + return "train" + + +def test_flux_datamodule_config_initialization(monkeypatch): + # Patch the symbol used inside flux_energon_datamodule module + monkeypatch.setattr(flux_dm_mod, "DiffusionDataModule", _FakeDiffusionDataModule) + + cfg = flux_dm_mod.FluxDataModuleConfig( + path="", + seq_length=1024, + packing_buffer_size=4, + micro_batch_size=2, + global_batch_size=8, + num_workers=0, + vae_scale_factor=8, + latent_channels=16, + task_encoder_seq_length=1024, + ) + + # __post_init__ should construct a dataset with FluxTaskEncoder and propagate seq_length + assert isinstance(cfg.dataset, _FakeDiffusionDataModule) + assert cfg.sequence_length == cfg.dataset.seq_length == 1024 + assert isinstance(cfg.dataset.task_encoder, FluxTaskEncoder) + assert cfg.dataset.task_encoder.seq_length == 1024 + assert cfg.dataset.task_encoder.packing_buffer_size == 4 + assert cfg.dataset.task_encoder.vae_scale_factor == 8 + assert cfg.dataset.task_encoder.latent_channels == 16 + assert cfg.dataset.use_train_split_for_val is True + + # build_datasets should return train loader thrice + train, val, test = cfg.build_datasets(context=None) + assert train == "train" and val == "train" and test == "train" + + +def test_flux_datamodule_config_with_custom_parameters(monkeypatch): + """Test FluxDataModuleConfig with custom VAE and latent parameters.""" + monkeypatch.setattr(flux_dm_mod, "DiffusionDataModule", _FakeDiffusionDataModule) + + cfg = flux_dm_mod.FluxDataModuleConfig( + path="/path/to/dataset", + seq_length=2048, + packing_buffer_size=8, + micro_batch_size=4, + global_batch_size=16, + num_workers=8, + vae_scale_factor=16, + latent_channels=32, + task_encoder_seq_length=2048, + ) + + # Verify all parameters are correctly propagated + assert cfg.dataset.path == "/path/to/dataset" + assert cfg.dataset.seq_length == 2048 + assert cfg.dataset.packing_buffer_size == 8 + assert cfg.dataset.micro_batch_size == 4 + assert cfg.dataset.global_batch_size == 16 + assert cfg.dataset.num_workers == 8 + + # Verify task encoder parameters + assert cfg.dataset.task_encoder.vae_scale_factor == 16 + assert cfg.dataset.task_encoder.latent_channels == 32 + assert cfg.dataset.task_encoder.seq_length == 2048 + assert cfg.dataset.task_encoder.packing_buffer_size == 8 diff --git a/tests/unit_tests/diffusion/data/flux/test_flux_mock_datamodule.py b/tests/unit_tests/diffusion/data/flux/test_flux_mock_datamodule.py new file mode 100644 index 0000000000..98f7fad4c0 --- /dev/null +++ b/tests/unit_tests/diffusion/data/flux/test_flux_mock_datamodule.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron.bridge.diffusion.data.flux.flux_mock_datamodule import FluxMockDataModuleConfig + + +def test_flux_mock_datamodule_build_and_batch_shapes(): + cfg = FluxMockDataModuleConfig( + path="", + seq_length=1024, + packing_buffer_size=None, + micro_batch_size=2, + global_batch_size=8, + num_workers=0, + # Use small shapes for a light-weight test run + image_H=256, + image_W=256, + vae_channels=16, + vae_scale_factor=8, + prompt_seq_len=128, + context_dim=512, + pooled_prompt_dim=256, + image_precached=True, + text_precached=True, + num_train_samples=100, + ) + train_dl, val_dl, test_dl = cfg.build_datasets(_context=None) + assert train_dl is val_dl and val_dl is test_dl + + batch = next(iter(train_dl)) + expected_keys = { + "latents", + "prompt_embeds", + "pooled_prompt_embeds", + "text_ids", + "loss_mask", + } + assert expected_keys.issubset(set(batch.keys())) + + # Basic sanity checks on shapes/dtypes + batch_size = cfg.micro_batch_size + latent_h = cfg.image_H // cfg.vae_scale_factor + latent_w = cfg.image_W // cfg.vae_scale_factor + + # Check latents shape: [B, C, H, W] + assert batch["latents"].shape == (batch_size, cfg.vae_channels, latent_h, latent_w) + assert batch["latents"].dtype == torch.bfloat16 + + # Check prompt_embeds shape: [B, seq_len, context_dim] + assert batch["prompt_embeds"].shape == (batch_size, cfg.prompt_seq_len, cfg.context_dim) + assert batch["prompt_embeds"].dtype == torch.bfloat16 + + # Check pooled_prompt_embeds shape: [B, pooled_dim] + assert batch["pooled_prompt_embeds"].shape == (batch_size, cfg.pooled_prompt_dim) + assert batch["pooled_prompt_embeds"].dtype == torch.bfloat16 + + # Check text_ids shape: [B, seq_len, 3] + assert batch["text_ids"].shape == (batch_size, cfg.prompt_seq_len, 3) + assert batch["text_ids"].dtype == torch.bfloat16 + + # Check loss_mask shape: [B, num_patches] where num_patches = (H/2) * (W/2) for FLUX + num_patches = latent_h * latent_w + assert batch["loss_mask"].shape == (batch_size, num_patches) + assert batch["loss_mask"].dtype == torch.bfloat16 + + +def test_flux_mock_datamodule_without_precaching(): + """Test the mock datamodule with non-precached data.""" + cfg = FluxMockDataModuleConfig( + path="", + seq_length=1024, + micro_batch_size=1, + global_batch_size=4, + num_workers=0, + image_H=128, + image_W=128, + image_precached=False, + text_precached=False, + num_train_samples=50, + ) + train_dl, _, _ = cfg.build_datasets(_context=None) + + batch = next(iter(train_dl)) + + # When not precached, should have raw images and text + assert "images" in batch or "latents" in batch + assert "txt" in batch or "prompt_embeds" in batch + + # If images are not precached, they should be raw RGB + if "images" in batch: + assert batch["images"].shape[1] == 3 # RGB channels + assert batch["images"].dtype == torch.bfloat16 + + +def test_flux_mock_datamodule_different_image_sizes(): + """Test the mock datamodule with different image dimensions.""" + cfg = FluxMockDataModuleConfig( + path="", + seq_length=2048, + micro_batch_size=1, + global_batch_size=2, + num_workers=0, + image_H=512, + image_W=1024, + vae_channels=16, + vae_scale_factor=8, + num_train_samples=20, + ) + train_dl, _, _ = cfg.build_datasets(_context=None) + + batch = next(iter(train_dl)) + + latent_h = 512 // 8 # 64 + latent_w = 1024 // 8 # 128 + + assert batch["latents"].shape == (1, 16, latent_h, latent_w) + # Loss mask should cover all latent positions + assert batch["loss_mask"].shape == (1, latent_h * latent_w) diff --git a/tests/unit_tests/diffusion/data/flux/test_flux_taskencoder.py b/tests/unit_tests/diffusion/data/flux/test_flux_taskencoder.py new file mode 100644 index 0000000000..e9603f1b14 --- /dev/null +++ b/tests/unit_tests/diffusion/data/flux/test_flux_taskencoder.py @@ -0,0 +1,338 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron.bridge.diffusion.data.flux.flux_taskencoder import FluxTaskEncoder, cook, parallel_state + + +def test_cook_extracts_expected_fields(): + sample = { + "__key__": "k", + "__restore_key__": "rk", + "__subflavors__": [], + "json": {"meta": 1, "resolution": "1024x1024"}, + "pth": torch.randn(16, 128, 128), # [C, H, W] image latents + "pickle": { + "prompt_embeds": torch.randn(512, 4096), # T5 embeddings + "pooled_prompt_embeds": torch.randn(768), # CLIP pooled + }, + "unused": 123, + } + out = cook(sample) + assert "json" in out and out["json"] is sample["json"] + assert "pth" in out and torch.equal(out["pth"], sample["pth"]) + assert "pickle" in out and out["pickle"] is sample["pickle"] + # ensure basic keys from the sample are preserved by cook via basic_sample_keys() + assert out["__key__"] == sample["__key__"] + assert out["__restore_key__"] == sample["__restore_key__"] + assert out["__subflavors__"] == sample["__subflavors__"] + + +def test_encode_sample_no_context_parallel(monkeypatch): + # Ensure CP world size is 1 to avoid extra padding branch + monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 1, raising=False) + # Ensure seeded wrapper has an active worker config + from megatron.energon.task_encoder.base import WorkerConfig + + class _FakeWorkerCfg: + def worker_seed(self): + return 123 + + active_worker_sample_index = 0 + + monkeypatch.setattr(WorkerConfig, "active_worker_config", _FakeWorkerCfg(), raising=False) + + # Construct a minimal, consistent sample + C = 16 # latent channels + H_latents, W_latents = 128, 128 + vae_scale_factor = 8 + + # Image latent shape: [C, H, W] + image_latent = torch.randn(C, H_latents, W_latents) + + # Text embeddings + text_seq_len, context_dim = 256, 4096 + prompt_embeds = torch.randn(text_seq_len, context_dim) + pooled_prompt_embeds = torch.randn(768) + + text_embeddings = { + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + } + + sample = { + "__key__": "k", + "__restore_key__": "rk", + "__subflavors__": [], + "json": {"meta": 1}, + "pth": image_latent, + "pickle": text_embeddings, + } + + enc = FluxTaskEncoder( + seq_length=1024, + vae_scale_factor=vae_scale_factor, + latent_channels=C, + packing_buffer_size=None, + ) + out = enc.encode_sample(sample) + + # Check latent shape storage + assert out.latent_shape.dtype == torch.int32 + assert torch.equal(out.latent_shape, torch.tensor([H_latents, W_latents], dtype=torch.int32)) + + # Check video (image latent) shape - should be unpacked [C, H, W] + assert out.video.shape == (C, H_latents, W_latents) + + # Loss mask and seq lengths + # For FLUX, seq_len_q is (H/2)*(W/2) after packing + seq_len_q = (H_latents // 2) * (W_latents // 2) + assert out.loss_mask.dtype == torch.bfloat16 + assert out.loss_mask.shape[0] == seq_len_q + assert torch.equal(out.seq_len_q, torch.tensor([seq_len_q], dtype=torch.int32)) + + # Context embeddings are padded to fixed 512 inside encode_sample + assert torch.equal(out.seq_len_kv, torch.tensor([512], dtype=torch.int32)) + assert torch.equal(out.seq_len_q_padded, out.seq_len_q) + assert torch.equal(out.seq_len_kv_padded, out.seq_len_kv) + + # Check context embeddings shape + assert out.context_embeddings.shape[0] == 512 # padded length + + # Metadata passthrough + assert isinstance(out.video_metadata, dict) + assert "pooled_prompt_embeds" in out.video_metadata + assert "text_ids" in out.video_metadata + assert out.__key__ == sample["__key__"] + assert out.__restore_key__ == sample["__restore_key__"] + assert out.__subflavors__ == sample["__subflavors__"] + + +def test_encode_sample_with_context_parallel(monkeypatch): + """Test encoding with context parallelism enabled to check padding.""" + # Set CP world size to 2 to trigger padding logic + monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 2, raising=False) + from megatron.energon.task_encoder.base import WorkerConfig + + class _FakeWorkerCfg: + def worker_seed(self): + return 456 + + active_worker_sample_index = 0 + + monkeypatch.setattr(WorkerConfig, "active_worker_config", _FakeWorkerCfg(), raising=False) + + C = 16 + H_latents, W_latents = 64, 64 # Smaller size + + image_latent = torch.randn(C, H_latents, W_latents) + text_embeddings = { + "prompt_embeds": torch.randn(200, 4096), + "pooled_prompt_embeds": torch.randn(768), + } + + sample = { + "__key__": "test", + "__restore_key__": "test_restore", + "__subflavors__": [], + "json": {}, + "pth": image_latent, + "pickle": text_embeddings, + } + + enc = FluxTaskEncoder(seq_length=2048, packing_buffer_size=None) + out = enc.encode_sample(sample) + + # With CP world size 2, sharding factor is 2*2=4 + # seq_len_q_padded should be divisible by 4 + seq_len_q = (H_latents // 2) * (W_latents // 2) + assert out.seq_len_q_padded.item() % 4 == 0 + assert out.seq_len_q_padded.item() >= seq_len_q + + # seq_len_kv_padded should also be divisible by 4 + assert out.seq_len_kv_padded.item() % 4 == 0 + assert out.seq_len_kv_padded.item() >= 512 # original padded length + + +def test_batch_without_packing(monkeypatch): + """Test batching multiple samples without packing.""" + monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 1, raising=False) + from megatron.energon.task_encoder.base import WorkerConfig + + class _FakeWorkerCfg: + def worker_seed(self): + return 789 + + active_worker_sample_index = 0 + + monkeypatch.setattr(WorkerConfig, "active_worker_config", _FakeWorkerCfg(), raising=False) + + C = 16 + H, W = 64, 64 + + # Create multiple samples + samples_data = [] + for i in range(2): + image_latent = torch.randn(C, H, W) + text_embeddings = { + "prompt_embeds": torch.randn(256, 4096), + "pooled_prompt_embeds": torch.randn(768), + } + sample = { + "__key__": f"sample_{i}", + "__restore_key__": f"restore_{i}", + "__subflavors__": [], + "json": {"id": i}, + "pth": image_latent, + "pickle": text_embeddings, + } + samples_data.append(sample) + + enc = FluxTaskEncoder(seq_length=2048, packing_buffer_size=None) + encoded_samples = [enc.encode_sample(s) for s in samples_data] + batch = enc.batch(encoded_samples) + + assert isinstance(batch, dict) + expected_keys = [ + "latents", + "prompt_embeds", + "pooled_prompt_embeds", + "text_ids", + "loss_mask", + "seq_len_q", + "seq_len_q_padded", + "seq_len_kv", + "seq_len_kv_padded", + "latent_shape", + "image_metadata", + ] + for k in expected_keys: + assert k in batch + + # Check batch dimensions + assert batch["latents"].shape[0] == 2 # batch size + assert batch["latents"].shape[1:] == (C, H, W) + assert batch["prompt_embeds"].shape[0] == 2 + assert batch["pooled_prompt_embeds"].shape[0] == 2 + assert batch["text_ids"].shape[0] == 2 + assert batch["loss_mask"].shape[0] == 2 + + +def test_batch_with_packing_buffer_size(monkeypatch): + """Test batching with packing buffer size.""" + # Force CP world size 1 + monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 1, raising=False) + from megatron.energon.task_encoder.base import WorkerConfig + + class _FakeWorkerCfg: + def worker_seed(self): + return 999 + + active_worker_sample_index = 0 + + monkeypatch.setattr(WorkerConfig, "active_worker_config", _FakeWorkerCfg(), raising=False) + + C = 16 + H, W = 64, 64 + + image_latent = torch.randn(C, H, W) + text_embeddings = { + "prompt_embeds": torch.randn(300, 4096), + "pooled_prompt_embeds": torch.randn(768), + } + + sample = { + "__key__": "packed", + "__restore_key__": "packed_restore", + "__subflavors__": [], + "json": {"meta": "data"}, + "pth": image_latent, + "pickle": text_embeddings, + } + + enc = FluxTaskEncoder( + seq_length=2048, + vae_scale_factor=8, + packing_buffer_size=3, + ) + diff_sample = enc.encode_sample(sample) + batch = enc.batch([diff_sample]) + + assert isinstance(batch, dict) + for k in [ + "latents", + "prompt_embeds", + "pooled_prompt_embeds", + "text_ids", + "loss_mask", + "seq_len_q", + "seq_len_q_padded", + "seq_len_kv", + "seq_len_kv_padded", + "latent_shape", + "image_metadata", + ]: + assert k in batch + + # With packing, batch size is 1 but has batch dimension [1, ...] + assert batch["latents"].shape[0] == 1 + assert batch["latents"].shape[1:] == (C, H, W) + assert batch["prompt_embeds"].shape[0] == 1 + assert batch["pooled_prompt_embeds"].shape[0] == 1 + assert batch["text_ids"].shape[0] == 1 + if batch["loss_mask"] is not None: + assert batch["loss_mask"].shape[0] == 1 + + +def test_encode_sample_with_alternative_text_format(monkeypatch): + """Test encoding with alternative text embedding keys (t5_embeds, clip_embeds).""" + monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 1, raising=False) + from megatron.energon.task_encoder.base import WorkerConfig + + class _FakeWorkerCfg: + def worker_seed(self): + return 111 + + active_worker_sample_index = 0 + + monkeypatch.setattr(WorkerConfig, "active_worker_config", _FakeWorkerCfg(), raising=False) + + C = 16 + H, W = 128, 128 + + image_latent = torch.randn(C, H, W) + # Use alternative keys + text_embeddings = { + "t5_embeds": torch.randn(400, 4096), + "clip_embeds": torch.randn(768), + } + + sample = { + "__key__": "alt_format", + "__restore_key__": "alt_restore", + "__subflavors__": [], + "json": {}, + "pth": image_latent, + "pickle": text_embeddings, + } + + enc = FluxTaskEncoder(seq_length=1024, packing_buffer_size=None) + out = enc.encode_sample(sample) + + # Should successfully encode even with alternative keys + assert out.video.shape == (C, H, W) + assert out.context_embeddings.shape[0] == 512 # padded + assert "pooled_prompt_embeds" in out.video_metadata + assert "text_ids" in out.video_metadata diff --git a/tests/unit_tests/diffusion/data/wan/__init__.py b/tests/unit_tests/diffusion/data/wan/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/diffusion/data/wan/test_wan_energon_datamodule.py b/tests/unit_tests/diffusion/data/wan/test_wan_energon_datamodule.py new file mode 100644 index 0000000000..344cb3a0fd --- /dev/null +++ b/tests/unit_tests/diffusion/data/wan/test_wan_energon_datamodule.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.bridge.diffusion.data.wan import wan_energon_datamodule as wan_dm_mod +from megatron.bridge.diffusion.data.wan.wan_taskencoder import WanTaskEncoder + + +class _FakeDiffusionDataModule: + def __init__( + self, + *, + path: str, + seq_length: int, + packing_buffer_size: int, + task_encoder, + micro_batch_size: int, + global_batch_size: int, + num_workers: int, + ): + self.path = path + self.seq_length = seq_length + self.packing_buffer_size = packing_buffer_size + self.task_encoder = task_encoder + self.micro_batch_size = micro_batch_size + self.global_batch_size = global_batch_size + self.num_workers = num_workers + + # mimic API used by WanDataModuleConfig.build_datasets + def train_dataloader(self): + return "train" + + +def test_wan_datamodule_config_initialization(monkeypatch): + # Patch the symbol used inside wan_energon_datamodule module + monkeypatch.setattr(wan_dm_mod, "DiffusionDataModule", _FakeDiffusionDataModule) + + cfg = wan_dm_mod.WanDataModuleConfig( + path="", + seq_length=128, + task_encoder_seq_length=128, + packing_buffer_size=4, + micro_batch_size=2, + global_batch_size=8, + num_workers=0, + ) + + # __post_init__ should construct a dataset with WanTaskEncoder and propagate seq_length + assert isinstance(cfg.dataset, _FakeDiffusionDataModule) + assert cfg.sequence_length == cfg.dataset.seq_length == 128 + assert isinstance(cfg.dataset.task_encoder, WanTaskEncoder) + assert cfg.dataset.task_encoder.seq_length == 128 + assert cfg.dataset.task_encoder.packing_buffer_size == 4 + + # build_datasets should return train loader thrice + train, val, test = cfg.build_datasets(context=None) + assert train == "train" and val == "train" and test == "train" diff --git a/tests/unit_tests/diffusion/data/wan/test_wan_mock_datamodule.py b/tests/unit_tests/diffusion/data/wan/test_wan_mock_datamodule.py new file mode 100644 index 0000000000..f18cddf132 --- /dev/null +++ b/tests/unit_tests/diffusion/data/wan/test_wan_mock_datamodule.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron.bridge.diffusion.data.wan.wan_mock_datamodule import WanMockDataModuleConfig + + +def test_wan_mock_datamodule_build_and_batch_shapes(): + cfg = WanMockDataModuleConfig( + path="", + seq_length=128, + packing_buffer_size=2, + micro_batch_size=2, + global_batch_size=8, + num_workers=0, + # Use small shapes for a light-weight test run + F_latents=4, + H_latents=8, + W_latents=6, + patch_spatial=2, + patch_temporal=1, + number_packed_samples=2, + context_seq_len=16, + context_embeddings_dim=64, + ) + train_dl, val_dl, test_dl = cfg.build_datasets(_context=None) + assert train_dl is val_dl and val_dl is test_dl + + batch = next(iter(train_dl)) + expected_keys = { + "video_latents", + "context_embeddings", + "loss_mask", + "seq_len_q", + "seq_len_q_padded", + "seq_len_kv", + "seq_len_kv_padded", + "grid_sizes", + "video_metadata", + } + assert expected_keys.issubset(set(batch.keys())) + + # Basic sanity checks on shapes/dtypes + assert batch["video_latents"].dim() == 3 and batch["video_latents"].shape[1] == 1 + assert batch["context_embeddings"].dim() == 3 and batch["context_embeddings"].shape[1] == 1 + assert batch["loss_mask"].dim() == 2 and batch["loss_mask"].shape[1] == 1 + assert batch["seq_len_q"].dtype == torch.int32 + assert batch["seq_len_q_padded"].dtype == torch.int32 + assert batch["seq_len_kv"].dtype == torch.int32 + assert batch["seq_len_kv_padded"].dtype == torch.int32 + assert batch["grid_sizes"].dtype == torch.int32 diff --git a/tests/unit_tests/diffusion/data/wan/test_wan_taskencoder.py b/tests/unit_tests/diffusion/data/wan/test_wan_taskencoder.py new file mode 100644 index 0000000000..81b47ef382 --- /dev/null +++ b/tests/unit_tests/diffusion/data/wan/test_wan_taskencoder.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron.bridge.diffusion.data.wan.wan_taskencoder import WanTaskEncoder, cook, parallel_state + + +def test_cook_extracts_expected_fields(): + sample = { + "__key__": "k", + "__restore_key__": "rk", + "__subflavors__": [], + "json": {"meta": 1}, + "pth": torch.randn(1, 2, 2, 2), + "pickle": torch.randn(3, 4), + "unused": 123, + } + out = cook(sample) + assert "json" in out and out["json"] is sample["json"] + assert "pth" in out and torch.equal(out["pth"], sample["pth"]) + assert "pickle" in out and torch.equal(out["pickle"], sample["pickle"]) + # ensure basic keys from the sample are preserved by cook via basic_sample_keys() + assert out["__key__"] == sample["__key__"] + assert out["__restore_key__"] == sample["__restore_key__"] + assert out["__subflavors__"] == sample["__subflavors__"] + + +def test_encode_sample_no_context_parallel(monkeypatch): + # Ensure CP world size is 1 to avoid extra padding branch + monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 1, raising=False) + # Ensure seeded wrapper has an active worker config + from megatron.energon.task_encoder.base import WorkerConfig + + class _FakeWorkerCfg: + def worker_seed(self): + return 123 + + active_worker_sample_index = 0 + + monkeypatch.setattr(WorkerConfig, "active_worker_config", _FakeWorkerCfg(), raising=False) + + # Construct a minimal, consistent sample + c = 8 + F_latents, H_latents, W_latents = 4, 8, 6 + patch_temporal, patch_spatial = 1, 2 + # video latent before patchify has shape [c, F_latents, H_latents, W_latents] + # where grid sizes (patch counts) are (F_latents // pF, H_latents // pH, W_latents // pW) + video_latent = torch.randn(c, F_latents, H_latents, W_latents) + context_len, context_dim = 256, 64 + context_embeddings = torch.randn(context_len, context_dim) + sample = { + "__key__": "k", + "__restore_key__": "rk", + "__subflavors__": [], + "json": {"meta": 1}, + "pth": video_latent, + "pickle": context_embeddings, + } + + enc = WanTaskEncoder( + seq_length=1024, patch_temporal=patch_temporal, patch_spatial=patch_spatial, packing_buffer_size=None + ) + out = enc.encode_sample(sample) + + # Grid / patches + F_patches = F_latents // patch_temporal + H_patches = H_latents // patch_spatial + W_patches = W_latents // patch_spatial + num_patches = F_patches * H_patches * W_patches + patch_vec_dim = c * patch_temporal * patch_spatial * patch_spatial + + assert out.video.shape == (num_patches, patch_vec_dim) + assert out.latent_shape.dtype == torch.int32 + assert torch.equal(out.latent_shape, torch.tensor([F_patches, H_patches, W_patches], dtype=torch.int32)) + + # Loss mask and seq lengths + assert out.loss_mask.dtype == torch.bfloat16 + assert out.loss_mask.shape[0] == num_patches + assert torch.equal(out.seq_len_q, torch.tensor([num_patches], dtype=torch.int32)) + # context embeddings are padded to fixed 512 inside encode_sample + assert torch.equal(out.seq_len_kv, torch.tensor([512], dtype=torch.int32)) + assert torch.equal(out.seq_len_q_padded, out.seq_len_q) + assert torch.equal(out.seq_len_kv_padded, out.seq_len_kv) + + # Metadata passthrough + assert out.video_metadata == sample["json"] + assert out.__key__ == sample["__key__"] + assert out.__restore_key__ == sample["__restore_key__"] + assert out.__subflavors__ == sample["__subflavors__"] + + +def test_batch_with_packing_buffer_size(monkeypatch): + # Force CP world size 1 + monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 1, raising=False) + # Ensure seeded wrapper has an active worker config + from megatron.energon.task_encoder.base import WorkerConfig + + class _FakeWorkerCfg: + def worker_seed(self): + return 456 + + active_worker_sample_index = 0 + + monkeypatch.setattr(WorkerConfig, "active_worker_config", _FakeWorkerCfg(), raising=False) + + c = 4 + F_latents, H_latents, W_latents = 2, 4, 4 + patch_temporal, patch_spatial = 1, 2 + video_latent = torch.randn(c, F_latents * patch_temporal, H_latents * patch_spatial, W_latents * patch_spatial) + sample = { + "__key__": "k", + "__restore_key__": "rk", + "__subflavors__": [], + "json": {"meta": 1}, + "pth": video_latent, + "pickle": torch.randn(32, 128), + } + + enc = WanTaskEncoder( + seq_length=256, patch_temporal=patch_temporal, patch_spatial=patch_spatial, packing_buffer_size=3 + ) + diff_sample = enc.encode_sample(sample) + batch = enc.batch([diff_sample]) + + assert isinstance(batch, dict) + for k in [ + "video_latents", + "context_embeddings", + "loss_mask", + "seq_len_q", + "seq_len_q_padded", + "seq_len_kv", + "seq_len_kv_padded", + "grid_sizes", + "video_metadata", + ]: + assert k in batch + + # video_latents: [S, 1, ...], where S equals sample.video length when CP world size is 1 + assert batch["video_latents"].shape[1] == 1 + assert batch["context_embeddings"].shape[1] == 1 + assert batch["loss_mask"].shape[1] == 1 diff --git a/tests/unit_tests/diffusion/model/common/__init__.py b/tests/unit_tests/diffusion/model/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/diffusion/model/common/test_normalization.py b/tests/unit_tests/diffusion/model/common/test_normalization.py new file mode 100644 index 0000000000..0b23a0db15 --- /dev/null +++ b/tests/unit_tests/diffusion/model/common/test_normalization.py @@ -0,0 +1,301 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron.bridge.diffusion.models.common.normalization import RMSNorm + + +def test_rmsnorm_initialization(): + """Test RMSNorm initialization with different hidden sizes.""" + hidden_size = 768 + norm = RMSNorm(hidden_size) + + # Check weight parameter exists and has correct shape + assert hasattr(norm, "weight") + assert norm.weight.shape == (hidden_size,) + assert isinstance(norm.weight, torch.nn.Parameter) + + # Check weight is initialized to ones + assert torch.allclose(norm.weight, torch.ones(hidden_size)) + + # Check epsilon value + assert norm.eps == 1e-6 + + +def test_rmsnorm_initialization_with_custom_eps(): + """Test RMSNorm initialization with custom epsilon.""" + hidden_size = 512 + custom_eps = 1e-8 + norm = RMSNorm(hidden_size, eps=custom_eps) + + assert norm.eps == custom_eps + + +def test_rmsnorm_initialization_with_config(): + """Test RMSNorm initialization with config parameter (for compatibility).""" + hidden_size = 1024 + mock_config = {"dummy": "config"} + + # Should not raise an error even with config parameter + norm = RMSNorm(hidden_size, config=mock_config) + assert norm.weight.shape == (hidden_size,) + + +def test_rmsnorm_forward_2d_input(): + """Test RMSNorm forward pass with 2D input [batch, hidden].""" + batch_size = 4 + hidden_size = 256 + + norm = RMSNorm(hidden_size) + x = torch.randn(batch_size, hidden_size) + + output = norm(x) + + # Check output shape + assert output.shape == (batch_size, hidden_size) + + # Check output dtype matches input + assert output.dtype == x.dtype + + +def test_rmsnorm_forward_3d_input(): + """Test RMSNorm forward pass with 3D input [batch, seq_len, hidden].""" + batch_size = 2 + seq_len = 128 + hidden_size = 512 + + norm = RMSNorm(hidden_size) + x = torch.randn(batch_size, seq_len, hidden_size) + + output = norm(x) + + # Check output shape + assert output.shape == (batch_size, seq_len, hidden_size) + + # Check output dtype matches input + assert output.dtype == x.dtype + + +def test_rmsnorm_numerical_correctness(): + """Test that RMSNorm produces numerically correct results.""" + hidden_size = 64 + norm = RMSNorm(hidden_size, eps=1e-6) + + # Create a simple input + x = torch.randn(2, hidden_size) + + # Manually compute expected RMS normalization + rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + norm.eps) + expected = x / rms + + # Get actual output (without weight scaling for this test) + with torch.no_grad(): + norm.weight.fill_(1.0) # Set weights to 1 to isolate normalization + + output = norm(x) + + # Compare (allow small numerical differences) + assert torch.allclose(output, expected, rtol=1e-4, atol=1e-6) + + +def test_rmsnorm_weight_scaling(): + """Test that RMSNorm correctly applies weight scaling.""" + hidden_size = 32 + norm = RMSNorm(hidden_size) + + # Set custom weights + with torch.no_grad(): + norm.weight.fill_(2.0) + + x = torch.randn(4, hidden_size) + + # Get normalized output + rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + norm.eps) + normalized = x / rms + expected = normalized * 2.0 + + output = norm(x) + + assert torch.allclose(output, expected, rtol=1e-4, atol=1e-6) + + +def test_rmsnorm_different_dtypes(): + """Test RMSNorm with different input dtypes.""" + hidden_size = 128 + norm = RMSNorm(hidden_size) + + # Test with float32 + x_float32 = torch.randn(2, hidden_size, dtype=torch.float32) + output_float32 = norm(x_float32) + assert output_float32.dtype == torch.float32 + + # Test with float16 + x_float16 = torch.randn(2, hidden_size, dtype=torch.float16) + output_float16 = norm(x_float16) + assert output_float16.dtype == torch.float16 + + # Test with bfloat16 + x_bfloat16 = torch.randn(2, hidden_size, dtype=torch.bfloat16) + output_bfloat16 = norm(x_bfloat16) + assert output_bfloat16.dtype == torch.bfloat16 + + +def test_rmsnorm_preserves_dtype(): + """Test that RMSNorm preserves input dtype even though internal computation is in float32.""" + hidden_size = 256 + norm = RMSNorm(hidden_size) + + # Test with bfloat16 (common in training) + x = torch.randn(3, 10, hidden_size, dtype=torch.bfloat16) + output = norm(x) + + # Output should have same dtype as input + assert output.dtype == torch.bfloat16 + + # But should have been normalized correctly (internal computation in float) + # Verify by checking the RMS is approximately 1 + with torch.no_grad(): + norm.weight.fill_(1.0) + output_normalized = norm(x.float()) + rms = torch.sqrt(torch.mean(output_normalized**2, dim=-1)) + # RMS should be close to 1 after normalization + assert torch.allclose(rms, torch.ones_like(rms), rtol=0.1) + + +def test_rmsnorm_zero_input(): + """Test RMSNorm behavior with zero input (edge case).""" + hidden_size = 64 + norm = RMSNorm(hidden_size, eps=1e-6) + + # Create zero input + x = torch.zeros(2, hidden_size) + + # Should not crash and should produce zero output (scaled by weights) + output = norm(x) + + # With zero input and epsilon, the norm is sqrt(0 + eps) + # So output should be 0 / sqrt(eps) * weight = 0 + assert torch.allclose(output, torch.zeros_like(output)) + + +def test_rmsnorm_gradient_flow(): + """Test that gradients flow properly through RMSNorm.""" + hidden_size = 128 + norm = RMSNorm(hidden_size) + + x = torch.randn(4, hidden_size, requires_grad=True) + output = norm(x) + + # Compute loss and backprop + loss = output.sum() + loss.backward() + + # Check that gradients exist for both input and weight + assert x.grad is not None + assert norm.weight.grad is not None + + # Check that gradients are not all zeros + assert not torch.allclose(x.grad, torch.zeros_like(x.grad)) + assert not torch.allclose(norm.weight.grad, torch.zeros_like(norm.weight.grad)) + + +def test_rmsnorm_batch_independence(): + """Test that normalization is independent across batch dimension.""" + hidden_size = 64 + norm = RMSNorm(hidden_size) + + # Create batched input + x1 = torch.randn(1, hidden_size) + x2 = torch.randn(1, hidden_size) + x_batched = torch.cat([x1, x2], dim=0) + + # Process individually and batched + with torch.no_grad(): + out1 = norm(x1) + out2 = norm(x2) + out_batched = norm(x_batched) + + # Results should be identical + assert torch.allclose(out_batched[0], out1[0], rtol=1e-5) + assert torch.allclose(out_batched[1], out2[0], rtol=1e-5) + + +def test_rmsnorm_sequence_independence(): + """Test that normalization is independent across sequence dimension.""" + hidden_size = 64 + seq_len = 10 + norm = RMSNorm(hidden_size) + + # Create 3D input [batch, seq, hidden] + x = torch.randn(2, seq_len, hidden_size) + + with torch.no_grad(): + # Process full sequence + output_full = norm(x) + + # Process each position separately + for i in range(seq_len): + output_single = norm(x[:, i : i + 1, :]) + assert torch.allclose(output_full[:, i : i + 1, :], output_single, rtol=1e-5) + + +def test_rmsnorm_epsilon_effect(): + """Test that epsilon parameter affects numerical stability.""" + hidden_size = 64 + + # Create very small input values + x = torch.randn(2, hidden_size) * 1e-8 + + # Test with different epsilon values + norm_large_eps = RMSNorm(hidden_size, eps=1e-3) + norm_small_eps = RMSNorm(hidden_size, eps=1e-10) + + with torch.no_grad(): + norm_large_eps.weight.fill_(1.0) + norm_small_eps.weight.fill_(1.0) + + output_large_eps = norm_large_eps(x) + output_small_eps = norm_small_eps(x) + + # Outputs should be different due to epsilon + assert not torch.allclose(output_large_eps, output_small_eps, rtol=1e-2) + + +def test_rmsnorm_large_values(): + """Test RMSNorm with large input values.""" + hidden_size = 128 + norm = RMSNorm(hidden_size) + + # Create input with large values + x = torch.randn(2, hidden_size) * 1000 + + # Should not produce NaN or Inf + output = norm(x) + + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() + + +def test_rmsnorm_different_hidden_sizes(): + """Test RMSNorm with various hidden sizes.""" + hidden_sizes = [64, 128, 256, 512, 768, 1024, 2048, 4096] + + for hidden_size in hidden_sizes: + norm = RMSNorm(hidden_size) + x = torch.randn(2, 10, hidden_size) + output = norm(x) + + assert output.shape == x.shape + assert not torch.isnan(output).any() diff --git a/tests/unit_tests/diffusion/model/flux/__init__.py b/tests/unit_tests/diffusion/model/flux/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/diffusion/model/flux/conversion/__init__.py b/tests/unit_tests/diffusion/model/flux/conversion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/diffusion/model/flux/conversion/test_flux_bridge.py b/tests/unit_tests/diffusion/model/flux/conversion/test_flux_bridge.py new file mode 100644 index 0000000000..3618f24520 --- /dev/null +++ b/tests/unit_tests/diffusion/model/flux/conversion/test_flux_bridge.py @@ -0,0 +1,203 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types + +import pytest +import torch + +from megatron.bridge.diffusion.models.flux.conversion import flux_bridge as flux_bridge_module + + +pytestmark = [pytest.mark.unit] + + +def _make_cfg( + *, + in_channels=64, + patch_size=1, + num_layers=19, + num_single_layers=38, + num_attention_heads=24, + attention_head_dim=128, + pooled_projection_dim=768, + guidance_embeds=False, + axes_dims_rope=(16, 56, 56), +): + cfg = types.SimpleNamespace() + cfg.in_channels = in_channels + cfg.patch_size = patch_size + cfg.num_layers = num_layers + cfg.num_single_layers = num_single_layers + cfg.num_attention_heads = num_attention_heads + cfg.attention_head_dim = attention_head_dim + cfg.pooled_projection_dim = pooled_projection_dim + cfg.guidance_embeds = guidance_embeds + cfg.axes_dims_rope = axes_dims_rope + return cfg + + +def test_provider_bridge_constructs_provider_with_expected_fields(): + class DummyHF: + def __init__(self, cfg): + self.config = cfg + + cfg = _make_cfg() + bridge = flux_bridge_module.FluxBridge() + provider = bridge.provider_bridge(DummyHF(cfg)) + + # Basic sanity: returned type and a few key fields + assert provider is not None + assert provider.num_joint_layers == cfg.num_layers + assert provider.num_single_layers == cfg.num_single_layers + assert provider.num_attention_heads == cfg.num_attention_heads + + # kv_channels equals per-head dim + assert getattr(provider, "kv_channels") == cfg.attention_head_dim + + # num_query_groups equals num_attention_heads + assert provider.num_query_groups == cfg.num_attention_heads + + # passthrough fields + assert provider.in_channels == cfg.in_channels + assert provider.patch_size == cfg.patch_size + assert provider.vec_in_dim == cfg.pooled_projection_dim + assert provider.guidance_embed == cfg.guidance_embeds + assert provider.axes_dims_rope == cfg.axes_dims_rope + + # bf16 and params_dtype set by bridge + assert provider.bf16 is False + assert provider.params_dtype == torch.float32 + + # hidden_size stored on bridge instance + assert bridge.hidden_size == provider.hidden_size + + +def test_mapping_registry_registers_module_types_and_builds_mappings(monkeypatch): + calls_register_module_type = [] + + def fake_register_module_type(name, parallelism): + calls_register_module_type.append((name, parallelism)) + + constructed_registry_args = {} + + class FakeRegistry: + def __init__(self, *mappings): + constructed_registry_args["mappings"] = mappings + + monkeypatch.setattr( + flux_bridge_module.AutoMapping, "register_module_type", staticmethod(fake_register_module_type) + ) + monkeypatch.setattr(flux_bridge_module, "MegatronMappingRegistry", FakeRegistry) + + registry = flux_bridge_module.FluxBridge().mapping_registry() + + # Verify module type registrations + assert ("Linear", "replicated") in calls_register_module_type + + # We replaced the real registry with FakeRegistry; the function should return that instance + assert isinstance(registry, FakeRegistry) + mappings = constructed_registry_args["mappings"] + + # Ensure we have a reasonable number of mappings + assert len(mappings) >= 20 + + # Expect at least one AutoMapping, one QKVMapping, one SplitRowParallelMapping + has_auto = any(m.__class__.__name__ == "AutoMapping" for m in mappings) + has_qkv = any(m.__class__.__name__ == "QKVMapping" for m in mappings) + has_split_row = any(m.__class__.__name__ == "SplitRowParallelMapping" for m in mappings) + assert has_auto and has_qkv and has_split_row + + +def test_maybe_modify_loaded_hf_weight_with_weight_1(): + """Test that weight_1 suffix correctly slices the second half of weight tensor""" + bridge = flux_bridge_module.FluxBridge() + bridge.hidden_size = 100 + + # Create a dummy weight tensor + dummy_weight = torch.randn(200, 300) + hf_state_dict = {"single_transformer_blocks.0.proj_out.weight": dummy_weight} + + # Test weight_1 suffix (should get second half) + result = bridge.maybe_modify_loaded_hf_weight("single_transformer_blocks.0.proj_out.weight_1", hf_state_dict) + + expected = dummy_weight[:, 100:] + assert torch.equal(result, expected) + assert result.shape == (200, 200) + + +def test_maybe_modify_loaded_hf_weight_with_weight_2(): + """Test that weight_2 suffix correctly slices the first half of weight tensor""" + bridge = flux_bridge_module.FluxBridge() + bridge.hidden_size = 100 + + # Create a dummy weight tensor + dummy_weight = torch.randn(200, 300) + hf_state_dict = {"single_transformer_blocks.0.proj_out.weight": dummy_weight} + + # Test weight_2 suffix (should get first half) + result = bridge.maybe_modify_loaded_hf_weight("single_transformer_blocks.0.proj_out.weight_2", hf_state_dict) + + expected = dummy_weight[:, :100] + assert torch.equal(result, expected) + assert result.shape == (200, 100) + + +def test_maybe_modify_loaded_hf_weight_normal_param(): + """Test that normal parameter names are passed through unchanged""" + bridge = flux_bridge_module.FluxBridge() + bridge.hidden_size = 100 + + # Create a dummy weight tensor + dummy_weight = torch.randn(200, 300) + hf_state_dict = {"norm_out.linear.weight": dummy_weight} + + # Test normal parameter (no modification) + result = bridge.maybe_modify_loaded_hf_weight("norm_out.linear.weight", hf_state_dict) + + assert torch.equal(result, dummy_weight) + + +def test_maybe_modify_loaded_hf_weight_with_dict_param(): + """Test that dictionary parameters are handled correctly""" + bridge = flux_bridge_module.FluxBridge() + + # Create dummy weight tensors + weight1 = torch.randn(100, 200) + weight2 = torch.randn(100, 200) + hf_state_dict = { + "transformer_blocks.0.attn.to_q.weight": weight1, + "transformer_blocks.0.attn.to_k.weight": weight2, + } + + # Test dictionary parameter + param_dict = { + "q": "transformer_blocks.0.attn.to_q.weight", + "k": "transformer_blocks.0.attn.to_k.weight", + } + result = bridge.maybe_modify_loaded_hf_weight(param_dict, hf_state_dict) + + assert isinstance(result, dict) + assert torch.equal(result["q"], weight1) + assert torch.equal(result["k"], weight2) + + +def test_split_row_parallel_mapping_has_allow_hf_name_mismatch(): + """Test that SplitRowParallelMapping has allow_hf_name_mismatch set to True""" + mapping = flux_bridge_module.SplitRowParallelMapping( + megatron_param="single_blocks.*.mlp.linear_fc2.weight", + hf_param="single_transformer_blocks.*.proj_out.weight_1", + ) + + assert mapping.allow_hf_name_mismatch is True diff --git a/tests/unit_tests/diffusion/model/flux/conversion/test_flux_hf_pretrained.py b/tests/unit_tests/diffusion/model/flux/conversion/test_flux_hf_pretrained.py new file mode 100644 index 0000000000..8c8360925a --- /dev/null +++ b/tests/unit_tests/diffusion/model/flux/conversion/test_flux_hf_pretrained.py @@ -0,0 +1,279 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import pytest + +from megatron.bridge.diffusion.models.flux.conversion import flux_hf_pretrained as flux_hf_module + + +pytestmark = [pytest.mark.unit] + + +def test_load_config_uses_transformer_subfolder(monkeypatch, tmp_path): + calls = [] + + class FakeModel: + def __init__(self, cfg): + self.config = cfg + + class FakeFlux: + @classmethod + def from_pretrained(cls, path, subfolder=None): + calls.append((str(path), subfolder)) + return FakeModel(cfg={"ok": True}) + + monkeypatch.setattr(flux_hf_module, "FluxTransformer2DModel", FakeFlux) + + src = tmp_path / "hf" + src.mkdir(parents=True, exist_ok=True) + hf = flux_hf_module.PreTrainedFlux(str(src)) + + # Accessing .config should trigger _load_config + cfg = hf.config + assert cfg == {"ok": True} + # Ensure we called with transformer subfolder + assert calls and calls[-1][1] == "transformer" + + +def test_state_uses_transformer_subfolder_and_caches(monkeypatch, tmp_path): + captured = {"source_path": None, "constructed": 0} + + class FakeSource: + def __init__(self, path): + captured["source_path"] = str(path) + + class FakeStateDict: + def __init__(self, source): + self.source = source + captured["constructed"] += 1 + + monkeypatch.setattr(flux_hf_module, "FluxSafeTensorsStateSource", FakeSource) + monkeypatch.setattr(flux_hf_module, "StateDict", FakeStateDict) + + src = tmp_path / "hf_model" + (src / "transformer").mkdir(parents=True, exist_ok=True) + hf = flux_hf_module.PreTrainedFlux(str(src)) + + s1 = hf.state + s2 = hf.state # Cached + assert s1 is s2 + # Correct subfolder used + assert captured["source_path"] == str(src / "transformer") + # StateDict constructed only once due to caching + assert captured["constructed"] == 1 + + +def test_save_artifacts_copies_existing_files(tmp_path): + # Prepare source with transformer/config.json and index + src = tmp_path / "src" + tdir = src / "transformer" + tdir.mkdir(parents=True, exist_ok=True) + config_src = tdir / "config.json" + index_src = tdir / "diffusion_pytorch_model.safetensors.index.json" + config_data = {"a": 1} + index_data = {"weight_map": {}} + config_src.write_text(json.dumps(config_data)) + index_src.write_text(json.dumps(index_data)) + + # Destination directory + dest = tmp_path / "dest" + + hf = flux_hf_module.PreTrainedFlux(str(src)) + hf.save_artifacts(str(dest)) + + # Validate files copied + dest_tdir = dest / "transformer" + assert dest_tdir.is_dir() + assert json.loads((dest_tdir / "config.json").read_text()) == config_data + assert json.loads((dest_tdir / "diffusion_pytorch_model.safetensors.index.json").read_text()) == index_data + + +def test_save_artifacts_exports_config_when_missing(monkeypatch, tmp_path): + class FakeCfg: + def to_dict(self): + return {"from_model": True} + + class FakeModel: + def __init__(self): + self.config = FakeCfg() + + class FakeFlux: + @classmethod + def from_pretrained(cls, path, subfolder=None): + # Ensure it targets the transformer subfolder + assert subfolder == "transformer" + return FakeModel() + + monkeypatch.setattr(flux_hf_module, "FluxTransformer2DModel", FakeFlux) + + src = tmp_path / "empty_src" + src.mkdir(parents=True, exist_ok=True) + + dest = tmp_path / "out" + hf = flux_hf_module.PreTrainedFlux(str(src)) + hf.save_artifacts(dest) + + # Should create transformer/config.json with exported contents + dest_cfg = dest / "transformer" / "config.json" + assert dest_cfg.is_file() + assert json.loads(dest_cfg.read_text()) == {"from_model": True} + + +def test_save_artifacts_handles_export_failure(monkeypatch, tmp_path): + class FailingFlux: + @classmethod + def from_pretrained(cls, path, subfolder=None): + raise RuntimeError("fail") + + monkeypatch.setattr(flux_hf_module, "FluxTransformer2DModel", FailingFlux) + + src = tmp_path / "src2" + src.mkdir(parents=True, exist_ok=True) + dest = tmp_path / "dest2" + + hf = flux_hf_module.PreTrainedFlux(str(src)) + # Should not raise + hf.save_artifacts(dest) + + # Transformer folder created but no config.json written + dest_tdir = dest / "transformer" + assert dest_tdir.is_dir() + assert not (dest_tdir / "config.json").exists() + + +def test_flux_safetensors_state_source_save_generator(monkeypatch, tmp_path): + """Test that FluxSafeTensorsStateSource.save_generator writes to transformer/ subfolder""" + parent_save_called = [] + + class FakeParentClass: + def save_generator(self, generator, output_path, strict=True): + parent_save_called.append((str(output_path), strict)) + return "success" + + # Temporarily replace SafeTensorsStateSource for testing + monkeypatch.setattr(flux_hf_module, "SafeTensorsStateSource", FakeParentClass) + + # Need to recreate the class with the new base + class TestFluxSource(FakeParentClass): + def save_generator(self, generator, output_path, strict=True): + from pathlib import Path + + target_dir = Path(output_path) / "transformer" + return super().save_generator(generator, target_dir, strict=strict) + + output_path = tmp_path / "output" + source = TestFluxSource() + result = source.save_generator(None, output_path, strict=False) + + # Verify parent's save_generator was called with transformer/ subfolder + assert len(parent_save_called) == 1 + assert parent_save_called[0][0] == str(output_path / "transformer") + assert parent_save_called[0][1] is False + assert result == "success" + + +def test_pretrained_flux_model_name_or_path_property(tmp_path): + """Test that model_name_or_path property returns the correct path""" + src = tmp_path / "model" + src.mkdir(parents=True, exist_ok=True) + + hf = flux_hf_module.PreTrainedFlux(str(src)) + assert hf.model_name_or_path == str(src) + + +def test_load_model_calls_from_pretrained(monkeypatch, tmp_path): + """Test that _load_model calls FluxTransformer2DModel.from_pretrained""" + calls = [] + + class FakeFlux: + @classmethod + def from_pretrained(cls, path, **kwargs): + calls.append(str(path)) + return FakeFlux() + + monkeypatch.setattr(flux_hf_module, "FluxTransformer2DModel", FakeFlux) + + src = tmp_path / "model" + src.mkdir(parents=True, exist_ok=True) + + hf = flux_hf_module.PreTrainedFlux(str(src)) + model = hf._load_model() + + assert len(calls) == 1 + assert calls[0] == str(src) + assert isinstance(model, FakeFlux) + + +def test_state_uses_model_state_dict_when_model_loaded(monkeypatch, tmp_path): + """Test that state property uses model's state_dict when model is already loaded""" + model_state = {"weight1": "value1"} + + class FakeModel: + def state_dict(self): + return model_state + + class FakeStateDict: + def __init__(self, source): + self.source = source + + monkeypatch.setattr(flux_hf_module, "StateDict", FakeStateDict) + + src = tmp_path / "model" + src.mkdir(parents=True, exist_ok=True) + + hf = flux_hf_module.PreTrainedFlux(str(src)) + # Manually set _model to simulate loaded model + hf._model = FakeModel() + + state = hf.state + # Should use model's state_dict, not file source + assert state.source == model_state + + +def test_save_artifacts_creates_transformer_directory(tmp_path): + """Test that save_artifacts creates transformer directory structure""" + src = tmp_path / "src" + src.mkdir(parents=True, exist_ok=True) + + dest = tmp_path / "dest" + + hf = flux_hf_module.PreTrainedFlux(str(src)) + hf.save_artifacts(dest) + + # Verify transformer directory was created + assert (dest / "transformer").is_dir() + + +def test_save_artifacts_only_copies_index_if_config_exists(tmp_path): + """Test that index file is only copied when config.json exists""" + src = tmp_path / "src" + tdir = src / "transformer" + tdir.mkdir(parents=True, exist_ok=True) + + # Only create index file, not config.json + index_src = tdir / "diffusion_pytorch_model.safetensors.index.json" + index_data = {"weight_map": {}} + index_src.write_text(json.dumps(index_data)) + + dest = tmp_path / "dest" + + hf = flux_hf_module.PreTrainedFlux(str(src)) + hf.save_artifacts(dest) + + # Index should not be copied since config.json doesn't exist + dest_tdir = dest / "transformer" + assert dest_tdir.is_dir() + assert not (dest_tdir / "diffusion_pytorch_model.safetensors.index.json").exists() diff --git a/tests/unit_tests/diffusion/model/flux/test_flux_layer_spec.py b/tests/unit_tests/diffusion/model/flux/test_flux_layer_spec.py new file mode 100644 index 0000000000..2ce03c5a43 --- /dev/null +++ b/tests/unit_tests/diffusion/model/flux/test_flux_layer_spec.py @@ -0,0 +1,226 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from megatron.core.transformer.transformer_config import TransformerConfig + +from megatron.bridge.diffusion.models.flux.flux_layer_spec import ( + AdaLNContinuous, + get_flux_double_transformer_engine_spec, + get_flux_single_transformer_engine_spec, +) + + +pytestmark = [pytest.mark.unit] + + +def test_adaln_continuous_initialization(): + """Test AdaLNContinuous module initialization.""" + hidden_size = 512 + conditioning_dim = 768 + + config = TransformerConfig( + num_layers=1, + hidden_size=hidden_size, + num_attention_heads=8, + layernorm_epsilon=1e-6, + sequence_parallel=False, + ) + + # Test with layer norm + adaln_ln = AdaLNContinuous(config=config, conditioning_embedding_dim=conditioning_dim, norm_type="layer_norm") + assert hasattr(adaln_ln, "adaLN_modulation") + assert hasattr(adaln_ln, "norm") + + # Test with RMS norm + adaln_rms = AdaLNContinuous(config=config, conditioning_embedding_dim=conditioning_dim, norm_type="rms_norm") + assert hasattr(adaln_rms, "adaLN_modulation") + assert hasattr(adaln_rms, "norm") + + +def test_adaln_continuous_invalid_norm_type(): + """Test AdaLNContinuous raises error for invalid norm type.""" + hidden_size = 512 + conditioning_dim = 768 + + config = TransformerConfig( + num_layers=1, + hidden_size=hidden_size, + num_attention_heads=8, + layernorm_epsilon=1e-6, + sequence_parallel=False, + ) + + with pytest.raises(ValueError, match="Unknown normalization type"): + AdaLNContinuous(config=config, conditioning_embedding_dim=conditioning_dim, norm_type="invalid_norm") + + +def test_adaln_continuous_forward(): + """Test AdaLNContinuous forward pass.""" + hidden_size = 512 + conditioning_dim = 768 + seq_len = 8 + batch_size = 2 + + config = TransformerConfig( + num_layers=1, + hidden_size=hidden_size, + num_attention_heads=8, + layernorm_epsilon=1e-6, + sequence_parallel=False, + ) + + adaln_continuous = AdaLNContinuous( + config=config, conditioning_embedding_dim=conditioning_dim, norm_type="layer_norm" + ) + + x = torch.randn(seq_len, batch_size, hidden_size) + conditioning_emb = torch.randn(batch_size, conditioning_dim) + + output = adaln_continuous(x, conditioning_emb) + assert output.shape == x.shape + assert not torch.isnan(output).any() + + +def test_get_flux_double_transformer_engine_spec(): + """Test get_flux_double_transformer_engine_spec returns valid ModuleSpec.""" + spec = get_flux_double_transformer_engine_spec() + + # Basic structure checks + assert hasattr(spec, "module") + assert hasattr(spec, "submodules") + + # Check module type + from megatron.bridge.diffusion.models.flux.flux_layer_spec import MMDiTLayer + + assert spec.module == MMDiTLayer + + # Check submodules exist + sub = spec.submodules + assert hasattr(sub, "self_attention") + assert hasattr(sub, "mlp") + + # Check self_attention submodules + attn_spec = sub.self_attention + assert hasattr(attn_spec, "module") + assert hasattr(attn_spec, "submodules") + assert hasattr(attn_spec, "params") + + # Verify attention mask type is no_mask + from megatron.core.transformer.enums import AttnMaskType + + assert attn_spec.params.get("attn_mask_type") == AttnMaskType.no_mask + + # Check attention submodules + attn_sub = attn_spec.submodules + for attr in [ + "q_layernorm", + "k_layernorm", + "added_q_layernorm", + "added_k_layernorm", + "linear_qkv", + "added_linear_qkv", + "core_attention", + "linear_proj", + ]: + assert hasattr(attn_sub, attr), f"Missing attention submodule: {attr}" + + # Check MLP submodules + mlp_spec = sub.mlp + assert hasattr(mlp_spec, "submodules") + mlp_sub = mlp_spec.submodules + assert hasattr(mlp_sub, "linear_fc1") + assert hasattr(mlp_sub, "linear_fc2") + + +def test_get_flux_single_transformer_engine_spec(): + """Test get_flux_single_transformer_engine_spec returns valid ModuleSpec.""" + spec = get_flux_single_transformer_engine_spec() + + # Basic structure checks + assert hasattr(spec, "module") + assert hasattr(spec, "submodules") + + # Check module type + from megatron.bridge.diffusion.models.flux.flux_layer_spec import FluxSingleTransformerBlock + + assert spec.module == FluxSingleTransformerBlock + + # Check submodules exist + sub = spec.submodules + assert hasattr(sub, "self_attention") + assert hasattr(sub, "mlp") + + # Check self_attention submodules + attn_spec = sub.self_attention + assert hasattr(attn_spec, "module") + assert hasattr(attn_spec, "submodules") + assert hasattr(attn_spec, "params") + + # Verify attention mask type is no_mask + from megatron.core.transformer.enums import AttnMaskType + + assert attn_spec.params.get("attn_mask_type") == AttnMaskType.no_mask + + # Check attention submodules + attn_sub = attn_spec.submodules + for attr in ["linear_qkv", "core_attention", "q_layernorm", "k_layernorm", "linear_proj"]: + assert hasattr(attn_sub, attr), f"Missing attention submodule: {attr}" + + # Check MLP submodules + mlp_spec = sub.mlp + assert hasattr(mlp_spec, "submodules") + mlp_sub = mlp_spec.submodules + assert hasattr(mlp_sub, "linear_fc1") + assert hasattr(mlp_sub, "linear_fc2") + + +def test_flux_double_and_single_specs_are_different(): + """Test that double and single transformer specs have different modules.""" + double_spec = get_flux_double_transformer_engine_spec() + single_spec = get_flux_single_transformer_engine_spec() + + # Should have different module types + assert double_spec.module != single_spec.module + + # Should have different attention modules + assert double_spec.submodules.self_attention.module != single_spec.submodules.self_attention.module + + +def test_adaln_continuous_with_rms_norm_forward(): + """Test AdaLNContinuous forward pass with RMS norm.""" + hidden_size = 512 + conditioning_dim = 768 + seq_len = 8 + batch_size = 2 + + config = TransformerConfig( + num_layers=1, + hidden_size=hidden_size, + num_attention_heads=8, + layernorm_epsilon=1e-6, + sequence_parallel=False, + ) + + adaln_continuous = AdaLNContinuous( + config=config, conditioning_embedding_dim=conditioning_dim, norm_type="rms_norm" + ) + + x = torch.randn(seq_len, batch_size, hidden_size) + conditioning_emb = torch.randn(batch_size, conditioning_dim) + + output = adaln_continuous(x, conditioning_emb) + assert output.shape == x.shape + assert not torch.isnan(output).any() diff --git a/tests/unit_tests/diffusion/model/flux/test_flux_layers.py b/tests/unit_tests/diffusion/model/flux/test_flux_layers.py new file mode 100644 index 0000000000..fffbb85067 --- /dev/null +++ b/tests/unit_tests/diffusion/model/flux/test_flux_layers.py @@ -0,0 +1,382 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch + +from megatron.bridge.diffusion.models.flux.layers import ( + EmbedND, + MLPEmbedder, + TimeStepEmbedder, + rope, +) + + +pytestmark = [pytest.mark.unit] + + +class TestRopeFunction: + """Test the rope function for rotary position embeddings.""" + + def test_rope_basic_shape(self): + """Test rope function returns correct shape.""" + pos = torch.tensor([[1.0, 2.0, 3.0]]) + dim = 64 + theta = 10000 + + result = rope(pos, dim, theta) + + # Output shape should be [..., dim//2] + assert result.shape == (1, 3, dim // 2) + + def test_rope_requires_even_dimension(self): + """Test that rope function requires even dimension.""" + pos = torch.tensor([[1.0, 2.0]]) + dim = 63 # Odd dimension + theta = 10000 + + with pytest.raises(AssertionError, match="The dimension must be even"): + rope(pos, dim, theta) + + def test_rope_output_is_float(self): + """Test that rope output is float32.""" + pos = torch.tensor([[1.0, 2.0]], dtype=torch.float64) + dim = 64 + theta = 10000 + + result = rope(pos, dim, theta) + + assert result.dtype == torch.float32 + + def test_rope_with_different_thetas(self): + """Test rope function with different theta values.""" + pos = torch.tensor([[1.0]]) + dim = 32 + + result_10k = rope(pos, dim, 10000) + result_5k = rope(pos, dim, 5000) + + # Different theta values should produce different results + assert not torch.allclose(result_10k, result_5k) + + def test_rope_with_zero_positions(self): + """Test rope function with zero positions.""" + pos = torch.zeros(2, 3) + dim = 64 + theta = 10000 + + result = rope(pos, dim, theta) + + # Zero positions should produce zero embeddings + assert torch.allclose(result, torch.zeros_like(result), atol=1e-6) + + def test_rope_batched_input(self): + """Test rope function with batched input.""" + batch_size = 4 + seq_len = 16 + pos = torch.randn(batch_size, seq_len) + dim = 64 + theta = 10000 + + result = rope(pos, dim, theta) + + assert result.shape == (batch_size, seq_len, dim // 2) + assert not torch.isnan(result).any() + + +class TestEmbedND: + """Test the EmbedND class for N-dimensional rotary embeddings.""" + + def test_embednd_initialization(self): + """Test EmbedND initialization.""" + dim = 128 + theta = 10000 + axes_dim = [16, 56, 56] + + embed = EmbedND(dim, theta, axes_dim) + + assert embed.dim == dim + assert embed.theta == theta + assert embed.axes_dim == axes_dim + + def test_embednd_forward_shape(self): + """Test EmbedND forward pass output shape.""" + dim = 128 + theta = 10000 + axes_dim = [16, 56, 56] + batch_size = 2 + seq_len = 100 + n_axes = 3 + + embed = EmbedND(dim, theta, axes_dim) + ids = torch.randn(batch_size, seq_len, n_axes) + + output = embed(ids) + + # Output should be reshaped and stacked + assert output.ndim == 4 + assert not torch.isnan(output).any() + + def test_embednd_with_different_axes_dims(self): + """Test EmbedND with different axes dimensions.""" + dim = 256 + theta = 10000 + axes_dim = [8, 32, 32] + + embed = EmbedND(dim, theta, axes_dim) + ids = torch.randn(1, 50, 3) + + output = embed(ids) + + assert output.ndim == 4 + assert not torch.isnan(output).any() + + def test_embednd_output_is_finite(self): + """Test that EmbedND output contains no inf values.""" + embed = EmbedND(128, 10000, [16, 56, 56]) + ids = torch.randn(2, 100, 3) + + output = embed(ids) + + assert torch.isfinite(output).all() + + +class TestMLPEmbedder: + """Test the MLPEmbedder class.""" + + def test_mlpembedder_initialization(self): + """Test MLPEmbedder initialization.""" + in_dim = 256 + hidden_dim = 512 + + embedder = MLPEmbedder(in_dim, hidden_dim) + + assert embedder.in_layer.in_features == in_dim + assert embedder.in_layer.out_features == hidden_dim + assert embedder.out_layer.in_features == hidden_dim + assert embedder.out_layer.out_features == hidden_dim + assert isinstance(embedder.silu, torch.nn.SiLU) + + def test_mlpembedder_forward_shape(self): + """Test MLPEmbedder forward pass output shape.""" + in_dim = 256 + hidden_dim = 512 + batch_size = 4 + + embedder = MLPEmbedder(in_dim, hidden_dim) + x = torch.randn(batch_size, in_dim) + + output = embedder(x) + + assert output.shape == (batch_size, hidden_dim) + + def test_mlpembedder_forward_with_different_input_shapes(self): + """Test MLPEmbedder with different input shapes.""" + embedder = MLPEmbedder(128, 256) + + # 2D input + x_2d = torch.randn(4, 128) + out_2d = embedder(x_2d) + assert out_2d.shape == (4, 256) + + # 3D input + x_3d = torch.randn(2, 8, 128) + out_3d = embedder(x_3d) + assert out_3d.shape == (2, 8, 256) + + def test_mlpembedder_output_is_finite(self): + """Test that MLPEmbedder output is finite.""" + embedder = MLPEmbedder(256, 512) + x = torch.randn(10, 256) + + output = embedder(x) + + assert torch.isfinite(output).all() + + def test_mlpembedder_has_bias(self): + """Test that MLPEmbedder layers have bias.""" + embedder = MLPEmbedder(256, 512) + + assert embedder.in_layer.bias is not None + assert embedder.out_layer.bias is not None + + +class TestTimeStepEmbedder: + """Test the TimeStepEmbedder class.""" + + def test_timestepembedder_initialization(self): + """Test TimeStepEmbedder initialization.""" + embedding_dim = 256 + hidden_dim = 512 + + embedder = TimeStepEmbedder(embedding_dim, hidden_dim) + + assert isinstance(embedder.time_embedder, MLPEmbedder) + assert embedder.time_proj.num_channels == embedding_dim + + def test_timestepembedder_forward_shape(self): + """Test TimeStepEmbedder forward pass output shape.""" + embedding_dim = 256 + hidden_dim = 512 + batch_size = 4 + + embedder = TimeStepEmbedder(embedding_dim, hidden_dim) + timesteps = torch.tensor([0.0, 100.0, 500.0, 999.0]) + + output = embedder(timesteps) + + assert output.shape == (batch_size, hidden_dim) + + def test_timestepembedder_with_custom_params(self): + """Test TimeStepEmbedder with custom parameters.""" + embedder = TimeStepEmbedder( + embedding_dim=128, + hidden_dim=256, + flip_sin_to_cos=False, + downscale_freq_shift=1.0, + scale=2.0, + max_period=5000, + ) + + timesteps = torch.tensor([100.0, 200.0]) + output = embedder(timesteps) + + assert output.shape == (2, 256) + assert torch.isfinite(output).all() + + def test_timestepembedder_output_is_finite(self): + """Test that TimeStepEmbedder output is finite.""" + embedder = TimeStepEmbedder(256, 512) + timesteps = torch.randn(10).abs() * 1000 + + output = embedder(timesteps) + + assert torch.isfinite(output).all() + + def test_timestepembedder_is_nn_module(self): + """Test that TimeStepEmbedder is an nn.Module.""" + embedder = TimeStepEmbedder(256, 512) + + assert isinstance(embedder, torch.nn.Module) + + def test_timestepembedder_components_connected(self): + """Test that TimeStepEmbedder components are properly connected.""" + embedder = TimeStepEmbedder(256, 512) + timesteps = torch.tensor([100.0]) + + # Forward through time_proj + proj_output = embedder.time_proj(timesteps) + assert proj_output.shape == (1, 256) + + # Forward through time_embedder + emb_output = embedder.time_embedder(proj_output) + assert emb_output.shape == (1, 512) + + # Full forward pass + full_output = embedder(timesteps) + assert torch.allclose(full_output, emb_output, atol=1e-6) + + +class TestLayersIntegration: + """Integration tests for layers module.""" + + def test_rope_to_embednd_pipeline(self): + """Test using rope in EmbedND.""" + embed = EmbedND(dim=128, theta=10000, axes_dim=[16, 56, 56]) + ids = torch.randn(2, 50, 3) + + # This internally uses rope function + output = embed(ids) + + assert output.ndim == 4 + assert torch.isfinite(output).all() + + def test_timestep_embedding_pipeline(self): + """Test complete timestep embedding pipeline.""" + # Create embedder + embedder = TimeStepEmbedder(embedding_dim=256, hidden_dim=3072) + + # Generate timesteps + timesteps = torch.linspace(0, 1000, 10) + + # Get embeddings + embeddings = embedder(timesteps) + + assert embeddings.shape == (10, 3072) + assert torch.isfinite(embeddings).all() + # Different timesteps should produce different embeddings + assert not torch.allclose(embeddings[0], embeddings[-1]) + + def test_mlp_embedder_in_timestep_embedder(self): + """Test that MLPEmbedder works correctly within TimeStepEmbedder.""" + embedder = TimeStepEmbedder(128, 256) + + timesteps = torch.tensor([0.0, 500.0, 1000.0]) + output = embedder(timesteps) + + # Should pass through both time_proj and time_embedder + assert output.shape == (3, 256) + + +class TestLayersEdgeCases: + """Test edge cases for layers module.""" + + def test_rope_with_large_positions(self): + """Test rope function with large position values.""" + pos = torch.tensor([[1000.0, 2000.0, 3000.0]]) + dim = 64 + theta = 10000 + + result = rope(pos, dim, theta) + + assert torch.isfinite(result).all() + + def test_embednd_with_small_batch(self): + """Test EmbedND with batch size 1.""" + embed = EmbedND(64, 10000, [8, 16, 16]) + ids = torch.randn(1, 10, 3) + + output = embed(ids) + + assert output.shape[1] == 1 # Batch dimension + + def test_mlpembedder_with_zero_input(self): + """Test MLPEmbedder with zero input.""" + embedder = MLPEmbedder(256, 512) + x = torch.zeros(4, 256) + + output = embedder(x) + + assert output.shape == (4, 512) + assert torch.isfinite(output).all() + + def test_timestepembedder_with_very_small_timesteps(self): + """Test TimeStepEmbedder with very small fractional timesteps.""" + embedder = TimeStepEmbedder(256, 512) + timesteps = torch.tensor([0.001, 0.01, 0.1]) + + output = embedder(timesteps) + + assert output.shape == (3, 512) + assert torch.isfinite(output).all() + + def test_rope_different_dimensions(self): + """Test rope with various even dimensions.""" + pos = torch.tensor([[1.0, 2.0]]) + + for dim in [32, 64, 128, 256]: + result = rope(pos, dim, 10000) + assert result.shape[-1] == dim // 2 + assert torch.isfinite(result).all() diff --git a/tests/unit_tests/diffusion/model/flux/test_flux_pipeline.py b/tests/unit_tests/diffusion/model/flux/test_flux_pipeline.py new file mode 100644 index 0000000000..5588c0cf8f --- /dev/null +++ b/tests/unit_tests/diffusion/model/flux/test_flux_pipeline.py @@ -0,0 +1,505 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from megatron.bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline import ( + FlowMatchEulerDiscreteScheduler, + FluxInferencePipeline, +) + + +pytestmark = [pytest.mark.unit] + + +class TestFlowMatchEulerDiscreteScheduler: + """Test class for FlowMatchEulerDiscreteScheduler.""" + + def test_scheduler_initialization_default(self): + """Test scheduler initialization with default parameters.""" + scheduler = FlowMatchEulerDiscreteScheduler() + + assert scheduler.num_train_timesteps == 1000 + assert scheduler.shift == 1.0 + assert scheduler.use_dynamic_shifting is False + assert scheduler.timesteps is not None + assert scheduler.sigmas is not None + assert scheduler._step_index is None + assert scheduler._begin_index is None + + def test_scheduler_initialization_custom(self): + """Test scheduler initialization with custom parameters.""" + num_timesteps = 500 + shift = 2.0 + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_timesteps, shift=shift) + + assert scheduler.num_train_timesteps == num_timesteps + assert scheduler.shift == shift + + def test_scheduler_initialization_with_dynamic_shifting(self): + """Test scheduler initialization with dynamic shifting enabled.""" + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, base_shift=0.5, max_shift=1.15, base_image_seq_len=256, max_image_seq_len=4096 + ) + + assert scheduler.use_dynamic_shifting is True + assert scheduler.base_shift == 0.5 + assert scheduler.max_shift == 1.15 + assert scheduler.base_image_seq_len == 256 + assert scheduler.max_image_seq_len == 4096 + + def test_scheduler_timesteps_shape(self): + """Test that timesteps have correct shape.""" + num_timesteps = 1000 + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_timesteps) + + assert scheduler.timesteps.shape[0] == num_timesteps + assert scheduler.sigmas.shape[0] == num_timesteps + + def test_scheduler_sigma_range(self): + """Test that sigmas are in valid range.""" + scheduler = FlowMatchEulerDiscreteScheduler() + + assert (scheduler.sigmas >= 0).all() + assert (scheduler.sigmas <= 1).all() + assert scheduler.sigma_min >= 0 + assert scheduler.sigma_max <= 1 + assert scheduler.sigma_min <= scheduler.sigma_max + + def test_set_begin_index(self): + """Test set_begin_index method.""" + scheduler = FlowMatchEulerDiscreteScheduler() + + scheduler.set_begin_index(10) + assert scheduler.begin_index == 10 + + scheduler.set_begin_index(0) + assert scheduler.begin_index == 0 + + def test_step_index_property(self): + """Test step_index property.""" + scheduler = FlowMatchEulerDiscreteScheduler() + + assert scheduler.step_index is None + + # After setting internally + scheduler._step_index = 5 + assert scheduler.step_index == 5 + + def test_set_timesteps_basic(self): + """Test set_timesteps with basic parameters.""" + scheduler = FlowMatchEulerDiscreteScheduler() + num_inference_steps = 50 + + scheduler.set_timesteps(num_inference_steps=num_inference_steps, device="cpu") + + assert scheduler.num_inference_steps == num_inference_steps + assert len(scheduler.timesteps) == num_inference_steps + assert scheduler.timesteps.device.type == "cpu" + + def test_set_timesteps_with_custom_sigmas(self): + """Test set_timesteps with custom sigmas.""" + import numpy as np + + scheduler = FlowMatchEulerDiscreteScheduler() + custom_sigmas = np.array([1.0, 0.75, 0.5, 0.25, 0.0]) + + scheduler.set_timesteps(sigmas=custom_sigmas, device="cpu") + + assert len(scheduler.timesteps) == len(custom_sigmas) + + def test_set_timesteps_with_dynamic_shifting(self): + """Test set_timesteps with dynamic shifting.""" + scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=True) + num_inference_steps = 20 + mu = 0.5 + + scheduler.set_timesteps(num_inference_steps=num_inference_steps, device="cpu", mu=mu) + + assert len(scheduler.timesteps) == num_inference_steps + + def test_set_timesteps_dynamic_shifting_requires_mu(self): + """Test that dynamic shifting requires mu parameter.""" + scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=True) + + with pytest.raises(ValueError, match="you have a pass a value for `mu`"): + scheduler.set_timesteps(num_inference_steps=10, device="cpu") + + def test_scale_noise(self): + """Test scale_noise method.""" + scheduler = FlowMatchEulerDiscreteScheduler() + scheduler.set_timesteps(num_inference_steps=10, device="cpu") + batch_size = 2 + channels = 4 + height = width = 8 + + sample = torch.randn(batch_size, channels, height, width) + noise = torch.randn_like(sample) + # timestep must be a 1-d tensor (batch of timesteps) + timestep = scheduler.timesteps[0:1].repeat(batch_size) + + noisy_sample = scheduler.scale_noise(sample, timestep, noise) + + assert noisy_sample.shape == sample.shape + assert not torch.isnan(noisy_sample).any() + assert torch.isfinite(noisy_sample).all() + + def test_index_for_timestep(self): + """Test index_for_timestep method.""" + scheduler = FlowMatchEulerDiscreteScheduler() + + # Get a timestep from the schedule + timestep = scheduler.timesteps[10] + index = scheduler.index_for_timestep(timestep) + + # Index should be valid + assert 0 <= index < len(scheduler.timesteps) + + def test_step_basic(self): + """Test step method basic functionality.""" + scheduler = FlowMatchEulerDiscreteScheduler() + scheduler.set_timesteps(num_inference_steps=10, device="cpu") + + batch_size = 2 + channels = 4 + height = width = 8 + + sample = torch.randn(batch_size, channels, height, width) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0] + + prev_sample = scheduler.step(model_output, timestep, sample)[0] + + assert prev_sample.shape == sample.shape + assert not torch.isnan(prev_sample).any() + + def test_step_increments_step_index(self): + """Test that step method increments step_index.""" + scheduler = FlowMatchEulerDiscreteScheduler() + scheduler.set_timesteps(num_inference_steps=10, device="cpu") + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + + assert scheduler.step_index is None + + scheduler.step(model_output, scheduler.timesteps[0], sample) + assert scheduler.step_index == 1 + + scheduler.step(model_output, scheduler.timesteps[1], sample) + assert scheduler.step_index == 2 + + def test_step_rejects_integer_timesteps(self): + """Test that step method rejects integer timesteps.""" + scheduler = FlowMatchEulerDiscreteScheduler() + scheduler.set_timesteps(num_inference_steps=10, device="cpu") + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + + with pytest.raises(ValueError, match="Passing integer indices"): + scheduler.step(model_output, 5, sample) + + def test_scheduler_length(self): + """Test __len__ method.""" + num_timesteps = 500 + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_timesteps) + + assert len(scheduler) == num_timesteps + + def test_time_shift_method(self): + """Test time_shift method.""" + scheduler = FlowMatchEulerDiscreteScheduler() + mu = 0.5 + sigma = 1.0 + t = torch.tensor([0.1, 0.5, 0.9]) + + shifted = scheduler.time_shift(mu, sigma, t) + + assert shifted.shape == t.shape + assert not torch.isnan(shifted).any() + + +class TestFluxInferencePipelineStaticMethods: + """Test static methods of FluxInferencePipeline.""" + + def test_prepare_latent_image_ids(self): + """Test _prepare_latent_image_ids static method.""" + batch_size = 2 + height = 64 + width = 64 + device = torch.device("cpu") + dtype = torch.float32 + + latent_ids = FluxInferencePipeline._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + expected_seq_len = (height // 2) * (width // 2) + assert latent_ids.shape == (batch_size, expected_seq_len, 3) + assert latent_ids.device == device + assert latent_ids.dtype == dtype + + def test_pack_latents(self): + """Test _pack_latents static method.""" + batch_size = 2 + num_channels = 16 + height = 64 + width = 64 + + latents = torch.randn(batch_size, num_channels, height, width) + packed = FluxInferencePipeline._pack_latents(latents, batch_size, num_channels, height, width) + + expected_seq_len = (height // 2) * (width // 2) + expected_channels = num_channels * 4 + assert packed.shape == (batch_size, expected_seq_len, expected_channels) + + def test_unpack_latents(self): + """Test _unpack_latents static method.""" + batch_size = 2 + height = 512 + width = 512 + vae_scale_factor = 16 + channels = 64 + # num_patches = (height // vae_scale_factor) * (width // vae_scale_factor) + num_patches = (height // vae_scale_factor) * (width // vae_scale_factor) # 32 * 32 = 1024 + + packed_latents = torch.randn(batch_size, num_patches, channels) + unpacked = FluxInferencePipeline._unpack_latents(packed_latents, height, width, vae_scale_factor) + + expected_height = (height // vae_scale_factor) * 2 + expected_width = (width // vae_scale_factor) * 2 + expected_channels = channels // 4 + assert unpacked.shape == (batch_size, expected_channels, expected_height, expected_width) + + def test_calculate_shift(self): + """Test _calculate_shift static method.""" + # Test with default parameters + image_seq_len = 256 + shift = FluxInferencePipeline._calculate_shift(image_seq_len) + + assert isinstance(shift, (int, float)) + assert shift >= 0.5 # Should be at least base_shift + + # Test with larger sequence length + large_seq_len = 4096 + large_shift = FluxInferencePipeline._calculate_shift(large_seq_len) + + # Larger sequence length should give larger shift + assert large_shift >= shift + + def test_numpy_to_pil(self): + """Test numpy_to_pil static method.""" + # Test single image + single_image = np.random.rand(256, 256, 3) + pil_images = FluxInferencePipeline.numpy_to_pil(single_image) + + assert len(pil_images) == 1 + from PIL import Image + + assert isinstance(pil_images[0], Image.Image) + + # Test batch of images + batch_images = np.random.rand(4, 256, 256, 3) + pil_batch = FluxInferencePipeline.numpy_to_pil(batch_images) + + assert len(pil_batch) == 4 + assert all(isinstance(img, Image.Image) for img in pil_batch) + + def test_torch_to_numpy(self): + """Test torch_to_numpy static method.""" + batch_size = 4 + channels = 3 + height = width = 256 + + torch_images = torch.randn(batch_size, channels, height, width) + numpy_images = FluxInferencePipeline.torch_to_numpy(torch_images) + + assert numpy_images.shape == (batch_size, height, width, channels) + assert isinstance(numpy_images, np.ndarray) + + def test_denormalize(self): + """Test denormalize static method.""" + # Create images in range [-1, 1] + images = torch.randn(4, 3, 256, 256) * 2 - 1 + denorm = FluxInferencePipeline.denormalize(images) + + # Should be clamped to [0, 1] + assert (denorm >= 0).all() + assert (denorm <= 1).all() + + +class TestFluxInferencePipelineHelperMethods: + """Test helper methods of FluxInferencePipeline.""" + + def test_prepare_latents_shape(self, monkeypatch): + """Test prepare_latents method output shape.""" + + # Mock dependencies to avoid loading models + def mock_setup(self, checkpoint_dir): + class MockTransformer: + in_channels = 64 + + return MockTransformer() + + def mock_load_text(self, t5, clip): + pass + + def mock_load_vae(self, vae): + pass + + monkeypatch.setattr(FluxInferencePipeline, "setup_model_from_checkpoint", mock_setup) + monkeypatch.setattr(FluxInferencePipeline, "load_text_encoders", mock_load_text) + monkeypatch.setattr(FluxInferencePipeline, "load_vae", mock_load_vae) + + pipeline = FluxInferencePipeline() + + batch_size = 2 + num_channels = 16 + height = 512 + width = 512 + dtype = torch.float32 + device = torch.device("cpu") + + latents, latent_ids = pipeline.prepare_latents( + batch_size, num_channels, height, width, dtype, device, generator=None + ) + + # Check shapes + assert latents.ndim == 3 + assert latent_ids.shape[0] == batch_size + assert latent_ids.shape[2] == 3 # 3D position IDs + + +class TestFluxInferencePipelineIntegration: + """Integration tests for FluxInferencePipeline (without actual model loading).""" + + def test_pipeline_initialization_attributes(self, monkeypatch): + """Test that pipeline sets up basic attributes.""" + + def mock_setup(self, checkpoint_dir): + class MockTransformer: + in_channels = 64 + guidance_embed = False + + return MockTransformer() + + def mock_load_text(self, t5, clip): + self.t5_encoder = None + self.clip_encoder = None + + def mock_load_vae(self, vae): + self.vae = None + + monkeypatch.setattr(FluxInferencePipeline, "setup_model_from_checkpoint", mock_setup) + monkeypatch.setattr(FluxInferencePipeline, "load_text_encoders", mock_load_text) + monkeypatch.setattr(FluxInferencePipeline, "load_vae", mock_load_vae) + + pipeline = FluxInferencePipeline(scheduler_steps=500) + + assert pipeline.device == "cuda:0" + assert pipeline.vae_scale_factor == 16 + assert hasattr(pipeline, "scheduler") + assert isinstance(pipeline.scheduler, FlowMatchEulerDiscreteScheduler) + assert pipeline.scheduler.num_train_timesteps == 500 + + def test_pack_unpack_latents_roundtrip(self): + """Test that pack and unpack latents are inverse operations.""" + batch_size = 2 + num_channels = 16 + height = 64 + width = 64 + vae_scale_factor = 16 + # Original image dimensions: 2 * original_h // vae_scale_factor = height + # So: original_h = height * vae_scale_factor // 2 + original_height = height * vae_scale_factor // 2 + original_width = width * vae_scale_factor // 2 + + # Original latents + original = torch.randn(batch_size, num_channels, height, width) + + # Pack + packed = FluxInferencePipeline._pack_latents(original, batch_size, num_channels, height, width) + + # Unpack - should restore original dimensions + unpacked = FluxInferencePipeline._unpack_latents(packed, original_height, original_width, vae_scale_factor) + + assert unpacked.shape[0] == batch_size + assert unpacked.shape[1] == num_channels + assert unpacked.shape == original.shape + + def test_scheduler_step_sequence(self): + """Test a sequence of scheduler steps.""" + scheduler = FlowMatchEulerDiscreteScheduler() + scheduler.set_timesteps(num_inference_steps=5, device="cpu") + + sample = torch.randn(1, 4, 32, 32) + + for i, timestep in enumerate(scheduler.timesteps[:-1]): # Exclude last to avoid index error + model_output = torch.randn_like(sample) + sample = scheduler.step(model_output, timestep, sample)[0] + + assert not torch.isnan(sample).any() + assert torch.isfinite(sample).all() + + +class TestFluxInferencePipelineEdgeCases: + """Test edge cases and error handling.""" + + def test_prepare_latent_image_ids_small_dimensions(self): + """Test _prepare_latent_image_ids with small dimensions.""" + latent_ids = FluxInferencePipeline._prepare_latent_image_ids(1, 4, 4, torch.device("cpu"), torch.float32) + + assert latent_ids.shape == (1, 4, 3) + + def test_calculate_shift_boundary_values(self): + """Test _calculate_shift with boundary sequence lengths.""" + base_seq_len = 256 + max_seq_len = 4096 + + # Test at base + shift_base = FluxInferencePipeline._calculate_shift(base_seq_len) + assert shift_base >= 0 + + # Test at max + shift_max = FluxInferencePipeline._calculate_shift(max_seq_len) + assert shift_max >= shift_base + + # Test below base + shift_below = FluxInferencePipeline._calculate_shift(128) + assert shift_below >= 0 + + # Test above max + shift_above = FluxInferencePipeline._calculate_shift(8192) + assert shift_above >= 0 + + def test_denormalize_extreme_values(self): + """Test denormalize with extreme values.""" + # Very negative values + extreme_neg = torch.full((2, 3, 64, 64), -10.0) + denorm_neg = FluxInferencePipeline.denormalize(extreme_neg) + assert (denorm_neg == 0).all() + + # Very positive values + extreme_pos = torch.full((2, 3, 64, 64), 10.0) + denorm_pos = FluxInferencePipeline.denormalize(extreme_pos) + assert (denorm_pos == 1).all() + + def test_scheduler_sigma_ordering(self): + """Test that scheduler sigmas are in descending order.""" + scheduler = FlowMatchEulerDiscreteScheduler() + + # Sigmas should generally decrease (though not strictly due to shifting) + # Just check first and last + assert scheduler.sigmas[0] >= scheduler.sigmas[-1] diff --git a/tests/unit_tests/diffusion/model/flux/test_flux_provider.py b/tests/unit_tests/diffusion/model/flux/test_flux_provider.py new file mode 100644 index 0000000000..c38f04f878 --- /dev/null +++ b/tests/unit_tests/diffusion/model/flux/test_flux_provider.py @@ -0,0 +1,295 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from megatron.core import parallel_state + +from megatron.bridge.diffusion.models.flux.flux_provider import FluxProvider + + +pytestmark = [pytest.mark.unit] + + +def _mock_parallel_state(monkeypatch): + """Mock parallel_state functions to avoid initialization requirements.""" + monkeypatch.setattr(parallel_state, "is_pipeline_first_stage", lambda: True, raising=False) + monkeypatch.setattr(parallel_state, "is_pipeline_last_stage", lambda: True, raising=False) + monkeypatch.setattr(parallel_state, "get_tensor_model_parallel_world_size", lambda: 1, raising=False) + monkeypatch.setattr(parallel_state, "get_pipeline_model_parallel_world_size", lambda: 1, raising=False) + monkeypatch.setattr( + parallel_state, "get_data_parallel_world_size", lambda with_context_parallel=False: 1, raising=False + ) + monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 1, raising=False) + monkeypatch.setattr(parallel_state, "get_tensor_model_parallel_group", lambda **kwargs: None, raising=False) + monkeypatch.setattr(parallel_state, "get_data_parallel_group", lambda **kwargs: None, raising=False) + monkeypatch.setattr(parallel_state, "get_context_parallel_group", lambda **kwargs: None, raising=False) + monkeypatch.setattr(parallel_state, "get_tensor_and_data_parallel_group", lambda **kwargs: None, raising=False) + + +def test_flux_provider_initialization_defaults(): + """Test FluxProvider initialization with default values.""" + provider = FluxProvider() + + # Base class requirements + assert provider.num_layers == 1 # Dummy setting + assert provider.hidden_size == 3072 + assert provider.ffn_hidden_size == 12288 + assert provider.num_attention_heads == 24 + assert provider.layernorm_epsilon == 1e-06 + assert provider.hidden_dropout == 0 + assert provider.attention_dropout == 0 + + # FLUX-specific layer configuration + assert provider.num_joint_layers == 19 + assert provider.num_single_layers == 38 + + # Model architecture + assert provider.add_qkv_bias is True + assert provider.in_channels == 64 + assert provider.context_dim == 4096 + assert provider.model_channels == 256 + assert provider.axes_dims_rope == [16, 56, 56] + assert provider.patch_size == 1 + assert provider.guidance_embed is False + assert provider.vec_in_dim == 768 + + +def test_flux_provider_initialization_custom(): + """Test FluxProvider initialization with custom values.""" + provider = FluxProvider( + hidden_size=2048, + num_attention_heads=16, + kv_channels=128, # 2048 // 16 + num_query_groups=16, # Same as num_attention_heads (no GQA) + num_joint_layers=10, + num_single_layers=20, + in_channels=32, + guidance_embed=True, + ) + + assert provider.hidden_size == 2048 + assert provider.num_attention_heads == 16 + assert provider.num_joint_layers == 10 + assert provider.num_single_layers == 20 + assert provider.in_channels == 32 + assert provider.guidance_embed is True + + +def test_flux_provider_rotary_embedding_settings(): + """Test FluxProvider rotary embedding settings.""" + provider = FluxProvider() + + assert provider.rotary_interleaved is True + assert provider.apply_rope_fusion is False + + +def test_flux_provider_initialization_settings(): + """Test FluxProvider initialization and performance settings.""" + provider = FluxProvider() + + assert provider.use_cpu_initialization is True + assert provider.gradient_accumulation_fusion is False + assert provider.enable_cuda_graph is False + assert provider.cuda_graph_scope is None + assert provider.use_te_rng_tracker is False + assert provider.cuda_graph_warmup_steps == 2 + + +def test_flux_provider_inference_settings(): + """Test FluxProvider inference settings.""" + provider = FluxProvider() + + assert provider.guidance_scale == 3.5 + + +def test_flux_provider_checkpoint_settings(): + """Test FluxProvider checkpoint loading settings.""" + provider = FluxProvider() + + assert provider.ckpt_path is None + assert provider.load_dist_ckpt is False + assert provider.do_convert_from_hf is False + assert provider.save_converted_model_to is None + + +def test_flux_provider_llm_compatibility_attributes(): + """Test FluxProvider has attributes for LLM compatibility.""" + provider = FluxProvider() + + # These attributes are unused for images/videos but required by bridge training for LLMs + assert provider.seq_length == 1024 + assert provider.share_embeddings_and_output_weights is False + assert provider.vocab_size == 25256 * 8 + assert provider.make_vocab_size_divisible_by == 128 + + +def test_flux_provider_virtual_pipeline_validation(): + """Test that FluxProvider validates virtual pipeline configuration.""" + provider = FluxProvider( + num_joint_layers=12, + num_single_layers=24, + virtual_pipeline_model_parallel_size=2, + pipeline_model_parallel_size=3, + ) + + total_layers = provider.num_joint_layers + provider.num_single_layers + p_size = provider.pipeline_model_parallel_size + vp_size = provider.virtual_pipeline_model_parallel_size + + # Should satisfy: (total_layers // p_size) % vp_size == 0 + # (36 // 3) % 2 == 12 % 2 == 0 ✓ + assert (total_layers // p_size) % vp_size == 0 + + +def test_flux_provider_virtual_pipeline_validation_fails(monkeypatch): + """Test that FluxProvider raises assertion error for invalid virtual pipeline configuration.""" + _mock_parallel_state(monkeypatch) + + provider = FluxProvider( + num_joint_layers=10, + num_single_layers=20, + virtual_pipeline_model_parallel_size=3, + pipeline_model_parallel_size=2, + ) + + # Should fail: (30 // 2) % 3 == 15 % 3 == 0 ✓ (actually this passes) + # Let's create a failing case: (10 // 2) % 3 == 5 % 3 == 2 ≠ 0 + provider.num_joint_layers = 5 + provider.num_single_layers = 5 + + with pytest.raises(AssertionError, match="Make sure the number of model chunks is the same"): + provider.provide() + + +def test_flux_provider_axes_dims_rope_field(): + """Test that axes_dims_rope field factory works correctly.""" + provider1 = FluxProvider() + provider2 = FluxProvider() + + # Should have default values + assert provider1.axes_dims_rope == [16, 56, 56] + assert provider2.axes_dims_rope == [16, 56, 56] + + # Should be independent instances (not sharing same list) + provider1.axes_dims_rope[0] = 32 + assert provider2.axes_dims_rope[0] == 16 # Should not be affected + + +def test_flux_provider_custom_axes_dims_rope(): + """Test FluxProvider with custom axes_dims_rope.""" + custom_axes = [8, 32, 32] + provider = FluxProvider(axes_dims_rope=custom_axes) + + assert provider.axes_dims_rope == custom_axes + + +def test_flux_provider_activation_func_default(): + """Test that FluxProvider has default activation function.""" + provider = FluxProvider() + + from megatron.core.transformer.utils import openai_gelu + + assert provider.activation_func == openai_gelu + + +def test_flux_provider_custom_checkpoint_settings(): + """Test FluxProvider with custom checkpoint settings.""" + provider = FluxProvider( + ckpt_path="/path/to/checkpoint", + load_dist_ckpt=True, + do_convert_from_hf=True, + save_converted_model_to="/path/to/save", + ) + + assert provider.ckpt_path == "/path/to/checkpoint" + assert provider.load_dist_ckpt is True + assert provider.do_convert_from_hf is True + assert provider.save_converted_model_to == "/path/to/save" + + +def test_flux_provider_cuda_graph_settings(): + """Test FluxProvider CUDA graph settings.""" + provider = FluxProvider(enable_cuda_graph=True, cuda_graph_scope="full", cuda_graph_warmup_steps=5) + + assert provider.enable_cuda_graph is True + assert provider.cuda_graph_scope == "full" + assert provider.cuda_graph_warmup_steps == 5 + + +def test_flux_provider_is_transformer_config(): + """Test that FluxProvider is a TransformerConfig.""" + from megatron.bridge.models.transformer_config import TransformerConfig + + provider = FluxProvider() + + assert isinstance(provider, TransformerConfig) + + +def test_flux_provider_is_model_provider_mixin(): + """Test that FluxProvider is a ModelProviderMixin.""" + from megatron.bridge.models.model_provider import ModelProviderMixin + + provider = FluxProvider() + + assert isinstance(provider, ModelProviderMixin) + + +def test_flux_provider_has_provide_method(): + """Test that FluxProvider has provide method.""" + provider = FluxProvider() + + assert hasattr(provider, "provide") + assert callable(provider.provide) + + +def test_flux_provider_dtype_settings(): + """Test FluxProvider data type settings.""" + provider = FluxProvider(bf16=True, params_dtype=torch.bfloat16) + + assert provider.bf16 is True + assert provider.params_dtype == torch.bfloat16 + + +def test_flux_provider_parallel_settings(): + """Test FluxProvider parallel configuration settings.""" + provider = FluxProvider(tensor_model_parallel_size=2, pipeline_model_parallel_size=4, sequence_parallel=True) + + assert provider.tensor_model_parallel_size == 2 + assert provider.pipeline_model_parallel_size == 4 + assert provider.sequence_parallel is True + + +def test_flux_provider_num_layers_is_dummy(): + """Test that num_layers is a dummy value and not used for layer count.""" + provider = FluxProvider() + + # num_layers is set to 1 (dummy) but actual layers are controlled by: + assert provider.num_layers == 1 + assert provider.num_joint_layers == 19 # Actual double block count + assert provider.num_single_layers == 38 # Actual single block count + + +def test_flux_provider_default_guidance_scale(): + """Test that guidance_scale has correct default value.""" + provider = FluxProvider() + + assert provider.guidance_scale == 3.5 + + +def test_flux_provider_custom_guidance_scale(): + """Test FluxProvider with custom guidance_scale.""" + provider = FluxProvider(guidance_scale=7.5) + + assert provider.guidance_scale == 7.5 diff --git a/tests/unit_tests/diffusion/model/flux/test_flux_step.py b/tests/unit_tests/diffusion/model/flux/test_flux_step.py new file mode 100644 index 0000000000..4deefa8356 --- /dev/null +++ b/tests/unit_tests/diffusion/model/flux/test_flux_step.py @@ -0,0 +1,438 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from unittest.mock import MagicMock + +import pytest +import torch + +from megatron.bridge.diffusion.models.flux.flux_step import FluxForwardStep, flux_data_step + + +pytestmark = [pytest.mark.unit] + + +@pytest.mark.run_only_on("GPU") +class TestFluxDataStep: + """Test flux_data_step function.""" + + def test_flux_data_step_basic(self): + """Test basic flux_data_step functionality.""" + # Create mock iterator + batch = {"latents": torch.randn(2, 16, 64, 64), "prompt_embeds": torch.randn(2, 512, 4096)} + dataloader_iter = iter([batch]) + + result = flux_data_step(dataloader_iter) + + assert "latents" in result + assert "prompt_embeds" in result + assert "loss_mask" in result + assert result["loss_mask"].device.type == "cuda" + + def test_flux_data_step_with_tuple_input(self): + """Test flux_data_step with tuple input from dataloader.""" + batch = {"latents": torch.randn(2, 16, 64, 64)} + dataloader_iter = iter([(batch, None, None)]) + + result = flux_data_step(dataloader_iter) + + assert "latents" in result + assert "loss_mask" in result + + def test_flux_data_step_preserves_loss_mask(self): + """Test that existing loss_mask is preserved.""" + custom_loss_mask = torch.ones(2) + batch = {"latents": torch.randn(2, 16, 64, 64), "loss_mask": custom_loss_mask} + dataloader_iter = iter([batch]) + + result = flux_data_step(dataloader_iter) + + assert torch.equal(result["loss_mask"].cpu(), custom_loss_mask) + + def test_flux_data_step_creates_default_loss_mask(self): + """Test that default loss_mask is created when missing.""" + batch = {"latents": torch.randn(2, 16, 64, 64)} + dataloader_iter = iter([batch]) + + result = flux_data_step(dataloader_iter) + + assert "loss_mask" in result + assert result["loss_mask"].shape == (1,) + assert torch.all(result["loss_mask"] == 1.0) + + def test_flux_data_step_moves_tensors_to_cuda(self): + """Test that tensors are moved to CUDA.""" + batch = { + "latents": torch.randn(2, 16, 64, 64), + "prompt_embeds": torch.randn(2, 512, 4096), + "non_tensor": "text", + } + dataloader_iter = iter([batch]) + + result = flux_data_step(dataloader_iter) + + assert result["latents"].device.type == "cuda" + assert result["prompt_embeds"].device.type == "cuda" + assert result["non_tensor"] == "text" # Non-tensors unchanged + + +class TestFluxForwardStepInitialization: + """Test FluxForwardStep initialization.""" + + def test_initialization_defaults(self): + """Test FluxForwardStep initialization with default values.""" + step = FluxForwardStep() + + assert step.timestep_sampling == "logit_normal" + assert step.logit_mean == 0.0 + assert step.logit_std == 1.0 + assert step.mode_scale == 1.29 + assert step.scheduler_steps == 1000 + assert step.guidance_scale == 3.5 + assert step.autocast_dtype == torch.bfloat16 + assert hasattr(step, "scheduler") + + def test_initialization_custom(self): + """Test FluxForwardStep initialization with custom values.""" + step = FluxForwardStep( + timestep_sampling="uniform", + logit_mean=1.0, + logit_std=2.0, + mode_scale=1.5, + scheduler_steps=500, + guidance_scale=7.5, + ) + + assert step.timestep_sampling == "uniform" + assert step.logit_mean == 1.0 + assert step.logit_std == 2.0 + assert step.mode_scale == 1.5 + assert step.scheduler_steps == 500 + assert step.guidance_scale == 7.5 + + +class TestFluxForwardStepTimestepSampling: + """Test timestep sampling methods.""" + + def test_compute_density_logit_normal(self): + """Test logit-normal timestep sampling.""" + step = FluxForwardStep(timestep_sampling="logit_normal", logit_mean=0.0, logit_std=1.0) + batch_size = 10 + + u = step.compute_density_for_timestep_sampling("logit_normal", batch_size) + + assert u.shape == (batch_size,) + assert (u >= 0).all() + assert (u <= 1).all() + + def test_compute_density_mode(self): + """Test mode-based timestep sampling.""" + step = FluxForwardStep(timestep_sampling="mode", mode_scale=1.29) + batch_size = 10 + + u = step.compute_density_for_timestep_sampling("mode", batch_size) + + assert u.shape == (batch_size,) + assert (u >= 0).all() + assert (u <= 1).all() + + def test_compute_density_uniform(self): + """Test uniform timestep sampling.""" + step = FluxForwardStep(timestep_sampling="uniform") + batch_size = 10 + + u = step.compute_density_for_timestep_sampling("uniform", batch_size) + + assert u.shape == (batch_size,) + assert (u >= 0).all() + assert (u <= 1).all() + + def test_compute_density_uses_instance_defaults(self): + """Test that compute_density uses instance defaults when not provided.""" + step = FluxForwardStep(logit_mean=0.5, logit_std=0.8, mode_scale=1.5) + + # Should use instance defaults + u = step.compute_density_for_timestep_sampling("logit_normal", batch_size=5) + + assert u.shape == (5,) + + def test_compute_density_override_defaults(self): + """Test that compute_density can override instance defaults.""" + step = FluxForwardStep(logit_mean=0.0, logit_std=1.0) + + # Override with custom values + u = step.compute_density_for_timestep_sampling("logit_normal", batch_size=5, logit_mean=1.0, logit_std=0.5) + + assert u.shape == (5,) + + +class TestFluxForwardStepLatentOperations: + """Test latent packing/unpacking operations.""" + + def test_pack_latents(self): + """Test _pack_latents method.""" + step = FluxForwardStep() + batch_size = 2 + num_channels = 16 + height = 64 + width = 64 + + latents = torch.randn(batch_size, num_channels, height, width) + packed = step._pack_latents(latents, batch_size, num_channels, height, width) + + expected_seq_len = (height // 2) * (width // 2) + expected_channels = num_channels * 4 + assert packed.shape == (batch_size, expected_seq_len, expected_channels) + + def test_unpack_latents(self): + """Test _unpack_latents method.""" + step = FluxForwardStep() + batch_size = 2 + num_patches = 1024 # (64 // 2) * (64 // 2) + channels = 64 # 16 * 4 + height = 64 + width = 64 + + packed_latents = torch.randn(batch_size, num_patches, channels) + unpacked = step._unpack_latents(packed_latents, height, width) + + expected_channels = channels // 4 + assert unpacked.shape == (batch_size, expected_channels, height, width) + + def test_pack_unpack_roundtrip(self): + """Test that pack and unpack are consistent.""" + step = FluxForwardStep() + batch_size = 2 + num_channels = 16 + height = 64 + width = 64 + + original = torch.randn(batch_size, num_channels, height, width) + packed = step._pack_latents(original, batch_size, num_channels, height, width) + unpacked = step._unpack_latents(packed, height, width) + + assert unpacked.shape == original.shape + # Note: Due to the reshape operations, values should be approximately equal + # but the exact comparison might not hold due to floating point operations + + def test_prepare_latent_image_ids(self): + """Test _prepare_latent_image_ids method.""" + step = FluxForwardStep() + batch_size = 2 + height = 64 + width = 64 + device = torch.device("cpu") + dtype = torch.float32 + + # First call creates the IDs + ids = step._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + expected_seq_len = (height // 2) * (width // 2) + assert ids.shape == (batch_size, expected_seq_len, 3) + assert ids.device == device + assert ids.dtype == dtype + + # Second call should use cache + ids2 = step._prepare_latent_image_ids(batch_size, height, width, device, dtype) + assert ids2.shape == ids.shape + + def test_prepare_latent_image_ids_caching(self): + """Test that _prepare_latent_image_ids uses LRU cache.""" + step = FluxForwardStep() + + # Cache should work with same parameters + ids1 = step._prepare_latent_image_ids(2, 64, 64, torch.device("cpu"), torch.float32) + ids2 = step._prepare_latent_image_ids(2, 64, 64, torch.device("cpu"), torch.float32) + + # Should be the same object from cache + assert ids1.data_ptr() == ids2.data_ptr() + + +@pytest.mark.run_only_on("GPU") +class TestFluxForwardStepPrepareImageLatent: + """Test prepare_image_latent method.""" + + def test_prepare_image_latent_basic(self): + """Test prepare_image_latent with basic input.""" + step = FluxForwardStep() + batch_size = 2 + channels = 16 + height = 64 + width = 64 + + latents = torch.randn(batch_size, channels, height, width, device="cuda") + + # Mock model + mock_model = MagicMock() + mock_model.guidance_embed = False + + result = step.prepare_image_latent(latents, mock_model) + + # Unpack result tuple + ret_latents, noise, packed_noisy_input, latent_ids, guidance_vec, timesteps = result + + # Check shapes (transposed from [B, ...] to [seq, B, ...] format) + assert ret_latents.shape[1] == batch_size + assert noise.shape[1] == batch_size + assert packed_noisy_input.shape[1] == batch_size + assert latent_ids.shape[0] == batch_size + assert guidance_vec is None + assert timesteps.shape[0] == batch_size + + def test_prepare_image_latent_with_guidance(self): + """Test prepare_image_latent with guidance embedding.""" + step = FluxForwardStep(guidance_scale=7.5) + batch_size = 2 + channels = 16 + height = 64 + width = 64 + + latents = torch.randn(batch_size, channels, height, width, device="cuda") + + # Mock model with guidance + mock_model = MagicMock() + mock_model.guidance_embed = True + + result = step.prepare_image_latent(latents, mock_model) + ret_latents, noise, packed_noisy_input, latent_ids, guidance_vec, timesteps = result + + assert guidance_vec is not None + assert guidance_vec.shape == (batch_size,) + assert torch.all(guidance_vec == 7.5) + + +class TestFluxForwardStepLossFunction: + """Test loss function creation.""" + + def test_create_loss_function(self): + """Test _create_loss_function method.""" + step = FluxForwardStep() + loss_mask = torch.ones(4, dtype=torch.float32) + + loss_fn = step._create_loss_function(loss_mask, check_for_nan_in_loss=True, check_for_spiky_loss=False) + + assert isinstance(loss_fn, partial) + assert callable(loss_fn) + + def test_create_loss_function_parameters(self): + """Test that loss function parameters are correctly set.""" + step = FluxForwardStep() + loss_mask = torch.ones(2, dtype=torch.float32) + + loss_fn = step._create_loss_function(loss_mask, check_for_nan_in_loss=False, check_for_spiky_loss=True) + + # Verify it's a partial with expected arguments + assert loss_fn.func.__name__ == "masked_next_token_loss" + assert loss_fn.keywords["check_for_nan_in_loss"] is False + assert loss_fn.keywords["check_for_spiky_loss"] is True + + +class TestFluxForwardStepIntegration: + """Integration tests for FluxForwardStep.""" + + def test_timestep_sampling_methods_produce_valid_values(self): + """Test that all timestep sampling methods produce valid u values.""" + batch_size = 100 + + for method in ["logit_normal", "mode", "uniform"]: + step = FluxForwardStep(timestep_sampling=method) + u = step.compute_density_for_timestep_sampling(method, batch_size) + + assert u.shape == (batch_size,) + assert (u >= 0).all(), f"{method} produced u < 0" + assert (u <= 1).all(), f"{method} produced u > 1" + assert not torch.isnan(u).any(), f"{method} produced NaN values" + + def test_latent_operations_preserve_batch_dimension(self): + """Test that latent operations preserve batch dimension.""" + step = FluxForwardStep() + + for batch_size in [1, 2, 4]: + latents = torch.randn(batch_size, 16, 64, 64) + packed = step._pack_latents(latents, batch_size, 16, 64, 64) + unpacked = step._unpack_latents(packed, 64, 64) + + assert packed.shape[0] == batch_size + assert unpacked.shape[0] == batch_size + + +class TestFluxForwardStepEdgeCases: + """Test edge cases and error handling.""" + + def test_pack_latents_small_dimensions(self): + """Test _pack_latents with small dimensions.""" + step = FluxForwardStep() + latents = torch.randn(1, 4, 4, 4) + + packed = step._pack_latents(latents, 1, 4, 4, 4) + + assert packed.shape == (1, 4, 16) # (4/2) * (4/2) = 4, 4*4 = 16 + + def test_unpack_latents_small_dimensions(self): + """Test _unpack_latents with small dimensions.""" + step = FluxForwardStep() + packed = torch.randn(1, 4, 16) + + unpacked = step._unpack_latents(packed, 4, 4) + + assert unpacked.shape == (1, 4, 4, 4) + + def test_compute_density_mode_with_extreme_scale(self): + """Test mode sampling with extreme scale values.""" + step = FluxForwardStep() + + # Test with very small scale + u_small = step.compute_density_for_timestep_sampling("mode", 10, mode_scale=0.01) + assert (u_small >= 0).all() and (u_small <= 1).all() + + # Test with larger scale + u_large = step.compute_density_for_timestep_sampling("mode", 10, mode_scale=2.0) + assert (u_large >= 0).all() and (u_large <= 1).all() + + def test_prepare_latent_image_ids_different_sizes(self): + """Test _prepare_latent_image_ids with different image sizes.""" + step = FluxForwardStep() + + for height, width in [(32, 32), (64, 64), (128, 128)]: + ids = step._prepare_latent_image_ids(2, height, width, torch.device("cpu"), torch.float32) + + expected_seq_len = (height // 2) * (width // 2) + assert ids.shape == (2, expected_seq_len, 3) + + +class TestFluxForwardStepScheduler: + """Test scheduler integration.""" + + def test_scheduler_initialized_with_correct_steps(self): + """Test that scheduler is initialized with correct number of steps.""" + scheduler_steps = 500 + step = FluxForwardStep(scheduler_steps=scheduler_steps) + + assert step.scheduler.num_train_timesteps == scheduler_steps + assert len(step.scheduler.timesteps) == scheduler_steps + + def test_scheduler_timesteps_in_valid_range(self): + """Test that scheduler timesteps are in valid range.""" + step = FluxForwardStep() + + assert (step.scheduler.timesteps >= 0).all() + assert (step.scheduler.timesteps <= step.scheduler.num_train_timesteps).all() + + def test_scheduler_sigmas_in_valid_range(self): + """Test that scheduler sigmas are in valid range.""" + step = FluxForwardStep() + + assert (step.scheduler.sigmas >= 0).all() + assert (step.scheduler.sigmas <= 1).all() diff --git a/tests/unit_tests/diffusion/model/wan/__init__.py b/tests/unit_tests/diffusion/model/wan/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/diffusion/model/wan/conversion/__init__.py b/tests/unit_tests/diffusion/model/wan/conversion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/diffusion/model/wan/conversion/test_wan_bridge.py b/tests/unit_tests/diffusion/model/wan/conversion/test_wan_bridge.py new file mode 100644 index 0000000000..400e10f270 --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/conversion/test_wan_bridge.py @@ -0,0 +1,116 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types + +import pytest + +from megatron.bridge.diffusion.models.wan.conversion import wan_bridge as wan_bridge_module + + +pytestmark = [pytest.mark.unit] + + +def _make_cfg( + *, + num_layers=4, + num_attention_heads=8, + attention_head_dim=64, + ffn_dim=1024, + in_channels=16, + out_channels=16, + text_dim=4096, + patch_size=(1, 2), + freq_dim=256, + eps=1e-6, +): + cfg = types.SimpleNamespace() + cfg.num_layers = num_layers + cfg.num_attention_heads = num_attention_heads + cfg.attention_head_dim = attention_head_dim + cfg.ffn_dim = ffn_dim + cfg.in_channels = in_channels + cfg.out_channels = out_channels + cfg.text_dim = text_dim + cfg.patch_size = patch_size + cfg.freq_dim = freq_dim + cfg.eps = eps + return cfg + + +def test_provider_bridge_constructs_provider_with_expected_fields(): + class DummyHF: + def __init__(self, cfg): + self.config = cfg + + cfg = _make_cfg() + bridge = wan_bridge_module.WanBridge() + provider = bridge.provider_bridge(DummyHF(cfg)) + + # Basic sanity: returned type and a few key fields + assert provider is not None + assert provider.num_layers == cfg.num_layers + # hidden_size and crossattn_emb_size computed from heads and head_dim + expected_hsize = cfg.num_attention_heads * cfg.attention_head_dim + assert provider.hidden_size == expected_hsize + assert provider.crossattn_emb_size == expected_hsize + # kv_channels equals per-head dim + assert getattr(provider, "kv_channels") == cfg.attention_head_dim + # patch sizes split into temporal/spatial + assert provider.patch_temporal == cfg.patch_size[0] + assert provider.patch_spatial == cfg.patch_size[1] + # passthrough fields + assert provider.in_channels == cfg.in_channels + assert provider.out_channels == cfg.out_channels + assert provider.text_dim == cfg.text_dim + assert provider.freq_dim == cfg.freq_dim + assert provider.layernorm_epsilon == cfg.eps + # defaults enforced by bridge + assert provider.hidden_dropout == 0 + assert provider.attention_dropout == 0 + + +def test_mapping_registry_registers_module_types_and_builds_mappings(monkeypatch): + calls_register_module_type = [] + + def fake_register_module_type(name, parallelism): + calls_register_module_type.append((name, parallelism)) + + constructed_registry_args = {} + + class FakeRegistry: + def __init__(self, *mappings): + constructed_registry_args["mappings"] = mappings + + monkeypatch.setattr(wan_bridge_module.AutoMapping, "register_module_type", staticmethod(fake_register_module_type)) + monkeypatch.setattr(wan_bridge_module, "MegatronMappingRegistry", FakeRegistry) + + registry = wan_bridge_module.WanBridge().mapping_registry() + + # Verify module type registrations + assert ("Linear", "replicated") in calls_register_module_type + assert ("Conv3d", "replicated") in calls_register_module_type + assert ("WanAdaLN", "replicated") in calls_register_module_type + assert ("Head", "replicated") in calls_register_module_type + + # We replaced the real registry with FakeRegistry; the function should return that instance + assert isinstance(registry, FakeRegistry) + mappings = constructed_registry_args["mappings"] + # Ensure we have a reasonable number of mappings and a mix of kinds + assert len(mappings) >= 10 + # Expect at least one AutoMapping, one KVMapping, one QKVMapping + has_auto = any(m.__class__.__name__ == "AutoMapping" for m in mappings) + has_kv = any(m.__class__.__name__ == "KVMapping" for m in mappings) + has_qkv = any(m.__class__.__name__ == "QKVMapping" for m in mappings) + assert has_auto and has_kv and has_qkv diff --git a/tests/unit_tests/diffusion/model/wan/conversion/test_wan_hf_pretrained.py b/tests/unit_tests/diffusion/model/wan/conversion/test_wan_hf_pretrained.py new file mode 100644 index 0000000000..1b28ecdef7 --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/conversion/test_wan_hf_pretrained.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import pytest + +from megatron.bridge.diffusion.models.wan.conversion import wan_hf_pretrained as wan_hf_module + + +pytestmark = [pytest.mark.unit] + + +def test_load_config_uses_transformer_subfolder(monkeypatch, tmp_path): + calls = [] + + class FakeModel: + def __init__(self, cfg): + self.config = cfg + + class FakeWAN: + @classmethod + def from_pretrained(cls, path, subfolder=None): + calls.append((str(path), subfolder)) + return FakeModel(cfg={"ok": True}) + + monkeypatch.setattr(wan_hf_module, "WanTransformer3DModel", FakeWAN) + + src = tmp_path / "hf" + src.mkdir(parents=True, exist_ok=True) + hf = wan_hf_module.PreTrainedWAN(str(src)) + + # Accessing .config should trigger _load_config + cfg = hf.config + assert cfg == {"ok": True} + # Ensure we called with transformer subfolder + assert calls and calls[-1][1] == "transformer" + + +def test_state_uses_transformer_subfolder_and_caches(monkeypatch, tmp_path): + captured = {"source_path": None, "constructed": 0} + + class FakeSource: + def __init__(self, path): + captured["source_path"] = str(path) + + class FakeStateDict: + def __init__(self, source): + self.source = source + captured["constructed"] += 1 + + monkeypatch.setattr(wan_hf_module, "WanSafeTensorsStateSource", FakeSource) + monkeypatch.setattr(wan_hf_module, "StateDict", FakeStateDict) + + src = tmp_path / "hf_model" + (src / "transformer").mkdir(parents=True, exist_ok=True) + hf = wan_hf_module.PreTrainedWAN(str(src)) + + s1 = hf.state + s2 = hf.state # Cached + assert s1 is s2 + # Correct subfolder used + assert captured["source_path"] == str(src / "transformer") + # StateDict constructed only once due to caching + assert captured["constructed"] == 1 + + +def test_save_artifacts_copies_existing_files(tmp_path): + # Prepare source with transformer/config.json and index + src = tmp_path / "src" + tdir = src / "transformer" + tdir.mkdir(parents=True, exist_ok=True) + config_src = tdir / "config.json" + index_src = tdir / "diffusion_pytorch_model.safetensors.index.json" + config_data = {"a": 1} + index_data = {"weight_map": {}} + config_src.write_text(json.dumps(config_data)) + index_src.write_text(json.dumps(index_data)) + + # Destination directory + dest = tmp_path / "dest" + + hf = wan_hf_module.PreTrainedWAN(str(src)) + hf.save_artifacts(str(dest)) + + # Validate files copied + dest_tdir = dest / "transformer" + assert dest_tdir.is_dir() + assert json.loads((dest_tdir / "config.json").read_text()) == config_data + assert json.loads((dest_tdir / "diffusion_pytorch_model.safetensors.index.json").read_text()) == index_data + + +def test_save_artifacts_exports_config_when_missing(monkeypatch, tmp_path): + class FakeCfg: + def to_dict(self): + return {"from_model": True} + + class FakeModel: + def __init__(self): + self.config = FakeCfg() + + class FakeWAN: + @classmethod + def from_pretrained(cls, path, subfolder=None): + # Ensure it targets the transformer subfolder + assert subfolder == "transformer" + return FakeModel() + + monkeypatch.setattr(wan_hf_module, "WanTransformer3DModel", FakeWAN) + + src = tmp_path / "empty_src" + src.mkdir(parents=True, exist_ok=True) + + dest = tmp_path / "out" + hf = wan_hf_module.PreTrainedWAN(str(src)) + hf.save_artifacts(dest) + + # Should create transformer/config.json with exported contents + dest_cfg = dest / "transformer" / "config.json" + assert dest_cfg.is_file() + assert json.loads(dest_cfg.read_text()) == {"from_model": True} + + +def test_save_artifacts_handles_export_failure(monkeypatch, tmp_path): + class FailingWAN: + @classmethod + def from_pretrained(cls, path, subfolder=None): + raise RuntimeError("fail") + + monkeypatch.setattr(wan_hf_module, "WanTransformer3DModel", FailingWAN) + + src = tmp_path / "src2" + src.mkdir(parents=True, exist_ok=True) + dest = tmp_path / "dest2" + + hf = wan_hf_module.PreTrainedWAN(str(src)) + # Should not raise + hf.save_artifacts(dest) + + # Transformer folder created but no config.json written + dest_tdir = dest / "transformer" + assert dest_tdir.is_dir() + assert not (dest_tdir / "config.json").exists() diff --git a/tests/unit_tests/diffusion/model/wan/flow_matching/__init__.py b/tests/unit_tests/diffusion/model/wan/flow_matching/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_inference_pipeline.py b/tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_inference_pipeline.py new file mode 100644 index 0000000000..1ae5bd83d5 --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_inference_pipeline.py @@ -0,0 +1,90 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch + +from megatron.bridge.diffusion.models.wan.flow_matching.flow_inference_pipeline import FlowInferencePipeline + + +def test_select_checkpoint_dir_latest(tmp_path): + base = tmp_path / "ckpts" + os.makedirs(base / "iter_0000100") + os.makedirs(base / "iter_0000200") + + # Minimal inference config object + class _Cfg: + num_train_timesteps = 1000 + param_dtype = torch.float32 + text_len = 512 + t5_dtype = torch.float32 + vae_stride = (1, 1, 1) + patch_size = (1, 1, 1) + + # Instantiate object without running heavy init by patching __init__ to a no-op + pip = object.__new__(FlowInferencePipeline) + + pip.inference_cfg = _Cfg() + + latest = FlowInferencePipeline._select_checkpoint_dir(pip, str(base), checkpoint_step=None) + assert latest.endswith("iter_0000200") + + specific = FlowInferencePipeline._select_checkpoint_dir(pip, str(base), checkpoint_step=100) + assert specific.endswith("iter_0000100") + + with pytest.raises(FileNotFoundError): + FlowInferencePipeline._select_checkpoint_dir(pip, str(base), checkpoint_step=999) + + +def test_forward_pp_step_no_pp(monkeypatch): + # Build a minimal instance skipping heavy init + pip = object.__new__(FlowInferencePipeline) + + class _Model: + class _Cfg: + hidden_size = 16 + qkv_format = "sbhd" + + config = _Cfg() + + def __call__(self, x, grid_sizes, t, **kwargs): + return x # echo input + + def set_input_tensor(self, x): + pass + + pip.model = _Model() + + # Patch parallel state to no-PP path + from megatron.core import parallel_state + + monkeypatch.setattr(parallel_state, "get_pipeline_model_parallel_world_size", lambda: 1, raising=False) + + S, B, H = 8, 1, pip.model.config.hidden_size + latent_model_input = torch.randn(S, B, H, dtype=torch.float32) + grid_sizes = [(2, 2, 2)] + timestep = torch.tensor([10.0], dtype=torch.float32) + arg_c = {} + + out = FlowInferencePipeline.forward_pp_step( + pip, + latent_model_input=latent_model_input, + grid_sizes=grid_sizes, + max_video_seq_len=S, + timestep=timestep, + arg_c=arg_c, + ) + assert out.shape == latent_model_input.shape diff --git a/tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_matching_pipeline_wan.py b/tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_matching_pipeline_wan.py new file mode 100644 index 0000000000..efe9b08516 --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_matching_pipeline_wan.py @@ -0,0 +1,211 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import pytest +import torch +from dfm.src.automodel.flow_matching.adapters.base import FlowMatchingContext + +from megatron.bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan import ( + WanAdapter, + WanFlowMatchingPipeline, +) + + +class TestWanAdapter: + @pytest.fixture + def adapter(self): + return WanAdapter() + + @pytest.fixture + def context(self): + # Create a mock context with necessary attributes + ctx = MagicMock(spec=FlowMatchingContext) + + # Setup inputs + batch_size = 2 + seq_len = 8 + hidden_dim = 16 + + # Input latents are typically (B, S, H) before adapter + ctx.noisy_latents = torch.randn(batch_size, seq_len, hidden_dim) + ctx.video_latents = torch.randn(batch_size, seq_len, hidden_dim) + ctx.timesteps = torch.tensor([0.5, 0.5]) + + ctx.batch = { + "grid_sizes": [(4, 4, 4)] * batch_size, + "loss_mask": torch.ones(batch_size, seq_len), + "context_embeddings": torch.randn(batch_size, seq_len, hidden_dim), # B, S, H + "packed_seq_params": { + "self_attention": MagicMock(cu_seqlens_q_padded=None), + "cross_attention": MagicMock(cu_seqlens_kv_padded=None), + }, + } + return ctx + + def test_prepare_inputs_no_cp(self, adapter, context): + with patch( + "megatron.bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan.parallel_state" + ) as mock_ps: + mock_ps.get_context_parallel_world_size.return_value = 1 + + inputs = adapter.prepare_inputs(context) + + # Check keys + assert "noisy_latents" in inputs + assert "grid_sizes" in inputs + assert "timesteps" in inputs + assert "context_embeddings" in inputs + assert "packed_seq_params" in inputs + + # Check shapes and types + # noisy_latents should be transposed to (S, B, H) from (B, S, H) and cast to bf16 + # Input was (2, 8, 16), so expected is (8, 2, 16) + assert inputs["noisy_latents"].shape == (8, 2, 16) + assert inputs["noisy_latents"].dtype == torch.bfloat16 + + # context_embeddings should be (B, S, H) (2, 8, 16) and cast to bf16 + assert inputs["context_embeddings"].shape == (2, 8, 16) + assert inputs["context_embeddings"].dtype == torch.bfloat16 + + # Timesteps should be bf16 + assert inputs["timesteps"].dtype == torch.bfloat16 + + def test_prepare_inputs_with_cp(self, adapter, context): + with ( + patch( + "megatron.bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan.parallel_state" + ) as mock_ps, + patch( + "megatron.bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan.thd_split_inputs_cp" + ) as mock_split, + ): + mock_ps.get_context_parallel_world_size.return_value = 2 + mock_ps.get_context_parallel_group.return_value = "fake_group" + + # Mock split to return a dummy value so we know it was processed + # We return the input as is, but we check the call arguments + mock_split.side_effect = lambda x, *args: x + + inputs = adapter.prepare_inputs(context) + + # Verify thd_split_inputs_cp was called for noisy_latents and context_embeddings + assert mock_split.call_count == 2 + + # Verify args for the calls + # We can't easily distinguish order without checking args, but we know both should be called. + + # Check types are correct (bf16) + assert inputs["noisy_latents"].dtype == torch.bfloat16 + assert inputs["context_embeddings"].dtype == torch.bfloat16 + + def test_forward(self, adapter): + model = MagicMock() + model.return_value = torch.randn(8, 2, 16) # S, B, H + + inputs = { + "noisy_latents": torch.randn(8, 2, 16), + "grid_sizes": [], + "timesteps": torch.tensor([1.0]), + "context_embeddings": torch.randn(2, 8, 16), + "packed_seq_params": {}, + } + + # Mock post_process_prediction inherited from ModelAdapter + adapter.post_process_prediction = MagicMock(side_effect=lambda x: x) + + out = adapter.forward(model, inputs) + + model.assert_called_once_with( + x=inputs["noisy_latents"], + grid_sizes=inputs["grid_sizes"], + t=inputs["timesteps"], + context=inputs["context_embeddings"], + packed_seq_params=inputs["packed_seq_params"], + ) + assert out is not None + + +class TestWanFlowMatchingPipeline: + @pytest.fixture + def pipeline(self): + # Use object.__new__ to avoid __init__ if it's heavy + pip = object.__new__(WanFlowMatchingPipeline) + return pip + + def test_determine_task_type(self, pipeline): + assert pipeline.determine_task_type("any") == "t2v" + + def test_compute_loss_no_cp(self, pipeline): + model_pred = torch.randn(8, 2, 16) # S, B, H + target = torch.randn(2, 8, 16) # B, S, H + sigma = torch.randn(2) + + batch = { + "loss_mask": torch.ones(2, 8), + "packed_seq_params": {"self_attention": MagicMock(cu_seqlens_q_padded=None)}, + } + + with ( + patch( + "dfm.src.automodel.flow_matching.flow_matching_pipeline.FlowMatchingPipeline.compute_loss" + ) as mock_super_loss, + patch( + "megatron.bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan.parallel_state" + ) as mock_ps, + ): + mock_ps.get_context_parallel_world_size.return_value = 1 + mock_super_loss.return_value = (1, 2, 3, 4, 5, batch["loss_mask"]) + + pipeline.compute_loss(model_pred, target, sigma, batch) + + # target should be transposed to (S, B, H) before passing to super + # Input target was (2, 8, 16), so expected is (8, 2, 16) + args, _ = mock_super_loss.call_args + passed_target = args[1] + assert passed_target.shape == (8, 2, 16) + + def test_compute_loss_with_cp(self, pipeline): + model_pred = torch.randn(8, 2, 16) + target = torch.randn(2, 8, 16) + sigma = torch.randn(2) + + batch = { + "loss_mask": torch.ones(2, 8), + "packed_seq_params": {"self_attention": MagicMock(cu_seqlens_q_padded="dummy_seq_len")}, + } + + with ( + patch( + "dfm.src.automodel.flow_matching.flow_matching_pipeline.FlowMatchingPipeline.compute_loss" + ) as mock_super_loss, + patch( + "megatron.bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan.parallel_state" + ) as mock_ps, + patch( + "megatron.bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan.thd_split_inputs_cp" + ) as mock_split, + ): + mock_ps.get_context_parallel_world_size.return_value = 2 + mock_ps.get_context_parallel_group.return_value = "fake_group" + mock_super_loss.return_value = (1, 2, 3, 4, 5, batch["loss_mask"]) + + mock_split.side_effect = lambda x, *args: x # Identity for simplicity + + pipeline.compute_loss(model_pred, target, sigma, batch) + + # Check thd_split_inputs_cp calls + # Should be called for target and split_loss_mask + assert mock_split.call_count == 2 diff --git a/tests/unit_tests/diffusion/model/wan/flow_matching/test_time_shift_utils.py b/tests/unit_tests/diffusion/model/wan/flow_matching/test_time_shift_utils.py new file mode 100644 index 0000000000..c42809eaf9 --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/flow_matching/test_time_shift_utils.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron.bridge.diffusion.models.wan.flow_matching.time_shift_utils import ( + compute_density_for_timestep_sampling, + get_flow_match_loss_weight, + time_shift, +) + + +def test_time_shift_constant_linear_sqrt_bounds_and_monotonic(): + t_small = torch.tensor(0.1, dtype=torch.float32) + t_large = torch.tensor(0.9, dtype=torch.float32) + seq_len = 512 + + # constant + s_small = time_shift(t_small, image_seq_len=seq_len, shift_type="constant", constant=3.0) + s_large = time_shift(t_large, image_seq_len=seq_len, shift_type="constant", constant=3.0) + assert 0.0 <= s_small.item() <= 1.0 + assert 0.0 <= s_large.item() <= 1.0 + assert s_large > s_small + + # linear + s_small = time_shift(t_small, image_seq_len=seq_len, shift_type="linear", base_shift=0.5, max_shift=1.15) + s_large = time_shift(t_large, image_seq_len=seq_len, shift_type="linear", base_shift=0.5, max_shift=1.15) + assert 0.0 <= s_small.item() <= 1.0 + assert 0.0 <= s_large.item() <= 1.0 + assert s_large > s_small + + # sqrt + s_small = time_shift(t_small, image_seq_len=seq_len, shift_type="sqrt") + s_large = time_shift(t_large, image_seq_len=seq_len, shift_type="sqrt") + assert 0.0 <= s_small.item() <= 1.0 + assert 0.0 <= s_large.item() <= 1.0 + assert s_large > s_small + + +def test_compute_density_for_timestep_sampling_modes_and_ranges(): + batch_size = 16 + for mode in ["uniform", "logit_normal", "mode"]: + u = compute_density_for_timestep_sampling(mode, batch_size=batch_size, logit_mean=0.0, logit_std=1.0) + assert u.shape == (batch_size,) + assert torch.all((0.0 <= u) & (u <= 1.0)) + + +def test_get_flow_match_loss_weight_simple_cases(): + sigma = torch.zeros(5, dtype=torch.float32) + w = get_flow_match_loss_weight(sigma, shift=3.0) + assert torch.allclose(w, torch.ones_like(w)) + + sigma = torch.ones(5, dtype=torch.float32) + w = get_flow_match_loss_weight(sigma, shift=2.0) + assert torch.allclose(w, torch.full_like(sigma, 3.0)) diff --git a/tests/unit_tests/diffusion/model/wan/inference/__init__.py b/tests/unit_tests/diffusion/model/wan/inference/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/diffusion/model/wan/inference/test_inference_init.py b/tests/unit_tests/diffusion/model/wan/inference/test_inference_init.py new file mode 100644 index 0000000000..aa4100120c --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/inference/test_inference_init.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.bridge.diffusion.models.wan.inference import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES + + +def test_size_configs_structure_and_values(): + assert isinstance(SIZE_CONFIGS, dict) + for key, val in SIZE_CONFIGS.items(): + assert isinstance(key, str) + assert isinstance(val, tuple) and len(val) == 2 + w, h = val + assert isinstance(w, int) and isinstance(h, int) + assert w > 0 and h > 0 + + +def test_max_area_configs_consistency(): + for size_key, area in MAX_AREA_CONFIGS.items(): + w, h = SIZE_CONFIGS[size_key] + assert area == w * h + + +def test_supported_sizes_lists(): + assert "t2v-14B" in SUPPORTED_SIZES + assert "t2v-1.3B" in SUPPORTED_SIZES + for model_key, sizes in SUPPORTED_SIZES.items(): + assert isinstance(sizes, tuple) + for s in sizes: + assert s in SIZE_CONFIGS diff --git a/tests/unit_tests/diffusion/model/wan/inference/test_inference_utils.py b/tests/unit_tests/diffusion/model/wan/inference/test_inference_utils.py new file mode 100644 index 0000000000..8b960e2990 --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/inference/test_inference_utils.py @@ -0,0 +1,84 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import tempfile + +import torch + +from megatron.bridge.diffusion.models.wan.inference import utils as inf_utils + + +def test_str2bool_variants_and_errors(): + true_vals = ["yes", "true", "t", "y", "1", "TRUE", "Yes"] + false_vals = ["no", "false", "f", "n", "0", "FALSE", "No"] + for v in true_vals: + assert inf_utils.str2bool(v) is True + for v in false_vals: + assert inf_utils.str2bool(v) is False + assert inf_utils.str2bool(True) is True + assert inf_utils.str2bool(False) is False + try: + inf_utils.str2bool("maybe") + except argparse.ArgumentTypeError: + pass + else: + assert False, "Expected argparse.ArgumentTypeError for invalid boolean string" + + +def test_cache_image_writes_file(tmp_path): + # Small 3x8x8 image + img = torch.rand(3, 8, 8) + out_path = tmp_path / "test.png" + saved = inf_utils.cache_image(img, str(out_path), nrow=1, normalize=False, value_range=(0.0, 1.0), retry=1) + assert saved == str(out_path) + assert os.path.exists(out_path) + assert os.path.getsize(out_path) > 0 + + +def test_cache_video_uses_writer_and_returns_path(monkeypatch): + # Stub imageio.get_writer to avoid codec dependency + calls = {"frames": 0, "path": None} + + class _DummyWriter: + def __init__(self, path, fps=None, codec=None, quality=None): + calls["path"] = path + + def append_data(self, frame): + calls["frames"] += 1 + + def close(self): + pass + + monkeypatch.setattr( + inf_utils.imageio, "get_writer", lambda path, fps, codec, quality: _DummyWriter(path, fps, codec, quality) + ) + + # Stub make_grid to return a fixed CHW tensor regardless of input + def _fake_make_grid(x, nrow, normalize, value_range): + return torch.rand(3, 4, 5) + + monkeypatch.setattr(inf_utils.torchvision.utils, "make_grid", _fake_make_grid) + + # Build a tensor whose unbind(2) yields 2 slices so we expect 2 frames written + vid = torch.rand(3, 3, 2, 2) # shape chosen to exercise unbind(2) + with tempfile.TemporaryDirectory() as td: + out_file = os.path.join(td, "out.mp4") + result = inf_utils.cache_video( + vid, save_file=out_file, fps=5, suffix=".mp4", nrow=1, normalize=False, value_range=(0.0, 1.0), retry=1 + ) + assert result == out_file + assert calls["path"] == out_file + assert calls["frames"] == vid.shape[2] # frames equal to number of unbinds on dim=2 diff --git a/tests/unit_tests/diffusion/model/wan/test_rope_utils.py b/tests/unit_tests/diffusion/model/wan/test_rope_utils.py new file mode 100644 index 0000000000..f69cab4479 --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/test_rope_utils.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from megatron.bridge.diffusion.models.wan.rope_utils import Wan3DRopeEmbeddings + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +def test_wan3d_rope_embeddings_shapes_and_padding(): + # Small, CPU-friendly config + n_head = 2 + dim_head = 8 # must be divisible with the internal splits + max_position_len = 16 + rope = Wan3DRopeEmbeddings(dim_head=dim_head, max_position_len=max_position_len) + + # Two samples with different (f, h, w) + grid_sizes = torch.tensor([[2, 3, 2], [4, 1, 1]], dtype=torch.int32) + seq_lens = [(2 * 3 * 2), (4 * 1 * 1)] + padded_lens = [seq_lens[0] + 2, seq_lens[1]] # pad first sample + + cu_seqlens_q_padded = torch.tensor([0, padded_lens[0], padded_lens[0] + padded_lens[1]], dtype=torch.int32) + + out = rope( + n_head=n_head, + dim_head=dim_head, + cu_seqlens_q_padded=cu_seqlens_q_padded, + grid_sizes=grid_sizes, + device=torch.device("cpu"), + ) + + # Total concatenated length equals sum of padded lens + assert out.shape == (sum(padded_lens), 1, 1, dim_head) + + # Check that padding region for the first sample is zero + first_seq_len = seq_lens[0] + first_padded_len = padded_lens[0] + tail = out[first_seq_len:first_padded_len] + assert torch.all(tail == 0), "Padded region should be zeros" diff --git a/tests/unit_tests/diffusion/model/wan/test_utils.py b/tests/unit_tests/diffusion/model/wan/test_utils.py new file mode 100644 index 0000000000..e46770cb08 --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/test_utils.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron.bridge.diffusion.models.wan.utils import grid_sizes_calculation, patchify, unpatchify + + +def test_grid_sizes_calculation_basic(): + input_shape = (4, 8, 6) + patch_size = (1, 2, 3) + f, h, w = grid_sizes_calculation(input_shape, patch_size) + assert (f, h, w) == (4, 4, 2) + + +def test_patchify_unpatchify_roundtrip(): + # Video latent: [c, F_patches * pF, H_patches * pH, W_patches * pW] + c = 3 + F_patches, H_patches, W_patches = 2, 2, 3 + patch_size = (1, 2, 2) + F_latents = F_patches * patch_size[0] + H_latents = H_patches * patch_size[1] + W_latents = W_patches * patch_size[2] + + x = [torch.randn(c, F_latents, H_latents, W_latents)] + + patches = patchify(x, patch_size) + assert isinstance(patches, list) and len(patches) == 1 + seq_len, dim = patches[0].shape + assert seq_len == F_patches * H_patches * W_patches + assert dim == c * (patch_size[0] * patch_size[1] * patch_size[2]) + + # Unpatchify and compare + y = unpatchify(patches, [[F_patches, H_patches, W_patches]], out_dim=c, patch_size=patch_size) + assert isinstance(y, list) and len(y) == 1 + assert y[0].shape == x[0].shape + torch.testing.assert_close(y[0], x[0], rtol=1e-5, atol=1e-5) diff --git a/tests/unit_tests/diffusion/model/wan/test_wan_layer_spec.py b/tests/unit_tests/diffusion/model/wan/test_wan_layer_spec.py new file mode 100644 index 0000000000..cbd463fc3f --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/test_wan_layer_spec.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.bridge.diffusion.models.wan.wan_layer_spec import get_wan_block_with_transformer_engine_spec + + +def test_get_wan_block_with_transformer_engine_spec_basic(): + spec = get_wan_block_with_transformer_engine_spec() + # Basic structure checks + assert hasattr(spec, "module") + assert hasattr(spec, "submodules") + sub = spec.submodules + # Expected submodule fields exist + for name in ["norm1", "norm2", "norm3", "full_self_attention", "cross_attention", "mlp"]: + assert hasattr(sub, name), f"Missing submodule {name}" diff --git a/tests/unit_tests/diffusion/model/wan/test_wan_model_misc.py b/tests/unit_tests/diffusion/model/wan/test_wan_model_misc.py new file mode 100644 index 0000000000..e1b79d70a7 --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/test_wan_model_misc.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron.bridge.diffusion.models.wan.wan_model import sinusoidal_embedding_1d + + +def test_sinusoidal_embedding_1d_shape_and_dtype(): + dim = 16 + pos = torch.arange(10, dtype=torch.float32) + emb = sinusoidal_embedding_1d(dim, pos) + assert emb.shape == (pos.shape[0], dim) + assert emb.dtype == torch.float32 diff --git a/tests/unit_tests/diffusion/model/wan/test_wan_provider.py b/tests/unit_tests/diffusion/model/wan/test_wan_provider.py new file mode 100644 index 0000000000..878a3c9711 --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/test_wan_provider.py @@ -0,0 +1,84 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from megatron.core import parallel_state + +import megatron.bridge.diffusion.models.wan.wan_model as wan_model_module +from megatron.bridge.diffusion.models.wan.wan_model import WanModel +from megatron.bridge.diffusion.models.wan.wan_provider import WanModelProvider + + +def test_wan_model_provider_provide_returns_model(monkeypatch): + # Force pipeline stage booleans to avoid dependency on initialized model parallel + monkeypatch.setattr(parallel_state, "is_pipeline_first_stage", lambda: True, raising=False) + monkeypatch.setattr(parallel_state, "is_pipeline_last_stage", lambda: True, raising=False) + # Avoid querying uninitialized PP groups + monkeypatch.setattr(parallel_state, "get_pipeline_model_parallel_world_size", lambda: 1, raising=False) + + # Bypass Megatron's ProcessGroupCollection usage inside TransformerBlock during construction. + # CI does not initialize distributed groups; a dummy block suffices for construction checks. + class DummyTransformerBlock(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.input_tensor = None + + def set_input_tensor(self, input_tensor): + self.input_tensor = input_tensor + + def forward(self, hidden_states, **kwargs): + return hidden_states + + monkeypatch.setattr(wan_model_module, "TransformerBlock", DummyTransformerBlock, raising=False) + + provider = WanModelProvider( + num_layers=2, # keep small + hidden_size=64, + ffn_hidden_size=128, + num_attention_heads=4, + layernorm_epsilon=1e-6, + normalization="RMSNorm", + layernorm_zero_centered_gamma=False, + layernorm_across_heads=True, + add_qkv_bias=True, + rotary_interleaved=True, + hidden_dropout=0.0, + attention_dropout=0.0, + fp16_lm_cross_entropy=False, + parallel_output=True, + bf16=False, + params_dtype=torch.float32, + qkv_format="sbhd", + seq_length=128, + share_embeddings_and_output_weights=False, + vocab_size=32000, + make_vocab_size_divisible_by=128, + in_channels=4, + out_channels=4, + patch_spatial=2, + patch_temporal=1, + freq_dim=16, + text_len=32, + text_dim=64, + ) + # Ensure config supplies fields expected by core attention + provider.kv_channels = provider.hidden_size // provider.num_attention_heads + provider.num_query_groups = provider.num_attention_heads + model = provider.provide() + assert isinstance(model, WanModel) + # Sanity check key config properties were plumbed + assert model.config.hidden_size == 64 + assert model.config.num_attention_heads == 4 + assert model.config.text_dim == 64 diff --git a/tests/unit_tests/diffusion/model/wan/test_wan_step.py b/tests/unit_tests/diffusion/model/wan/test_wan_step.py new file mode 100644 index 0000000000..c1ad9662af --- /dev/null +++ b/tests/unit_tests/diffusion/model/wan/test_wan_step.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from megatron.bridge.diffusion.models.wan.wan_step import WanForwardStep, wan_data_step + + +class _DummyIter: + def __init__(self, batch): + # mimic attribute used inside wan_data_step + self.iterable = [batch] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="wan_data_step moves tensors to CUDA") +def test_wan_data_step_builds_packed_seq_params_cuda_guarded(): + # Construct minimal batch with required seq_len fields + # S=8, B=2 + batch = { + "seq_len_q": torch.tensor([3, 5], dtype=torch.int32), + "seq_len_q_padded": torch.tensor([4, 6], dtype=torch.int32), + "seq_len_kv": torch.tensor([2, 7], dtype=torch.int32), + "seq_len_kv_padded": torch.tensor([2, 8], dtype=torch.int32), + # include a tensor field to exercise device transfer + # shape: [S, B, H, D] + "video_latents": torch.randn(8, 2, 4, 16, dtype=torch.float32), + } + it = iter(_DummyIter(batch).iterable) + qkv_format = "sbhd" + out = wan_data_step(qkv_format, it) + + assert "packed_seq_params" in out + for k in ["self_attention", "cross_attention"]: + assert k in out["packed_seq_params"] + p = out["packed_seq_params"][k] + assert hasattr(p, "cu_seqlens_q") + assert hasattr(p, "cu_seqlens_q_padded") + assert hasattr(p, "cu_seqlens_kv") + assert hasattr(p, "cu_seqlens_kv_padded") + # spot-check CUDA device after move + assert out["video_latents"].is_cuda + # Verify transpose from (S, B, H, D) -> (B, S, H, D) + assert out["video_latents"].shape == (2, 8, 4, 16) + + +def test_wan_forward_step_loss_partial_creation(): + step = WanForwardStep() + mask = torch.ones(4, dtype=torch.float32) + loss_fn = step._create_loss_function(mask, check_for_nan_in_loss=False, check_for_spiky_loss=False) + # Just validate it's callable and is a functools.partial + import functools + + assert isinstance(loss_fn, functools.partial) + assert callable(loss_fn) diff --git a/tests/unit_tests/diffusion/recipes/flux/__init__.py b/tests/unit_tests/diffusion/recipes/flux/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/diffusion/recipes/flux/test_flux_recipe.py b/tests/unit_tests/diffusion/recipes/flux/test_flux_recipe.py new file mode 100644 index 0000000000..a2ea329415 --- /dev/null +++ b/tests/unit_tests/diffusion/recipes/flux/test_flux_recipe.py @@ -0,0 +1,156 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import pytest + +from megatron.bridge.diffusion.data.flux.flux_mock_datamodule import FluxMockDataModuleConfig +from megatron.bridge.diffusion.models.flux.flux_provider import FluxProvider +from megatron.bridge.diffusion.recipes.flux.flux import model_config, pretrain_config +from megatron.bridge.training.config import ConfigContainer + + +pytestmark = [pytest.mark.unit] + + +class TestModelConfig: + """Tests for model_config function.""" + + def test_model_config_returns_flux_provider_with_defaults(self): + """Test that model_config returns a FluxProvider with correct defaults.""" + config = model_config() + + assert isinstance(config, FluxProvider) + + # Parallelism defaults + assert config.tensor_model_parallel_size == 1 + assert config.pipeline_model_parallel_size == 1 + assert config.sequence_parallel is False + + # FLUX-specific defaults + assert config.num_joint_layers == 19 + assert config.num_single_layers == 38 + assert config.hidden_size == 3072 + assert config.num_attention_heads == 24 + + def test_model_config_custom_parameters(self): + """Test model_config with custom parameters.""" + config = model_config( + tensor_parallelism=2, + pipeline_parallelism=4, + num_joint_layers=10, + num_single_layers=20, + hidden_size=2048, + guidance_embed=True, + ) + + assert config.tensor_model_parallel_size == 2 + assert config.pipeline_model_parallel_size == 4 + assert config.num_joint_layers == 10 + assert config.num_single_layers == 20 + assert config.hidden_size == 2048 + assert config.guidance_embed is True + + +class TestPretrainConfig: + """Tests for pretrain_config function.""" + + def test_pretrain_config_returns_complete_config(self): + """Test that pretrain_config returns a ConfigContainer with all required components.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = pretrain_config(dir=tmpdir, mock=True) + + assert isinstance(config, ConfigContainer) + assert isinstance(config.model, FluxProvider) + assert isinstance(config.dataset, FluxMockDataModuleConfig) + + # Check all required components exist + assert hasattr(config, "train") + assert hasattr(config, "optimizer") + assert hasattr(config, "scheduler") + assert hasattr(config, "ddp") + assert hasattr(config, "logger") + assert hasattr(config, "checkpoint") + + def test_pretrain_config_directory_structure(self): + """Test that pretrain_config creates correct directory structure.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = pretrain_config(dir=tmpdir, name="test_run", mock=True) + + assert "test_run" in config.checkpoint.save + assert "test_run" in config.logger.tensorboard_dir + assert config.checkpoint.save.endswith("checkpoints") + + def test_pretrain_config_custom_training_parameters(self): + """Test pretrain_config with custom training parameters.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = pretrain_config( + dir=tmpdir, + mock=True, + train_iters=5000, + global_batch_size=8, + micro_batch_size=2, + lr=5e-5, + ) + + assert config.train.train_iters == 5000 + assert config.train.global_batch_size == 8 + assert config.train.micro_batch_size == 2 + + def test_pretrain_config_custom_model_parameters(self): + """Test that model parameters propagate correctly.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = pretrain_config( + dir=tmpdir, + mock=True, + num_joint_layers=12, + hidden_size=2048, + guidance_embed=True, + tensor_parallelism=2, + ) + + assert config.model.num_joint_layers == 12 + assert config.model.hidden_size == 2048 + assert config.model.guidance_embed is True + assert config.model.tensor_model_parallel_size == 2 + + def test_pretrain_config_mock_dataset_configuration(self): + """Test pretrain_config with mock dataset parameters.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = pretrain_config( + dir=tmpdir, + mock=True, + image_H=512, + image_W=512, + vae_channels=16, + ) + + assert config.dataset.image_H == 512 + assert config.dataset.image_W == 512 + assert config.dataset.vae_channels == 16 + + def test_pretrain_config_with_real_dataset(self): + """Test pretrain_config with real dataset configuration.""" + with tempfile.TemporaryDirectory() as tmpdir: + data_path = os.path.join(tmpdir, "data") + os.makedirs(data_path, exist_ok=True) + + config = pretrain_config(dir=tmpdir, mock=False, data_paths=[data_path]) + + from megatron.bridge.diffusion.data.flux.flux_energon_datamodule import FluxDataModuleConfig + + assert isinstance(config.dataset, FluxDataModuleConfig) + assert config.dataset.path == [data_path]