From e52b79a9921d354843a89445f4dcf8f00eb64dc0 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 4 Mar 2026 07:57:28 -0800 Subject: [PATCH] worable code for checkpoint conversion, inferenceing, and training. No partial convergence test run --- .../wan/conversion/convert_checkpoints.py | 2 + .../common/flow_matching/__init__.py | 18 + .../common/flow_matching/adapters/__init__.py | 46 ++ .../common/flow_matching/adapters/base.py | 145 +++++ .../common/flow_matching/adapters/flux.py | 224 +++++++ .../common/flow_matching/adapters/hunyuan.py | 181 ++++++ .../common/flow_matching/adapters/simple.py | 92 +++ .../flow_matching/flow_matching_pipeline.py | 593 ++++++++++++++++++ .../common/flow_matching/time_shift_utils.py | 118 ++++ .../common/flow_matching/training_step_t2v.py | 294 +++++++++ .../flow_matching_pipeline_wan.py | 4 +- .../diffusion/models/wan/wan_provider.py | 2 +- .../test_flow_matching_pipeline_wan.py | 2 +- 13 files changed, 1717 insertions(+), 4 deletions(-) create mode 100644 src/megatron/bridge/diffusion/common/flow_matching/__init__.py create mode 100644 src/megatron/bridge/diffusion/common/flow_matching/adapters/__init__.py create mode 100644 src/megatron/bridge/diffusion/common/flow_matching/adapters/base.py create mode 100644 src/megatron/bridge/diffusion/common/flow_matching/adapters/flux.py create mode 100644 src/megatron/bridge/diffusion/common/flow_matching/adapters/hunyuan.py create mode 100644 src/megatron/bridge/diffusion/common/flow_matching/adapters/simple.py create mode 100644 src/megatron/bridge/diffusion/common/flow_matching/flow_matching_pipeline.py create mode 100644 src/megatron/bridge/diffusion/common/flow_matching/time_shift_utils.py create mode 100644 src/megatron/bridge/diffusion/common/flow_matching/training_step_t2v.py diff --git a/examples/diffusion/recipes/wan/conversion/convert_checkpoints.py b/examples/diffusion/recipes/wan/conversion/convert_checkpoints.py index fc8dedc6e6..9e982be362 100644 --- a/examples/diffusion/recipes/wan/conversion/convert_checkpoints.py +++ b/examples/diffusion/recipes/wan/conversion/convert_checkpoints.py @@ -143,6 +143,8 @@ def import_hf_to_megatron( bridge = WanBridge() provider = bridge.provider_bridge(hf) provider.perform_initialization = False + if hasattr(provider, "finalize"): + provider.finalize() megatron_models = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True) bridge.load_weights_hf_to_megatron(hf, megatron_models) save_megatron_model(megatron_models, megatron_path, hf_tokenizer_path=None) diff --git a/src/megatron/bridge/diffusion/common/flow_matching/__init__.py b/src/megatron/bridge/diffusion/common/flow_matching/__init__.py new file mode 100644 index 0000000000..5187fe3a3d --- /dev/null +++ b/src/megatron/bridge/diffusion/common/flow_matching/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + "training_step_t2v", + "time_shift_utils", +] diff --git a/src/megatron/bridge/diffusion/common/flow_matching/adapters/__init__.py b/src/megatron/bridge/diffusion/common/flow_matching/adapters/__init__.py new file mode 100644 index 0000000000..eccfe5794d --- /dev/null +++ b/src/megatron/bridge/diffusion/common/flow_matching/adapters/__init__.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model adapters for FlowMatching Pipeline. + +This module provides model-specific adapters that decouple the flow matching +logic from model-specific implementation details. + +Available Adapters: +- ModelAdapter: Abstract base class for all adapters +- HunyuanAdapter: For HunyuanVideo 1.5 style models +- SimpleAdapter: For simple transformer models (e.g., Wan) +- FluxAdapter: For FLUX.1 text-to-image models + +Usage: + from automodel.flow_matching.adapters import HunyuanAdapter, SimpleAdapter, FluxAdapter + + # Or import the base class to create custom adapters + from automodel.flow_matching.adapters import ModelAdapter +""" + +from .base import FlowMatchingContext, ModelAdapter +from .flux import FluxAdapter +from .hunyuan import HunyuanAdapter +from .simple import SimpleAdapter + + +__all__ = [ + "FlowMatchingContext", + "ModelAdapter", + "FluxAdapter", + "HunyuanAdapter", + "SimpleAdapter", +] diff --git a/src/megatron/bridge/diffusion/common/flow_matching/adapters/base.py b/src/megatron/bridge/diffusion/common/flow_matching/adapters/base.py new file mode 100644 index 0000000000..a8a1def40f --- /dev/null +++ b/src/megatron/bridge/diffusion/common/flow_matching/adapters/base.py @@ -0,0 +1,145 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base classes and data structures for model adapters. + +This module defines the abstract ModelAdapter class and the FlowMatchingContext +dataclass used to pass data between the pipeline and adapters. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict + +import torch +import torch.nn as nn + + +@dataclass +class FlowMatchingContext: + """ + Context object passed to model adapters containing all necessary data. + + This provides a clean interface for adapters to access the data they need + without coupling to the batch dictionary structure. + + Attributes: + noisy_latents: [B, C, F, H, W] or [B, C, H, W] - Noisy latents after interpolation + latents: [B, C, F, H, W] for video or [B, C, H, W] for image - Original clean latents + (also accessible via deprecated 'video_latents' property for backward compatibility) + timesteps: [B] - Sampled timesteps + sigma: [B] - Sigma values + task_type: "t2v" or "i2v" + data_type: "video" or "image" + device: Device for tensor operations + dtype: Data type for tensor operations + cfg_dropout_prob: Probability of dropping text embeddings (setting to 0) during + training for classifier-free guidance (CFG). Defaults to 0.0 for backward compatibility. + batch: Original batch dictionary (for model-specific data) + """ + + # Core tensors + noisy_latents: torch.Tensor + latents: torch.Tensor + timesteps: torch.Tensor + sigma: torch.Tensor + + # Task info + task_type: str + data_type: str + + # Device/dtype + device: torch.device + dtype: torch.dtype + + # Original batch (for model-specific data) + batch: Dict[str, Any] + + # CFG dropout probability (optional with default for backward compatibility) + cfg_dropout_prob: float = 0.0 + + @property + def video_latents(self) -> torch.Tensor: + """Backward compatibility alias for 'latents' field.""" + return self.latents + + +class ModelAdapter(ABC): + """ + Abstract base class for model-specific forward pass logic. + + Implement this class to add support for new model architectures + without modifying the FlowMatchingPipeline. + + The adapter pattern decouples the flow matching logic from model-specific + details like input preparation and forward pass conventions. + + Example: + class MyCustomAdapter(ModelAdapter): + def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: + return { + "x": context.noisy_latents, + "t": context.timesteps, + "cond": context.batch["my_conditioning"], + } + + def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor: + return model(**inputs) + + pipeline = FlowMatchingPipelineV2(model_adapter=MyCustomAdapter()) + """ + + @abstractmethod + def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: + """ + Prepare model-specific inputs from the context. + + Args: + context: FlowMatchingContext containing all necessary data + + Returns: + Dictionary of inputs to pass to the model's forward method + """ + pass + + @abstractmethod + def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor: + """ + Execute the model forward pass. + + Args: + model: The model to call + inputs: Dictionary of inputs from prepare_inputs() + + Returns: + Model prediction tensor + """ + pass + + def post_process_prediction(self, model_pred: torch.Tensor) -> torch.Tensor: + """ + Post-process model prediction if needed. + + Override this for models that return extra outputs or need transformation. + + Args: + model_pred: Raw model output + + Returns: + Processed prediction tensor + """ + if isinstance(model_pred, tuple): + return model_pred[0] + return model_pred diff --git a/src/megatron/bridge/diffusion/common/flow_matching/adapters/flux.py b/src/megatron/bridge/diffusion/common/flow_matching/adapters/flux.py new file mode 100644 index 0000000000..cdac6afa7b --- /dev/null +++ b/src/megatron/bridge/diffusion/common/flow_matching/adapters/flux.py @@ -0,0 +1,224 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Flux model adapter for FlowMatching Pipeline. + +This adapter supports FLUX.1 style models with: +- T5 text embeddings (text_embeddings) +- CLIP pooled embeddings (pooled_prompt_embeds) +- 2D image latents (treated as 1-frame video: [B, C, 1, H, W]) +""" + +import random +from typing import Any, Dict + +import torch +import torch.nn as nn + +from .base import FlowMatchingContext, ModelAdapter + + +class FluxAdapter(ModelAdapter): + """ + Model adapter for FLUX.1 image generation models. + + Supports batch format from multiresolution dataloader: + - image_latents: [B, C, H, W] for images + - text_embeddings: T5 embeddings [B, seq_len, 4096] + - pooled_prompt_embeds: CLIP pooled [B, 768] + + FLUX model forward interface: + - hidden_states: Packed latents + - encoder_hidden_states: T5 text embeddings + - pooled_projections: CLIP pooled embeddings + - timestep: Normalized timesteps [0, 1] + - img_ids / txt_ids: Positional embeddings + """ + + def __init__( + self, + guidance_scale: float = 3.5, + use_guidance_embeds: bool = True, + ): + """ + Initialize FluxAdapter. + + Args: + guidance_scale: Guidance scale for classifier-free guidance + use_guidance_embeds: Whether to use guidance embeddings + """ + self.guidance_scale = guidance_scale + self.use_guidance_embeds = use_guidance_embeds + + def _pack_latents(self, latents: torch.Tensor) -> torch.Tensor: + """ + Pack latents from [B, C, H, W] to Flux format [B, (H//2)*(W//2), C*4]. + + Flux uses a 2x2 patch embedding, so latents are reshaped accordingly. + """ + b, c, h, w = latents.shape + # Reshape: [B, C, H, W] -> [B, C, H//2, 2, W//2, 2] + latents = latents.view(b, c, h // 2, 2, w // 2, 2) + # Permute: -> [B, H//2, W//2, C, 2, 2] + latents = latents.permute(0, 2, 4, 1, 3, 5) + # Reshape: -> [B, (H//2)*(W//2), C*4] + latents = latents.reshape(b, (h // 2) * (w // 2), c * 4) + return latents + + @staticmethod + def _unpack_latents(latents: torch.Tensor, height: int, width: int, vae_scale_factor: int = 8) -> torch.Tensor: + """ + Unpack latents from Flux format back to [B, C, H, W]. + + Args: + latents: Packed latents of shape [B, num_patches, channels] + height: Original image height in pixels + width: Original image width in pixels + vae_scale_factor: VAE compression factor (default: 8) + """ + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def _prepare_latent_image_ids( + self, + batch_size: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """ + Prepare positional IDs for image latents. + + Returns tensor of shape [B, (H//2)*(W//2), 3] containing (batch_idx, y, x). + """ + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = torch.arange(width // 2)[None, :] + + latent_image_ids = latent_image_ids.reshape(-1, 3) + return latent_image_ids.to(device=device, dtype=dtype) + + def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: + """ + Prepare inputs for Flux model from FlowMatchingContext. + + Expects 4D image latents: [B, C, H, W] + """ + batch = context.batch + device = context.device + dtype = context.dtype + + # Flux only supports 4D image latents [B, C, H, W] + noisy_latents = context.noisy_latents + if noisy_latents.ndim != 4: + raise ValueError(f"FluxAdapter expects 4D latents [B, C, H, W], got {noisy_latents.ndim}D") + + batch_size, channels, height, width = noisy_latents.shape + + # Get text embeddings (T5) + text_embeddings = batch["text_embeddings"].to(device, dtype=dtype) + if text_embeddings.ndim == 2: + text_embeddings = text_embeddings.unsqueeze(0) + + # Get pooled embeddings (CLIP) - may or may not be present + if "pooled_prompt_embeds" in batch: + pooled_projections = batch["pooled_prompt_embeds"].to(device, dtype=dtype) + elif "clip_pooled" in batch: + pooled_projections = batch["clip_pooled"].to(device, dtype=dtype) + else: + # Create zero embeddings if not provided + pooled_projections = torch.zeros(batch_size, 768, device=device, dtype=dtype) + + if pooled_projections.ndim == 1: + pooled_projections = pooled_projections.unsqueeze(0) + + if random.random() < context.cfg_dropout_prob: + text_embeddings = torch.zeros_like(text_embeddings) + pooled_projections = torch.zeros_like(pooled_projections) + + # Pack latents for Flux transformer + packed_latents = self._pack_latents(noisy_latents) + + # Prepare positional IDs + img_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + # Text positional IDs + text_seq_len = text_embeddings.shape[1] + txt_ids = torch.zeros(batch_size, text_seq_len, 3, device=device, dtype=dtype) + + # Timesteps - Flux expects normalized [0, 1] range + # The pipeline provides timesteps in [0, num_train_timesteps] + timesteps = context.timesteps.to(dtype) / 1000.0 + + # TODO: guidance scale is different across pretraining and finetuning, we need pass it as a hyperparamters. + # needs verify by Pranav + guidance = torch.full((batch_size,), self.guidance_scale, device=device, dtype=torch.float32) + + inputs = { + "hidden_states": packed_latents, + "encoder_hidden_states": text_embeddings, + "pooled_projections": pooled_projections, + "timestep": timesteps, + "img_ids": img_ids, + "txt_ids": txt_ids, + # Store original shape for unpacking + "_original_shape": (batch_size, channels, height, width), + "guidance": guidance, + } + + return inputs + + def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor: + """ + Execute forward pass for Flux model. + + Returns unpacked prediction in [B, C, H, W] format. + """ + original_shape = inputs.pop("_original_shape") + batch_size, channels, height, width = original_shape + + # Flux forward pass + model_pred = model( + hidden_states=inputs["hidden_states"], + encoder_hidden_states=inputs["encoder_hidden_states"], + pooled_projections=inputs["pooled_projections"], + timestep=inputs["timestep"], + img_ids=inputs["img_ids"], + txt_ids=inputs["txt_ids"], + guidance=inputs["guidance"], + return_dict=False, + ) + + # Handle tuple output + pred = self.post_process_prediction(model_pred) + + # Unpack from Flux format back to [B, C, H, W] + # Pass pixel dimensions (latent * vae_scale_factor) to _unpack_latents + vae_scale_factor = 8 + pred = self._unpack_latents(pred, height * vae_scale_factor, width * vae_scale_factor) + + return pred diff --git a/src/megatron/bridge/diffusion/common/flow_matching/adapters/hunyuan.py b/src/megatron/bridge/diffusion/common/flow_matching/adapters/hunyuan.py new file mode 100644 index 0000000000..240fd3ca83 --- /dev/null +++ b/src/megatron/bridge/diffusion/common/flow_matching/adapters/hunyuan.py @@ -0,0 +1,181 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +HunyuanVideo model adapter for FlowMatching Pipeline. + +This adapter supports HunyuanVideo 1.5 style models with dual text encoders +and image embeddings for image-to-video conditioning. +""" + +from typing import Any, Dict, Tuple + +import torch +import torch.nn as nn + +from .base import FlowMatchingContext, ModelAdapter + + +class HunyuanAdapter(ModelAdapter): + """ + Model adapter for HunyuanVideo 1.5 style models. + + These models use: + - Condition latents concatenated with noisy latents + - Dual text encoders with attention masks + - Image embeddings for i2v + + Expected batch keys: + - text_embeddings: Primary text encoder output [B, seq_len, dim] + - text_mask: Attention mask for primary encoder [B, seq_len] (optional) + - text_embeddings_2: Secondary text encoder output [B, seq_len, dim] (optional) + - text_mask_2: Attention mask for secondary encoder [B, seq_len] (optional) + - image_embeds: Image embeddings for i2v [B, seq_len, dim] (optional) + + Example: + adapter = HunyuanAdapter() + pipeline = FlowMatchingPipelineV2(model_adapter=adapter) + """ + + def __init__( + self, + default_image_embed_shape: Tuple[int, int] = (729, 1152), + use_condition_latents: bool = True, + ): + """ + Initialize the HunyuanAdapter. + + Args: + default_image_embed_shape: Default shape for image embeddings (seq_len, dim) + when not provided in batch. Defaults to (729, 1152). + use_condition_latents: Whether to concatenate condition latents with + noisy latents. Defaults to True. + """ + self.default_image_embed_shape = default_image_embed_shape + self.use_condition_latents = use_condition_latents + + def get_condition_latents(self, latents: torch.Tensor, task_type: str) -> torch.Tensor: + """ + Generate conditional latents based on task type. + + Args: + latents: Input latents [B, C, F, H, W] + task_type: Task type ("t2v" or "i2v") + + Returns: + Conditional latents [B, C+1, F, H, W] + """ + b, c, f, h, w = latents.shape + cond = torch.zeros([b, c + 1, f, h, w], device=latents.device, dtype=latents.dtype) + + if task_type == "t2v": + return cond + elif task_type == "i2v": + cond[:, :-1, :1] = latents[:, :, :1] + cond[:, -1, 0] = 1 + return cond + else: + raise ValueError(f"Unsupported task type: {task_type}") + + def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: + """ + Prepare inputs for HunyuanVideo model. + + Args: + context: FlowMatchingContext with batch data + + Returns: + Dictionary containing: + - latents: Noisy latents (optionally concatenated with condition latents) + - timesteps: Timestep values + - encoder_hidden_states: Primary text embeddings + - encoder_attention_mask: Primary attention mask + - encoder_hidden_states_2: Secondary text embeddings + - encoder_attention_mask_2: Secondary attention mask + - image_embeds: Image embeddings + """ + batch = context.batch + batch_size = context.noisy_latents.shape[0] + device = context.device + dtype = context.dtype + + # Get text embeddings + text_embeddings = batch["text_embeddings"].to(device, dtype=dtype) + if text_embeddings.ndim == 2: + text_embeddings = text_embeddings.unsqueeze(0) + + # Get optional elements + text_mask = batch.get("text_mask") + text_embeddings_2 = batch.get("text_embeddings_2") + text_mask_2 = batch.get("text_mask_2") + + if text_mask is not None: + text_mask = text_mask.to(device, dtype=dtype) + if text_embeddings_2 is not None: + text_embeddings_2 = text_embeddings_2.to(device, dtype=dtype) + if text_mask_2 is not None: + text_mask_2 = text_mask_2.to(device, dtype=dtype) + + # Handle image embeds for i2v + if context.task_type == "i2v" and "image_embeds" in batch: + image_embeds = batch["image_embeds"].to(device, dtype=dtype) + else: + seq_len, dim = self.default_image_embed_shape + image_embeds = torch.zeros( + batch_size, + seq_len, + dim, + dtype=dtype, + device=device, + ) + + # Prepare latents (with or without condition) + if self.use_condition_latents: + cond_latents = self.get_condition_latents(context.latents, context.task_type) + latents = torch.cat([context.noisy_latents, cond_latents], dim=1) + else: + latents = context.noisy_latents + + return { + "latents": latents, + "timesteps": context.timesteps.to(dtype), + "encoder_hidden_states": text_embeddings, + "encoder_attention_mask": text_mask, + "encoder_hidden_states_2": text_embeddings_2, + "encoder_attention_mask_2": text_mask_2, + "image_embeds": image_embeds, + } + + def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor: + """ + Execute forward pass for HunyuanVideo model. + + Args: + model: HunyuanVideo model + inputs: Dictionary from prepare_inputs() + + Returns: + Model prediction tensor + """ + model_pred = model( + inputs["latents"], + inputs["timesteps"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=inputs["encoder_attention_mask"], + encoder_hidden_states_2=inputs["encoder_hidden_states_2"], + encoder_attention_mask_2=inputs["encoder_attention_mask_2"], + image_embeds=inputs["image_embeds"], + return_dict=False, + ) + return self.post_process_prediction(model_pred) diff --git a/src/megatron/bridge/diffusion/common/flow_matching/adapters/simple.py b/src/megatron/bridge/diffusion/common/flow_matching/adapters/simple.py new file mode 100644 index 0000000000..efb7aeb4d7 --- /dev/null +++ b/src/megatron/bridge/diffusion/common/flow_matching/adapters/simple.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple transformer model adapter for FlowMatching Pipeline. + +This adapter supports simple transformer models with a basic interface, +such as Wan-style models. +""" + +from typing import Any, Dict + +import torch +import torch.nn as nn + +from .base import FlowMatchingContext, ModelAdapter + + +class SimpleAdapter(ModelAdapter): + """ + Model adapter for simple transformer models (e.g., Wan). + + These models use a simple interface with: + - hidden_states: noisy latents + - timestep: timestep values + - encoder_hidden_states: text embeddings + + Expected batch keys: + - text_embeddings: Text encoder output [B, seq_len, dim] + + Example: + adapter = SimpleAdapter() + pipeline = FlowMatchingPipelineV2(model_adapter=adapter) + """ + + def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: + """ + Prepare inputs for simple transformer model. + + Args: + context: FlowMatchingContext with batch data + + Returns: + Dictionary containing: + - hidden_states: Noisy latents + - timestep: Timestep values + - encoder_hidden_states: Text embeddings + """ + batch = context.batch + device = context.device + dtype = context.dtype + + # Get text embeddings + text_embeddings = batch["text_embeddings"].to(device, dtype=dtype) + if text_embeddings.ndim == 2: + text_embeddings = text_embeddings.unsqueeze(0) + + return { + "hidden_states": context.noisy_latents, + "timestep": context.timesteps.to(dtype), + "encoder_hidden_states": text_embeddings, + } + + def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor: + """ + Execute forward pass for simple transformer model. + + Args: + model: Transformer model + inputs: Dictionary from prepare_inputs() + + Returns: + Model prediction tensor + """ + model_pred = model( + hidden_states=inputs["hidden_states"], + timestep=inputs["timestep"], + encoder_hidden_states=inputs["encoder_hidden_states"], + return_dict=False, + ) + return self.post_process_prediction(model_pred) diff --git a/src/megatron/bridge/diffusion/common/flow_matching/flow_matching_pipeline.py b/src/megatron/bridge/diffusion/common/flow_matching/flow_matching_pipeline.py new file mode 100644 index 0000000000..c0c2057331 --- /dev/null +++ b/src/megatron/bridge/diffusion/common/flow_matching/flow_matching_pipeline.py @@ -0,0 +1,593 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FlowMatching Pipeline - Model-agnostic implementation with adapter pattern. + +This module provides a unified FlowMatchingPipeline class that is completely +independent of specific model implementations through the ModelAdapter abstraction. + +Features: +- Model-agnostic design via ModelAdapter protocol +- Various timestep sampling strategies (uniform, logit_normal, mode, lognorm) +- Flow shift transformation +- Sigma clamping for finetuning +- Loss weighting +- Detailed training logging +""" + +import logging +import math +import os +import random +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn + +# Import adapters from the adapters module +from .adapters import ( + FlowMatchingContext, + FluxAdapter, + HunyuanAdapter, + ModelAdapter, + SimpleAdapter, +) + + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Noise Schedule +# ============================================================================= + + +class LinearInterpolationSchedule: + """Simple linear interpolation schedule for flow matching.""" + + def forward(self, x0: torch.Tensor, x1: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + """ + Linear interpolation: x_t = (1 - σ) * x_0 + σ * x_1 + + Args: + x0: Starting point (clean latents) + x1: Ending point (noise) + sigma: Sigma values in [0, 1] + + Returns: + Interpolated tensor at sigma + """ + sigma = sigma.view(-1, *([1] * (x0.ndim - 1))) + return (1.0 - sigma) * x0 + sigma * x1 + + +# ============================================================================= +# Flow Matching Pipeline +# ============================================================================= + + +class FlowMatchingPipeline: + """ + Flow Matching Pipeline - Model-agnostic implementation. + + This pipeline handles all flow matching training logic while delegating + model-specific operations to a ModelAdapter. This allows adding support + for new model architectures without modifying the pipeline code. + + Features: + - Noise scheduling with linear interpolation + - Timestep sampling with various strategies + - Flow shift transformation + - Sigma clamping for finetuning + - Loss weighting + - Detailed training logging + + Example: + # Create pipeline with HunyuanVideo adapter + from automodel.flow_matching.adapters import HunyuanAdapter + + pipeline = FlowMatchingPipeline( + model_adapter=HunyuanAdapter(), + flow_shift=3.0, + timestep_sampling="logit_normal", + ) + + # Training step + weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step(model, batch, device, dtype, global_step) + """ + + def __init__( + self, + model_adapter: ModelAdapter, + num_train_timesteps: int = 1000, + timestep_sampling: str = "logit_normal", + flow_shift: float = 3.0, + i2v_prob: float = 0.3, + cfg_dropout_prob: float = 0.1, + # Logit-normal distribution parameters + logit_mean: float = 0.0, + logit_std: float = 1.0, + # Mix sampling parameters + mix_uniform_ratio: float = 0.1, + # Sigma clamping for finetuning (pretrain uses [0.0, 1.0]) + sigma_min: float = 0.0, + sigma_max: float = 1.0, + # Loss weighting + use_loss_weighting: bool = True, + # Logging + log_interval: int = 100, + summary_log_interval: int = 10, + device: Optional[torch.device] = None, + ): + """ + Initialize the FlowMatching pipeline. + + Args: + model_adapter: ModelAdapter instance for model-specific operations + num_train_timesteps: Total number of timesteps for the flow + timestep_sampling: Sampling strategy: + - "uniform": Pure uniform sampling + - "logit_normal": SD3-style logit-normal (recommended) + - "mode": Mode-based sampling + - "lognorm": Log-normal based sampling + - "mix": Mix of lognorm and uniform + flow_shift: Shift parameter for timestep transformation + i2v_prob: Probability of using image-to-video conditioning + cfg_dropout_prob: Probability of dropping text embeddings for CFG training + logit_mean: Mean for logit-normal distribution + logit_std: Std for logit-normal distribution + mix_uniform_ratio: Ratio of uniform samples when using mix + sigma_min: Minimum sigma (0.0 for pretrain) + sigma_max: Maximum sigma (1.0 for pretrain) + use_loss_weighting: Whether to apply flow-based loss weighting + log_interval: Steps between detailed logs + summary_log_interval: Steps between summary logs + device: Device to use for computations + """ + self.model_adapter = model_adapter + self.num_train_timesteps = num_train_timesteps + self.timestep_sampling = timestep_sampling + self.flow_shift = flow_shift + self.i2v_prob = i2v_prob + self.cfg_dropout_prob = cfg_dropout_prob + self.logit_mean = logit_mean + self.logit_std = logit_std + self.mix_uniform_ratio = mix_uniform_ratio + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.use_loss_weighting = use_loss_weighting + self.log_interval = log_interval + self.summary_log_interval = summary_log_interval + self.device = device if device is not None else torch.device("cuda") + + # Initialize noise schedule + self.noise_schedule = LinearInterpolationSchedule() + + def sample_timesteps( + self, + batch_size: int, + device: Optional[torch.device] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, str]: + """ + Sample timesteps and compute sigma values with flow shift. + + Implements the flow shift transformation: + σ = shift / (shift + (1/u - 1)) + + Args: + batch_size: Number of timesteps to sample + device: Device for tensor operations + + Returns: + sigma: Sigma values in [sigma_min, sigma_max] + timesteps: Timesteps in [0, num_train_timesteps] + sampling_method: Name of the sampling method used + """ + if device is None: + device = self.device + + # Determine if we should use uniform (for mix strategy) + use_uniform = self.timestep_sampling == "uniform" or ( + self.mix_uniform_ratio > 0 and torch.rand(1).item() < self.mix_uniform_ratio + ) + + if use_uniform: + u = torch.rand(size=(batch_size,), device=device) + sampling_method = "uniform" + else: + u = self._sample_from_distribution(batch_size, device) + sampling_method = self.timestep_sampling + + # Apply flow shift: σ = shift / (shift + (1/u - 1)) + u_clamped = torch.clamp(u, min=1e-5) # Avoid division by zero + sigma = self.flow_shift / (self.flow_shift + (1.0 / u_clamped - 1.0)) + + # Apply sigma clamping + sigma = torch.clamp(sigma, self.sigma_min, self.sigma_max) + + # Convert sigma to timesteps [0, T] + timesteps = sigma * self.num_train_timesteps + + return sigma, timesteps, sampling_method + + def _sample_from_distribution(self, batch_size: int, device: torch.device) -> torch.Tensor: + """Sample u values from the configured distribution.""" + if self.timestep_sampling == "logit_normal": + u = torch.normal( + mean=self.logit_mean, + std=self.logit_std, + size=(batch_size,), + device=device, + ) + u = torch.sigmoid(u) + + elif self.timestep_sampling == "lognorm": + u = torch.normal(mean=0.0, std=1.0, size=(batch_size,), device=device) + u = torch.sigmoid(u) + + elif self.timestep_sampling == "mode": + mode_scale = 1.29 + u = torch.rand(size=(batch_size,), device=device) + u = 1.0 - u - mode_scale * (torch.cos(math.pi * u / 2.0) ** 2 - 1.0 + u) + u = torch.clamp(u, 0.0, 1.0) + + elif self.timestep_sampling == "mix": + u = torch.normal(mean=0.0, std=1.0, size=(batch_size,), device=device) + u = torch.sigmoid(u) + + else: + u = torch.rand(size=(batch_size,), device=device) + + return u + + def determine_task_type(self, data_type: str) -> str: + """Determine task type based on data type and randomization.""" + if data_type == "image": + return "t2v" + elif data_type == "video": + return "i2v" if random.random() < self.i2v_prob else "t2v" + else: + return "t2v" + + def compute_loss( + self, + model_pred: torch.Tensor, + target: torch.Tensor, + sigma: torch.Tensor, + batch: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Compute flow matching loss with optional weighting. + + Loss weight: w = 1 + flow_shift * σ + + Args: + model_pred: Model prediction + target: Target (velocity = noise - clean) + sigma: Sigma values for each sample + batch: Optional batch dictionary containing loss_mask + + Returns: + weighted_loss: Per-element weighted loss + average_weighted_loss: Scalar average weighted loss + unweighted_loss: Per-element raw MSE loss + average_unweighted_loss: Scalar average unweighted loss + loss_weight: Applied weights + loss_mask: Loss mask from batch (or None if not present) + """ + loss = nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none") + loss_mask = batch.get("loss_mask") if batch is not None else None + + if self.use_loss_weighting: + loss_weight = 1.0 + self.flow_shift * sigma + loss_weight = loss_weight.view(-1, *([1] * (loss.ndim - 1))) + else: + loss_weight = torch.ones_like(sigma).view(-1, *([1] * (loss.ndim - 1))) + + loss_weight = loss_weight.to(model_pred.device) + + unweighted_loss = loss + weighted_loss = loss * loss_weight + average_unweighted_loss = unweighted_loss.mean() + average_weighted_loss = weighted_loss.mean() + + return weighted_loss, average_weighted_loss, unweighted_loss, average_unweighted_loss, loss_weight, loss_mask + + def step( + self, + model: nn.Module, + batch: Dict[str, Any], + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + global_step: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Dict[str, Any]]: + """ + Execute a single training step with flow matching. + + Expected batch format: + { + "video_latents": torch.Tensor, # [B, C, F, H, W] for video + OR + "image_latents": torch.Tensor, # [B, C, H, W] for image + "text_embeddings": torch.Tensor, # [B, seq_len, dim] + "data_type": str, # "video" or "image" (optional) + # ... additional model-specific keys handled by adapter + } + + Args: + model: The model to train + batch: Batch of training data + device: Device to use + dtype: Data type for operations + global_step: Current training step (for logging) + + Returns: + weighted_loss: Per-element weighted loss + average_weighted_loss: Scalar average weighted loss + loss_mask: Mask indicating valid loss elements (or None) + metrics: Dictionary of training metrics + """ + debug_mode = os.environ.get("DEBUG_TRAINING", "0") == "1" + detailed_log = global_step % self.log_interval == 0 + summary_log = global_step % self.summary_log_interval == 0 + + # Extract and prepare batch data (either image_latents or video_latents) + if "video_latents" in batch: + latents = batch["video_latents"].to(device, dtype=dtype) + elif "image_latents" in batch: + latents = batch["image_latents"].to(device, dtype=dtype) + else: + raise KeyError("Batch must contain either 'video_latents' or 'image_latents'") + + # latents can be 4D [B, C, H, W] for images or 5D [B, C, F, H, W] for videos + batch_size = latents.shape[0] + + # Determine task type + data_type = batch.get("data_type", "video") + task_type = self.determine_task_type(data_type) + + # ==================================================================== + # Flow Matching: Sample Timesteps + # ==================================================================== + sigma, timesteps, sampling_method = self.sample_timesteps(batch_size, device) + + # ==================================================================== + # Flow Matching: Add Noise + # ==================================================================== + noise = torch.randn_like(latents, dtype=torch.float32) + + # x_t = (1 - σ) * x_0 + σ * ε + noisy_latents = self.noise_schedule.forward(latents.float(), noise, sigma) + + # ==================================================================== + # Logging + # ==================================================================== + if detailed_log and debug_mode: + self._log_detailed( + global_step, sampling_method, batch_size, sigma, timesteps, latents, noise, noisy_latents + ) + elif summary_log and debug_mode: + logger.info( + f"[STEP {global_step}] σ=[{sigma.min():.3f},{sigma.max():.3f}] | " + f"t=[{timesteps.min():.1f},{timesteps.max():.1f}] | " + f"noisy=[{noisy_latents.min():.1f},{noisy_latents.max():.1f}] | " + f"{sampling_method}" + ) + + # Convert to target dtype + noisy_latents = noisy_latents.to(dtype) + + # ==================================================================== + # Forward Pass (via adapter) + # ==================================================================== + context = FlowMatchingContext( + noisy_latents=noisy_latents, + latents=latents, + timesteps=timesteps, + sigma=sigma, + task_type=task_type, + data_type=data_type, + device=device, + dtype=dtype, + cfg_dropout_prob=self.cfg_dropout_prob, + batch=batch, + ) + + inputs = self.model_adapter.prepare_inputs(context) + model_pred = self.model_adapter.forward(model, inputs) + + # ==================================================================== + # Target: Flow Matching Velocity + # ==================================================================== + # v = ε - x_0 + target = noise - latents.float() + + # ==================================================================== + # Loss Computation + # ==================================================================== + weighted_loss, average_weighted_loss, unweighted_loss, average_unweighted_loss, loss_weight, loss_mask = ( + self.compute_loss(model_pred, target, sigma, batch) + ) + + # Safety check + if torch.isnan(average_weighted_loss) or average_weighted_loss > 100: + logger.error(f"[ERROR] Loss explosion! Loss={average_weighted_loss.item():.3f}") + raise ValueError(f"Loss exploded: {average_weighted_loss.item()}") + + # Logging + if detailed_log and debug_mode: + self._log_loss_detailed( + global_step, model_pred, target, loss_weight, average_unweighted_loss, average_weighted_loss + ) + elif summary_log and debug_mode: + logger.info( + f"[STEP {global_step}] Loss: {average_weighted_loss.item():.6f} | " + f"w=[{loss_weight.min():.2f},{loss_weight.max():.2f}]" + ) + + # Collect metrics + metrics = { + "loss": average_weighted_loss.item(), + "unweighted_loss": average_unweighted_loss.item(), + "sigma_min": sigma.min().item(), + "sigma_max": sigma.max().item(), + "sigma_mean": sigma.mean().item(), + "weight_min": loss_weight.min().item(), + "weight_max": loss_weight.max().item(), + "timestep_min": timesteps.min().item(), + "timestep_max": timesteps.max().item(), + "noisy_min": noisy_latents.min().item(), + "noisy_max": noisy_latents.max().item(), + "sampling_method": sampling_method, + "task_type": task_type, + "data_type": data_type, + } + + return weighted_loss, average_weighted_loss, loss_mask, metrics + + def _log_detailed( + self, + global_step: int, + sampling_method: str, + batch_size: int, + sigma: torch.Tensor, + timesteps: torch.Tensor, + latents: torch.Tensor, + noise: torch.Tensor, + noisy_latents: torch.Tensor, + ): + """Log detailed training information.""" + logger.info("\n" + "=" * 80) + logger.info(f"[STEP {global_step}] FLOW MATCHING") + logger.info("=" * 80) + logger.info("[INFO] Using: x_t = (1-σ)x_0 + σ*ε") + logger.info("") + logger.info(f"[SAMPLING] Method: {sampling_method}") + logger.info(f"[FLOW] Shift: {self.flow_shift}") + logger.info(f"[BATCH] Size: {batch_size}") + logger.info("") + logger.info(f"[SIGMA] Range: [{sigma.min():.4f}, {sigma.max():.4f}]") + if sigma.numel() > 1: + logger.info(f"[SIGMA] Mean: {sigma.mean():.4f}, Std: {sigma.std():.4f}") + else: + logger.info(f"[SIGMA] Value: {sigma.item():.4f}") + logger.info("") + logger.info(f"[TIMESTEPS] Range: [{timesteps.min():.2f}, {timesteps.max():.2f}]") + logger.info("") + logger.info(f"[RANGES] Clean latents: [{latents.min():.4f}, {latents.max():.4f}]") + logger.info(f"[RANGES] Noise: [{noise.min():.4f}, {noise.max():.4f}]") + logger.info(f"[RANGES] Noisy latents: [{noisy_latents.min():.4f}, {noisy_latents.max():.4f}]") + + # Sanity check + max_expected = ( + max( + abs(latents.max().item()), + abs(latents.min().item()), + abs(noise.max().item()), + abs(noise.min().item()), + ) + * 1.5 + ) + if abs(noisy_latents.max()) > max_expected or abs(noisy_latents.min()) > max_expected: + logger.info(f"\n⚠️ WARNING: Noisy range seems large! Expected ~{max_expected:.1f}") + else: + logger.info("\n✓ Noisy latents range is reasonable") + logger.info("=" * 80 + "\n") + + def _log_loss_detailed( + self, + global_step: int, + model_pred: torch.Tensor, + target: torch.Tensor, + loss_weight: torch.Tensor, + unweighted_loss: torch.Tensor, + weighted_loss: torch.Tensor, + ): + """Log detailed loss information.""" + logger.info("=" * 80) + logger.info(f"[STEP {global_step}] LOSS DEBUG") + logger.info("=" * 80) + logger.info("[TARGET] Flow matching: v = ε - x_0") + logger.info("") + logger.info(f"[RANGES] Model pred: [{model_pred.min():.4f}, {model_pred.max():.4f}]") + logger.info(f"[RANGES] Target (v): [{target.min():.4f}, {target.max():.4f}]") + logger.info("") + logger.info(f"[WEIGHTS] Formula: 1 + {self.flow_shift} * σ") + logger.info(f"[WEIGHTS] Range: [{loss_weight.min():.4f}, {loss_weight.max():.4f}]") + logger.info(f"[WEIGHTS] Mean: {loss_weight.mean():.4f}") + logger.info("") + unweighted_val = unweighted_loss.item() + weighted_val = weighted_loss.item() + logger.info(f"[LOSS] Unweighted: {unweighted_val:.6f}") + logger.info(f"[LOSS] Weighted: {weighted_val:.6f}") + logger.info(f"[LOSS] Impact: {(weighted_val / max(unweighted_val, 1e-8)):.3f}x") + logger.info("=" * 80 + "\n") + + +# ============================================================================= +# Factory Functions +# ============================================================================= + + +def create_adapter(adapter_type: str, **kwargs) -> ModelAdapter: + """ + Factory function to create a model adapter by name. + + Args: + adapter_type: Type of adapter ("hunyuan", "simple", "flux") + **kwargs: Additional arguments passed to the adapter constructor + + Returns: + ModelAdapter instance + """ + adapters = { + "hunyuan": HunyuanAdapter, + "simple": SimpleAdapter, + "flux": FluxAdapter, + } + + if adapter_type not in adapters: + raise ValueError(f"Unknown adapter type: {adapter_type}. Available: {list(adapters.keys())}") + + return adapters[adapter_type](**kwargs) + + +def create_pipeline( + adapter_type: str, + adapter_kwargs: Optional[Dict[str, Any]] = None, + **pipeline_kwargs, +) -> FlowMatchingPipeline: + """ + Factory function to create a pipeline with a specific adapter. + + Args: + adapter_type: Type of adapter ("hunyuan", "simple") + adapter_kwargs: Arguments for the adapter constructor + **pipeline_kwargs: Arguments for the pipeline constructor + + Returns: + FlowMatchingPipeline instance + + Example: + pipeline = create_pipeline( + adapter_type="hunyuan", + adapter_kwargs={"use_condition_latents": True}, + flow_shift=3.0, + timestep_sampling="logit_normal", + ) + """ + adapter_kwargs = adapter_kwargs or {} + adapter = create_adapter(adapter_type, **adapter_kwargs) + return FlowMatchingPipeline(model_adapter=adapter, **pipeline_kwargs) diff --git a/src/megatron/bridge/diffusion/common/flow_matching/time_shift_utils.py b/src/megatron/bridge/diffusion/common/flow_matching/time_shift_utils.py new file mode 100644 index 0000000000..d1786dcfce --- /dev/null +++ b/src/megatron/bridge/diffusion/common/flow_matching/time_shift_utils.py @@ -0,0 +1,118 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math + +import numpy as np +import torch + + +def time_shift( + t: torch.Tensor, + image_seq_len: int, + shift_type: str = "constant", + base_shift: float = 0.5, + max_shift: float = 1.15, + constant: float = 3.0, +): + """ + Convert timesteps to sigmas with sequence-length-aware shifting. + + Args: + t: timesteps in range [0, 1] + image_seq_len: number of tokens (frames * height * width / patch_size^2) + shift_type: "linear", "sqrt", or "constant" + base_shift: base shift for linear mode + max_shift: max shift for linear mode + constant: shift value for constant mode (default 3.0 matches Pika) + + Returns: + sigma values for noise scheduling + """ + if shift_type == "linear": + # Linear interpolation based on sequence length + mu = base_shift + (max_shift - base_shift) * (image_seq_len / 4096) + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)) + + elif shift_type == "sqrt": + # Square root scaling (Flux-style) + # Assuming 128x128 latent space (1024x1024 image) gives mu=3 + mu = np.maximum(1.0, np.sqrt(image_seq_len / (128.0 * 128.0)) * 3.0) + return mu / (mu + (1 / t - 1)) + + elif shift_type == "constant": + # Constant shift (Pika default) + return constant / (constant + (1 / t - 1)) + + else: + # No shift, return original t + return t + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, + batch_size: int, + logit_mean: float = 0.0, + logit_std: float = 1.0, + mode_scale: float = 1.29, +): + """ + Sample timesteps from different distributions for better training coverage. + + Args: + weighting_scheme: "uniform", "logit_normal", or "mode" + batch_size: number of samples to generate + logit_mean: mean for logit-normal distribution + logit_std: std for logit-normal distribution + mode_scale: scale for mode-based sampling + + Returns: + Tensor of shape (batch_size,) with values in [0, 1] + """ + if weighting_scheme == "logit_normal": + # SD3-style logit-normal sampling + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + + elif weighting_scheme == "mode": + # Mode-based sampling (concentrates around certain timesteps) + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + + else: + # Uniform sampling (default) + u = torch.rand(size=(batch_size,), device="cpu") + + return u + + +def get_flow_match_loss_weight(sigma: torch.Tensor, shift: float = 3.0): + """ + Compute loss weights for flow matching based on sigma values. + + Higher sigma (more noise) typically gets higher weight. + + Args: + sigma: sigma values in range [0, 1] + shift: weight scaling factor + + Returns: + Loss weights with same shape as sigma + """ + # Flow matching weight: weight = 1 + shift * sigma + # This gives more weight to noisier timesteps + weight = 1.0 + shift * sigma + return weight diff --git a/src/megatron/bridge/diffusion/common/flow_matching/training_step_t2v.py b/src/megatron/bridge/diffusion/common/flow_matching/training_step_t2v.py new file mode 100644 index 0000000000..b208e666c3 --- /dev/null +++ b/src/megatron/bridge/diffusion/common/flow_matching/training_step_t2v.py @@ -0,0 +1,294 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import os +from typing import Dict, Tuple + +import torch + +from .time_shift_utils import ( + compute_density_for_timestep_sampling, +) + + +logger = logging.getLogger(__name__) + + +def step_fsdp_transformer_t2v( + scheduler, + model, + batch, + device, + bf16, + # Flow matching parameters + use_sigma_noise: bool = True, + timestep_sampling: str = "uniform", + logit_mean: float = 0.0, + logit_std: float = 1.0, + flow_shift: float = 3.0, + mix_uniform_ratio: float = 0.1, + sigma_min: float = 0.0, # Default: no clamping (pretrain) + sigma_max: float = 1.0, # Default: no clamping (pretrain) + global_step: int = 0, +) -> Tuple[torch.Tensor, Dict]: + """ + Pure flow matching training - DO NOT use scheduler.add_noise(). + + The scheduler's add_noise() uses alpha_t/sigma_t which explodes at low timesteps. + We use simple flow matching: x_t = (1-σ)x_0 + σ*ε + """ + debug_mode = os.environ.get("DEBUG_TRAINING", "0") == "1" + detailed_log = global_step % 100 == 0 + summary_log = global_step % 10 == 0 + + # Extract and prepare batch data + video_latents = batch["video_latents"].to(device, dtype=bf16) + text_embeddings = batch["text_embeddings"].to(device, dtype=bf16) + + assert video_latents.ndim in (4, 5), "Expected video_latents.ndim to be 4 or 5 " + assert text_embeddings.ndim in (2, 3), "Expected text_embeddings.ndim to be 2 or 3 " + # Handle tensor shapes + if video_latents.ndim == 4: + video_latents = video_latents.unsqueeze(0) + + if text_embeddings.ndim == 2: + text_embeddings = text_embeddings.unsqueeze(0) + + batch_size, channels, frames, height, width = video_latents.shape + + # ======================================================================== + # Flow Matching Timestep Sampling + # ======================================================================== + + num_train_timesteps = scheduler.config.num_train_timesteps + + if use_sigma_noise: + use_uniform = torch.rand(1).item() < mix_uniform_ratio + + if use_uniform or timestep_sampling == "uniform": + # Pure uniform: u ~ U(0, 1) + u = torch.rand(size=(batch_size,), device=device) + sampling_method = "uniform" + else: + # Density-based sampling + u = compute_density_for_timestep_sampling( + weighting_scheme=timestep_sampling, + batch_size=batch_size, + logit_mean=logit_mean, + logit_std=logit_std, + ).to(device) + sampling_method = timestep_sampling + + # Apply flow shift: σ = shift/(shift + (1/u - 1)) + u_clamped = torch.clamp(u, min=1e-5) # Avoid division by zero + sigma = flow_shift / (flow_shift + (1.0 / u_clamped - 1.0)) + + # Clamp sigma (only if not full range [0,1]) + # Pretrain uses [0, 1], finetune uses [0.02, 0.55] + if sigma_min > 0.0 or sigma_max < 1.0: + sigma = torch.clamp(sigma, sigma_min, sigma_max) + else: + sigma = torch.clamp(sigma, 0.0, 1.0) + + else: + # Simple uniform without shift + u = torch.rand(size=(batch_size,), device=device) + + # Clamp sigma (only if not full range [0,1]) + if sigma_min > 0.0 or sigma_max < 1.0: + sigma = torch.clamp(u, sigma_min, sigma_max) + else: + sigma = u + sampling_method = "uniform_no_shift" + + # ======================================================================== + # Manual Flow Matching Noise Addition + # ======================================================================== + + # Generate noise + noise = torch.randn_like(video_latents, dtype=torch.float32) + + # CRITICAL: Manual flow matching (NOT scheduler.add_noise!) + # x_t = (1 - σ) * x_0 + σ * ε + sigma_reshaped = sigma.view(-1, 1, 1, 1, 1) + noisy_latents = (1.0 - sigma_reshaped) * video_latents.float() + sigma_reshaped * noise + + # Timesteps for model [0, 1000] + timesteps = sigma * num_train_timesteps + + # ==================================================================== + # DETAILED LOGGING + # ==================================================================== + if detailed_log or debug_mode: + logger.info("\n" + "=" * 80) + logger.info(f"[STEP {global_step}] MANUAL FLOW MATCHING") + logger.info("=" * 80) + logger.info("[WARNING] NOT using scheduler.add_noise() - it explodes!") + logger.info("[INFO] Using manual: x_t = (1-σ)x_0 + σ*ε") + logger.info("") + logger.info(f"[SAMPLING] Method: {sampling_method}") + logger.info(f"[FLOW] Shift: {flow_shift}") + logger.info(f"[BATCH] Size: {batch_size}") + logger.info("") + logger.info(f"[U] Range: [{u.min():.4f}, {u.max():.4f}]") + if u.numel() > 1: + logger.info(f"[U] Mean: {u.mean():.4f}, Std: {u.std():.4f}") + else: + logger.info(f"[U] Value: {u.item():.4f}") + logger.info("") + logger.info(f"[SIGMA] Range: [{sigma.min():.4f}, {sigma.max():.4f}]") + if sigma.numel() > 1: + logger.info(f"[SIGMA] Mean: {sigma.mean():.4f}, Std: {sigma.std():.4f}") + else: + logger.info(f"[SIGMA] Value: {sigma.item():.4f}") + logger.info("") + logger.info(f"[TIMESTEPS] Range: [{timesteps.min():.2f}, {timesteps.max():.2f}]") + logger.info("") + logger.info(f"[WEIGHTS] Clean: {(1 - sigma_reshaped).squeeze().cpu().numpy()}") + logger.info(f"[WEIGHTS] Noise: {sigma_reshaped.squeeze().cpu().numpy()}") + logger.info("") + logger.info(f"[RANGES] Clean latents: [{video_latents.min():.4f}, {video_latents.max():.4f}]") + logger.info(f"[RANGES] Noise: [{noise.min():.4f}, {noise.max():.4f}]") + logger.info(f"[RANGES] Noisy latents: [{noisy_latents.min():.4f}, {noisy_latents.max():.4f}]") + + # Sanity check + max_expected = ( + max( + abs(video_latents.max().item()), + abs(video_latents.min().item()), + abs(noise.max().item()), + abs(noise.min().item()), + ) + * 1.5 + ) + if abs(noisy_latents.max()) > max_expected or abs(noisy_latents.min()) > max_expected: + logger.info(f"\n⚠️ WARNING: Noisy range seems large! Expected ~{max_expected:.1f}") + else: + logger.info("\n✓ Noisy latents range is reasonable") + logger.info("=" * 80 + "\n") + + elif summary_log: + logger.info( + f"[STEP {global_step}] σ=[{sigma.min():.3f},{sigma.max():.3f}] | " + f"t=[{timesteps.min():.1f},{timesteps.max():.1f}] | " + f"noisy=[{noisy_latents.min():.1f},{noisy_latents.max():.1f}] | " + f"{sampling_method}" + ) + + # Convert to bf16 + noisy_latents = noisy_latents.to(bf16) + timesteps_for_model = timesteps.to(bf16) + + # ======================================================================== + # Forward Pass + # ======================================================================== + + try: + model_pred = model( + hidden_states=noisy_latents, + timestep=timesteps_for_model, + encoder_hidden_states=text_embeddings, + return_dict=False, + ) + + if isinstance(model_pred, tuple): + model_pred = model_pred[0] + + except Exception as e: + logger.info(f"[ERROR] Forward pass failed: {e}") + logger.info( + f"[DEBUG] noisy_latents: {noisy_latents.shape}, range: [{noisy_latents.min()}, {noisy_latents.max()}]" + ) + logger.info( + f"[DEBUG] timesteps: {timesteps_for_model.shape}, range: [{timesteps_for_model.min()}, {timesteps_for_model.max()}]" + ) + raise + + # ======================================================================== + # Target: Flow Matching Velocity + # ======================================================================== + + # Flow matching target: v = ε - x_0 + target = noise - video_latents.float() + + # ======================================================================== + # Loss with Flow Weighting + # ======================================================================== + + loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none") + + # Flow weight: w = 1 + shift * σ + loss_weight = 1.0 + flow_shift * sigma + loss_weight = loss_weight.view(-1, 1, 1, 1, 1).to(device) + + unweighted_loss = loss.mean() + weighted_loss = (loss * loss_weight).mean() + + # Safety check + if torch.isnan(weighted_loss) or weighted_loss > 100: + logger.info(f"[ERROR] Loss explosion! Loss={weighted_loss.item():.3f}") + logger.info("[DEBUG] Stopping training - check hyperparameters") + raise ValueError(f"Loss exploded: {weighted_loss.item()}") + + # ==================================================================== + # LOSS LOGGING + # ==================================================================== + if detailed_log or debug_mode: + logger.info("=" * 80) + logger.info(f"[STEP {global_step}] LOSS DEBUG") + logger.info("=" * 80) + logger.info("[TARGET] Flow matching: v = ε - x_0") + logger.info(f"[PREDICTION] Scheduler type (inference only): {type(scheduler).__name__}") + logger.info("") + logger.info(f"[RANGES] Model pred: [{model_pred.min():.4f}, {model_pred.max():.4f}]") + logger.info(f"[RANGES] Target (v): [{target.min():.4f}, {target.max():.4f}]") + logger.info("") + logger.info(f"[WEIGHTS] Formula: 1 + {flow_shift} * σ") + logger.info(f"[WEIGHTS] Range: [{loss_weight.min():.4f}, {loss_weight.max():.4f}]") + if loss_weight.numel() > 1: + logger.info(f"[WEIGHTS] Mean: {loss_weight.mean():.4f}") + else: + logger.info(f"[WEIGHTS] Value: {loss_weight.mean():.4f}") + logger.info("") + logger.info(f"[LOSS] Unweighted: {unweighted_loss.item():.6f}") + logger.info(f"[LOSS] Weighted: {weighted_loss.item():.6f}") + logger.info(f"[LOSS] Impact: {(weighted_loss / max(unweighted_loss, 1e-8)):.3f}x") + logger.info("=" * 80 + "\n") + + elif summary_log: + logger.info( + f"[STEP {global_step}] Loss: {weighted_loss.item():.6f} | " + f"w=[{loss_weight.min():.2f},{loss_weight.max():.2f}]" + ) + + # Metrics + metrics = { + "loss": weighted_loss.item(), + "unweighted_loss": unweighted_loss.item(), + "sigma_min": sigma.min().item(), + "sigma_max": sigma.max().item(), + "sigma_mean": sigma.mean().item(), + "weight_min": loss_weight.min().item(), + "weight_max": loss_weight.max().item(), + "timestep_min": timesteps.min().item(), + "timestep_max": timesteps.max().item(), + "noisy_min": noisy_latents.min().item(), + "noisy_max": noisy_latents.max().item(), + "sampling_method": sampling_method, + } + + return weighted_loss, metrics diff --git a/src/megatron/bridge/diffusion/models/wan/flow_matching/flow_matching_pipeline_wan.py b/src/megatron/bridge/diffusion/models/wan/flow_matching/flow_matching_pipeline_wan.py index 5b18b6480a..badc3ee485 100644 --- a/src/megatron/bridge/diffusion/models/wan/flow_matching/flow_matching_pipeline_wan.py +++ b/src/megatron/bridge/diffusion/models/wan/flow_matching/flow_matching_pipeline_wan.py @@ -16,8 +16,8 @@ import torch import torch.nn as nn -from dfm.src.automodel.flow_matching.adapters.base import FlowMatchingContext, ModelAdapter -from dfm.src.automodel.flow_matching.flow_matching_pipeline import FlowMatchingPipeline +from megatron.bridge.diffusion.common.flow_matching.adapters.base import FlowMatchingContext, ModelAdapter +from megatron.bridge.diffusion.common.flow_matching.flow_matching_pipeline import FlowMatchingPipeline from megatron.core import parallel_state from megatron.bridge.diffusion.models.wan.utils import thd_split_inputs_cp diff --git a/src/megatron/bridge/diffusion/models/wan/wan_provider.py b/src/megatron/bridge/diffusion/models/wan/wan_provider.py index 2135c0e0c3..45b983aafd 100644 --- a/src/megatron/bridge/diffusion/models/wan/wan_provider.py +++ b/src/megatron/bridge/diffusion/models/wan/wan_provider.py @@ -54,7 +54,7 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): # bf16: bool = False params_dtype: torch.dtype = torch.float32 qkv_format: str = "thd" # "sbhd". NOTE: if we use context parallelism, we need to use "thd" - apply_rope_fusion: bool = True + apply_rope_fusion: bool = False # currently, in Megatron-LM's TE, apply_rope_fusion + thd doesn't support interleaved RoPE bias_activation_fusion: bool = True # these attributes are unused for images/videos, we just set because bridge training requires for LLMs seq_length: int = 1024 diff --git a/tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_matching_pipeline_wan.py b/tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_matching_pipeline_wan.py index efe9b08516..ec62a5abc1 100644 --- a/tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_matching_pipeline_wan.py +++ b/tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_matching_pipeline_wan.py @@ -16,7 +16,7 @@ import pytest import torch -from dfm.src.automodel.flow_matching.adapters.base import FlowMatchingContext +from megatron.bridge.diffusion.common.flow_matching.adapters.base import FlowMatchingContext from megatron.bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan import ( WanAdapter,