diff --git a/.github/workflows/config/.secrets.baseline b/.github/workflows/config/.secrets.baseline index c6860eec..83fa8afc 100644 --- a/.github/workflows/config/.secrets.baseline +++ b/.github/workflows/config/.secrets.baseline @@ -139,10 +139,10 @@ "filename": "examples/automodel/pretrain/cicd/wan21_cicd_nightly_video.yaml", "hashed_secret": "c70f071570ba65f9c4079d6051e955ff4f802eea", "is_verified": false, - "line_number": 61, + "line_number": 67, "is_secret": false } ] }, - "generated_at": "2026-01-29T18:49:17Z" + "generated_at": "2026-01-30T18:50:34Z" } diff --git a/dfm/src/automodel/_diffusers/__init__.py b/dfm/src/automodel/_diffusers/__init__.py index 45dbbf79..fc4cb5ce 100644 --- a/dfm/src/automodel/_diffusers/__init__.py +++ b/dfm/src/automodel/_diffusers/__init__.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .auto_diffusion_pipeline import NeMoAutoDiffusionPipeline +from .auto_diffusion_pipeline import NeMoAutoDiffusionPipeline, PipelineSpec __all__ = [ "NeMoAutoDiffusionPipeline", + "PipelineSpec", ] diff --git a/dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py b/dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py index 37bdebb9..d889a2ec 100644 --- a/dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py +++ b/dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py @@ -12,20 +12,51 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy +""" +NeMo Auto Diffusion Pipeline - Unified pipeline wrapper for all diffusion models. + +This module provides a single pipeline class that handles: +- Loading from pretrained weights (finetuning) via DiffusionPipeline auto-detection +- Loading from config with random weights (pretraining) via YAML-specified transformer class +- FSDP2/DDP parallelization for distributed training +- Gradient checkpointing for memory efficiency + +Usage: + # Finetuning (from_pretrained) - no pipeline_spec needed + pipe, managers = NeMoAutoDiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + load_for_training=True, + parallel_scheme={"transformer": manager_args}, + ) + + # Pretraining (from_config) - pipeline_spec required in YAML + pipe, managers = NeMoAutoDiffusionPipeline.from_config( + "black-forest-labs/FLUX.1-dev", + pipeline_spec={ + "transformer_cls": "FluxTransformer2DModel", + "subfolder": "transformer", + }, + parallel_scheme={"transformer": manager_args}, + ) +""" + import logging import os +from dataclasses import dataclass from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch import torch.nn as nn -from diffusers import DiffusionPipeline, WanPipeline +from diffusers import DiffusionPipeline from nemo_automodel.components.distributed import parallelizer from nemo_automodel.components.distributed.ddp import DDPManager from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager from nemo_automodel.shared.utils import dtype_from_str -from dfm.src.automodel.distributed.dfm_parallelizer import HunyuanParallelizationStrategy, WanParallelizationStrategy +from dfm.src.automodel.distributed.dfm_parallelizer import ( + HunyuanParallelizationStrategy, + WanParallelizationStrategy, +) logger = logging.getLogger(__name__) @@ -34,12 +65,79 @@ ParallelManager = Union[FSDP2Manager, DDPManager] +@dataclass +class PipelineSpec: + """ + YAML-driven specification for loading a diffusion pipeline. + + This is required for from_config (pretraining with random weights). + Not needed for from_pretrained (finetuning). + + Example YAML: + pipeline_spec: + transformer_cls: "FluxTransformer2DModel" + pipeline_cls: "FluxPipeline" # Optional + subfolder: "transformer" + load_full_pipeline: false + enable_gradient_checkpointing: true + """ + + # Required for from_config: transformer class name from diffusers + transformer_cls: str = "" + + # Optional: full pipeline class name (for loading VAE, text encoders, etc.) + pipeline_cls: Optional[str] = None + + # Subfolder for transformer weights in HF repo + subfolder: str = "transformer" + + # For from_config: whether to load full pipeline or just transformer + load_full_pipeline: bool = False + + # Training optimizations + enable_gradient_checkpointing: bool = True + low_cpu_mem_usage: bool = True + + @classmethod + def from_dict(cls, d: Optional[Dict[str, Any]]) -> "PipelineSpec": + """Create PipelineSpec from YAML dict.""" + if d is None: + return cls() + known_fields = {f.name for f in cls.__dataclass_fields__.values()} + filtered = {k: v for k, v in d.items() if k in known_fields} + return cls(**filtered) + + def validate_for_from_config(self): + """Validate spec has required fields for from_config.""" + if not self.transformer_cls: + raise ValueError( + "pipeline_spec.transformer_cls is required for from_config. " + "Example YAML:\n" + " pipeline_spec:\n" + " transformer_cls: 'FluxTransformer2DModel'\n" + " subfolder: 'transformer'" + ) + + +def _import_diffusers_class(class_name: str): + """Dynamically import a class from diffusers by name.""" + import diffusers + + if not hasattr(diffusers, class_name): + raise ImportError( + f"Class '{class_name}' not found in diffusers. Check pipeline_spec.transformer_cls in your YAML config." + ) + return getattr(diffusers, class_name) + + def _init_parallelizer(): + """Register custom parallelization strategies.""" parallelizer.PARALLELIZATION_STRATEGIES["WanTransformer3DModel"] = WanParallelizationStrategy() parallelizer.PARALLELIZATION_STRATEGIES["HunyuanVideo15Transformer3DModel"] = HunyuanParallelizationStrategy() def _choose_device(device: Optional[torch.device]) -> torch.device: + """Choose device, defaulting to CUDA with LOCAL_RANK if available.""" if device is not None: return device if torch.cuda.is_available(): @@ -48,7 +146,8 @@ def _choose_device(device: Optional[torch.device]) -> torch.device: return torch.device("cpu") -def _iter_pipeline_modules(pipe: DiffusionPipeline) -> Iterable[Tuple[str, nn.Module]]: +def _iter_pipeline_modules(pipe) -> Iterable[Tuple[str, nn.Module]]: + """Iterate over nn.Module components in a pipeline.""" # Prefer Diffusers' components registry when available if hasattr(pipe, "components") and isinstance(pipe.components, dict): for name, value in pipe.components.items(): @@ -69,7 +168,7 @@ def _iter_pipeline_modules(pipe: DiffusionPipeline) -> Iterable[Tuple[str, nn.Mo def _move_module_to_device(module: nn.Module, device: torch.device, torch_dtype: Any) -> None: - # torch_dtype can be "auto", torch.dtype, or string + """Move module to device with specified dtype.""" dtype: Optional[torch.dtype] if torch_dtype == "auto": dtype = None @@ -85,8 +184,7 @@ def _ensure_params_trainable(module: nn.Module, module_name: Optional[str] = Non """ Ensure that all parameters in the given module are trainable. - Returns the number of parameters marked trainable. If a module name is - provided, it will be used in the log message for clarity. + Returns the number of parameters marked trainable. """ num_trainable_parameters = 0 for parameter in module.parameters(): @@ -131,21 +229,77 @@ def _create_parallel_manager(manager_args: Dict[str, Any]) -> ParallelManager: raise ValueError(f"Unknown manager type: '{manager_type}'. Expected 'ddp' or 'fsdp2'.") -class NeMoAutoDiffusionPipeline(DiffusionPipeline): +def _apply_parallelization( + pipe, + parallel_scheme: Optional[Dict[str, Dict[str, Any]]], +) -> Dict[str, ParallelManager]: + """Apply FSDP2/DDP parallelization to pipeline components.""" + created_managers: Dict[str, ParallelManager] = {} + if parallel_scheme is None: + return created_managers + + assert torch.distributed.is_initialized(), "Distributed environment must be initialized for parallelization" + _init_parallelizer() + + for comp_name, comp_module in _iter_pipeline_modules(pipe): + manager_args = parallel_scheme.get(comp_name) + if manager_args is None: + continue + logger.info("[INFO] Applying parallelization to %s", comp_name) + manager = _create_parallel_manager(manager_args) + created_managers[comp_name] = manager + parallel_module = manager.parallelize(comp_module) + setattr(pipe, comp_name, parallel_module) + + return created_managers + + +class NeMoAutoDiffusionPipeline: """ - Drop-in Diffusers pipeline that adds optional FSDP2/DDP parallelization during from_pretrained. + Unified diffusion pipeline wrapper for all model types. + + This class serves dual purposes: + 1. Provides class methods (from_pretrained, from_config) for loading pipelines + 2. Acts as a minimal wrapper when load_full_pipeline=False (transformer-only mode) + + Two loading paths: + - from_pretrained: Uses DiffusionPipeline auto-detection (for finetuning) + No pipeline_spec needed - pipeline type is auto-detected from model_index.json + + - from_config: Uses YAML-specified transformer class (for pretraining) + Requires pipeline_spec with transformer_cls in YAML config Features: - Accepts a per-component mapping from component name to parallel manager init args - Moves all nn.Module components to the chosen device/dtype - Parallelizes only components present in the mapping by constructing a manager per component - Supports both FSDP2Manager and DDPManager via '_manager_type' key in config + - Gradient checkpointing support for memory efficiency parallel_scheme: - Dict[str, Dict[str, Any]]: component name -> kwargs for parallel manager - Each component's kwargs should include '_manager_type': 'fsdp2' or 'ddp' (defaults to 'fsdp2') """ + def __init__(self, transformer=None, **components): + """ + Initialize NeMoAutoDiffusionPipeline. + + Args: + transformer: The transformer model instance + **components: Additional pipeline components (vae, text_encoder, etc.) + """ + self.transformer = transformer + for k, v in components.items(): + setattr(self, k, v) + # Create components dict for compatibility with _iter_pipeline_modules + self._components = {"transformer": transformer, **components} + + @property + def components(self) -> Dict[str, Any]: + """Return components dict for compatibility.""" + return {k: v for k, v in self._components.items() if v is not None} + @classmethod def from_pretrained( cls, @@ -153,18 +307,48 @@ def from_pretrained( *model_args, parallel_scheme: Optional[Dict[str, Dict[str, Any]]] = None, device: Optional[torch.device] = None, - torch_dtype: Any = "auto", + torch_dtype: Any = torch.bfloat16, move_to_device: bool = True, load_for_training: bool = False, components_to_load: Optional[Iterable[str]] = None, + enable_gradient_checkpointing: bool = True, **kwargs, - ) -> tuple[DiffusionPipeline, Dict[str, ParallelManager]]: + ) -> Tuple[DiffusionPipeline, Dict[str, ParallelManager]]: + """ + Load pipeline from pretrained weights using DiffusionPipeline auto-detection. + + This method auto-detects the pipeline type from model_index.json and loads + all components. Use this for finetuning existing models. + + No pipeline_spec is needed - the pipeline type is determined automatically. + + Args: + pretrained_model_name_or_path: HuggingFace model ID or local path + parallel_scheme: Dict mapping component names to parallel manager kwargs. + Each component's kwargs should include '_manager_type': 'fsdp2' or 'ddp' + device: Device to load model to + torch_dtype: Data type for model parameters + move_to_device: Whether to move modules to device + load_for_training: Whether to make parameters trainable + components_to_load: Which components to process (default: all) + enable_gradient_checkpointing: Enable gradient checkpointing for transformer + **kwargs: Additional arguments passed to DiffusionPipeline.from_pretrained + + Returns: + Tuple of (DiffusionPipeline, Dict[str, ParallelManager]) + """ + logger.info("[INFO] Loading pipeline from pretrained: %s", pretrained_model_name_or_path) + + # Use DiffusionPipeline.from_pretrained for auto-detection pipe: DiffusionPipeline = DiffusionPipeline.from_pretrained( pretrained_model_name_or_path, *model_args, torch_dtype=torch_dtype, **kwargs, ) + + logger.info("[INFO] Loaded pipeline type: %s", type(pipe).__name__) + # Decide device dev = _choose_device(device) @@ -175,6 +359,12 @@ def from_pretrained( logger.info("[INFO] Moving module: %s to device/dtype", name) _move_module_to_device(module, dev, torch_dtype) + # Enable gradient checkpointing if configured + if enable_gradient_checkpointing: + if hasattr(pipe, "transformer") and hasattr(pipe.transformer, "enable_gradient_checkpointing"): + pipe.transformer.enable_gradient_checkpointing() + logger.info("[INFO] Enabled gradient checkpointing for transformer") + # If loading for training, ensure the target module parameters are trainable if load_for_training: for name, module in _iter_pipeline_modules(pipe): @@ -182,85 +372,106 @@ def from_pretrained( logger.info("[INFO] Ensuring params trainable: %s", name) _ensure_params_trainable(module, module_name=name) - # Use per-component manager init-args to parallelize components - created_managers: Dict[str, ParallelManager] = {} - if parallel_scheme is not None: - assert torch.distributed.is_initialized(), "Expect distributed environment to be initialized" - _init_parallelizer() - for comp_name, comp_module in _iter_pipeline_modules(pipe): - manager_args = parallel_scheme.get(comp_name) - if manager_args is None: - continue - manager = _create_parallel_manager(manager_args) - created_managers[comp_name] = manager - parallel_module = manager.parallelize(comp_module) - setattr(pipe, comp_name, parallel_module) - return pipe, created_managers + # Apply parallelization (FSDP2 or DDP) + created_managers = _apply_parallelization(pipe, parallel_scheme) - -class NeMoWanPipeline: - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - return NeMoAutoDiffusionPipeline.from_pretrained(*args, **kwargs) + return pipe, created_managers @classmethod def from_config( cls, - model_id, + model_id: str, + pipeline_spec: Dict[str, Any], torch_dtype: torch.dtype = torch.bfloat16, - config: dict = None, - parallel_scheme: Optional[Dict[str, Dict[str, Any]]] = None, device: Optional[torch.device] = None, + parallel_scheme: Optional[Dict[str, Dict[str, Any]]] = None, move_to_device: bool = True, components_to_load: Optional[Iterable[str]] = None, - ) -> tuple[WanPipeline, Dict[str, ParallelManager]]: - # Load just the config - from diffusers import WanTransformer3DModel - - if config is None: - transformer = WanTransformer3DModel.from_pretrained( - model_id, - subfolder="transformer", - torch_dtype=torch.bfloat16, - ) - - # Get config and reinitialize with random weights - config = copy.deepcopy(transformer.config) - del transformer - - # Initialize with random weights - transformer = WanTransformer3DModel.from_config(config) + **kwargs, + ) -> Tuple["NeMoAutoDiffusionPipeline", Dict[str, ParallelManager]]: + """ + Initialize pipeline with random weights using YAML-specified transformer class. + + This method uses the transformer_cls from pipeline_spec to create a model + with random weights. Use this for pretraining from scratch. + + Requires pipeline_spec in YAML config with at least: + pipeline_spec: + transformer_cls: "FluxTransformer2DModel" # or WanTransformer3DModel, etc. + subfolder: "transformer" + + Args: + model_id: HuggingFace model ID or local path (for loading config) + pipeline_spec: Dict from YAML config with transformer_cls, subfolder, etc. + torch_dtype: Data type for model parameters + device: Device to load model to + parallel_scheme: Dict mapping component names to parallel manager kwargs + move_to_device: Whether to move modules to device + components_to_load: Which components to process (default: all) + **kwargs: Additional arguments + + Returns: + Tuple of (NeMoAutoDiffusionPipeline or DiffusionPipeline, Dict[str, ParallelManager]) + """ + # Parse and validate pipeline spec + spec = PipelineSpec.from_dict(pipeline_spec) + spec.validate_for_from_config() + + logger.info("[INFO] Initializing pipeline from config with random weights") + logger.info("[INFO] Model ID: %s", model_id) + logger.info("[INFO] Transformer class: %s", spec.transformer_cls) + + # Dynamically import transformer class from diffusers + TransformerCls = _import_diffusers_class(spec.transformer_cls) + + # Load config from the model_id + logger.info("[INFO] Loading config from %s/%s", model_id, spec.subfolder) + config = TransformerCls.load_config(model_id, subfolder=spec.subfolder) + + # Initialize transformer with random weights + logger.info("[INFO] Creating %s with random weights", spec.transformer_cls) + transformer = TransformerCls.from_config(config) + transformer = transformer.to(torch_dtype) - # Load pipeline with random transformer - pipe = WanPipeline.from_pretrained( - model_id, - transformer=transformer, - torch_dtype=torch_dtype, - ) # Decide device dev = _choose_device(device) - # Move modules to device/dtype first (helps avoid initial OOM during sharding) - if move_to_device: - for name, module in _iter_pipeline_modules(pipe): - if not components_to_load or name in components_to_load: - logger.info("[INFO] Moving module: %s to device/dtype", name) - _move_module_to_device(module, dev, torch_dtype) + # Either load full pipeline or just use transformer + if spec.load_full_pipeline and spec.pipeline_cls: + # Load full pipeline with random transformer injected + PipelineCls = _import_diffusers_class(spec.pipeline_cls) + logger.info("[INFO] Loading full pipeline %s with random transformer", spec.pipeline_cls) + pipe = PipelineCls.from_pretrained( + model_id, + transformer=transformer, + torch_dtype=torch_dtype, + ) + + # Move all modules to device + if move_to_device: + for name, module in _iter_pipeline_modules(pipe): + if not components_to_load or name in components_to_load: + logger.info("[INFO] Moving module: %s to device/dtype", name) + _move_module_to_device(module, dev, torch_dtype) + else: + # Transformer only mode - use this class as minimal wrapper + if move_to_device: + transformer = transformer.to(dev) + pipe = cls(transformer=transformer) + + # Enable gradient checkpointing if configured + if spec.enable_gradient_checkpointing: + target_transformer = getattr(pipe, "transformer", transformer) + if hasattr(target_transformer, "enable_gradient_checkpointing"): + target_transformer.enable_gradient_checkpointing() + logger.info("[INFO] Enabled gradient checkpointing for transformer") + + # Make parameters trainable (always true for from_config / pretraining) + for name, module in _iter_pipeline_modules(pipe): + if not components_to_load or name in components_to_load: + _ensure_params_trainable(module, module_name=name) + + # Apply parallelization (FSDP2 or DDP) + created_managers = _apply_parallelization(pipe, parallel_scheme) - # Use per-component manager init-args to parallelize components - created_managers: Dict[str, ParallelManager] = {} - if parallel_scheme is not None: - assert torch.distributed.is_initialized(), "Expect distributed environment to be initialized" - _init_parallelizer() - for comp_name, comp_module in _iter_pipeline_modules(pipe): - manager_args = parallel_scheme.get(comp_name) - if manager_args is None: - continue - manager = _create_parallel_manager(manager_args) - created_managers[comp_name] = manager - parallel_module = manager.parallelize(comp_module) - setattr(pipe, comp_name, parallel_module) return pipe, created_managers diff --git a/dfm/src/automodel/datasets/multiresolutionDataloader/__init__.py b/dfm/src/automodel/datasets/multiresolutionDataloader/__init__.py index cbddf1ce..fac6b8d4 100644 --- a/dfm/src/automodel/datasets/multiresolutionDataloader/__init__.py +++ b/dfm/src/automodel/datasets/multiresolutionDataloader/__init__.py @@ -17,6 +17,10 @@ build_multiresolution_dataloader, collate_fn_production, ) +from .flux_collate import ( + build_flux_multiresolution_dataloader, + collate_fn_flux, +) from .multi_tier_bucketing import MultiTierBucketCalculator from .text_to_image_dataset import TextToImageDataset @@ -27,4 +31,7 @@ "SequentialBucketSampler", "build_multiresolution_dataloader", "collate_fn_production", + # Flux-specific + "build_flux_multiresolution_dataloader", + "collate_fn_flux", ] diff --git a/dfm/src/automodel/datasets/multiresolutionDataloader/dataloader.py b/dfm/src/automodel/datasets/multiresolutionDataloader/dataloader.py index 7fb39d71..e535f378 100644 --- a/dfm/src/automodel/datasets/multiresolutionDataloader/dataloader.py +++ b/dfm/src/automodel/datasets/multiresolutionDataloader/dataloader.py @@ -107,7 +107,6 @@ def __init__( logger.info( f" Base batch size: {base_batch_size}" + (f" @ {base_resolution}" if dynamic_batch_size else " (fixed)") ) - logger.info(f" DDP: rank {self.rank} of {self.num_replicas}") def _get_batch_size(self, resolution: Tuple[int, int]) -> int: """Get batch size for resolution (dynamic or fixed based on setting).""" @@ -258,8 +257,8 @@ def collate_fn_production(batch: List[Dict]) -> Dict: # Handle text encodings if "clip_hidden" in batch[0]: output["clip_hidden"] = torch.stack([item["clip_hidden"] for item in batch]) - output["clip_pooled"] = torch.stack([item["clip_pooled"] for item in batch]) - output["t5_hidden"] = torch.stack([item["t5_hidden"] for item in batch]) + output["pooled_prompt_embeds"] = torch.stack([item["pooled_prompt_embeds"] for item in batch]) + output["prompt_embeds"] = torch.stack([item["prompt_embeds"] for item in batch]) else: output["clip_tokens"] = torch.stack([item["clip_tokens"] for item in batch]) output["t5_tokens"] = torch.stack([item["t5_tokens"] for item in batch]) diff --git a/dfm/src/automodel/datasets/multiresolutionDataloader/flux_collate.py b/dfm/src/automodel/datasets/multiresolutionDataloader/flux_collate.py new file mode 100644 index 00000000..8c9e4519 --- /dev/null +++ b/dfm/src/automodel/datasets/multiresolutionDataloader/flux_collate.py @@ -0,0 +1,165 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Flux-compatible collate function that wraps the multiresolution dataloader output +to match the FlowMatchingPipeline expected batch format. +""" + +import logging +from typing import Dict, List, Tuple + +from torch.utils.data import DataLoader + +from dfm.src.automodel.datasets.multiresolutionDataloader.dataloader import ( + SequentialBucketSampler, + collate_fn_production, +) +from dfm.src.automodel.datasets.multiresolutionDataloader.text_to_image_dataset import TextToImageDataset + + +logger = logging.getLogger(__name__) + + +def collate_fn_flux(batch: List[Dict]) -> Dict: + """ + Flux-compatible collate function that transforms multiresolution batch output + to match FlowMatchingPipeline expected format. + + Args: + batch: List of samples from TextToImageDataset + + Returns: + Dict compatible with FlowMatchingPipeline.step() + """ + # First, use the production collate to stack tensors + production_batch = collate_fn_production(batch) + + # Keep latent as 4D [B, C, H, W] for Flux (image model, not video) + latent = production_batch["latent"] + + # Use "image_latents" key for 4D tensors (FluxAdapter expects 4D) + flux_batch = { + "image_latents": latent, + "data_type": "image", + "metadata": { + "prompts": production_batch.get("prompt", []), + "image_paths": production_batch.get("image_path", []), + "bucket_ids": production_batch.get("bucket_id", []), + "aspect_ratios": production_batch.get("aspect_ratio", []), + "crop_resolution": production_batch.get("crop_resolution"), + "original_resolution": production_batch.get("original_resolution"), + "crop_offset": production_batch.get("crop_offset"), + }, + } + + # Handle text embeddings (pre-encoded vs tokenized) + if "prompt_embeds" in production_batch: + # Pre-encoded text embeddings + flux_batch["text_embeddings"] = production_batch["prompt_embeds"] + flux_batch["pooled_prompt_embeds"] = production_batch["pooled_prompt_embeds"] + # Also include CLIP hidden for models that need it + if "clip_hidden" in production_batch: + flux_batch["clip_hidden"] = production_batch["clip_hidden"] + else: + # Tokenized - need to encode during training (not supported yet) + flux_batch["t5_tokens"] = production_batch["t5_tokens"] + flux_batch["clip_tokens"] = production_batch["clip_tokens"] + raise NotImplementedError( + "On-the-fly text encoding not yet supported. Please use pre-encoded text embeddings in your dataset." + ) + + return flux_batch + + +def build_flux_multiresolution_dataloader( + *, + # TextToImageDataset parameters + cache_dir: str, + train_text_encoder: bool = False, + # Dataloader parameters + batch_size: int = 1, + dp_rank: int = 0, + dp_world_size: int = 1, + base_resolution: Tuple[int, int] = (256, 256), + drop_last: bool = True, + shuffle: bool = True, + dynamic_batch_size: bool = False, + num_workers: int = 4, + pin_memory: bool = True, + prefetch_factor: int = 2, +) -> Tuple[DataLoader, SequentialBucketSampler]: + """ + Build a Flux-compatible multiresolution dataloader for TrainDiffusionRecipe. + + This wraps the existing TextToImageDataset and SequentialBucketSampler + with a Flux-compatible collate function. + + Args: + cache_dir: Directory containing preprocessed cache (metadata.json, shards, and resolution subdirs) + train_text_encoder: If True, returns tokens instead of embeddings + batch_size: Batch size per GPU + dp_rank: Data parallel rank + dp_world_size: Data parallel world size + base_resolution: Base resolution for dynamic batch sizing + drop_last: Drop incomplete batches + shuffle: Shuffle data + dynamic_batch_size: Scale batch size by resolution + num_workers: DataLoader workers + pin_memory: Pin memory for GPU transfer + prefetch_factor: Prefetch batches per worker + + Returns: + Tuple of (DataLoader, SequentialBucketSampler) + """ + logger.info("Building Flux multiresolution dataloader:") + logger.info(f" cache_dir: {cache_dir}") + logger.info(f" train_text_encoder: {train_text_encoder}") + logger.info(f" batch_size: {batch_size}") + logger.info(f" dp_rank: {dp_rank}, dp_world_size: {dp_world_size}") + + # Create dataset + dataset = TextToImageDataset( + cache_dir=cache_dir, + train_text_encoder=train_text_encoder, + ) + + # Create sampler + sampler = SequentialBucketSampler( + dataset, + base_batch_size=batch_size, + base_resolution=base_resolution, + drop_last=drop_last, + shuffle_buckets=shuffle, + shuffle_within_bucket=shuffle, + dynamic_batch_size=dynamic_batch_size, + num_replicas=dp_world_size, + rank=dp_rank, + ) + + # Create dataloader with Flux-compatible collate + dataloader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=collate_fn_flux, # Use Flux-compatible collate + num_workers=num_workers, + pin_memory=pin_memory, + prefetch_factor=prefetch_factor if num_workers > 0 else None, + persistent_workers=num_workers > 0, + ) + + logger.info(f" Dataset size: {len(dataset)}") + logger.info(f" Batches per epoch: {len(sampler)}") + + return dataloader, sampler diff --git a/dfm/src/automodel/datasets/multiresolutionDataloader/text_to_image_dataset.py b/dfm/src/automodel/datasets/multiresolutionDataloader/text_to_image_dataset.py index c691f682..fb94f8a5 100644 --- a/dfm/src/automodel/datasets/multiresolutionDataloader/text_to_image_dataset.py +++ b/dfm/src/automodel/datasets/multiresolutionDataloader/text_to_image_dataset.py @@ -156,7 +156,7 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: output["t5_tokens"] = data["t5_tokens"].squeeze(0) else: output["clip_hidden"] = data["clip_hidden"].squeeze(0) - output["clip_pooled"] = data["clip_pooled"].squeeze(0) - output["t5_hidden"] = data["t5_hidden"].squeeze(0) + output["pooled_prompt_embeds"] = data["pooled_prompt_embeds"].squeeze(0) + output["prompt_embeds"] = data["prompt_embeds"].squeeze(0) return output diff --git a/dfm/src/automodel/flow_matching/adapters/__init__.py b/dfm/src/automodel/flow_matching/adapters/__init__.py index 15cffef5..eccfe579 100644 --- a/dfm/src/automodel/flow_matching/adapters/__init__.py +++ b/dfm/src/automodel/flow_matching/adapters/__init__.py @@ -22,15 +22,17 @@ - ModelAdapter: Abstract base class for all adapters - HunyuanAdapter: For HunyuanVideo 1.5 style models - SimpleAdapter: For simple transformer models (e.g., Wan) +- FluxAdapter: For FLUX.1 text-to-image models Usage: - from automodel.flow_matching.adapters import HunyuanAdapter, SimpleAdapter + from automodel.flow_matching.adapters import HunyuanAdapter, SimpleAdapter, FluxAdapter # Or import the base class to create custom adapters from automodel.flow_matching.adapters import ModelAdapter """ from .base import FlowMatchingContext, ModelAdapter +from .flux import FluxAdapter from .hunyuan import HunyuanAdapter from .simple import SimpleAdapter @@ -38,6 +40,7 @@ __all__ = [ "FlowMatchingContext", "ModelAdapter", + "FluxAdapter", "HunyuanAdapter", "SimpleAdapter", ] diff --git a/dfm/src/automodel/flow_matching/adapters/base.py b/dfm/src/automodel/flow_matching/adapters/base.py index d9b117af..a8a1def4 100644 --- a/dfm/src/automodel/flow_matching/adapters/base.py +++ b/dfm/src/automodel/flow_matching/adapters/base.py @@ -36,20 +36,23 @@ class FlowMatchingContext: without coupling to the batch dictionary structure. Attributes: - noisy_latents: [B, C, F, H, W] - Noisy latents after interpolation - video_latents: [B, C, F, H, W] - Original clean latents + noisy_latents: [B, C, F, H, W] or [B, C, H, W] - Noisy latents after interpolation + latents: [B, C, F, H, W] for video or [B, C, H, W] for image - Original clean latents + (also accessible via deprecated 'video_latents' property for backward compatibility) timesteps: [B] - Sampled timesteps sigma: [B] - Sigma values task_type: "t2v" or "i2v" data_type: "video" or "image" device: Device for tensor operations dtype: Data type for tensor operations + cfg_dropout_prob: Probability of dropping text embeddings (setting to 0) during + training for classifier-free guidance (CFG). Defaults to 0.0 for backward compatibility. batch: Original batch dictionary (for model-specific data) """ # Core tensors noisy_latents: torch.Tensor - video_latents: torch.Tensor + latents: torch.Tensor timesteps: torch.Tensor sigma: torch.Tensor @@ -64,6 +67,14 @@ class FlowMatchingContext: # Original batch (for model-specific data) batch: Dict[str, Any] + # CFG dropout probability (optional with default for backward compatibility) + cfg_dropout_prob: float = 0.0 + + @property + def video_latents(self) -> torch.Tensor: + """Backward compatibility alias for 'latents' field.""" + return self.latents + class ModelAdapter(ABC): """ diff --git a/dfm/src/automodel/flow_matching/adapters/flux.py b/dfm/src/automodel/flow_matching/adapters/flux.py new file mode 100644 index 00000000..4d05f464 --- /dev/null +++ b/dfm/src/automodel/flow_matching/adapters/flux.py @@ -0,0 +1,222 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Flux model adapter for FlowMatching Pipeline. + +This adapter supports FLUX.1 style models with: +- T5 text embeddings (text_embeddings) +- CLIP pooled embeddings (pooled_prompt_embeds) +- 2D image latents (treated as 1-frame video: [B, C, 1, H, W]) +""" + +import random +from typing import Any, Dict + +import torch +import torch.nn as nn + +from .base import FlowMatchingContext, ModelAdapter + + +class FluxAdapter(ModelAdapter): + """ + Model adapter for FLUX.1 image generation models. + + Supports batch format from multiresolution dataloader: + - image_latents: [B, C, H, W] for images + - text_embeddings: T5 embeddings [B, seq_len, 4096] + - pooled_prompt_embeds: CLIP pooled [B, 768] + + FLUX model forward interface: + - hidden_states: Packed latents + - encoder_hidden_states: T5 text embeddings + - pooled_projections: CLIP pooled embeddings + - timestep: Normalized timesteps [0, 1] + - img_ids / txt_ids: Positional embeddings + """ + + def __init__( + self, + guidance_scale: float = 3.5, + use_guidance_embeds: bool = True, + ): + """ + Initialize FluxAdapter. + + Args: + guidance_scale: Guidance scale for classifier-free guidance + use_guidance_embeds: Whether to use guidance embeddings + """ + self.guidance_scale = guidance_scale + self.use_guidance_embeds = use_guidance_embeds + + def _pack_latents(self, latents: torch.Tensor) -> torch.Tensor: + """ + Pack latents from [B, C, H, W] to Flux format [B, (H//2)*(W//2), C*4]. + + Flux uses a 2x2 patch embedding, so latents are reshaped accordingly. + """ + b, c, h, w = latents.shape + # Reshape: [B, C, H, W] -> [B, C, H//2, 2, W//2, 2] + latents = latents.view(b, c, h // 2, 2, w // 2, 2) + # Permute: -> [B, H//2, W//2, C, 2, 2] + latents = latents.permute(0, 2, 4, 1, 3, 5) + # Reshape: -> [B, (H//2)*(W//2), C*4] + latents = latents.reshape(b, (h // 2) * (w // 2), c * 4) + return latents + + @staticmethod + def _unpack_latents(latents: torch.Tensor, height: int, width: int, vae_scale_factor: int = 8) -> torch.Tensor: + """ + Unpack latents from Flux format back to [B, C, H, W]. + + Args: + latents: Packed latents of shape [B, num_patches, channels] + height: Original image height in pixels + width: Original image width in pixels + vae_scale_factor: VAE compression factor (default: 8) + """ + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def _prepare_latent_image_ids( + self, + batch_size: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """ + Prepare positional IDs for image latents. + + Returns tensor of shape [B, (H//2)*(W//2), 3] containing (batch_idx, y, x). + """ + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = torch.arange(width // 2)[None, :] + + latent_image_ids = latent_image_ids.reshape(-1, 3) + return latent_image_ids.to(device=device, dtype=dtype) + + def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: + """ + Prepare inputs for Flux model from FlowMatchingContext. + + Expects 4D image latents: [B, C, H, W] + """ + batch = context.batch + device = context.device + dtype = context.dtype + + # Flux only supports 4D image latents [B, C, H, W] + noisy_latents = context.noisy_latents + if noisy_latents.ndim != 4: + raise ValueError(f"FluxAdapter expects 4D latents [B, C, H, W], got {noisy_latents.ndim}D") + + batch_size, channels, height, width = noisy_latents.shape + + # Get text embeddings (T5) + text_embeddings = batch["text_embeddings"].to(device, dtype=dtype) + if text_embeddings.ndim == 2: + text_embeddings = text_embeddings.unsqueeze(0) + + # Get pooled embeddings (CLIP) - may or may not be present + if "pooled_prompt_embeds" in batch: + pooled_projections = batch["pooled_prompt_embeds"].to(device, dtype=dtype) + elif "clip_pooled" in batch: + pooled_projections = batch["clip_pooled"].to(device, dtype=dtype) + else: + # Create zero embeddings if not provided + pooled_projections = torch.zeros(batch_size, 768, device=device, dtype=dtype) + + if pooled_projections.ndim == 1: + pooled_projections = pooled_projections.unsqueeze(0) + + if random.random() < context.cfg_dropout_prob: + text_embeddings = torch.zeros_like(text_embeddings) + pooled_projections = torch.zeros_like(pooled_projections) + + # Pack latents for Flux transformer + packed_latents = self._pack_latents(noisy_latents) + + # Prepare positional IDs + img_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + # Text positional IDs + text_seq_len = text_embeddings.shape[1] + txt_ids = torch.zeros(batch_size, text_seq_len, 3, device=device, dtype=dtype) + + # Timesteps - Flux expects normalized [0, 1] range + # The pipeline provides timesteps in [0, num_train_timesteps] + timesteps = context.timesteps.to(dtype) / 1000.0 + + guidance = torch.full((batch_size,), 3.5, device=device, dtype=torch.float32) + + inputs = { + "hidden_states": packed_latents, + "encoder_hidden_states": text_embeddings, + "pooled_projections": pooled_projections, + "timestep": timesteps, + "img_ids": img_ids, + "txt_ids": txt_ids, + # Store original shape for unpacking + "_original_shape": (batch_size, channels, height, width), + "guidance": guidance, + } + + return inputs + + def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor: + """ + Execute forward pass for Flux model. + + Returns unpacked prediction in [B, C, H, W] format. + """ + original_shape = inputs.pop("_original_shape") + batch_size, channels, height, width = original_shape + + # Flux forward pass + model_pred = model( + hidden_states=inputs["hidden_states"], + encoder_hidden_states=inputs["encoder_hidden_states"], + pooled_projections=inputs["pooled_projections"], + timestep=inputs["timestep"], + img_ids=inputs["img_ids"], + txt_ids=inputs["txt_ids"], + guidance=inputs["guidance"], + return_dict=False, + ) + + # Handle tuple output + pred = self.post_process_prediction(model_pred) + + # Unpack from Flux format back to [B, C, H, W] + # Pass pixel dimensions (latent * vae_scale_factor) to _unpack_latents + vae_scale_factor = 8 + pred = self._unpack_latents(pred, height * vae_scale_factor, width * vae_scale_factor) + + return pred diff --git a/dfm/src/automodel/flow_matching/adapters/hunyuan.py b/dfm/src/automodel/flow_matching/adapters/hunyuan.py index c60f8bbd..240fd3ca 100644 --- a/dfm/src/automodel/flow_matching/adapters/hunyuan.py +++ b/dfm/src/automodel/flow_matching/adapters/hunyuan.py @@ -142,7 +142,7 @@ def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: # Prepare latents (with or without condition) if self.use_condition_latents: - cond_latents = self.get_condition_latents(context.video_latents, context.task_type) + cond_latents = self.get_condition_latents(context.latents, context.task_type) latents = torch.cat([context.noisy_latents, cond_latents], dim=1) else: latents = context.noisy_latents diff --git a/dfm/src/automodel/flow_matching/flow_matching_pipeline.py b/dfm/src/automodel/flow_matching/flow_matching_pipeline.py index 89ab621b..c0c20573 100644 --- a/dfm/src/automodel/flow_matching/flow_matching_pipeline.py +++ b/dfm/src/automodel/flow_matching/flow_matching_pipeline.py @@ -39,6 +39,7 @@ # Import adapters from the adapters module from .adapters import ( FlowMatchingContext, + FluxAdapter, HunyuanAdapter, ModelAdapter, SimpleAdapter, @@ -114,6 +115,7 @@ def __init__( timestep_sampling: str = "logit_normal", flow_shift: float = 3.0, i2v_prob: float = 0.3, + cfg_dropout_prob: float = 0.1, # Logit-normal distribution parameters logit_mean: float = 0.0, logit_std: float = 1.0, @@ -143,6 +145,7 @@ def __init__( - "mix": Mix of lognorm and uniform flow_shift: Shift parameter for timestep transformation i2v_prob: Probability of using image-to-video conditioning + cfg_dropout_prob: Probability of dropping text embeddings for CFG training logit_mean: Mean for logit-normal distribution logit_std: Std for logit-normal distribution mix_uniform_ratio: Ratio of uniform samples when using mix @@ -158,6 +161,7 @@ def __init__( self.timestep_sampling = timestep_sampling self.flow_shift = flow_shift self.i2v_prob = i2v_prob + self.cfg_dropout_prob = cfg_dropout_prob self.logit_mean = logit_mean self.logit_std = logit_std self.mix_uniform_ratio = mix_uniform_ratio @@ -262,8 +266,8 @@ def compute_loss( model_pred: torch.Tensor, target: torch.Tensor, sigma: torch.Tensor, - batch: Dict[str, Any], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Compute flow matching loss with optional weighting. @@ -273,14 +277,18 @@ def compute_loss( model_pred: Model prediction target: Target (velocity = noise - clean) sigma: Sigma values for each sample + batch: Optional batch dictionary containing loss_mask Returns: - weighted_loss: Final loss to backprop - unweighted_loss: Raw MSE loss + weighted_loss: Per-element weighted loss + average_weighted_loss: Scalar average weighted loss + unweighted_loss: Per-element raw MSE loss + average_unweighted_loss: Scalar average unweighted loss loss_weight: Applied weights + loss_mask: Loss mask from batch (or None if not present) """ loss = nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none") - loss_mask = batch["loss_mask"] if "loss_mask" in batch else None + loss_mask = batch.get("loss_mask") if batch is not None else None if self.use_loss_weighting: loss_weight = 1.0 + self.flow_shift * sigma @@ -304,13 +312,15 @@ def step( device: torch.device = torch.device("cuda"), dtype: torch.dtype = torch.bfloat16, global_step: int = 0, - ) -> Tuple[torch.Tensor, Dict[str, Any]]: + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Dict[str, Any]]: """ Execute a single training step with flow matching. Expected batch format: { - "video_latents": torch.Tensor, # [B, C, F, H, W] + "video_latents": torch.Tensor, # [B, C, F, H, W] for video + OR + "image_latents": torch.Tensor, # [B, C, H, W] for image "text_embeddings": torch.Tensor, # [B, seq_len, dim] "data_type": str, # "video" or "image" (optional) # ... additional model-specific keys handled by adapter @@ -324,21 +334,25 @@ def step( global_step: Current training step (for logging) Returns: - loss: The computed loss + weighted_loss: Per-element weighted loss + average_weighted_loss: Scalar average weighted loss + loss_mask: Mask indicating valid loss elements (or None) metrics: Dictionary of training metrics """ debug_mode = os.environ.get("DEBUG_TRAINING", "0") == "1" detailed_log = global_step % self.log_interval == 0 summary_log = global_step % self.summary_log_interval == 0 - # Extract and prepare batch data - video_latents = batch["video_latents"].to(device, dtype=dtype) - - # Handle tensor shapes - if video_latents.ndim == 4: - video_latents = video_latents.unsqueeze(0) + # Extract and prepare batch data (either image_latents or video_latents) + if "video_latents" in batch: + latents = batch["video_latents"].to(device, dtype=dtype) + elif "image_latents" in batch: + latents = batch["image_latents"].to(device, dtype=dtype) + else: + raise KeyError("Batch must contain either 'video_latents' or 'image_latents'") - batch_size = video_latents.shape[0] + # latents can be 4D [B, C, H, W] for images or 5D [B, C, F, H, W] for videos + batch_size = latents.shape[0] # Determine task type data_type = batch.get("data_type", "video") @@ -352,19 +366,19 @@ def step( # ==================================================================== # Flow Matching: Add Noise # ==================================================================== - noise = torch.randn_like(video_latents, dtype=torch.float32) + noise = torch.randn_like(latents, dtype=torch.float32) # x_t = (1 - σ) * x_0 + σ * ε - noisy_latents = self.noise_schedule.forward(video_latents.float(), noise, sigma) + noisy_latents = self.noise_schedule.forward(latents.float(), noise, sigma) # ==================================================================== # Logging # ==================================================================== - if debug_mode and detailed_log: + if detailed_log and debug_mode: self._log_detailed( - global_step, sampling_method, batch_size, sigma, timesteps, video_latents, noise, noisy_latents + global_step, sampling_method, batch_size, sigma, timesteps, latents, noise, noisy_latents ) - elif debug_mode and summary_log: + elif summary_log and debug_mode: logger.info( f"[STEP {global_step}] σ=[{sigma.min():.3f},{sigma.max():.3f}] | " f"t=[{timesteps.min():.1f},{timesteps.max():.1f}] | " @@ -380,13 +394,14 @@ def step( # ==================================================================== context = FlowMatchingContext( noisy_latents=noisy_latents, - video_latents=video_latents, + latents=latents, timesteps=timesteps, sigma=sigma, task_type=task_type, data_type=data_type, device=device, dtype=dtype, + cfg_dropout_prob=self.cfg_dropout_prob, batch=batch, ) @@ -397,7 +412,7 @@ def step( # Target: Flow Matching Velocity # ==================================================================== # v = ε - x_0 - target = noise - video_latents.float() + target = noise - latents.float() # ==================================================================== # Loss Computation @@ -412,9 +427,11 @@ def step( raise ValueError(f"Loss exploded: {average_weighted_loss.item()}") # Logging - if debug_mode and detailed_log: - self._log_loss_detailed(global_step, model_pred, target, loss_weight, unweighted_loss, weighted_loss) - elif debug_mode and summary_log: + if detailed_log and debug_mode: + self._log_loss_detailed( + global_step, model_pred, target, loss_weight, average_unweighted_loss, average_weighted_loss + ) + elif summary_log and debug_mode: logger.info( f"[STEP {global_step}] Loss: {average_weighted_loss.item():.6f} | " f"w=[{loss_weight.min():.2f},{loss_weight.max():.2f}]" @@ -447,7 +464,7 @@ def _log_detailed( batch_size: int, sigma: torch.Tensor, timesteps: torch.Tensor, - video_latents: torch.Tensor, + latents: torch.Tensor, noise: torch.Tensor, noisy_latents: torch.Tensor, ): @@ -469,15 +486,15 @@ def _log_detailed( logger.info("") logger.info(f"[TIMESTEPS] Range: [{timesteps.min():.2f}, {timesteps.max():.2f}]") logger.info("") - logger.info(f"[RANGES] Clean latents: [{video_latents.min():.4f}, {video_latents.max():.4f}]") + logger.info(f"[RANGES] Clean latents: [{latents.min():.4f}, {latents.max():.4f}]") logger.info(f"[RANGES] Noise: [{noise.min():.4f}, {noise.max():.4f}]") logger.info(f"[RANGES] Noisy latents: [{noisy_latents.min():.4f}, {noisy_latents.max():.4f}]") # Sanity check max_expected = ( max( - abs(video_latents.max().item()), - abs(video_latents.min().item()), + abs(latents.max().item()), + abs(latents.min().item()), abs(noise.max().item()), abs(noise.min().item()), ) @@ -511,9 +528,11 @@ def _log_loss_detailed( logger.info(f"[WEIGHTS] Range: [{loss_weight.min():.4f}, {loss_weight.max():.4f}]") logger.info(f"[WEIGHTS] Mean: {loss_weight.mean():.4f}") logger.info("") - logger.info(f"[LOSS] Unweighted: {unweighted_loss.item():.6f}") - logger.info(f"[LOSS] Weighted: {weighted_loss.item():.6f}") - logger.info(f"[LOSS] Impact: {(weighted_loss / max(unweighted_loss, 1e-8)):.3f}x") + unweighted_val = unweighted_loss.item() + weighted_val = weighted_loss.item() + logger.info(f"[LOSS] Unweighted: {unweighted_val:.6f}") + logger.info(f"[LOSS] Weighted: {weighted_val:.6f}") + logger.info(f"[LOSS] Impact: {(weighted_val / max(unweighted_val, 1e-8)):.3f}x") logger.info("=" * 80 + "\n") @@ -527,7 +546,7 @@ def create_adapter(adapter_type: str, **kwargs) -> ModelAdapter: Factory function to create a model adapter by name. Args: - adapter_type: Type of adapter ("hunyuan", "simple") + adapter_type: Type of adapter ("hunyuan", "simple", "flux") **kwargs: Additional arguments passed to the adapter constructor Returns: @@ -536,6 +555,7 @@ def create_adapter(adapter_type: str, **kwargs) -> ModelAdapter: adapters = { "hunyuan": HunyuanAdapter, "simple": SimpleAdapter, + "flux": FluxAdapter, } if adapter_type not in adapters: diff --git a/dfm/src/automodel/recipes/train.py b/dfm/src/automodel/recipes/train.py index 2d34e078..c67032b2 100644 --- a/dfm/src/automodel/recipes/train.py +++ b/dfm/src/automodel/recipes/train.py @@ -32,7 +32,7 @@ from torch.distributed.fsdp import MixedPrecisionPolicy from transformers.utils.hub import TRANSFORMERS_CACHE -from dfm.src.automodel._diffusers.auto_diffusion_pipeline import NeMoAutoDiffusionPipeline, NeMoWanPipeline +from dfm.src.automodel._diffusers.auto_diffusion_pipeline import NeMoAutoDiffusionPipeline from dfm.src.automodel.flow_matching.flow_matching_pipeline import FlowMatchingPipeline, create_adapter @@ -48,12 +48,13 @@ def build_model_and_optimizer( ddp_cfg: Optional[Dict[str, Any]] = None, attention_backend: Optional[str] = None, optimizer_cfg: Optional[Dict[str, Any]] = None, -) -> tuple[NeMoWanPipeline, dict[str, Dict[str, Any]], torch.optim.Optimizer, Any]: + pipeline_spec: Optional[Dict[str, Any]] = None, +) -> tuple[NeMoAutoDiffusionPipeline, torch.optim.Optimizer, Any]: """Build the diffusion model, parallel scheme, and optimizer. Args: model_id: Pretrained model name or path. - finetune_mode: Whether to load for finetuning. + finetune_mode: Whether to load for finetuning (True) or pretraining (False). learning_rate: Learning rate for optimizer. device: Target device. dtype: Model dtype. @@ -62,12 +63,18 @@ def build_model_and_optimizer( ddp_cfg: DDP configuration dict. Mutually exclusive with fsdp_cfg. attention_backend: Optional attention backend override. optimizer_cfg: Optional optimizer configuration. + pipeline_spec: Pipeline specification for pretraining (from_config). + Required when finetune_mode is False. Should contain: + - transformer_cls: str (e.g., "WanTransformer3DModel", "FluxTransformer2DModel") + - subfolder: str (e.g., "transformer") + - Optional: pipeline_cls, load_full_pipeline, enable_gradient_checkpointing Returns: Tuple of (pipeline, optimizer, device_mesh or None). Raises: ValueError: If both fsdp_cfg and ddp_cfg are provided. + ValueError: If finetune_mode is False and pipeline_spec is not provided. """ # Validate mutually exclusive configs if fsdp_cfg is not None and ddp_cfg is not None: @@ -124,23 +131,37 @@ def build_model_and_optimizer( parallel_scheme = {"transformer": manager_args} - kwargs = {} if finetune_mode: - kwargs["load_for_training"] = True - kwargs["low_cpu_mem_usage"] = True - if "wan" in model_id: - init_fn = NeMoWanPipeline.from_pretrained if finetune_mode else NeMoWanPipeline.from_config + # Finetuning: load from pretrained weights + logging.info("[INFO] Loading pretrained model for finetuning") + pipe, created_managers = NeMoAutoDiffusionPipeline.from_pretrained( + model_id, + torch_dtype=dtype, + device=device, + parallel_scheme=parallel_scheme, + components_to_load=["transformer"], + load_for_training=True, + low_cpu_mem_usage=True, + ) else: - init_fn = NeMoAutoDiffusionPipeline.from_pretrained - - pipe, created_managers = init_fn( - model_id, - torch_dtype=dtype, - device=device, - parallel_scheme=parallel_scheme, - components_to_load=["transformer"], - **kwargs, - ) + # Pretraining: initialize with random weights using pipeline_spec + if pipeline_spec is None: + raise ValueError( + "pipeline_spec is required for pretraining (finetune_mode=False). " + "Please provide pipeline_spec in your YAML config with at least:\n" + " pipeline_spec:\n" + " transformer_cls: 'WanTransformer3DModel' # or 'FluxTransformer2DModel', etc.\n" + " subfolder: 'transformer'" + ) + logging.info("[INFO] Initializing model with random weights for pretraining") + pipe, created_managers = NeMoAutoDiffusionPipeline.from_config( + model_id, + pipeline_spec=pipeline_spec, + torch_dtype=dtype, + device=device, + parallel_scheme=parallel_scheme, + components_to_load=["transformer"], + ) fsdp2_manager = created_managers["transformer"] transformer_module = pipe.transformer if attention_backend is not None: @@ -278,6 +299,10 @@ def setup(self): logging.info(f"[INFO] - Mix uniform ratio: {self.mix_uniform_ratio}") logging.info(f"[INFO] - Use loss weighting: {self.use_loss_weighting}") + # Get pipeline_spec for pretraining mode (required when mode != "finetune") + pipeline_spec_cfg = self.cfg.get("model.pipeline_spec", None) + pipeline_spec = pipeline_spec_cfg.to_dict() if pipeline_spec_cfg is not None else None + (self.pipe, self.optimizer, self.device_mesh) = build_model_and_optimizer( model_id=self.model_id, finetune_mode=self.cfg.get("model.mode", "finetune").lower() == "finetune", @@ -289,6 +314,7 @@ def setup(self): ddp_cfg=ddp_cfg, optimizer_cfg=self.cfg.get("optim.optimizer", {}), attention_backend=self.attention_backend, + pipeline_spec=pipeline_spec, ) self.model = self.pipe.transformer @@ -442,7 +468,7 @@ def run_train_validation_loop(self): micro_losses = [] for micro_batch in batch_group: try: - _, loss, _, metrics = self.flow_matching_pipeline.step( + weighted_loss, average_weighted_loss, loss_mask, metrics = self.flow_matching_pipeline.step( model=self.model, batch=micro_batch, device=self.device, @@ -456,8 +482,9 @@ def run_train_validation_loop(self): logging.info(f"[DEBUG] Batch shapes - video: {video_shape}, text: {text_shape}") raise - (loss / len(batch_group)).backward() - micro_losses.append(float(loss.item())) + # Use average_weighted_loss for backprop (scalar for gradient accumulation) + (average_weighted_loss / len(batch_group)).backward() + micro_losses.append(float(average_weighted_loss.item())) grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) grad_norm = float(grad_norm) if torch.is_tensor(grad_norm) else grad_norm diff --git a/dfm/src/automodel/utils/preprocessing_multiprocess.py b/dfm/src/automodel/utils/preprocessing_multiprocess.py new file mode 100644 index 00000000..6b7ef2bf --- /dev/null +++ b/dfm/src/automodel/utils/preprocessing_multiprocess.py @@ -0,0 +1,539 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import hashlib +import json +import os +import traceback +from multiprocessing import Pool +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import torch +from PIL import Image +from tqdm import tqdm + +from dfm.src.automodel.datasets.multiresolutionDataloader.multi_tier_bucketing import MultiTierBucketCalculator +from dfm.src.automodel.utils.processors import BaseModelProcessor, ProcessorRegistry + + +# Global worker state (initialized once per process) +_worker_models: Optional[Dict[str, Any]] = None +_worker_processor: Optional[BaseModelProcessor] = None +_worker_calculator: Optional[MultiTierBucketCalculator] = None +_worker_device: Optional[str] = None + + +def _init_worker(processor_name: str, model_name: str, gpu_id: int, max_pixels: int): + """Initialize worker process with models on assigned GPU.""" + global _worker_models, _worker_processor, _worker_calculator, _worker_device + + # Set CUDA_VISIBLE_DEVICES to isolate this GPU for the worker process. + # After this, the selected GPU becomes cuda:0 (not cuda:{gpu_id}). + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + _worker_device = "cuda:0" + + _worker_processor = ProcessorRegistry.get(processor_name) + _worker_models = _worker_processor.load_models(model_name, _worker_device) + _worker_calculator = MultiTierBucketCalculator(quantization=64, max_pixels=max_pixels) + + print(f"Worker initialized on GPU {gpu_id}") + + +def _load_caption(image_path: Path, caption_field: str = "internvl") -> Optional[str]: + """ + Load caption from JSON file for an image. + + DEPRECATED: Use _load_all_captions() instead for better performance. + This function is kept for backward compatibility only. + """ + image_name = image_path.name + + # Extract prefix: everything before '_sample' + if "_sample" in image_name: + prefix = image_name.rsplit("_sample", 1)[0] + else: + prefix = image_path.stem + + json_path = image_path.parent / f"{prefix}_internvl.json" + + if not json_path.exists(): + return None + + try: + with open(json_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + entry = json.loads(line) + if entry.get("file_name") == image_name: + return entry.get(caption_field, "") + except json.JSONDecodeError: + continue + except Exception: + pass + + return None + + +def _load_all_captions( + image_files: List[Path], caption_field: str = "internvl", verbose: bool = True +) -> Dict[str, str]: + """ + Pre-load all captions from JSONL files into memory. + + This function eliminates the performance bottleneck of repeatedly opening + and parsing the same JSONL files by loading all captions once upfront. + + Args: + image_files: List of image file paths + caption_field: Field name in JSONL to use ('internvl' or 'usr') + verbose: Print progress information + + Returns: + Dictionary mapping image filename to caption text + """ + from collections import defaultdict + + if verbose: + print("\nPre-loading captions from JSONL files...") + + # Group images by their JSONL file + jsonl_to_images = defaultdict(list) + + for image_path in image_files: + image_name = image_path.name + + # Extract prefix: everything before '_sample' + if "_sample" in image_name: + prefix = image_name.rsplit("_sample", 1)[0] + else: + prefix = image_path.stem + + json_path = image_path.parent / f"{prefix}_internvl.json" + jsonl_to_images[json_path].append(image_name) + + # Load each JSONL file once and build caption dictionary + caption_cache = {} + loaded_files = 0 + missing_files = 0 + total_captions = 0 + + for json_path, image_names in tqdm(jsonl_to_images.items(), desc="Loading JSONL files", disable=not verbose): + if not json_path.exists(): + missing_files += 1 + # Images with missing JSONL will use filename fallback + continue + + try: + with open(json_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + entry = json.loads(line) + file_name = entry.get("file_name") + if file_name and file_name in image_names: + caption = entry.get(caption_field, "") + if caption: + caption_cache[file_name] = caption + total_captions += 1 + except json.JSONDecodeError: + continue + loaded_files += 1 + except Exception as e: + if verbose: + print(f"Warning: Failed to load {json_path}: {e}") + continue + + if verbose: + print(f"Loaded {total_captions} captions from {loaded_files} JSONL files") + if missing_files > 0: + print(f" {missing_files} JSONL files not found (will use filename fallback)") + missing_captions = len(image_files) - total_captions + if missing_captions > 0: + print(f" {missing_captions} images will use filename as caption") + + return caption_cache + + +def _validate_caption_files(image_files: List[Path], caption_field: str) -> Tuple[int, int, List[str]]: + """ + Validate that caption files exist and are parseable. + + Args: + image_files: List of image file paths + caption_field: Field name to check in JSONL files + + Returns: + (num_valid_files, num_missing_files, error_messages) + """ + + # Group images by their JSONL file + jsonl_files = set() + + for image_path in image_files: + image_name = image_path.name + + # Extract prefix: everything before '_sample' + if "_sample" in image_name: + prefix = image_name.rsplit("_sample", 1)[0] + else: + prefix = image_path.stem + + json_path = image_path.parent / f"{prefix}_internvl.json" + jsonl_files.add(json_path) + + # Validate each JSONL file + valid_files = 0 + missing_files = 0 + errors = [] + + for json_path in jsonl_files: + if not json_path.exists(): + missing_files += 1 + errors.append(f"Missing: {json_path}") + continue + + try: + with open(json_path, "r", encoding="utf-8") as f: + line_count = 0 + for line in f: + line = line.strip() + if not line: + continue + line_count += 1 + try: + entry = json.loads(line) + # Basic validation: check structure + if "file_name" not in entry: + errors.append(f"Invalid format in {json_path}: missing 'file_name' field") + break + except json.JSONDecodeError as e: + errors.append(f"JSON error in {json_path} line {line_count}: {e}") + break + else: + # File parsed successfully + valid_files += 1 + except Exception as e: + errors.append(f"Failed to read {json_path}: {e}") + continue + + return valid_files, missing_files, errors + + +def _process_image(args: Tuple) -> Optional[Dict]: + """Process a single image using pre-initialized worker state.""" + image_path, output_dir, verify, caption = args + + try: + image = Image.open(image_path).convert("RGB") + orig_width, orig_height = image.size + + bucket = _worker_calculator.get_bucket_for_image(orig_width, orig_height) + target_width, target_height = bucket["resolution"] + + resized_image, crop_offset = _worker_calculator.resize_and_crop( + image, target_width, target_height, crop_mode="center" + ) + + image_tensor = _worker_processor.preprocess_image(resized_image) + latent = _worker_processor.encode_image(image_tensor, _worker_models, _worker_device) + + if verify and not _worker_processor.verify_latent(latent, _worker_models, _worker_device): + print(f"Verification failed: {image_path}") + return None + + # Use pre-loaded caption with fallback to filename + if not caption: + caption = Path(image_path).stem.replace("_", " ") + + text_encodings = _worker_processor.encode_text(caption, _worker_models, _worker_device) + + # Save cache file + resolution = f"{target_width}x{target_height}" + cache_subdir = Path(output_dir) / resolution + cache_subdir.mkdir(parents=True, exist_ok=True) + + cache_hash = hashlib.md5(f"{Path(image_path).absolute()}_{resolution}".encode()).hexdigest() + cache_file = cache_subdir / f"{cache_hash}.pt" + + metadata = { + "original_resolution": (orig_width, orig_height), + "crop_resolution": (target_width, target_height), + "crop_offset": crop_offset, + "prompt": caption, + "image_path": str(Path(image_path).absolute()), + "bucket_id": bucket["id"], + "aspect_ratio": bucket["aspect_ratio"], + } + + cache_data = _worker_processor.get_cache_data(latent, text_encodings, metadata) + torch.save(cache_data, cache_file) + + return { + "cache_file": str(cache_file), + "image_path": str(Path(image_path).absolute()), + "crop_resolution": [target_width, target_height], + "original_resolution": [orig_width, orig_height], + "prompt": caption, + "bucket_id": bucket["id"], + "aspect_ratio": bucket["aspect_ratio"], + "pixels": target_width * target_height, + "model_type": _worker_processor.model_type, + } + + except Exception as e: + print(f"Error processing {image_path}: {e}") + traceback.print_exc() + return None + + +def _get_image_files(image_dir: Path) -> List[Path]: + """ + Recursively get all image files efficiently. + + Uses os.walk() for better performance on large directories compared to rglob(). + """ + image_files = [] + valid_extensions = {"jpg", "jpeg", "png", "webp", "bmp"} + + # Use os.walk for better performance on large directories + for root, dirs, files in os.walk(image_dir): + root_path = Path(root) + for file in files: + # Extract extension and check if it's a valid image file + if "." in file: + ext = file.lower().rsplit(".", 1)[-1] + if ext in valid_extensions: + image_files.append(root_path / file) + + return sorted(image_files) + + +def _process_shard_on_gpu( + gpu_id: int, + image_files: List[Path], + output_dir: str, + processor_name: str, + model_name: str, + verify: bool, + caption_cache: Dict[str, str], + max_pixels: int, +) -> List[Dict]: + """Process a shard of images on a specific GPU.""" + _init_worker(processor_name, model_name, gpu_id, max_pixels) + + results = [] + for image_path in tqdm(image_files, desc=f"GPU {gpu_id}", position=gpu_id): + # Get caption from cache (or None if not found) + caption = caption_cache.get(image_path.name) + result = _process_image((str(image_path), output_dir, verify, caption)) + if result: + results.append(result) + + return results + + +def preprocess_dataset( + image_dir: str, + output_dir: str, + processor_name: str, + model_name: Optional[str] = None, + shard_size: int = 10000, + verify: bool = False, + caption_field: str = "internvl", + max_images: Optional[int] = None, + max_pixels: int = 256 * 256, +): + """ + Preprocess dataset with one process per GPU. + + Args: + image_dir: Directory containing images + output_dir: Output directory for cache + processor_name: Name of processor to use (e.g., 'flux', 'sdxl') + model_name: HuggingFace model name (uses processor default if None) + shard_size: Number of images per metadata shard + verify: Whether to verify latents can be decoded + caption_field: Field to use from JSON captions ('internvl' or 'usr') + max_images: Maximum number of images to process + max_pixels: Maximum pixels per image + """ + image_dir = Path(image_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Get processor and resolve model name + processor = ProcessorRegistry.get(processor_name) + if model_name is None: + model_name = processor.default_model_name + + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + raise RuntimeError("No GPUs available") + + print(f"Processor: {processor_name} ({processor.model_type})") + print(f"Model: {model_name}") + print(f"GPUs: {num_gpus}") + print(f"Max pixels: {max_pixels}") + + # Get all image files + print("\nScanning for images...") + image_files = _get_image_files(image_dir) + + if max_images is not None: + image_files = image_files[:max_images] + + print(f"Processing {len(image_files)} images") + + if not image_files: + return + + # Validate caption files before processing + print("\nValidating caption files...") + num_valid, num_missing, errors = _validate_caption_files(image_files, caption_field) + print(f" Valid JSONL files: {num_valid}") + print(f" Missing JSONL files: {num_missing}") + + if errors and num_missing > len(set([img.parent / f"{img.stem}_internvl.json" for img in image_files])) * 0.5: + print("\nWARNING: Many caption files missing or invalid. First 10 errors:") + for err in errors[:10]: + print(f" {err}") + elif errors and len(errors) <= 5: + print("\nCaption file issues:") + for err in errors: + print(f" {err}") + + # Pre-load all captions (PERFORMANCE OPTIMIZATION) + caption_cache = _load_all_captions(image_files, caption_field, verbose=True) + + # Split images across GPUs + chunks = [image_files[i::num_gpus] for i in range(num_gpus)] + + # Process with one worker per GPU + all_metadata = [] + + with Pool(processes=num_gpus) as pool: + args = [ + (gpu_id, chunks[gpu_id], str(output_dir), processor_name, model_name, verify, caption_cache, max_pixels) + for gpu_id in range(num_gpus) + ] + + results = pool.starmap(_process_shard_on_gpu, args) + + for gpu_results in results: + all_metadata.extend(gpu_results) + + # Save metadata in shards + shard_files = [] + for shard_idx in range(0, len(all_metadata), shard_size): + shard_data = all_metadata[shard_idx : shard_idx + shard_size] + shard_file = output_dir / f"metadata_shard_{shard_idx // shard_size:04d}.json" + with open(shard_file, "w") as f: + json.dump(shard_data, f, indent=2) + shard_files.append(shard_file.name) + + # Save config metadata (references shards instead of duplicating items) + metadata_file = output_dir / "metadata.json" + with open(metadata_file, "w") as f: + json.dump( + { + "processor": processor_name, + "model_name": model_name, + "model_type": processor.model_type, + "caption_field": caption_field, + "max_pixels": max_pixels, + "total_images": len(all_metadata), + "num_shards": len(shard_files), + "shard_size": shard_size, + "shards": shard_files, + }, + f, + indent=2, + ) + + # Print summary + print(f"\n{'=' * 50}") + print(f"COMPLETE: {len(all_metadata)}/{len(image_files)} images") + print(f"Output: {output_dir}") + + bucket_counts: Dict[str, int] = {} + for item in all_metadata: + res = f"{item['crop_resolution'][0]}x{item['crop_resolution'][1]}" + bucket_counts[res] = bucket_counts.get(res, 0) + 1 + + print("\nBucket distribution:") + for res in sorted(bucket_counts.keys()): + print(f" {res}: {bucket_counts[res]}") + + +def main(): + parser = argparse.ArgumentParser(description="Preprocess images (one process per GPU)") + + parser.add_argument("--list_processors", action="store_true", help="List available processors") + parser.add_argument("--image_dir", type=str, help="Input image directory") + parser.add_argument("--output_dir", type=str, help="Output cache directory") + parser.add_argument("--processor", type=str, default="flux", help="Processor name") + parser.add_argument("--model_name", type=str, default=None, help="Model name") + parser.add_argument("--shard_size", type=int, default=10000, help="Metadata shard size") + parser.add_argument("--verify", action="store_true", help="Verify latents") + parser.add_argument("--caption_field", type=str, default="internvl", choices=["internvl", "usr"]) + parser.add_argument("--max_images", type=int, default=None, help="Max images to process") + parser.add_argument("--max_pixels", type=int, default=None, help="Max pixels per image") + parser.add_argument( + "--resolution_preset", type=str, default=None, choices=["256p", "512p", "768p", "1024p", "1536p"] + ) + + args = parser.parse_args() + + if args.list_processors: + print("Available processors:") + for name in ProcessorRegistry.list_available(): + proc = ProcessorRegistry.get(name) + print(f" {name}: {proc.model_type}") + return + + if not args.image_dir or not args.output_dir: + parser.error("--image_dir and --output_dir are required") + + if args.resolution_preset and args.max_pixels: + parser.error("Cannot specify both --resolution_preset and --max_pixels") + + if args.resolution_preset: + max_pixels = MultiTierBucketCalculator.RESOLUTION_PRESETS[args.resolution_preset] + elif args.max_pixels: + max_pixels = args.max_pixels + else: + max_pixels = 256 * 256 + + preprocess_dataset( + args.image_dir, + args.output_dir, + args.processor, + args.model_name, + args.shard_size, + args.verify, + args.caption_field, + args.max_images, + max_pixels, + ) + + +if __name__ == "__main__": + main() diff --git a/dfm/src/automodel/utils/processors/__init__.py b/dfm/src/automodel/utils/processors/__init__.py new file mode 100644 index 00000000..5991f3fc --- /dev/null +++ b/dfm/src/automodel/utils/processors/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import BaseModelProcessor +from .flux import FluxProcessor +from .registry import ProcessorRegistry + + +__all__ = [ + "BaseModelProcessor", + "ProcessorRegistry", + "FluxProcessor", +] diff --git a/dfm/src/automodel/utils/processors/base.py b/dfm/src/automodel/utils/processors/base.py new file mode 100644 index 00000000..29ea979d --- /dev/null +++ b/dfm/src/automodel/utils/processors/base.py @@ -0,0 +1,194 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Dict + +import torch +from PIL import Image + + +class BaseModelProcessor(ABC): + """ + Abstract base class for model-specific preprocessing logic. + + Each model architecture (FLUX, SDXL, SD1.5, SD3, etc.) should have its own + processor implementation that handles: + - Model loading (VAE, text encoders) + - Image encoding to latent space + - Text encoding to embeddings + - Verification of encoded latents + - Cache data structure formatting + """ + + @property + @abstractmethod + def model_type(self) -> str: + """ + Return the model type identifier. + + Returns: + str: Model type (e.g., 'flux', 'sdxl', 'sd15', 'sd3') + """ + pass + + @property + def default_model_name(self) -> str: + """ + Return the default HuggingFace model path for this processor. + + Returns: + str: Default model name/path + """ + raise NotImplementedError(f"{self.__class__.__name__} does not specify a default model name") + + @abstractmethod + def load_models(self, model_name: str, device: str) -> Dict[str, Any]: + """ + Load all required models for this architecture. + + Args: + model_name: HuggingFace model name/path + device: Device to load models on (e.g., 'cuda', 'cuda:0', 'cpu') + + Returns: + Dict containing all loaded models and tokenizers + """ + pass + + @abstractmethod + def encode_image( + self, + image_tensor: torch.Tensor, + models: Dict[str, Any], + device: str, + ) -> torch.Tensor: + """ + Encode image tensor to latent space. + + Args: + image_tensor: Image tensor of shape (1, C, H, W), normalized to [-1, 1] + models: Dict of loaded models from load_models() + device: Device to use for encoding + + Returns: + Latent tensor (typically shape (C, H//8, W//8) for most VAEs) + """ + pass + + @abstractmethod + def encode_text( + self, + prompt: str, + models: Dict[str, Any], + device: str, + ) -> Dict[str, torch.Tensor]: + """ + Encode text prompt to embeddings. + + Args: + prompt: Text prompt to encode + models: Dict of loaded models from load_models() + device: Device to use for encoding + + Returns: + Dict containing all text embeddings (keys vary by model type) + """ + pass + + @abstractmethod + def verify_latent( + self, + latent: torch.Tensor, + models: Dict[str, Any], + device: str, + ) -> bool: + """ + Verify that a latent can be decoded back to a reasonable image. + + Args: + latent: Encoded latent tensor + models: Dict of loaded models from load_models() + device: Device to use for verification + + Returns: + True if verification passes, False otherwise + """ + pass + + @abstractmethod + def get_cache_data( + self, + latent: torch.Tensor, + text_encodings: Dict[str, torch.Tensor], + metadata: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Construct the cache dictionary to save. + + Args: + latent: Encoded latent tensor + text_encodings: Dict of text embeddings from encode_text() + metadata: Dict containing: + - original_resolution: Tuple[int, int] + - crop_resolution: Tuple[int, int] + - crop_offset: Tuple[int, int] + - prompt: str + - image_path: str + - bucket_id: str + - tier: str + - aspect_ratio: float + + Returns: + Dict to be saved with torch.save() + """ + pass + + def preprocess_image(self, image: Image.Image) -> torch.Tensor: + """ + Convert PIL Image to normalized tensor. + + Default implementation handles standard preprocessing. + Override if model requires different preprocessing. + + Args: + image: PIL Image (RGB) + + Returns: + Tensor of shape (1, 3, H, W), normalized to [-1, 1] + """ + import numpy as np + + image_tensor = torch.from_numpy(np.array(image)).float() / 255.0 + image_tensor = (image_tensor - 0.5) / 0.5 # Normalize to [-1, 1] + + if image_tensor.ndim == 2: + image_tensor = image_tensor.unsqueeze(-1).repeat(1, 1, 3) + + image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) + return image_tensor + + def get_vae_scaling_factor(self, models: Dict[str, Any]) -> float: + """ + Get the VAE scaling factor for this model. + + Args: + models: Dict of loaded models + + Returns: + Scaling factor (typically from vae.config.scaling_factor) + """ + if "vae" in models and hasattr(models["vae"], "config"): + return models["vae"].config.scaling_factor + return 0.18215 # Default for most models diff --git a/dfm/src/automodel/utils/processors/flux.py b/dfm/src/automodel/utils/processors/flux.py new file mode 100644 index 00000000..c189f9d9 --- /dev/null +++ b/dfm/src/automodel/utils/processors/flux.py @@ -0,0 +1,274 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FLUX model processor for preprocessing. + +Handles FLUX.1-dev and similar FLUX architecture models with: +- VAE for image encoding +- CLIP text encoder +- T5 text encoder +""" + +from typing import Any, Dict + +import torch +from torch import autocast + +from .base import BaseModelProcessor +from .registry import ProcessorRegistry + + +@ProcessorRegistry.register("flux") +class FluxProcessor(BaseModelProcessor): + """ + Processor for FLUX.1 architecture models. + + FLUX uses a VAE for image encoding and dual text encoders (CLIP + T5) + for text conditioning. + """ + + @property + def model_type(self) -> str: + return "flux" + + @property + def default_model_name(self) -> str: + return "black-forest-labs/FLUX.1-dev" + + def load_models(self, model_name: str, device: str) -> Dict[str, Any]: + """ + Load FLUX models from FluxPipeline. + + Args: + model_name: HuggingFace model path (e.g., 'black-forest-labs/FLUX.1-dev') + device: Device to load models on + + Returns: + Dict containing: + - vae: AutoencoderKL + - clip_tokenizer: CLIPTokenizer + - clip_encoder: CLIPTextModel + - t5_tokenizer: T5TokenizerFast + - t5_encoder: T5EncoderModel + """ + from diffusers import FluxPipeline + + print(f"[FLUX] Loading models from {model_name} via FluxPipeline...") + + # Load pipeline without transformer (not needed for preprocessing) + pipeline = FluxPipeline.from_pretrained( + model_name, + transformer=None, + torch_dtype=torch.bfloat16, + ) + + models = {} + + print(" Configuring VAE...") + models["vae"] = pipeline.vae.to(device=device, dtype=torch.bfloat16) + models["vae"].eval() + print(f"!!! VAE config: {models['vae'].config}") + print(f"!!! VAE shift_factor: {models['vae'].config.shift_factor}") + print(f"!!! VAE scaling_factor: {models['vae'].config.scaling_factor}") + + # Extract CLIP components + print(" Configuring CLIP...") + models["clip_tokenizer"] = pipeline.tokenizer + models["clip_encoder"] = pipeline.text_encoder.to(device) + models["clip_encoder"].eval() + + # Extract T5 components + print(" Configuring T5...") + models["t5_tokenizer"] = pipeline.tokenizer_2 + models["t5_encoder"] = pipeline.text_encoder_2.to(device) + models["t5_encoder"].eval() + + # Clean up pipeline reference to free memory + del pipeline + torch.cuda.empty_cache() + + print("[FLUX] Models loaded successfully!") + return models + + def encode_image( + self, + image_tensor: torch.Tensor, + models: Dict[str, Any], + device: str, + ) -> torch.Tensor: + """ + Encode image to latent space using VAE. + + Args: + image_tensor: Image tensor (1, 3, H, W), normalized to [-1, 1] + models: Dict containing 'vae' + device: Device to use + + Returns: + Latent tensor (C, H//8, W//8), FP16 + """ + vae = models["vae"] + image_tensor = image_tensor.to(device, dtype=torch.bfloat16) + + device_type = "cuda" if "cuda" in device else "cpu" + + with torch.no_grad(): + latent = vae.encode(image_tensor).latent_dist.sample() + + # Apply scaling factor + latent = (latent - vae.config.shift_factor) * vae.config.scaling_factor + + # Return as FP16 to save space, remove batch dimension + # Use detach() to ensure tensor can be serialized across process boundaries + return latent.detach().cpu().to(torch.float16).squeeze(0) + + def encode_text( + self, + prompt: str, + models: Dict[str, Any], + device: str, + ) -> Dict[str, torch.Tensor]: + """ + Encode text using CLIP and T5. + + Args: + prompt: Text prompt + models: Dict containing tokenizers and encoders + device: Device to use + + Returns: + Dict containing: + - clip_tokens: Token IDs + - clip_hidden: Hidden states from CLIP + - pooled_prompt_embeds: Pooled CLIP output + - t5_tokens: T5 token IDs + - prompt_embeds: T5 hidden states + """ + device_type = "cuda" if "cuda" in device else "cpu" + + # CLIP encoding + clip_tokens = models["clip_tokenizer"]( + prompt, + padding="max_length", + max_length=models["clip_tokenizer"].model_max_length, + truncation=True, + return_tensors="pt", + ) + + clip_output = models["clip_encoder"](clip_tokens.input_ids.to(device_type), output_hidden_states=True) + clip_hidden = clip_output.hidden_states[-2] + pooled_prompt_embeds = clip_output.pooler_output + + # T5 encoding + t5_tokens = models["t5_tokenizer"]( + prompt, + padding="max_length", + max_length=models["t5_tokenizer"].model_max_length, + truncation=True, + return_tensors="pt", + ) + t5_output = models["t5_encoder"](t5_tokens.input_ids.to(device_type), output_hidden_states=False) + prompt_embeds = t5_output.last_hidden_state + + return { + "clip_tokens": clip_tokens["input_ids"].cpu(), + "clip_hidden": clip_hidden.detach().cpu(), + "pooled_prompt_embeds": pooled_prompt_embeds.detach().cpu(), + "t5_tokens": t5_tokens["input_ids"].cpu(), + "prompt_embeds": prompt_embeds.detach().cpu(), + } + + def verify_latent( + self, + latent: torch.Tensor, + models: Dict[str, Any], + device: str, + ) -> bool: + """ + Verify latent can be decoded back to reasonable image. + + Args: + latent: Encoded latent (C, H, W) + models: Dict containing 'vae' + device: Device to use + + Returns: + True if verification passes + """ + try: + vae = models["vae"] + device_type = "cuda" if "cuda" in device else "cpu" + + # Add batch dimension and move to device + latent = latent.unsqueeze(0).to(device).float() + + with torch.no_grad(), autocast(device_type=device_type, dtype=torch.float32): + # Undo scaling + latent = latent / vae.config.scaling_factor + decoded = vae.decode(latent).sample + + # Check shape + _, c, h, w = decoded.shape + if c != 3: + return False + + # Check for NaN/Inf + if torch.isnan(decoded).any() or torch.isinf(decoded).any(): + return False + + return True + + except Exception as e: + print(f"[FLUX] Verification failed: {e}") + return False + + def get_cache_data( + self, + latent: torch.Tensor, + text_encodings: Dict[str, torch.Tensor], + metadata: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Construct cache dictionary for FLUX. + + Args: + latent: Encoded latent + text_encodings: Dict from encode_text() + metadata: Additional metadata + + Returns: + Dict to save with torch.save() + """ + return { + # Image latent + "latent": latent, + # CLIP embeddings + "clip_tokens": text_encodings["clip_tokens"], + "clip_hidden": text_encodings["clip_hidden"], + "pooled_prompt_embeds": text_encodings["pooled_prompt_embeds"], + # T5 embeddings + "t5_tokens": text_encodings["t5_tokens"], + "prompt_embeds": text_encodings["prompt_embeds"], + # Metadata + "original_resolution": metadata["original_resolution"], + "crop_resolution": metadata["crop_resolution"], + "crop_offset": metadata["crop_offset"], + "prompt": metadata["prompt"], + "image_path": metadata["image_path"], + "bucket_id": metadata["bucket_id"], + "aspect_ratio": metadata["aspect_ratio"], + # Model info + "model_type": self.model_type, + } diff --git a/dfm/src/automodel/utils/processors/registry.py b/dfm/src/automodel/utils/processors/registry.py new file mode 100644 index 00000000..bffb3920 --- /dev/null +++ b/dfm/src/automodel/utils/processors/registry.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Processor registry for model-agnostic preprocessing. + +This module provides a registry pattern for discovering and instantiating +model-specific processors at runtime. +""" + +from typing import Dict, List, Type + +from .base import BaseModelProcessor + + +class ProcessorRegistry: + """ + Registry for model processors. + + Allows registering processor classes by name and retrieving them at runtime. + Uses a decorator pattern for easy registration. + + Example: + @ProcessorRegistry.register("flux") + class FluxProcessor(BaseModelProcessor): + ... + + # Later + processor = ProcessorRegistry.get("flux") + """ + + _processors: Dict[str, Type[BaseModelProcessor]] = {} + + @classmethod + def register(cls, name: str): + """ + Decorator to register a processor class. + + Args: + name: Name to register the processor under (e.g., 'flux', 'sdxl') + + Returns: + Decorator function + + Example: + @ProcessorRegistry.register("my_model") + class MyModelProcessor(BaseModelProcessor): + ... + """ + + def decorator(processor_class: Type[BaseModelProcessor]): + if not issubclass(processor_class, BaseModelProcessor): + raise TypeError(f"Processor {processor_class.__name__} must inherit from BaseModelProcessor") + cls._processors[name] = processor_class + return processor_class + + return decorator + + @classmethod + def get(cls, name: str) -> BaseModelProcessor: + """ + Get a processor instance by name. + + Args: + name: Registered processor name + + Returns: + Instantiated processor + + Raises: + ValueError: If processor name is not registered + """ + if name not in cls._processors: + available = ", ".join(sorted(cls._processors.keys())) + raise ValueError(f"Unknown processor: '{name}'. Available processors: {available}") + return cls._processors[name]() + + @classmethod + def get_class(cls, name: str) -> Type[BaseModelProcessor]: + """ + Get a processor class by name (without instantiating). + + Args: + name: Registered processor name + + Returns: + Processor class + + Raises: + ValueError: If processor name is not registered + """ + if name not in cls._processors: + available = ", ".join(sorted(cls._processors.keys())) + raise ValueError(f"Unknown processor: '{name}'. Available processors: {available}") + return cls._processors[name] + + @classmethod + def list_available(cls) -> List[str]: + """ + List all registered processor names. + + Returns: + List of registered processor names + """ + return sorted(cls._processors.keys()) + + @classmethod + def is_registered(cls, name: str) -> bool: + """ + Check if a processor is registered. + + Args: + name: Processor name to check + + Returns: + True if registered, False otherwise + """ + return name in cls._processors diff --git a/examples/automodel/finetune/hunyuan_t2v_flow.yaml b/examples/automodel/finetune/hunyuan_t2v_flow.yaml index 7f0642d1..70da74fc 100644 --- a/examples/automodel/finetune/hunyuan_t2v_flow.yaml +++ b/examples/automodel/finetune/hunyuan_t2v_flow.yaml @@ -1,16 +1,9 @@ -# HunyuanVideo-1.5 720p T2V Training Configuration -# -# This configuration file is fully compatible with TrainDiffusionRecipe class -# (dfm/src/automodel/recipes/train.py) using FlowMatchingPipelineV2 - -# Model configuration model: pretrained_model_name_or_path: "hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v" - mode: "finetune" # "finetune" or "pretrain" - cache_dir: null # Optional: specify cache directory for model weights + mode: "finetune" + cache_dir: null attention_backend: "_flash_3_hub" -# Optimizer configuration optim: learning_rate: 5e-6 @@ -18,45 +11,41 @@ optim: weight_decay: 0.01 betas: [0.9, 0.999] -# FSDP (Fully Sharded Data Parallel) configuration fsdp: - dp_size: 8 # Auto-calculate based on world_size and other parallel dimensions + dp_size: 8 dp_replicate_size: 1 - tp_size: 1 # Tensor parallelism size - cp_size: 1 # Context parallelism size - pp_size: 1 # Pipeline parallelism size + tp_size: 1 + cp_size: 1 + pp_size: 1 cpu_offload: false activation_checkpointing: true use_hf_tp_plan: false -# Flow matching V2 configuration flow_matching: - adapter_type: "hunyuan" # Options: "hunyuan", "simple" + adapter_type: "hunyuan" adapter_kwargs: use_condition_latents: true default_image_embed_shape: [729, 1152] - timestep_sampling: "logit_normal" # Options: "uniform", "logit_normal", "lognorm", "mix", "mode" + timestep_sampling: "logit_normal" logit_mean: 0.0 logit_std: 1.0 - flow_shift: 3.0 # Flow shift for training - mix_uniform_ratio: 0.1 # For "mix" timestep sampling + flow_shift: 3.0 + mix_uniform_ratio: 0.1 sigma_min: 0.0 sigma_max: 1.0 num_train_timesteps: 1000 i2v_prob: 0.3 use_loss_weighting: false - log_interval: 1000 # Steps between detailed logs - summary_log_interval: 100 # Steps between summary logs + log_interval: 1000 + summary_log_interval: 100 -# Training step scheduler configuration step_scheduler: num_epochs: 30 - local_batch_size: 1 # Batch size per GPU - global_batch_size: 8 # Effective batch size across all GPUs (with gradient accumulation) - ckpt_every_steps: 1000 # Save checkpoint every N steps - log_every: 10 # Log metrics every N steps + local_batch_size: 1 + global_batch_size: 8 + ckpt_every_steps: 1000 + log_every: 10 -# Data configuration data: dataloader: _target_: dfm.src.automodel.datasets.build_dataloader @@ -64,7 +53,6 @@ data: num_workers: 2 device: cpu -# Checkpoint configuration checkpoint: enabled: true checkpoint_dir: /opt/DFM/hunyuan_t2v_flow_outputs_base_recipe_flowPipelineV2/ @@ -77,10 +65,8 @@ wandb: mode: online name: 720p_t2v_run -# Distributed environment configuration dist_env: backend: "nccl" init_method: "env://" -# Random seed seed: 42 diff --git a/examples/automodel/finetune/wan2_1_t2v_flow.yaml b/examples/automodel/finetune/wan2_1_t2v_flow.yaml index 525cacf1..fa3ca082 100644 --- a/examples/automodel/finetune/wan2_1_t2v_flow.yaml +++ b/examples/automodel/finetune/wan2_1_t2v_flow.yaml @@ -11,6 +11,7 @@ dist_env: model: pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + mode: finetune # "finetune" loads pretrained weights, "pretrain" initializes random weights step_scheduler: global_batch_size: 8 diff --git a/examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml b/examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml index 47d8e975..76c88bfd 100644 --- a/examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml +++ b/examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml @@ -11,6 +11,7 @@ dist_env: model: pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + mode: finetune # "finetune" loads pretrained weights, "pretrain" initializes random weights step_scheduler: global_batch_size: 8 diff --git a/examples/automodel/pretrain/cicd/wan21_cicd_nightly_image.yaml b/examples/automodel/pretrain/cicd/wan21_cicd_nightly_image.yaml index a12c1b9c..80c66fcf 100644 --- a/examples/automodel/pretrain/cicd/wan21_cicd_nightly_image.yaml +++ b/examples/automodel/pretrain/cicd/wan21_cicd_nightly_image.yaml @@ -11,7 +11,13 @@ dist_env: model: pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers - mode: pretrain + mode: pretrain # "pretrain" initializes with random weights using pipeline_spec + # Pipeline specification for pretraining (required when mode: pretrain) + pipeline_spec: + transformer_cls: "WanTransformer3DModel" + subfolder: "transformer" + enable_gradient_checkpointing: false + load_full_pipeline: false step_scheduler: global_batch_size: 8 diff --git a/examples/automodel/pretrain/cicd/wan21_cicd_nightly_video.yaml b/examples/automodel/pretrain/cicd/wan21_cicd_nightly_video.yaml index 38decd7f..0238752d 100644 --- a/examples/automodel/pretrain/cicd/wan21_cicd_nightly_video.yaml +++ b/examples/automodel/pretrain/cicd/wan21_cicd_nightly_video.yaml @@ -11,7 +11,13 @@ dist_env: model: pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers - mode: pretrain + mode: pretrain # "pretrain" initializes with random weights using pipeline_spec + # Pipeline specification for pretraining (required when mode: pretrain) + pipeline_spec: + transformer_cls: "WanTransformer3DModel" + subfolder: "transformer" + enable_gradient_checkpointing: true + load_full_pipeline: false step_scheduler: global_batch_size: 8 diff --git a/examples/automodel/pretrain/flux_t2i_flow.yaml b/examples/automodel/pretrain/flux_t2i_flow.yaml new file mode 100644 index 00000000..f444b7a2 --- /dev/null +++ b/examples/automodel/pretrain/flux_t2i_flow.yaml @@ -0,0 +1,80 @@ +model: + pretrained_model_name_or_path: "black-forest-labs/FLUX.1-dev" + mode: "pretrain" + cache_dir: null + attention_backend: "_flash_3_hub" + + pipeline_spec: + transformer_cls: "FluxTransformer2DModel" + subfolder: "transformer" + load_full_pipeline: false + enable_gradient_checkpointing: false + +optim: + learning_rate: 1e-5 + + optimizer: + weight_decay: 0.01 + betas: [0.9, 0.999] + +fsdp: + dp_size: 8 + tp_size: 1 + cp_size: 1 + pp_size: 1 + activation_checkpointing: false + cpu_offload: false + +flow_matching: + adapter_type: "flux" + adapter_kwargs: + guidance_scale: 3.5 + use_guidance_embeds: true + timestep_sampling: "logit_normal" + logit_mean: 0.0 + logit_std: 1.0 + flow_shift: 3.0 + mix_uniform_ratio: 0.1 + sigma_min: 0.0 + sigma_max: 1.0 + num_train_timesteps: 1000 + i2v_prob: 0.0 + use_loss_weighting: true + log_interval: 100 + summary_log_interval: 10 + +step_scheduler: + num_epochs: 5000 + local_batch_size: 1 + global_batch_size: 8 + ckpt_every_steps: 2000 + log_every: 1 + +data: + dataloader: + _target_: dfm.src.automodel.datasets.multiresolutionDataloader.build_flux_multiresolution_dataloader + cache_dir: /lustre/fsw/portfolios/coreai/users/pthombre/Automodel/FluxTraining/DFM/FluxData512Full/ + train_text_encoder: false + num_workers: 10 + base_resolution: [512, 512] + dynamic_batch_size: false + shuffle: true + drop_last: false + +checkpoint: + enabled: true + checkpoint_dir: /lustre/fsw/portfolios/coreai/users/pthombre/Automodel/FluxTraining/DFM/flux_ddp_test/ + model_save_format: torch_save + save_consolidated: false + restore_from: null + +wandb: + project: flux-pretraining + mode: online + name: flux_pretrain_ddp_test_run_1 + +dist_env: + backend: "nccl" + init_method: "env://" + +seed: 42 diff --git a/examples/automodel/pretrain/wan2_1_t2v_flow.yaml b/examples/automodel/pretrain/wan2_1_t2v_flow.yaml index 3f0e9d2b..2f9ff18c 100644 --- a/examples/automodel/pretrain/wan2_1_t2v_flow.yaml +++ b/examples/automodel/pretrain/wan2_1_t2v_flow.yaml @@ -12,6 +12,11 @@ dist_env: model: pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers mode: pretrain + pipeline_spec: + transformer_cls: "WanTransformer3DModel" + subfolder: "transformer" + enable_gradient_checkpointing: false + load_full_pipeline: false step_scheduler: global_batch_size: 8 diff --git a/pyproject.toml b/pyproject.toml index fccd7cc9..365eef57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ dependencies = [ "imageio-ffmpeg", "opencv-python-headless==4.10.0.84", "megatron-energon", + "sentencepiece" ] [build-system] diff --git a/tests/functional_tests/automodel/wan21/mock_configs/wan2_1_t2v_flow_mock.yaml b/tests/functional_tests/automodel/wan21/mock_configs/wan2_1_t2v_flow_mock.yaml index de3f45e0..85efa366 100644 --- a/tests/functional_tests/automodel/wan21/mock_configs/wan2_1_t2v_flow_mock.yaml +++ b/tests/functional_tests/automodel/wan21/mock_configs/wan2_1_t2v_flow_mock.yaml @@ -29,6 +29,11 @@ dist_env: model: pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers mode: pretrain + pipeline_spec: + transformer_cls: "WanTransformer3DModel" + subfolder: "transformer" + enable_gradient_checkpointing: false + load_full_pipeline: false step_scheduler: global_batch_size: 2 diff --git a/tests/unit_tests/automodel/adapters/test_hunyuan_adapter.py b/tests/unit_tests/automodel/adapters/test_hunyuan_adapter.py index 58cd0dc9..292f3592 100644 --- a/tests/unit_tests/automodel/adapters/test_hunyuan_adapter.py +++ b/tests/unit_tests/automodel/adapters/test_hunyuan_adapter.py @@ -133,13 +133,14 @@ def create_context(batch, task_type="t2v", data_type="video"): """Helper to create FlowMatchingContext.""" return FlowMatchingContext( noisy_latents=torch.randn(batch["video_latents"].shape), - video_latents=batch["video_latents"], + latents=batch["video_latents"], timesteps=torch.rand(batch["video_latents"].shape[0]) * 1000, sigma=torch.rand(batch["video_latents"].shape[0]), task_type=task_type, data_type=data_type, device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) @@ -295,13 +296,14 @@ def test_prepare_inputs_2d_text_embeddings(self, hunyuan_adapter): context = FlowMatchingContext( noisy_latents=torch.randn(1, 16, 4, 8, 8), - video_latents=batch["video_latents"].unsqueeze(0), + latents=batch["video_latents"].unsqueeze(0), timesteps=torch.rand(1) * 1000, sigma=torch.rand(1), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) @@ -315,13 +317,14 @@ def test_prepare_inputs_dtype_conversion(self, hunyuan_adapter, sample_batch): """Test that inputs are converted to correct dtype.""" context = FlowMatchingContext( noisy_latents=torch.randn(2, 16, 4, 8, 8), - video_latents=sample_batch["video_latents"], + latents=sample_batch["video_latents"], timesteps=torch.rand(2) * 1000, sigma=torch.rand(2), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=torch.bfloat16, + cfg_dropout_prob=0.0, batch=sample_batch, ) diff --git a/tests/unit_tests/automodel/adapters/test_model_adapter_base.py b/tests/unit_tests/automodel/adapters/test_model_adapter_base.py index ee20582d..407f22c4 100644 --- a/tests/unit_tests/automodel/adapters/test_model_adapter_base.py +++ b/tests/unit_tests/automodel/adapters/test_model_adapter_base.py @@ -40,13 +40,14 @@ def test_context_creation(self): context = FlowMatchingContext( noisy_latents=torch.randn(2, 16, 4, 8, 8), - video_latents=torch.randn(2, 16, 4, 8, 8), + latents=torch.randn(2, 16, 4, 8, 8), timesteps=torch.rand(2) * 1000, sigma=torch.rand(2), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) @@ -62,13 +63,14 @@ def test_context_with_i2v_task(self): context = FlowMatchingContext( noisy_latents=torch.randn(2, 16, 4, 8, 8), - video_latents=torch.randn(2, 16, 4, 8, 8), + latents=torch.randn(2, 16, 4, 8, 8), timesteps=torch.rand(2) * 1000, sigma=torch.rand(2), task_type="i2v", data_type="video", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) @@ -80,13 +82,14 @@ def test_context_with_image_data(self): context = FlowMatchingContext( noisy_latents=torch.randn(2, 16, 1, 8, 8), - video_latents=torch.randn(2, 16, 1, 8, 8), + latents=torch.randn(2, 16, 1, 8, 8), timesteps=torch.rand(2) * 1000, sigma=torch.rand(2), task_type="t2v", data_type="image", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) @@ -102,13 +105,14 @@ def test_context_batch_access(self): context = FlowMatchingContext( noisy_latents=torch.randn(2, 16, 4, 8, 8), - video_latents=batch["video_latents"], + latents=batch["video_latents"], timesteps=torch.rand(2) * 1000, sigma=torch.rand(2), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) @@ -127,18 +131,19 @@ def test_context_tensor_shapes(self): batch = {"video_latents": torch.randn(shape)} context = FlowMatchingContext( noisy_latents=torch.randn(shape), - video_latents=torch.randn(shape), + latents=torch.randn(shape), timesteps=torch.rand(shape[0]) * 1000, sigma=torch.rand(shape[0]), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) assert context.noisy_latents.shape == shape - assert context.video_latents.shape == shape + assert context.latents.shape == shape assert context.timesteps.shape == (shape[0],) assert context.sigma.shape == (shape[0],) @@ -150,13 +155,14 @@ def test_context_different_dtypes(self): batch = {"video_latents": torch.randn(2, 16, 4, 8, 8)} context = FlowMatchingContext( noisy_latents=torch.randn(2, 16, 4, 8, 8), - video_latents=torch.randn(2, 16, 4, 8, 8), + latents=torch.randn(2, 16, 4, 8, 8), timesteps=torch.rand(2) * 1000, sigma=torch.rand(2), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=dtype, + cfg_dropout_prob=0.0, batch=batch, ) diff --git a/tests/unit_tests/automodel/adapters/test_simple_adapter.py b/tests/unit_tests/automodel/adapters/test_simple_adapter.py index 473cc8d3..5f1a96bd 100644 --- a/tests/unit_tests/automodel/adapters/test_simple_adapter.py +++ b/tests/unit_tests/automodel/adapters/test_simple_adapter.py @@ -79,13 +79,14 @@ def sample_context(): } return FlowMatchingContext( noisy_latents=torch.randn(2, 16, 4, 8, 8), - video_latents=batch["video_latents"], + latents=batch["video_latents"], timesteps=torch.rand(2) * 1000, sigma=torch.rand(2), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) @@ -137,13 +138,14 @@ def test_prepare_inputs_2d_text_embeddings(self, simple_adapter): context = FlowMatchingContext( noisy_latents=torch.randn(1, 16, 4, 8, 8), - video_latents=batch["video_latents"].unsqueeze(0), + latents=batch["video_latents"].unsqueeze(0), timesteps=torch.rand(1) * 1000, sigma=torch.rand(1), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) @@ -163,13 +165,14 @@ def test_prepare_inputs_different_batch_sizes(self, simple_adapter): context = FlowMatchingContext( noisy_latents=torch.randn(batch_size, 16, 4, 8, 8), - video_latents=batch["video_latents"], + latents=batch["video_latents"], timesteps=torch.rand(batch_size) * 1000, sigma=torch.rand(batch_size), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) @@ -189,13 +192,14 @@ def test_prepare_inputs_different_dtypes(self, simple_adapter): context = FlowMatchingContext( noisy_latents=torch.randn(2, 16, 4, 8, 8), - video_latents=batch["video_latents"], + latents=batch["video_latents"], timesteps=torch.rand(2) * 1000, sigma=torch.rand(2), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=dtype, + cfg_dropout_prob=0.0, batch=batch, ) @@ -247,13 +251,14 @@ def test_forward_output_shape(self, simple_adapter, mock_model): context = FlowMatchingContext( noisy_latents=torch.randn(shape), - video_latents=batch["video_latents"], + latents=batch["video_latents"], timesteps=torch.rand(shape[0]) * 1000, sigma=torch.rand(shape[0]), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) @@ -302,13 +307,14 @@ def test_full_workflow(self, simple_adapter, mock_model): context = FlowMatchingContext( noisy_latents=torch.randn(2, 16, 4, 8, 8), - video_latents=batch["video_latents"], + latents=batch["video_latents"], timesteps=torch.rand(2) * 1000, sigma=torch.rand(2), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) @@ -337,13 +343,14 @@ def test_multiple_forward_passes(self, simple_adapter, mock_model): context = FlowMatchingContext( noisy_latents=torch.randn(2, 16, 4, 8, 8), - video_latents=batch["video_latents"], + latents=batch["video_latents"], timesteps=torch.rand(2) * 1000, sigma=torch.rand(2), task_type="t2v", data_type="video", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) @@ -364,13 +371,14 @@ def test_with_different_task_types(self, simple_adapter, mock_model): context = FlowMatchingContext( noisy_latents=torch.randn(2, 16, 4, 8, 8), - video_latents=batch["video_latents"], + latents=batch["video_latents"], timesteps=torch.rand(2) * 1000, sigma=torch.rand(2), task_type=task_type, data_type="video", device=torch.device("cpu"), dtype=torch.float32, + cfg_dropout_prob=0.0, batch=batch, ) diff --git a/tests/unit_tests/automodel/data/test_dataloader.py b/tests/unit_tests/automodel/data/test_dataloader.py index f81913c2..8c41a38a 100644 --- a/tests/unit_tests/automodel/data/test_dataloader.py +++ b/tests/unit_tests/automodel/data/test_dataloader.py @@ -70,8 +70,8 @@ def create_sample( "prompt": f"Test prompt {idx}", "image_path": f"/fake/path/image_{idx}.jpg", "clip_hidden": torch.randn(1, 77, 768), - "clip_pooled": torch.randn(1, 768), - "t5_hidden": torch.randn(1, 256, 4096), + "pooled_prompt_embeds": torch.randn(1, 768), + "prompt_embeds": torch.randn(1, 256, 4096), "clip_tokens": torch.randint(0, 49408, (1, 77)), "t5_tokens": torch.randint(0, 32128, (1, 256)), } @@ -547,8 +547,8 @@ def test_collate_handles_embeddings(self, simple_dataset): if "clip_hidden" in items[0]: assert batch["clip_hidden"].shape[0] == 4 - assert batch["clip_pooled"].shape[0] == 4 - assert batch["t5_hidden"].shape[0] == 4 + assert batch["pooled_prompt_embeds"].shape[0] == 4 + assert batch["prompt_embeds"].shape[0] == 4 def test_collate_same_resolution_required(self, multi_resolution_dataset): """Test collate requires same resolution in batch.""" @@ -800,12 +800,12 @@ def test_dataloader_batch_to_gpu(self, simple_dataset): # Check embeddings if present if "clip_hidden" in batch: clip_hidden_gpu = batch["clip_hidden"].to(device) - clip_pooled_gpu = batch["clip_pooled"].to(device) - t5_hidden_gpu = batch["t5_hidden"].to(device) + pooled_prompt_embeds_gpu = batch["pooled_prompt_embeds"].to(device) + prompt_embeds_gpu = batch["prompt_embeds"].to(device) assert clip_hidden_gpu.is_cuda - assert clip_pooled_gpu.is_cuda - assert t5_hidden_gpu.is_cuda + assert pooled_prompt_embeds_gpu.is_cuda + assert prompt_embeds_gpu.is_cuda break diff --git a/tests/unit_tests/automodel/data/test_text_to_image_dataset.py b/tests/unit_tests/automodel/data/test_text_to_image_dataset.py index 69968208..53704775 100644 --- a/tests/unit_tests/automodel/data/test_text_to_image_dataset.py +++ b/tests/unit_tests/automodel/data/test_text_to_image_dataset.py @@ -52,8 +52,8 @@ def create_sample( "prompt": f"Test prompt {idx}", "image_path": f"/fake/path/image_{idx}.jpg", "clip_hidden": torch.randn(1, 77, 768), - "clip_pooled": torch.randn(1, 768), - "t5_hidden": torch.randn(1, 256, 4096), + "pooled_prompt_embeds": torch.randn(1, 768), + "prompt_embeds": torch.randn(1, 256, 4096), "clip_tokens": torch.randint(0, 49408, (1, 77)), "t5_tokens": torch.randint(0, 32128, (1, 256)), } @@ -250,8 +250,8 @@ def test_getitem_has_required_fields_embeddings(self, simple_cache_dir): "bucket_id", "aspect_ratio", "clip_hidden", - "clip_pooled", - "t5_hidden", + "pooled_prompt_embeds", + "prompt_embeds", } assert required_fields.issubset(item.keys()) @@ -309,11 +309,11 @@ def test_getitem_embeddings_shapes(self, simple_cache_dir): assert item["clip_hidden"].dim() == 2 assert item["clip_hidden"].shape[0] == 77 - # CLIP pooled should be [768] - assert item["clip_pooled"].dim() == 1 + # Pooled prompt embeds should be [768] + assert item["pooled_prompt_embeds"].dim() == 1 - # T5 hidden should be [256, 4096] - assert item["t5_hidden"].dim() == 2 + # Prompt embeds should be [256, 4096] + assert item["prompt_embeds"].dim() == 2 def test_getitem_tokens_shapes(self, simple_cache_dir): """Test token shapes when train_text_encoder=True.""" diff --git a/tests/unit_tests/automodel/test_flow_matching_pipeline.py b/tests/unit_tests/automodel/test_flow_matching_pipeline.py index 2977cd15..5a4f462e 100644 --- a/tests/unit_tests/automodel/test_flow_matching_pipeline.py +++ b/tests/unit_tests/automodel/test_flow_matching_pipeline.py @@ -335,14 +335,14 @@ def test_loss_weighting_enabled(self, simple_adapter): sigma = torch.tensor([0.3, 0.7]) batch = {} - # Returns: weighted_loss, average_weighted_loss, unweighted_loss, average_unweighted_loss, loss_weight, loss_mask - _, scalar_weighted_loss, _, scalar_unweighted_loss, loss_weight, _ = pipeline.compute_loss( - model_pred, target, sigma, batch + # Returns 6 values for megatron compatibility + weighted_loss, average_weighted_loss, unweighted_loss, average_unweighted_loss, loss_weight, _ = ( + pipeline.compute_loss(model_pred, target, sigma, batch) ) - # Verify shapes - assert scalar_weighted_loss.ndim == 0, "Weighted loss should be scalar" - assert scalar_unweighted_loss.ndim == 0, "Unweighted loss should be scalar" + # Verify shapes - average losses should be scalar + assert average_weighted_loss.ndim == 0, "Average weighted loss should be scalar" + assert average_unweighted_loss.ndim == 0, "Average unweighted loss should be scalar" # Verify weight formula: w = 1 + shift * σ expected_weights = 1.0 + 3.0 * sigma @@ -363,12 +363,10 @@ def test_loss_weighting_disabled(self, simple_adapter): sigma = torch.tensor([0.3, 0.7]) batch = {} - _, scalar_weighted_loss, _, scalar_unweighted_loss, loss_weight, _ = pipeline.compute_loss( - model_pred, target, sigma, batch - ) + weighted_loss, _, unweighted_loss, _, loss_weight, _ = pipeline.compute_loss(model_pred, target, sigma, batch) # Without weighting, weighted loss should equal unweighted loss - assert torch.allclose(scalar_weighted_loss, scalar_unweighted_loss, atol=1e-6) + assert torch.allclose(weighted_loss, unweighted_loss, atol=1e-6) # All weights should be 1 assert torch.allclose(loss_weight, torch.ones_like(loss_weight)) @@ -406,12 +404,12 @@ def test_loss_is_non_negative(self, simple_adapter): sigma = torch.rand(2) batch = {} - _, scalar_weighted_loss, _, scalar_unweighted_loss, _, _ = pipeline.compute_loss( + _, average_weighted_loss, _, average_unweighted_loss, _, _ = pipeline.compute_loss( model_pred, target, sigma, batch ) - assert scalar_weighted_loss >= 0, "Weighted loss should be non-negative" - assert scalar_unweighted_loss >= 0, "Unweighted loss should be non-negative" + assert average_weighted_loss >= 0, "Weighted loss should be non-negative" + assert average_unweighted_loss >= 0, "Unweighted loss should be non-negative" def test_loss_is_finite(self, simple_adapter): """Test that computed loss is finite.""" @@ -422,12 +420,12 @@ def test_loss_is_finite(self, simple_adapter): sigma = torch.rand(2) batch = {} - _, scalar_weighted_loss, _, scalar_unweighted_loss, _, _ = pipeline.compute_loss( + _, average_weighted_loss, _, average_unweighted_loss, _, _ = pipeline.compute_loss( model_pred, target, sigma, batch ) - assert torch.isfinite(scalar_weighted_loss), "Weighted loss should be finite" - assert torch.isfinite(scalar_unweighted_loss), "Unweighted loss should be finite" + assert torch.isfinite(average_weighted_loss), "Weighted loss should be finite" + assert torch.isfinite(average_unweighted_loss), "Unweighted loss should be finite" def test_loss_mse_correctness(self, simple_adapter): """Test that base loss is MSE.""" @@ -441,12 +439,12 @@ def test_loss_mse_correctness(self, simple_adapter): sigma = torch.rand(2) batch = {} - _, _, _, scalar_unweighted_loss, _, _ = pipeline.compute_loss(model_pred, target, sigma, batch) + _, _, _, average_unweighted_loss, _, _ = pipeline.compute_loss(model_pred, target, sigma, batch) # Manual MSE calculation expected_mse = nn.functional.mse_loss(model_pred.float(), target.float()) - assert torch.allclose(scalar_unweighted_loss, expected_mse, atol=1e-6) + assert torch.allclose(average_unweighted_loss, expected_mse, atol=1e-6) class TestFullTrainingStep: @@ -458,13 +456,16 @@ def test_basic_training_step(self, pipeline, mock_model, sample_batch): dtype = torch.bfloat16 # Returns: weighted_loss, average_weighted_loss, loss_mask, metrics - _, loss, _, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0) + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( + mock_model, sample_batch, device, dtype, global_step=0 + ) # Verify loss - assert isinstance(loss, torch.Tensor), "Loss should be a tensor" - assert loss.ndim == 0, "Loss should be scalar" - assert not torch.isnan(loss), "Loss should not be NaN" - assert torch.isfinite(loss), "Loss should be finite" + assert isinstance(weighted_loss, torch.Tensor), "Weighted loss should be a tensor" + assert isinstance(average_weighted_loss, torch.Tensor), "Average weighted loss should be a tensor" + assert average_weighted_loss.ndim == 0, "Average weighted loss should be scalar" + assert not torch.isnan(average_weighted_loss), "Loss should not be NaN" + assert torch.isfinite(average_weighted_loss), "Loss should be finite" # Verify metrics assert isinstance(metrics, dict), "Metrics should be a dictionary" @@ -476,7 +477,7 @@ def test_basic_training_step(self, pipeline, mock_model, sample_batch): assert "timestep_min" in metrics assert "timestep_max" in metrics assert "sampling_method" in metrics - print(f"✓ Basic training step test passed - Loss: {loss.item():.4f}") + print(f"✓ Basic training step test passed - Loss: {average_weighted_loss.item():.4f}") def test_step_with_different_batch_sizes(self, simple_adapter, mock_model): """Test training step with different batch sizes.""" @@ -494,10 +495,12 @@ def test_step_with_different_batch_sizes(self, simple_adapter, mock_model): "text_embeddings": torch.randn(batch_size, 77, 4096), } - _, loss, _, metrics = pipeline.step(mock_model, batch, device, dtype, global_step=0) + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( + mock_model, batch, device, dtype, global_step=0 + ) - assert isinstance(loss, torch.Tensor), f"Loss should be tensor for batch_size={batch_size}" - assert not torch.isnan(loss), f"Loss should not be NaN for batch_size={batch_size}" + assert isinstance(weighted_loss, torch.Tensor), f"Loss should be tensor for batch_size={batch_size}" + assert not torch.isnan(average_weighted_loss), f"Loss should not be NaN for batch_size={batch_size}" def test_step_with_4d_video_latents(self, pipeline, mock_model): """Test that 4D video latents are handled (unsqueezed to 5D).""" @@ -509,17 +512,21 @@ def test_step_with_4d_video_latents(self, pipeline, mock_model): "text_embeddings": torch.randn(77, 4096), # 2D instead of 3D } - _, loss, _, metrics = pipeline.step(mock_model, batch, device, dtype, global_step=0) + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( + mock_model, batch, device, dtype, global_step=0 + ) - assert isinstance(loss, torch.Tensor) - assert not torch.isnan(loss) + assert isinstance(weighted_loss, torch.Tensor) + assert not torch.isnan(average_weighted_loss) def test_step_metrics_collection(self, pipeline, mock_model, sample_batch): """Test that all expected metrics are collected.""" device = torch.device("cpu") dtype = torch.bfloat16 - _, loss, _, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=100) + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( + mock_model, sample_batch, device, dtype, global_step=100 + ) expected_keys = [ "loss", @@ -546,7 +553,9 @@ def test_step_sigma_in_valid_range(self, pipeline, mock_model, sample_batch): device = torch.device("cpu") dtype = torch.bfloat16 - _, loss, _, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0) + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( + mock_model, sample_batch, device, dtype, global_step=0 + ) assert 0.0 <= metrics["sigma_min"] <= 1.0, "Sigma min should be in [0, 1]" assert 0.0 <= metrics["sigma_max"] <= 1.0, "Sigma max should be in [0, 1]" @@ -564,7 +573,9 @@ def test_step_timesteps_in_valid_range(self, simple_adapter, mock_model, sample_ device = torch.device("cpu") dtype = torch.bfloat16 - _, loss, _, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0) + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( + mock_model, sample_batch, device, dtype, global_step=0 + ) assert 0.0 <= metrics["timestep_min"] <= num_timesteps assert 0.0 <= metrics["timestep_max"] <= num_timesteps @@ -574,7 +585,9 @@ def test_step_noisy_latents_are_finite(self, pipeline, mock_model, sample_batch) device = torch.device("cpu") dtype = torch.bfloat16 - _, loss, _, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0) + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( + mock_model, sample_batch, device, dtype, global_step=0 + ) assert torch.isfinite(torch.tensor(metrics["noisy_min"])), "Noisy min should be finite" assert torch.isfinite(torch.tensor(metrics["noisy_max"])), "Noisy max should be finite" @@ -584,10 +597,12 @@ def test_step_with_image_batch(self, pipeline, mock_model, image_batch): device = torch.device("cpu") dtype = torch.bfloat16 - _, loss, _, metrics = pipeline.step(mock_model, image_batch, device, dtype, global_step=0) + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( + mock_model, image_batch, device, dtype, global_step=0 + ) - assert isinstance(loss, torch.Tensor) - assert not torch.isnan(loss) + assert isinstance(weighted_loss, torch.Tensor) + assert not torch.isnan(average_weighted_loss) assert metrics["data_type"] == "image" assert metrics["task_type"] == "t2v" # Image always uses t2v @@ -691,9 +706,11 @@ def test_empty_batch_handling(self, simple_adapter): } mock_model = MockModel() - _, loss, _, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=0) + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( + mock_model, batch, torch.device("cpu"), torch.float32, global_step=0 + ) - assert not torch.isnan(loss) + assert not torch.isnan(average_weighted_loss) def test_large_batch_handling(self, simple_adapter): """Test handling of larger batch sizes.""" @@ -709,9 +726,11 @@ def test_large_batch_handling(self, simple_adapter): } mock_model = MockModel() - _, loss, _, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=0) + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( + mock_model, batch, torch.device("cpu"), torch.float32, global_step=0 + ) - assert not torch.isnan(loss) + assert not torch.isnan(average_weighted_loss) def test_extreme_flow_shift_values(self, simple_adapter): """Test with extreme flow shift values.""" @@ -732,9 +751,11 @@ def test_extreme_flow_shift_values(self, simple_adapter): } mock_model = MockModel() - _, loss, _, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=0) + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( + mock_model, batch, torch.device("cpu"), torch.float32, global_step=0 + ) - assert torch.isfinite(loss), f"Loss should be finite for shift={shift}" + assert torch.isfinite(average_weighted_loss), f"Loss should be finite for shift={shift}" def test_sigma_clamping_edge_cases(self, simple_adapter): """Test sigma clamping at boundary values.""" @@ -771,13 +792,13 @@ def test_multiple_training_steps(self, simple_adapter): "text_embeddings": torch.randn(2, 77, 4096), } - _, loss, _, metrics = pipeline.step( + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( mock_model, batch, torch.device("cpu"), torch.float32, global_step=step ) - losses.append(loss.item()) + losses.append(average_weighted_loss.item()) - assert not torch.isnan(loss), f"Loss became NaN at step {step}" - assert torch.isfinite(loss), f"Loss became infinite at step {step}" + assert not torch.isnan(average_weighted_loss), f"Loss became NaN at step {step}" + assert torch.isfinite(average_weighted_loss), f"Loss became infinite at step {step}" def test_pipeline_with_all_sampling_methods(self, simple_adapter): """Test pipeline works with all sampling methods.""" @@ -797,9 +818,11 @@ def test_pipeline_with_all_sampling_methods(self, simple_adapter): "text_embeddings": torch.randn(2, 77, 4096), } - _, loss, _, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=0) + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step( + mock_model, batch, torch.device("cpu"), torch.float32, global_step=0 + ) - assert not torch.isnan(loss), f"Loss should not be NaN for method={method}" + assert not torch.isnan(average_weighted_loss), f"Loss should not be NaN for method={method}" def test_pipeline_state_consistency(self, simple_adapter): """Test that pipeline maintains consistent state.""" diff --git a/uv.lock b/uv.lock index 7d341eba..7757604f 100644 --- a/uv.lock +++ b/uv.lock @@ -3642,6 +3642,7 @@ dependencies = [ { name = "kernels" }, { name = "megatron-energon" }, { name = "opencv-python-headless" }, + { name = "sentencepiece" }, ] [package.dev-dependencies] @@ -3704,6 +3705,7 @@ requires-dist = [ { name = "kernels" }, { name = "megatron-energon" }, { name = "opencv-python-headless", specifier = "==4.10.0.84" }, + { name = "sentencepiece" }, ] [package.metadata.requires-dev]