From a86856a56cb727b974b7dc31234cd30e3534644b Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 22 Oct 2025 18:42:10 -0700 Subject: [PATCH 01/17] runnanle mcore Wan inference --- examples/recipes/wan/inference_wan.py | 291 ++++++ .../models/wan/flow_matching/__init__.py | 13 + .../flow_matching/flow_inference_pipeline.py | 741 +++++++++++++++ .../models/wan/flow_matching/flow_pipeline.py | 246 +++++ .../models/wan/inference/configs/__init__.py | 53 ++ .../wan/inference/configs/shared_config.py | 21 + .../wan/inference/configs/wan_i2v_14B.py | 36 + .../wan/inference/configs/wan_t2v_14B.py | 29 + .../wan/inference/configs/wan_t2v_1_3B.py | 29 + .../models/wan/inference/utils/fm_solvers.py | 859 ++++++++++++++++++ .../wan/inference/utils/fm_solvers_unipc.py | 802 ++++++++++++++++ .../models/wan/inference/utils/utils.py | 118 +++ .../bridge/models/wan/modules/__init__.py | 13 + src/megatron/bridge/models/wan/modules/t5.py | 513 +++++++++++ .../bridge/models/wan/modules/tokenizers.py | 82 ++ src/megatron/bridge/models/wan/modules/vae.py | 663 ++++++++++++++ src/megatron/bridge/models/wan/rope_utils.py | 61 ++ src/megatron/bridge/models/wan/wan_bridge.py | 225 +++++ .../bridge/models/wan/wan_layer_spec.py | 674 ++++++++++++++ src/megatron/bridge/models/wan/wan_model.py | 387 ++++++++ .../bridge/models/wan/wan_provider.py | 121 +++ src/megatron/bridge/models/wan/wan_step.py | 194 ++++ 22 files changed, 6171 insertions(+) create mode 100644 examples/recipes/wan/inference_wan.py create mode 100644 src/megatron/bridge/models/wan/flow_matching/__init__.py create mode 100644 src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py create mode 100644 src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py create mode 100644 src/megatron/bridge/models/wan/inference/configs/__init__.py create mode 100644 src/megatron/bridge/models/wan/inference/configs/shared_config.py create mode 100644 src/megatron/bridge/models/wan/inference/configs/wan_i2v_14B.py create mode 100644 src/megatron/bridge/models/wan/inference/configs/wan_t2v_14B.py create mode 100644 src/megatron/bridge/models/wan/inference/configs/wan_t2v_1_3B.py create mode 100644 src/megatron/bridge/models/wan/inference/utils/fm_solvers.py create mode 100644 src/megatron/bridge/models/wan/inference/utils/fm_solvers_unipc.py create mode 100644 src/megatron/bridge/models/wan/inference/utils/utils.py create mode 100644 src/megatron/bridge/models/wan/modules/__init__.py create mode 100644 src/megatron/bridge/models/wan/modules/t5.py create mode 100644 src/megatron/bridge/models/wan/modules/tokenizers.py create mode 100644 src/megatron/bridge/models/wan/modules/vae.py create mode 100644 src/megatron/bridge/models/wan/rope_utils.py create mode 100644 src/megatron/bridge/models/wan/wan_bridge.py create mode 100644 src/megatron/bridge/models/wan/wan_layer_spec.py create mode 100644 src/megatron/bridge/models/wan/wan_model.py create mode 100644 src/megatron/bridge/models/wan/wan_provider.py create mode 100644 src/megatron/bridge/models/wan/wan_step.py diff --git a/examples/recipes/wan/inference_wan.py b/examples/recipes/wan/inference_wan.py new file mode 100644 index 0000000000..a593f73e0d --- /dev/null +++ b/examples/recipes/wan/inference_wan.py @@ -0,0 +1,291 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +import logging +import os +import sys +import warnings +from datetime import datetime + +warnings.filterwarnings('ignore') + +import random + +import torch +import torch.distributed as dist +from PIL import Image + +from megatron.bridge.models.wan.flow_matching.flow_inference_pipeline import FlowInferencePipeline +from megatron.bridge.models.wan.inference.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS +from megatron.bridge.models.wan.inference.utils.utils import cache_video, str2bool + +# DEBUGGING +import numpy as np +np.set_printoptions(precision=10, suppress=False) +torch.set_printoptions(precision=6, sci_mode=False) + +EXAMPLE_PROMPT = { + "t2v-1.3B": { + "prompt": + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "t2v-14B": { + "prompt": + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, +} + + +def _validate_args(args): + # Basic check + assert args.ckpt_dir is not None, "Please specify the checkpoint directory." + assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" + assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" + + # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. + if args.sample_steps is None: + args.sample_steps = 50 + + if args.sample_shift is None: + args.sample_shift = 5.0 + + # Frames default handled later; no single frame arg anymore + + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( + 0, sys.maxsize) + # Size check: only validate provided --sizes; default handled later + if args.sizes is not None and len(args.sizes) > 0: + for s in args.sizes: + assert s in SUPPORTED_SIZES[args.task], ( + f"Unsupport size {s} for task {args.task}, supported sizes are: " + f"{', '.join(SUPPORTED_SIZES[args.task])}") + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate a image or video from a text prompt or image using Wan" + ) + parser.add_argument( + "--task", + type=str, + default="t2v-14B", + choices=list(WAN_CONFIGS.keys()), + help="The task to run.") + parser.add_argument( + "--sizes", + type=str, + nargs="+", + default=None, + choices=list(SIZE_CONFIGS.keys()), + help="A list of sizes to generate multiple images or videos. Example: --sizes 1280*720 1920*1080" + ) + parser.add_argument( + "--frame_nums", + type=int, + nargs="+", + default=None, + help="List of frame counts (each should be 4n+1). Broadcasts if single value." + ) + parser.add_argument( + "--ckpt_dir", + type=str, + default=None, + help="The path to the checkpoint directory.") + parser.add_argument( + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." + ) + parser.add_argument( + "--t5_cpu", + action="store_true", + default=False, + help="Whether to place T5 model on CPU.") + parser.add_argument( + "--save_file", + type=str, + default=None, + help="The file to save the generated image or video to.") + parser.add_argument( + "--prompts", + type=str, + nargs="+", + default=None, + help="A list of prompts to generate multiple images or videos. Example: --prompts 'a cat' 'a dog'" + ) + parser.add_argument( + "--base_seed", + type=int, + default=-1, + help="The seed to use for generating the image or video.") + parser.add_argument( + "--sample_solver", + type=str, + default='unipc', + choices=['unipc', 'dpm++'], + help="The solver used to sample.") + parser.add_argument( + "--sample_steps", type=int, default=None, help="The sampling steps.") + parser.add_argument( + "--sample_shift", + type=float, + default=None, + help="Sampling shift factor for flow matching schedulers.") + parser.add_argument( + "--sample_guide_scale", + type=float, + default=5.0, + help="Classifier free guidance scale.") + parser.add_argument( + "--tensor_parallel_size", + type=int, + default=1, + help="Tensor parallel size.") + parser.add_argument( + "--context_parallel_size", + type=int, + default=1, + help="Context parallel size.") + parser.add_argument( + "--pipeline_parallel_size", + type=int, + default=1, + help="Pipeline parallel size.") + parser.add_argument( + "--sequence_parallel", + type=str2bool, + default=False, + help="Sequence parallel.") + + args = parser.parse_args() + + _validate_args(args) + + return args + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)]) + else: + logging.basicConfig(level=logging.ERROR) + + +def generate(args): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = local_rank + _init_logging(rank) + + if args.offload_model is None: + args.offload_model = False if world_size > 1 else True + logging.info( + f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + init_method="env://", + rank=rank, + world_size=world_size) + + cfg = WAN_CONFIGS[args.task] + + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {cfg}") + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + if "t2v" in args.task: + # Resolve prompts list (default to example prompt) + if args.prompts is not None and len(args.prompts) > 0: + prompts = args.prompts + else: + prompts = [EXAMPLE_PROMPT[args.task]["prompt"]] + + # Resolve sizes list (default to first supported size for task) + if args.sizes is not None and len(args.sizes) > 0: + size_keys = args.sizes + else: + size_keys = [SUPPORTED_SIZES[args.task][0]] + + # Resolve frame counts list (default 81) + if args.frame_nums is not None and len(args.frame_nums) > 0: + frame_nums = args.frame_nums + else: + frame_nums = [81] + + # Enforce 1:1 pairing across lists + assert len(prompts) == len(size_keys) == len(frame_nums), ( + f"prompts ({len(prompts)}), sizes ({len(size_keys)}), and frame_nums ({len(frame_nums)}) " + f"must have the same length") + + logging.info("Creating flow inference pipeline.") + pipeline = FlowInferencePipeline( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_cpu=args.t5_cpu, + tensor_parallel_size=args.tensor_parallel_size, + context_parallel_size=args.context_parallel_size, + pipeline_parallel_size=args.pipeline_parallel_size, + sequence_parallel=args.sequence_parallel, + pipeline_dtype=torch.float32, + ) + + # DEBUGGING + rank = dist.get_rank() + if rank == 0: + print("tensor_parallel_size:", args.tensor_parallel_size) + print("context_parallel_size:", args.context_parallel_size) + print("pipeline_parallel_size:", args.pipeline_parallel_size) + print("sequence_parallel:", args.sequence_parallel) + print("\n\n\n") + + logging.info( + f"Generating videos ...") + videos = pipeline.generate( + prompts=prompts, + sizes=[SIZE_CONFIGS[size] for size in size_keys], + frame_nums=frame_nums, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + if rank == 0: + for i, video in enumerate(videos): + formatted_experiment_name = (args.save_file) if args.save_file is not None else "DefaultExp" + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = prompts[i].replace(" ", "_").replace("/", + "_")[:50] + suffix = '.mp4' + formatted_save_file = f"{args.task}_{formatted_experiment_name}_videoindex{int(i)}_size{size_keys[i].replace('*','x') if sys.platform=='win32' else size_keys[i]}_{formatted_prompt}_{formatted_time}" + suffix + + if "t2v" in args.task: + logging.info(f"Saving generated video to {formatted_save_file}") + cache_video( + tensor=video[None], + save_file=formatted_save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + logging.info("Finished.") + + +if __name__ == "__main__": + args = _parse_args() + generate(args) diff --git a/src/megatron/bridge/models/wan/flow_matching/__init__.py b/src/megatron/bridge/models/wan/flow_matching/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/src/megatron/bridge/models/wan/flow_matching/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py new file mode 100644 index 0000000000..5b905cabee --- /dev/null +++ b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py @@ -0,0 +1,741 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial + +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +from tqdm import tqdm + +from megatron.bridge.models.wan.wan_model import WanModel +from megatron.bridge.models.wan.wan_provider import WanModelProvider +from megatron.bridge.models.wan.modules.t5 import T5EncoderModel +from megatron.bridge.models.wan.modules import WanVAE +from megatron.bridge.models.wan.inference.utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) +from megatron.bridge.models.wan.inference.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from megatron.core.dist_checkpointing.validation import StrictHandling +from megatron.core import dist_checkpointing, parallel_state +from torch.nn import functional as F + +import math +from typing import Tuple, Union + +class FlowInferencePipeline: + + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_cpu=False, + tensor_parallel_size=1, + context_parallel_size=1, + pipeline_parallel_size=1, + sequence_parallel=False, + pipeline_dtype=torch.float32, + ): + r""" + Initializes the FlowInferencePipeline with the given parameters. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.tensor_parallel_size = tensor_parallel_size + self.context_parallel_size = context_parallel_size + self.pipeline_parallel_size = pipeline_parallel_size + self.sequence_parallel = sequence_parallel + self.pipeline_dtype = pipeline_dtype + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = WanVAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) + + wan_checkpoint_dir = os.path.join(checkpoint_dir, "iter_0000000") + self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) + + self.sp_size = 1 + + if dist.is_initialized(): + dist.barrier() + self.model.to(self.device) + + self.sample_neg_prompt = config.sample_neg_prompt + + + def patchify(self, x, patch_size): + """ + Convert a list of reconstructed video tensor into patch embeddings (inverse of `unpatchify`). + + Args: + x (list[torch.Tensor]): list of tensors, each with shape [C, F * pF, H * pH, W * pW] + patch_size (tuple): (pF, pH, pW) + + Returns: + torch.Tensor: shape [num_patches, C * prod(patch_size)], + where num_patches = F * H * W + """ + out = [] + for u in x: + c, F_pF, H_pH, W_pW = u.shape + pF, pH, pW = patch_size + assert F_pF % pF == 0 and H_pH % pH == 0 and W_pW % pW == 0, \ + "Spatial dimensions must be divisible by patch size." + + F, H, W = F_pF // pF, H_pH // pH, W_pW // pW + + # split spatial dims into (grid, patch) and reorder to match original patch layout: + # start: (C, F_pF, H_pW, W_pW) + # reshape -> (C, F, pF, H, pH, W, pW) + # permute -> (F, H, W, pF, pH, pW, C) + # DEBUGGING + t = u.reshape(c, F, pF, H, pH, W, pW) + # t = u.reshape(c, F, pF, W, pW, H, pH) + t = t.permute(1, 3, 5, 0, 2, 4, 6) + + num_patches = F * H * W + out.append(t.reshape(num_patches, c * (pF * pH * pW))) + return out + + + def unpatchify(self, x: torch.Tensor, grid_sizes: torch.Tensor, out_dim: int) -> torch.Tensor: + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (Tensor): + Tensor of patchified features, with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + Tensor: + # Reconstructed video tensor with shape [C_out, F, H / 8, W / 8] + # ??? list of tensors, because each sample in the batch has a different video shape, the original video shape is determined by the grid_sizes. + list[Tensor]: list of tensors, each with shape [C_out, F, H / 8, W / 8] + """ + + c = out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + # because the video shapes are different for each sample in the batch, we cannot stack the videos into a single tensor. + # out = torch.stack(out, dim=0) + return out + + + def setup_model_from_checkpoint(self, checkpoint_dir): + + # def init_distributed(tp_size: int = 1, pp_size: int = 1, cp_size: int = 1): + # rank = int(os.environ.get("LOCAL_RANK", 0)) + # world_size = int(os.environ.get("WORLD_SIZE", 1)) + # torch.cuda.set_device(rank % torch.cuda.device_count()) + # torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) + # parallel_state.initialize_model_parallel(tp_size, pp_size, context_parallel_size=cp_size) + # init_distributed(self.tensor_parallel_size, self.pipeline_parallel_size, self.context_parallel_size) + + provider = WanModelProvider() + provider.tensor_model_parallel_size = self.tensor_parallel_size + provider.pipeline_model_parallel_size = self.pipeline_parallel_size + provider.context_parallel_size = self.context_parallel_size + provider.sequence_parallel = self.sequence_parallel + print(f"provider.sequence_parallel: {provider.sequence_parallel}") + provider.pipeline_dtype = self.pipeline_dtype + # Once all overrides are set, finalize the model provider to ensure the post initialization logic is run + provider.finalize() + provider.initialize_model_parallel(seed=0) + + + ## Method 1: Read from megatron checkpoint + from megatron.bridge.training.model_load_save import load_megatron_model as _load_megatron_model + model = _load_megatron_model( + checkpoint_dir, + mp_overrides={ + "tensor_model_parallel_size": self.tensor_parallel_size, + "pipeline_model_parallel_size": self.pipeline_parallel_size, + "context_parallel_size": self.context_parallel_size, + "sequence_parallel": self.sequence_parallel, + "pipeline_dtype": self.pipeline_dtype, + }, + ) + if isinstance(model, list): + model = model[0] + # ## Method 2: Read from megatron checkpoint + # model = provider.provide_distributed_model(wrap_with_ddp=False) + ## Method 3 (not loading checkpoint) + # model = provider.provide() + + return model + + + def grid_sizes_calculation( + self, + input_shape: Tuple[int, int, int], # (D_in, H_in, W_in) + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1 + ) -> Tuple[int, int, int]: + """ + Compute the (f,h,w) output spatial/temporal dimensions of a Conv3d patch embedder. + + Args: + input_shape: (D_in, H_in, W_in) + kernel_size, stride, padding, dilation of the Conv3d patch embedder: either int or 3-tuple + + Returns: + (D_out, H_out, W_out) + """ + + def to_tuple(x): + return (x, x, x) if isinstance(x, int) else x + + kernel_size = to_tuple(kernel_size) + stride = to_tuple(stride) + padding = to_tuple(padding) + dilation = to_tuple(dilation) + + D_in, H_in, W_in = input_shape + + def calc_out(in_size, k, s, p, d): + return math.floor((in_size + 2*p - d*(k - 1) - 1) / s + 1) + + D_out = calc_out(D_in, kernel_size[0], stride[0], padding[0], dilation[0]) + H_out = calc_out(H_in, kernel_size[1], stride[1], padding[1], dilation[1]) + W_out = calc_out(W_in, kernel_size[2], stride[2], padding[2], dilation[2]) + + return [D_out, H_out, W_out] + + + def forward_pp_step( + self, + latent_model_input: torch.Tensor, + grid_sizes: list[Tuple[int, int, int]], + max_video_seq_len: int, + timestep: torch.Tensor, + arg_c: dict, + ) -> torch.Tensor: + """One decode step supporting pipeline parallelism for batch_size=1. + + Returns a tensor containing the noise prediction. + """ + + from megatron.core import parallel_state + from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage, recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank + + pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() + is_pp_first = parallel_state.is_pipeline_first_stage(ignore_virtual=True) + is_pp_last = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + + # TP-only or single-rank + if pp_world_size == 1: + noise_pred_pp = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + **arg_c) + return noise_pred_pp + + # Pipeline-parallel path + hidden_size = self.model.config.hidden_size + batch_size = latent_model_input.shape[1] + noise_pred_pp_shape = list(latent_model_input.shape) + print(f"batch_size: {batch_size}") + + # DEBUGGING + # we should bring x unpatchify out of the model + # x_after_patch_embedding_shape = [16, 3, 104, 60] # ???? + # when bring unpatchified out, for pp communicate last stage to first stage, this should be + # x_after_patch_embedding_shape = [max_video_seq_len, batch_size, (ph pw pt C)] + + if is_pp_first: + # First stage: compute multimodal + first PP slice, send activations, then receive sampled token + hidden_states = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + **arg_c) + print(f"[rank {torch.distributed.get_rank()}] Got here! - self.model") + send_to_next_pipeline_rank(hidden_states) + print(f"[rank {torch.distributed.get_rank()}] Got here! - hidden_states.shape: {hidden_states.shape} - hidden_states.dtype: {hidden_states.dtype}") + print(f"[rank {torch.distributed.get_rank()}] Got here! - send_to_next_pipeline_rank") + + noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) + return noise_pred_pp + + if is_pp_last: + # Last stage: recv activations, run final slice + output, sample, broadcast + recv_buffer = torch.empty( + (max_video_seq_len, batch_size, hidden_size), + dtype=next(self.model.parameters()).dtype, + device=latent_model_input[0].device, + ) + recv_from_prev_pipeline_rank_(recv_buffer) + # DEBUGGING + recv_buffer = recv_buffer.to(torch.bfloat16) # ???? + self.model.set_input_tensor(recv_buffer) + noise_pred_pp = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + **arg_c) + + + print("noise_pred_pp_shape: ", noise_pred_pp_shape) + + noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=noise_pred_pp.dtype, tensor=noise_pred_pp.contiguous()) + return noise_pred_pp + + # Intermediate stages: recv -> run local slice -> send -> receive broadcast token + recv_buffer = torch.empty( + (max_video_seq_len, batch_size, hidden_size), + dtype=next(self.model.parameters()).dtype, + device=latent_model_input[0].device, + ) + print(f"[rank {torch.distributed.get_rank()}] Got here! - recv_buffer.shape: {recv_buffer.shape} - recv_buffer.dtype: {recv_buffer.dtype}") + recv_from_prev_pipeline_rank_(recv_buffer) + print(f"[rank {torch.distributed.get_rank()}] Got here! - recv_from_prev_pipeline_rank_") + # DEBUGGING + recv_buffer = recv_buffer.to(torch.bfloat16) # ???? + self.model.set_input_tensor(recv_buffer) + print(f"[rank {torch.distributed.get_rank()}] Got here! - self.model.set_input_tensor") + hidden_states = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + **arg_c) + send_to_next_pipeline_rank(hidden_states) + + noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) + return noise_pred_pp + + + def generate(self, + prompts, + sizes, + frame_nums, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + size (tupele[`int`], *optional*, defaults to (1280,720)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + + # DEBUGGING + run_debug = True + + # size = sizes[0] + # input_prompt = prompts[0] + # frame_num = frame_nums[0] + + # preprocess + target_shapes = [] + for size, frame_num in zip(sizes, frame_nums): + target_shapes.append((self.vae.model.z_dim, (frame_num - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2])) + + max_video_seq_len = 0 + seq_lens = [] + for target_shape in target_shapes: + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + seq_lens.append(seq_len) + max_video_seq_len = max(seq_lens) + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + ## process context + context_max_len = 512 + context_lens = [] + contexts = [] + contexts_null = [] + for prompt in prompts: + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([prompt], self.device)[0] + context_null = self.text_encoder([n_prompt], self.device)[0] + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([prompt], torch.device('cpu'))[0].to(self.device) + context_null = self.text_encoder([n_prompt], torch.device('cpu'))[0].to(self.device) + context_lens.append(context_max_len) # all samples have the same context_max_len + contexts.append(context) + contexts_null.append(context_null) + # pad to context_max_len tokens, and stack to a tensor of shape [s, b, hidden] + contexts = [F.pad(context, (0, 0, 0, context_max_len - context.shape[0])) for context in contexts] + contexts_null = [F.pad(context_null, (0, 0, 0, context_max_len - context_null.shape[0])) for context_null in contexts_null] + contexts = torch.stack(contexts, dim=1) + contexts_null = torch.stack(contexts_null, dim=1) + + + + ## setup noise + noises = [] + for target_shape in target_shapes: + noises.append( + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ) + + # DEBUGGING + print("[DEBUG] noises[0].shape - noises[0].dtype - noises[0].mean() - noises[0].std() - noises[0].norm():", noises[0].shape, noises[0].dtype, noises[0].mean(), noises[0].std(), noises[0].norm()) + print("[DEBUG] noises[0]:", noises[0]) + + # calculate grid_sizes + grid_sizes = [self.grid_sizes_calculation( + input_shape =u.shape[1:], + kernel_size=self.model.patch_size, + stride=self.model.patch_size, + ) for u in noises] + grid_sizes = torch.tensor(grid_sizes, dtype=torch.long) + + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + # Create a prototype scheduler to compute shared timesteps + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + + # Instantiate per-sample schedulers so each sample maintains its own state + batch_size_for_schedulers = len(noises) + schedulers = [] + for _ in range(batch_size_for_schedulers): + s = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + s.set_timesteps(sampling_steps, device=self.device, shift=shift) + schedulers.append(s) + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noises + + from megatron.core.packed_seq_params import PackedSeqParams + cu_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)]) + cu_q = cu_q.to(torch.int32).to(self.device) + cu_kv_self = cu_q + cu_kv_cross = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(context_lens), dim=0)]) + cu_kv_cross = cu_kv_cross.to(torch.int32).to(self.device) + packed_seq_params = { + "self_attention": PackedSeqParams( + cu_seqlens_q=cu_q, + cu_seqlens_kv=cu_kv_self, + qkv_format="sbhd", + ), + "cross_attention": PackedSeqParams( + cu_seqlens_q=cu_q, + cu_seqlens_kv=cu_kv_cross, + qkv_format="sbhd", + ), + } + + + arg_c = {'context': contexts, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} + arg_null = {'context': contexts_null, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} + + for _, t in enumerate(tqdm(timesteps)): + + batch_size = len(latents) + + # patchify latents + # ??? when batch_size > 1, we need to pad to have same length + unpatchified_latents = latents + latents = self.patchify(latents, self.patch_size) + # pad to have same length + for i in range(batch_size): + latents[i] = F.pad(latents[i], (0, 0, 0, max_video_seq_len - latents[i].shape[0])) + latents = torch.stack(latents, dim=1) + + + latent_model_input = latents + timestep = [t] * batch_size + timestep = torch.stack(timestep) + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print(f"[DEBUG] [rank {torch.distributed.get_rank()}] contexts.shape: {contexts.shape}") + print(f"[DEBUG] [rank {torch.distributed.get_rank()}] max_video_seq_len: {max_video_seq_len}") + print(f"[DEBUG] [rank {torch.distributed.get_rank()}] grid_sizes: {grid_sizes}") + print(f"[DEBUG] [rank {torch.distributed.get_rank()}] latent_model_input.shape: {latent_model_input.shape}") + print(f"[DEBUG] [rank {torch.distributed.get_rank()}] timestep.shape: {timestep.shape}") + + + self.model.to(self.device) + noise_pred_cond = self.forward_pp_step( + latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_c) + + noise_pred_uncond = self.forward_pp_step( + latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_null) + + + # noise_pred = noise_pred_uncond + guide_scale * ( + # noise_pred_cond - noise_pred_uncond) + + # DEBUGGING + unpatchified_noise_pred_cond = noise_pred_cond + unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. ??? + unpatchified_noise_pred_cond = self.unpatchify(unpatchified_noise_pred_cond, grid_sizes, self.vae.model.z_dim) + + unpatchified_noise_pred_uncond = noise_pred_uncond + unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond.transpose(0, 1) # bring sbhd -> bshd + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. ??? + unpatchified_noise_pred_uncond = self.unpatchify(unpatchified_noise_pred_uncond, grid_sizes, self.vae.model.z_dim) + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print(f"[DEBUG] unpatchified_noise_pred_cond[0].shape - unpatchified_noise_pred_cond[0].dtype - unpatchified_noise_pred_cond[0].mean() - unpatchified_noise_pred_cond[0].std() - unpatchified_noise_pred_cond[0].norm(): {unpatchified_noise_pred_cond[0].shape} - {unpatchified_noise_pred_cond[0].dtype} - {unpatchified_noise_pred_cond[0].mean()} - {unpatchified_noise_pred_cond[0].std()} - {unpatchified_noise_pred_cond[0].norm()}") + print(f"[DEBUG] unpatchified_noise_pred_uncond[0].shape - unpatchified_noise_pred_uncond[0].dtype - unpatchified_noise_pred_uncond[0].mean() - unpatchified_noise_pred_uncond[0].std() - unpatchified_noise_pred_uncond[0].norm(): {unpatchified_noise_pred_uncond[0].shape} - {unpatchified_noise_pred_uncond[0].dtype} - {unpatchified_noise_pred_uncond[0].mean()} - {unpatchified_noise_pred_uncond[0].std()} - {unpatchified_noise_pred_uncond[0].norm()}") + + + noise_preds = [] + for i in range(batch_size): + noise_pred = unpatchified_noise_pred_uncond[i] + guide_scale * ( + unpatchified_noise_pred_cond[i] - unpatchified_noise_pred_uncond[i]) + noise_preds.append(noise_pred) + + # unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond[0] + # unpatchified_noise_pred_cond = unpatchified_noise_pred_cond[0] + + # noise_pred = unpatchified_noise_pred_uncond + guide_scale * ( + # unpatchified_noise_pred_cond - unpatchified_noise_pred_uncond) + + # # DEBUGGING + # # we will be running unpatchify here??? + # # x0 = latents + # if run_debug and torch.distributed.get_rank()==0: + # print(f"[DEBUG] [rank {torch.distributed.get_rank()}] (before unpatchify) noise_pred_cond.shape: {noise_pred_cond.shape}") + # print(f"[DEBUG] [rank {torch.distributed.get_rank()}] (before unpatchify) noise_pred_uncond.shape: {noise_pred_uncond.shape}") + # noise_pred_cond = noise_pred_cond.transpose(0, 1) + # noise_pred_cond = self.unpatchify(noise_pred_cond, grid_sizes, self.vae.model.z_dim) + # noise_pred_cond = noise_pred_cond.transpose(0, 1) + # noise_pred_uncond = noise_pred_uncond.transpose(0, 1) + # noise_pred_uncond = self.unpatchify(noise_pred_uncond, grid_sizes, self.vae.model.z_dim) + # noise_pred_uncond = noise_pred_uncond.transpose(0, 1) + # if run_debug and torch.distributed.get_rank()==0: + # print(f"[DEBUG] [rank {torch.distributed.get_rank()}] (after unpatchify) noise_pred_cond.shape: {noise_pred_cond.shape}") + # print(f"[DEBUG] [rank {torch.distributed.get_rank()}] (after unpatchify) noise_pred_uncond.shape: {noise_pred_uncond.shape}") + # print(stop_here) + + # # we run unpatchify here, but unpatchify should be run seprately for each sample in the batch, because the video shape is different for each sample in the batch. + # # ??? when batch_size > 1, we need to run sample_scheduler.step seprately for each sample in the batch. + # noise_pred = noise_pred.transpose(0, 1) # bring sbhd -> bshd + # noise_pred = self.unpatchify(noise_pred, grid_sizes, self.vae.model.z_dim) + + # print("[DEBUG] len(noise_pred): ", len(noise_pred)) + # print("[DEBUG] len(unpatchified_latents): ", len(unpatchified_latents)) + # print("[DEBUG] noise_pred[0].shape - noise_pred[0].dtype - noise_pred[0].mean() - noise_pred[0].std() - noise_pred[0].norm(): ", noise_pred[0].shape, noise_pred[0].dtype, noise_pred[0].mean(), noise_pred[0].std(), noise_pred[0].norm()) + # print("[DEBUG] unpatchified_latents[0].shape - unpatchified_latents[0].dtype - unpatchified_latents[0].mean() - unpatchified_latents[0].std() - unpatchified_latents[0].norm(): ", unpatchified_latents[0].shape, unpatchified_latents[0].dtype, unpatchified_latents[0].mean(), unpatchified_latents[0].std(), unpatchified_latents[0].norm()) + + # latents = [] + # for i in range(len(noise_pred)): + # temp_x0 = sample_scheduler.step( + # noise_pred[i].unsqueeze(0), + # t, + # unpatchified_latents[i].unsqueeze(0), + # return_dict=False, + # generator=seed_g)[0] + # latents.append(temp_x0.squeeze(0)) + + # print("len(latents): ", len(latents)) + # print("latents[0].shape: ", latents[0].shape) + + # latents = unpatchified_latents + # print(f"[DEBUG] noise_pred.shape - noise_pred.dtype - noise_pred.mean() - noise_pred.std() - noise_pred.norm(): {noise_pred.shape} - {noise_pred.dtype} - {noise_pred.mean()} - {noise_pred.std()} - {noise_pred.norm()}") + # print(f"[DEBUG] latents[0].shape - latents[0].dtype - latents[0].mean() - latents[0].std() - latents[0].norm(): {latents[0].shape} - {latents[0].dtype} - {latents[0].mean()} - {latents[0].std()} - {latents[0].norm()}") + # print(f"[DEBUG] noise_pred: {noise_pred}") + # print(f"[DEBUG] latents[0]: {latents[0]}") + + print("batch_size: ", batch_size) + + # step and update latents + latents = [] + for i in range(batch_size): + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print("[DEBUG] len(unpatchified_latents): ", len(unpatchified_latents)) + print("[DEBUG] len(noise_preds): ", len(noise_preds)) + print("[DEBUG] unpatchified_latents[i].shape - unpatchified_latents[i].dtype - unpatchified_latents[i].mean() - unpatchified_latents[i].std() - unpatchified_latents[i].norm(): ", unpatchified_latents[i].shape, unpatchified_latents[i].dtype, unpatchified_latents[i].mean(), unpatchified_latents[i].std(), unpatchified_latents[i].norm()) + print("[DEBUG] noise_preds[i].shape - noise_preds[i].dtype - noise_preds[i].mean() - noise_preds[i].std() - noise_preds[i].norm(): ", noise_preds[i].shape, noise_preds[i].dtype, noise_preds[i].mean(), noise_preds[i].std(), noise_preds[i].norm()) + + + if sample_solver == 'unipc': + temp_x0 = schedulers[i].step( + noise_preds[i].unsqueeze(0), + t, + unpatchified_latents[i].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + else: + temp_x0 = sample_scheduler.step( + noise_preds[i].unsqueeze(0), + t, + unpatchified_latents[i].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents.append(temp_x0.squeeze(0)) + + # # DEBUGGING + # # we will be running unpatchify here??? + # # x0 = latents + # x0 = self.unpatchify(latents, grid_sizes) + + # # loop through each sample in the batch + # videos = [] + # if offload_model: + # self.model.cpu() + # torch.cuda.empty_cache() + # x0 = latents + # if self.rank == 0: + # videos = self.vae.decode(x0) + + # DEBUGGING + print("[DEBUG] len(latents): ", len(latents)) + print("[DEBUG] latents[0].shape - latents[0].dtype - latents[0].mean() - latents[0].std() - latents[0].norm(): ", latents[0].shape, latents[0].dtype, latents[0].mean(), latents[0].std(), latents[0].norm()) + print("[DEBUG] latents[0]: ", latents[0]) + + x0 = latents + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + if self.rank == 0: + videos = self.vae.decode(x0) + else: + videos = None + + + # # DEBUGGING + # print("len(latents): ", len(latents)) + # print("latents[0].shape - latents[0].dtype - latents[0].mean() - latents[0].std() - latents[0].norm(): ", latents[0].shape, latents[0].dtype, latents[0].mean(), latents[0].std(), latents[0].norm()) + # print("latents[0]: ", latents[0]) + # print("len(videos): ", len(videos)) + if videos is not None: + print("len(videos): ", len(videos)) + print("[DEBUG] videos[0].shape - videos[0].dtype - videos[0].mean() - videos[0].std() - videos[0].norm(): ", videos[0].shape, videos[0].dtype, videos[0].mean(), videos[0].std(), videos[0].norm()) + print("[DEBUG] videos[0]: ", videos[0]) + + del noises, latents + if sample_solver == 'unipc': + del schedulers + else: + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos if self.rank == 0 else None diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py new file mode 100644 index 0000000000..850230eced --- /dev/null +++ b/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py @@ -0,0 +1,246 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, Optional, Tuple, List + +import numpy as np +import torch +import torch.distributed +from megatron.core import parallel_state +# from megatron.bridge.models.DiTModel.sampler.context_parallel import cat_outputs_cp ??? +from torch import Tensor +from diffusers import WanPipeline + +class FlowPipeline: + """ + FlowPipeline is a class that implements a diffusion model pipeline for video generation. It includes methods for + initializing the pipeline, encoding and decoding video data, performing training steps, denoising, and generating + samples. + Attributes: + ... + Methods: + ... + """ + + def __init__( + self, + model_id="Wan-AI/Wan2.2-T2V-A14B-Diffusers", + vae=None, + seed=1234, + ): + """ + Initializes the FlowPipeline with the given parameters. + + Args: + net: The DiT model. + vae: The Video Tokenizer (optional). + seed (int): Random seed for reproducibility. + + Attributes: + vae: The Video Tokenizer. + net: The DiT model. + _noise_generator: Generator for noise. + seed (int): Random seed for reproducibility. + input_data_key (str): Key for input data. + input_image_key (str): Key for input images. + tensor_kwargs (dict): Tensor keyword arguments for device and dtype. + """ + self.vae = vae + + self.seed = seed + self._noise_generator = None + + self.input_data_key = "video" + self.input_image_key = "images_1024" + self.tensor_kwargs = {"device": "cuda", "dtype": torch.bfloat16} + + pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.float32) + self.scheduler = pipe.scheduler + + + def _initialize_generators(self): + """ + Initializes the random number generators for noise + + This method sets up a generator: + 1. A PyTorch generator for noise, seeded with a combination of the base seed and the data parallel rank. + + Returns: + None + """ + noise_seed = self.seed + 100 * parallel_state.get_data_parallel_rank(with_context_parallel=True) + noise_level_seed = self.seed + 100 * parallel_state.get_data_parallel_rank(with_context_parallel=False) + self._noise_generator = torch.Generator(device="cuda") + self._noise_generator.manual_seed(noise_seed) + + def training_step( + self, model, data_batch: dict[str, torch.Tensor] + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + Performs a single training step for the diffusion model. + + This method is responsible for executing one iteration of the model's training. It involves: + 1. Adding noise to the input data using the SDE process. + 2. Passing the noisy data through the network to generate predictions. + 3. Computing the loss based on the difference between the predictions and the original data. + + Args: + data_batch (dict): raw data batch draw from the training data loader. + + Returns: + A tuple with the output batch and the computed loss. + """ + + # DEBUGGING + run_debug = False + if run_debug and torch.distributed.get_rank()==0: + print("---- Sample info [FlowPipeline.training_step] ----") + print(f"data_batch['video_latents'] shape: {data_batch['video_latents'].shape}") + print(f"data_batch['context_embeddings'] shape: {data_batch['context_embeddings'].shape}") + print(f"data_batch['loss_mask'] shape: {data_batch['loss_mask'].shape}") + print(f"data_batch['grid_sizes']: {data_batch['grid_sizes']}") + print(f"data_batch['packed_seq_params']: {data_batch['packed_seq_params']}") + print(f"data_batch['max_video_seq_len']: {data_batch['max_video_seq_len']}") + + + video_latents = data_batch['video_latents'] + max_video_seq_len = data_batch['max_video_seq_len'] + context_embeddings = data_batch['context_embeddings'] + grid_sizes = data_batch['grid_sizes'] + packed_seq_params = data_batch['packed_seq_params'] + + + # Get the input data to noise and denoise~(image, video) and the corresponding conditioner. + self.model = model + + + # Get timesteps + batch_size = video_latents.shape[1] + device = video_latents.device + timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (batch_size,), device=device) + + # Generate noise + # shape of latents is [S, B, (C pF pH pW)] + noise_batch = torch.randn_like(video_latents) + + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print("---- Sample info [FlowPipeline.training_step] ----") + print(f"noise_batch shape: {noise_batch.shape}") + print(f"timesteps shape: {timesteps.shape}") + print(f"video_latents shape: {video_latents.shape}") + print("--------------------------------") + + # ??? can this add_noise method used for videos of different sizes and just padding? + # => it should be, because the main formula is: noisy_latents = alpha_t * original_samples + sigma_t * noise + # Apply scheduler noise based on timesteps + # DEBUGGING + # bring to shape [batch_size, ...] to run add_noise + noisy_latents = self.scheduler.add_noise(video_latents.transpose(0, 1), noise_batch.transpose(0, 1), timesteps) + noisy_latents = noisy_latents.transpose(0, 1) + + # Pass through model + # noise only needed at the last stage + if parallel_state.is_pipeline_last_stage(): + output_batch, loss = self.compute_loss( + noisy_latents, noise_batch, timesteps, context_embeddings, grid_sizes, packed_seq_params, max_video_seq_len + ) + + return output_batch, loss + else: + hidden_states = self.compute_loss( + noisy_latents, timesteps, context_embeddings, grid_sizes, packed_seq_params, max_video_seq_len + ) + return hidden_states + + # def get_data_and_condition(self, data_batch: dict[str, Tensor]) -> Tuple[Tensor]: + # """ + # Retrieves data and conditioning for model input. + + # Args: + # data_batch: Batch of input data. + + # Returns: + # ... + # """ + # ... + # return None + + def compute_loss( + self, + video_latents: torch.Tensor, + noise_batch: torch.Tensor, + timesteps: torch.Tensor, + context_embeddings: torch.Tensor, + grid_sizes: List[Tuple[int, int, int]], + packed_seq_params: dict, + max_video_seq_len: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Computes the loss for the given latents, timesteps, context_embeddings, grid_sizes, and packed_seq_params. + """ + + # ??? the shape of latents is [S, B, (ph pw pt C)] + # ??? the shape of noise is [S, B, (ph pw pt C)] + # loss_mask is [S, B], will be transffered in WanForwardStep to combine with loss to get the final loss + + # condition would be: + # t5_text_embeddings, t5_text_mask, seq_len_q, seq_len_kv, pos_ids, latent_shape, grid_sizes + # the shape of t5_text_embeddings is [S, B, (ph pw pt C)] + # the shape of t5_text_mask is [S, B] + # the shape of seq_len_q is [B] + # the shape of seq_len_kv is [B] + # the shape of pos_ids is [S, B, (ph pw pt C)] + # the shape of latent_shape is [B, 4] + # the shape of grid_sizes is [B, 3] + + # Pass through model + if parallel_state.is_pipeline_last_stage(): + model_predict = self.model( + x = video_latents, + grid_sizes = grid_sizes, + t = timesteps, + context = context_embeddings, + max_seq_len = max_video_seq_len, + packed_seq_params=packed_seq_params, + ) + + # Compute target based on prediction type + if self.scheduler.config.prediction_type == "epsilon": + target = noise_batch + elif self.scheduler.config.prediction_type == "v_prediction": + target = self.scheduler.get_velocity(latents, noise_batch, timesteps) + elif self.scheduler.config.prediction_type == "flow_prediction": + # Flow matching + target = video_latents - noise_batch + else: + raise ValueError(f"Unknown prediction type: {self.scheduler.config.prediction_type}") + + # Compute loss + loss = torch.nn.functional.mse_loss(model_predict, target, reduction="mean") + + return model_predict, loss + + else: + hidden_states = self.model( + x = video_latents, + grid_sizes = grid_sizes, + t = timesteps, + context = context_embeddings, + max_seq_len = max_video_seq_len, + packed_seq_params=packed_seq_params, + ) + + return hidden_states diff --git a/src/megatron/bridge/models/wan/inference/configs/__init__.py b/src/megatron/bridge/models/wan/inference/configs/__init__.py new file mode 100644 index 0000000000..e7f95d7125 --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/configs/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import copy +import os + +os.environ['TOKENIZERS_PARALLELISM'] = 'false' + +from .wan_i2v_14B import i2v_14B +from .wan_t2v_1_3B import t2v_1_3B +from .wan_t2v_14B import t2v_14B + +# the config of t2i_14B is the same as t2v_14B +t2i_14B = copy.deepcopy(t2v_14B) +t2i_14B.__name__ = 'Config: Wan T2I 14B' + +# the config of flf2v_14B is the same as i2v_14B +flf2v_14B = copy.deepcopy(i2v_14B) +flf2v_14B.__name__ = 'Config: Wan FLF2V 14B' +flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt + +WAN_CONFIGS = { + 't2v-14B': t2v_14B, + 't2v-1.3B': t2v_1_3B, + 'i2v-14B': i2v_14B, + 't2i-14B': t2i_14B, + 'flf2v-14B': flf2v_14B, + 'vace-1.3B': t2v_1_3B, + 'vace-14B': t2v_14B, +} + +SIZE_CONFIGS = { + '720*1280': (720, 1280), + '1280*720': (1280, 720), + '480*832': (480, 832), + '832*480': (832, 480), + '1024*1024': (1024, 1024), +} + +MAX_AREA_CONFIGS = { + '720*1280': 720 * 1280, + '1280*720': 1280 * 720, + '480*832': 480 * 832, + '832*480': 832 * 480, +} + +SUPPORTED_SIZES = { + 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), + 't2v-1.3B': ('480*832', '832*480'), + 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), + 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), + 't2i-14B': tuple(SIZE_CONFIGS.keys()), + 'vace-1.3B': ('480*832', '832*480'), + 'vace-14B': ('720*1280', '1280*720', '480*832', '832*480') +} diff --git a/src/megatron/bridge/models/wan/inference/configs/shared_config.py b/src/megatron/bridge/models/wan/inference/configs/shared_config.py new file mode 100644 index 0000000000..56a99ad433 --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/configs/shared_config.py @@ -0,0 +1,21 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +from easydict import EasyDict + +#------------------------ Wan shared config ------------------------# +wan_shared_cfg = EasyDict() + +# t5 +wan_shared_cfg.t5_model = 'umt5_xxl' +wan_shared_cfg.t5_dtype = torch.bfloat16 +wan_shared_cfg.text_len = 512 + +# transformer +# DEBUGGING +wan_shared_cfg.param_dtype = torch.bfloat16 +# wan_shared_cfg.param_dtype = torch.float32 + +# inference +wan_shared_cfg.num_train_timesteps = 1000 +wan_shared_cfg.sample_fps = 16 +wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' diff --git a/src/megatron/bridge/models/wan/inference/configs/wan_i2v_14B.py b/src/megatron/bridge/models/wan/inference/configs/wan_i2v_14B.py new file mode 100644 index 0000000000..53bf2211b8 --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/configs/wan_i2v_14B.py @@ -0,0 +1,36 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan I2V 14B ------------------------# + +i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') +i2v_14B.update(wan_shared_cfg) +i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt + +i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +i2v_14B.t5_tokenizer = 'google/umt5-xxl' + +# clip +i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' +i2v_14B.clip_dtype = torch.float16 +i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' +i2v_14B.clip_tokenizer = 'xlm-roberta-large' + +# vae +i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' +i2v_14B.vae_stride = (4, 8, 8) + +# transformer +i2v_14B.patch_size = (1, 2, 2) +i2v_14B.dim = 5120 +i2v_14B.ffn_dim = 13824 +i2v_14B.freq_dim = 256 +i2v_14B.num_heads = 40 +i2v_14B.num_layers = 40 +i2v_14B.window_size = (-1, -1) +i2v_14B.qk_norm = True +i2v_14B.cross_attn_norm = True +i2v_14B.eps = 1e-6 diff --git a/src/megatron/bridge/models/wan/inference/configs/wan_t2v_14B.py b/src/megatron/bridge/models/wan/inference/configs/wan_t2v_14B.py new file mode 100644 index 0000000000..9d0ee69dea --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/configs/wan_t2v_14B.py @@ -0,0 +1,29 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan T2V 14B ------------------------# + +t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') +t2v_14B.update(wan_shared_cfg) + +# t5 +t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +t2v_14B.t5_tokenizer = 'google/umt5-xxl' + +# vae +t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' +t2v_14B.vae_stride = (4, 8, 8) + +# transformer +t2v_14B.patch_size = (1, 2, 2) +t2v_14B.dim = 5120 +t2v_14B.ffn_dim = 13824 +t2v_14B.freq_dim = 256 +t2v_14B.num_heads = 40 +t2v_14B.num_layers = 40 +t2v_14B.window_size = (-1, -1) +t2v_14B.qk_norm = True +t2v_14B.cross_attn_norm = True +t2v_14B.eps = 1e-6 diff --git a/src/megatron/bridge/models/wan/inference/configs/wan_t2v_1_3B.py b/src/megatron/bridge/models/wan/inference/configs/wan_t2v_1_3B.py new file mode 100644 index 0000000000..ea9502b0df --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/configs/wan_t2v_1_3B.py @@ -0,0 +1,29 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan T2V 1.3B ------------------------# + +t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') +t2v_1_3B.update(wan_shared_cfg) + +# t5 +t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' + +# vae +t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' +t2v_1_3B.vae_stride = (4, 8, 8) + +# transformer +t2v_1_3B.patch_size = (1, 2, 2) +t2v_1_3B.dim = 1536 +t2v_1_3B.ffn_dim = 8960 +t2v_1_3B.freq_dim = 256 +t2v_1_3B.num_heads = 12 +t2v_1_3B.num_layers = 30 +t2v_1_3B.window_size = (-1, -1) +t2v_1_3B.qk_norm = True +t2v_1_3B.cross_attn_norm = True +t2v_1_3B.eps = 1e-6 diff --git a/src/megatron/bridge/models/wan/inference/utils/fm_solvers.py b/src/megatron/bridge/models/wan/inference/utils/fm_solvers.py new file mode 100644 index 0000000000..17bef85000 --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/utils/fm_solvers.py @@ -0,0 +1,859 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +# Convert dpm solver for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import inspect +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate, is_scipy_available +from diffusers.utils.torch_utils import randn_tensor + +if is_scipy_available(): + pass + + +def get_sampling_sigmas(sampling_steps, shift): + sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] + sigma = (shift * sigma / (1 + (shift - 1) * sigma)) + + return sigma + + +def retrieve_timesteps( + scheduler, + num_inference_steps=None, + device=None, + timesteps=None, + sigmas=None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. This determines the resolution of the diffusion process. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored + and used in multistep updates. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + shift (`float`, *optional*, defaults to 1.0): + A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling + process. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is + applied on the fly. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent + saturation and improve photorealism. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, *optional*, defaults to "zero"): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + invert_sigmas: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", + deprecation_message) + + # settings for DPM-Solver + if algorithm_type not in [ + "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++" + ]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError( + f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++" + ] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + self._step_index = None + self._begin_index = None + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / + sigma_s) * sample - (alpha_t * + (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / + alpha_s) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ((alpha_t / alpha_s) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * + (alpha_t * (torch.exp(-h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * + (sigma_t * (torch.exp(h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * + (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / + (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + self.sigmas[self.step_index - 2], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[ + -2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - + (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((alpha_t / alpha_s0) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2) + return x_t # pyright: ignore + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final or + (self.config.lower_order_final and len(self.timesteps) < 15) or + self.config.final_sigmas_type == "zero") + lower_order_second = ((self.step_index == len(self.timesteps) - 2) and + self.config.lower_order_final and + len(self.timesteps) < 15) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++" + ] and variance_noise is None: + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=torch.float32) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to( + device=model_output.device, + dtype=torch.float32) # pyright: ignore + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update( + model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.Tensor`): + The input sample. + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/megatron/bridge/models/wan/inference/utils/fm_solvers_unipc.py b/src/megatron/bridge/models/wan/inference/utils/fm_solvers_unipc.py new file mode 100644 index 0000000000..fb502f2eb2 --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/utils/fm_solvers_unipc.py @@ -0,0 +1,802 @@ +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError( + " missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], + b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop( + "this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError( + " missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError( + " missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError( + " missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ + self.step_index - 1] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step(self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and + self.step_index - 1 not in self.disable_corrector and + self.last_sample is not None # pyright: ignore + ) + + model_output_convert = self.convert_model_output( + model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, + len(self.timesteps) - + self.step_index) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, + self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/megatron/bridge/models/wan/inference/utils/utils.py b/src/megatron/bridge/models/wan/inference/utils/utils.py new file mode 100644 index 0000000000..d72599967f --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/utils/utils.py @@ -0,0 +1,118 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +import binascii +import os +import os.path as osp + +import imageio +import torch +import torchvision + +__all__ = ['cache_video', 'cache_image', 'str2bool'] + + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + + +def cache_video(tensor, + save_file=None, + fps=30, + suffix='.mp4', + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5): + # cache file + cache_file = osp.join('/tmp', rand_name( + suffix=suffix)) if save_file is None else save_file + + # save to cache + error = None + for _ in range(retry): + try: + # preprocess + tensor = tensor.clamp(min(value_range), max(value_range)) + tensor = torch.stack([ + torchvision.utils.make_grid( + u, nrow=nrow, normalize=normalize, value_range=value_range) + for u in tensor.unbind(2) + ], + dim=1).permute(1, 2, 3, 0) + tensor = (tensor * 255).type(torch.uint8).cpu() + + # write video + writer = imageio.get_writer( + cache_file, fps=fps, codec='libx264', quality=8) + for frame in tensor.numpy(): + writer.append_data(frame) + writer.close() + return cache_file + except Exception as e: + error = e + continue + else: + print(f'cache_video failed, error: {error}', flush=True) + return None + + +def cache_image(tensor, + save_file, + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5): + # cache file + suffix = osp.splitext(save_file)[1] + if suffix.lower() not in [ + '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' + ]: + suffix = '.png' + + # save to cache + error = None + for _ in range(retry): + try: + tensor = tensor.clamp(min(value_range), max(value_range)) + torchvision.utils.save_image( + tensor, + save_file, + nrow=nrow, + normalize=normalize, + value_range=value_range) + return save_file + except Exception as e: + error = e + continue + + +def str2bool(v): + """ + Convert a string to a boolean. + + Supported true values: 'yes', 'true', 't', 'y', '1' + Supported false values: 'no', 'false', 'f', 'n', '0' + + Args: + v (str): String to convert. + + Returns: + bool: Converted boolean value. + + Raises: + argparse.ArgumentTypeError: If the value cannot be converted to boolean. + """ + if isinstance(v, bool): + return v + v_lower = v.lower() + if v_lower in ('yes', 'true', 't', 'y', '1'): + return True + elif v_lower in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected (True/False)') diff --git a/src/megatron/bridge/models/wan/modules/__init__.py b/src/megatron/bridge/models/wan/modules/__init__.py new file mode 100644 index 0000000000..435f1eef0d --- /dev/null +++ b/src/megatron/bridge/models/wan/modules/__init__.py @@ -0,0 +1,13 @@ +from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model +from .tokenizers import HuggingfaceTokenizer +from .vae import WanVAE + + +__all__ = [ + 'WanVAE', + 'T5Model', + 'T5Encoder', + 'T5Decoder', + 'T5EncoderModel', + 'HuggingfaceTokenizer', +] diff --git a/src/megatron/bridge/models/wan/modules/t5.py b/src/megatron/bridge/models/wan/modules/t5.py new file mode 100644 index 0000000000..c841b044a2 --- /dev/null +++ b/src/megatron/bridge/models/wan/modules/t5.py @@ -0,0 +1,513 @@ +# Modified from transformers.models.t5.modeling_t5 +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .tokenizers import HuggingfaceTokenizer + +__all__ = [ + 'T5Model', + 'T5Encoder', + 'T5Decoder', + 'T5EncoderModel', +] + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5Model): + nn.init.normal_(m.token_embedding.weight, std=1.0) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_( + m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) + + +class GELU(nn.Module): + + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, + -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum('bnij,bjnc->binc', attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5CrossAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=False) + + def forward(self, + x, + mask=None, + encoder_states=None, + encoder_mask=None, + pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn( + self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ + torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( + 0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + math.log(self.max_dist / max_exact) * + (num_buckets - max_exact)).long() + rel_pos_large = torch.min( + rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + + +class T5Encoder(nn.Module): + + def __init__(self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Encoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Decoder(nn.Module): + + def __init__(self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Decoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=False) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): + b, s = ids.size() + + # causal mask + if mask is None: + mask = torch.tril(torch.ones(1, s, s).to(ids.device)) + elif mask.ndim == 2: + mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) + + # layers + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Model(nn.Module): + + def __init__(self, + vocab_size, + dim, + dim_attn, + dim_ffn, + num_heads, + encoder_layers, + decoder_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Model, self).__init__() + self.vocab_size = vocab_size + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_buckets = num_buckets + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn, + num_heads, encoder_layers, num_buckets, + shared_pos, dropout) + self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn, + num_heads, decoder_layers, num_buckets, + shared_pos, dropout) + self.head = nn.Linear(dim, vocab_size, bias=False) + + # initialize weights + self.apply(init_weights) + + def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): + x = self.encoder(encoder_ids, encoder_mask) + x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) + x = self.head(x) + return x + + +def _t5(name, + encoder_only=False, + decoder_only=False, + return_tokenizer=False, + tokenizer_kwargs={}, + dtype=torch.float32, + device='cpu', + **kwargs): + # sanity check + assert not (encoder_only and decoder_only) + + # params + if encoder_only: + model_cls = T5Encoder + kwargs['vocab'] = kwargs.pop('vocab_size') + kwargs['num_layers'] = kwargs.pop('encoder_layers') + _ = kwargs.pop('decoder_layers') + elif decoder_only: + model_cls = T5Decoder + kwargs['vocab'] = kwargs.pop('vocab_size') + kwargs['num_layers'] = kwargs.pop('decoder_layers') + _ = kwargs.pop('encoder_layers') + else: + model_cls = T5Model + + # init model + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + + # init tokenizer + if return_tokenizer: + from .tokenizers import HuggingfaceTokenizer + tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs) + return model, tokenizer + else: + return model + + +def umt5_xxl(**kwargs): + cfg = dict( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + encoder_layers=24, + decoder_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1) + cfg.update(**kwargs) + return _t5('umt5-xxl', **cfg) + + +class T5EncoderModel: + + def __init__( + self, + text_len, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + checkpoint_path=None, + tokenizer_path=None, + shard_fn=None, + ): + self.text_len = text_len + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + model = umt5_xxl( + encoder_only=True, + return_tokenizer=False, + dtype=dtype, + device=device).eval().requires_grad_(False) + logging.info(f'loading {checkpoint_path}') + model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + self.model = model + if shard_fn is not None: + self.model = shard_fn(self.model, sync_module_states=False) + else: + self.model.to(self.device) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer( + name=tokenizer_path, seq_len=text_len, clean='whitespace') + + def __call__(self, texts, device): + ids, mask = self.tokenizer( + texts, return_mask=True, add_special_tokens=True) + ids = ids.to(device) + mask = mask.to(device) + seq_lens = mask.gt(0).sum(dim=1).long() + context = self.model(ids, mask) + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/src/megatron/bridge/models/wan/modules/tokenizers.py b/src/megatron/bridge/models/wan/modules/tokenizers.py new file mode 100644 index 0000000000..121e591c48 --- /dev/null +++ b/src/megatron/bridge/models/wan/modules/tokenizers.py @@ -0,0 +1,82 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + +__all__ = ['HuggingfaceTokenizer'] + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace('_', ' ') + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans('', '', string.punctuation)) + for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans('', '', string.punctuation)) + text = text.lower() + text = re.sub(r'\s+', ' ', text) + return text.strip() + + +class HuggingfaceTokenizer: + + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, 'whitespace', 'lower', 'canonicalize') + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop('return_mask', False) + + # arguments + _kwargs = {'return_tensors': 'pt'} + if self.seq_len is not None: + _kwargs.update({ + 'padding': 'max_length', + 'truncation': True, + 'max_length': self.seq_len + }) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == 'whitespace': + text = whitespace_clean(basic_clean(text)) + elif self.clean == 'lower': + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == 'canonicalize': + text = canonicalize(basic_clean(text)) + return text diff --git a/src/megatron/bridge/models/wan/modules/vae.py b/src/megatron/bridge/models/wan/modules/vae.py new file mode 100644 index 0000000000..5c6da57235 --- /dev/null +++ b/src/megatron/bridge/models/wan/modules/vae.py @@ -0,0 +1,663 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = [ + 'WanVAE', +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -1:, :, :].clone() + # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': + # # cache last frame of last two chunk + # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 + conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk( + 3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + ## 对encode输入的x,按时间拆分为1、4、4、4.... + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + #cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + # params + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0) + cfg.update(**kwargs) + + # init model + with torch.device('meta'): + model = WanVAE_(**cfg) + + # load checkpoint + logging.info(f'loading {pretrained_path}') + model.load_state_dict( + torch.load(pretrained_path, map_location=device), assign=True) + + return model + + +class WanVAE: + + def __init__(self, + z_dim=16, + vae_pth='cache/vae_step_411000.pth', + dtype=torch.float, + device="cuda"): + self.dtype = dtype + self.device = device + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean, dtype=dtype, device=device) + self.std = torch.tensor(std, dtype=dtype, device=device) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae( + pretrained_path=vae_pth, + z_dim=z_dim, + ).eval().requires_grad_(False).to(device) + + def encode(self, videos): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + with amp.autocast(dtype=self.dtype): + return [ + self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) + for u in videos + ] + + def decode(self, zs): + with amp.autocast(dtype=self.dtype): + return [ + self.model.decode(u.unsqueeze(0), + self.scale).float().clamp_(-1, 1).squeeze(0) + for u in zs + ] diff --git a/src/megatron/bridge/models/wan/rope_utils.py b/src/megatron/bridge/models/wan/rope_utils.py new file mode 100644 index 0000000000..6e25fdb24b --- /dev/null +++ b/src/megatron/bridge/models/wan/rope_utils.py @@ -0,0 +1,61 @@ +import torch +from torch.cuda import amp + +class Wan3DRopeEmbeddings(torch.nn.Module): + """ + Wan 3D RoPE embeddings implementation. + Implements Wan's 3D RoPE embeddings for Mcore Attention based on Wan's implementation at https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py. + """ + + def __init__(self, dim_head, max_position_len): + super().__init__() + self.freqs = torch.cat([ + self.rope_params(max_position_len, dim_head - 4 * (dim_head // 6)), + self.rope_params(max_position_len, 2 * (dim_head // 6)), + self.rope_params(max_position_len, 2 * (dim_head // 6)) + ], dim=1) + + def rope_params(self, max_position_len, dim_head, theta=10000): + assert dim_head % 2 == 0 + freqs = torch.outer( + torch.arange(max_position_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim_head, 2).to(torch.float64).div(dim_head))) + return freqs + + def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): + self.freqs = self.freqs.to(device) # ??? do we need to put this here, or the when we move WanModel to device, it also move freqs to device? + + n, c = n_head, dim_head // 2 + + # split freqs + freqs = self.freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + freqs_real = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + freqs_real_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(seq_len, 1, 1, -1) # <-- add 1,1 for batch/head broadcasting + + # Double dimension from c -> 2c with rotating angles as (x0, x0, x1, x1, ...), for interleaving RoPE + freqs_real_i = freqs_real_i.unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(seq_len, 1, 1, dim_head) + + # Pad freqs_real_i to (max_seq_len, 1, 1, dim_head) with 0s + if freqs_real_i.shape[0] < max_seq_len: + pad_shape = (max_seq_len - freqs_real_i.shape[0], 1, 1, dim_head) + freqs_real_i = torch.cat( + [freqs_real_i, torch.zeros(pad_shape, dtype=freqs_real_i.dtype, device=freqs_real_i.device)] + ) + freqs_real.append(freqs_real_i) + + # Each freqs_real[i] is (max_seq_len, 1, 1, dim_head) + # We concatenate them along dim=1 to get (max_seq_len, batch_size, 1, dim_head) + freqs_real = torch.cat(freqs_real, dim=1) + + # TODO: if run context/sequence related parallel, then we need to scatter + # the freqs_real to the context parallel region, using specific method "get_pos_emb_on_this_cp_rank" + + return freqs_real \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/wan_bridge.py b/src/megatron/bridge/models/wan/wan_bridge.py new file mode 100644 index 0000000000..80d7eafafe --- /dev/null +++ b/src/megatron/bridge/models/wan/wan_bridge.py @@ -0,0 +1,225 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import torch +from megatron.bridge.models.wan.wan_model import WanModel +from diffusers import WanTransformer3DModel + +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, + KVMapping, + ReplicatedMapping, +) +from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN +from megatron.bridge.models.wan.wan_provider import WanModelProvider +from megatron.core.transformer.utils import openai_gelu +from megatron.bridge.models.conversion.utils import get_module_and_param_from_name + + +@MegatronModelBridge.register_bridge(source=WanTransformer3DModel, target=WanModel) +class WanBridge(MegatronModelBridge): + """ + Megatron Bridge for WAN model. + + As a user you would not use this bridge directly, but through `AutoBridge`. + + Example: + >>> from megatron.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("WAN-3D-1.3B-v1") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedWAN) -> WanModelProvider: + hf_config = hf_pretrained.config + + cls = WanModelProvider + + provider = cls( + num_layers=hf_config.num_layers, + hidden_size=hf_config.num_attention_heads * hf_config.attention_head_dim, + kv_channels=hf_config.attention_head_dim, + num_query_groups=hf_config.num_attention_heads, + crossattn_emb_size=hf_config.num_attention_heads * hf_config.attention_head_dim, + ffn_hidden_size=hf_config.ffn_dim, + num_attention_heads=hf_config.num_attention_heads, + activation_func=openai_gelu, + add_qkv_bias=True, + in_channels=hf_config.in_channels, + out_channels=hf_config.out_channels, + text_dim=hf_config.text_dim, + patch_spatial=hf_config.patch_size[1], + patch_temporal=hf_config.patch_size[0], + patch_size=hf_config.patch_size, # ??? adundant variable + rotary_interleaved=True, + layernorm_epsilon=hf_config.eps, + hidden_dropout=0, + attention_dropout=0, + use_cpu_initialization=True, + freq_dim=hf_config.freq_dim, + qk_layernorm_per_head=False, + bf16=False, + params_dtype=torch.float32, + ) + + # num_layers=source_config.num_layers, # dummy setting + # hidden_size=source_config.num_attention_heads * source_config.attention_head_dim, + # crossattn_emb_size=source_config.num_attention_heads * source_config.attention_head_dim, + # ffn_hidden_size=source_config.ffn_dim, + # num_attention_heads=source_config.num_attention_heads, + # activation_func=openai_gelu, + # add_qkv_bias=True, + # in_channels=source_config.in_channels, + # text_dim=source_config.text_dim, + # # model_channels=256, + # # DEBUGGING + # patch_spatial=source_config.patch_size[1], + # patch_temporal=source_config.patch_size[0], + # patch_size=source_config.patch_size, + # rotary_interleaved=True, + # layernorm_epsilon=1e-06, + # hidden_dropout=0, + # attention_dropout=0, + # use_cpu_initialization=True, + # # DEBUGGING + # freq_dim=source_config.freq_dim, + # bf16=False, + # params_dtype=torch.float32, + # # DEBUGGING + # qk_layernorm_per_head=False, + + return provider + + + def mapping_registry(self) -> MegatronMappingRegistry: + """Return MegatronMappingRegistry containing parameter mappings from HF to Megatron format. + + Returns: + MegatronMappingRegistry: Registry of parameter mappings + """ + # Dictionary maps HF parameter names -> Megatron parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "scale_shift_table": "head.modulation", + "patch_embedding.weight": "patch_embedding.weight", + "patch_embedding.bias": "patch_embedding.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "blocks.*.scale_shift_table": "decoder.layers.*.adaLN.modulation", + "blocks.*.attn1.to_out.0.weight": "decoder.layers.*.full_self_attention.linear_proj.weight", + "blocks.*.attn1.to_out.0.bias": "decoder.layers.*.full_self_attention.linear_proj.bias", + "blocks.*.attn1.norm_q.weight": "decoder.layers.*.full_self_attention.q_layernorm.weight", + "blocks.*.attn1.norm_k.weight": "decoder.layers.*.full_self_attention.k_layernorm.weight", + "blocks.*.attn2.to_q.weight": "decoder.layers.*.cross_attention.linear_q.weight", + "blocks.*.attn2.to_q.bias": "decoder.layers.*.cross_attention.linear_q.bias", + "blocks.*.attn2.to_out.0.weight": "decoder.layers.*.cross_attention.linear_proj.weight", + "blocks.*.attn2.to_out.0.bias": "decoder.layers.*.cross_attention.linear_proj.bias", + "blocks.*.attn2.norm_q.weight": "decoder.layers.*.cross_attention.q_layernorm.weight", + "blocks.*.attn2.norm_k.weight": "decoder.layers.*.cross_attention.k_layernorm.weight", + "blocks.*.norm2.weight": "decoder.layers.*.norm3.weight", + "blocks.*.norm2.bias": "decoder.layers.*.norm3.bias", + "blocks.*.ffn.net.0.proj.weight": "decoder.layers.*.mlp.linear_fc1.weight", + "blocks.*.ffn.net.0.proj.bias": "decoder.layers.*.mlp.linear_fc1.bias", + "blocks.*.ffn.net.2.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "blocks.*.ffn.net.2.bias": "decoder.layers.*.mlp.linear_fc2.bias", + "proj_out.weight": "head.head.weight", + "proj_out.bias": "head.head.bias", + } + + + # Custom WAN mapping to safely handle replicated params whose owning module + # does not expose a top-level `.weight` (e.g., Head.modulation) + class _ReplicatedByParamNameMapping(ReplicatedMapping): + def hf_to_megatron(self, hf_weights, megatron_module): + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + + target_device = target_param.device + target_dtype = target_param.dtype + + hf_weights = hf_weights.to(device=target_device, dtype=target_dtype) + if self.tp_size == 1: + return hf_weights + + if target_device.type == "cuda" and torch.cuda.is_available(): + if target_device.index != torch.cuda.current_device(): + hf_weights = hf_weights.to(torch.cuda.current_device()) + + if self.tp_rank > 0: + hf_weights = torch.empty_like(hf_weights) + + return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0) + + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(hf_param, megatron_param) + for hf_param, megatron_param in param_mappings.items(): + if hf_param in {"scale_shift_table", "blocks.*.scale_shift_table", "proj_out.weight", "proj_out.bias"}: + # Use WAN-specific replicated mapping that resolves the exact param + mapping_list.append(_ReplicatedByParamNameMapping(hf_param=hf_param, megatron_param=megatron_param)) + else: + mapping_list.append(AutoMapping(hf_param=hf_param, megatron_param=megatron_param)) + + # Adding custom module types for AutoMapping + AutoMapping.register_module_type("Linear", "replicated") + AutoMapping.register_module_type("Conv3d", "replicated") + AutoMapping.register_module_type("WanAdaLN", "replicated") + AutoMapping.register_module_type("Head", "replicated") + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # QKV: Combine separate Q, K, V matrices into single QKV matrix + QKVMapping( + q="blocks.*.attn1.to_q.weight", + k="blocks.*.attn1.to_k.weight", + v="blocks.*.attn1.to_v.weight", + megatron_param="decoder.layers.*.full_self_attention.linear_qkv.weight", + ), + # QKV bias: Combine separate Q, K, V bias into single QKV bias + QKVMapping( + q="blocks.*.attn1.to_q.bias", + k="blocks.*.attn1.to_k.bias", + v="blocks.*.attn1.to_v.bias", + megatron_param="decoder.layers.*.full_self_attention.linear_qkv.bias", + ), + # K, V: Combine separate K, V matrices into single KV matrix + KVMapping( + k="blocks.*.attn2.to_k.weight", + v="blocks.*.attn2.to_v.weight", + megatron_param="decoder.layers.*.cross_attention.linear_kv.weight", + ), + # K, V bias: Combine separate K, V bias into single KV bias + KVMapping( + k="blocks.*.attn2.to_k.bias", + v="blocks.*.attn2.to_v.bias", + megatron_param="decoder.layers.*.cross_attention.linear_kv.bias", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/wan_layer_spec.py b/src/megatron/bridge/models/wan/wan_layer_spec.py new file mode 100644 index 0000000000..3b014140cf --- /dev/null +++ b/src/megatron/bridge/models/wan/wan_layer_spec.py @@ -0,0 +1,674 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import copy +from dataclasses import dataclass +from typing import Union, Optional + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from megatron.core import parallel_state, tensor_parallel +from megatron.core.transformer.attention import ( + CrossAttention, + CrossAttentionSubmodules, + SelfAttention, + SelfAttentionSubmodules, +) +from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TERowParallelLinear, +) +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.utils import make_viewless_tensor +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.extensions.transformer_engine import TENorm + +try: + import transformer_engine # pylint: disable=unused-import + + HAVE_TE = True + from megatron.core.extensions.transformer_engine import SplitAlongDim + +except ImportError: + HAVE_TE = False + SplitAlongDim = None + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x.float()).type_as(x) + + +@dataclass +class WanSelfAttentionSubmodules: + """ + Configuration class for specifying the submodules of a self-attention. + """ + + linear_qkv: Union[ModuleSpec, type] = None + core_attention: Union[ModuleSpec, type] = None + linear_proj: Union[ModuleSpec, type] = None + layernorm_across_head: bool = False + q_layernorm: Union[ModuleSpec, type] = None + k_layernorm: Union[ModuleSpec, type] = None + + +@dataclass +class WanCrossAttentionSubmodules: + """ + Configuration class for specifying the submodules of a cross-attention. + """ + linear_q: Union[ModuleSpec, type] = None + linear_kv: Union[ModuleSpec, type] = None + core_attention: Union[ModuleSpec, type] = None + linear_proj: Union[ModuleSpec, type] = None + layernorm_across_head: bool = False + q_layernorm: Union[ModuleSpec, type] = None + k_layernorm: Union[ModuleSpec, type] = None + + +class WanSelfAttention(SelfAttention): + def __init__( + self, + config: TransformerConfig, + submodules: WanSelfAttentionSubmodules, + layer_number: int, + attn_mask_type: AttnMaskType, + cp_comm_type: str = None, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + super().__init__( + config, + submodules, + layer_number, + attn_mask_type, + cp_comm_type, + pg_collection, + ) + + self.layernorm_across_head = submodules.layernorm_across_head + + # override q_layernorm + if submodules.q_layernorm is not None: + if self.layernorm_across_head: + q_layernorm_size = self.query_projection_size + else: + q_layernorm_size = self.hidden_size_per_attention_head + import transformer_engine as te + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.q_layernorm = build_module( + submodules.q_layernorm, + eps=1e-6, + hidden_size=q_layernorm_size, + config=norm_config, + ) + else: + self.q_layernorm = None + + # override k_layernorm + if submodules.k_layernorm is not None: + if self.layernorm_across_head: + k_layernorm_size = self.kv_projection_size + else: + k_layernorm_size = self.hidden_size_per_attention_head + import transformer_engine as te + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.k_layernorm = build_module( + submodules.k_layernorm, + eps=1e-6, + hidden_size=k_layernorm_size, + config=norm_config, + ) + else: + self.k_layernorm = None + + def get_query_key_value_tensors(self, hidden_states, key_value_states=None): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_qkv, _ = self.linear_qkv(hidden_states) + + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + + split_arg_list = [ + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + + if SplitAlongDim is not None: + + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) + else: + + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) + + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + + # gather query and key heads across TP ranks if self.layernorm_across_head is True + if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.gather_from_tensor_model_parallel_region(query) + key = tensor_parallel.gather_from_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + + if self.q_layernorm is not None: + if self.layernorm_across_head: + q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] + q_flat = self.q_layernorm(q_flat.float()) # Wan RMSNorm cast input to float32 + query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] + else: + query = self.q_layernorm(query.contiguous()) + + if self.k_layernorm is not None: + if self.layernorm_across_head: + k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() + k_flat = self.k_layernorm(k_flat.float()) # Wan RMSNorm cast input to float32 + key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) + else: + key = self.k_layernorm(key.contiguous()) + + # scatter query and key heads across TP ranks if self.layernorm_across_head is True + if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.scatter_to_tensor_model_parallel_region(query) + key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = query.contiguous() # important becuase TE attention expects contiguous tensors + key = key.contiguous() # important becuase TE attention expects contiguous tensors + + if self.config.test_mode: + self.run_realtime_tests() + + return query, key, value + + +class WanCrossAttention(CrossAttention): + def __init__( + self, + config: TransformerConfig, + submodules: WanCrossAttentionSubmodules, + layer_number: int, + attn_mask_type: AttnMaskType, + cp_comm_type: str = None, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + super().__init__( + config, + submodules, + layer_number, + attn_mask_type, + cp_comm_type, + pg_collection, + ) + + self.layernorm_across_head = submodules.layernorm_across_head + + # override q_layernorm + if submodules.q_layernorm is not None: + if self.layernorm_across_head: + q_layernorm_size = self.query_projection_size + else: + q_layernorm_size = self.hidden_size_per_attention_head + import transformer_engine as te + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.q_layernorm = build_module( + submodules.q_layernorm, + eps=1e-6, + hidden_size=q_layernorm_size, + config=norm_config, + ) + else: + self.q_layernorm = None + + # override k_layernorm + if submodules.k_layernorm is not None: + if self.layernorm_across_head: + k_layernorm_size = self.kv_projection_size + else: + k_layernorm_size = self.hidden_size_per_attention_head + import transformer_engine as te + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.k_layernorm = build_module( + submodules.k_layernorm, + eps=1e-6, + hidden_size=k_layernorm_size, + config=norm_config, + ) + else: + self.k_layernorm = None + + def get_query_key_value_tensors(self, hidden_states, key_value_states): + """ + Derives `query` tensor from `hidden_states`, and `key`/`value` tensors + from `key_value_states`. + """ + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv, _ = self.linear_kv(key_value_states) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv.size()[:-1] + ( + self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head, + ) + mixed_kv = mixed_kv.view(*new_tensor_shape) + + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) + + # Attention head [sq, b, h] --> [sq, b, hp] + query, _ = self.linear_q(hidden_states) + + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + query = query.view(*new_tensor_shape) + + # gather query and key heads across TP ranks if self.layernorm_across_head is True + if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.gather_from_tensor_model_parallel_region(query) + key = tensor_parallel.gather_from_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + + if self.q_layernorm is not None: + if self.layernorm_across_head: + q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] + q_flat = self.q_layernorm(q_flat.float()) # Wan RMSNorm cast input to float32 + query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] + else: + query = self.q_layernorm(query.contiguous()) + + if self.k_layernorm is not None: + if self.layernorm_across_head: + k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() + k_flat = self.k_layernorm(k_flat.float()) # Wan RMSNorm cast input to float32 + key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) + else: + key = self.k_layernorm(key.contiguous()) + + # scatter query and key heads across TP ranks if self.layernorm_across_head is True + if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.scatter_to_tensor_model_parallel_region(query) + key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = query.contiguous() # important becuase TE attention expects contiguous tensors + key = key.contiguous() # important becuase TE attention expects contiguous tensors + + return query, key, value + + +@dataclass +class WanWithAdaLNSubmodules(TransformerLayerSubmodules): + temporal_self_attention: Union[ModuleSpec, type] = IdentityOp + full_self_attention: Union[ModuleSpec, type] = IdentityOp + norm1: Union[ModuleSpec, type] = None + norm3: Union[ModuleSpec, type] = None + norm2: Union[ModuleSpec, type] = None + + +class WanAdaLN(MegatronModule): + """ + Adaptive Layer Normalization Module for DiT. + """ + + def __init__( + self, config: TransformerConfig + ): + super().__init__(config) + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5) + + setattr(self.modulation, "sequence_parallel", config.sequence_parallel) + + def forward(self, timestep_emb): + assert timestep_emb.dtype == torch.float32 + with amp.autocast(dtype=torch.float32): + e = (self.modulation + timestep_emb).chunk(6, dim=1) + assert e[0].dtype == torch.float32 + return e + + # @jit_fuser + def modulate(self, x, shift, scale): + return x * (1 + scale) + shift + + # @jit_fuser + def scale_add(self, residual, x, gate): + return residual + gate * x + + +class WanLayerWithAdaLN(TransformerLayer): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + DiT with Adapative Layer Normalization. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + ): + super().__init__( + config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout + ) + + # # TODO: Override Cross Attention to disable TP Comm overlap as well. ??? + # # Not disabling will attempt re-use of buffer size same as Q and lead to incorrect tensor shapes. + # cp_override_config = copy.deepcopy(config) + # cp_override_config.tp_comm_overlap = False + # self.cross_attention = build_module( + # submodules.cross_attention, + # config=cp_override_config, + # layer_number=layer_number, + # ) + + self.full_self_attention = build_module( + submodules.full_self_attention, + config=self.config, + layer_number=layer_number, + ) + + self.adaLN = WanAdaLN(config=self.config) + self.norm1 = build_module( + submodules.norm1, + dim=config.hidden_size, + eps=1e-6, + elementwise_affine=False + ) + self.norm3 = build_module( + submodules.norm3, + dim=config.hidden_size, + eps=1e-6, + elementwise_affine=True, + ) + self.norm2 = build_module( + submodules.norm2, + dim=config.hidden_size, + eps=1e-6, + elementwise_affine=False, + ) + + + def forward( + self, + hidden_states, + attention_mask=None, + context=None, + context_mask=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + inference_context=None, + ): + # the timestep embedding is stored in attention_mask argument + timestep_emb = attention_mask + rope_emb = rotary_pos_emb + + # DEBUGGING + run_debug = False + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print("[DEBUG][WanLayerWithAdaLN] ================================") + print("[DEBUG][WanLayerWithAdaLN][forward_input] hidden_states.shape - hidden_states.dtype - hidden_states.mean() - hidden_states.std() - hidden_states.norm():", hidden_states.shape, hidden_states.dtype, hidden_states.mean(), hidden_states.std(), hidden_states.norm()) + print("[DEBUG][WanLayerWithAdaLN][forward_input] timestep_emb.shape - timestep_emb.dtype - timestep_emb.mean() - timestep_emb.std() - timestep_emb.norm():", timestep_emb.shape, timestep_emb.dtype, timestep_emb.mean(), timestep_emb.std(), timestep_emb.norm()) + print("[DEBUG][WanLayerWithAdaLN][forward_input] context.shape - context.dtype - context.mean() - context.std() - context.norm():", context.shape, context.dtype, context.mean(), context.std(), context.norm()) + if context_mask is not None: + print("[DEBUG][WanLayerWithAdaLN][forward_input] context_mask.shape - context_mask.dtype - context_mask.mean() - context_mask.std() - context_mask.norm():", context_mask.shape, context_mask.dtype, context_mask.mean(), context_mask.std(), context_mask.norm()) + if rotary_pos_emb is not None: + print("[DEBUG][WanLayerWithAdaLN][forward_input] rotary_pos_emb.shape - rotary_pos_emb.dtype - rotary_pos_emb.mean() - rotary_pos_emb.std() - rotary_pos_emb.norm():", rotary_pos_emb.shape, rotary_pos_emb.dtype, rotary_pos_emb.mean(), rotary_pos_emb.std(), rotary_pos_emb.norm()) + if rotary_pos_cos is not None: + print("[DEBUG][WanLayerWithAdaLN][forward_input] rotary_pos_cos.shape - rotary_pos_cos.dtype - rotary_pos_cos.mean() - rotary_pos_cos.std() - rotary_pos_cos.norm():", rotary_pos_cos.shape, rotary_pos_cos.dtype, rotary_pos_cos.mean(), rotary_pos_cos.std(), rotary_pos_cos.norm()) + if rotary_pos_sin is not None: + print("[DEBUG][WanLayerWithAdaLN][forward_input] rotary_pos_sin.shape - rotary_pos_sin.dtype - rotary_pos_sin.mean() - rotary_pos_sin.std() - rotary_pos_sin.norm():", rotary_pos_sin.shape, rotary_pos_sin.dtype, rotary_pos_sin.mean(), rotary_pos_sin.std(), rotary_pos_sin.norm()) + if attention_bias is not None: + print("[DEBUG][WanLayerWithAdaLN][forward_input] attention_bias.shape - attention_bias.dtype - attention_bias.mean() - attention_bias.std() - attention_bias.norm():", attention_bias.shape, attention_bias.dtype, attention_bias.mean(), attention_bias.std(), attention_bias.norm()) + if inference_params is not None: + print("[DEBUG][WanLayerWithAdaLN][forward_input] inference_params.shape - inference_params.dtype - inference_params.mean() - inference_params.std() - inference_params.norm():", inference_params.shape, inference_params.dtype, inference_params.mean(), inference_params.std(), inference_params.norm()) + if packed_seq_params is not None: + print("[DEBUG][WanLayerWithAdaLN][forward_input] packed_seq_params:", packed_seq_params) + if sequence_len_offset is not None: + print("[DEBUG][WanLayerWithAdaLN][forward_input] sequence_len_offset.shape - sequence_len_offset.dtype - sequence_len_offset.mean() - sequence_len_offset.std() - sequence_len_offset.norm():", sequence_len_offset.shape, sequence_len_offset.dtype, sequence_len_offset.mean(), sequence_len_offset.std(), sequence_len_offset.norm()) + + shift_full, scale_full, gate_full, shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb) + # transpose to bring it to [1, b, ...] format + shift_full = shift_full.transpose(0, 1) + scale_full = scale_full.transpose(0, 1) + gate_full = gate_full.transpose(0, 1) + shift_mlp = shift_mlp.transpose(0, 1) + scale_mlp = scale_mlp.transpose(0, 1) + gate_mlp = gate_mlp.transpose(0, 1) + + # ******************************************** full self attention ******************************************* + + if run_debug and torch.distributed.get_rank()==0: + print("[DEBUG][WanLayerWithAdaLN] shift_full.shape - shift_full.dtype - shift_full.mean() - shift_full.std():", shift_full.shape, shift_full.dtype, float(shift_full.mean().item()), float(shift_full.std().item())) + print("[DEBUG][WanLayerWithAdaLN] scale_full.shape - scale_full.dtype - scale_full.mean() - scale_full.std():", scale_full.shape, scale_full.dtype, scale_full.mean(), scale_full.std()) + print("[DEBUG][WanLayerWithAdaLN] gate_full.shape - gate_full.dtype - gate_full.mean() - gate_full.std():", gate_full.shape, gate_full.dtype, gate_full.mean(), gate_full.std()) + print("[DEBUG][WanLayerWithAdaLN] shift_mlp.shape - shift_mlp.dtype - shift_mlp.mean() - shift_mlp.std():", shift_mlp.shape, shift_mlp.dtype, shift_mlp.mean(), shift_mlp.std()) + print("[DEBUG][WanLayerWithAdaLN] scale_mlp.shape - scale_mlp.dtype - scale_mlp.mean() - scale_mlp.std():", scale_mlp.shape, scale_mlp.dtype, scale_mlp.mean(), scale_mlp.std()) + print("[DEBUG][WanLayerWithAdaLN] gate_mlp.shape - gate_mlp.dtype - gate_mlp.mean() - gate_mlp.std():", gate_mlp.shape, gate_mlp.dtype, gate_mlp.mean(), gate_mlp.std()) + + # DEBUGGING + # if run_debug and torch.distributed.get_rank()==0: + if run_debug: + x_debug = hidden_states # DEBUGGING + print(f"[DEBUG][WanLayerWithAdaLN] [rank {torch.distributed.get_rank()}] hidden_states.shape - hidden_states.dtype - hidden_states.mean() - hidden_states.std():", hidden_states.shape, hidden_states.dtype, float(hidden_states.mean().item()), float(hidden_states.std().item())) + print(f"[DEBUG][WanLayerWithAdaLN] [rank {torch.distributed.get_rank()}] self.norm1(hidden_states).shape - self.norm1(hidden_states).dtype - self.norm1(hidden_states).mean() - self.norm1(hidden_states).std():", self.norm1(hidden_states).shape, self.norm1(hidden_states).dtype, float(self.norm1(hidden_states).mean().item()), float(self.norm1(hidden_states).std().item())) + print(f"[DEBUG][WanLayerWithAdaLN] [rank {torch.distributed.get_rank()}] shift_full.shape - shift_full.dtype - shift_full.mean() - shift_full.std():", shift_full.shape, shift_full.dtype, float(shift_full.mean().item()), float(shift_full.std().item())) + print(f"[DEBUG][WanLayerWithAdaLN] [rank {torch.distributed.get_rank()}] scale_full.shape - scale_full.dtype - scale_full.mean() - scale_full.std():", scale_full.shape, scale_full.dtype, float(scale_full.mean().item()), float(scale_full.std().item())) + + + # adaLN with scale + shift + gate + pre_full_attn_layernorm_output_ada = self.adaLN.modulate( + self.norm1(hidden_states.float()), # Wan's LayerNorm implementation forward pass casts input to float32 + shift=shift_full, + scale=scale_full, + ) + + attention_output, bias = self.full_self_attention( + pre_full_attn_layernorm_output_ada, + attention_mask=None, + rotary_pos_emb=rope_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params['self_attention'], + ) + if bias is not None: + attention_output = attention_output + bias + + with amp.autocast(dtype=torch.float32): + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=attention_output, gate=gate_full) + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print("[DEBUG][WanLayerWithAdaLN][self_attention] x_debug.shape - x_debug.dtype - x_debug.mean() - x_debug.std() - x.norm:", x_debug.shape, x_debug.dtype, x_debug.mean(), x_debug.std(), x_debug.norm()) + print("[DEBUG][WanLayerWithAdaLN][self_attention] pre_full_attn_layernorm_output_ada.shape - pre_full_attn_layernorm_output_ada.dtype - pre_full_attn_layernorm_output_ada.mean() - pre_full_attn_layernorm_output_ada.std() - pre_full_attn_layernorm_output_ada.norm:", pre_full_attn_layernorm_output_ada.shape, pre_full_attn_layernorm_output_ada.dtype, pre_full_attn_layernorm_output_ada.mean(), pre_full_attn_layernorm_output_ada.std(), pre_full_attn_layernorm_output_ada.norm()) + print("[DEBUG][WanLayerWithAdaLN][self_attention] attention_output.shape - attention_output.dtype - attention_output.mean() - attention_output.std() - attention_output.norm():", attention_output.shape, attention_output.dtype, attention_output.mean(), attention_output.std(), attention_output.norm()) + print("[DEBUG][WanLayerWithAdaLN][self_attention] gate_full.shape - gate_full.dtype - gate_full.mean() - gate_full.std() - gate_full.norm():", gate_full.shape, gate_full.dtype, gate_full.mean(), gate_full.std(), gate_full.norm()) + print("[DEBUG][WanLayerWithAdaLN][self_attention] hidden_states.shape - hidden_states.dtype - hidden_states.mean() - hidden_states.std() - hidden_states.norm():", hidden_states.shape, hidden_states.dtype, hidden_states.mean(), hidden_states.std(), hidden_states.norm()) + + + # ******************************************** cross attention ****************************************************** + + attention_output, bias = self.cross_attention( + self.norm3(hidden_states.float()), # Wan's LayerNorm implementation forward pass casts input to float32 + attention_mask=context_mask, + key_value_states=context, + packed_seq_params=packed_seq_params['cross_attention'], + ) + if bias is not None: + attention_output = attention_output + bias + + hidden_states = hidden_states + attention_output + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print("[DEBUG][WanLayerWithAdaLN][cross_attention] attention_output.shape - attention_output.dtype - attention_output.mean() - attention_output.std() - attention_output.norm():", attention_output.shape, attention_output.dtype, attention_output.mean(), attention_output.std(), attention_output.norm()) + print("[DEBUG][WanLayerWithAdaLN][cross_attention] hidden_states.shape - hidden_states.dtype - hidden_states.mean() - hidden_states.std() - hidden_states.norm():", hidden_states.shape, hidden_states.dtype, hidden_states.mean(), hidden_states.std(), hidden_states.norm()) + + # ******************************************** mlp ****************************************************** + + pre_mlp_layernorm_output_ada = self.adaLN.modulate( + self.norm2(hidden_states.float()), # Wan's LayerNorm implementation forward pass casts input to float32 + shift=shift_mlp, + scale=scale_mlp, + ) + + mlp_output, bias = self.mlp(pre_mlp_layernorm_output_ada) + if bias is not None: + mlp_output = mlp_output + bias + + # DEBUGGING + print("self.mlp.activation_func:", self.mlp.activation_func) + + with amp.autocast(dtype=torch.float32): + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) + + + # TODO: Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. ??? + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) + # output = hidden_states + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print("[DEBUG][WanLayerWithAdaLN][mlp] pre_mlp_layernorm_output_ada.shape - pre_mlp_layernorm_output_ada.dtype - pre_mlp_layernorm_output_ada.mean() - pre_mlp_layernorm_output_ada.std() - pre_mlp_layernorm_output_ada.norm():", pre_mlp_layernorm_output_ada.shape, pre_mlp_layernorm_output_ada.dtype, pre_mlp_layernorm_output_ada.mean(), pre_mlp_layernorm_output_ada.std(), pre_mlp_layernorm_output_ada.norm()) + print("[DEBUG][WanLayerWithAdaLN][mlp] mlp_output.shape - mlp_output.dtype - mlp_output.mean() - mlp_output.std() - mlp_output.norm():", mlp_output.shape, mlp_output.dtype, mlp_output.mean(), mlp_output.std(), mlp_output.norm()) + print("[DEBUG][WanLayerWithAdaLN][mlp] gate_mlp.shape - gate_mlp.dtype - gate_mlp.mean() - gate_mlp.std() - gate_mlp.norm():", gate_mlp.shape, gate_mlp.dtype, gate_mlp.mean(), gate_mlp.std(), gate_mlp.norm()) + print("[DEBUG][WanLayerWithAdaLN][mlp] hidden_states.shape - hidden_states.dtype - hidden_states.mean() - hidden_states.std() - hidden_states.norm():", hidden_states.shape, hidden_states.dtype, hidden_states.mean(), hidden_states.std(), hidden_states.norm()) + + # DEBUGGING + if run_debug: + hidden_states_concatenated = cat_outputs_cp(hidden_states, 0, parallel_state.get_context_parallel_group()) + if torch.distributed.get_rank()==0: + print("[DEBUG][WanLayerWithAdaLN][mlp] (after cat_outputs_cp) hidden_states_concatenated.shape - hidden_states_concatenated.dtype - hidden_states_concatenated.mean() - hidden_states_concatenated.std() - hidden_states_concatenated.norm():", hidden_states_concatenated.shape, hidden_states_concatenated.dtype, hidden_states_concatenated.mean(), hidden_states_concatenated.std(), hidden_states_concatenated.norm()) + + # # DEBUGGING + # if run_debug and torch.distributed.get_rank()==0: + # print(stop_here) + + return output, context + + +import transformer_engine as te +def get_wan_block_with_transformer_engine_spec() -> ModuleSpec: + params = {"attn_mask_type": AttnMaskType.padding} + return ModuleSpec( + module=WanLayerWithAdaLN, + submodules=WanWithAdaLNSubmodules( + norm1=WanLayerNorm, + norm3=WanLayerNorm, + norm2=WanLayerNorm, + full_self_attention=ModuleSpec( + module=WanSelfAttention, + params=params, + submodules=WanSelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + cross_attention=ModuleSpec( + module=WanCrossAttention, + params=params, + submodules=WanCrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + # by default, activation_func is openai_gelu, which is equivalent to nn.GELU(approximate='tanh') + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) diff --git a/src/megatron/bridge/models/wan/wan_model.py b/src/megatron/bridge/models/wan/wan_model.py new file mode 100644 index 0000000000..adb2d6eaad --- /dev/null +++ b/src/megatron/bridge/models/wan/wan_model.py @@ -0,0 +1,387 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from typing import Dict, Literal, Optional, Tuple, List, Union + +import math +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from megatron.core import parallel_state, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import make_sharded_tensor_for_checkpoint +from megatron.bridge.models.wan.wan_layer_spec import ( + get_wan_block_with_transformer_engine_spec as WanLayerWithAdaLNspec, +) +from megatron.bridge.models.wan.wan_layer_spec import WanLayerNorm +from torch import Tensor +from .rope_utils import Wan3DRopeEmbeddings + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + assert e.dtype == torch.float32 + with amp.autocast(dtype=torch.float32): + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + return x + + +class WanModel(VisionModule): + """ + WanModel is a VisionModule that implements a Wan model. + Attributes: + config (TransformerConfig): Configuration for the transformer. + pre_process (bool): Whether to apply pre-processing steps. + post_process (bool): Whether to apply post-processing steps. + fp16_lm_cross_entropy (bool): Whether to use fp16 for cross-entropy loss. + parallel_output (bool): Whether to use parallel output. + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + transformer_decoder_layer_spec (WanLayerWithAdaLNspec): Specification for the transformer decoder layer. + add_encoder (bool): Whether to add an encoder. + add_decoder (bool): Whether to add a decoder. + model_type (ModelType): Type of the model. + """ + + def __init__( + self, + config: TransformerConfig, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + in_channels: int = 16, + out_channels: int = 16, + transformer_decoder_layer_spec=WanLayerWithAdaLNspec, + **kwargs, + ): + super(WanModel, self).__init__(config=config) + + self.config: TransformerConfig = config + + self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = True + self.add_decoder = True + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.in_channels = in_channels + self.out_channels = out_channels + + # megatron core pipelining currently depends on model type + # TODO: remove this dependency ? + self.model_type = ModelType.encoder_or_decoder + + self.num_heads = self.config.num_attention_heads + self.freq_dim = self.config.freq_dim + self.patch_spatial = self.config.patch_spatial + self.patch_temporal = self.config.patch_temporal + self.patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial) + + + ###################################### + ########## Wan architecture ########## + + # embeddings + if self.pre_process: + self.patch_embedding = nn.Conv3d( + self.in_channels, self.config.hidden_size, kernel_size=self.patch_size, stride=self.patch_size) + + self.text_embedding = nn.Sequential( + nn.Linear(self.config.text_dim, self.config.hidden_size), nn.GELU(approximate='tanh'), + nn.Linear(self.config.hidden_size, self.config.hidden_size)) + + self.time_embedding = nn.Sequential( + nn.Linear(self.freq_dim, self.config.hidden_size), nn.SiLU(), nn.Linear(self.config.hidden_size, self.config.hidden_size)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(self.config.hidden_size, self.config.hidden_size * 6)) + + self.rope_embeddings = Wan3DRopeEmbeddings(dim_head = self.config.hidden_size // self.num_heads, max_position_len = 1024) + + # decoder blocks + self.decoder = TransformerBlock( + config=self.config, + spec=self.transformer_decoder_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=False, + ) + + # output head + if self.post_process: + self.head = Head(self.config.hidden_size, self.out_channels, self.patch_size, eps = 1e-6) + + + def forward( + self, + x: Tensor, + grid_sizes: list[Tuple[int, int, int]], + t: Tensor, + context: Tensor, + max_seq_len: int, + packed_seq_params: PackedSeqParams = None, + **kwargs, + ) -> Tensor: + """Forward pass. + + Args: + x List[Tensor]: list of vae encoded data (in_channel, f, h, w) + grid_sizes List[Tuple[int, int, int]]: list of grid sizes (f, h, w) + t Tensor: timesteps + context List[Tensor]: list of context (text_len, hidden_size) + max_seq_len int: maximum sequence length + packed_seq_params PackedSeqParams: packed sequence parameters + + Returns: + Tensor: output tensor (still patchified) of shape [seq_len, batch_size, hidden_size] + """ + ################################# + ########## Wan forward ########## + + # DEBUGGING + run_debug = False + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print("[DEBUG] state_dict keys:") + for k, v in self.state_dict().items(): + if "_extra_state" in k: + continue + if hasattr(v, "shape"): + print(f"[DEBUG] {k} | shape - dtype - mean - std - norm: {tuple(v.shape)} - {v.dtype} - {v.mean().item()} - {v.std().item()} - {v.norm().item()}") + else: + print(f"[DEBUG] {k}") + print("\n\n\n") + + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print("[DEBUG] [WanModel forward] x.shape - x.dtype - x.mean() - x.std() - x.norm(): ", x.shape, x.dtype, x.mean(), x.std(), x.norm()) + print("[DEBUG] [WanModel forward] grid_sizes: ", grid_sizes) + print("[DEBUG] [WanModel forward] t: ", t) + print("[DEBUG] [WanModel forward] context.shape - context.dtype - context.mean() - context.std() - context.norm(): ", context.shape, context.dtype, context.mean(), context.std(), context.norm()) + print("[DEBUG] [WanModel forward] max_seq_len: ", max_seq_len) + print("[DEBUG] [WanModel forward] packed_seq_params: ", packed_seq_params) + + + # ============= embedders ============= + + # run input embedding + if self.pre_process: + # x.shape [s, b, c * pF * pH * pW] + seq_len, batch_size, _ = x.shape + c = self.out_channels + pF, pH, pW = self.patch_size + x = x.reshape(seq_len * batch_size, c, pF, pH, pW) # output: x.shape [s * b, c, pF, pH, pW] + x = self.patch_embedding(x) # output: x.shape [s * b, hidden_size, 1, 1, 1] + x = x.flatten(1) # output: x.shape [s * b, hidden_size] + x = x.reshape(seq_len, batch_size, -1) # output: x.shape [s, b, hidden_size] + + # split sequence for sequence_parallel + # TODO: for PP, do we move scatter_to_sequence_parallel_region here or after "x = self.decoder.input_tensor" ??? + if self.config.sequence_parallel: + x = tensor_parallel.scatter_to_sequence_parallel_region(x) # output: x.shape [s * b // tp_size, hidden_size] + + else: + # intermediate stage of pipeline + x = self.decoder.input_tensor + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print("[DEBUG] [WanModel forward] (after patch_embedding) x.shape - x.dtype - x.mean() - x.std() - x.norm(): ", x.shape, x.dtype, x.mean(), x.std(), x.norm()) + print("[DEBUG] [WanModel forward] (after patch_embedding) x:", x) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.config.hidden_size)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context embeddings + context = self.text_embedding(context) # shape [text_len, b, hidden_size] + + + # ============= decoder ============= + # calculate rotary pos emb + n_head, dim_head = self.num_heads, self.config.hidden_size // self.num_heads + rotary_pos_emb = self.rope_embeddings(n_head, dim_head, max_seq_len, grid_sizes, t.device) # output: rotary_pos_emb.shape [s, b, 1, dim_head] + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print("[DEBUG] [WanModel forward] (before self.decoder) x.shape - x.dtype - x.mean() - x.std() - x.norm(): ", x.shape, x.dtype, x.mean(), x.std(), x.norm()) + print("[DEBUG] [WanModel forward] (before self.decoder) context.shape - context.dtype - context.mean() - context.std() - context.norm(): ", context.shape, context.dtype, context.mean(), context.std(), context.norm()) + print("[DEBUG] [WanModel forward] (before self.decoder) e0.shape - e0.dtype - e0.mean() - e0.std() - e0.norm(): ", e0.shape, e0.dtype, e0.mean(), e0.std(), e0.norm()) + print("[DEBUG] [WanModel forward] (before self.decoder) rotary_pos_emb.shape - rotary_pos_emb.dtype - rotary_pos_emb.mean() - rotary_pos_emb.std() - rotary_pos_emb.norm(): ", rotary_pos_emb.shape, rotary_pos_emb.dtype, rotary_pos_emb.mean(), rotary_pos_emb.std(), rotary_pos_emb.norm()) + print("[DEBUG] [WanModel forward] (before self.decoder) packed_seq_params: ", packed_seq_params) + + # run decoder + x = self.decoder( + hidden_states=x, + attention_mask=e0, + context=context, + context_mask=None, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=None, + rotary_pos_sin=None, + packed_seq_params=packed_seq_params, + ) + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print("[DEBUG] [WanModel forward] (after self.decoder) x.shape - x.dtype - x.mean() - x.std() - x.norm(): ", x.shape, x.dtype, x.mean(), x.std(), x.norm()) + + # return if not post_process + if not self.post_process: + return x + + # head + x = x.transpose(0, 1) # head expects shape [b, s, hidden_size] + x = self.head(x, e) # output: x.shape [b, s, c * pF * pH * pW] + x = x.transpose(0, 1) # reshape back to shape [s, b, c * pF * pH * pW] + + # gather outputs for sequence_parallel + # Note: in GPT models, because the vocab projection matrix is ColumnParallelLinear, the sequence is + # automatically gathered in ColumnParallelLinear forward pass. + # However, in Wan models, we need to gather the outputs manually. + if self.config.sequence_parallel: + x = tensor_parallel.gather_from_sequence_parallel_region(x) + + # DEBUGGING + if run_debug and torch.distributed.get_rank()==0: + print("[DEBUG] [WanModel forward] (after self.head) x.shape - x.dtype - x.mean() - x.std() - x.norm(): ", x.shape, x.dtype, x.mean(), x.std(), x.norm()) + + return x # output: x.shape [s, b, c * pF * pH * pW] + + + def set_input_tensor(self, input_tensor: Tensor) -> None: + """Sets input tensor to the model. + + See megatron.model.transformer.set_input_tensor() + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + assert len(input_tensor) == 1, "input_tensor should only be length 1 for gpt/bert" + self.decoder.set_input_tensor(input_tensor[0]) + + + def sharded_state_dict( + self, prefix: str = "module.", sharded_offsets: tuple = (), metadata: Optional[Dict] = None + ) -> ShardedStateDict: + """Sharded state dict implementation for GPTModel backward-compatibility (removing extra state). + + Args: + prefix (str): Module name prefix. + sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. + metadata (Optional[Dict]): metadata controlling sharded state dict creation. + + Returns: + ShardedStateDict: sharded state dict for the GPTModel + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + + # DEBUGGING + # for module in ["t_embedder"]: + # for param_name, param in getattr(self, module).named_parameters(): + # weight_key = f"{prefix}{module}.{param_name}" + # self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key) + # DEBUGGING + # Ensure replica ids for non-transformer embedder weights include pipeline dimension + for module in ["text_embedding", "time_embedding", "time_projection"]: + if hasattr(self, module): + for param_name, param in getattr(self, module).named_parameters(): + weight_key = f"{prefix}{module}.{param_name}" + if weight_key in sharded_state_dict: + self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key) + + return sharded_state_dict + + + def _set_embedder_weights_replica_id( + self, tensor: Tensor, sharded_state_dict: ShardedStateDict, embedder_weight_key: str + ) -> None: + """set replica ids of the weights in t_embedder for sharded state dict. + + Args: + sharded_state_dict (ShardedStateDict): state dict with the weight to tie + weight_key (str): key of the weight in the state dict. + This entry will be replaced with a tied version + + Returns: None, acts in-place + """ + tp_rank = parallel_state.get_tensor_model_parallel_rank() + vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() + vpp_rank = vpp_rank if vpp_rank else 0 + vpp_world = parallel_state.get_virtual_pipeline_model_parallel_world_size() + vpp_world = vpp_world if vpp_world else 1 + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + del sharded_state_dict[embedder_weight_key] + replica_id = ( + tp_rank, + (vpp_rank + pp_rank * vpp_world), + parallel_state.get_data_parallel_rank(with_context_parallel=True), + ) + + sharded_state_dict[embedder_weight_key] = make_sharded_tensor_for_checkpoint( + tensor=tensor, + key=embedder_weight_key, + replica_id=replica_id, + allow_shape_mismatch=False, + ) diff --git a/src/megatron/bridge/models/wan/wan_provider.py b/src/megatron/bridge/models/wan/wan_provider.py new file mode 100644 index 0000000000..0003761f5e --- /dev/null +++ b/src/megatron/bridge/models/wan/wan_provider.py @@ -0,0 +1,121 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import inspect +import logging +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Callable, Dict, Literal, Optional, Union + +import torch +from megatron.core import parallel_state +from megatron.core.models.gpt import GPTModel as MCoreGPTModel +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) +from megatron.core.transformer import ModuleSpec +from megatron.bridge.models.transformer_config import TransformerConfig +from megatron.bridge.models.DiTModel.dit_utils import dynamic_import + +from megatron.bridge.models.model_provider import ModelProviderMixin +from megatron.bridge.utils import fusions +from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.bridge.models.wan.wan_model import WanModel + +logger = logging.getLogger(__name__) + +@dataclass +class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): + crossattn_emb_size: int = 1536 + add_bias_linear: bool = True + gated_linear_unit: bool = False + + num_layers: int = 30 + hidden_size: int = 1536 + ffn_hidden_size: int = 8960 + max_img_h: int = 80 + max_img_w: int = 80 + max_frames: int = 34 + patch_spatial: int = 2 + patch_temporal: int = 1 + num_attention_heads: int = 12 + layernorm_epsilon = 1e-6 + normalization = "RMSNorm" + qk_layernorm_per_head: bool = False + layernorm_zero_centered_gamma = False + + fp16_lm_cross_entropy: bool = False + parallel_output: bool = True + share_embeddings_and_output_weights: bool = True + + hidden_dropout: float = 0 + attention_dropout: float = 0 + + bf16: bool = False + params_dtype: torch.dtype = torch.float32 + + vae_module: str = "nemo_vfm.diffusion.vae.diffusers_vae.AutoencoderKLVAE" + vae_path: str = None + sigma_data: float = 0.5 + + in_channels: int = 16 + out_channels: int = 16 + + replicated_t_embedder = True + qkv_format: str = 'sbhd' + + # DEBUGGING + # adding more attributes + text_dim: int = 4096 + patch_size: list = field(default_factory=lambda: [1, 2, 2]) + freq_dim: int = 256 + out_dim: int = 16 + text_len: int = 512 + + + + # DEBUGGING + # unused, we just set because bridge training requires this for LLMs + seq_length: int = 1024 + vocab_size: int = None + make_vocab_size_divisible_by: int = 128 + + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> WanModel: + vp_size = self.virtual_pipeline_model_parallel_size + if vp_size: + p_size = self.pipeline_model_parallel_size + assert (self.num_layers // p_size) % vp_size == 0, ( + "Make sure the number of model chunks is the same across all pipeline stages." + ) + + model = WanModel + + return model( + self, + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + max_img_h=self.max_img_h, + max_img_w=self.max_img_w, + max_frames=self.max_frames, + patch_spatial=self.patch_spatial, + ) + + def configure_vae(self): + return dynamic_import(self.vae_module)(self.vae_path) \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/wan_step.py b/src/megatron/bridge/models/wan/wan_step.py new file mode 100644 index 0000000000..a969f30135 --- /dev/null +++ b/src/megatron/bridge/models/wan/wan_step.py @@ -0,0 +1,194 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import partial +from typing import Iterable + +import torch +from megatron.core import parallel_state +from megatron.core.models.gpt import GPTModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.utils import get_batch_on_this_cp_rank, get_model_config +# from megatron.bridge.models.DiTModel.edm.edm_pipeline import EDMPipeline +from megatron.bridge.models.wan.flow_matching.flow_pipeline import FlowPipeline + +from megatron.bridge.training.config import ConfigContainer, FinetuningDatasetConfig +from megatron.bridge.training.losses import masked_next_token_loss +from megatron.bridge.training.state import GlobalState + + +logger = logging.getLogger(__name__) + +def wan_data_step(qkv_format, dataloader_iter): + batch = next(iter(dataloader_iter.iterable)) + + # # can we do this ??? + # batch = get_batch_on_this_cp_rank(batch) + + batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} + + + # ??? Should we do the padding here, by padding to the longest sequence length in the batch? + # ??? Or should we do the padding in the TaskEncoder? + # => do task encoder padding here + + # Construct packed sequence parameters + if ("seq_len_q" in batch) and ("seq_len_kv" in batch): + cu_seqlens = batch["seq_len_q"].cumsum(dim=0).to(torch.int32) + zero = torch.zeros(1, dtype=torch.int32, device="cuda") + cu_seqlens = torch.cat((zero, cu_seqlens)) + + cu_seqlens_kv = batch["seq_len_kv"].cumsum(dim=0).to(torch.int32) + cu_seqlens_kv = torch.cat((zero, cu_seqlens_kv)) + + batch["packed_seq_params"] = { + "self_attention": PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + qkv_format=qkv_format, + ), + "cross_attention": PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens_kv, + qkv_format=qkv_format, + ), + } + + return batch + + +def get_batch_on_this_cp_rank(data): + """Split the data for context parallelism.""" + from megatron.core import mpu + + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + + t = 16 + if cp_size > 1: + # cp split on seq_length, for video_latent, noise_latent and pos_ids + assert t % cp_size == 0, "t must divisibly by cp_size" + num_valid_tokens_in_ub = None + if "loss_mask" in data and data["loss_mask"] is not None: + num_valid_tokens_in_ub = data["loss_mask"].sum() + + for key, value in data.items(): + if (value is not None) and (key in ["video", "video_latent", "noise_latent", "pos_ids"]): + if len(value.shape) > 5: + value = value.squeeze(0) + B, C, T, H, W = value.shape + if T % cp_size == 0: + # FIXME packed sequencing + data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous() + else: + # FIXME packed sequencing + data[key] = value.view(B, C, T, cp_size, H // cp_size, W)[:, :, :, cp_rank, ...].contiguous() + loss_mask = data["loss_mask"] + data["loss_mask"] = loss_mask.view(loss_mask.shape[0], cp_size, loss_mask.shape[1] // cp_size)[ + :, cp_rank, ... + ].contiguous() + data["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub + + return data + + +class WanForwardStep: + def __init__(self): + self.diffusion_pipeline = FlowPipeline() + + + def __call__( + self, state: GlobalState, data_iterator: Iterable, model: GPTModel, return_schedule_plan: bool = False + ) -> tuple[torch.Tensor, partial]: + """Forward training step. + + Args: + state: Global state for the run + data_iterator: Input data iterator + model: The GPT Model + return_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor + + Returns: + tuple containing the output tensor and the loss function + """ + timers = state.timers + straggler_timer = state.straggler_timer + + config = get_model_config(model) + + timers("batch-generator", log_level=2).start() + + qkv_format = getattr(config, "qkv_format", "sbhd") + with straggler_timer(bdata=True): + batch = wan_data_step( + qkv_format, data_iterator + ) + timers("batch-generator").stop() + + check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss + check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss + + # DEBUGGING + run_debug = False + if run_debug: + print("---- Sample info [WanForwardStep] ----") + print(f"batch['video_latents'] shape: {batch['video_latents'].shape}") + print(f"batch['context_embeddings'] shape: {batch['context_embeddings'].shape}") + print(f"batch['loss_mask'] shape: {batch['loss_mask'].shape}") + print(f"batch['grid_sizes']: {batch['grid_sizes']}") + print(f"batch['packed_seq_params']: {batch['packed_seq_params']}") + + + # run diffusion training step + with straggler_timer: + if parallel_state.is_pipeline_last_stage(): + output_batch, loss = self.diffusion_pipeline.training_step(model, batch) + output_tensor = torch.mean(loss, dim=-1) + else: + output_tensor = self.diffusion_pipeline.training_step(model, batch) + + + # DEBUGGING + # ??? do we need to gather output with sequence or context parallelism here + # ??? especially when we have pipeline parallelism + + + loss = output_tensor + if "loss_mask" not in batch or batch["loss_mask"] is None: + loss_mask = torch.ones_like(loss) + loss_mask = batch["loss_mask"] + + loss_function = self._create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) + + return output_tensor, loss_function + + + def _create_loss_function(self, loss_mask: torch.Tensor, check_for_nan_in_loss: bool, check_for_spiky_loss: bool) -> partial: + """Create a partial loss function with the specified configuration. + + Args: + loss_mask: Used to mask out some portions of the loss + check_for_nan_in_loss: Whether to check for NaN values in the loss + check_for_spiky_loss: Whether to check for spiky loss values + + Returns: + A partial function that can be called with output_tensor to compute the loss + """ + return partial( + masked_next_token_loss, + loss_mask, + check_for_nan_in_loss=check_for_nan_in_loss, + check_for_spiky_loss=check_for_spiky_loss, + ) From 544ad75112a3840f934ee46637a6c2d668596960 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Thu, 23 Oct 2025 07:24:01 -0700 Subject: [PATCH 02/17] clean inference code --- examples/recipes/wan/inference_wan.py | 14 ++ .../flow_matching/flow_inference_pipeline.py | 228 +++--------------- .../wan/inference/configs/shared_config.py | 2 - .../bridge/models/wan/wan_layer_spec.py | 92 +------ src/megatron/bridge/models/wan/wan_model.py | 59 +---- .../bridge/models/wan/wan_provider.py | 80 ++---- 6 files changed, 74 insertions(+), 401 deletions(-) diff --git a/examples/recipes/wan/inference_wan.py b/examples/recipes/wan/inference_wan.py index a593f73e0d..8edd890f9c 100644 --- a/examples/recipes/wan/inference_wan.py +++ b/examples/recipes/wan/inference_wan.py @@ -1,4 +1,18 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +# Example of running script for Wan inference. +# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ +# --task t2v-1.3B \ +# --sizes 480*832 \ +# --ckpt_dir /path/to/wan_checkpoints \ +# --frame_nums 81 \ +# --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +# --tensor_parallel_size 1 \ +# --context_parallel_size 1 \ +# --pipeline_parallel_size 1 \ +# --sequence_parallel False \ +# --base_seed 42 \ +# --sample_steps 50 + import argparse import logging import os diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py index 5b905cabee..83314df11c 100644 --- a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py +++ b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py @@ -24,8 +24,7 @@ retrieve_timesteps, ) from megatron.bridge.models.wan.inference.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from megatron.core.dist_checkpointing.validation import StrictHandling -from megatron.core import dist_checkpointing, parallel_state +from megatron.core import parallel_state from torch.nn import functional as F import math @@ -90,6 +89,7 @@ def __init__( wan_checkpoint_dir = os.path.join(checkpoint_dir, "iter_0000000") self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) + # set self.sp_size=1 for later use, just to respect the original Wan inference code self.sp_size = 1 if dist.is_initialized(): @@ -101,15 +101,15 @@ def __init__( def patchify(self, x, patch_size): """ - Convert a list of reconstructed video tensor into patch embeddings (inverse of `unpatchify`). + Convert a list of reconstructed video tensor into patch embeddings. + This method is the inverse of `unpatchify`. Args: - x (list[torch.Tensor]): list of tensors, each with shape [C, F * pF, H * pH, W * pW] + x (list[torch.Tensor]): list of tensors, each with shape [c, F_patches * pF, H_patches * pH, W_patches * pW] patch_size (tuple): (pF, pH, pW) Returns: - torch.Tensor: shape [num_patches, C * prod(patch_size)], - where num_patches = F * H * W + torch.Tensor: shape [ (F_patches * H_patches * W_patches), (c * pF * pH * pW)], """ out = [] for u in x: @@ -118,38 +118,33 @@ def patchify(self, x, patch_size): assert F_pF % pF == 0 and H_pH % pH == 0 and W_pW % pW == 0, \ "Spatial dimensions must be divisible by patch size." - F, H, W = F_pF // pF, H_pH // pH, W_pW // pW + F_patches, H_patches, W_patches = F_pF // pF, H_pH // pH, W_pW // pW # split spatial dims into (grid, patch) and reorder to match original patch layout: - # start: (C, F_pF, H_pW, W_pW) - # reshape -> (C, F, pF, H, pH, W, pW) - # permute -> (F, H, W, pF, pH, pW, C) - # DEBUGGING - t = u.reshape(c, F, pF, H, pH, W, pW) - # t = u.reshape(c, F, pF, W, pW, H, pH) + # start: (c, F_patches * pF, H_patches * pH, W_patches * pW) + # reshape -> (c, F_patches, pF, H_patches, pH, W_patches, pW) + # permute -> (F_patches, H_patches, W_patches, pF, pH, pW, c) + t = u.reshape(c, F_patches, pF, H_patches, pH, W_patches, pW) t = t.permute(1, 3, 5, 0, 2, 4, 6) - num_patches = F * H * W + num_patches = F_patches * H_patches * W_patches out.append(t.reshape(num_patches, c * (pF * pH * pW))) return out - def unpatchify(self, x: torch.Tensor, grid_sizes: torch.Tensor, out_dim: int) -> torch.Tensor: + def unpatchify(self, x: torch.Tensor, grid_sizes: torch.Tensor, out_dim: int) -> list[torch.Tensor]: r""" - Reconstruct video tensors from patch embeddings. + Reconstruct video tensors from patch embeddings into a list of videotensors. Args: - x (Tensor): - Tensor of patchified features, with shape [L, C_out * prod(patch_size)] + x (torch.Tensor): + Tensor of patchified features, with shape [seq_len, c * pF * pH * pW] grid_sizes (Tensor): Original spatial-temporal grid dimensions before patching, shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) Returns: - Tensor: - # Reconstructed video tensor with shape [C_out, F, H / 8, W / 8] - # ??? list of tensors, because each sample in the batch has a different video shape, the original video shape is determined by the grid_sizes. - list[Tensor]: list of tensors, each with shape [C_out, F, H / 8, W / 8] + list[torch.Tensor]: list of tensors, each with shape [c, F_latents, H_latents, W_latents] """ c = out_dim @@ -159,34 +154,21 @@ def unpatchify(self, x: torch.Tensor, grid_sizes: torch.Tensor, out_dim: int) -> u = torch.einsum('fhwpqrc->cfphqwr', u) u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) out.append(u) - # because the video shapes are different for each sample in the batch, we cannot stack the videos into a single tensor. - # out = torch.stack(out, dim=0) return out def setup_model_from_checkpoint(self, checkpoint_dir): - - # def init_distributed(tp_size: int = 1, pp_size: int = 1, cp_size: int = 1): - # rank = int(os.environ.get("LOCAL_RANK", 0)) - # world_size = int(os.environ.get("WORLD_SIZE", 1)) - # torch.cuda.set_device(rank % torch.cuda.device_count()) - # torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) - # parallel_state.initialize_model_parallel(tp_size, pp_size, context_parallel_size=cp_size) - # init_distributed(self.tensor_parallel_size, self.pipeline_parallel_size, self.context_parallel_size) - provider = WanModelProvider() provider.tensor_model_parallel_size = self.tensor_parallel_size provider.pipeline_model_parallel_size = self.pipeline_parallel_size provider.context_parallel_size = self.context_parallel_size provider.sequence_parallel = self.sequence_parallel - print(f"provider.sequence_parallel: {provider.sequence_parallel}") provider.pipeline_dtype = self.pipeline_dtype # Once all overrides are set, finalize the model provider to ensure the post initialization logic is run provider.finalize() provider.initialize_model_parallel(seed=0) - - ## Method 1: Read from megatron checkpoint + ## Read from megatron checkpoint from megatron.bridge.training.model_load_save import load_megatron_model as _load_megatron_model model = _load_megatron_model( checkpoint_dir, @@ -200,17 +182,13 @@ def setup_model_from_checkpoint(self, checkpoint_dir): ) if isinstance(model, list): model = model[0] - # ## Method 2: Read from megatron checkpoint - # model = provider.provide_distributed_model(wrap_with_ddp=False) - ## Method 3 (not loading checkpoint) - # model = provider.provide() return model def grid_sizes_calculation( self, - input_shape: Tuple[int, int, int], # (D_in, H_in, W_in) + input_shape: Tuple[int, int, int], # (F_latents, H_latents, W_latents) kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, @@ -220,11 +198,11 @@ def grid_sizes_calculation( Compute the (f,h,w) output spatial/temporal dimensions of a Conv3d patch embedder. Args: - input_shape: (D_in, H_in, W_in) + input_shape: (F_latents, H_latents, W_latents) kernel_size, stride, padding, dilation of the Conv3d patch embedder: either int or 3-tuple Returns: - (D_out, H_out, W_out) + (F_patches, H_patches, W_patches) """ def to_tuple(x): @@ -255,9 +233,8 @@ def forward_pp_step( timestep: torch.Tensor, arg_c: dict, ) -> torch.Tensor: - """One decode step supporting pipeline parallelism for batch_size=1. - - Returns a tensor containing the noise prediction. + """ + Forward pass supporting pipeline parallelism. """ from megatron.core import parallel_state @@ -267,7 +244,7 @@ def forward_pp_step( is_pp_first = parallel_state.is_pipeline_first_stage(ignore_virtual=True) is_pp_last = parallel_state.is_pipeline_last_stage(ignore_virtual=True) - # TP-only or single-rank + # PP=1: no pipeline parallelism if pp_world_size == 1: noise_pred_pp = self.model( latent_model_input, @@ -276,17 +253,11 @@ def forward_pp_step( **arg_c) return noise_pred_pp - # Pipeline-parallel path + # PP>1: pipeline parallelism hidden_size = self.model.config.hidden_size batch_size = latent_model_input.shape[1] + # noise prediction shape for communication between first and last pipeline stages noise_pred_pp_shape = list(latent_model_input.shape) - print(f"batch_size: {batch_size}") - - # DEBUGGING - # we should bring x unpatchify out of the model - # x_after_patch_embedding_shape = [16, 3, 104, 60] # ???? - # when bring unpatchified out, for pp communicate last stage to first stage, this should be - # x_after_patch_embedding_shape = [max_video_seq_len, batch_size, (ph pw pt C)] if is_pp_first: # First stage: compute multimodal + first PP slice, send activations, then receive sampled token @@ -295,10 +266,7 @@ def forward_pp_step( grid_sizes=grid_sizes, t=timestep, **arg_c) - print(f"[rank {torch.distributed.get_rank()}] Got here! - self.model") send_to_next_pipeline_rank(hidden_states) - print(f"[rank {torch.distributed.get_rank()}] Got here! - hidden_states.shape: {hidden_states.shape} - hidden_states.dtype: {hidden_states.dtype}") - print(f"[rank {torch.distributed.get_rank()}] Got here! - send_to_next_pipeline_rank") noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) return noise_pred_pp @@ -311,7 +279,6 @@ def forward_pp_step( device=latent_model_input[0].device, ) recv_from_prev_pipeline_rank_(recv_buffer) - # DEBUGGING recv_buffer = recv_buffer.to(torch.bfloat16) # ???? self.model.set_input_tensor(recv_buffer) noise_pred_pp = self.model( @@ -320,9 +287,6 @@ def forward_pp_step( t=timestep, **arg_c) - - print("noise_pred_pp_shape: ", noise_pred_pp_shape) - noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=noise_pred_pp.dtype, tensor=noise_pred_pp.contiguous()) return noise_pred_pp @@ -332,13 +296,9 @@ def forward_pp_step( dtype=next(self.model.parameters()).dtype, device=latent_model_input[0].device, ) - print(f"[rank {torch.distributed.get_rank()}] Got here! - recv_buffer.shape: {recv_buffer.shape} - recv_buffer.dtype: {recv_buffer.dtype}") recv_from_prev_pipeline_rank_(recv_buffer) - print(f"[rank {torch.distributed.get_rank()}] Got here! - recv_from_prev_pipeline_rank_") - # DEBUGGING recv_buffer = recv_buffer.to(torch.bfloat16) # ???? self.model.set_input_tensor(recv_buffer) - print(f"[rank {torch.distributed.get_rank()}] Got here! - self.model.set_input_tensor") hidden_states = self.model( latent_model_input, grid_sizes=grid_sizes, @@ -365,11 +325,11 @@ def generate(self, Generates video frames from text prompt using diffusion process. Args: - input_prompt (`str`): + prompts (`list[str]`): Text prompt for content generation - size (tupele[`int`], *optional*, defaults to (1280,720)): + sizes (list[tuple[int, int]]): Controls video resolution, (width,height). - frame_num (`int`, *optional*, defaults to 81): + frame_nums (`list[int]`): How many frames to sample from a video. The number should be 4n+1 shift (`float`, *optional*, defaults to 5.0): Noise schedule shift parameter. Affects temporal dynamics @@ -395,13 +355,6 @@ def generate(self, - W: Frame width from size) """ - # DEBUGGING - run_debug = True - - # size = sizes[0] - # input_prompt = prompts[0] - # frame_num = frame_nums[0] - # preprocess target_shapes = [] for size, frame_num in zip(sizes, frame_nums): @@ -424,6 +377,7 @@ def generate(self, seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) + ## process context context_max_len = 512 context_lens = [] @@ -449,7 +403,6 @@ def generate(self, contexts_null = torch.stack(contexts_null, dim=1) - ## setup noise noises = [] for target_shape in target_shapes: @@ -464,9 +417,6 @@ def generate(self, generator=seed_g) ) - # DEBUGGING - print("[DEBUG] noises[0].shape - noises[0].dtype - noises[0].mean() - noises[0].std() - noises[0].norm():", noises[0].shape, noises[0].dtype, noises[0].mean(), noises[0].std(), noises[0].norm()) - print("[DEBUG] noises[0]:", noises[0]) # calculate grid_sizes grid_sizes = [self.grid_sizes_calculation( @@ -550,7 +500,6 @@ def noop_no_sync(): batch_size = len(latents) # patchify latents - # ??? when batch_size > 1, we need to pad to have same length unpatchified_latents = latents latents = self.patchify(latents, self.patch_size) # pad to have same length @@ -563,15 +512,6 @@ def noop_no_sync(): timestep = [t] * batch_size timestep = torch.stack(timestep) - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print(f"[DEBUG] [rank {torch.distributed.get_rank()}] contexts.shape: {contexts.shape}") - print(f"[DEBUG] [rank {torch.distributed.get_rank()}] max_video_seq_len: {max_video_seq_len}") - print(f"[DEBUG] [rank {torch.distributed.get_rank()}] grid_sizes: {grid_sizes}") - print(f"[DEBUG] [rank {torch.distributed.get_rank()}] latent_model_input.shape: {latent_model_input.shape}") - print(f"[DEBUG] [rank {torch.distributed.get_rank()}] timestep.shape: {timestep.shape}") - - self.model.to(self.device) noise_pred_cond = self.forward_pp_step( latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_c) @@ -580,98 +520,26 @@ def noop_no_sync(): latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_null) - # noise_pred = noise_pred_uncond + guide_scale * ( - # noise_pred_cond - noise_pred_uncond) - - # DEBUGGING + # run unpatchify unpatchified_noise_pred_cond = noise_pred_cond unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd - # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. ??? + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. unpatchified_noise_pred_cond = self.unpatchify(unpatchified_noise_pred_cond, grid_sizes, self.vae.model.z_dim) - unpatchified_noise_pred_uncond = noise_pred_uncond unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond.transpose(0, 1) # bring sbhd -> bshd - # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. ??? + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. unpatchified_noise_pred_uncond = self.unpatchify(unpatchified_noise_pred_uncond, grid_sizes, self.vae.model.z_dim) - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print(f"[DEBUG] unpatchified_noise_pred_cond[0].shape - unpatchified_noise_pred_cond[0].dtype - unpatchified_noise_pred_cond[0].mean() - unpatchified_noise_pred_cond[0].std() - unpatchified_noise_pred_cond[0].norm(): {unpatchified_noise_pred_cond[0].shape} - {unpatchified_noise_pred_cond[0].dtype} - {unpatchified_noise_pred_cond[0].mean()} - {unpatchified_noise_pred_cond[0].std()} - {unpatchified_noise_pred_cond[0].norm()}") - print(f"[DEBUG] unpatchified_noise_pred_uncond[0].shape - unpatchified_noise_pred_uncond[0].dtype - unpatchified_noise_pred_uncond[0].mean() - unpatchified_noise_pred_uncond[0].std() - unpatchified_noise_pred_uncond[0].norm(): {unpatchified_noise_pred_uncond[0].shape} - {unpatchified_noise_pred_uncond[0].dtype} - {unpatchified_noise_pred_uncond[0].mean()} - {unpatchified_noise_pred_uncond[0].std()} - {unpatchified_noise_pred_uncond[0].norm()}") - - noise_preds = [] for i in range(batch_size): noise_pred = unpatchified_noise_pred_uncond[i] + guide_scale * ( unpatchified_noise_pred_cond[i] - unpatchified_noise_pred_uncond[i]) noise_preds.append(noise_pred) - # unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond[0] - # unpatchified_noise_pred_cond = unpatchified_noise_pred_cond[0] - - # noise_pred = unpatchified_noise_pred_uncond + guide_scale * ( - # unpatchified_noise_pred_cond - unpatchified_noise_pred_uncond) - - # # DEBUGGING - # # we will be running unpatchify here??? - # # x0 = latents - # if run_debug and torch.distributed.get_rank()==0: - # print(f"[DEBUG] [rank {torch.distributed.get_rank()}] (before unpatchify) noise_pred_cond.shape: {noise_pred_cond.shape}") - # print(f"[DEBUG] [rank {torch.distributed.get_rank()}] (before unpatchify) noise_pred_uncond.shape: {noise_pred_uncond.shape}") - # noise_pred_cond = noise_pred_cond.transpose(0, 1) - # noise_pred_cond = self.unpatchify(noise_pred_cond, grid_sizes, self.vae.model.z_dim) - # noise_pred_cond = noise_pred_cond.transpose(0, 1) - # noise_pred_uncond = noise_pred_uncond.transpose(0, 1) - # noise_pred_uncond = self.unpatchify(noise_pred_uncond, grid_sizes, self.vae.model.z_dim) - # noise_pred_uncond = noise_pred_uncond.transpose(0, 1) - # if run_debug and torch.distributed.get_rank()==0: - # print(f"[DEBUG] [rank {torch.distributed.get_rank()}] (after unpatchify) noise_pred_cond.shape: {noise_pred_cond.shape}") - # print(f"[DEBUG] [rank {torch.distributed.get_rank()}] (after unpatchify) noise_pred_uncond.shape: {noise_pred_uncond.shape}") - # print(stop_here) - - # # we run unpatchify here, but unpatchify should be run seprately for each sample in the batch, because the video shape is different for each sample in the batch. - # # ??? when batch_size > 1, we need to run sample_scheduler.step seprately for each sample in the batch. - # noise_pred = noise_pred.transpose(0, 1) # bring sbhd -> bshd - # noise_pred = self.unpatchify(noise_pred, grid_sizes, self.vae.model.z_dim) - - # print("[DEBUG] len(noise_pred): ", len(noise_pred)) - # print("[DEBUG] len(unpatchified_latents): ", len(unpatchified_latents)) - # print("[DEBUG] noise_pred[0].shape - noise_pred[0].dtype - noise_pred[0].mean() - noise_pred[0].std() - noise_pred[0].norm(): ", noise_pred[0].shape, noise_pred[0].dtype, noise_pred[0].mean(), noise_pred[0].std(), noise_pred[0].norm()) - # print("[DEBUG] unpatchified_latents[0].shape - unpatchified_latents[0].dtype - unpatchified_latents[0].mean() - unpatchified_latents[0].std() - unpatchified_latents[0].norm(): ", unpatchified_latents[0].shape, unpatchified_latents[0].dtype, unpatchified_latents[0].mean(), unpatchified_latents[0].std(), unpatchified_latents[0].norm()) - - # latents = [] - # for i in range(len(noise_pred)): - # temp_x0 = sample_scheduler.step( - # noise_pred[i].unsqueeze(0), - # t, - # unpatchified_latents[i].unsqueeze(0), - # return_dict=False, - # generator=seed_g)[0] - # latents.append(temp_x0.squeeze(0)) - - # print("len(latents): ", len(latents)) - # print("latents[0].shape: ", latents[0].shape) - - # latents = unpatchified_latents - # print(f"[DEBUG] noise_pred.shape - noise_pred.dtype - noise_pred.mean() - noise_pred.std() - noise_pred.norm(): {noise_pred.shape} - {noise_pred.dtype} - {noise_pred.mean()} - {noise_pred.std()} - {noise_pred.norm()}") - # print(f"[DEBUG] latents[0].shape - latents[0].dtype - latents[0].mean() - latents[0].std() - latents[0].norm(): {latents[0].shape} - {latents[0].dtype} - {latents[0].mean()} - {latents[0].std()} - {latents[0].norm()}") - # print(f"[DEBUG] noise_pred: {noise_pred}") - # print(f"[DEBUG] latents[0]: {latents[0]}") - - print("batch_size: ", batch_size) - # step and update latents latents = [] for i in range(batch_size): - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print("[DEBUG] len(unpatchified_latents): ", len(unpatchified_latents)) - print("[DEBUG] len(noise_preds): ", len(noise_preds)) - print("[DEBUG] unpatchified_latents[i].shape - unpatchified_latents[i].dtype - unpatchified_latents[i].mean() - unpatchified_latents[i].std() - unpatchified_latents[i].norm(): ", unpatchified_latents[i].shape, unpatchified_latents[i].dtype, unpatchified_latents[i].mean(), unpatchified_latents[i].std(), unpatchified_latents[i].norm()) - print("[DEBUG] noise_preds[i].shape - noise_preds[i].dtype - noise_preds[i].mean() - noise_preds[i].std() - noise_preds[i].norm(): ", noise_preds[i].shape, noise_preds[i].dtype, noise_preds[i].mean(), noise_preds[i].std(), noise_preds[i].norm()) - - if sample_solver == 'unipc': temp_x0 = schedulers[i].step( noise_preds[i].unsqueeze(0), @@ -688,25 +556,6 @@ def noop_no_sync(): generator=seed_g)[0] latents.append(temp_x0.squeeze(0)) - # # DEBUGGING - # # we will be running unpatchify here??? - # # x0 = latents - # x0 = self.unpatchify(latents, grid_sizes) - - # # loop through each sample in the batch - # videos = [] - # if offload_model: - # self.model.cpu() - # torch.cuda.empty_cache() - # x0 = latents - # if self.rank == 0: - # videos = self.vae.decode(x0) - - # DEBUGGING - print("[DEBUG] len(latents): ", len(latents)) - print("[DEBUG] latents[0].shape - latents[0].dtype - latents[0].mean() - latents[0].std() - latents[0].norm(): ", latents[0].shape, latents[0].dtype, latents[0].mean(), latents[0].std(), latents[0].norm()) - print("[DEBUG] latents[0]: ", latents[0]) - x0 = latents if offload_model: self.model.cpu() @@ -716,17 +565,6 @@ def noop_no_sync(): else: videos = None - - # # DEBUGGING - # print("len(latents): ", len(latents)) - # print("latents[0].shape - latents[0].dtype - latents[0].mean() - latents[0].std() - latents[0].norm(): ", latents[0].shape, latents[0].dtype, latents[0].mean(), latents[0].std(), latents[0].norm()) - # print("latents[0]: ", latents[0]) - # print("len(videos): ", len(videos)) - if videos is not None: - print("len(videos): ", len(videos)) - print("[DEBUG] videos[0].shape - videos[0].dtype - videos[0].mean() - videos[0].std() - videos[0].norm(): ", videos[0].shape, videos[0].dtype, videos[0].mean(), videos[0].std(), videos[0].norm()) - print("[DEBUG] videos[0]: ", videos[0]) - del noises, latents if sample_solver == 'unipc': del schedulers diff --git a/src/megatron/bridge/models/wan/inference/configs/shared_config.py b/src/megatron/bridge/models/wan/inference/configs/shared_config.py index 56a99ad433..04a9f45421 100644 --- a/src/megatron/bridge/models/wan/inference/configs/shared_config.py +++ b/src/megatron/bridge/models/wan/inference/configs/shared_config.py @@ -11,9 +11,7 @@ wan_shared_cfg.text_len = 512 # transformer -# DEBUGGING wan_shared_cfg.param_dtype = torch.bfloat16 -# wan_shared_cfg.param_dtype = torch.float32 # inference wan_shared_cfg.num_train_timesteps = 1000 diff --git a/src/megatron/bridge/models/wan/wan_layer_spec.py b/src/megatron/bridge/models/wan/wan_layer_spec.py index 3b014140cf..fdd4d9957f 100644 --- a/src/megatron/bridge/models/wan/wan_layer_spec.py +++ b/src/megatron/bridge/models/wan/wan_layer_spec.py @@ -128,7 +128,7 @@ def __init__( norm_config.normalization = "RMSNorm" self.q_layernorm = build_module( submodules.q_layernorm, - eps=1e-6, + eps=norm_config.layernorm_epsilon, hidden_size=q_layernorm_size, config=norm_config, ) @@ -146,7 +146,7 @@ def __init__( norm_config.normalization = "RMSNorm" self.k_layernorm = build_module( submodules.k_layernorm, - eps=1e-6, + eps=norm_config.layernorm_epsilon, hidden_size=k_layernorm_size, config=norm_config, ) @@ -268,7 +268,7 @@ def __init__( norm_config.normalization = "RMSNorm" self.q_layernorm = build_module( submodules.q_layernorm, - eps=1e-6, + eps=norm_config.layernorm_epsilon, hidden_size=q_layernorm_size, config=norm_config, ) @@ -286,7 +286,7 @@ def __init__( norm_config.normalization = "RMSNorm" self.k_layernorm = build_module( submodules.k_layernorm, - eps=1e-6, + eps=norm_config.layernorm_epsilon, hidden_size=k_layernorm_size, config=norm_config, ) @@ -441,19 +441,19 @@ def __init__( self.norm1 = build_module( submodules.norm1, dim=config.hidden_size, - eps=1e-6, + eps=config.layernorm_epsilon, elementwise_affine=False ) self.norm3 = build_module( submodules.norm3, dim=config.hidden_size, - eps=1e-6, + eps=config.layernorm_epsilon, elementwise_affine=True, ) self.norm2 = build_module( submodules.norm2, dim=config.hidden_size, - eps=1e-6, + eps=config.layernorm_epsilon, elementwise_affine=False, ) @@ -477,32 +477,6 @@ def forward( timestep_emb = attention_mask rope_emb = rotary_pos_emb - # DEBUGGING - run_debug = False - - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print("[DEBUG][WanLayerWithAdaLN] ================================") - print("[DEBUG][WanLayerWithAdaLN][forward_input] hidden_states.shape - hidden_states.dtype - hidden_states.mean() - hidden_states.std() - hidden_states.norm():", hidden_states.shape, hidden_states.dtype, hidden_states.mean(), hidden_states.std(), hidden_states.norm()) - print("[DEBUG][WanLayerWithAdaLN][forward_input] timestep_emb.shape - timestep_emb.dtype - timestep_emb.mean() - timestep_emb.std() - timestep_emb.norm():", timestep_emb.shape, timestep_emb.dtype, timestep_emb.mean(), timestep_emb.std(), timestep_emb.norm()) - print("[DEBUG][WanLayerWithAdaLN][forward_input] context.shape - context.dtype - context.mean() - context.std() - context.norm():", context.shape, context.dtype, context.mean(), context.std(), context.norm()) - if context_mask is not None: - print("[DEBUG][WanLayerWithAdaLN][forward_input] context_mask.shape - context_mask.dtype - context_mask.mean() - context_mask.std() - context_mask.norm():", context_mask.shape, context_mask.dtype, context_mask.mean(), context_mask.std(), context_mask.norm()) - if rotary_pos_emb is not None: - print("[DEBUG][WanLayerWithAdaLN][forward_input] rotary_pos_emb.shape - rotary_pos_emb.dtype - rotary_pos_emb.mean() - rotary_pos_emb.std() - rotary_pos_emb.norm():", rotary_pos_emb.shape, rotary_pos_emb.dtype, rotary_pos_emb.mean(), rotary_pos_emb.std(), rotary_pos_emb.norm()) - if rotary_pos_cos is not None: - print("[DEBUG][WanLayerWithAdaLN][forward_input] rotary_pos_cos.shape - rotary_pos_cos.dtype - rotary_pos_cos.mean() - rotary_pos_cos.std() - rotary_pos_cos.norm():", rotary_pos_cos.shape, rotary_pos_cos.dtype, rotary_pos_cos.mean(), rotary_pos_cos.std(), rotary_pos_cos.norm()) - if rotary_pos_sin is not None: - print("[DEBUG][WanLayerWithAdaLN][forward_input] rotary_pos_sin.shape - rotary_pos_sin.dtype - rotary_pos_sin.mean() - rotary_pos_sin.std() - rotary_pos_sin.norm():", rotary_pos_sin.shape, rotary_pos_sin.dtype, rotary_pos_sin.mean(), rotary_pos_sin.std(), rotary_pos_sin.norm()) - if attention_bias is not None: - print("[DEBUG][WanLayerWithAdaLN][forward_input] attention_bias.shape - attention_bias.dtype - attention_bias.mean() - attention_bias.std() - attention_bias.norm():", attention_bias.shape, attention_bias.dtype, attention_bias.mean(), attention_bias.std(), attention_bias.norm()) - if inference_params is not None: - print("[DEBUG][WanLayerWithAdaLN][forward_input] inference_params.shape - inference_params.dtype - inference_params.mean() - inference_params.std() - inference_params.norm():", inference_params.shape, inference_params.dtype, inference_params.mean(), inference_params.std(), inference_params.norm()) - if packed_seq_params is not None: - print("[DEBUG][WanLayerWithAdaLN][forward_input] packed_seq_params:", packed_seq_params) - if sequence_len_offset is not None: - print("[DEBUG][WanLayerWithAdaLN][forward_input] sequence_len_offset.shape - sequence_len_offset.dtype - sequence_len_offset.mean() - sequence_len_offset.std() - sequence_len_offset.norm():", sequence_len_offset.shape, sequence_len_offset.dtype, sequence_len_offset.mean(), sequence_len_offset.std(), sequence_len_offset.norm()) - shift_full, scale_full, gate_full, shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb) # transpose to bring it to [1, b, ...] format shift_full = shift_full.transpose(0, 1) @@ -514,24 +488,6 @@ def forward( # ******************************************** full self attention ******************************************* - if run_debug and torch.distributed.get_rank()==0: - print("[DEBUG][WanLayerWithAdaLN] shift_full.shape - shift_full.dtype - shift_full.mean() - shift_full.std():", shift_full.shape, shift_full.dtype, float(shift_full.mean().item()), float(shift_full.std().item())) - print("[DEBUG][WanLayerWithAdaLN] scale_full.shape - scale_full.dtype - scale_full.mean() - scale_full.std():", scale_full.shape, scale_full.dtype, scale_full.mean(), scale_full.std()) - print("[DEBUG][WanLayerWithAdaLN] gate_full.shape - gate_full.dtype - gate_full.mean() - gate_full.std():", gate_full.shape, gate_full.dtype, gate_full.mean(), gate_full.std()) - print("[DEBUG][WanLayerWithAdaLN] shift_mlp.shape - shift_mlp.dtype - shift_mlp.mean() - shift_mlp.std():", shift_mlp.shape, shift_mlp.dtype, shift_mlp.mean(), shift_mlp.std()) - print("[DEBUG][WanLayerWithAdaLN] scale_mlp.shape - scale_mlp.dtype - scale_mlp.mean() - scale_mlp.std():", scale_mlp.shape, scale_mlp.dtype, scale_mlp.mean(), scale_mlp.std()) - print("[DEBUG][WanLayerWithAdaLN] gate_mlp.shape - gate_mlp.dtype - gate_mlp.mean() - gate_mlp.std():", gate_mlp.shape, gate_mlp.dtype, gate_mlp.mean(), gate_mlp.std()) - - # DEBUGGING - # if run_debug and torch.distributed.get_rank()==0: - if run_debug: - x_debug = hidden_states # DEBUGGING - print(f"[DEBUG][WanLayerWithAdaLN] [rank {torch.distributed.get_rank()}] hidden_states.shape - hidden_states.dtype - hidden_states.mean() - hidden_states.std():", hidden_states.shape, hidden_states.dtype, float(hidden_states.mean().item()), float(hidden_states.std().item())) - print(f"[DEBUG][WanLayerWithAdaLN] [rank {torch.distributed.get_rank()}] self.norm1(hidden_states).shape - self.norm1(hidden_states).dtype - self.norm1(hidden_states).mean() - self.norm1(hidden_states).std():", self.norm1(hidden_states).shape, self.norm1(hidden_states).dtype, float(self.norm1(hidden_states).mean().item()), float(self.norm1(hidden_states).std().item())) - print(f"[DEBUG][WanLayerWithAdaLN] [rank {torch.distributed.get_rank()}] shift_full.shape - shift_full.dtype - shift_full.mean() - shift_full.std():", shift_full.shape, shift_full.dtype, float(shift_full.mean().item()), float(shift_full.std().item())) - print(f"[DEBUG][WanLayerWithAdaLN] [rank {torch.distributed.get_rank()}] scale_full.shape - scale_full.dtype - scale_full.mean() - scale_full.std():", scale_full.shape, scale_full.dtype, float(scale_full.mean().item()), float(scale_full.std().item())) - - # adaLN with scale + shift + gate pre_full_attn_layernorm_output_ada = self.adaLN.modulate( self.norm1(hidden_states.float()), # Wan's LayerNorm implementation forward pass casts input to float32 @@ -553,15 +509,6 @@ def forward( with amp.autocast(dtype=torch.float32): hidden_states = self.adaLN.scale_add(residual=hidden_states, x=attention_output, gate=gate_full) - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print("[DEBUG][WanLayerWithAdaLN][self_attention] x_debug.shape - x_debug.dtype - x_debug.mean() - x_debug.std() - x.norm:", x_debug.shape, x_debug.dtype, x_debug.mean(), x_debug.std(), x_debug.norm()) - print("[DEBUG][WanLayerWithAdaLN][self_attention] pre_full_attn_layernorm_output_ada.shape - pre_full_attn_layernorm_output_ada.dtype - pre_full_attn_layernorm_output_ada.mean() - pre_full_attn_layernorm_output_ada.std() - pre_full_attn_layernorm_output_ada.norm:", pre_full_attn_layernorm_output_ada.shape, pre_full_attn_layernorm_output_ada.dtype, pre_full_attn_layernorm_output_ada.mean(), pre_full_attn_layernorm_output_ada.std(), pre_full_attn_layernorm_output_ada.norm()) - print("[DEBUG][WanLayerWithAdaLN][self_attention] attention_output.shape - attention_output.dtype - attention_output.mean() - attention_output.std() - attention_output.norm():", attention_output.shape, attention_output.dtype, attention_output.mean(), attention_output.std(), attention_output.norm()) - print("[DEBUG][WanLayerWithAdaLN][self_attention] gate_full.shape - gate_full.dtype - gate_full.mean() - gate_full.std() - gate_full.norm():", gate_full.shape, gate_full.dtype, gate_full.mean(), gate_full.std(), gate_full.norm()) - print("[DEBUG][WanLayerWithAdaLN][self_attention] hidden_states.shape - hidden_states.dtype - hidden_states.mean() - hidden_states.std() - hidden_states.norm():", hidden_states.shape, hidden_states.dtype, hidden_states.mean(), hidden_states.std(), hidden_states.norm()) - - # ******************************************** cross attention ****************************************************** attention_output, bias = self.cross_attention( @@ -575,11 +522,6 @@ def forward( hidden_states = hidden_states + attention_output - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print("[DEBUG][WanLayerWithAdaLN][cross_attention] attention_output.shape - attention_output.dtype - attention_output.mean() - attention_output.std() - attention_output.norm():", attention_output.shape, attention_output.dtype, attention_output.mean(), attention_output.std(), attention_output.norm()) - print("[DEBUG][WanLayerWithAdaLN][cross_attention] hidden_states.shape - hidden_states.dtype - hidden_states.mean() - hidden_states.std() - hidden_states.norm():", hidden_states.shape, hidden_states.dtype, hidden_states.mean(), hidden_states.std(), hidden_states.norm()) - # ******************************************** mlp ****************************************************** pre_mlp_layernorm_output_ada = self.adaLN.modulate( @@ -592,9 +534,6 @@ def forward( if bias is not None: mlp_output = mlp_output + bias - # DEBUGGING - print("self.mlp.activation_func:", self.mlp.activation_func) - with amp.autocast(dtype=torch.float32): hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) @@ -608,23 +547,6 @@ def forward( output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) # output = hidden_states - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print("[DEBUG][WanLayerWithAdaLN][mlp] pre_mlp_layernorm_output_ada.shape - pre_mlp_layernorm_output_ada.dtype - pre_mlp_layernorm_output_ada.mean() - pre_mlp_layernorm_output_ada.std() - pre_mlp_layernorm_output_ada.norm():", pre_mlp_layernorm_output_ada.shape, pre_mlp_layernorm_output_ada.dtype, pre_mlp_layernorm_output_ada.mean(), pre_mlp_layernorm_output_ada.std(), pre_mlp_layernorm_output_ada.norm()) - print("[DEBUG][WanLayerWithAdaLN][mlp] mlp_output.shape - mlp_output.dtype - mlp_output.mean() - mlp_output.std() - mlp_output.norm():", mlp_output.shape, mlp_output.dtype, mlp_output.mean(), mlp_output.std(), mlp_output.norm()) - print("[DEBUG][WanLayerWithAdaLN][mlp] gate_mlp.shape - gate_mlp.dtype - gate_mlp.mean() - gate_mlp.std() - gate_mlp.norm():", gate_mlp.shape, gate_mlp.dtype, gate_mlp.mean(), gate_mlp.std(), gate_mlp.norm()) - print("[DEBUG][WanLayerWithAdaLN][mlp] hidden_states.shape - hidden_states.dtype - hidden_states.mean() - hidden_states.std() - hidden_states.norm():", hidden_states.shape, hidden_states.dtype, hidden_states.mean(), hidden_states.std(), hidden_states.norm()) - - # DEBUGGING - if run_debug: - hidden_states_concatenated = cat_outputs_cp(hidden_states, 0, parallel_state.get_context_parallel_group()) - if torch.distributed.get_rank()==0: - print("[DEBUG][WanLayerWithAdaLN][mlp] (after cat_outputs_cp) hidden_states_concatenated.shape - hidden_states_concatenated.dtype - hidden_states_concatenated.mean() - hidden_states_concatenated.std() - hidden_states_concatenated.norm():", hidden_states_concatenated.shape, hidden_states_concatenated.dtype, hidden_states_concatenated.mean(), hidden_states_concatenated.std(), hidden_states_concatenated.norm()) - - # # DEBUGGING - # if run_debug and torch.distributed.get_rank()==0: - # print(stop_here) - return output, context diff --git a/src/megatron/bridge/models/wan/wan_model.py b/src/megatron/bridge/models/wan/wan_model.py index adb2d6eaad..47662dbcc7 100644 --- a/src/megatron/bridge/models/wan/wan_model.py +++ b/src/megatron/bridge/models/wan/wan_model.py @@ -86,11 +86,7 @@ class WanModel(VisionModule): post_process (bool): Whether to apply post-processing steps. fp16_lm_cross_entropy (bool): Whether to use fp16 for cross-entropy loss. parallel_output (bool): Whether to use parallel output. - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. transformer_decoder_layer_spec (WanLayerWithAdaLNspec): Specification for the transformer decoder layer. - add_encoder (bool): Whether to add an encoder. - add_decoder (bool): Whether to add a decoder. model_type (ModelType): Type of the model. """ @@ -101,8 +97,6 @@ def __init__( post_process: bool = True, fp16_lm_cross_entropy: bool = False, parallel_output: bool = True, - in_channels: int = 16, - out_channels: int = 16, transformer_decoder_layer_spec=WanLayerWithAdaLNspec, **kwargs, ): @@ -113,12 +107,8 @@ def __init__( self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() self.pre_process = pre_process self.post_process = post_process - self.add_encoder = True - self.add_decoder = True self.fp16_lm_cross_entropy = fp16_lm_cross_entropy self.parallel_output = parallel_output - self.in_channels = in_channels - self.out_channels = out_channels # megatron core pipelining currently depends on model type # TODO: remove this dependency ? @@ -126,6 +116,8 @@ def __init__( self.num_heads = self.config.num_attention_heads self.freq_dim = self.config.freq_dim + self.in_channels = self.config.in_channels + self.out_channels = self.config.out_channels self.patch_spatial = self.config.patch_spatial self.patch_temporal = self.config.patch_temporal self.patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial) @@ -189,32 +181,6 @@ def forward( ################################# ########## Wan forward ########## - # DEBUGGING - run_debug = False - - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print("[DEBUG] state_dict keys:") - for k, v in self.state_dict().items(): - if "_extra_state" in k: - continue - if hasattr(v, "shape"): - print(f"[DEBUG] {k} | shape - dtype - mean - std - norm: {tuple(v.shape)} - {v.dtype} - {v.mean().item()} - {v.std().item()} - {v.norm().item()}") - else: - print(f"[DEBUG] {k}") - print("\n\n\n") - - - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print("[DEBUG] [WanModel forward] x.shape - x.dtype - x.mean() - x.std() - x.norm(): ", x.shape, x.dtype, x.mean(), x.std(), x.norm()) - print("[DEBUG] [WanModel forward] grid_sizes: ", grid_sizes) - print("[DEBUG] [WanModel forward] t: ", t) - print("[DEBUG] [WanModel forward] context.shape - context.dtype - context.mean() - context.std() - context.norm(): ", context.shape, context.dtype, context.mean(), context.std(), context.norm()) - print("[DEBUG] [WanModel forward] max_seq_len: ", max_seq_len) - print("[DEBUG] [WanModel forward] packed_seq_params: ", packed_seq_params) - - # ============= embedders ============= # run input embedding @@ -237,11 +203,6 @@ def forward( # intermediate stage of pipeline x = self.decoder.input_tensor - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print("[DEBUG] [WanModel forward] (after patch_embedding) x.shape - x.dtype - x.mean() - x.std() - x.norm(): ", x.shape, x.dtype, x.mean(), x.std(), x.norm()) - print("[DEBUG] [WanModel forward] (after patch_embedding) x:", x) - # time embeddings with amp.autocast(dtype=torch.float32): e = self.time_embedding( @@ -258,14 +219,6 @@ def forward( n_head, dim_head = self.num_heads, self.config.hidden_size // self.num_heads rotary_pos_emb = self.rope_embeddings(n_head, dim_head, max_seq_len, grid_sizes, t.device) # output: rotary_pos_emb.shape [s, b, 1, dim_head] - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print("[DEBUG] [WanModel forward] (before self.decoder) x.shape - x.dtype - x.mean() - x.std() - x.norm(): ", x.shape, x.dtype, x.mean(), x.std(), x.norm()) - print("[DEBUG] [WanModel forward] (before self.decoder) context.shape - context.dtype - context.mean() - context.std() - context.norm(): ", context.shape, context.dtype, context.mean(), context.std(), context.norm()) - print("[DEBUG] [WanModel forward] (before self.decoder) e0.shape - e0.dtype - e0.mean() - e0.std() - e0.norm(): ", e0.shape, e0.dtype, e0.mean(), e0.std(), e0.norm()) - print("[DEBUG] [WanModel forward] (before self.decoder) rotary_pos_emb.shape - rotary_pos_emb.dtype - rotary_pos_emb.mean() - rotary_pos_emb.std() - rotary_pos_emb.norm(): ", rotary_pos_emb.shape, rotary_pos_emb.dtype, rotary_pos_emb.mean(), rotary_pos_emb.std(), rotary_pos_emb.norm()) - print("[DEBUG] [WanModel forward] (before self.decoder) packed_seq_params: ", packed_seq_params) - # run decoder x = self.decoder( hidden_states=x, @@ -278,10 +231,6 @@ def forward( packed_seq_params=packed_seq_params, ) - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print("[DEBUG] [WanModel forward] (after self.decoder) x.shape - x.dtype - x.mean() - x.std() - x.norm(): ", x.shape, x.dtype, x.mean(), x.std(), x.norm()) - # return if not post_process if not self.post_process: return x @@ -298,10 +247,6 @@ def forward( if self.config.sequence_parallel: x = tensor_parallel.gather_from_sequence_parallel_region(x) - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print("[DEBUG] [WanModel forward] (after self.head) x.shape - x.dtype - x.mean() - x.std() - x.norm(): ", x.shape, x.dtype, x.mean(), x.std(), x.norm()) - return x # output: x.shape [s, b, c * pF * pH * pW] diff --git a/src/megatron/bridge/models/wan/wan_provider.py b/src/megatron/bridge/models/wan/wan_provider.py index 0003761f5e..de7487f3ac 100644 --- a/src/megatron/bridge/models/wan/wan_provider.py +++ b/src/megatron/bridge/models/wan/wan_provider.py @@ -12,27 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import inspect import logging -from dataclasses import dataclass, field -from functools import partial -from typing import Any, Callable, Dict, Literal, Optional, Union +from dataclasses import dataclass import torch from megatron.core import parallel_state -from megatron.core.models.gpt import GPTModel as MCoreGPTModel -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.transformer import ModuleSpec from megatron.bridge.models.transformer_config import TransformerConfig -from megatron.bridge.models.DiTModel.dit_utils import dynamic_import from megatron.bridge.models.model_provider import ModelProviderMixin -from megatron.bridge.utils import fusions -from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size from megatron.core.models.common.vision_module.vision_module import VisionModule from megatron.bridge.models.wan.wan_model import WanModel @@ -47,53 +34,29 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): num_layers: int = 30 hidden_size: int = 1536 ffn_hidden_size: int = 8960 - max_img_h: int = 80 - max_img_w: int = 80 - max_frames: int = 34 - patch_spatial: int = 2 - patch_temporal: int = 1 num_attention_heads: int = 12 - layernorm_epsilon = 1e-6 - normalization = "RMSNorm" - qk_layernorm_per_head: bool = False - layernorm_zero_centered_gamma = False - - fp16_lm_cross_entropy: bool = False - parallel_output: bool = True - share_embeddings_and_output_weights: bool = True - + layernorm_epsilon: float = 1e-6 + normalization: str = "RMSNorm" + layernorm_zero_centered_gamma: bool = False hidden_dropout: float = 0 attention_dropout: float = 0 - + fp16_lm_cross_entropy: bool = False + parallel_output: bool = True bf16: bool = False params_dtype: torch.dtype = torch.float32 + qkv_format: str = 'sbhd' + # these attributes are unused for images/videos, we just set because bridge training requires for LLMs + seq_length: int = 1024 + share_embeddings_and_output_weights: bool = False - vae_module: str = "nemo_vfm.diffusion.vae.diffusers_vae.AutoencoderKLVAE" - vae_path: str = None - sigma_data: float = 0.5 - + # images/videos attributes in_channels: int = 16 out_channels: int = 16 - - replicated_t_embedder = True - qkv_format: str = 'sbhd' - - # DEBUGGING - # adding more attributes - text_dim: int = 4096 - patch_size: list = field(default_factory=lambda: [1, 2, 2]) + patch_spatial: int = 2 + patch_temporal: int = 1 freq_dim: int = 256 - out_dim: int = 16 - text_len: int = 512 - - - - # DEBUGGING - # unused, we just set because bridge training requires this for LLMs - seq_length: int = 1024 - vocab_size: int = None - make_vocab_size_divisible_by: int = 128 - + text_len: int = 512 + text_dim: int = 4096 def provide(self, pre_process=None, post_process=None, vp_stage=None) -> WanModel: vp_size = self.virtual_pipeline_model_parallel_size @@ -107,15 +70,8 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> WanMode return model( self, - fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, - parallel_output=self.parallel_output, pre_process=parallel_state.is_pipeline_first_stage(), post_process=parallel_state.is_pipeline_last_stage(), - max_img_h=self.max_img_h, - max_img_w=self.max_img_w, - max_frames=self.max_frames, - patch_spatial=self.patch_spatial, - ) - - def configure_vae(self): - return dynamic_import(self.vae_module)(self.vae_path) \ No newline at end of file + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + ) \ No newline at end of file From e41b3d12d1566503cbd1c6200a9aef01e2be59d4 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 29 Oct 2025 20:10:34 -0700 Subject: [PATCH 03/17] workable model implementation, inference, finetuning --- .../conversion/convert_wan_checkpoints.py | 20 + examples/recipes/wan/inference_wan.py | 43 +- examples/recipes/wan/pretrain_wan.py | 184 ++++++++ src/megatron/bridge/data/loaders.py | 6 +- .../data/wan/prepare_energon_dataset_wan.py | 404 ++++++++++++++++++ .../bridge/data/wan/wan_energon_datamodule.py | 47 ++ .../bridge/data/wan/wan_taskencoder.py | 190 ++++++++ .../bridge/models/conversion/param_mapping.py | 152 +++++++ .../bridge/models/hf_pretrained/__init__.py | 3 +- .../bridge/models/hf_pretrained/state.py | 12 +- .../bridge/models/hf_pretrained/wan.py | 52 +++ .../flow_matching/flow_inference_pipeline.py | 149 +++---- .../models/wan/flow_matching/flow_pipeline.py | 305 ++++++------- .../wan/flow_matching/time_shift_utils.py | 108 +++++ .../models/wan/inference/configs/__init__.py | 1 - .../wan/inference/configs/shared_config.py | 1 - .../wan/inference/configs/wan_i2v_14B.py | 1 - .../wan/inference/configs/wan_t2v_14B.py | 1 - .../wan/inference/configs/wan_t2v_1_3B.py | 1 - .../models/wan/inference/utils/fm_solvers.py | 1 - .../wan/inference/utils/fm_solvers_unipc.py | 1 - .../models/wan/inference/utils/utils.py | 1 - src/megatron/bridge/models/wan/modules/t5.py | 1 - .../bridge/models/wan/modules/tokenizers.py | 1 - src/megatron/bridge/models/wan/modules/vae.py | 1 - src/megatron/bridge/models/wan/rope_utils.py | 8 +- src/megatron/bridge/models/wan/utils/utils.py | 128 ++++++ src/megatron/bridge/models/wan/wan_bridge.py | 30 -- .../bridge/models/wan/wan_layer_spec.py | 29 +- src/megatron/bridge/models/wan/wan_model.py | 22 +- .../bridge/models/wan/wan_provider.py | 4 + src/megatron/bridge/models/wan/wan_step.py | 85 +--- src/megatron/bridge/recipes/wan/wan.py | 219 ++++++++++ 33 files changed, 1798 insertions(+), 413 deletions(-) create mode 100644 examples/conversion/convert_wan_checkpoints.py create mode 100644 examples/recipes/wan/pretrain_wan.py create mode 100644 src/megatron/bridge/data/wan/prepare_energon_dataset_wan.py create mode 100644 src/megatron/bridge/data/wan/wan_energon_datamodule.py create mode 100644 src/megatron/bridge/data/wan/wan_taskencoder.py create mode 100644 src/megatron/bridge/models/hf_pretrained/wan.py create mode 100644 src/megatron/bridge/models/wan/flow_matching/time_shift_utils.py create mode 100644 src/megatron/bridge/models/wan/utils/utils.py create mode 100644 src/megatron/bridge/recipes/wan/wan.py diff --git a/examples/conversion/convert_wan_checkpoints.py b/examples/conversion/convert_wan_checkpoints.py new file mode 100644 index 0000000000..4594ebaa5e --- /dev/null +++ b/examples/conversion/convert_wan_checkpoints.py @@ -0,0 +1,20 @@ +from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN +from megatron.bridge.models.wan.wan_bridge import WanBridge +from megatron.bridge.training.model_load_save import save_megatron_model +import os, random +os.environ["MASTER_ADDR"] = "127.0.0.1" +os.environ["MASTER_PORT"] = str(29500 + random.randint(0, 1000)) +os.environ["RANK"] = "0" +os.environ["WORLD_SIZE"] = "1" +os.environ["LOCAL_RANK"] = "0" +# +# hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") +hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-14B-Diffusers") +bridge = WanBridge() +# +provider = bridge.provider_bridge(hf) +provider.perform_initialization = False +megatron_models = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True) +# +bridge.load_weights_hf_to_megatron(hf, megatron_models) +save_megatron_model(megatron_models, "/opt/megatron_checkpoint", hf_tokenizer_path=None) \ No newline at end of file diff --git a/examples/recipes/wan/inference_wan.py b/examples/recipes/wan/inference_wan.py index 8edd890f9c..61f38ecdea 100644 --- a/examples/recipes/wan/inference_wan.py +++ b/examples/recipes/wan/inference_wan.py @@ -1,9 +1,10 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Example of running script for Wan inference. # NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ # --task t2v-1.3B \ # --sizes 480*832 \ -# --ckpt_dir /path/to/wan_checkpoints \ +# --checkpoint_dir /path/to/wan_checkpoint_dir \ +# --t5_checkpoint_dir /path/to/t5_checkpoint_dir \ +# --vae_checkpoint_dir /path/to/vae_checkpoint_dir \ # --frame_nums 81 \ # --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ # --tensor_parallel_size 1 \ @@ -32,11 +33,6 @@ from megatron.bridge.models.wan.inference.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from megatron.bridge.models.wan.inference.utils.utils import cache_video, str2bool -# DEBUGGING -import numpy as np -np.set_printoptions(precision=10, suppress=False) -torch.set_printoptions(precision=6, sci_mode=False) - EXAMPLE_PROMPT = { "t2v-1.3B": { "prompt": @@ -51,7 +47,9 @@ def _validate_args(args): # Basic check - assert args.ckpt_dir is not None, "Please specify the checkpoint directory." + assert args.checkpoint_dir is not None, "Please specify the checkpoint directory." + assert args.t5_checkpoint_dir is not None, "Please specify the T5 checkpoint directory." + assert args.vae_checkpoint_dir is not None, "Please specify the VAE checkpoint directory." assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" @@ -90,7 +88,7 @@ def _parse_args(): nargs="+", default=None, choices=list(SIZE_CONFIGS.keys()), - help="A list of sizes to generate multiple images or videos. Example: --sizes 1280*720 1920*1080" + help="A list of sizes to generate multiple images or videos (WIDTH*HEIGHT). Example: --sizes 1280*720 1920*1080" ) parser.add_argument( "--frame_nums", @@ -100,10 +98,28 @@ def _parse_args(): help="List of frame counts (each should be 4n+1). Broadcasts if single value." ) parser.add_argument( - "--ckpt_dir", + "--checkpoint_dir", + type=str, + default=None, + help="The path to the main WAN checkpoint directory.") + parser.add_argument( + "--checkpoint_step", + type=int, + default=None, + help=( + "Optional training step to load, e.g. 1800 -> iter_0001800. " + "If not provided, the latest (largest) step in --checkpoint_dir is used.") + ) + parser.add_argument( + "--t5_checkpoint_dir", + type=str, + default=None, + help="Optional directory containing T5 checkpoint/tokenizer") + parser.add_argument( + "--vae_checkpoint_dir", type=str, default=None, - help="The path to the checkpoint directory.") + help="Optional directory containing VAE checkpoint") parser.add_argument( "--offload_model", type=str2bool, @@ -246,7 +262,10 @@ def generate(args): logging.info("Creating flow inference pipeline.") pipeline = FlowInferencePipeline( config=cfg, - checkpoint_dir=args.ckpt_dir, + checkpoint_dir=args.checkpoint_dir, + checkpoint_step=args.checkpoint_step, + t5_checkpoint_dir=args.t5_checkpoint_dir, + vae_checkpoint_dir=args.vae_checkpoint_dir, device_id=device, rank=rank, t5_cpu=args.t5_cpu, diff --git a/examples/recipes/wan/pretrain_wan.py b/examples/recipes/wan/pretrain_wan.py new file mode 100644 index 0000000000..d6a492f655 --- /dev/null +++ b/examples/recipes/wan/pretrain_wan.py @@ -0,0 +1,184 @@ + +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wan Pretraining Script with YAML and CLI Configuration Overrides. + +This script provides a flexible way to pretrain Wan models using Megatron-Bridge with support for +both YAML configuration files and command-line overrides using Hydra-style syntax. + +Examples: + Basic usage with default configuration: + $ torchrun --nproc_per_node=8 pretrain_wan.py + + Using a custom YAML config file: + $ torchrun --nproc_per_node=8 pretrain_wan.py --config-file my_custom_config.yaml + + Using CLI overrides only: + $ torchrun --nproc_per_node=8 pretrain_wan.py model.tensor_model_parallel_size=4 train.train_iters=100000 + + Combining YAML and CLI overrides (CLI takes precedence): + $ torchrun --nproc_per_node=8 pretrain_wan.py --config-file conf/my_config.yaml \ + model.pipeline_dtype=torch.float16 \ + train.global_batch_size=512 + +Configuration Precedence: + 1. Base configuration from pretrain_config() recipe + 2. YAML overrides from --config-file (if provided) + 3. CLI overrides (highest precedence) + +Supported Override Syntax: + - Standard assignment: key=value + - Nested assignment: section.subsection.key=value + - Addition: +new_key=value + - Deletion: ~key_to_remove + - Type conversion: Automatic for basic types (int, float, bool, str) + - Complex types: torch.dtype, enums, etc. are supported +""" + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import Tuple + +from omegaconf import OmegaConf + +from megatron.bridge.recipes.wan.wan import pretrain_config +from megatron.bridge.training.config import ConfigContainer +from megatron.bridge.models.wan.wan_step import WanForwardStep +from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.training.utils.omegaconf_utils import ( + apply_overrides, + create_omegaconf_dict_config, + parse_hydra_overrides, +) +from megatron.bridge.utils.common_utils import get_rank_safe + + +logger: logging.Logger = logging.getLogger(__name__) + + +# Define paths relative to this script's location +# Assumes this script (pretrain_wan.py) is in Megatron-Bridge/examples/recipes/wan/ +# and the config is in a 'conf' subdirectory. +SCRIPT_DIR: Path = Path(__file__).parent.resolve() +DEFAULT_CONFIG_FILENAME: str = "wan_pretrain_override_example.yaml" +DEFAULT_CONFIG_FILE_PATH: Path = SCRIPT_DIR / "conf" / DEFAULT_CONFIG_FILENAME + +# DEBUGGING +import numpy as np +import torch +np.set_printoptions(precision=10, suppress=False) +torch.set_printoptions(precision=10, sci_mode=False) + +def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: + """Parse command line arguments, separating known script args from OmegaConf overrides.""" + parser = argparse.ArgumentParser( + description="Pretrain Wan model using Megatron-Bridge with YAML and CLI overrides", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--config-file", + type=str, + default=str(DEFAULT_CONFIG_FILE_PATH), + help="Path to the YAML OmegaConf override file. Default: conf/wan_pretrain_override_example.yaml", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + + # Parse known args for the script, remaining will be treated as overrides + args, cli_dotlist_overrides = parser.parse_known_args() + return args, cli_dotlist_overrides + + +def main() -> None: + """ + Entry point for the Wan pretraining script. + + This function orchestrates the complete configuration workflow: + 1. Loads the base configuration from pretrain_config() recipe + 2. Applies YAML overrides from --config-file (if exists) + 3. Applies CLI overrides using Hydra-style syntax + 4. Starts Megatron pretraining with the final merged configuration + + Configuration merging preserves callable fields (like activation functions) + and handles type conversions automatically. + + Examples of CLI usage: + # Use default config with custom learning rate + torchrun --nproc_per_node=8 pretrain_wan.py optimizer.lr=0.0002 + + # Custom config file with additional overrides + torchrun --nproc_per_node=8 pretrain_wan.py --config-file my_config.yaml train.train_iters=50000 + + # Multiple overrides for distributed training + torchrun --nproc_per_node=8 pretrain_wan.py \ + model.tensor_model_parallel_size=4 \ + model.pipeline_model_parallel_size=2 \ + train.global_batch_size=512 + """ + args, cli_overrides = parse_cli_args() + + logger.info("Megatron-Bridge Wan Pretraining Script with YAML & CLI Overrides") + logger.info("------------------------------------------------------------------") + + # Load base configuration from the recipe as a Python dataclass + cfg: ConfigContainer = pretrain_config() + logger.info("Loaded base configuration") + + # Print configuration on rank 0 + if get_rank_safe() == 0: + cfg.print_yaml() + + # Convert the initial Python dataclass to an OmegaConf DictConfig for merging + merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg) + + # Load and merge YAML overrides if a config file is provided + if args.config_file: + logger.debug(f"Loading YAML overrides from: {args.config_file}") + if not os.path.exists(args.config_file): + logger.error(f"Override YAML file not found: {args.config_file}") + sys.exit(1) + yaml_overrides_omega = OmegaConf.load(args.config_file) + merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega) + logger.debug("YAML overrides merged successfully.") + + # Apply command-line overrides using Hydra-style parsing + if cli_overrides: + logger.debug(f"Applying Hydra-style command-line overrides: {cli_overrides}") + merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides) + logger.debug("Hydra-style command-line overrides applied successfully.") + + # Apply the final merged OmegaConf configuration back to the original ConfigContainer + logger.debug("Applying final merged configuration back to Python ConfigContainer...") + final_overrides_as_dict = OmegaConf.to_container(merged_omega_conf, resolve=True) + # Apply overrides while preserving excluded fields + apply_overrides(cfg, final_overrides_as_dict, excluded_fields) + + # Display final configuration + if get_rank_safe() == 0: + logger.info("--- Final Merged Configuration ---") + cfg.print_yaml() + logger.info("----------------------------------") + + # Start training + logger.debug("Starting pretraining...") + pretrain(config=cfg, forward_step_func=WanForwardStep()) + + +if __name__ == "__main__": + main() diff --git a/src/megatron/bridge/data/loaders.py b/src/megatron/bridge/data/loaders.py index 6c3aeda95c..7d45114436 100644 --- a/src/megatron/bridge/data/loaders.py +++ b/src/megatron/bridge/data/loaders.py @@ -219,7 +219,11 @@ def worker_init_fn(_): valid_dataloader = build_pretraining_data_loader( valid_ds, train_state.consumed_valid_samples, - "cyclic", + # DEBUGGING + # known issue: + # https://nvidia.slack.com/archives/C09MX7UEB0W/p1761316355203679 + # "cyclic", + "external", cfg.train.micro_batch_size, cfg.dataset.num_workers, cfg.dataset.data_sharding, diff --git a/src/megatron/bridge/data/wan/prepare_energon_dataset_wan.py b/src/megatron/bridge/data/wan/prepare_energon_dataset_wan.py new file mode 100644 index 0000000000..a8464aa6ec --- /dev/null +++ b/src/megatron/bridge/data/wan/prepare_energon_dataset_wan.py @@ -0,0 +1,404 @@ +import os +import json +import pickle +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import cv2 +import numpy as np +import torch +import webdataset as wds + +from diffusers import AutoencoderKLWan +from transformers import AutoTokenizer, UMT5EncoderModel + + +def _map_interpolation(resize_mode: str) -> int: + interpolation_map = { + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "nearest": cv2.INTER_NEAREST, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4, + } + if resize_mode not in interpolation_map: + raise ValueError(f"Invalid resize_mode '{resize_mode}'. Choose from: {list(interpolation_map.keys())}") + return interpolation_map[resize_mode] + + +def _calculate_resize_dimensions( + original_height: int, + original_width: int, + target_size: Optional[Tuple[int, int]], + maintain_aspect_ratio: bool, +) -> Tuple[int, int]: + if target_size is None: + return original_height, original_width + + target_height, target_width = target_size + if not maintain_aspect_ratio: + return target_height, target_width + + original_aspect = original_width / max(1, original_height) + target_aspect = target_width / max(1, target_height) + + if original_aspect > target_aspect: + new_width = target_width + new_height = int(round(target_width / max(1e-6, original_aspect))) + else: + new_height = target_height + new_width = int(round(target_height * original_aspect)) + + return new_height, new_width + + +def _resize_frame( + frame: np.ndarray, + target_size: Optional[Tuple[int, int]], + resize_mode: str, + maintain_aspect_ratio: bool, + center_crop: bool, +) -> np.ndarray: + if target_size is None: + return frame + + original_height, original_width = frame.shape[:2] + resize_height, resize_width = _calculate_resize_dimensions( + original_height, original_width, target_size, maintain_aspect_ratio + ) + + interpolation = _map_interpolation(resize_mode) + resized_frame = cv2.resize(frame, (resize_width, resize_height), interpolation=interpolation) + + if maintain_aspect_ratio and center_crop: + target_height, target_width = target_size + if resize_height != target_height or resize_width != target_width: + y_start = max(0, (resize_height - target_height) // 2) + x_start = max(0, (resize_width - target_width) // 2) + y_end = min(resize_height, y_start + target_height) + x_end = min(resize_width, x_start + target_width) + resized_frame = resized_frame[y_start:y_end, x_start:x_end] + + if resized_frame.shape[0] < target_height or resized_frame.shape[1] < target_width: + pad_height = max(0, target_height - resized_frame.shape[0]) + pad_width = max(0, target_width - resized_frame.shape[1]) + resized_frame = np.pad( + resized_frame, ((0, pad_height), (0, pad_width), (0, 0)), mode="constant", constant_values=0 + ) + + return resized_frame + + +def _read_sidecar_caption(jsonl_path: Path) -> str: + if not jsonl_path.exists(): + return "" + try: + with open(jsonl_path, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + except Exception: + continue + # Prefer keys used across datasets + for key in ("vila_caption", "gemini_v2_caption", "caption", "text"): + if key in obj and isinstance(obj[key], str): + return obj[key] + # If no known key, try first string value + for v in obj.values(): + if isinstance(v, str): + return v + break + except Exception: + return "" + return "" + + +def _get_total_frames(video_path: str) -> int: + cap = cv2.VideoCapture(video_path) + total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + cap.release() + return max(0, total) + + +def _load_metadata(video_folder: Path) -> List[Dict]: + meta_path = video_folder / "meta.json" + if meta_path.exists(): + with open(meta_path, "r") as f: + return json.load(f) + + # Fallback: scan for .mp4 files with sidecar .jsonl; use full frame range + items: List[Dict] = [] + for entry in sorted(video_folder.iterdir()): + if not entry.is_file(): + continue + if entry.suffix.lower() != ".mp4": + continue + video_name = entry.name + video_path = str(entry) + total_frames = _get_total_frames(video_path) + start_frame = 0 + end_frame = max(0, total_frames - 1) + sidecar = entry.with_suffix("") + # Handle names with additional dots gracefully + sidecar_jsonl = Path(str(entry).rsplit(".", 1)[0] + ".jsonl") + caption = _read_sidecar_caption(sidecar_jsonl) + items.append( + { + "file_name": video_name, + "start_frame": start_frame, + "end_frame": end_frame, + "vila_caption": caption, + } + ) + if not items: + raise FileNotFoundError(f"No meta.json and no .mp4 files found in {video_folder}") + return items + + +def _load_frames_cv2( + video_path: str, + start_frame: int, + end_frame: int, + target_size: Optional[Tuple[int, int]], + resize_mode: str, + maintain_aspect_ratio: bool, + center_crop: bool, + target_dtype: torch.dtype, +) -> torch.Tensor: + cap = cv2.VideoCapture(video_path) + frames: List[np.ndarray] = [] + + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + for frame_idx in range(start_frame, end_frame + 1): + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = _resize_frame(frame, target_size, resize_mode, maintain_aspect_ratio, center_crop) + frame = frame.astype(np.float32) / 255.0 + frames.append(frame) + cap.release() + + if not frames: + raise ValueError(f"No frames loaded from {video_path}") + + video_array = np.array(frames) # T, H, W, C in [0,1] + video_tensor = torch.from_numpy(video_array) # T, H, W, C + video_tensor = video_tensor.permute(3, 0, 1, 2).unsqueeze(0) # 1, C, T, H, W + video_tensor = video_tensor.to(dtype=target_dtype) + return video_tensor + + +@torch.no_grad() +def _init_hf_models( + model_id: str, + device: str, + enable_memory_optimization: bool, +): + dtype = torch.float16 if device.startswith("cuda") else torch.float32 + + text_encoder = UMT5EncoderModel.from_pretrained( + model_id, + subfolder="text_encoder", + torch_dtype=dtype, + ) + text_encoder.to(device) + text_encoder.eval() + + vae = AutoencoderKLWan.from_pretrained( + model_id, + subfolder="vae", + torch_dtype=dtype, + ) + vae.to(device) + vae.eval() + if enable_memory_optimization: + vae.enable_slicing() + vae.enable_tiling() + + tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer") + + return vae, text_encoder, tokenizer, dtype + + +@torch.no_grad() +def _encode_text( + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + device: str, + caption: str, +) -> torch.Tensor: + caption = caption.strip() + inputs = tokenizer( + caption, + max_length=512, + padding="max_length", + truncation=True, + return_tensors="pt", + return_attention_mask=True, + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + outputs = text_encoder(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).last_hidden_state + return outputs + + +@torch.no_grad() +def _encode_video_latents( + vae: AutoencoderKLWan, + device: str, + video_tensor: torch.Tensor, + deterministic_latents: bool, +) -> torch.Tensor: + video_tensor = video_tensor.to(device=device, dtype=vae.dtype) + video_tensor = video_tensor * 2.0 - 1.0 # [0,1] -> [-1,1] + + latent_dist = vae.encode(video_tensor) + if deterministic_latents: + video_latents = latent_dist.latent_dist.mean + else: + video_latents = latent_dist.latent_dist.sample() + + latent_mean = video_latents.mean().item() + latent_std = video_latents.std().item() + + if abs(latent_mean) < 0.5 and 0.5 < latent_std < 2.0: + final_latents = video_latents + else: + if not hasattr(vae.config, "latents_mean") or not hasattr(vae.config, "latents_std"): + raise ValueError("Wan2.1 VAE requires latents_mean and latents_std in config") + latents_mean = torch.tensor(vae.config.latents_mean, device=device, dtype=vae.dtype).view(1, -1, 1, 1, 1) + latents_std = torch.tensor(vae.config.latents_std, device=device, dtype=vae.dtype).view(1, -1, 1, 1, 1) + final_latents = (video_latents - latents_mean) / latents_std + + return final_latents + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Prepare WAN WebDataset shards using HF automodel encoders and resizing" + ) + parser.add_argument("--video_folder", type=str, required=True, help="Folder containing videos and meta.json") + parser.add_argument("--output_dir", type=str, required=True, help="Directory to write webdataset shards") + parser.add_argument( + "--model", + default="Wan-AI/Wan2.1-T2V-14B-Diffusers", + help="Wan2.1 model ID (e.g., Wan-AI/Wan2.1-T2V-14B-Diffusers or Wan-AI/Wan2.1-T2V-1.3B-Diffusers)", + ) + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") + parser.add_argument( + "--stochastic", + action="store_true", + help="Use stochastic encoding (sampling) instead of deterministic posterior mean", + ) + parser.add_argument("--no-memory-optimization", action="store_true", help="Disable VAE slicing/tiling") + parser.add_argument("--shard_maxcount", type=int, default=10000, help="Max samples per shard") + + # Resize arguments (match automodel) + parser.add_argument("--height", type=int, default=None, help="Target height for video frames") + parser.add_argument("--width", type=int, default=None, help="Target width for video frames") + parser.add_argument( + "--resize_mode", + default="bilinear", + choices=["bilinear", "bicubic", "nearest", "area", "lanczos"], + help="Interpolation mode for resizing", + ) + parser.add_argument("--no-aspect-ratio", action="store_true", help="Disable aspect ratio preservation") + parser.add_argument("--center-crop", action="store_true", help="Center crop to exact target size after resize") + + args = parser.parse_args() + + video_folder = Path(args.video_folder) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + shard_pattern = str(output_dir / "shard-%06d.tar") + + # Target size + target_size = None + if args.height is not None and args.width is not None: + target_size = (args.height, args.width) + elif (args.height is None) ^ (args.width is None): + parser.error("Both --height and --width must be specified together") + + # Init HF models + vae, text_encoder, tokenizer, model_dtype = _init_hf_models( + model_id=args.model, + device=args.device, + enable_memory_optimization=not args.no_memory_optimization, + ) + + # Load metadata list + metadata_list = _load_metadata(video_folder) + + with wds.ShardWriter(shard_pattern, maxcount=args.shard_maxcount) as sink: + written = 0 + for index, meta in enumerate(metadata_list): + video_name = meta["file_name"] + start_frame = int(meta["start_frame"]) # inclusive + end_frame = int(meta["end_frame"]) # inclusive + caption_text = meta.get("vila_caption", "") + + video_path = str(video_folder / video_name) + # Load frames using the same OpenCV + resize path as automodel + video_tensor = _load_frames_cv2( + video_path=video_path, + start_frame=start_frame, + end_frame=end_frame, + target_size=target_size, + resize_mode=args.resize_mode, + maintain_aspect_ratio=not args.no_aspect_ratio, + center_crop=args.center_crop, + target_dtype=model_dtype, + ) + + # Encode text and video with HF models exactly like automodel + text_embed = _encode_text(tokenizer, text_encoder, args.device, caption_text) + latents = _encode_video_latents(vae, args.device, video_tensor, deterministic_latents=not args.stochastic) + + # Move to CPU without changing dtype; keep exact values to match automodel outputs + text_embed_cpu = text_embed.detach().to(device="cpu") + latents_cpu = latents.detach().to(device="cpu") + + # Reshape to match Mcore's Wan input format + text_embed_cpu = text_embed_cpu[0] + latents_cpu = latents_cpu[0] + + # Build JSON side-info similar to prepare_energon script + C, T, H, W = video_tensor.shape[1:] # 1,C,T,H,W + json_data = { + "video_path": video_path, + "processed_frames": int(T), + "processed_height": int(H), + "processed_width": int(W), + "caption": caption_text, + "deterministic_latents": bool(not args.stochastic), + "memory_optimization": bool(not args.no_memory_optimization), + "model_version": "wan2.1", + "resize_settings": { + "target_size": target_size, + "resize_mode": args.resize_mode, + "maintain_aspect_ratio": bool(not args.no_aspect_ratio), + "center_crop": bool(args.center_crop), + }, + } + + sample = { + "__key__": f"{index:06}", + "pth": latents_cpu, + "pickle": pickle.dumps(text_embed_cpu), + "json": json_data, + } + sink.write(sample) + written += 1 + + print("Done writing shards using HF automodel encoders.") + + +if __name__ == "__main__": + main() + + diff --git a/src/megatron/bridge/data/wan/wan_energon_datamodule.py b/src/megatron/bridge/data/wan/wan_energon_datamodule.py new file mode 100644 index 0000000000..98774e8157 --- /dev/null +++ b/src/megatron/bridge/data/wan/wan_energon_datamodule.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from dataclasses import dataclass +import logging +from typing import Any, Dict, Literal + +from torch import int_repr + +from megatron.bridge.data.Dit.data.diffusion_energon_datamodule import DiffusionDataModule +from megatron.bridge.data.wan.wan_taskencoder import WanTaskEncoder +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider + +@dataclass(kw_only=True) +class WanDataModuleConfig(DatasetProvider): + path: str + seq_length: int + micro_batch_size: int + global_batch_size: int + num_workers: int_repr + dataloader_type: str = "external" + + def __post_init__(self): + self.dataset = DiffusionDataModule( + path=self.path, + seq_length=self.seq_length, + task_encoder=WanTaskEncoder(seq_length=self.seq_length), + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + num_workers=self.num_workers) + self.sequence_length = self.dataset.seq_length + + def build_datasets(self, context: DatasetBuildContext): + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() \ No newline at end of file diff --git a/src/megatron/bridge/data/wan/wan_taskencoder.py b/src/megatron/bridge/data/wan/wan_taskencoder.py new file mode 100644 index 0000000000..63f67bd721 --- /dev/null +++ b/src/megatron/bridge/data/wan/wan_taskencoder.py @@ -0,0 +1,190 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import torch +import torch.nn.functional as F +from megatron.energon import DefaultTaskEncoder, SkipSample +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys +from megatron.bridge.models.wan.utils.utils import grid_sizes_calculation, patchify +from megatron.core import parallel_state + + +def cook(sample: dict) -> dict: + """ + Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - All keys from the result of `basic_sample_keys(sample)` + - 'json': The contains meta data like resolution, aspect ratio, fps, etc. + - 'pth': contains video latent tensor + - 'pickle': contains text embeddings + """ + return dict( + **basic_sample_keys(sample), + json=sample["json"], + pth=sample["pth"], + pickle=sample["pickle"], + ) + + +class WanTaskEncoder(DefaultTaskEncoder): + """ + Task encoder for Wan dataset. + Attributes: + cookers (list): A list of Cooker objects used for processing. + patch_spatial (int): The spatial patch size. Defaults to 2. + patch_temporal (int): The temporal patch size. Defaults to 1. + seq_length (int): The sequence length. Defaults to 1024. + """ + + cookers = [ + Cooker(cook), + ] + + def __init__( + self, + *args, + max_frames: int = None, + patch_spatial: int = 2, + patch_temporal: int = 1, + seq_length: int = 1024, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.seq_length = seq_length + + + def encode_sample(self, sample: dict) -> dict: + + video_latent = sample["pth"] + context_embeddings = sample["pickle"] + video_metadata = sample["json"] + + # sanity quality check + if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): + raise SkipSample() + if torch.max(torch.abs(video_latent)) > 1e3: + raise SkipSample() + + # calculate grid size + grid_size = grid_sizes_calculation( + input_shape = video_latent.shape[1:], + patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial), + ) + + ### Note: shape of sample's values + # video_latent: [latents_channels, F_latents, W_latents, H_latents] + # grid_size: [F_patches, W_patches, H_patches] + # context_embeddings: [context_seq_len, text_embedding_dim] + + return dict( + video_latent=video_latent, + grid_size=grid_size, + context_embeddings=context_embeddings, + video_metadata=video_metadata, + ) + + + # def mock_encode_sample(self, sample: dict) -> dict: + + # # mock encode sample + # video_latent = torch.tensor(torch.randn(16, 3, 104, 60), dtype=torch.float32) + # # video_latent = torch.tensor(torch.randn(16, 24, 104, 60), dtype=torch.float32) + # grid_size = torch.tensor([video_latent.shape[1] // self.patch_temporal, video_latent.shape[2] // self.patch_spatial, video_latent.shape[3] // self.patch_spatial], dtype=torch.int32) + # context_embeddings = torch.tensor(torch.randn(512, 4096), dtype=torch.float32) + + # return dict( + # video_latent=video_latent, + # grid_size=grid_size, + # context_embeddings=context_embeddings, + # ) + + + def batch(self, samples: list[dict]) -> dict: + + # process video latents + # do padding here for video latents + self.patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial) + + # running patchify + video_latents = patchify([sample["video_latent"] for sample in samples], self.patch_size) + + # build per-sample loss masks (1 for valid tokens pre-padding) + loss_masks = [torch.ones(v.shape[0]) for v in video_latents] + # calculate all sequence lengths of video latents for self-attention (for videos, we do this before padding to get original seq len) + seq_len_q = [v.shape[0] for v in video_latents] + seq_len_q = torch.tensor(seq_len_q, dtype=torch.int32) + + + # padding and stack video latents + max_video_seq_len = max([video_latent.shape[0] for video_latent in video_latents]) + # CAVEAT: + # when using pipeline parallelism, we need to set batch sequence length to DataModule's seq_length because + # because pipeline parallelism requires pre-specified sequence length to create buffer + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + if max_video_seq_len > self.seq_length: + raise ValueError(f"max_video_seq_len {max_video_seq_len} is greater than DataModule's seq_length {self.seq_length}") + else: + # set max_video_seq_len to DataModule's seq_length + max_video_seq_len = self.seq_length + # CAVEAT: + # when using context parallelism, we need to pad batch sequence length to be divisible by [cp_rank*2] + # (because TransformerEngine's context parallelism requires "AssertionError: Sequence length per GPU needs to be divisible by 2!") + if parallel_state.get_context_parallel_world_size() > 1: + batch_size = len(video_latents) + assert batch_size == 1, "Error: Batch size must be 1 when using context parallelism" + sharding_factor = parallel_state.get_context_parallel_world_size() * 2 + max_video_seq_len = ((max_video_seq_len + sharding_factor - 1) // sharding_factor) * sharding_factor + video_latents = [F.pad(video_latent, (0, 0, 0, max_video_seq_len - video_latent.shape[0])) for video_latent in video_latents] + video_latents = torch.stack(video_latents, dim=1) + # pad and stack loss masks to shape [S_max, B] + loss_masks = [F.pad(m, (0, max_video_seq_len - m.shape[0])) for m in loss_masks] + loss_masks = torch.stack(loss_masks, dim=1) + + # process grid sizes + grid_sizes = [torch.tensor(sample["grid_size"], dtype=torch.int32) for sample in samples] + grid_sizes = torch.stack(grid_sizes, dim=0) + + # process text embeddings + # pad here for text embeddings + context_max_len = 512 + context_embeddings = [sample["context_embeddings"] for sample in samples] + context_embeddings = [F.pad(context_embedding, (0, 0, 0, context_max_len - context_embedding.shape[0])) for context_embedding in context_embeddings] + # calculate all sequence lengths of context embeddings for cross-attention (for videos, we do this after padding to get padded seq len) + seq_len_kv = [c.shape[0] for c in context_embeddings] + seq_len_kv = torch.tensor(seq_len_kv, dtype=torch.int32) + # stack context embeddings + context_embeddings = torch.stack(context_embeddings, dim=1) + + # process video metadata + video_metadata = [sample["video_metadata"] for sample in samples] + + return dict( + video_latents = video_latents, + max_video_seq_len = max_video_seq_len, + grid_sizes = grid_sizes, + context_embeddings = context_embeddings, + loss_mask = loss_masks, + seq_len_q = seq_len_q, + seq_len_kv = seq_len_kv, + video_metadata = video_metadata, + ) \ No newline at end of file diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index 70e33b3734..e3014dcb49 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -1339,6 +1339,90 @@ def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": ) +class KVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): + """ + Mapping for interleaved Key/Value projection weights. + + This mapping converts between separate K and V tensors used in external + checkpoints and Megatron's interleaved KV format following grouped-query + attention semantics. + + External format (HF) + - Separate tensors: k_proj, v_proj + - Shapes mirror QKV mappings but without Q + + Megatron format + - Single interleaved tensor with order: [k1, v1, k2, v2, ...] + where index corresponds to query-group id + + Tensor-parallel distribution is delegated to AutoMapping. + """ + + def __init__(self, megatron_param: str, k: str, v: str): + super().__init__(megatron_param, {"k": k, "v": v}) + # Delegate TP sharding/broadcasting + self._tp_mapping = AutoMapping(megatron_param, megatron_param) + + def hf_to_megatron( + self, + hf_weights: Dict[str, torch.Tensor], + megatron_module: nn.Module, + ) -> torch.Tensor: + """Merge K and V into interleaved format and distribute across TP.""" + if self.tp_rank == 0: + config = self._get_config(megatron_module) + + if hf_weights["k"].ndim == 1: + merged = merge_kv_biases(config, hf_weights["k"], hf_weights["v"]) + else: + merged = merge_kv_weights(config, hf_weights["k"], hf_weights["v"]) + else: + merged = None + + return self._tp_mapping.hf_to_megatron(merged, megatron_module) + + def megatron_to_hf( + self, + megatron_weights: Optional[torch.Tensor], + megatron_module: Optional[nn.Module], + ) -> Dict[str, torch.Tensor]: + """Gather KV shards and split into separate K and V tensors.""" + if megatron_weights is not None: + megatron_weights = self.maybe_dequantize(megatron_weights) + + # Ensure all PP ranks participate in config broadcast + if megatron_module is None: + config = self.broadcast_obj_from_pp_rank(None, "kv_config") + else: + config = self._get_config(megatron_module) + config = remove_non_pickleables(config, max_depth=2) + config = self.broadcast_obj_from_pp_rank(config, "kv_config") + + packed_dict = self._tp_mapping.megatron_to_hf(megatron_weights, megatron_module) + if not packed_dict: + return {} + + packed_kv = next(iter(packed_dict.values())) + + if packed_kv.ndim == 1: + k, v = split_kv_biases(config, packed_kv) + else: + k, v = split_kv_weights(config, packed_kv) + + return { + self.hf_param["k"]: k, + self.hf_param["v"]: v, + } + + def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": + resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) + return type(self)( + resolved_megatron_param, + resolved_hf_param["k"], + resolved_hf_param["v"], + ) + + class GatedMLPMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): r"""Mapping for **gated-MLP** projection weights (SwiGLU / GeGLU). @@ -1652,3 +1736,71 @@ def split_qkv_weights( v = v.reshape(-1, hidden_size) return q, k, v + + +def merge_kv_biases(config: TransformerConfig, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Merge separate K, V bias vectors into Megatron's interleaved KV format (1D).""" + num_query_groups = config.num_query_groups + head_size = config.kv_channels or (config.hidden_size // config.num_attention_heads) + + k = k.view(num_query_groups, head_size) + v = v.view(num_query_groups, head_size) + + pieces: List[torch.Tensor] = [] + for i in range(num_query_groups): + pieces.append(k[i : i + 1, :]) + pieces.append(v[i : i + 1, :]) + + kv = torch.cat(pieces, dim=0) + return kv.reshape(-1) + + +def split_kv_biases(config: TransformerConfig, kv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Split Megatron's interleaved KV bias (1D) into separate K and V biases.""" + num_query_groups = config.num_query_groups + head_size = config.kv_channels or (config.hidden_size // config.num_attention_heads) + kv_total_dim = 2 * num_query_groups + + kv_reshaped = kv.view(kv_total_dim, head_size) + + k_slice = torch.arange(0, kv_total_dim, 2) + v_slice = torch.arange(1, kv_total_dim, 2) + + k = kv_reshaped[k_slice].reshape(-1) + v = kv_reshaped[v_slice].reshape(-1) + return k, v + + +def merge_kv_weights(provider: TransformerConfig, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Merge separate K, V weights into Megatron's interleaved KV format (2D).""" + num_query_groups = provider.num_query_groups + head_size = provider.kv_channels or (provider.hidden_size // provider.num_attention_heads) + hidden_size = provider.hidden_size + + k_reshaped = k.view(num_query_groups, head_size, hidden_size) + v_reshaped = v.view(num_query_groups, head_size, hidden_size) + + pieces: List[torch.Tensor] = [] + for i in range(num_query_groups): + pieces.append(k_reshaped[i : i + 1]) + pieces.append(v_reshaped[i : i + 1]) + + kv = torch.cat(pieces, dim=0) + return kv.view(-1, hidden_size) + + +def split_kv_weights(provider: TransformerConfig, kv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Split Megatron's interleaved KV weights (2D) into separate K and V matrices.""" + num_query_groups = provider.num_query_groups + head_size = provider.kv_channels or (provider.hidden_size // provider.num_attention_heads) + hidden_size = kv.shape[-1] + kv_total_dim = 2 * num_query_groups + + kv_reshaped = kv.view(kv_total_dim, head_size, hidden_size) + + k_slice = torch.arange(0, kv_total_dim, 2) + v_slice = torch.arange(1, kv_total_dim, 2) + + k = kv_reshaped[k_slice].reshape(-1, hidden_size) + v = kv_reshaped[v_slice].reshape(-1, hidden_size) + return k, v diff --git a/src/megatron/bridge/models/hf_pretrained/__init__.py b/src/megatron/bridge/models/hf_pretrained/__init__.py index de1604f253..9bfb9fd83f 100644 --- a/src/megatron/bridge/models/hf_pretrained/__init__.py +++ b/src/megatron/bridge/models/hf_pretrained/__init__.py @@ -14,6 +14,7 @@ from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM +from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN -__all__ = ["PreTrainedCausalLM", "PreTrainedVLM"] +__all__ = ["PreTrainedCausalLM", "PreTrainedVLM", "PreTrainedWAN"] diff --git a/src/megatron/bridge/models/hf_pretrained/state.py b/src/megatron/bridge/models/hf_pretrained/state.py index a47a22771d..b35f2c05f9 100644 --- a/src/megatron/bridge/models/hf_pretrained/state.py +++ b/src/megatron/bridge/models/hf_pretrained/state.py @@ -496,7 +496,8 @@ def key_to_filename_map(self) -> Dict[str, str]: from safetensors import safe_open key_map = {} - safetensor_files = file_glob(str(self.path / "*.safetensors")) + # DEBUGGING + safetensor_files = file_glob(str(self.path / "transformer" / "*.safetensors")) for file_path in safetensor_files: filename = os.path.basename(file_path) try: @@ -564,7 +565,8 @@ def get_all_keys(self) -> List[str]: all_keys.update(key_to_filename_map.keys()) if not all_keys: - safetensor_files = file_glob(str(self.path / "*.safetensors")) + # DEBUGGING + safetensor_files = file_glob(str(self.path / "transformer" / "*.safetensors")) if not safetensor_files and not key_to_filename_map: raise FileNotFoundError(f"No .safetensors files or index found in {self.model_name_or_path}") for safetensor_file in safetensor_files: @@ -603,7 +605,8 @@ def load_tensors(self, keys_to_load: List[str]) -> Dict[str, torch.Tensor]: remaining_keys.discard(key) if remaining_keys: - safetensor_files = file_glob(str(self.path / "*.safetensors")) + # DEBUGGING + safetensor_files = file_glob(str(self.path / "transformer" / "*.safetensors")) if not safetensor_files and not key_to_filename_map and not loaded_tensors: raise FileNotFoundError( f"No .safetensors files found in {self.model_name_or_path} to load keys: {remaining_keys}" @@ -650,7 +653,8 @@ def has_glob(self, pattern: str) -> bool: return False # If no index map, scan the files directly. - safetensor_files = file_glob(str(self.path / "*.safetensors")) + # DEBUGGING + safetensor_files = file_glob(str(self.path / "transformer" / "*.safetensors")) if not safetensor_files: return False diff --git a/src/megatron/bridge/models/hf_pretrained/wan.py b/src/megatron/bridge/models/hf_pretrained/wan.py new file mode 100644 index 0000000000..97aa6f853c --- /dev/null +++ b/src/megatron/bridge/models/hf_pretrained/wan.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Optional, Union + +from diffusers import WanTransformer3DModel +from transformers import AutoConfig + +from megatron.bridge.models.hf_pretrained.base import PreTrainedBase + + +class PreTrainedWAN(PreTrainedBase): + """ + Lightweight pretrained wrapper for Diffusers WAN models. + + Provides access to WAN config and state through the common PreTrainedBase API + so bridges can consume `.config` and `.state` uniformly. + """ + + def __init__(self, model_name_or_path: Union[str, Path], **kwargs): + self._model_name_or_path = str(model_name_or_path) + super().__init__(**kwargs) + + @property + def model_name_or_path(self) -> str: + return self._model_name_or_path + + # Model loading is optional for conversion; implemented for completeness + def _load_model(self) -> WanTransformer3DModel: + return WanTransformer3DModel.from_pretrained(self.model_name_or_path) + + # Config is required by the WAN bridge + def _load_config(self) -> AutoConfig: + # WanTransformer3DModel returns a config-like object with required fields + + print(f"Loading config from {self.model_name_or_path}") + + return WanTransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer").config + + diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py index 83314df11c..fedef4f40d 100644 --- a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py +++ b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import gc import logging import math @@ -6,6 +5,7 @@ import random import sys import types +import re from contextlib import contextmanager from functools import partial @@ -24,8 +24,10 @@ retrieve_timesteps, ) from megatron.bridge.models.wan.inference.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from megatron.bridge.models.wan.utils.utils import grid_sizes_calculation, patchify from megatron.core import parallel_state from torch.nn import functional as F +from megatron.bridge.models.wan.utils.utils import split_inputs_cp, cat_outputs_cp import math from typing import Tuple, Union @@ -36,9 +38,13 @@ def __init__( self, config, checkpoint_dir, + checkpoint_step=None, + t5_checkpoint_dir=None, + vae_checkpoint_dir=None, device_id=0, rank=0, t5_cpu=False, + tensor_parallel_size=1, context_parallel_size=1, pipeline_parallel_size=1, @@ -53,6 +59,10 @@ def __init__( Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints + t5_checkpoint_dir (`str`, *optional*, defaults to None): + Optional directory containing T5 checkpoint and tokenizer; falls back to `checkpoint_dir` if None. + vae_checkpoint_dir (`str`, *optional*, defaults to None): + Optional directory containing VAE checkpoint; falls back to `checkpoint_dir` if None. device_id (`int`, *optional*, defaults to 0): Id of target GPU device rank (`int`, *optional*, defaults to 0): @@ -76,18 +86,22 @@ def __init__( text_len=config.text_len, dtype=config.t5_dtype, device=torch.device('cpu'), - checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + checkpoint_path=os.path.join(t5_checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(t5_checkpoint_dir, config.t5_tokenizer), shard_fn=None) self.vae_stride = config.vae_stride self.patch_size = config.patch_size self.vae = WanVAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + vae_pth=os.path.join(vae_checkpoint_dir, config.vae_checkpoint), device=self.device) - wan_checkpoint_dir = os.path.join(checkpoint_dir, "iter_0000000") + wan_checkpoint_dir = self._select_checkpoint_dir(checkpoint_dir, checkpoint_step) self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) + + # DEBUGGING + # set qkv_format to to "thd" for context parallelism + self.model.config.qkv_format = "sbhd" # set self.sp_size=1 for later use, just to respect the original Wan inference code self.sp_size = 1 @@ -97,39 +111,6 @@ def __init__( self.model.to(self.device) self.sample_neg_prompt = config.sample_neg_prompt - - - def patchify(self, x, patch_size): - """ - Convert a list of reconstructed video tensor into patch embeddings. - This method is the inverse of `unpatchify`. - - Args: - x (list[torch.Tensor]): list of tensors, each with shape [c, F_patches * pF, H_patches * pH, W_patches * pW] - patch_size (tuple): (pF, pH, pW) - - Returns: - torch.Tensor: shape [ (F_patches * H_patches * W_patches), (c * pF * pH * pW)], - """ - out = [] - for u in x: - c, F_pF, H_pH, W_pW = u.shape - pF, pH, pW = patch_size - assert F_pF % pF == 0 and H_pH % pH == 0 and W_pW % pW == 0, \ - "Spatial dimensions must be divisible by patch size." - - F_patches, H_patches, W_patches = F_pF // pF, H_pH // pH, W_pW // pW - - # split spatial dims into (grid, patch) and reorder to match original patch layout: - # start: (c, F_patches * pF, H_patches * pH, W_patches * pW) - # reshape -> (c, F_patches, pF, H_patches, pH, W_patches, pW) - # permute -> (F_patches, H_patches, W_patches, pF, pH, pW, c) - t = u.reshape(c, F_patches, pF, H_patches, pH, W_patches, pW) - t = t.permute(1, 3, 5, 0, 2, 4, 6) - - num_patches = F_patches * H_patches * W_patches - out.append(t.reshape(num_patches, c * (pF * pH * pW))) - return out def unpatchify(self, x: torch.Tensor, grid_sizes: torch.Tensor, out_dim: int) -> list[torch.Tensor]: @@ -182,47 +163,41 @@ def setup_model_from_checkpoint(self, checkpoint_dir): ) if isinstance(model, list): model = model[0] + if hasattr(model, "module"): + model = model.module return model - - def grid_sizes_calculation( - self, - input_shape: Tuple[int, int, int], # (F_latents, H_latents, W_latents) - kernel_size: Union[int, Tuple[int, int, int]], - stride: Union[int, Tuple[int, int, int]] = 1, - padding: Union[int, Tuple[int, int, int]] = 0, - dilation: Union[int, Tuple[int, int, int]] = 1 - ) -> Tuple[int, int, int]: + def _select_checkpoint_dir(self, base_dir: str, checkpoint_step) -> str: """ - Compute the (f,h,w) output spatial/temporal dimensions of a Conv3d patch embedder. - - Args: - input_shape: (F_latents, H_latents, W_latents) - kernel_size, stride, padding, dilation of the Conv3d patch embedder: either int or 3-tuple - - Returns: - (F_patches, H_patches, W_patches) + Resolve checkpoint directory: + - If checkpoint_step is provided, use base_dir/iter_{step:07d} + - Otherwise, pick the largest iter_######## subdirectory under base_dir """ - - def to_tuple(x): - return (x, x, x) if isinstance(x, int) else x - - kernel_size = to_tuple(kernel_size) - stride = to_tuple(stride) - padding = to_tuple(padding) - dilation = to_tuple(dilation) - - D_in, H_in, W_in = input_shape - - def calc_out(in_size, k, s, p, d): - return math.floor((in_size + 2*p - d*(k - 1) - 1) / s + 1) - - D_out = calc_out(D_in, kernel_size[0], stride[0], padding[0], dilation[0]) - H_out = calc_out(H_in, kernel_size[1], stride[1], padding[1], dilation[1]) - W_out = calc_out(W_in, kernel_size[2], stride[2], padding[2], dilation[2]) - - return [D_out, H_out, W_out] + if checkpoint_step is not None: + path = os.path.join(base_dir, f"iter_{int(checkpoint_step):07d}") + if os.path.isdir(path): + logging.info(f"Using specified checkpoint: {path}") + return path + raise FileNotFoundError(f"Specified checkpoint step {checkpoint_step} not found at {path}") + + if not os.path.isdir(base_dir): + raise FileNotFoundError(f"Checkpoint base directory does not exist: {base_dir}") + + pattern = re.compile(r"^iter_(\d+)$") + try: + _, latest_path = max( + ((int(pattern.match(e.name).group(1)), e.path) + for e in os.scandir(base_dir) + if e.is_dir() and pattern.match(e.name)), + key=lambda x: x[0], + ) + except ValueError: + raise FileNotFoundError( + f"No checkpoints found under {base_dir}. Expected subdirectories named like 'iter_0001800'.") + + logging.info(f"Auto-selected latest checkpoint: {latest_path}") + return latest_path def forward_pp_step( @@ -419,10 +394,9 @@ def generate(self, # calculate grid_sizes - grid_sizes = [self.grid_sizes_calculation( + grid_sizes = [grid_sizes_calculation( input_shape =u.shape[1:], - kernel_size=self.model.patch_size, - stride=self.model.patch_size, + patch_size=self.model.patch_size, ) for u in noises] grid_sizes = torch.tensor(grid_sizes, dtype=torch.long) @@ -482,12 +456,12 @@ def noop_no_sync(): "self_attention": PackedSeqParams( cu_seqlens_q=cu_q, cu_seqlens_kv=cu_kv_self, - qkv_format="sbhd", + qkv_format=self.model.config.qkv_format, ), "cross_attention": PackedSeqParams( cu_seqlens_q=cu_q, cu_seqlens_kv=cu_kv_cross, - qkv_format="sbhd", + qkv_format=self.model.config.qkv_format, ), } @@ -501,7 +475,7 @@ def noop_no_sync(): # patchify latents unpatchified_latents = latents - latents = self.patchify(latents, self.patch_size) + latents = patchify(latents, self.patch_size) # pad to have same length for i in range(batch_size): latents[i] = F.pad(latents[i], (0, 0, 0, max_video_seq_len - latents[i].shape[0])) @@ -512,6 +486,12 @@ def noop_no_sync(): timestep = [t] * batch_size timestep = torch.stack(timestep) + # run context parallelism slitting + if parallel_state.get_context_parallel_world_size() > 1: + latent_model_input = split_inputs_cp(latent_model_input, 0) + arg_c['context'] = split_inputs_cp(arg_c['context'], 0) + arg_null['context'] = split_inputs_cp(arg_null['context'], 0) + self.model.to(self.device) noise_pred_cond = self.forward_pp_step( latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_c) @@ -519,6 +499,15 @@ def noop_no_sync(): noise_pred_uncond = self.forward_pp_step( latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_null) + # run context parallelism gathering + if parallel_state.get_context_parallel_world_size() > 1: + arg_c['context'] = cat_outputs_cp(arg_c['context'], 0) # we need to cat the context back together for the next timestep + arg_null['context'] = cat_outputs_cp(arg_null['context'], 0) # we need to cat the context back together for the next timestep + # TODO: does this step slow down speed??? + noise_pred_cond = noise_pred_cond.contiguous() + noise_pred_uncond = noise_pred_uncond.contiguous() + noise_pred_cond = cat_outputs_cp(noise_pred_cond, 0) + noise_pred_uncond = cat_outputs_cp(noise_pred_uncond, 0) # run unpatchify unpatchified_noise_pred_cond = noise_pred_cond diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py index 850230eced..9d272a131e 100644 --- a/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py +++ b/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py @@ -16,200 +16,153 @@ import numpy as np import torch -import torch.distributed from megatron.core import parallel_state -# from megatron.bridge.models.DiTModel.sampler.context_parallel import cat_outputs_cp ??? from torch import Tensor from diffusers import WanPipeline +from megatron.bridge.models.wan.flow_matching.time_shift_utils import compute_density_for_timestep_sampling +from megatron.bridge.models.wan.utils.utils import patchify, split_inputs_cp class FlowPipeline: - """ - FlowPipeline is a class that implements a diffusion model pipeline for video generation. It includes methods for - initializing the pipeline, encoding and decoding video data, performing training steps, denoising, and generating - samples. - Attributes: - ... - Methods: - ... - """ def __init__( self, - model_id="Wan-AI/Wan2.2-T2V-A14B-Diffusers", - vae=None, + model_id="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", seed=1234, ): """ Initializes the FlowPipeline with the given parameters. - - Args: - net: The DiT model. - vae: The Video Tokenizer (optional). - seed (int): Random seed for reproducibility. - - Attributes: - vae: The Video Tokenizer. - net: The DiT model. - _noise_generator: Generator for noise. - seed (int): Random seed for reproducibility. - input_data_key (str): Key for input data. - input_image_key (str): Key for input images. - tensor_kwargs (dict): Tensor keyword arguments for device and dtype. """ - self.vae = vae - - self.seed = seed - self._noise_generator = None - - self.input_data_key = "video" - self.input_image_key = "images_1024" - self.tensor_kwargs = {"device": "cuda", "dtype": torch.bfloat16} - - pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.float32) - self.scheduler = pipe.scheduler - + self.pipe = WanPipeline.from_pretrained(model_id, vae=None, torch_dtype=torch.float32, text_encoder=None) - def _initialize_generators(self): - """ - Initializes the random number generators for noise - - This method sets up a generator: - 1. A PyTorch generator for noise, seeded with a combination of the base seed and the data parallel rank. - - Returns: - None - """ - noise_seed = self.seed + 100 * parallel_state.get_data_parallel_rank(with_context_parallel=True) - noise_level_seed = self.seed + 100 * parallel_state.get_data_parallel_rank(with_context_parallel=False) - self._noise_generator = torch.Generator(device="cuda") - self._noise_generator.manual_seed(noise_seed) def training_step( - self, model, data_batch: dict[str, torch.Tensor] + self, + model, + data_batch: dict[str, torch.Tensor], + # Flow matching parameters + use_sigma_noise: bool = True, + timestep_sampling: str = "uniform", + logit_mean: float = 0.0, + logit_std: float = 1.0, + flow_shift: float = 3.0, + mix_uniform_ratio: float = 0.1, ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: """ - Performs a single training step for the diffusion model. + Performs a single training step using flow matching algorithm. This method is responsible for executing one iteration of the model's training. It involves: - 1. Adding noise to the input data using the SDE process. - 2. Passing the noisy data through the network to generate predictions. - 3. Computing the loss based on the difference between the predictions and the original data. - - Args: - data_batch (dict): raw data batch draw from the training data loader. - - Returns: - A tuple with the output batch and the computed loss. + 1. Generate noise and add it to the input data. + 2. Pass the noisy data through the network to generate predictions. + 3. Compute the loss based on the difference between the predictions and target. """ - # DEBUGGING - run_debug = False - if run_debug and torch.distributed.get_rank()==0: - print("---- Sample info [FlowPipeline.training_step] ----") - print(f"data_batch['video_latents'] shape: {data_batch['video_latents'].shape}") - print(f"data_batch['context_embeddings'] shape: {data_batch['context_embeddings'].shape}") - print(f"data_batch['loss_mask'] shape: {data_batch['loss_mask'].shape}") - print(f"data_batch['grid_sizes']: {data_batch['grid_sizes']}") - print(f"data_batch['packed_seq_params']: {data_batch['packed_seq_params']}") - print(f"data_batch['max_video_seq_len']: {data_batch['max_video_seq_len']}") - - video_latents = data_batch['video_latents'] max_video_seq_len = data_batch['max_video_seq_len'] context_embeddings = data_batch['context_embeddings'] + loss_mask = data_batch['loss_mask'] grid_sizes = data_batch['grid_sizes'] packed_seq_params = data_batch['packed_seq_params'] + video_metadata = data_batch['video_metadata'] - - # Get the input data to noise and denoise~(image, video) and the corresponding conditioner. self.model = model - - # Get timesteps batch_size = video_latents.shape[1] device = video_latents.device - timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (batch_size,), device=device) - # Generate noise - # shape of latents is [S, B, (C pF pH pW)] - noise_batch = torch.randn_like(video_latents) - - - # DEBUGGING - if run_debug and torch.distributed.get_rank()==0: - print("---- Sample info [FlowPipeline.training_step] ----") - print(f"noise_batch shape: {noise_batch.shape}") - print(f"timesteps shape: {timesteps.shape}") - print(f"video_latents shape: {video_latents.shape}") - print("--------------------------------") - - # ??? can this add_noise method used for videos of different sizes and just padding? - # => it should be, because the main formula is: noisy_latents = alpha_t * original_samples + sigma_t * noise - # Apply scheduler noise based on timesteps - # DEBUGGING - # bring to shape [batch_size, ...] to run add_noise - noisy_latents = self.scheduler.add_noise(video_latents.transpose(0, 1), noise_batch.transpose(0, 1), timesteps) - noisy_latents = noisy_latents.transpose(0, 1) - - # Pass through model - # noise only needed at the last stage - if parallel_state.is_pipeline_last_stage(): - output_batch, loss = self.compute_loss( - noisy_latents, noise_batch, timesteps, context_embeddings, grid_sizes, packed_seq_params, max_video_seq_len - ) + # # # DEBUGGING precision + # # import torch.cuda.amp as amp + # # with amp.autocast(dtype=torch.bfloat16): + # # # Pass through model + # # ... - return output_batch, loss + # ======================================================================== + # Flow Matching Timestep Sampling + # ======================================================================== + + num_train_timesteps = self.pipe.scheduler.config.num_train_timesteps + + if use_sigma_noise: + use_uniform = torch.rand(1).item() < mix_uniform_ratio + + if use_uniform or timestep_sampling == "uniform": + # Pure uniform: u ~ U(0, 1) + u = torch.rand(size=(batch_size,), device=device) + sampling_method = "uniform" + else: + # Density-based sampling + u = compute_density_for_timestep_sampling( + weighting_scheme=timestep_sampling, + batch_size=batch_size, + logit_mean=logit_mean, + logit_std=logit_std, + ).to(device) + sampling_method = timestep_sampling + + # Apply flow shift: σ = shift/(shift + (1/u - 1)) + u_clamped = torch.clamp(u, min=1e-5) # Avoid division by zero + sigma = flow_shift / (flow_shift + (1.0 / u_clamped - 1.0)) + sigma = torch.clamp(sigma, 0.0, 1.0) + else: - hidden_states = self.compute_loss( - noisy_latents, timesteps, context_embeddings, grid_sizes, packed_seq_params, max_video_seq_len - ) - return hidden_states + # Simple uniform without shift + u = torch.rand(size=(batch_size,), device=device) + sigma = u + sampling_method = "uniform_no_shift" + + # ======================================================================== + # Manual Flow Matching Noise Addition + # ======================================================================== + + # Generate noise + noise = torch.randn_like(torch.ones([1, 16, grid_sizes[0][0], grid_sizes[0][1]*2, grid_sizes[0][2]*2], device=video_latents.device), dtype=torch.float32) + noise = patchify(noise, (1, 2, 2))[0].unsqueeze(1) + + # CRITICAL: Manual flow matching (NOT scheduler.add_noise!) + # x_t = (1 - σ) * x_0 + σ * ε + sigma_reshaped = sigma.view(1, batch_size, 1) + noisy_latents = ( + (1.0 - sigma_reshaped) * video_latents.float() + + sigma_reshaped * noise + ) + + # Timesteps for model [0, 1000] + timesteps = sigma * num_train_timesteps - # def get_data_and_condition(self, data_batch: dict[str, Tensor]) -> Tuple[Tensor]: - # """ - # Retrieves data and conditioning for model input. + # ======================================================================== + # Cast model inputs to bf16 + # ======================================================================== - # Args: - # data_batch: Batch of input data. + video_latents = video_latents.to(torch.bfloat16) + noisy_latents = noisy_latents.to(torch.bfloat16) + context_embeddings = context_embeddings.to(torch.bfloat16) + timesteps = timesteps.to(torch.bfloat16) - # Returns: - # ... - # """ - # ... - # return None + # ======================================================================== + # Split accross context parallelism + # ======================================================================== + + if parallel_state.get_context_parallel_world_size() > 1: + video_latents = split_inputs_cp(video_latents, 0) + noisy_latents = split_inputs_cp(noisy_latents, 0) + noise = split_inputs_cp(noise, 0) + context_embeddings = split_inputs_cp(context_embeddings, 0) + split_loss_mask = split_inputs_cp(loss_mask, 0) + else: + video_latents = video_latents + noisy_latents = noisy_latents + noise = noise + context_embeddings = context_embeddings + split_loss_mask = loss_mask - def compute_loss( - self, - video_latents: torch.Tensor, - noise_batch: torch.Tensor, - timesteps: torch.Tensor, - context_embeddings: torch.Tensor, - grid_sizes: List[Tuple[int, int, int]], - packed_seq_params: dict, - max_video_seq_len: int, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Computes the loss for the given latents, timesteps, context_embeddings, grid_sizes, and packed_seq_params. - """ - # ??? the shape of latents is [S, B, (ph pw pt C)] - # ??? the shape of noise is [S, B, (ph pw pt C)] - # loss_mask is [S, B], will be transffered in WanForwardStep to combine with loss to get the final loss - - # condition would be: - # t5_text_embeddings, t5_text_mask, seq_len_q, seq_len_kv, pos_ids, latent_shape, grid_sizes - # the shape of t5_text_embeddings is [S, B, (ph pw pt C)] - # the shape of t5_text_mask is [S, B] - # the shape of seq_len_q is [B] - # the shape of seq_len_kv is [B] - # the shape of pos_ids is [S, B, (ph pw pt C)] - # the shape of latent_shape is [B, 4] - # the shape of grid_sizes is [B, 3] - - # Pass through model + # ======================================================================== + # Forward Pass + # ======================================================================== + if parallel_state.is_pipeline_last_stage(): - model_predict = self.model( - x = video_latents, + + model_pred = self.model( + x = noisy_latents, grid_sizes = grid_sizes, t = timesteps, context = context_embeddings, @@ -217,25 +170,41 @@ def compute_loss( packed_seq_params=packed_seq_params, ) - # Compute target based on prediction type - if self.scheduler.config.prediction_type == "epsilon": - target = noise_batch - elif self.scheduler.config.prediction_type == "v_prediction": - target = self.scheduler.get_velocity(latents, noise_batch, timesteps) - elif self.scheduler.config.prediction_type == "flow_prediction": - # Flow matching - target = video_latents - noise_batch - else: - raise ValueError(f"Unknown prediction type: {self.scheduler.config.prediction_type}") + # ======================================================================== + # Target: Flow Matching Velocity + # ======================================================================== + + # Flow matching target: v = ε - x_0 + target = noise - video_latents.float() + + # ======================================================================== + # Loss with Flow Weighting + # ======================================================================== + + loss = torch.nn.functional.mse_loss( + model_pred.float(), + target.float(), + reduction="none" + ) + + # Flow weight: w = 1 + shift * σ + loss_weight = 1.0 + flow_shift * sigma # shape [batch_size] + loss_weight = loss_weight.view(1, batch_size, 1).to(device) # shape [1, batch_size, 1] + unweighted_loss = loss + weighted_loss = (loss * loss_weight) # shape [seq_length / cp_size, batch_size, -1] - # Compute loss - loss = torch.nn.functional.mse_loss(model_predict, target, reduction="mean") + # Safety check + mean_weighted_loss = weighted_loss.mean() + if torch.isnan(mean_weighted_loss) or mean_weighted_loss > 100: + print(f"[ERROR] Loss explosion! Loss={mean_weighted_loss.item():.3f}") + print(f"[DEBUG] Stopping training - check hyperparameters") + raise ValueError(f"Loss exploded: {mean_weighted_loss.item()}") - return model_predict, loss + return model_pred, weighted_loss, split_loss_mask else: hidden_states = self.model( - x = video_latents, + x = noisy_latents, grid_sizes = grid_sizes, t = timesteps, context = context_embeddings, @@ -243,4 +212,4 @@ def compute_loss( packed_seq_params=packed_seq_params, ) - return hidden_states + return hidden_states \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/flow_matching/time_shift_utils.py b/src/megatron/bridge/models/wan/flow_matching/time_shift_utils.py new file mode 100644 index 0000000000..56faee4b20 --- /dev/null +++ b/src/megatron/bridge/models/wan/flow_matching/time_shift_utils.py @@ -0,0 +1,108 @@ +# time_shift_utils.py - Timestep sampling and sigma computation utilities + +import math +import numpy as np +import torch + + +def time_shift( + t: torch.Tensor, + image_seq_len: int, + shift_type: str = "constant", + base_shift: float = 0.5, + max_shift: float = 1.15, + constant: float = 3.0, +): + """ + Convert timesteps to sigmas with sequence-length-aware shifting. + + Args: + t: timesteps in range [0, 1] + image_seq_len: number of tokens (frames * height * width / patch_size^2) + shift_type: "linear", "sqrt", or "constant" + base_shift: base shift for linear mode + max_shift: max shift for linear mode + constant: shift value for constant mode (default 3.0 matches Pika) + + Returns: + sigma values for noise scheduling + """ + if shift_type == "linear": + # Linear interpolation based on sequence length + mu = base_shift + (max_shift - base_shift) * (image_seq_len / 4096) + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)) + + elif shift_type == "sqrt": + # Square root scaling (Flux-style) + # Assuming 128x128 latent space (1024x1024 image) gives mu=3 + mu = np.maximum(1.0, np.sqrt(image_seq_len / (128.0 * 128.0)) * 3.0) + return mu / (mu + (1 / t - 1)) + + elif shift_type == "constant": + # Constant shift (Pika default) + return constant / (constant + (1 / t - 1)) + + else: + # No shift, return original t + return t + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, + batch_size: int, + logit_mean: float = 0.0, + logit_std: float = 1.0, + mode_scale: float = 1.29, +): + """ + Sample timesteps from different distributions for better training coverage. + + Args: + weighting_scheme: "uniform", "logit_normal", or "mode" + batch_size: number of samples to generate + logit_mean: mean for logit-normal distribution + logit_std: std for logit-normal distribution + mode_scale: scale for mode-based sampling + + Returns: + Tensor of shape (batch_size,) with values in [0, 1] + """ + if weighting_scheme == "logit_normal": + # SD3-style logit-normal sampling + u = torch.normal( + mean=logit_mean, + std=logit_std, + size=(batch_size,), + device="cpu" + ) + u = torch.nn.functional.sigmoid(u) + + elif weighting_scheme == "mode": + # Mode-based sampling (concentrates around certain timesteps) + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + + else: + # Uniform sampling (default) + u = torch.rand(size=(batch_size,), device="cpu") + + return u + + +def get_flow_match_loss_weight(sigma: torch.Tensor, shift: float = 3.0): + """ + Compute loss weights for flow matching based on sigma values. + + Higher sigma (more noise) typically gets higher weight. + + Args: + sigma: sigma values in range [0, 1] + shift: weight scaling factor + + Returns: + Loss weights with same shape as sigma + """ + # Flow matching weight: weight = 1 + shift * sigma + # This gives more weight to noisier timesteps + weight = 1.0 + shift * sigma + return weight \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/inference/configs/__init__.py b/src/megatron/bridge/models/wan/inference/configs/__init__.py index e7f95d7125..a28c03c5fd 100644 --- a/src/megatron/bridge/models/wan/inference/configs/__init__.py +++ b/src/megatron/bridge/models/wan/inference/configs/__init__.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import copy import os diff --git a/src/megatron/bridge/models/wan/inference/configs/shared_config.py b/src/megatron/bridge/models/wan/inference/configs/shared_config.py index 04a9f45421..37d3ae0c43 100644 --- a/src/megatron/bridge/models/wan/inference/configs/shared_config.py +++ b/src/megatron/bridge/models/wan/inference/configs/shared_config.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch from easydict import EasyDict diff --git a/src/megatron/bridge/models/wan/inference/configs/wan_i2v_14B.py b/src/megatron/bridge/models/wan/inference/configs/wan_i2v_14B.py index 53bf2211b8..764d2ed8c3 100644 --- a/src/megatron/bridge/models/wan/inference/configs/wan_i2v_14B.py +++ b/src/megatron/bridge/models/wan/inference/configs/wan_i2v_14B.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch from easydict import EasyDict diff --git a/src/megatron/bridge/models/wan/inference/configs/wan_t2v_14B.py b/src/megatron/bridge/models/wan/inference/configs/wan_t2v_14B.py index 9d0ee69dea..c793f7f6c3 100644 --- a/src/megatron/bridge/models/wan/inference/configs/wan_t2v_14B.py +++ b/src/megatron/bridge/models/wan/inference/configs/wan_t2v_14B.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. from easydict import EasyDict from .shared_config import wan_shared_cfg diff --git a/src/megatron/bridge/models/wan/inference/configs/wan_t2v_1_3B.py b/src/megatron/bridge/models/wan/inference/configs/wan_t2v_1_3B.py index ea9502b0df..c8458ce804 100644 --- a/src/megatron/bridge/models/wan/inference/configs/wan_t2v_1_3B.py +++ b/src/megatron/bridge/models/wan/inference/configs/wan_t2v_1_3B.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. from easydict import EasyDict from .shared_config import wan_shared_cfg diff --git a/src/megatron/bridge/models/wan/inference/utils/fm_solvers.py b/src/megatron/bridge/models/wan/inference/utils/fm_solvers.py index 17bef85000..a38b755c40 100644 --- a/src/megatron/bridge/models/wan/inference/utils/fm_solvers.py +++ b/src/megatron/bridge/models/wan/inference/utils/fm_solvers.py @@ -1,6 +1,5 @@ # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py # Convert dpm solver for flow matching -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import inspect import math diff --git a/src/megatron/bridge/models/wan/inference/utils/fm_solvers_unipc.py b/src/megatron/bridge/models/wan/inference/utils/fm_solvers_unipc.py index fb502f2eb2..8d96058394 100644 --- a/src/megatron/bridge/models/wan/inference/utils/fm_solvers_unipc.py +++ b/src/megatron/bridge/models/wan/inference/utils/fm_solvers_unipc.py @@ -1,6 +1,5 @@ # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py # Convert unipc for flow matching -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import math from typing import List, Optional, Tuple, Union diff --git a/src/megatron/bridge/models/wan/inference/utils/utils.py b/src/megatron/bridge/models/wan/inference/utils/utils.py index d72599967f..a57f9bb993 100644 --- a/src/megatron/bridge/models/wan/inference/utils/utils.py +++ b/src/megatron/bridge/models/wan/inference/utils/utils.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import binascii import os diff --git a/src/megatron/bridge/models/wan/modules/t5.py b/src/megatron/bridge/models/wan/modules/t5.py index c841b044a2..fecd989e07 100644 --- a/src/megatron/bridge/models/wan/modules/t5.py +++ b/src/megatron/bridge/models/wan/modules/t5.py @@ -1,5 +1,4 @@ # Modified from transformers.models.t5.modeling_t5 -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging import math diff --git a/src/megatron/bridge/models/wan/modules/tokenizers.py b/src/megatron/bridge/models/wan/modules/tokenizers.py index 121e591c48..a69972adf2 100644 --- a/src/megatron/bridge/models/wan/modules/tokenizers.py +++ b/src/megatron/bridge/models/wan/modules/tokenizers.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import html import string diff --git a/src/megatron/bridge/models/wan/modules/vae.py b/src/megatron/bridge/models/wan/modules/vae.py index 5c6da57235..d4f1ef1d0e 100644 --- a/src/megatron/bridge/models/wan/modules/vae.py +++ b/src/megatron/bridge/models/wan/modules/vae.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging import torch diff --git a/src/megatron/bridge/models/wan/rope_utils.py b/src/megatron/bridge/models/wan/rope_utils.py index 6e25fdb24b..93d0e93363 100644 --- a/src/megatron/bridge/models/wan/rope_utils.py +++ b/src/megatron/bridge/models/wan/rope_utils.py @@ -1,5 +1,7 @@ import torch from torch.cuda import amp +from megatron.bridge.models.wan.utils.utils import split_inputs_cp +from megatron.core import parallel_state class Wan3DRopeEmbeddings(torch.nn.Module): """ @@ -20,7 +22,7 @@ def rope_params(self, max_position_len, dim_head, theta=10000): freqs = torch.outer( torch.arange(max_position_len), 1.0 / torch.pow(theta, - torch.arange(0, dim_head, 2).to(torch.float64).div(dim_head))) + torch.arange(0, dim_head, 2).div(dim_head))) return freqs def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): @@ -56,6 +58,8 @@ def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): freqs_real = torch.cat(freqs_real, dim=1) # TODO: if run context/sequence related parallel, then we need to scatter - # the freqs_real to the context parallel region, using specific method "get_pos_emb_on_this_cp_rank" + # the freqs_real to the context parallel region, using specific cp_rank split method + if parallel_state.get_context_parallel_world_size() > 1: + freqs_real = split_inputs_cp(freqs_real, 0) return freqs_real \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/utils/utils.py b/src/megatron/bridge/models/wan/utils/utils.py new file mode 100644 index 0000000000..8551c6fc50 --- /dev/null +++ b/src/megatron/bridge/models/wan/utils/utils.py @@ -0,0 +1,128 @@ +import torch +from typing import Tuple +from torch.distributed import all_gather +import megatron.core.parallel_state as parallel_state +import math + +def grid_sizes_calculation( + input_shape: Tuple[int, int, int], # (F_latents, H_latents, W_latents) + patch_size: Tuple[int, int, int], # (pF, pH, pW) +) -> Tuple[int, int, int]: + """ + Compute the (f,h,w) output spatial/temporal dimensions of a Conv3d patch embedder. + """ + + F_latents, H_latents, W_latents = input_shape + pF, pH, pW = patch_size + F_patches = F_latents // pF + H_patches = H_latents // pH + W_patches = W_latents // pW + + return [F_patches, H_patches, W_patches] + + +def patchify(x, patch_size): + """ + Convert a list of reconstructed video tensor into patch embeddings. + This method is the inverse of `unpatchify`. + + Args: + x (list[torch.Tensor]): list of tensors, each with shape [c, F_patches * pF, H_patches * pH, W_patches * pW] + patch_size (tuple): (pF, pH, pW) + + Returns: + torch.Tensor: shape [ (F_patches * H_patches * W_patches), (c * pF * pH * pW)], + """ + out = [] + for u in x: + c, F_pF, H_pH, W_pW = u.shape + pF, pH, pW = patch_size + assert F_pF % pF == 0 and H_pH % pH == 0 and W_pW % pW == 0, \ + "Spatial dimensions must be divisible by patch size." + + F_patches, H_patches, W_patches = F_pF // pF, H_pH // pH, W_pW // pW + + # split spatial dims into (grid, patch) and reorder to match original patch layout: + # start: (c, F_patches * pF, H_patches * pH, W_patches * pW) + # reshape -> (c, F_patches, pF, H_patches, pH, W_patches, pW) + # permute -> (F_patches, H_patches, W_patches, pF, pH, pW, c) + t = u.reshape(c, F_patches, pF, H_patches, pH, W_patches, pW) + t = t.permute(1, 3, 5, 2, 4, 6, 0) + + num_patches = F_patches * H_patches * W_patches + out.append(t.reshape(num_patches, c * (pF * pH * pW))) + return out + + +def unpatchify(x: list[torch.Tensor], grid_sizes: list[Tuple[int, int, int]], out_dim: int, patch_size: Tuple[int, int, int]) -> list[torch.Tensor]: + """ + Reconstruct video tensors from patch embeddings into a list of videotensors. + + Args: + x (list[torch.Tensor]): + list of tensors, each with shape [seq_len, c * pF * pH * pW] + grid_sizes (list[Tuple[int, int, int]]): + list of tensors, each with original spatial-temporal grid dimensions before patching, + (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + list[torch.Tensor]: list of tensors, each with shape [c, F_latents, H_latents, W_latents] + """ + + c = out_dim + out = [] + for u, v in zip(x, grid_sizes): + u = u[:math.prod(v)].view(*v, *patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, patch_size)]) + out.append(u) + return out + + +def split_inputs_cp(x: torch.Tensor, seq_dim: int = 0) -> torch.Tensor: + """ + Split input tensor along the sequence dimension for context parallelism. + + Args: + x: Input tensor to be split. (e.g. shape [seq_len, batch_size, ...]) + seq_dim: The dimension along which to split the input (sequence dimension). + + Returns: + A slice of the input tensor corresponding to the current rank. (e.g. shape [seq_len/cp_size, batch_size, ...]) + """ + + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + cp_rank = parallel_state.get_context_parallel_rank() + assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}" + x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim + 1) :]) + seq_idx = torch.tensor([cp_rank], device=x.device) + x = x.index_select(seq_dim, seq_idx) + # Note that the new sequence length is the original sequence length / cp_size + x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + return x + + +def cat_outputs_cp(x: torch.Tensor, seq_dim: int) -> torch.Tensor: + """ + Concatenate tensors from multiple processes along a specified dimension. + + Args: + x: Input tensor to be concatenated. (e.g. shape [seq_len/cp_size, batch_size, ...]) + seq_dim: The dimension along which to concatenate the input tensors. + + Returns: + A tensor with the concatenated tensors. (e.g. shape [seq_len, batch_size, ...]) + """ + + cp_group = parallel_state.get_context_parallel_group() + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + gathered_tensors = [torch.zeros_like(x) for _ in range(cp_size)] + # Attempt to gather tensors from all ranks + # PyTorch’s all_gather orders outputs by rank within the group, which matches how chunks were selected by cp_rank + all_gather(gathered_tensors, x, group=cp_group) + gathered_tensors = torch.cat(gathered_tensors, dim=seq_dim) + return gathered_tensors + else: + return x diff --git a/src/megatron/bridge/models/wan/wan_bridge.py b/src/megatron/bridge/models/wan/wan_bridge.py index 80d7eafafe..b37540bcc9 100644 --- a/src/megatron/bridge/models/wan/wan_bridge.py +++ b/src/megatron/bridge/models/wan/wan_bridge.py @@ -60,50 +60,20 @@ def provider_bridge(self, hf_pretrained: PreTrainedWAN) -> WanModelProvider: ffn_hidden_size=hf_config.ffn_dim, num_attention_heads=hf_config.num_attention_heads, activation_func=openai_gelu, - add_qkv_bias=True, in_channels=hf_config.in_channels, out_channels=hf_config.out_channels, text_dim=hf_config.text_dim, patch_spatial=hf_config.patch_size[1], patch_temporal=hf_config.patch_size[0], - patch_size=hf_config.patch_size, # ??? adundant variable - rotary_interleaved=True, layernorm_epsilon=hf_config.eps, hidden_dropout=0, attention_dropout=0, use_cpu_initialization=True, freq_dim=hf_config.freq_dim, - qk_layernorm_per_head=False, bf16=False, params_dtype=torch.float32, ) - # num_layers=source_config.num_layers, # dummy setting - # hidden_size=source_config.num_attention_heads * source_config.attention_head_dim, - # crossattn_emb_size=source_config.num_attention_heads * source_config.attention_head_dim, - # ffn_hidden_size=source_config.ffn_dim, - # num_attention_heads=source_config.num_attention_heads, - # activation_func=openai_gelu, - # add_qkv_bias=True, - # in_channels=source_config.in_channels, - # text_dim=source_config.text_dim, - # # model_channels=256, - # # DEBUGGING - # patch_spatial=source_config.patch_size[1], - # patch_temporal=source_config.patch_size[0], - # patch_size=source_config.patch_size, - # rotary_interleaved=True, - # layernorm_epsilon=1e-06, - # hidden_dropout=0, - # attention_dropout=0, - # use_cpu_initialization=True, - # # DEBUGGING - # freq_dim=source_config.freq_dim, - # bf16=False, - # params_dtype=torch.float32, - # # DEBUGGING - # qk_layernorm_per_head=False, - return provider diff --git a/src/megatron/bridge/models/wan/wan_layer_spec.py b/src/megatron/bridge/models/wan/wan_layer_spec.py index fdd4d9957f..f98576ada1 100644 --- a/src/megatron/bridge/models/wan/wan_layer_spec.py +++ b/src/megatron/bridge/models/wan/wan_layer_spec.py @@ -1,3 +1,4 @@ + # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -65,7 +66,7 @@ def forward(self, x): Args: x(Tensor): Shape [B, L, C] """ - return super().forward(x.float()).type_as(x) + return super().forward(x).type_as(x) @dataclass @@ -206,7 +207,7 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): if self.q_layernorm is not None: if self.layernorm_across_head: q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] - q_flat = self.q_layernorm(q_flat.float()) # Wan RMSNorm cast input to float32 + q_flat = self.q_layernorm(q_flat) query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] else: query = self.q_layernorm(query.contiguous()) @@ -214,7 +215,7 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): if self.k_layernorm is not None: if self.layernorm_across_head: k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() - k_flat = self.k_layernorm(k_flat.float()) # Wan RMSNorm cast input to float32 + k_flat = self.k_layernorm(k_flat) key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) else: key = self.k_layernorm(key.contiguous()) @@ -333,7 +334,7 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states): if self.q_layernorm is not None: if self.layernorm_across_head: q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] - q_flat = self.q_layernorm(q_flat.float()) # Wan RMSNorm cast input to float32 + q_flat = self.q_layernorm(q_flat) query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] else: query = self.q_layernorm(query.contiguous()) @@ -341,7 +342,7 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states): if self.k_layernorm is not None: if self.layernorm_across_head: k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() - k_flat = self.k_layernorm(k_flat.float()) # Wan RMSNorm cast input to float32 + k_flat = self.k_layernorm(k_flat) key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) else: key = self.k_layernorm(key.contiguous()) @@ -384,10 +385,7 @@ def __init__( setattr(self.modulation, "sequence_parallel", config.sequence_parallel) def forward(self, timestep_emb): - assert timestep_emb.dtype == torch.float32 - with amp.autocast(dtype=torch.float32): - e = (self.modulation + timestep_emb).chunk(6, dim=1) - assert e[0].dtype == torch.float32 + e = (self.modulation + timestep_emb).chunk(6, dim=1) return e # @jit_fuser @@ -490,7 +488,7 @@ def forward( # adaLN with scale + shift + gate pre_full_attn_layernorm_output_ada = self.adaLN.modulate( - self.norm1(hidden_states.float()), # Wan's LayerNorm implementation forward pass casts input to float32 + self.norm1(hidden_states), shift=shift_full, scale=scale_full, ) @@ -506,13 +504,12 @@ def forward( if bias is not None: attention_output = attention_output + bias - with amp.autocast(dtype=torch.float32): - hidden_states = self.adaLN.scale_add(residual=hidden_states, x=attention_output, gate=gate_full) + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=attention_output, gate=gate_full) # ******************************************** cross attention ****************************************************** attention_output, bias = self.cross_attention( - self.norm3(hidden_states.float()), # Wan's LayerNorm implementation forward pass casts input to float32 + self.norm3(hidden_states), attention_mask=context_mask, key_value_states=context, packed_seq_params=packed_seq_params['cross_attention'], @@ -525,7 +522,7 @@ def forward( # ******************************************** mlp ****************************************************** pre_mlp_layernorm_output_ada = self.adaLN.modulate( - self.norm2(hidden_states.float()), # Wan's LayerNorm implementation forward pass casts input to float32 + self.norm2(hidden_states), shift=shift_mlp, scale=scale_mlp, ) @@ -534,9 +531,7 @@ def forward( if bias is not None: mlp_output = mlp_output + bias - with amp.autocast(dtype=torch.float32): - hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) - + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) # TODO: Jit compiled function creates 'view' tensor. This tensor # potentially gets saved in the MPU checkpoint function context, diff --git a/src/megatron/bridge/models/wan/wan_model.py b/src/megatron/bridge/models/wan/wan_model.py index 47662dbcc7..d11b780313 100644 --- a/src/megatron/bridge/models/wan/wan_model.py +++ b/src/megatron/bridge/models/wan/wan_model.py @@ -39,7 +39,7 @@ def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 half = dim // 2 - position = position.type(torch.float64) + position = position # calculation sinusoid = torch.outer( @@ -70,10 +70,8 @@ def forward(self, x, e): x(Tensor): Shape [B, L1, C] e(Tensor): Shape [B, C] """ - assert e.dtype == torch.float32 - with amp.autocast(dtype=torch.float32): - e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) return x @@ -122,6 +120,8 @@ def __init__( self.patch_temporal = self.config.patch_temporal self.patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial) + # these attributes are unused for images/videos, we just set because bridge training requires for LLMs + self.share_embeddings_and_output_weights = False ###################################### ########## Wan architecture ########## @@ -189,7 +189,8 @@ def forward( seq_len, batch_size, _ = x.shape c = self.out_channels pF, pH, pW = self.patch_size - x = x.reshape(seq_len * batch_size, c, pF, pH, pW) # output: x.shape [s * b, c, pF, pH, pW] + x = x.reshape(seq_len * batch_size, pF, pH, pW, c) # output: x.shape [s * b, pF, pH, pW, c] + x = x.permute(0, 4, 1, 2, 3) # output: x.shape [s * b, c, pF, pH, pW] x = self.patch_embedding(x) # output: x.shape [s * b, hidden_size, 1, 1, 1] x = x.flatten(1) # output: x.shape [s * b, hidden_size] x = x.reshape(seq_len, batch_size, -1) # output: x.shape [s, b, hidden_size] @@ -204,11 +205,10 @@ def forward( x = self.decoder.input_tensor # time embeddings - with amp.autocast(dtype=torch.float32): - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t).float()) - e0 = self.time_projection(e).unflatten(1, (6, self.config.hidden_size)) - assert e.dtype == torch.float32 and e0.dtype == torch.float32 + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(x.dtype) + ) + e0 = self.time_projection(e).unflatten(1, (6, self.config.hidden_size)) # context embeddings context = self.text_embedding(context) # shape [text_len, b, hidden_size] diff --git a/src/megatron/bridge/models/wan/wan_provider.py b/src/megatron/bridge/models/wan/wan_provider.py index de7487f3ac..a162a65c56 100644 --- a/src/megatron/bridge/models/wan/wan_provider.py +++ b/src/megatron/bridge/models/wan/wan_provider.py @@ -38,6 +38,8 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): layernorm_epsilon: float = 1e-6 normalization: str = "RMSNorm" layernorm_zero_centered_gamma: bool = False + add_qkv_bias: bool = True + rotary_interleaved: bool = True hidden_dropout: float = 0 attention_dropout: float = 0 fp16_lm_cross_entropy: bool = False @@ -48,6 +50,8 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): # these attributes are unused for images/videos, we just set because bridge training requires for LLMs seq_length: int = 1024 share_embeddings_and_output_weights: bool = False + vocab_size: int = 25256 * 8 + make_vocab_size_divisible_by: int = 128 # images/videos attributes in_channels: int = 16 diff --git a/src/megatron/bridge/models/wan/wan_step.py b/src/megatron/bridge/models/wan/wan_step.py index a969f30135..58429a6856 100644 --- a/src/megatron/bridge/models/wan/wan_step.py +++ b/src/megatron/bridge/models/wan/wan_step.py @@ -18,32 +18,20 @@ import torch from megatron.core import parallel_state -from megatron.core.models.gpt import GPTModel +from megatron.core.models.common.vision_module.vision_module import VisionModule from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.utils import get_batch_on_this_cp_rank, get_model_config -# from megatron.bridge.models.DiTModel.edm.edm_pipeline import EDMPipeline +from megatron.core.utils import get_model_config from megatron.bridge.models.wan.flow_matching.flow_pipeline import FlowPipeline - -from megatron.bridge.training.config import ConfigContainer, FinetuningDatasetConfig from megatron.bridge.training.losses import masked_next_token_loss from megatron.bridge.training.state import GlobalState - logger = logging.getLogger(__name__) def wan_data_step(qkv_format, dataloader_iter): batch = next(iter(dataloader_iter.iterable)) - # # can we do this ??? - # batch = get_batch_on_this_cp_rank(batch) - batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} - - # ??? Should we do the padding here, by padding to the longest sequence length in the batch? - # ??? Or should we do the padding in the TaskEncoder? - # => do task encoder padding here - # Construct packed sequence parameters if ("seq_len_q" in batch) and ("seq_len_kv" in batch): cu_seqlens = batch["seq_len_q"].cumsum(dim=0).to(torch.int32) @@ -69,59 +57,16 @@ def wan_data_step(qkv_format, dataloader_iter): return batch -def get_batch_on_this_cp_rank(data): - """Split the data for context parallelism.""" - from megatron.core import mpu - - cp_size = mpu.get_context_parallel_world_size() - cp_rank = mpu.get_context_parallel_rank() - - t = 16 - if cp_size > 1: - # cp split on seq_length, for video_latent, noise_latent and pos_ids - assert t % cp_size == 0, "t must divisibly by cp_size" - num_valid_tokens_in_ub = None - if "loss_mask" in data and data["loss_mask"] is not None: - num_valid_tokens_in_ub = data["loss_mask"].sum() - - for key, value in data.items(): - if (value is not None) and (key in ["video", "video_latent", "noise_latent", "pos_ids"]): - if len(value.shape) > 5: - value = value.squeeze(0) - B, C, T, H, W = value.shape - if T % cp_size == 0: - # FIXME packed sequencing - data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous() - else: - # FIXME packed sequencing - data[key] = value.view(B, C, T, cp_size, H // cp_size, W)[:, :, :, cp_rank, ...].contiguous() - loss_mask = data["loss_mask"] - data["loss_mask"] = loss_mask.view(loss_mask.shape[0], cp_size, loss_mask.shape[1] // cp_size)[ - :, cp_rank, ... - ].contiguous() - data["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub - - return data - - class WanForwardStep: def __init__(self): self.diffusion_pipeline = FlowPipeline() def __call__( - self, state: GlobalState, data_iterator: Iterable, model: GPTModel, return_schedule_plan: bool = False + self, state: GlobalState, data_iterator: Iterable, model: VisionModule ) -> tuple[torch.Tensor, partial]: - """Forward training step. - - Args: - state: Global state for the run - data_iterator: Input data iterator - model: The GPT Model - return_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor - - Returns: - tuple containing the output tensor and the loss function + """ + Forward training step. """ timers = state.timers straggler_timer = state.straggler_timer @@ -140,30 +85,18 @@ def __call__( check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss - # DEBUGGING - run_debug = False - if run_debug: - print("---- Sample info [WanForwardStep] ----") - print(f"batch['video_latents'] shape: {batch['video_latents'].shape}") - print(f"batch['context_embeddings'] shape: {batch['context_embeddings'].shape}") - print(f"batch['loss_mask'] shape: {batch['loss_mask'].shape}") - print(f"batch['grid_sizes']: {batch['grid_sizes']}") - print(f"batch['packed_seq_params']: {batch['packed_seq_params']}") - - # run diffusion training step with straggler_timer: if parallel_state.is_pipeline_last_stage(): - output_batch, loss = self.diffusion_pipeline.training_step(model, batch) + output_batch, loss, split_loss_mask = self.diffusion_pipeline.training_step(model, batch) output_tensor = torch.mean(loss, dim=-1) + batch["loss_mask"] = split_loss_mask else: output_tensor = self.diffusion_pipeline.training_step(model, batch) - # DEBUGGING - # ??? do we need to gather output with sequence or context parallelism here - # ??? especially when we have pipeline parallelism - + # TODO: do we need to gather output with sequence or context parallelism here + # especially when we have pipeline parallelism loss = output_tensor if "loss_mask" not in batch or batch["loss_mask"] is None: diff --git a/src/megatron/bridge/recipes/wan/wan.py b/src/megatron/bridge/recipes/wan/wan.py new file mode 100644 index 0000000000..b4975ad5a9 --- /dev/null +++ b/src/megatron/bridge/recipes/wan/wan.py @@ -0,0 +1,219 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List, Optional, Union + +from megatron.bridge.data.wan.wan_energon_datamodule import WanDataModuleConfig +from megatron.bridge.models.wan.wan_provider import WanModelProvider +import torch +from megatron.core.distributed import DistributedDataParallelConfig + +from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing +from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE +from megatron.bridge.training.comm_overlap import CommOverlapConfig +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + LoggerConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, get_mixed_precision_config + + +def model_config( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 1, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + seq_length: int = 1024, +) -> WanModelProvider: + """ + Configure the Wan model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + seq_length (int): Sequence length for the model. + Returns: + WanModelProvider: Configuration for the Wan model. + """ + return WanModelProvider( + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_dtype, + virtual_pipeline_model_parallel_size=None, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + seq_length=seq_length, + ) + + +def pretrain_config( + dir: Optional[str] = None, + name: str = "default", + # Dataset configuration + data_paths: Optional[List[str]] = None, + data_args_path: Optional[str] = None, + train_data_path: Optional[List[str]] = None, + valid_data_path: Optional[List[str]] = None, + test_data_path: Optional[List[str]] = None, + per_split_data_args_path: Optional[str] = None, + mock: bool = False, + # Model configuration + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 1, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + use_megatron_fsdp: bool = False, + # Training hyperparameters + train_iters: int = 10000, + global_batch_size: int = 4, + micro_batch_size: int = 1, + lr: float = 0.9e-4, + lr_warmup_iters: int = 2000, + # Precision recipe + # DEBUGGING + precision_config: Optional[Union[MixedPrecisionConfig, str]] = "bf16_mixed", + # precision_config: Optional[Union[MixedPrecisionConfig, str]] = MixedPrecisionConfig( + # fp32=True, + # params_dtype=torch.float32, + # pipeline_dtype=torch.float32, + # autocast_enabled=False, + # ), + comm_overlap_config: Optional[CommOverlapConfig] = None, +) -> ConfigContainer: + """ + Create a pre-training configuration for GPT3 175B model. + + The default configuration is expected to run on 64 nodes with 8 GPUs each. + + Args: + dir (Optional[str]): Base directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + data_paths (Optional[List[str]]): List of paths to dataset files. If None, mock data will be used. + data_args_path (Optional[str]): Path to file containing data arguments. + train_data_path (Optional[List[str]]): List of training data paths. + valid_data_path (Optional[List[str]]): List of validation data paths. + test_data_path (Optional[List[str]]): List of test data paths. + per_split_data_args_path (Optional[str]): Path to JSON file with per-split data configuration. + mock (bool): Whether to use mock data. If True, ignores data_paths. + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism to be passed to model_config. + sequence_parallelism (bool): Whether to use sequence parallelism. + train_iters (int): Total number of training iterations. + global_batch_size (int): Global batch size for training. + micro_batch_size (int): Micro batch size for training. + seq_length (int): Sequence length for training data. + lr (float): Learning rate. + min_lr (float): Minimum learning rate for cosine decay. + lr_warmup_iters (int): Number of warmup iterations for the learning rate. + precision_config (Optional[Union[MixedPrecisionConfig, str]]): Precision configuration for the model. + comm_overlap_config (Optional[CommOverlapConfig]): Communication overlap configuration for the model. + + Returns: + ConfigContainer: Configuration for pre-training. + """ + base_output_dir = dir if dir is not None else os.path.join(os.getcwd(), "nemo_experiments") + run_output_dir = os.path.join(base_output_dir, name) + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + tensorboard_dir = os.path.join(run_output_dir, "tb_logs") + + + model_cfg = model_config( + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, + pipeline_parallelism_dtype=pipeline_parallelism_dtype, + virtual_pipeline_parallelism=virtual_pipeline_parallelism, + context_parallelism=context_parallelism, + sequence_parallelism=sequence_parallelism, + seq_length=1024, + ) + + opt_config, scheduler = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=lr_warmup_iters, + lr_decay_iters=train_iters, + max_lr=lr, + ) + opt_config.use_precision_aware_optimizer = False + + if isinstance(precision_config, str): + precision_config = get_mixed_precision_config(precision_config) + + precision_config.grad_reduce_in_fp32 = False + + + # Config Container + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=train_iters, + eval_interval=2000, + eval_iters=32, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + manual_gc=True, + manual_gc_interval=100, + manual_gc_eval=100, + ), + optimizer=opt_config, + scheduler=scheduler, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=False, + overlap_param_gather=False, + average_in_collective=True, + use_distributed_optimizer=True, + use_megatron_fsdp=use_megatron_fsdp, # need use_distributed_optimizer=True + ), + dataset= WanDataModuleConfig( + path=None, + seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=10) + , + logger=LoggerConfig( + log_interval=10, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + ), + tokenizer=TokenizerConfig(tokenizer_type="NullTokenizer", vocab_size=DEFAULT_NULL_TOKENIZER_VOCAB_SIZE), + checkpoint=CheckpointConfig( + save_interval=2000, + save=checkpoint_dir, + load=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + ), + rng=RNGConfig(seed=1234), + comm_overlap=comm_overlap_config, + mixed_precision=precision_config, + ) + + return cfg From 74da525025d03d57a821bd0d6429e550d9955142 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 29 Oct 2025 20:19:02 -0700 Subject: [PATCH 04/17] add example commands --- example_commands.sh | 54 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 example_commands.sh diff --git a/example_commands.sh b/example_commands.sh new file mode 100644 index 0000000000..8f6a6ac048 --- /dev/null +++ b/example_commands.sh @@ -0,0 +1,54 @@ +### Finetuning +export HF_TOKEN=... +export WANDB_API_KEY=... +EXP_NAME=... +PRETRAINED_CHECKPOINT=/path/to/pretrained_checkpoint +CHECKPOINT_DIR=/path/to/checkpoint_dir +DATASET_PATH=/path/to/dataset +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=4 examples/recipes/wan/pretrain_wan.py \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=1 \ + model.context_parallel_size=4 \ + model.sequence_parallel=false \ + dataset.path=${DATASET_PATH} \ + checkpoint.save=${CHECKPOINT_DIR} \ + checkpoint.load=${PRETRAINED_CHECKPOINT} \ + checkpoint.load_optim=false \ + checkpoint.save_interval=200 \ + optimizer.lr=5e-6 \ + optimizer.min_lr=5e-6 \ + train.eval_iters=0 \ + scheduler.lr_decay_style=constant \ + scheduler.lr_warmup_iters=0 \ + model.seq_length=2048 \ + dataset.seq_length=2048 \ + train.global_batch_size=1 \ + train.micro_batch_size=1 \ + dataset.global_batch_size=1 \ + dataset.micro_batch_size=1 \ + logger.log_interval=1 \ + logger.wandb_project="wan" \ + logger.wandb_exp_name=${EXP_NAME} \ + logger.wandb_save_dir=${CHECKPOINT_DIR} + + +### Inferencing +export HF_TOKEN=... +CHECKPOINT_DIR=/path/to/checkpoint_dir +T5_DIR=/path/to/t5_weights +VAE_DIR=/path/to/vae_weights +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=4 examples/recipes/wan/inference_wan.py \ + --task t2v-1.3B \ + --sizes 832*480 \ + --checkpoint_dir ${CHECKPOINT_DIR} \ + --checkpoint_step 4000 \ + --t5_checkpoint_dir ${T5_DIR} \ + --vae_checkpoint_dir ${VAE_DIR} \ + --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + --frame_nums 81 \ + --tensor_parallel_size 1 \ + --context_parallel_size 4 \ + --pipeline_parallel_size 1 \ + --sequence_parallel False \ + --base_seed 42 \ + --sample_steps 50 \ No newline at end of file From 01898124d526c7d421913e9bf9b13eadd634c875 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 29 Oct 2025 20:22:11 -0700 Subject: [PATCH 05/17] add example commands --- example_commands.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/example_commands.sh b/example_commands.sh index 8f6a6ac048..d221a66c46 100644 --- a/example_commands.sh +++ b/example_commands.sh @@ -1,3 +1,7 @@ +### Convert checkpoint +See examples/conversion/convert_wan_checkpoints.py for details. + + ### Finetuning export HF_TOKEN=... export WANDB_API_KEY=... From a2a2580da1f7510473de936898c44a6732eab18a Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Fri, 31 Oct 2025 07:38:46 -0700 Subject: [PATCH 06/17] runnable thd, without containers edits --- example_commands.sh | 16 ++++++-- .../bridge/data/wan/wan_taskencoder.py | 4 +- .../flow_matching/flow_inference_pipeline.py | 23 ++--------- .../models/wan/flow_matching/flow_pipeline.py | 22 +++++++---- src/megatron/bridge/models/wan/rope_utils.py | 8 ++-- src/megatron/bridge/models/wan/utils/utils.py | 39 +++++++++++++++++++ .../bridge/models/wan/wan_provider.py | 2 +- 7 files changed, 78 insertions(+), 36 deletions(-) diff --git a/example_commands.sh b/example_commands.sh index d221a66c46..56622697af 100644 --- a/example_commands.sh +++ b/example_commands.sh @@ -1,3 +1,9 @@ +### install dependencies +python3 -m pip install --upgrade diffusers +pip install easydict +pip install imageio +pip install imageio-ffmpeg + ### Convert checkpoint See examples/conversion/convert_wan_checkpoints.py for details. @@ -14,6 +20,7 @@ NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=4 examples/recipes/wan/pretrain_wan. model.pipeline_model_parallel_size=1 \ model.context_parallel_size=4 \ model.sequence_parallel=false \ + model.qkv_format=thd \ dataset.path=${DATASET_PATH} \ checkpoint.save=${CHECKPOINT_DIR} \ checkpoint.load=${PRETRAINED_CHECKPOINT} \ @@ -37,21 +44,24 @@ NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=4 examples/recipes/wan/pretrain_wan. ### Inferencing +# Download T5 weights and VAE weights from "https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/tree/main" +# T5: models_t5_umt5-xxl-enc-bf16.pth, google +# VAE: Wan2.1_VAE.pth export HF_TOKEN=... CHECKPOINT_DIR=/path/to/checkpoint_dir T5_DIR=/path/to/t5_weights VAE_DIR=/path/to/vae_weights -NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=4 examples/recipes/wan/inference_wan.py \ +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ --task t2v-1.3B \ --sizes 832*480 \ --checkpoint_dir ${CHECKPOINT_DIR} \ - --checkpoint_step 4000 \ + --checkpoint_step 1000 \ --t5_checkpoint_dir ${T5_DIR} \ --vae_checkpoint_dir ${VAE_DIR} \ --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ --frame_nums 81 \ --tensor_parallel_size 1 \ - --context_parallel_size 4 \ + --context_parallel_size 1 \ --pipeline_parallel_size 1 \ --sequence_parallel False \ --base_seed 42 \ diff --git a/src/megatron/bridge/data/wan/wan_taskencoder.py b/src/megatron/bridge/data/wan/wan_taskencoder.py index 63f67bd721..a19f755617 100644 --- a/src/megatron/bridge/data/wan/wan_taskencoder.py +++ b/src/megatron/bridge/data/wan/wan_taskencoder.py @@ -104,18 +104,20 @@ def encode_sample(self, sample: dict) -> dict: ) - # def mock_encode_sample(self, sample: dict) -> dict: + # def encode_sample(self, sample: dict) -> dict: # # mock encode sample # video_latent = torch.tensor(torch.randn(16, 3, 104, 60), dtype=torch.float32) # # video_latent = torch.tensor(torch.randn(16, 24, 104, 60), dtype=torch.float32) # grid_size = torch.tensor([video_latent.shape[1] // self.patch_temporal, video_latent.shape[2] // self.patch_spatial, video_latent.shape[3] // self.patch_spatial], dtype=torch.int32) # context_embeddings = torch.tensor(torch.randn(512, 4096), dtype=torch.float32) + # video_metadata = {} # return dict( # video_latent=video_latent, # grid_size=grid_size, # context_embeddings=context_embeddings, + # video_metadata=video_metadata, # ) diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py index fedef4f40d..893bc8a4cf 100644 --- a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py +++ b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py @@ -27,7 +27,7 @@ from megatron.bridge.models.wan.utils.utils import grid_sizes_calculation, patchify from megatron.core import parallel_state from torch.nn import functional as F -from megatron.bridge.models.wan.utils.utils import split_inputs_cp, cat_outputs_cp +from megatron.bridge.models.wan.utils.utils import cat_outputs_cp import math from typing import Tuple, Union @@ -99,9 +99,8 @@ def __init__( wan_checkpoint_dir = self._select_checkpoint_dir(checkpoint_dir, checkpoint_step) self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) - # DEBUGGING - # set qkv_format to to "thd" for context parallelism - self.model.config.qkv_format = "sbhd" + # if we use context parallelism, we need to set qkv_format to "thd" for context parallelism + self.model.config.qkv_format = "thd" # "sbhd" # set self.sp_size=1 for later use, just to respect the original Wan inference code self.sp_size = 1 @@ -486,12 +485,6 @@ def noop_no_sync(): timestep = [t] * batch_size timestep = torch.stack(timestep) - # run context parallelism slitting - if parallel_state.get_context_parallel_world_size() > 1: - latent_model_input = split_inputs_cp(latent_model_input, 0) - arg_c['context'] = split_inputs_cp(arg_c['context'], 0) - arg_null['context'] = split_inputs_cp(arg_null['context'], 0) - self.model.to(self.device) noise_pred_cond = self.forward_pp_step( latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_c) @@ -499,16 +492,6 @@ def noop_no_sync(): noise_pred_uncond = self.forward_pp_step( latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_null) - # run context parallelism gathering - if parallel_state.get_context_parallel_world_size() > 1: - arg_c['context'] = cat_outputs_cp(arg_c['context'], 0) # we need to cat the context back together for the next timestep - arg_null['context'] = cat_outputs_cp(arg_null['context'], 0) # we need to cat the context back together for the next timestep - # TODO: does this step slow down speed??? - noise_pred_cond = noise_pred_cond.contiguous() - noise_pred_uncond = noise_pred_uncond.contiguous() - noise_pred_cond = cat_outputs_cp(noise_pred_cond, 0) - noise_pred_uncond = cat_outputs_cp(noise_pred_uncond, 0) - # run unpatchify unpatchified_noise_pred_cond = noise_pred_cond unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py index 9d272a131e..f6b80c1f19 100644 --- a/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py +++ b/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py @@ -20,7 +20,7 @@ from torch import Tensor from diffusers import WanPipeline from megatron.bridge.models.wan.flow_matching.time_shift_utils import compute_density_for_timestep_sampling -from megatron.bridge.models.wan.utils.utils import patchify, split_inputs_cp +from megatron.bridge.models.wan.utils.utils import patchify, thd_split_inputs_cp class FlowPipeline: @@ -116,6 +116,14 @@ def training_step( # Generate noise noise = torch.randn_like(torch.ones([1, 16, grid_sizes[0][0], grid_sizes[0][1]*2, grid_sizes[0][2]*2], device=video_latents.device), dtype=torch.float32) noise = patchify(noise, (1, 2, 2))[0].unsqueeze(1) + # DEBUGGING + # because video_latents might be padded, we need to make sure noise also be padded to have the same shape + seq_noise = noise.shape[0] + seq_video = video_latents.shape[0] + if seq_noise < seq_video: + pad_len = seq_video - seq_noise + pad = torch.zeros((pad_len, noise.shape[1], noise.shape[2]), device=noise.device, dtype=noise.dtype) + noise = torch.cat([noise, pad], dim=0) # CRITICAL: Manual flow matching (NOT scheduler.add_noise!) # x_t = (1 - σ) * x_0 + σ * ε @@ -140,13 +148,13 @@ def training_step( # ======================================================================== # Split accross context parallelism # ======================================================================== - + if parallel_state.get_context_parallel_world_size() > 1: - video_latents = split_inputs_cp(video_latents, 0) - noisy_latents = split_inputs_cp(noisy_latents, 0) - noise = split_inputs_cp(noise, 0) - context_embeddings = split_inputs_cp(context_embeddings, 0) - split_loss_mask = split_inputs_cp(loss_mask, 0) + video_latents = thd_split_inputs_cp(video_latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + noisy_latents = thd_split_inputs_cp(noisy_latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + noise = thd_split_inputs_cp(noise, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + context_embeddings = thd_split_inputs_cp(context_embeddings, packed_seq_params['cross_attention'].cu_seqlens_kv, parallel_state.get_context_parallel_group()) + split_loss_mask = thd_split_inputs_cp(loss_mask, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) else: video_latents = video_latents noisy_latents = noisy_latents diff --git a/src/megatron/bridge/models/wan/rope_utils.py b/src/megatron/bridge/models/wan/rope_utils.py index 93d0e93363..1f79d8bc7c 100644 --- a/src/megatron/bridge/models/wan/rope_utils.py +++ b/src/megatron/bridge/models/wan/rope_utils.py @@ -57,9 +57,9 @@ def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): # We concatenate them along dim=1 to get (max_seq_len, batch_size, 1, dim_head) freqs_real = torch.cat(freqs_real, dim=1) - # TODO: if run context/sequence related parallel, then we need to scatter - # the freqs_real to the context parallel region, using specific cp_rank split method - if parallel_state.get_context_parallel_world_size() > 1: - freqs_real = split_inputs_cp(freqs_real, 0) + # Note: + # when running context_parallel, which must use "thd" for qkv_format, + # we don't need to scatter the freqs to the context parallel region, + # because mcore rope_utils will automatically retrieve the correct freqs for each context parallel region return freqs_real \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/utils/utils.py b/src/megatron/bridge/models/wan/utils/utils.py index 8551c6fc50..9fc8655592 100644 --- a/src/megatron/bridge/models/wan/utils/utils.py +++ b/src/megatron/bridge/models/wan/utils/utils.py @@ -3,6 +3,8 @@ from torch.distributed import all_gather import megatron.core.parallel_state as parallel_state import math +import torch.distributed as dist +import transformer_engine_torch as tex def grid_sizes_calculation( input_shape: Tuple[int, int, int], # (F_latents, H_latents, W_latents) @@ -126,3 +128,40 @@ def cat_outputs_cp(x: torch.Tensor, seq_dim: int) -> torch.Tensor: return gathered_tensors else: return x + + +def thd_split_inputs_cp(x: torch.Tensor, + cu_seqlens_q_padded: torch.Tensor, + cp_group: dist.ProcessGroup) -> torch.Tensor: + """ + Split a THD-packed tensor across CP ranks for inputs shaped [S, B, ...]. + + Args: + x: [S, B, ...] tensor (sequence first). + cu_seqlens_q_padded: 1D int32 THD cu_seqlens (padded) used for packing. + cp_group: context-parallel process group. + + Returns: + x_local: [S_local, B, ...] shard for this CP rank. + """ + # Move to [B, S, ...] to use THD partitioning along S + x_bs = x.transpose(0, 1).contiguous() # [B, S, ...] + + total_S = x_bs.size(1) + cp_size = dist.get_world_size(cp_group) + cp_rank = dist.get_rank(cp_group) + + # Compute this rank's THD partition indices (same API as during gather) + idx = tex.thd_get_partitioned_indices( + cu_seqlens_q_padded, # int32 offsets + total_S, + cp_size, + cp_rank, + ).to(device=x_bs.device, dtype=torch.long) # [S_local] + + # Take the shard along sequence dim + x_local_bs = x_bs.index_select(dim=1, index=idx).contiguous() # [B, S_local, ...] + + # Return to [S, B, ...] + x_local = x_local_bs.transpose(0, 1).contiguous() # [S_local, B, ...] + return x_local \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/wan_provider.py b/src/megatron/bridge/models/wan/wan_provider.py index a162a65c56..fab72afcc4 100644 --- a/src/megatron/bridge/models/wan/wan_provider.py +++ b/src/megatron/bridge/models/wan/wan_provider.py @@ -46,7 +46,7 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): parallel_output: bool = True bf16: bool = False params_dtype: torch.dtype = torch.float32 - qkv_format: str = 'sbhd' + qkv_format: str = "sbhd" # "thd". NOTE: if we use context parallelism, we need to use "thd" # these attributes are unused for images/videos, we just set because bridge training requires for LLMs seq_length: int = 1024 share_embeddings_and_output_weights: bool = False From 77f2673f97a51e041696e122d5cad1118db39658 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Fri, 31 Oct 2025 07:57:16 -0700 Subject: [PATCH 07/17] update commands --- example_commands.sh | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/example_commands.sh b/example_commands.sh index 56622697af..f434a48ee8 100644 --- a/example_commands.sh +++ b/example_commands.sh @@ -1,9 +1,16 @@ +### set path to Megatron-Bridge +export MBRIDGE_PATH=/path/to/Megatron-Bridge +export PYTHONPATH="${MBRIDGE_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-Launcher/launcher_scripts" + + ### install dependencies +pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.15.0rc7 python3 -m pip install --upgrade diffusers pip install easydict pip install imageio pip install imageio-ffmpeg + ### Convert checkpoint See examples/conversion/convert_wan_checkpoints.py for details. From bf4b65252429d61de61f135cd1faafca9c63b283 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Fri, 31 Oct 2025 08:19:23 -0700 Subject: [PATCH 08/17] add example commands --- example_commands.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/example_commands.sh b/example_commands.sh index f434a48ee8..40986b538e 100644 --- a/example_commands.sh +++ b/example_commands.sh @@ -22,6 +22,7 @@ EXP_NAME=... PRETRAINED_CHECKPOINT=/path/to/pretrained_checkpoint CHECKPOINT_DIR=/path/to/checkpoint_dir DATASET_PATH=/path/to/dataset +cd $MBRIDGE_PATH NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=4 examples/recipes/wan/pretrain_wan.py \ model.tensor_model_parallel_size=1 \ model.pipeline_model_parallel_size=1 \ @@ -58,6 +59,7 @@ export HF_TOKEN=... CHECKPOINT_DIR=/path/to/checkpoint_dir T5_DIR=/path/to/t5_weights VAE_DIR=/path/to/vae_weights +cd $MBRIDGE_PATH NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ --task t2v-1.3B \ --sizes 832*480 \ From 2b4fd60dfce13ec87451a23f8a5ae960e1768028 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Fri, 31 Oct 2025 08:51:37 -0700 Subject: [PATCH 09/17] add example commands --- example_commands.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_commands.sh b/example_commands.sh index 40986b538e..c3613cd4c0 100644 --- a/example_commands.sh +++ b/example_commands.sh @@ -23,7 +23,7 @@ PRETRAINED_CHECKPOINT=/path/to/pretrained_checkpoint CHECKPOINT_DIR=/path/to/checkpoint_dir DATASET_PATH=/path/to/dataset cd $MBRIDGE_PATH -NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=4 examples/recipes/wan/pretrain_wan.py \ +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/recipes/wan/pretrain_wan.py \ model.tensor_model_parallel_size=1 \ model.pipeline_model_parallel_size=1 \ model.context_parallel_size=4 \ From a263c00c567d2194d6ec5a560bed3a1c0a871762 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Fri, 31 Oct 2025 14:35:33 -0700 Subject: [PATCH 10/17] fix example_commands.sh --- example_commands.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_commands.sh b/example_commands.sh index c3613cd4c0..d95b75453a 100644 --- a/example_commands.sh +++ b/example_commands.sh @@ -4,7 +4,7 @@ export PYTHONPATH="${MBRIDGE_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-L ### install dependencies -pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.15.0rc7 +pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@ce8185cbbe04f38beb74360e878450f2e8525885 python3 -m pip install --upgrade diffusers pip install easydict pip install imageio From ea6bb12b41f495b10ed743671f66129e792a71c6 Mon Sep 17 00:00:00 2001 From: Tiancheng Zhao Date: Thu, 13 Nov 2025 20:32:05 +0000 Subject: [PATCH 11/17] vace --- .gitignore | 2 + example_commands.sh | 160 ++-- .../conversion/convert_vace_checkpoints.py | 49 ++ .../conversion/convert_wan_checkpoints.py | 92 ++- examples/recipes/wan/inference_vace.py | 378 ++++++++++ .../bridge/models/hf_pretrained/wan.py | 33 +- .../flow_matching/flow_inference_pipeline.py | 701 +++++++++++++++++- .../bridge/models/wan/utils/preprocessor.py | 271 +++++++ src/megatron/bridge/models/wan/wan_bridge.py | 224 +++++- .../bridge/models/wan/wan_layer_spec.py | 237 +++++- src/megatron/bridge/models/wan/wan_model.py | 286 ++++++- .../bridge/models/wan/wan_provider.py | 29 +- vace.sh | 28 + 13 files changed, 2402 insertions(+), 88 deletions(-) create mode 100644 examples/conversion/convert_vace_checkpoints.py create mode 100644 examples/recipes/wan/inference_vace.py create mode 100644 src/megatron/bridge/models/wan/utils/preprocessor.py create mode 100644 vace.sh diff --git a/.gitignore b/.gitignore index 7e7db08e4c..d755ce3aa9 100644 --- a/.gitignore +++ b/.gitignore @@ -182,3 +182,5 @@ slurm*.out # UV package manager .uv/ + +*.mp4 \ No newline at end of file diff --git a/example_commands.sh b/example_commands.sh index d95b75453a..ee68def7b0 100644 --- a/example_commands.sh +++ b/example_commands.sh @@ -1,77 +1,129 @@ -### set path to Megatron-Bridge -export MBRIDGE_PATH=/path/to/Megatron-Bridge -export PYTHONPATH="${MBRIDGE_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-Launcher/launcher_scripts" +# ### set path to Megatron-Bridge +# export MBRIDGE_PATH=/path/to/Megatron-Bridge +# export PYTHONPATH="${MBRIDGE_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-Launcher/launcher_scripts" +export CUDA_VISIBLE_DEVICES=0 -### install dependencies -pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@ce8185cbbe04f38beb74360e878450f2e8525885 -python3 -m pip install --upgrade diffusers -pip install easydict -pip install imageio -pip install imageio-ffmpeg +# ### install dependencies +# pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@ce8185cbbe04f38beb74360e878450f2e8525885 +# python3 -m pip install --upgrade diffusers +# pip install easydict +# pip install imageio +# pip install imageio-ffmpeg -### Convert checkpoint -See examples/conversion/convert_wan_checkpoints.py for details. +# ### Convert checkpoint +# See examples/conversion/convert_wan_checkpoints.py for details. -### Finetuning -export HF_TOKEN=... -export WANDB_API_KEY=... -EXP_NAME=... -PRETRAINED_CHECKPOINT=/path/to/pretrained_checkpoint -CHECKPOINT_DIR=/path/to/checkpoint_dir -DATASET_PATH=/path/to/dataset -cd $MBRIDGE_PATH -NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/recipes/wan/pretrain_wan.py \ - model.tensor_model_parallel_size=1 \ - model.pipeline_model_parallel_size=1 \ - model.context_parallel_size=4 \ - model.sequence_parallel=false \ - model.qkv_format=thd \ - dataset.path=${DATASET_PATH} \ - checkpoint.save=${CHECKPOINT_DIR} \ - checkpoint.load=${PRETRAINED_CHECKPOINT} \ - checkpoint.load_optim=false \ - checkpoint.save_interval=200 \ - optimizer.lr=5e-6 \ - optimizer.min_lr=5e-6 \ - train.eval_iters=0 \ - scheduler.lr_decay_style=constant \ - scheduler.lr_warmup_iters=0 \ - model.seq_length=2048 \ - dataset.seq_length=2048 \ - train.global_batch_size=1 \ - train.micro_batch_size=1 \ - dataset.global_batch_size=1 \ - dataset.micro_batch_size=1 \ - logger.log_interval=1 \ - logger.wandb_project="wan" \ - logger.wandb_exp_name=${EXP_NAME} \ - logger.wandb_save_dir=${CHECKPOINT_DIR} +# ### Finetuning +# export HF_TOKEN=... +# export WANDB_API_KEY=... +# EXP_NAME=... +# PRETRAINED_CHECKPOINT=/path/to/pretrained_checkpoint +# CHECKPOINT_DIR=/path/to/checkpoint_dir +# DATASET_PATH=/path/to/dataset +# cd $MBRIDGE_PATH +# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/recipes/wan/pretrain_wan.py \ +# model.tensor_model_parallel_size=1 \ +# model.pipeline_model_parallel_size=1 \ +# model.context_parallel_size=4 \ +# model.sequence_parallel=false \ +# model.qkv_format=thd \ +# dataset.path=${DATASET_PATH} \ +# checkpoint.save=${CHECKPOINT_DIR} \ +# checkpoint.load=${PRETRAINED_CHECKPOINT} \ +# checkpoint.load_optim=false \ +# checkpoint.save_interval=200 \ +# optimizer.lr=5e-6 \ +# optimizer.min_lr=5e-6 \ +# train.eval_iters=0 \ +# scheduler.lr_decay_style=constant \ +# scheduler.lr_warmup_iters=0 \ +# model.seq_length=2048 \ +# dataset.seq_length=2048 \ +# train.global_batch_size=1 \ +# train.micro_batch_size=1 \ +# dataset.global_batch_size=1 \ +# dataset.micro_batch_size=1 \ +# logger.log_interval=1 \ +# logger.wandb_project="wan" \ +# logger.wandb_exp_name=${EXP_NAME} \ +# logger.wandb_save_dir=${CHECKPOINT_DIR} ### Inferencing # Download T5 weights and VAE weights from "https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/tree/main" # T5: models_t5_umt5-xxl-enc-bf16.pth, google # VAE: Wan2.1_VAE.pth -export HF_TOKEN=... -CHECKPOINT_DIR=/path/to/checkpoint_dir -T5_DIR=/path/to/t5_weights -VAE_DIR=/path/to/vae_weights -cd $MBRIDGE_PATH -NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ + +CHECKPOINT_DIR=/opt/megatron_checkpoint_WAN +T5_DIR=~/.cache/huggingface/hub/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a +VAE_DIR=~/.cache/huggingface/hub/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a +# cd $MBRIDGE_PATH +# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ +# --task t2v-1.3B \ +# --sizes 832*480 \ +# --checkpoint_dir ${CHECKPOINT_DIR} \ +# --checkpoint_step 0000 \ +# --t5_checkpoint_dir ${T5_DIR} \ +# --vae_checkpoint_dir ${VAE_DIR} \ +# --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +# --frame_nums 81 \ +# --tensor_parallel_size 1 \ +# --context_parallel_size 1 \ +# --pipeline_parallel_size 1 \ +# --sequence_parallel False \ +# --base_seed 42 \ +# --sample_steps 50 + + +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_wan.py \ --task t2v-1.3B \ --sizes 832*480 \ --checkpoint_dir ${CHECKPOINT_DIR} \ - --checkpoint_step 1000 \ + --checkpoint_step 0000 \ --t5_checkpoint_dir ${T5_DIR} \ --vae_checkpoint_dir ${VAE_DIR} \ - --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + --prompts "Beautiful maple leaves across the mountain during the autumn." \ --frame_nums 81 \ --tensor_parallel_size 1 \ --context_parallel_size 1 \ --pipeline_parallel_size 1 \ --sequence_parallel False \ --base_seed 42 \ - --sample_steps 50 \ No newline at end of file + --sample_steps 50 + + + # NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 examples/recipes/wan/inference_wan.py \ + # --task t2v-1.3B \ + # --sizes 832*480 \ + # --checkpoint_dir ${CHECKPOINT_DIR} \ + # --checkpoint_step 0000 \ + # --t5_checkpoint_dir ${T5_DIR} \ + # --vae_checkpoint_dir ${VAE_DIR} \ + # --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + # --frame_nums 81 \ + # --tensor_parallel_size 1 \ + # --context_parallel_size 2 \ + # --pipeline_parallel_size 1 \ + # --sequence_parallel False \ + # --base_seed 42 \ + # --sample_steps 50 + + + # NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 examples/recipes/wan/inference_wan.py \ + # --task t2v-1.3B \ + # --sizes 832*480 \ + # --checkpoint_dir ${CHECKPOINT_DIR} \ + # --checkpoint_step 0000 \ + # --t5_checkpoint_dir ${T5_DIR} \ + # --vae_checkpoint_dir ${VAE_DIR} \ + # --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + # --frame_nums 81 \ + # --tensor_parallel_size 1 \ + # --context_parallel_size 1 \ + # --pipeline_parallel_size 2 \ + # --sequence_parallel False \ + # --base_seed 42 \ + # --sample_steps 50 \ No newline at end of file diff --git a/examples/conversion/convert_vace_checkpoints.py b/examples/conversion/convert_vace_checkpoints.py new file mode 100644 index 0000000000..dd0eb6e378 --- /dev/null +++ b/examples/conversion/convert_vace_checkpoints.py @@ -0,0 +1,49 @@ +import os, random, multiprocessing as mp + +def main(): + from megatron.bridge.models.hf_pretrained.wan import PreTrainedVACE + from megatron.bridge.models.wan.wan_bridge import VACEBridge + from megatron.bridge.training.model_load_save import save_megatron_model + + # --- minimal torch.distributed single-rank env --- + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", str(29500 + random.randint(0, 1000))) + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + + # --- build & load --- + hf = PreTrainedVACE("Wan-AI/Wan2.1-VACE-1.3B-Diffusers") + # hf = PreTrainedVACE("Wan-AI/Wan2.1-VACE-14B-Diffusers") + + bridge = VACEBridge() + provider = bridge.provider_bridge(hf) + provider.perform_initialization = False + + # If you're on GPU but want CPU init to reduce peak mem: + megatron_models = provider.provide_distributed_model( + wrap_with_ddp=False, use_cpu_initialization=True + ) + + bridge.load_weights_hf_to_megatron(hf, megatron_models) + + # Save Megatron-format checkpoint (this triggers async writer internally) + save_megatron_model( + megatron_models, + "/opt/megatron_checkpoint_VACE", + hf_tokenizer_path=None + ) + +if __name__ == "__main__": + # On Linux, prefer 'fork' to avoid re-importing the module on spawn. + try: + mp.set_start_method("fork") + except RuntimeError: + # already set (fine on re-entry or non-Linux) + pass + + # If you’re on macOS/Windows and still want to be extra safe: + # mp.freeze_support() + + main() + diff --git a/examples/conversion/convert_wan_checkpoints.py b/examples/conversion/convert_wan_checkpoints.py index 4594ebaa5e..c4cf0bfcf3 100644 --- a/examples/conversion/convert_wan_checkpoints.py +++ b/examples/conversion/convert_wan_checkpoints.py @@ -1,20 +1,74 @@ -from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN -from megatron.bridge.models.wan.wan_bridge import WanBridge -from megatron.bridge.training.model_load_save import save_megatron_model -import os, random -os.environ["MASTER_ADDR"] = "127.0.0.1" -os.environ["MASTER_PORT"] = str(29500 + random.randint(0, 1000)) -os.environ["RANK"] = "0" -os.environ["WORLD_SIZE"] = "1" -os.environ["LOCAL_RANK"] = "0" -# +# from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN +# from megatron.bridge.models.wan.wan_bridge import WanBridge +# from megatron.bridge.training.model_load_save import save_megatron_model +# import os, random +# os.environ["MASTER_ADDR"] = "127.0.0.1" +# os.environ["MASTER_PORT"] = str(29500 + random.randint(0, 1000)) +# os.environ["RANK"] = "0" +# os.environ["WORLD_SIZE"] = "1" +# os.environ["LOCAL_RANK"] = "0" +# # # hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") -hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-14B-Diffusers") -bridge = WanBridge() -# -provider = bridge.provider_bridge(hf) -provider.perform_initialization = False -megatron_models = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True) -# -bridge.load_weights_hf_to_megatron(hf, megatron_models) -save_megatron_model(megatron_models, "/opt/megatron_checkpoint", hf_tokenizer_path=None) \ No newline at end of file +# # hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-14B-Diffusers") +# bridge = WanBridge() +# # +# provider = bridge.provider_bridge(hf) +# provider.perform_initialization = False +# megatron_models = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True) +# # +# bridge.load_weights_hf_to_megatron(hf, megatron_models) +# save_megatron_model(megatron_models, "/opt/megatron_checkpoint", hf_tokenizer_path=None) + + +# convert_wan_checkpoints.py + +import os, random, multiprocessing as mp + +def main(): + from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN + from megatron.bridge.models.wan.wan_bridge import WanBridge + from megatron.bridge.training.model_load_save import save_megatron_model + + # --- minimal torch.distributed single-rank env --- + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", str(29500 + random.randint(0, 1000))) + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + + # --- build & load --- + hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") + # hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-14B-Diffusers") + + bridge = WanBridge() + provider = bridge.provider_bridge(hf) + provider.perform_initialization = False + + # If you're on GPU but want CPU init to reduce peak mem: + megatron_models = provider.provide_distributed_model( + wrap_with_ddp=False, use_cpu_initialization=True + ) + print(megatron_models[0]) + bridge.load_weights_hf_to_megatron(hf, megatron_models) + + + # Save Megatron-format checkpoint (this triggers async writer internally) + save_megatron_model( + megatron_models, + "/opt/megatron_checkpoint_WAN", + hf_tokenizer_path=None + ) + +if __name__ == "__main__": + # On Linux, prefer 'fork' to avoid re-importing the module on spawn. + try: + mp.set_start_method("fork") + except RuntimeError: + # already set (fine on re-entry or non-Linux) + pass + + # If you’re on macOS/Windows and still want to be extra safe: + # mp.freeze_support() + + main() + diff --git a/examples/recipes/wan/inference_vace.py b/examples/recipes/wan/inference_vace.py new file mode 100644 index 0000000000..382cb2dd3a --- /dev/null +++ b/examples/recipes/wan/inference_vace.py @@ -0,0 +1,378 @@ +import argparse +import logging +import os +import sys +import warnings +from datetime import datetime + +warnings.filterwarnings('ignore') + +import random + +import torch +import torch.distributed as dist +from PIL import Image + +from megatron.bridge.models.wan.flow_matching.flow_inference_pipeline import VACEFlowInferencePipeline +from megatron.bridge.models.wan.inference.configs import SIZE_CONFIGS, SUPPORTED_SIZES, MAX_AREA_CONFIGS, WAN_CONFIGS +from megatron.bridge.models.wan.inference.utils.utils import cache_video, cache_image, str2bool + + +EXAMPLE_PROMPT = { + "vace-1.3B": { + "src_ref_images": 'assets/images/girl.png,assets/images/snake.png', + "prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" + }, + "vace-14B": { + "src_ref_images": 'assets/images/girl.png,assets/images/snake.png', + "prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" + } +} + + + + +def validate_args(args): + # Basic check + assert args.checkpoint_dir is not None, "Please specify the checkpoint directory." + assert args.model_name in WAN_CONFIGS, f"Unsupport model name: {args.model_name}" + assert args.model_name in EXAMPLE_PROMPT, f"Unsupport model name: {args.model_name}" + + # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. + if args.sample_steps is None: + args.sample_steps = 50 + + if args.sample_shift is None: + args.sample_shift = 16 + + # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. + if args.frame_nums is None: + args.frame_nums = 81 + + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize) + # Size check + if args.sizes is not None and len(args.sizes) > 0: + for s in args.sizes: + assert s in SUPPORTED_SIZES[args.model_name], f"Unsupport size {s} for model name {args.model_name}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.model_name])}" + return args + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate a image or video from a text prompt or image using Wan" + ) + parser.add_argument( + "--model_name", + type=str, + default="vace-1.3B", + choices=list(WAN_CONFIGS.keys()), + help="The model name to run.") + parser.add_argument( + "--sizes", + type=str, + nargs="+", + default=None, + choices=list(SIZE_CONFIGS.keys()), + help="List of sizes to generate multiple images or videos (WIDTH*HEIGHT). Example: --sizes 1280*720 1920*1080" + ) + parser.add_argument( + "--frame_nums", + type=int, + nargs="+", + default=None, + help="List of frame counts (each should be 4n+1). Broadcasts if single value." + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default=None, + help="The path to the main VACE checkpoint directory.") + parser.add_argument( + "--checkpoint_step", + type=int, + default=None, + help=( + "Optional training step to load, e.g. 1800 -> iter_0001800. " + "If not provided, the latest (largest) step in --checkpoint_dir is used.") + ) + parser.add_argument( + "--t5_checkpoint_dir", + type=str, + default=None, + help="Optional directory containing T5 checkpoint/tokenizer") + parser.add_argument( + "--vae_checkpoint_dir", + type=str, + default=None, + help="Optional directory containing VAE checkpoint") + parser.add_argument( + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." + ) + parser.add_argument( + "--t5_cpu", + action="store_true", + default=False, + help="Whether to place T5 model on CPU.") + parser.add_argument( + "--save_file", + type=str, + default=None, + help="The file to save the generated image or video to.") + parser.add_argument( + "--src_video", + type=str, + nargs="+", + default=None, + help="List of name of the source video. Default None.") + parser.add_argument( + "--src_mask", + type=str, + nargs="+", + default=None, + help="List of name of the source mask. Default None.") + parser.add_argument( + "--src_ref_images", + type=str, + nargs="+", + default=None, + help="List of list of the source reference images. Separated by ','. Default None.") + parser.add_argument( + "--prompts", + type=str, + nargs="+", + default=None, + help="List of prompt to generate the image or video from.") + parser.add_argument( + "--base_seed", + type=int, + default=-1, + help="The seed to use for generating the image or video.") + parser.add_argument( + "--sample_solver", + type=str, + default='unipc', + choices=['unipc', 'dpm++'], + help="The solver used to sample.") + parser.add_argument( + "--sample_steps", type=int, default=None, help="The sampling steps.") + parser.add_argument( + "--sample_shift", + type=float, + default=None, + help="Sampling shift factor for flow matching schedulers.") + parser.add_argument( + "--sample_guide_scale", + type=float, + default=5.0, + help="Classifier free guidance scale.") + parser.add_argument( + "--tensor_parallel_size", + type=int, + default=1, + help="Tensor parallel size.") + parser.add_argument( + "--context_parallel_size", + type=int, + default=1, + help="Context parallel size.") + parser.add_argument( + "--pipeline_parallel_size", + type=int, + default=1, + help="Pipeline parallel size.") + parser.add_argument( + "--sequence_parallel", + type=str2bool, + default=False, + help="Sequence parallel.") + + args = parser.parse_args() + + validate_args(args) + + return args + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)]) + else: + logging.basicConfig(level=logging.ERROR) + + +def generate(args): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = local_rank + _init_logging(rank) + + if args.offload_model is None: + args.offload_model = False if world_size > 1 else True + logging.info( + f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + init_method="env://", + rank=rank, + world_size=world_size) + + cfg = WAN_CONFIGS[args.model_name] + + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {cfg}") + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + if args.prompts is None: + prompts = [EXAMPLE_PROMPT[args.model_name]["prompt"]] + else: + prompts = args.prompts + + if args.src_video is None: + src_video = [EXAMPLE_PROMPT[args.model_name].get("src_video", None)] + else: + src_video = args.src_video + + if args.src_mask is None: + src_mask = [EXAMPLE_PROMPT[args.model_name].get("src_mask", None)] + else: + src_mask = args.src_mask + + if args.src_ref_images is None: + src_ref_images = [EXAMPLE_PROMPT[args.model_name].get("src_ref_images", None)] + else: + src_ref_images = args.src_ref_images + + # Resolve sizes list (default to first supported size for task) + if args.sizes is not None and len(args.sizes) > 0: + size_keys = args.sizes + else: + size_keys = [SUPPORTED_SIZES[args.model_name][0]] + + # Resolve frame counts list (default 81) + if args.frame_nums is not None and len(args.frame_nums) > 0: + frame_nums = args.frame_nums + else: + frame_nums = [81] + + # Enforce 1:1 pairing across lists + assert len(prompts) == len(size_keys) == len(frame_nums), ( + f"prompts ({len(prompts)}), sizes ({len(size_keys)}), and frame_nums ({len(frame_nums)}) " + f"must have the same length") + + logging.info("Creating VACE flow inference pipeline.") + pipeline = VACEFlowInferencePipeline( + config=cfg, + checkpoint_dir=args.checkpoint_dir, + checkpoint_step=args.checkpoint_step, + t5_checkpoint_dir=args.t5_checkpoint_dir, + vae_checkpoint_dir=args.vae_checkpoint_dir, + device_id=device, + rank=rank, + t5_cpu=args.t5_cpu, + tensor_parallel_size=args.tensor_parallel_size, + context_parallel_size=args.context_parallel_size, + pipeline_parallel_size=args.pipeline_parallel_size, + sequence_parallel=args.sequence_parallel, + pipeline_dtype=torch.float32, + ) + + # DEBUGGING + rank = dist.get_rank() + if rank == 0: + print("tensor_parallel_size:", args.tensor_parallel_size) + print("context_parallel_size:", args.context_parallel_size) + print("pipeline_parallel_size:", args.pipeline_parallel_size) + print("sequence_parallel:", args.sequence_parallel) + print("\n\n\n") + + for i in range(len(src_video)): + sub_src_video, sub_src_mask, sub_src_ref_images = pipeline.prepare_source([src_video[i]], + [None], + [None], + frame_nums[i], SIZE_CONFIGS[size_keys[i]], device) + src_video[i], src_mask[i], src_ref_images[i] = *sub_src_video, *sub_src_mask, *sub_src_ref_images + + + logging.info( + f"Generating videos ...") + videos = pipeline.generate( + prompts=prompts, + input_frames=src_video, + input_masks=src_mask, + input_ref_images=src_ref_images, + sizes=[SIZE_CONFIGS[size] for size in size_keys], + frame_nums=frame_nums, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + if rank == 0: + for i, video in enumerate(videos): + formatted_experiment_name = (args.save_file) if args.save_file is not None else "DefaultExp" + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = prompts[i].replace(" ", "_").replace("/", + "_")[:50] + suffix = '.mp4' + formatted_save_file = f"{args.model_name}_{formatted_experiment_name}_videoindex{int(i)}_size{size_keys[i].replace('*','x') if sys.platform=='win32' else size_keys[i]}_{formatted_prompt}_{formatted_time}" + suffix + + # if "t2v" in args.task: + logging.info(f"Saving generated video to {formatted_save_file}") + cache_video( + tensor=video[None], + save_file=formatted_save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + + cache_video( + tensor=src_video[i][None], + save_file=f'{i}_src_video.mp4', + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + logging.info(f"Saving src_video to {i}_src_video.mp4") + + cache_video( + tensor=src_mask[i][None], + save_file=f'{i}_src_mask.mp4', + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(0, 1)) + logging.info(f"Saving src_mask to {i}_src_mask.mp4") + + if src_ref_images[i] is not None: + for j, ref_img in enumerate(src_ref_images[i]): + cache_image( + tensor=ref_img[:, 0, ...], + save_file=f'{i}_src_ref_image_{j}.png', + nrow=1, + normalize=True, + value_range=(-1, 1)) + logging.info(f"Saving src_ref_image_{j} to {i}_src_ref_image_{j}.png") + logging.info("Finished.") + + +if __name__ == "__main__": + args = _parse_args() + generate(args) diff --git a/src/megatron/bridge/models/hf_pretrained/wan.py b/src/megatron/bridge/models/hf_pretrained/wan.py index 97aa6f853c..d682c5cf07 100644 --- a/src/megatron/bridge/models/hf_pretrained/wan.py +++ b/src/megatron/bridge/models/hf_pretrained/wan.py @@ -15,7 +15,7 @@ from pathlib import Path from typing import Optional, Union -from diffusers import WanTransformer3DModel +from diffusers import WanTransformer3DModel, WanVACETransformer3DModel from transformers import AutoConfig from megatron.bridge.models.hf_pretrained.base import PreTrainedBase @@ -39,7 +39,7 @@ def model_name_or_path(self) -> str: # Model loading is optional for conversion; implemented for completeness def _load_model(self) -> WanTransformer3DModel: - return WanTransformer3DModel.from_pretrained(self.model_name_or_path) + return WanTransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer") # Config is required by the WAN bridge def _load_config(self) -> AutoConfig: @@ -48,5 +48,34 @@ def _load_config(self) -> AutoConfig: print(f"Loading config from {self.model_name_or_path}") return WanTransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer").config + + +class PreTrainedVACE(PreTrainedBase): + """ + Lightweight pretrained wrapper for Diffusers WAN models. + + Provides access to WAN config and state through the common PreTrainedBase API + so bridges can consume `.config` and `.state` uniformly. + """ + + def __init__(self, model_name_or_path: Union[str, Path], **kwargs): + self._model_name_or_path = str(model_name_or_path) + super().__init__(**kwargs) + + @property + def model_name_or_path(self) -> str: + return self._model_name_or_path + + # Model loading is optional for conversion; implemented for completeness + def _load_model(self) -> WanVACETransformer3DModel: + return WanVACETransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer") + + # Config is required by the WAN bridge + def _load_config(self) -> AutoConfig: + # WanTransformer3DModel returns a config-like object with required fields + + print(f"Loading config from {self.model_name_or_path}") + + return WanVACETransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer").config diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py index 893bc8a4cf..2f82c7f962 100644 --- a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py +++ b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py @@ -9,13 +9,15 @@ from contextlib import contextmanager from functools import partial +from PIL import Image +import torchvision.transforms.functional as TF import torch import torch.cuda.amp as amp import torch.distributed as dist from tqdm import tqdm -from megatron.bridge.models.wan.wan_model import WanModel -from megatron.bridge.models.wan.wan_provider import WanModelProvider +from megatron.bridge.models.wan.wan_model import WanModel, VACEModel +from megatron.bridge.models.wan.wan_provider import WanModelProvider, VACEModelProvider from megatron.bridge.models.wan.modules.t5 import T5EncoderModel from megatron.bridge.models.wan.modules import WanVAE from megatron.bridge.models.wan.inference.utils.fm_solvers import ( @@ -32,6 +34,8 @@ import math from typing import Tuple, Union +from ..utils.preprocessor import VaceVideoProcessor + class FlowInferencePipeline: def __init__( @@ -162,9 +166,12 @@ def setup_model_from_checkpoint(self, checkpoint_dir): ) if isinstance(model, list): model = model[0] + # for i in list(model.state_dict().keys()): + # print(i) if hasattr(model, "module"): model = model.module - + # for ly in model.decoder.layers: + # print(ly.idx) return model def _select_checkpoint_dir(self, base_dir: str, checkpoint_step) -> str: @@ -549,3 +556,691 @@ def noop_no_sync(): dist.barrier() return videos if self.rank == 0 else None + + + + +class VACEFlowInferencePipeline: + + def __init__( + self, + config, + checkpoint_dir, + checkpoint_step=None, + t5_checkpoint_dir=None, + vae_checkpoint_dir=None, + device_id=0, + rank=0, + t5_cpu=False, + + tensor_parallel_size=1, + context_parallel_size=1, + pipeline_parallel_size=1, + sequence_parallel=False, + pipeline_dtype=torch.float32, + ): + r""" + Initializes the FlowInferencePipeline with the given parameters. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + t5_checkpoint_dir (`str`, *optional*, defaults to None): + Optional directory containing T5 checkpoint and tokenizer; falls back to `checkpoint_dir` if None. + vae_checkpoint_dir (`str`, *optional*, defaults to None): + Optional directory containing VAE checkpoint; falls back to `checkpoint_dir` if None. + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.tensor_parallel_size = tensor_parallel_size + self.context_parallel_size = context_parallel_size + self.pipeline_parallel_size = pipeline_parallel_size + self.sequence_parallel = sequence_parallel + self.pipeline_dtype = pipeline_dtype + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(t5_checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(t5_checkpoint_dir, config.t5_tokenizer), + shard_fn=None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = WanVAE( + vae_pth=os.path.join(vae_checkpoint_dir, config.vae_checkpoint), + device=self.device) + + wan_checkpoint_dir = self._select_checkpoint_dir(checkpoint_dir, checkpoint_step) + self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) + + # if we use context parallelism, we need to set qkv_format to "thd" for context parallelism + self.model.config.qkv_format = "thd" # "sbhd" + + # set self.sp_size=1 for later use, just to respect the original Wan inference code + self.sp_size = 1 + + if dist.is_initialized(): + dist.barrier() + self.model.to(self.device) + + self.sample_neg_prompt = config.sample_neg_prompt + + self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(self.vae_stride, self.patch_size)]), + min_area=832 *480, + max_area=832 *480, + min_fps=self.config.sample_fps, + max_fps=self.config.sample_fps, + zero_start=True, + seq_len=32760, + keep_last=True) + + + def unpatchify(self, x: torch.Tensor, grid_sizes: torch.Tensor, out_dim: int) -> list[torch.Tensor]: + r""" + Reconstruct video tensors from patch embeddings into a list of videotensors. + + Args: + x (torch.Tensor): + Tensor of patchified features, with shape [seq_len, c * pF * pH * pW] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + list[torch.Tensor]: list of tensors, each with shape [c, F_latents, H_latents, W_latents] + """ + + c = out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + + def setup_model_from_checkpoint(self, checkpoint_dir): + provider = VACEModelProvider() + provider.tensor_model_parallel_size = self.tensor_parallel_size + provider.pipeline_model_parallel_size = self.pipeline_parallel_size + provider.context_parallel_size = self.context_parallel_size + provider.sequence_parallel = self.sequence_parallel + provider.pipeline_dtype = self.pipeline_dtype + # Once all overrides are set, finalize the model provider to ensure the post initialization logic is run + provider.finalize() + provider.initialize_model_parallel(seed=0) + + ## Read from megatron checkpoint + from megatron.bridge.training.model_load_save import load_megatron_model as _load_megatron_model + model = _load_megatron_model( + checkpoint_dir, + mp_overrides={ + "tensor_model_parallel_size": self.tensor_parallel_size, + "pipeline_model_parallel_size": self.pipeline_parallel_size, + "context_parallel_size": self.context_parallel_size, + "sequence_parallel": self.sequence_parallel, + "pipeline_dtype": self.pipeline_dtype, + }, + ) + if isinstance(model, list): + model = model[0] + if hasattr(model, "module"): + model = model.module + return model + + def _select_checkpoint_dir(self, base_dir: str, checkpoint_step) -> str: + """ + Resolve checkpoint directory: + - If checkpoint_step is provided, use base_dir/iter_{step:07d} + - Otherwise, pick the largest iter_######## subdirectory under base_dir + """ + if checkpoint_step is not None: + path = os.path.join(base_dir, f"iter_{int(checkpoint_step):07d}") + if os.path.isdir(path): + logging.info(f"Using specified checkpoint: {path}") + return path + raise FileNotFoundError(f"Specified checkpoint step {checkpoint_step} not found at {path}") + + if not os.path.isdir(base_dir): + raise FileNotFoundError(f"Checkpoint base directory does not exist: {base_dir}") + + pattern = re.compile(r"^iter_(\d+)$") + try: + _, latest_path = max( + ((int(pattern.match(e.name).group(1)), e.path) + for e in os.scandir(base_dir) + if e.is_dir() and pattern.match(e.name)), + key=lambda x: x[0], + ) + except ValueError: + raise FileNotFoundError( + f"No checkpoints found under {base_dir}. Expected subdirectories named like 'iter_0001800'.") + + logging.info(f"Auto-selected latest checkpoint: {latest_path}") + return latest_path + + + def vace_encode_frames(self, frames, ref_images, masks=None): + vae = self.vae + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames) + else: + masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks] + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive) + reactive = vae.encode(reactive) + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs) + else: + ref_latent = vae.encode(refs) + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + + def vace_encode_masks(self, masks, ref_images=None): + vae_stride = self.vae_stride + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // vae_stride[0]) + height = 2 * (int(height) // (vae_stride[1] * 2)) + width = 2 * (int(width) // (vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, vae_stride[1], width, vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + vae_stride[1] * vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + + def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device): + area = image_size[0] * image_size[1] + self.vid_proc.set_area(area) + if area == 1280*720: + self.vid_proc.set_seq_len(75600) + elif area == 832*480: + self.vid_proc.set_seq_len(32760) + else: + raise NotImplementedError(f'image_size {image_size} is not supported') + + image_size = (image_size[1], image_size[0]) + image_sizes = [] + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_mask is not None and sub_src_video is not None: + src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask) + src_video[i] = src_video[i].to(device) + src_mask[i] = src_mask[i].to(device) + src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) + image_sizes.append(src_video[i].shape[2:]) + elif sub_src_video is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(image_size) + else: + src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video) + src_video[i] = src_video[i].to(device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(src_video[i].shape[2:]) + + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + image_size = image_sizes[i] + for j, ref_img in enumerate(ref_images): + if ref_img is not None: + ref_img = Image.open(ref_img).convert("RGB") + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + ref_img = white_canvas + src_ref_images[i][j] = ref_img.to(device) + return src_video, src_mask, src_ref_images + + + def decode_latent(self, latent, ref_images=None): + vae = self.vae + if ref_images is None: + ref_images = [None] * len(latent) + else: + assert len(latent) == len(ref_images) + + trimed_latent = [] + for lat, refs in zip(latent, ref_images): + if refs is not None: + lat = lat[:, len(refs):, :, :] + trimed_latent.append(lat) + + return vae.decode(trimed_latent) + + + def forward_pp_step( + self, + latent_model_input: torch.Tensor, + grid_sizes: list[Tuple[int, int, int]], + max_video_seq_len: int, + timestep: torch.Tensor, + vace_context: torch.Tensor, + arg_c: dict, + ) -> torch.Tensor: + """ + Forward pass supporting pipeline parallelism. + """ + + from megatron.core import parallel_state + from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage, recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank + + pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() + is_pp_first = parallel_state.is_pipeline_first_stage(ignore_virtual=True) + is_pp_last = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + + # PP=1: no pipeline parallelism + if pp_world_size == 1: + noise_pred_pp = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + vace_context=vace_context, + **arg_c) + return noise_pred_pp + + # # PP>1: pipeline parallelism + # hidden_size = self.model.config.hidden_size + # batch_size = latent_model_input.shape[1] + # # noise prediction shape for communication between first and last pipeline stages + # noise_pred_pp_shape = list(latent_model_input.shape) + + # if is_pp_first: + # # First stage: compute multimodal + first PP slice, send activations, then receive sampled token + # hidden_states = self.model( + # latent_model_input, + # grid_sizes=grid_sizes, + # t=timestep, + # **arg_c) + # send_to_next_pipeline_rank(hidden_states) + + # noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) + # return noise_pred_pp + + # if is_pp_last: + # # Last stage: recv activations, run final slice + output, sample, broadcast + # recv_buffer = torch.empty( + # (max_video_seq_len, batch_size, hidden_size), + # dtype=next(self.model.parameters()).dtype, + # device=latent_model_input[0].device, + # ) + # recv_from_prev_pipeline_rank_(recv_buffer) + # recv_buffer = recv_buffer.to(torch.bfloat16) # ???? + # self.model.set_input_tensor(recv_buffer) + # noise_pred_pp = self.model( + # latent_model_input, + # grid_sizes=grid_sizes, + # t=timestep, + # **arg_c) + + # noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=noise_pred_pp.dtype, tensor=noise_pred_pp.contiguous()) + # return noise_pred_pp + + # # Intermediate stages: recv -> run local slice -> send -> receive broadcast token + # recv_buffer = torch.empty( + # (max_video_seq_len, batch_size, hidden_size), + # dtype=next(self.model.parameters()).dtype, + # device=latent_model_input[0].device, + # ) + # recv_from_prev_pipeline_rank_(recv_buffer) + # recv_buffer = recv_buffer.to(torch.bfloat16) # ???? + # self.model.set_input_tensor(recv_buffer) + # hidden_states = self.model( + # latent_model_input, + # grid_sizes=grid_sizes, + # t=timestep, + # **arg_c) + # send_to_next_pipeline_rank(hidden_states) + + # noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) + # return noise_pred_pp + + + def generate(self, + prompts, + input_frames, + input_masks, + input_ref_images, + sizes, + frame_nums, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + prompts (`list[str]`): + Text prompt for content generation + Input_frames (`list[Tensor]`): + Input frames for content generation + Input_masks (`list[Tensor]`): + Input masks for content generation + Input_ref_images (`list[Tensor]`): + Input reference images for content generation + sizes (list[tuple[int, int]]): + Controls video resolution, (width,height). + frame_nums (`list[int]`): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N, H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + + + # process source video, mask, reference image + vace_context0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks) + mask0 = self.vace_encode_masks(input_masks, input_ref_images) + vace_context = self.vace_latent(vace_context0, mask0) + + max_video_seq_len = 0 + seq_lens = [] + target_shapes = [] + for item in vace_context0: + target_shape = list(item.shape) + target_shape[0] = int(target_shape[0] / 2) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + seq_lens.append(seq_len) + target_shapes.append(target_shape) + max_video_seq_len = max(seq_lens) + + vace_context = patchify(vace_context, self.patch_size) + # pad to have same length + for i in range(len(vace_context)): + vace_context[i] = F.pad(vace_context[i], (0, 0, 0, max_video_seq_len - vace_context[i].shape[0])) + vace_context = torch.stack(vace_context, dim=1) + + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + + ## process context + context_max_len = 512 + context_lens = [] + contexts = [] + contexts_null = [] + for prompt in prompts: + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([prompt], self.device)[0] + context_null = self.text_encoder([n_prompt], self.device)[0] + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([prompt], torch.device('cpu'))[0].to(self.device) + context_null = self.text_encoder([n_prompt], torch.device('cpu'))[0].to(self.device) + context_lens.append(context_max_len) # all samples have the same context_max_len + contexts.append(context) + contexts_null.append(context_null) + # pad to context_max_len tokens, and stack to a tensor of shape [s, b, hidden] + contexts = [F.pad(context, (0, 0, 0, context_max_len - context.shape[0])) for context in contexts] + contexts_null = [F.pad(context_null, (0, 0, 0, context_max_len - context_null.shape[0])) for context_null in contexts_null] + contexts = torch.stack(contexts, dim=1) + contexts_null = torch.stack(contexts_null, dim=1) + + + ## setup noise + noises = [] + for target_shape in target_shapes: + noises.append( + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ) + + + # calculate grid_sizes + grid_sizes = [grid_sizes_calculation( + input_shape =u.shape[1:], + patch_size=self.model.patch_size, + ) for u in noises] + grid_sizes = torch.tensor(grid_sizes, dtype=torch.long) + + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + # Create a prototype scheduler to compute shared timesteps + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + + # Instantiate per-sample schedulers so each sample maintains its own state + batch_size_for_schedulers = len(noises) + schedulers = [] + for _ in range(batch_size_for_schedulers): + s = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + s.set_timesteps(sampling_steps, device=self.device, shift=shift) + schedulers.append(s) + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noises + + from megatron.core.packed_seq_params import PackedSeqParams + cu_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)]) + cu_q = cu_q.to(torch.int32).to(self.device) + cu_kv_self = cu_q + cu_kv_cross = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(context_lens), dim=0)]) + cu_kv_cross = cu_kv_cross.to(torch.int32).to(self.device) + packed_seq_params = { + "self_attention": PackedSeqParams( + cu_seqlens_q=cu_q, + cu_seqlens_kv=cu_kv_self, + qkv_format=self.model.config.qkv_format, + ), + "cross_attention": PackedSeqParams( + cu_seqlens_q=cu_q, + cu_seqlens_kv=cu_kv_cross, + qkv_format=self.model.config.qkv_format, + ), + } + + + arg_c = {'context': contexts, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} + arg_null = {'context': contexts_null, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} + + for _, t in enumerate(tqdm(timesteps)): + + batch_size = len(latents) + + # patchify latents + unpatchified_latents = latents + latents = patchify(latents, self.patch_size) + # pad to have same length + for i in range(batch_size): + latents[i] = F.pad(latents[i], (0, 0, 0, max_video_seq_len - latents[i].shape[0])) + latents = torch.stack(latents, dim=1) + + + latent_model_input = latents + timestep = [t] * batch_size + timestep = torch.stack(timestep) + + self.model.to(self.device) + noise_pred_cond = self.forward_pp_step( + latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, vace_context=vace_context, arg_c=arg_c) + + noise_pred_uncond = self.forward_pp_step( + latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, vace_context=vace_context, arg_c=arg_null) + + # run unpatchify + unpatchified_noise_pred_cond = noise_pred_cond + unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. + unpatchified_noise_pred_cond = self.unpatchify(unpatchified_noise_pred_cond, grid_sizes, self.vae.model.z_dim) + unpatchified_noise_pred_uncond = noise_pred_uncond + unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond.transpose(0, 1) # bring sbhd -> bshd + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. + unpatchified_noise_pred_uncond = self.unpatchify(unpatchified_noise_pred_uncond, grid_sizes, self.vae.model.z_dim) + + noise_preds = [] + for i in range(batch_size): + noise_pred = unpatchified_noise_pred_uncond[i] + guide_scale * ( + unpatchified_noise_pred_cond[i] - unpatchified_noise_pred_uncond[i]) + noise_preds.append(noise_pred) + + # step and update latents + latents = [] + for i in range(batch_size): + + if sample_solver == 'unipc': + temp_x0 = schedulers[i].step( + noise_preds[i].unsqueeze(0), + t, + unpatchified_latents[i].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + else: + temp_x0 = sample_scheduler.step( + noise_preds[i].unsqueeze(0), + t, + unpatchified_latents[i].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents.append(temp_x0.squeeze(0)) + + x0 = latents + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + if self.rank == 0: + videos = self.decode_latent(x0, input_ref_images) + else: + videos = None + + del noises, latents + if sample_solver == 'unipc': + del schedulers + else: + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos if self.rank == 0 else None diff --git a/src/megatron/bridge/models/wan/utils/preprocessor.py b/src/megatron/bridge/models/wan/utils/preprocessor.py new file mode 100644 index 0000000000..fc5ea6a740 --- /dev/null +++ b/src/megatron/bridge/models/wan/utils/preprocessor.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF + + +class VaceImageProcessor(object): + def __init__(self, downsample=None, seq_len=None): + self.downsample = downsample + self.seq_len = seq_len + + def _pillow_convert(self, image, cvt_type='RGB'): + if image.mode != cvt_type: + if image.mode == 'P': + image = image.convert(f'{cvt_type}A') + if image.mode == f'{cvt_type}A': + bg = Image.new(cvt_type, + size=(image.width, image.height), + color=(255, 255, 255)) + bg.paste(image, (0, 0), mask=image) + image = bg + else: + image = image.convert(cvt_type) + return image + + def _load_image(self, img_path): + if img_path is None or img_path == '': + return None + img = Image.open(img_path) + img = self._pillow_convert(img) + return img + + def _resize_crop(self, img, oh, ow, normalize=True): + """ + Resize, center crop, convert to tensor, and normalize. + """ + # resize and crop + iw, ih = img.size + if iw != ow or ih != oh: + # resize + scale = max(ow / iw, oh / ih) + img = img.resize( + (round(scale * iw), round(scale * ih)), + resample=Image.Resampling.LANCZOS + ) + assert img.width >= ow and img.height >= oh + + # center crop + x1 = (img.width - ow) // 2 + y1 = (img.height - oh) // 2 + img = img.crop((x1, y1, x1 + ow, y1 + oh)) + + # normalize + if normalize: + img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1) + return img + + def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs): + return self._resize_crop(img, oh, ow, normalize) + + def load_image(self, data_key, **kwargs): + return self.load_image_batch(data_key, **kwargs) + + def load_image_pair(self, data_key, data_key2, **kwargs): + return self.load_image_batch(data_key, data_key2, **kwargs) + + def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs): + seq_len = self.seq_len if seq_len is None else seq_len + imgs = [] + for data_key in data_key_batch: + img = self._load_image(data_key) + imgs.append(img) + w, h = imgs[0].size + dh, dw = self.downsample[1:] + + # compute output size + scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw)))) + oh = int(h * scale) // dh * dh + ow = int(w * scale) // dw * dw + assert (oh // dh) * (ow // dw) <= seq_len + imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs] + return *imgs, (oh, ow) + + +class VaceVideoProcessor(object): + def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs): + self.downsample = downsample + self.min_area = min_area + self.max_area = max_area + self.min_fps = min_fps + self.max_fps = max_fps + self.zero_start = zero_start + self.keep_last = keep_last + self.seq_len = seq_len + assert seq_len >= min_area / (self.downsample[1] * self.downsample[2]) + + def set_area(self, area): + self.min_area = area + self.max_area = area + + def set_seq_len(self, seq_len): + self.seq_len = seq_len + + @staticmethod + def resize_crop(video: torch.Tensor, oh: int, ow: int): + """ + Resize, center crop and normalize for decord loaded video (torch.Tensor type) + + Parameters: + video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C) + oh - target height (int) + ow - target width (int) + + Returns: + The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W) + + Raises: + """ + # permute ([t, h, w, c] -> [t, c, h, w]) + video = video.permute(0, 3, 1, 2) + + # resize and crop + ih, iw = video.shape[2:] + if ih != oh or iw != ow: + # resize + scale = max(ow / iw, oh / ih) + video = F.interpolate( + video, + size=(round(scale * ih), round(scale * iw)), + mode='bicubic', + antialias=True + ) + assert video.size(3) >= ow and video.size(2) >= oh + + # center crop + x1 = (video.size(3) - ow) // 2 + y1 = (video.size(2) - oh) // 2 + video = video[:, :, y1:y1 + oh, x1:x1 + ow] + + # permute ([t, c, h, w] -> [c, t, h, w]) and normalize + video = video.transpose(0, 1).float().div_(127.5).sub_(1.) + return video + + def _video_preprocess(self, video, oh, ow): + return self.resize_crop(video, oh, ow) + + def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng): + target_fps = min(fps, self.max_fps) + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + of = min( + (int(duration * target_fps) - 1) // df + 1, + int(self.seq_len / area_z) + ) + + # deduce target shape of the [latent video] + target_area_z = min(area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = of / target_fps + begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration) + timestamps = np.linspace(begin, begin + target_duration, of) + frame_ids = np.argmax(np.logical_and( + timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] < frame_timestamps[None, :, 1] + ), axis=1).tolist() + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng): + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + of = min( + (len(frame_timestamps) - 1) // df + 1, + int(self.seq_len / area_z) + ) + + # deduce target shape of the [latent video] + target_area_z = min(area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = duration + target_fps = of / target_duration + timestamps = np.linspace(0., target_duration, of) + frame_ids = np.argmax(np.logical_and( + timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] <= frame_timestamps[None, :, 1] + ), axis=1).tolist() + # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids)) + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + + def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): + if self.keep_last: + return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng) + else: + return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng) + + def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs): + rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) + # read video + import decord + decord.bridge.set_bridge('torch') + readers = [] + for data_k in data_key_batch: + reader = decord.VideoReader(data_k) + readers.append(reader) + + fps = readers[0].get_avg_fps() + length = min([len(r) for r in readers]) + frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)] + frame_timestamps = np.array(frame_timestamps, dtype=np.float32) + h, w = readers[0].next().shape[:2] + frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng) + + # preprocess video + videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] + videos = [self._video_preprocess(video, oh, ow) for video in videos] + return *videos, frame_ids, (oh, ow), fps + # return videos if len(videos) > 1 else videos[0] + + +def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device): + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_video is None and sub_src_mask is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device) + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + for j, ref_img in enumerate(ref_images): + if ref_img is not None and ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + src_ref_images[i][j] = white_canvas + return src_video, src_mask, src_ref_images \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/wan_bridge.py b/src/megatron/bridge/models/wan/wan_bridge.py index b37540bcc9..ebcbf8e1c4 100644 --- a/src/megatron/bridge/models/wan/wan_bridge.py +++ b/src/megatron/bridge/models/wan/wan_bridge.py @@ -15,8 +15,8 @@ from functools import partial import torch -from megatron.bridge.models.wan.wan_model import WanModel -from diffusers import WanTransformer3DModel +from megatron.bridge.models.wan.wan_model import WanModel, VACEModel +from diffusers import WanTransformer3DModel, WanVACETransformer3DModel from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge @@ -27,8 +27,8 @@ KVMapping, ReplicatedMapping, ) -from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN -from megatron.bridge.models.wan.wan_provider import WanModelProvider +from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN, PreTrainedVACE +from megatron.bridge.models.wan.wan_provider import WanModelProvider, VACEModelProvider from megatron.core.transformer.utils import openai_gelu from megatron.bridge.models.conversion.utils import get_module_and_param_from_name @@ -192,4 +192,220 @@ def hf_to_megatron(self, hf_weights, megatron_module): ] ) + return MegatronMappingRegistry(*mapping_list) + + +@MegatronModelBridge.register_bridge(source=WanVACETransformer3DModel, target=VACEModel) +class VACEBridge(MegatronModelBridge): + """ + Megatron Bridge for VACE model. + + As a user you would not use this bridge directly, but through `AutoBridge`. + + Example: + >>> from megatron.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("WAN-3D-1.3B-v1") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedVACE) -> VACEModelProvider: + hf_config = hf_pretrained.config + + cls = VACEModelProvider + + provider = cls( + num_layers=hf_config.num_layers, + hidden_size=hf_config.num_attention_heads * hf_config.attention_head_dim, + kv_channels=hf_config.attention_head_dim, + num_query_groups=hf_config.num_attention_heads, + crossattn_emb_size=hf_config.num_attention_heads * hf_config.attention_head_dim, + ffn_hidden_size=hf_config.ffn_dim, + num_attention_heads=hf_config.num_attention_heads, + activation_func=openai_gelu, + in_channels=hf_config.in_channels, + out_channels=hf_config.out_channels, + text_dim=hf_config.text_dim, + patch_spatial=hf_config.patch_size[1], + patch_temporal=hf_config.patch_size[0], + layernorm_epsilon=hf_config.eps, + hidden_dropout=0, + attention_dropout=0, + use_cpu_initialization=True, + freq_dim=hf_config.freq_dim, + bf16=False, + params_dtype=torch.float32, + vace_in_channels=hf_config.vace_in_channels, + vace_layers=hf_config.vace_layers, + base_num_layers=hf_config.num_layers, + ) + + return provider + + + def mapping_registry(self) -> MegatronMappingRegistry: + """Return MegatronMappingRegistry containing parameter mappings from HF to Megatron format. + + Returns: + MegatronMappingRegistry: Registry of parameter mappings + """ + # Dictionary maps HF parameter names -> Megatron parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "scale_shift_table": "head.modulation", + "patch_embedding.weight": "patch_embedding.weight", + "patch_embedding.bias": "patch_embedding.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "blocks.*.scale_shift_table": "decoder.layers.*.adaLN.modulation", + "blocks.*.attn1.to_out.0.weight": "decoder.layers.*.full_self_attention.linear_proj.weight", + "blocks.*.attn1.to_out.0.bias": "decoder.layers.*.full_self_attention.linear_proj.bias", + "blocks.*.attn1.norm_q.weight": "decoder.layers.*.full_self_attention.q_layernorm.weight", + "blocks.*.attn1.norm_k.weight": "decoder.layers.*.full_self_attention.k_layernorm.weight", + "blocks.*.attn2.to_q.weight": "decoder.layers.*.cross_attention.linear_q.weight", + "blocks.*.attn2.to_q.bias": "decoder.layers.*.cross_attention.linear_q.bias", + "blocks.*.attn2.to_out.0.weight": "decoder.layers.*.cross_attention.linear_proj.weight", + "blocks.*.attn2.to_out.0.bias": "decoder.layers.*.cross_attention.linear_proj.bias", + "blocks.*.attn2.norm_q.weight": "decoder.layers.*.cross_attention.q_layernorm.weight", + "blocks.*.attn2.norm_k.weight": "decoder.layers.*.cross_attention.k_layernorm.weight", + "blocks.*.norm2.weight": "decoder.layers.*.norm3.weight", + "blocks.*.norm2.bias": "decoder.layers.*.norm3.bias", + "blocks.*.ffn.net.0.proj.weight": "decoder.layers.*.mlp.linear_fc1.weight", + "blocks.*.ffn.net.0.proj.bias": "decoder.layers.*.mlp.linear_fc1.bias", + "blocks.*.ffn.net.2.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "blocks.*.ffn.net.2.bias": "decoder.layers.*.mlp.linear_fc2.bias", + "proj_out.weight": "head.head.weight", + "proj_out.bias": "head.head.bias", + + "vace_patch_embedding.weight": "vace_patch_embedding.weight", + "vace_patch_embedding.bias": "vace_patch_embedding.bias", + "vace_blocks.0.proj_in.weight": "vace_init_proj.weight", + "vace_blocks.0.proj_in.bias": "vace_init_proj.bias", + "vace_blocks.*.scale_shift_table": "vace_decoder.layers.*.adaLN.modulation", + "vace_blocks.*.attn1.to_out.0.weight": "vace_decoder.layers.*.full_self_attention.linear_proj.weight", + "vace_blocks.*.attn1.to_out.0.bias": "vace_decoder.layers.*.full_self_attention.linear_proj.bias", + "vace_blocks.*.attn1.norm_q.weight": "vace_decoder.layers.*.full_self_attention.q_layernorm.weight", + "vace_blocks.*.attn1.norm_k.weight": "vace_decoder.layers.*.full_self_attention.k_layernorm.weight", + "vace_blocks.*.attn2.to_q.weight": "vace_decoder.layers.*.cross_attention.linear_q.weight", + "vace_blocks.*.attn2.to_q.bias": "vace_decoder.layers.*.cross_attention.linear_q.bias", + "vace_blocks.*.attn2.to_out.0.weight": "vace_decoder.layers.*.cross_attention.linear_proj.weight", + "vace_blocks.*.attn2.to_out.0.bias": "vace_decoder.layers.*.cross_attention.linear_proj.bias", + "vace_blocks.*.attn2.norm_q.weight": "vace_decoder.layers.*.cross_attention.q_layernorm.weight", + "vace_blocks.*.attn2.norm_k.weight": "vace_decoder.layers.*.cross_attention.k_layernorm.weight", + "vace_blocks.*.norm2.weight": "vace_decoder.layers.*.norm3.weight", + "vace_blocks.*.norm2.bias": "vace_decoder.layers.*.norm3.bias", + "vace_blocks.*.ffn.net.0.proj.weight": "vace_decoder.layers.*.mlp.linear_fc1.weight", + "vace_blocks.*.ffn.net.0.proj.bias": "vace_decoder.layers.*.mlp.linear_fc1.bias", + "vace_blocks.*.ffn.net.2.weight": "vace_decoder.layers.*.mlp.linear_fc2.weight", + "vace_blocks.*.ffn.net.2.bias": "vace_decoder.layers.*.mlp.linear_fc2.bias", + "vace_blocks.*.proj_out.weight":"vace_decoder.layers.*.context_proj.weight", + "vace_blocks.*.proj_out.bias":"vace_decoder.layers.*.context_proj.bias", + } + + + # Custom WAN mapping to safely handle replicated params whose owning module + # does not expose a top-level `.weight` (e.g., Head.modulation) + class _ReplicatedByParamNameMapping(ReplicatedMapping): + def hf_to_megatron(self, hf_weights, megatron_module): + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + + target_device = target_param.device + target_dtype = target_param.dtype + + hf_weights = hf_weights.to(device=target_device, dtype=target_dtype) + if self.tp_size == 1: + return hf_weights + + if target_device.type == "cuda" and torch.cuda.is_available(): + if target_device.index != torch.cuda.current_device(): + hf_weights = hf_weights.to(torch.cuda.current_device()) + + if self.tp_rank > 0: + hf_weights = torch.empty_like(hf_weights) + + return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0) + + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(hf_param, megatron_param) + for hf_param, megatron_param in param_mappings.items(): + if hf_param in {"scale_shift_table", "blocks.*.scale_shift_table", "vace_blocks.*.scale_shift_table", "proj_out.weight", "proj_out.bias"}: + # Use WAN-specific replicated mapping that resolves the exact param + mapping_list.append(_ReplicatedByParamNameMapping(hf_param=hf_param, megatron_param=megatron_param)) + else: + mapping_list.append(AutoMapping(hf_param=hf_param, megatron_param=megatron_param)) + + # Adding custom module types for AutoMapping + AutoMapping.register_module_type("Linear", "replicated") + AutoMapping.register_module_type("Conv3d", "replicated") + AutoMapping.register_module_type("WanAdaLN", "replicated") + AutoMapping.register_module_type("Head", "replicated") + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # QKV: Combine separate Q, K, V matrices into single QKV matrix + QKVMapping( + q="blocks.*.attn1.to_q.weight", + k="blocks.*.attn1.to_k.weight", + v="blocks.*.attn1.to_v.weight", + megatron_param="decoder.layers.*.full_self_attention.linear_qkv.weight", + ), + # QKV bias: Combine separate Q, K, V bias into single QKV bias + QKVMapping( + q="blocks.*.attn1.to_q.bias", + k="blocks.*.attn1.to_k.bias", + v="blocks.*.attn1.to_v.bias", + megatron_param="decoder.layers.*.full_self_attention.linear_qkv.bias", + ), + # K, V: Combine separate K, V matrices into single KV matrix + KVMapping( + k="blocks.*.attn2.to_k.weight", + v="blocks.*.attn2.to_v.weight", + megatron_param="decoder.layers.*.cross_attention.linear_kv.weight", + ), + # K, V bias: Combine separate K, V bias into single KV bias + KVMapping( + k="blocks.*.attn2.to_k.bias", + v="blocks.*.attn2.to_v.bias", + megatron_param="decoder.layers.*.cross_attention.linear_kv.bias", + ), + + # QKV: Combine separate Q, K, V matrices into single QKV matrix + QKVMapping( + q="vace_blocks.*.attn1.to_q.weight", + k="vace_blocks.*.attn1.to_k.weight", + v="vace_blocks.*.attn1.to_v.weight", + megatron_param="vace_decoder.layers.*.full_self_attention.linear_qkv.weight", + ), + # QKV bias: Combine separate Q, K, V bias into single QKV bias + QKVMapping( + q="vace_blocks.*.attn1.to_q.bias", + k="vace_blocks.*.attn1.to_k.bias", + v="vace_blocks.*.attn1.to_v.bias", + megatron_param="vace_decoder.layers.*.full_self_attention.linear_qkv.bias", + ), + # K, V: Combine separate K, V matrices into single KV matrix + KVMapping( + k="vace_blocks.*.attn2.to_k.weight", + v="vace_blocks.*.attn2.to_v.weight", + megatron_param="vace_decoder.layers.*.cross_attention.linear_kv.weight", + ), + # K, V bias: Combine separate K, V bias into single KV bias + KVMapping( + k="vace_blocks.*.attn2.to_k.bias", + v="vace_blocks.*.attn2.to_v.bias", + megatron_param="vace_decoder.layers.*.cross_attention.linear_kv.bias", + ), + ] + ) + return MegatronMappingRegistry(*mapping_list) \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/wan_layer_spec.py b/src/megatron/bridge/models/wan/wan_layer_spec.py index f98576ada1..b3b652ec1b 100644 --- a/src/megatron/bridge/models/wan/wan_layer_spec.py +++ b/src/megatron/bridge/models/wan/wan_layer_spec.py @@ -368,6 +368,12 @@ class WanWithAdaLNSubmodules(TransformerLayerSubmodules): norm1: Union[ModuleSpec, type] = None norm3: Union[ModuleSpec, type] = None norm2: Union[ModuleSpec, type] = None + context_proj: Union[ModuleSpec, type] = IdentityOp + + +# @dataclass +# class VACEContextLayerSubmodules(WanWithAdaLNSubmodules): + class WanAdaLN(MegatronModule): @@ -416,7 +422,7 @@ def __init__( vp_stage: Optional[int] = None, ): super().__init__( - config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout + config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout, pg_collection=pg_collection, vp_stage=vp_stage ) # # TODO: Override Cross Attention to disable TP Comm overlap as well. ??? @@ -545,6 +551,143 @@ def forward( return output, context +class VACEBaseLayer(WanLayerWithAdaLN): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + DiT with Adapative Layer Normalization. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + ): + super().__init__( + config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout, pg_collection=pg_collection, vp_stage=vp_stage + ) + + + def forward( + self, + hidden_states, + attention_mask=None, + context=None, + context_mask=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + inference_context=None, + ): + + hidden_states, context = super().forward( + hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=None, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + inference_context=inference_context, + ) + # consider how to pass block id and context_scale + # the context_tokens from context branch is stored in context_mask argument + if self.idx: + hidden_states = hidden_states + context_mask[self.idx] * self.context_scale + + return hidden_states, context + + +class VACEContextLayer(WanLayerWithAdaLN): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + DiT with Adapative Layer Normalization. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + ): + super().__init__( + config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout, pg_collection=pg_collection, vp_stage=vp_stage + ) + + self.context_proj = build_module( + submodules.context_proj, + self.config.hidden_size, + self.config.hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='proj', + tp_group=self.pg_collection.tp, + ) + + + def forward( + self, + hidden_states, + attention_mask=None, + context=None, + context_mask=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + inference_context=None, + ): + + all_hidden_states = list(torch.unbind(hidden_states)) + hidden_states = all_hidden_states.pop(-1) + hidden_states, context = super().forward( + hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=None, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + inference_context=inference_context, + ) + hidden_states_proj, bias = self.context_proj(hidden_states) + all_hidden_states += [hidden_states_proj + bias, hidden_states] + hidden_states = torch.stack(all_hidden_states) + + return hidden_states, context + + import transformer_engine as te def get_wan_block_with_transformer_engine_spec() -> ModuleSpec: params = {"attn_mask_type": AttnMaskType.padding} @@ -589,3 +732,95 @@ def get_wan_block_with_transformer_engine_spec() -> ModuleSpec: ), ), ) + + +def get_vace_base_block_with_transformer_engine_spec() -> ModuleSpec: + params = {"attn_mask_type": AttnMaskType.padding} + return ModuleSpec( + module=VACEBaseLayer, + submodules=WanWithAdaLNSubmodules( + norm1=WanLayerNorm, + norm3=WanLayerNorm, + norm2=WanLayerNorm, + full_self_attention=ModuleSpec( + module=WanSelfAttention, + params=params, + submodules=WanSelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + cross_attention=ModuleSpec( + module=WanCrossAttention, + params=params, + submodules=WanCrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + # by default, activation_func is openai_gelu, which is equivalent to nn.GELU(approximate='tanh') + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) + + +def get_vace_context_block_with_transformer_engine_spec() -> ModuleSpec: + params = {"attn_mask_type": AttnMaskType.padding} + return ModuleSpec( + module=VACEContextLayer, + submodules=WanWithAdaLNSubmodules( + norm1=WanLayerNorm, + norm3=WanLayerNorm, + norm2=WanLayerNorm, + full_self_attention=ModuleSpec( + module=WanSelfAttention, + params=params, + submodules=WanSelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + cross_attention=ModuleSpec( + module=WanCrossAttention, + params=params, + submodules=WanCrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + # by default, activation_func is openai_gelu, which is equivalent to nn.GELU(approximate='tanh') + linear_fc2=TERowParallelLinear, + ), + ), + context_proj=TERowParallelLinear + ), + ) + diff --git a/src/megatron/bridge/models/wan/wan_model.py b/src/megatron/bridge/models/wan/wan_model.py index d11b780313..800a26c37f 100644 --- a/src/megatron/bridge/models/wan/wan_model.py +++ b/src/megatron/bridge/models/wan/wan_model.py @@ -15,6 +15,7 @@ # pylint: disable=C0115,C0116,C0301 from typing import Dict, Literal, Optional, Tuple, List, Union +import copy import math import torch @@ -25,16 +26,118 @@ from megatron.core.models.common.vision_module.vision_module import VisionModule from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer.enums import ModelType -from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_block import TransformerBlock, TransformerBlockSubmodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import make_sharded_tensor_for_checkpoint from megatron.bridge.models.wan.wan_layer_spec import ( get_wan_block_with_transformer_engine_spec as WanLayerWithAdaLNspec, + get_vace_base_block_with_transformer_engine_spec as VACEBaseLayerspec, + get_vace_context_block_with_transformer_engine_spec as VACEContextLayerspec, ) from megatron.bridge.models.wan.wan_layer_spec import WanLayerNorm from torch import Tensor from .rope_utils import Wan3DRopeEmbeddings +from contextlib import nullcontext +from megatron.core.fp4_utils import get_fp4_context +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset +from megatron.core.utils import get_pg_rank + +class IndexTransformerBlock(TransformerBlock): + def __init__( + self, + config: TransformerConfig, + spec: Union[TransformerBlockSubmodules, ModuleSpec], + post_layer_norm: bool = True, + pre_process: bool = True, + post_process: bool = True, + pg_collection: ProcessGroupCollection = None, + vp_stage: Optional[int] = None, + ): + # Pass block id and context_scale + self.vace_layers = [i for i in range(0, config.num_layers, 2)] if config.vace_layers is None else config.vace_layers + print(self.vace_layers) + assert 0 in self.vace_layers + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + super().__init__( + config=config, + spec=spec, + post_layer_norm=post_layer_norm, + pre_process=pre_process, + post_process=post_process, + pg_collection=pg_collection, + vp_stage=vp_stage, + ) + + def _build_layers(self): + # Transformer layers. + # @jcasper can we improve how we deal with layer_number? + # currently it's only used in CoreAttention? + # if self.apply_query_key_layer_scaling: + # coeff = self.layer_number + # self.norm_factor *= coeff + def build_layer(layer_spec, layer_number): + global_layer_number = layer_number + get_transformer_layer_offset( + self.config, self.vp_stage, get_pg_rank(self.pg_collection.pp) + ) # 1-based index + if self.config.heterogeneous_block_specs: + layer_config = self.config.get_config_for_layer(global_layer_number) + else: + layer_config = self.config + + # Get appropriate quantization context (FP8 and FP4 are mutually exclusive) + if layer_config.fp8: + quantization_context = get_fp8_context( + layer_config, global_layer_number - 1, is_init=True + ) + elif layer_config.fp4: + quantization_context = get_fp4_context( + layer_config, global_layer_number - 1, is_init=True + ) + else: + quantization_context = nullcontext() + + with quantization_context: + module = build_module( + layer_spec, + config=layer_config, + layer_number=layer_number, + pg_collection=self.pg_collection, + vp_stage=self.vp_stage, + ) + idx = global_layer_number - 1 + if idx in self.vace_layers: + module.idx = self.vace_layers_mapping[idx] + module.context_scale = self.config.context_scale + else: + module.idx = None + return module + + # offset is implicit in TransformerLayer + self.layers = torch.nn.ModuleList( + [ + build_layer(layer_spec, i + 1) + for i, layer_spec in enumerate(self.submodules.layer_specs) + ] + ) + + # @TODO: add back account_for_embedding_in_pipeline_split (see issue #293) + # In pipeline parallelism, we want to add this LN only to the last stage of the pipeline + # self.post_process and self.post_layer_norm guide this behavior + if self.submodules.layer_norm and self.post_process and self.post_layer_norm: + self.final_layernorm = build_module( + self.submodules.layer_norm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + else: + self.final_layernorm = None # Either this or nn.Identity + def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 @@ -168,7 +271,7 @@ def forward( """Forward pass. Args: - x List[Tensor]: list of vae encoded data (in_channel, f, h, w) + x List[Tensor]: list of vae encoded data (s, b, c * pF * pH * pW) grid_sizes List[Tuple[int, int, int]]: list of grid sizes (f, h, w) t Tensor: timesteps context List[Tensor]: list of context (text_len, hidden_size) @@ -187,7 +290,7 @@ def forward( if self.pre_process: # x.shape [s, b, c * pF * pH * pW] seq_len, batch_size, _ = x.shape - c = self.out_channels + c = self.in_channels pF, pH, pW = self.patch_size x = x.reshape(seq_len * batch_size, pF, pH, pW, c) # output: x.shape [s * b, pF, pH, pW, c] x = x.permute(0, 4, 1, 2, 3) # output: x.shape [s * b, c, pF, pH, pW] @@ -268,7 +371,7 @@ def set_input_tensor(self, input_tensor: Tensor) -> None: def sharded_state_dict( - self, prefix: str = "module.", sharded_offsets: tuple = (), metadata: Optional[Dict] = None + self, prefix: str = "", sharded_offsets: tuple = (), metadata: Optional[Dict] = None ) -> ShardedStateDict: """Sharded state dict implementation for GPTModel backward-compatibility (removing extra state). @@ -330,3 +433,178 @@ def _set_embedder_weights_replica_id( replica_id=replica_id, allow_shape_mismatch=False, ) + + +class VACEModel(WanModel): + def __init__( + self, + config: TransformerConfig, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + transformer_decoder_layer_spec=VACEBaseLayerspec, + vace_transformer_decoder_layer_spec=VACEContextLayerspec, + **kwargs, + ): + super().__init__( + config, + pre_process, + post_process, + fp16_lm_cross_entropy, + parallel_output, + transformer_decoder_layer_spec, + **kwargs + ) + + self.vace_in_channels = self.config.vace_in_channels + self.vace_transformer_decoder_layer_spec = vace_transformer_decoder_layer_spec() + + if self.pre_process: + self.vace_patch_embedding = nn.Conv3d( + self.vace_in_channels, self.config.hidden_size, kernel_size=self.patch_size, stride=self.patch_size) + + self.decoder = IndexTransformerBlock( + config=self.config, + spec=self.transformer_decoder_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=False, + ) + # print(self.decoder) + self.vace_config = copy.deepcopy(self.config) + self.vace_config.num_layers = len(self.decoder.vace_layers) + self.vace_decoder = TransformerBlock( + config=self.vace_config, + spec=self.vace_transformer_decoder_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=False, + ) + # print(self.vace_decoder.state_dict().keys()) + + self.vace_init_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size) + + + def forward( + self, + x: Tensor, + grid_sizes: list[Tuple[int, int, int]], + t: Tensor, + context: Tensor, + vace_context: Tensor, + max_seq_len: int, + packed_seq_params: PackedSeqParams = None, + **kwargs, + ) -> Tensor: + """Forward pass. + + Args: + x List[Tensor]: list of vae encoded data (s, b, c * pF * pH * pW) + grid_sizes List[Tuple[int, int, int]]: list of grid sizes (f, h, w) + t Tensor: timesteps + context List[Tensor]: list of context (text_len, hidden_size) + max_seq_len int: maximum sequence length + packed_seq_params PackedSeqParams: packed sequence parameters + + Returns: + Tensor: output tensor (still patchified) of shape [seq_len, batch_size, hidden_size] + """ + ################################# + ########## Wan forward ########## + + # ============= embedders ============= + + # run input embedding + if self.pre_process: + # x.shape [s, b, c * pF * pH * pW] + seq_len, batch_size, _ = x.shape + c = self.in_channels + pF, pH, pW = self.patch_size + x = x.reshape(seq_len * batch_size, pF, pH, pW, c) # output: x.shape [s * b, pF, pH, pW, c] + x = x.permute(0, 4, 1, 2, 3) # output: x.shape [s * b, c, pF, pH, pW] + x = self.patch_embedding(x) # output: x.shape [s * b, hidden_size, 1, 1, 1] + x = x.flatten(1) # output: x.shape [s * b, hidden_size] + x = x.reshape(seq_len, batch_size, -1) # output: x.shape [s, b, hidden_size] + + # vace_context.shape [s, b, c * pF * pH * pW] + vace_seq_len, _, _ = vace_context.shape + vace_c = self.vace_in_channels + # pF, pH, pW = self.patch_size + vace_context = vace_context.reshape(vace_seq_len * batch_size, pF, pH, pW, vace_c) # output: vace_context.shape [s * b, pF, pH, pW, c] + vace_context = vace_context.permute(0, 4, 1, 2, 3) # output: vace_context.shape [s * b, c, pF, pH, pW] + vace_context = self.vace_patch_embedding(vace_context) # output: vace_context.shape [s * b, hidden_size, 1, 1, 1] + vace_context = vace_context.flatten(1) # output: vace_context.shape [s * b, hidden_size] + vace_context = vace_context.reshape(vace_seq_len, batch_size, -1) # output: vace_context.shape [s, b, hidden_size] + vace_context = self.vace_init_proj(vace_context) + x + vace_context = vace_context.unsqueeze(0) + + # split sequence for sequence_parallel + # TODO: for PP, do we move scatter_to_sequence_parallel_region here or after "x = self.decoder.input_tensor" ??? + if self.config.sequence_parallel: + x = tensor_parallel.scatter_to_sequence_parallel_region(x) # output: x.shape [s * b // tp_size, hidden_size] + vace_context = tensor_parallel.scatter_to_sequence_parallel_region(vace_context) # output: vace_context.shape [s * b // tp_size, hidden_size] + + else: + # intermediate stage of pipeline + x = self.decoder.input_tensor + vace_context = self.vace_decoder.input_tensor + + # run context token embedding + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(x.dtype) + ) + e0 = self.time_projection(e).unflatten(1, (6, self.config.hidden_size)) + + # context embeddings + context = self.text_embedding(context) # shape [text_len, b, hidden_size] + + + # ============= decoder ============= + # calculate rotary pos emb + n_head, dim_head = self.num_heads, self.config.hidden_size // self.num_heads + rotary_pos_emb = self.rope_embeddings(n_head, dim_head, max_seq_len, grid_sizes, t.device) # output: rotary_pos_emb.shape [s, b, 1, dim_head] + + # run vace decoder + vace_context = self.vace_decoder( + hidden_states=vace_context, + attention_mask=e0, + context=context, + context_mask=None, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=None, + rotary_pos_sin=None, + packed_seq_params=packed_seq_params, + )[:-1] + + # run decoder + x = self.decoder( + hidden_states=x, + attention_mask=e0, + context=context, + context_mask=vace_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=None, + rotary_pos_sin=None, + packed_seq_params=packed_seq_params, + ) + + # return if not post_process + if not self.post_process: + return x + + # head + x = x.transpose(0, 1) # head expects shape [b, s, hidden_size] + x = self.head(x, e) # output: x.shape [b, s, c * pF * pH * pW] + x = x.transpose(0, 1) # reshape back to shape [s, b, c * pF * pH * pW] + + # gather outputs for sequence_parallel + # Note: in GPT models, because the vocab projection matrix is ColumnParallelLinear, the sequence is + # automatically gathered in ColumnParallelLinear forward pass. + # However, in Wan models, we need to gather the outputs manually. + if self.config.sequence_parallel: + x = tensor_parallel.gather_from_sequence_parallel_region(x) + + return x # output: x.shape [s, b, c * pF * pH * pW] \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/wan_provider.py b/src/megatron/bridge/models/wan/wan_provider.py index fab72afcc4..2663745fa4 100644 --- a/src/megatron/bridge/models/wan/wan_provider.py +++ b/src/megatron/bridge/models/wan/wan_provider.py @@ -21,7 +21,7 @@ from megatron.bridge.models.model_provider import ModelProviderMixin from megatron.core.models.common.vision_module.vision_module import VisionModule -from megatron.bridge.models.wan.wan_model import WanModel +from megatron.bridge.models.wan.wan_model import WanModel, VACEModel logger = logging.getLogger(__name__) @@ -72,6 +72,33 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> WanMode model = WanModel + return model( + self, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + ) + + +@dataclass +class VACEModelProvider(WanModelProvider): + vace_layers: list = None + # vace_layers: list = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28] + vace_in_channels: int = 96 + base_num_layers: int = 30 + context_scale: float = 1.0 + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> VACEModel: + vp_size = self.virtual_pipeline_model_parallel_size + if vp_size: + p_size = self.pipeline_model_parallel_size + assert (self.num_layers // p_size) % vp_size == 0, ( + "Make sure the number of model chunks is the same across all pipeline stages." + ) + + model = VACEModel + return model( self, pre_process=parallel_state.is_pipeline_first_stage(), diff --git a/vace.sh b/vace.sh new file mode 100644 index 0000000000..b913f9cb8c --- /dev/null +++ b/vace.sh @@ -0,0 +1,28 @@ +export CUDA_VISIBLE_DEVICES=0 + +### Inferencing +# Download T5 weights and VAE weights from "https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/tree/main" +# T5: models_t5_umt5-xxl-enc-bf16.pth, google +# VAE: Wan2.1_VAE.pth + +CHECKPOINT_DIR=/opt/megatron_checkpoint_VACE +T5_DIR=~/.cache/huggingface/hub/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a +VAE_DIR=~/.cache/huggingface/hub/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a + +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_vace.py \ + --model_name vace-1.3B \ + --sizes 832*480 \ + --src_video "test.mp4" \ + --src_mask "src_mask.mp4" \ + --checkpoint_dir ${CHECKPOINT_DIR} \ + --checkpoint_step 0000 \ + --t5_checkpoint_dir ${T5_DIR} \ + --vae_checkpoint_dir ${VAE_DIR} \ + --prompts "Two dogs hit each other during boxing." \ + --frame_nums 81 \ + --tensor_parallel_size 1 \ + --context_parallel_size 1 \ + --pipeline_parallel_size 1 \ + --sequence_parallel False \ + --base_seed 42 \ + --sample_steps 50 \ No newline at end of file From e8e30d2e9c779ab5eec4fe5e9ed0a5944c410dbb Mon Sep 17 00:00:00 2001 From: Tiancheng Zhao Date: Sat, 15 Nov 2025 05:13:43 +0000 Subject: [PATCH 12/17] hf verification --- examples/recipes/wan/inference_vace.py | 24 +++++++++--------- .../flow_matching/flow_inference_pipeline.py | 25 +++++++++++++++++++ .../bridge/models/wan/wan_layer_spec.py | 4 ++- vace.sh | 14 +++++------ 4 files changed, 47 insertions(+), 20 deletions(-) diff --git a/examples/recipes/wan/inference_vace.py b/examples/recipes/wan/inference_vace.py index 382cb2dd3a..a1d66f4003 100644 --- a/examples/recipes/wan/inference_vace.py +++ b/examples/recipes/wan/inference_vace.py @@ -238,22 +238,22 @@ def generate(args): args.base_seed = base_seed[0] if args.prompts is None: - prompts = [EXAMPLE_PROMPT[args.model_name]["prompt"]] + prompts = [None] else: prompts = args.prompts if args.src_video is None: - src_video = [EXAMPLE_PROMPT[args.model_name].get("src_video", None)] + src_video = [None] else: src_video = args.src_video if args.src_mask is None: - src_mask = [EXAMPLE_PROMPT[args.model_name].get("src_mask", None)] + src_mask = [None] else: src_mask = args.src_mask if args.src_ref_images is None: - src_ref_images = [EXAMPLE_PROMPT[args.model_name].get("src_ref_images", None)] + src_ref_images = [None] else: src_ref_images = args.src_ref_images @@ -302,8 +302,8 @@ def generate(args): for i in range(len(src_video)): sub_src_video, sub_src_mask, sub_src_ref_images = pipeline.prepare_source([src_video[i]], - [None], - [None], + [src_mask[i]], + [src_ref_images[i]], frame_nums[i], SIZE_CONFIGS[size_keys[i]], device) src_video[i], src_mask[i], src_ref_images[i] = *sub_src_video, *sub_src_mask, *sub_src_ref_images @@ -345,31 +345,31 @@ def generate(args): cache_video( tensor=src_video[i][None], - save_file=f'{i}_src_video.mp4', + save_file=f'{args.model_name}_{formatted_experiment_name}_index{i}_src_video_{formatted_time}.mp4', fps=cfg.sample_fps, nrow=1, normalize=True, value_range=(-1, 1)) - logging.info(f"Saving src_video to {i}_src_video.mp4") + logging.info(f"Saving src_video to {args.model_name}_{formatted_experiment_name}_index{i}_src_video_{formatted_time}.mp4") cache_video( tensor=src_mask[i][None], - save_file=f'{i}_src_mask.mp4', + save_file=f'{args.model_name}_{formatted_experiment_name}_index{i}_src_mask_{formatted_time}.mp4', fps=cfg.sample_fps, nrow=1, normalize=True, value_range=(0, 1)) - logging.info(f"Saving src_mask to {i}_src_mask.mp4") + logging.info(f"Saving src_mask to {args.model_name}_{formatted_experiment_name}_index{i}_src_mask_{formatted_time}.mp4") if src_ref_images[i] is not None: for j, ref_img in enumerate(src_ref_images[i]): cache_image( tensor=ref_img[:, 0, ...], - save_file=f'{i}_src_ref_image_{j}.png', + save_file=f'{args.model_name}_{formatted_experiment_name}_index{i}_src_ref_image_{j}_{formatted_time}.png', nrow=1, normalize=True, value_range=(-1, 1)) - logging.info(f"Saving src_ref_image_{j} to {i}_src_ref_image_{j}.png") + logging.info(f"Saving src_ref_image_{j} to {args.model_name}_{formatted_experiment_name}_index{i}_src_ref_image_{j}_{formatted_time}.png") logging.info("Finished.") diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py index 2f82c7f962..7da300c706 100644 --- a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py +++ b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py @@ -1019,6 +1019,9 @@ def generate(self, vace_context0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks) mask0 = self.vace_encode_masks(input_masks, input_ref_images) vace_context = self.vace_latent(vace_context0, mask0) + + # # for huggingface inference, latent shape: B, C_latent, N/4, H/8, W/8 + # vace_context_hf = torch.stack(vace_context) max_video_seq_len = 0 seq_lens = [] @@ -1163,6 +1166,11 @@ def noop_no_sync(): arg_c = {'context': contexts, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} arg_null = {'context': contexts_null, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} + + from megatron.bridge.models.hf_pretrained.wan import PreTrainedVACE + hf = PreTrainedVACE("Wan-AI/Wan2.1-VACE-1.3B-Diffusers")._load_model().to(self.device) + + for _, t in enumerate(tqdm(timesteps)): batch_size = len(latents) @@ -1197,6 +1205,23 @@ def noop_no_sync(): # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. unpatchified_noise_pred_uncond = self.unpatchify(unpatchified_noise_pred_uncond, grid_sizes, self.vae.model.z_dim) + + # # for huggingface inference + # unpatchified_latents = torch.stack(latents) + # timestep = [t] * batch_size + # timestep = torch.stack(timestep) + # unpatchified_noise_pred_cond=hf(hidden_states=unpatchified_latents, + # timestep=timestep, + # encoder_hidden_states=contexts.transpose(0,1), + # control_hidden_states=vace_context_hf, + # return_dict=False)[0] + # unpatchified_noise_pred_uncond=hf(hidden_states=unpatchified_latents, + # timestep=timestep, + # encoder_hidden_states=contexts_null.transpose(0,1), + # control_hidden_states=vace_context_hf, + # return_dict=False)[0] + + noise_preds = [] for i in range(batch_size): noise_pred = unpatchified_noise_pred_uncond[i] + guide_scale * ( diff --git a/src/megatron/bridge/models/wan/wan_layer_spec.py b/src/megatron/bridge/models/wan/wan_layer_spec.py index b3b652ec1b..16d3931e29 100644 --- a/src/megatron/bridge/models/wan/wan_layer_spec.py +++ b/src/megatron/bridge/models/wan/wan_layer_spec.py @@ -606,8 +606,10 @@ def forward( ) # consider how to pass block id and context_scale # the context_tokens from context branch is stored in context_mask argument - if self.idx: + if self.idx is not None: hidden_states = hidden_states + context_mask[self.idx] * self.context_scale + # hidden_states = hidden_states + context_mask[self.idx] * 2.0 + # hidden_states = hidden_states + torch.rand_like(context_mask[self.idx]) * 0.05 return hidden_states, context diff --git a/vace.sh b/vace.sh index b913f9cb8c..9d8778ec17 100644 --- a/vace.sh +++ b/vace.sh @@ -1,4 +1,4 @@ -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=0,1 ### Inferencing # Download T5 weights and VAE weights from "https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/tree/main" @@ -6,21 +6,21 @@ export CUDA_VISIBLE_DEVICES=0 # VAE: Wan2.1_VAE.pth CHECKPOINT_DIR=/opt/megatron_checkpoint_VACE -T5_DIR=~/.cache/huggingface/hub/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a -VAE_DIR=~/.cache/huggingface/hub/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a +T5_DIR=/opt/Wan2.1-T2V-1.3B +VAE_DIR=/opt/Wan2.1-T2V-1.3B -NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_vace.py \ +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_vace.py \ --model_name vace-1.3B \ --sizes 832*480 \ - --src_video "test.mp4" \ - --src_mask "src_mask.mp4" \ + --save_file "depth" \ + --src_video "src_video_depth.mp4" \ --checkpoint_dir ${CHECKPOINT_DIR} \ --checkpoint_step 0000 \ --t5_checkpoint_dir ${T5_DIR} \ --vae_checkpoint_dir ${VAE_DIR} \ --prompts "Two dogs hit each other during boxing." \ --frame_nums 81 \ - --tensor_parallel_size 1 \ + --tensor_parallel_size 2 \ --context_parallel_size 1 \ --pipeline_parallel_size 1 \ --sequence_parallel False \ From 59d3e990f0afdea9b5c1807991ad4b6b3361a0a7 Mon Sep 17 00:00:00 2001 From: Tiancheng Zhao Date: Tue, 18 Nov 2025 17:18:04 +0000 Subject: [PATCH 13/17] add support for tp and cp --- .../flow_matching/flow_inference_pipeline.py | 41 +++++++++++++- src/megatron/bridge/models/wan/utils/utils.py | 56 ++++++++++++++++++- .../bridge/models/wan/wan_layer_spec.py | 28 ++++++++-- .../bridge/models/wan/wan_provider.py | 3 +- vace.sh | 4 +- 5 files changed, 121 insertions(+), 11 deletions(-) diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py index 7da300c706..ddf523f795 100644 --- a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py +++ b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py @@ -29,7 +29,7 @@ from megatron.bridge.models.wan.utils.utils import grid_sizes_calculation, patchify from megatron.core import parallel_state from torch.nn import functional as F -from megatron.bridge.models.wan.utils.utils import cat_outputs_cp +from megatron.bridge.models.wan.utils.utils import split_inputs_cp, cat_outputs_cp, thd_split_inputs_cp, thd_cat_outputs_cp import math from typing import Tuple, Union @@ -470,6 +470,12 @@ def noop_no_sync(): qkv_format=self.model.config.qkv_format, ), } + + + # context parallel + if parallel_state.get_context_parallel_world_size() > 1: + contexts = thd_split_inputs_cp(contexts, packed_seq_params['cross_attention'].cu_seqlens_kv, parallel_state.get_context_parallel_group()) + contexts_null = thd_split_inputs_cp(contexts_null, packed_seq_params['cross_attention'].cu_seqlens_kv, parallel_state.get_context_parallel_group()) arg_c = {'context': contexts, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} @@ -488,6 +494,11 @@ def noop_no_sync(): latents = torch.stack(latents, dim=1) + # context parallel + if parallel_state.get_context_parallel_world_size() > 1: + latents = thd_split_inputs_cp(latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + + latent_model_input = latents timestep = [t] * batch_size timestep = torch.stack(timestep) @@ -499,6 +510,13 @@ def noop_no_sync(): noise_pred_uncond = self.forward_pp_step( latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_null) + + # context parallel + if parallel_state.get_context_parallel_world_size() > 1: + noise_pred_cond = thd_cat_outputs_cp(noise_pred_cond, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + noise_pred_uncond = thd_cat_outputs_cp(noise_pred_uncond, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + + # run unpatchify unpatchified_noise_pred_cond = noise_pred_cond unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd @@ -1161,7 +1179,14 @@ def noop_no_sync(): qkv_format=self.model.config.qkv_format, ), } - + + + # context parallel + if parallel_state.get_context_parallel_world_size() > 1: + vace_context = thd_split_inputs_cp(vace_context, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + contexts = thd_split_inputs_cp(contexts, packed_seq_params['cross_attention'].cu_seqlens_kv, parallel_state.get_context_parallel_group()) + contexts_null = thd_split_inputs_cp(contexts_null, packed_seq_params['cross_attention'].cu_seqlens_kv, parallel_state.get_context_parallel_group()) + arg_c = {'context': contexts, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} arg_null = {'context': contexts_null, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} @@ -1184,6 +1209,11 @@ def noop_no_sync(): latents = torch.stack(latents, dim=1) + # context parallel + if parallel_state.get_context_parallel_world_size() > 1: + latents = thd_split_inputs_cp(latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + + latent_model_input = latents timestep = [t] * batch_size timestep = torch.stack(timestep) @@ -1195,6 +1225,13 @@ def noop_no_sync(): noise_pred_uncond = self.forward_pp_step( latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, vace_context=vace_context, arg_c=arg_null) + + # context parallel + if parallel_state.get_context_parallel_world_size() > 1: + noise_pred_cond = thd_cat_outputs_cp(noise_pred_cond, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + noise_pred_uncond = thd_cat_outputs_cp(noise_pred_uncond, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + + # run unpatchify unpatchified_noise_pred_cond = noise_pred_cond unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd diff --git a/src/megatron/bridge/models/wan/utils/utils.py b/src/megatron/bridge/models/wan/utils/utils.py index 9fc8655592..0f93526632 100644 --- a/src/megatron/bridge/models/wan/utils/utils.py +++ b/src/megatron/bridge/models/wan/utils/utils.py @@ -164,4 +164,58 @@ def thd_split_inputs_cp(x: torch.Tensor, # Return to [S, B, ...] x_local = x_local_bs.transpose(0, 1).contiguous() # [S_local, B, ...] - return x_local \ No newline at end of file + return x_local + + +def thd_cat_outputs_cp(x_local: torch.Tensor, + cu_seqlens_q_padded: torch.Tensor, + cp_group: dist.ProcessGroup) -> torch.Tensor: + """ + Reverse of thd_split_inputs_cp: gather THD-partitioned local shards back to global. + + Args: + x_local: [S_local, B, ...] tensor (this rank's shard, sequence first). + cu_seqlens_q_padded: 1D int32 THD cu_seqlens (padded) used for packing. + cp_group: context-parallel process group. + + Returns: + x_global: [S, B, ...] tensor reassembled across CP ranks. + """ + # Work in [B, S_local, ...] for easy indexing along S + x_local_bs = x_local.transpose(0, 1).contiguous() # [B, S_local, ...] + + cp_size = dist.get_world_size(cp_group) + cp_rank = dist.get_rank(cp_group) + + # Discover total S from cu_seqlens (last value) + # (Matches 'total_S' used during split.) + total_S = int(cu_seqlens_q_padded[-1].item()) + + # All-gather local shards across CP group + gather_list = [torch.empty_like(x_local_bs) for _ in range(cp_size)] + dist.all_gather(gather_list, x_local_bs, group=cp_group) # each is [B, S_r, ...] + + # Compute per-rank indices once (same device/dtype as input) + # NOTE: tex.thd_get_partitioned_indices returns indices along S for that rank. + idx_list = [] + for r in range(cp_size): + idx_r = tex.thd_get_partitioned_indices( + cu_seqlens_q_padded, # int32 offsets + total_S, + cp_size, + r, + ).to(device=x_local_bs.device, dtype=torch.long) # [S_r] + idx_list.append(idx_r) + + # Allocate output [B, S, ...] and place each rank's slice back + out_shape = list(x_local_bs.shape) + out_shape[1] = total_S # replace S_local with S + x_global_bs = x_local_bs.new_zeros(out_shape) # [B, S, ...] + + # index_copy_ along S dimension + for shard, idx in zip(gather_list, idx_list): + x_global_bs.index_copy_(dim=1, index=idx, source=shard) + + # Return to [S, B, ...] + x_global = x_global_bs.transpose(0, 1).contiguous() # [S, B, ...] + return x_global \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/wan_layer_spec.py b/src/megatron/bridge/models/wan/wan_layer_spec.py index 16d3931e29..51dccdd2f3 100644 --- a/src/megatron/bridge/models/wan/wan_layer_spec.py +++ b/src/megatron/bridge/models/wan/wan_layer_spec.py @@ -33,6 +33,7 @@ TEColumnParallelLinear, TEDotProductAttention, TERowParallelLinear, + TELinear, ) from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp @@ -439,6 +440,8 @@ def __init__( submodules.full_self_attention, config=self.config, layer_number=layer_number, + cp_comm_type=config.cp_comm_type, + pg_collection=pg_collection, ) self.adaLN = WanAdaLN(config=self.config) @@ -636,18 +639,33 @@ def __init__( config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout, pg_collection=pg_collection, vp_stage=vp_stage ) + # self.context_proj = build_module( + # submodules.context_proj, + # self.config.hidden_size, + # self.config.hidden_size, + # config=self.config, + # init_method=self.config.output_layer_init_method, + # bias=self.config.add_bias_linear, + # input_is_parallel=False, + # skip_bias_add=True, + # is_expert=False, + # tp_comm_buffer_name='proj', + # tp_group=self.pg_collection.tp, + # ) self.context_proj = build_module( submodules.context_proj, self.config.hidden_size, self.config.hidden_size, + parallel_mode="duplicated", config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, - input_is_parallel=True, - skip_bias_add=True, + skip_bias_add=False, + skip_weight_param_allocation=False, is_expert=False, + symmetric_ar_type=self.config.symmetric_ar_type, tp_comm_buffer_name='proj', - tp_group=self.pg_collection.tp, + tp_group=None, ) @@ -684,7 +702,7 @@ def forward( inference_context=inference_context, ) hidden_states_proj, bias = self.context_proj(hidden_states) - all_hidden_states += [hidden_states_proj + bias, hidden_states] + all_hidden_states += [hidden_states_proj, hidden_states] hidden_states = torch.stack(all_hidden_states) return hidden_states, context @@ -822,7 +840,7 @@ def get_vace_context_block_with_transformer_engine_spec() -> ModuleSpec: linear_fc2=TERowParallelLinear, ), ), - context_proj=TERowParallelLinear + context_proj=TELinear ), ) diff --git a/src/megatron/bridge/models/wan/wan_provider.py b/src/megatron/bridge/models/wan/wan_provider.py index 2663745fa4..c48b103bfb 100644 --- a/src/megatron/bridge/models/wan/wan_provider.py +++ b/src/megatron/bridge/models/wan/wan_provider.py @@ -46,7 +46,8 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): parallel_output: bool = True bf16: bool = False params_dtype: torch.dtype = torch.float32 - qkv_format: str = "sbhd" # "thd". NOTE: if we use context parallelism, we need to use "thd" + # qkv_format: str = "sbhd" # "thd". NOTE: if we use context parallelism, we need to use "thd" + qkv_format: str = "thd" # these attributes are unused for images/videos, we just set because bridge training requires for LLMs seq_length: int = 1024 share_embeddings_and_output_weights: bool = False diff --git a/vace.sh b/vace.sh index 9d8778ec17..54ce722a3d 100644 --- a/vace.sh +++ b/vace.sh @@ -20,8 +20,8 @@ NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoin --vae_checkpoint_dir ${VAE_DIR} \ --prompts "Two dogs hit each other during boxing." \ --frame_nums 81 \ - --tensor_parallel_size 2 \ - --context_parallel_size 1 \ + --tensor_parallel_size 1 \ + --context_parallel_size 2 \ --pipeline_parallel_size 1 \ --sequence_parallel False \ --base_seed 42 \ From afdd3c6c6ac50f735b57eee591203238d9d83e25 Mon Sep 17 00:00:00 2001 From: Tiancheng Zhao Date: Wed, 19 Nov 2025 01:53:04 +0000 Subject: [PATCH 14/17] add profiling --- example_commands.sh | 2 +- .../flow_matching/flow_inference_pipeline.py | 11 +++++++++-- .../bridge/models/wan/wan_layer_spec.py | 17 +++++++++++++++-- vace.sh | 4 ++-- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/example_commands.sh b/example_commands.sh index ee68def7b0..9643b10d62 100644 --- a/example_commands.sh +++ b/example_commands.sh @@ -2,7 +2,7 @@ # export MBRIDGE_PATH=/path/to/Megatron-Bridge # export PYTHONPATH="${MBRIDGE_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-Launcher/launcher_scripts" -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=0,1 # ### install dependencies # pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@ce8185cbbe04f38beb74360e878450f2e8525885 diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py index ddf523f795..b6d0b555e8 100644 --- a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py +++ b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py @@ -576,6 +576,11 @@ def noop_no_sync(): return videos if self.rank == 0 else None +def log_checkpoint(tag): + torch.cuda.synchronize() + alloc = torch.cuda.memory_allocated() / 1024**3 + reserved = torch.cuda.memory_reserved() / 1024**3 + print(f"[{tag}] alloc={alloc:.2f} GB reserved={reserved:.2f} GB") class VACEFlowInferencePipeline: @@ -635,7 +640,8 @@ def __init__( checkpoint_path=os.path.join(t5_checkpoint_dir, config.t5_checkpoint), tokenizer_path=os.path.join(t5_checkpoint_dir, config.t5_tokenizer), shard_fn=None) - + + log_checkpoint("before vae") self.vae_stride = config.vae_stride self.patch_size = config.patch_size self.vae = WanVAE( @@ -654,7 +660,8 @@ def __init__( if dist.is_initialized(): dist.barrier() self.model.to(self.device) - + log_checkpoint("after transformer") + self.sample_neg_prompt = config.sample_neg_prompt self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(self.vae_stride, self.patch_size)]), diff --git a/src/megatron/bridge/models/wan/wan_layer_spec.py b/src/megatron/bridge/models/wan/wan_layer_spec.py index 51dccdd2f3..9a6bd7c965 100644 --- a/src/megatron/bridge/models/wan/wan_layer_spec.py +++ b/src/megatron/bridge/models/wan/wan_layer_spec.py @@ -553,6 +553,11 @@ def forward( return output, context +def log_checkpoint(tag): + torch.cuda.synchronize() + alloc = torch.cuda.memory_allocated() / 1024**3 + reserved = torch.cuda.memory_reserved() / 1024**3 + print(f"[{tag}] alloc={alloc:.2f} GB reserved={reserved:.2f} GB") class VACEBaseLayer(WanLayerWithAdaLN): """A single transformer layer. @@ -593,6 +598,8 @@ def forward( inference_context=None, ): + log_checkpoint("before base") + hidden_states, context = super().forward( hidden_states, attention_mask=attention_mask, @@ -613,7 +620,9 @@ def forward( hidden_states = hidden_states + context_mask[self.idx] * self.context_scale # hidden_states = hidden_states + context_mask[self.idx] * 2.0 # hidden_states = hidden_states + torch.rand_like(context_mask[self.idx]) * 0.05 - + + log_checkpoint(f"after base {self.idx}") + return hidden_states, context @@ -685,6 +694,8 @@ def forward( inference_context=None, ): + log_checkpoint("before context") + all_hidden_states = list(torch.unbind(hidden_states)) hidden_states = all_hidden_states.pop(-1) hidden_states, context = super().forward( @@ -704,7 +715,9 @@ def forward( hidden_states_proj, bias = self.context_proj(hidden_states) all_hidden_states += [hidden_states_proj, hidden_states] hidden_states = torch.stack(all_hidden_states) - + + log_checkpoint("after context") + return hidden_states, context diff --git a/vace.sh b/vace.sh index 54ce722a3d..d879a4d7b3 100644 --- a/vace.sh +++ b/vace.sh @@ -9,7 +9,7 @@ CHECKPOINT_DIR=/opt/megatron_checkpoint_VACE T5_DIR=/opt/Wan2.1-T2V-1.3B VAE_DIR=/opt/Wan2.1-T2V-1.3B -NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_vace.py \ +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_vace.py \ --model_name vace-1.3B \ --sizes 832*480 \ --save_file "depth" \ @@ -21,7 +21,7 @@ NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoin --prompts "Two dogs hit each other during boxing." \ --frame_nums 81 \ --tensor_parallel_size 1 \ - --context_parallel_size 2 \ + --context_parallel_size 1 \ --pipeline_parallel_size 1 \ --sequence_parallel False \ --base_seed 42 \ From 59964562e4dc61d502171d4d7da28f5e16a3f7b7 Mon Sep 17 00:00:00 2001 From: Tiancheng Zhao Date: Thu, 20 Nov 2025 07:58:39 +0000 Subject: [PATCH 15/17] fix memory issues --- example_commands.sh | 4 +- .../flow_matching/flow_inference_pipeline.py | 6 ++ .../bridge/models/wan/wan_layer_spec.py | 24 +++-- src/megatron/bridge/models/wan/wan_model.py | 99 ++++++++++++++++++- vace.sh | 24 ++++- 5 files changed, 139 insertions(+), 18 deletions(-) diff --git a/example_commands.sh b/example_commands.sh index 9643b10d62..d244d9cf4e 100644 --- a/example_commands.sh +++ b/example_commands.sh @@ -58,8 +58,8 @@ export CUDA_VISIBLE_DEVICES=0,1 # VAE: Wan2.1_VAE.pth CHECKPOINT_DIR=/opt/megatron_checkpoint_WAN -T5_DIR=~/.cache/huggingface/hub/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a -VAE_DIR=~/.cache/huggingface/hub/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a +T5_DIR=/opt/Wan2.1-T2V-1.3B +VAE_DIR=/opt/Wan2.1-T2V-1.3B # cd $MBRIDGE_PATH # NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ # --task t2v-1.3B \ diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py index b6d0b555e8..608c5642d5 100644 --- a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py +++ b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py @@ -94,6 +94,8 @@ def __init__( tokenizer_path=os.path.join(t5_checkpoint_dir, config.t5_tokenizer), shard_fn=None) + log_checkpoint("before vae") + self.vae_stride = config.vae_stride self.patch_size = config.patch_size self.vae = WanVAE( @@ -112,6 +114,8 @@ def __init__( if dist.is_initialized(): dist.barrier() self.model.to(self.device) + + log_checkpoint("after transformer") self.sample_neg_prompt = config.sample_neg_prompt @@ -642,6 +646,7 @@ def __init__( shard_fn=None) log_checkpoint("before vae") + self.vae_stride = config.vae_stride self.patch_size = config.patch_size self.vae = WanVAE( @@ -660,6 +665,7 @@ def __init__( if dist.is_initialized(): dist.barrier() self.model.to(self.device) + log_checkpoint("after transformer") self.sample_neg_prompt = config.sample_neg_prompt diff --git a/src/megatron/bridge/models/wan/wan_layer_spec.py b/src/megatron/bridge/models/wan/wan_layer_spec.py index 9a6bd7c965..e6cf0a30f9 100644 --- a/src/megatron/bridge/models/wan/wan_layer_spec.py +++ b/src/megatron/bridge/models/wan/wan_layer_spec.py @@ -480,6 +480,9 @@ def forward( sequence_len_offset=None, inference_context=None, ): + + # log_checkpoint("before layer") + # the timestep embedding is stored in attention_mask argument timestep_emb = attention_mask rope_emb = rotary_pos_emb @@ -550,7 +553,9 @@ def forward( # 'view' tensor. ??? output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) # output = hidden_states - + + # log_checkpoint("after layer") + return output, context def log_checkpoint(tag): @@ -695,11 +700,11 @@ def forward( ): log_checkpoint("before context") - - all_hidden_states = list(torch.unbind(hidden_states)) - hidden_states = all_hidden_states.pop(-1) - hidden_states, context = super().forward( - hidden_states, + + # all_hidden_states = list(torch.unbind(hidden_states)) + # hidden_states = all_hidden_states.pop(-1) + hidden_state, context = super().forward( + hidden_states[self.idx], attention_mask=attention_mask, context=context, context_mask=None, @@ -712,9 +717,10 @@ def forward( sequence_len_offset=sequence_len_offset, inference_context=inference_context, ) - hidden_states_proj, bias = self.context_proj(hidden_states) - all_hidden_states += [hidden_states_proj, hidden_states] - hidden_states = torch.stack(all_hidden_states) + hidden_states[self.idx] = self.context_proj(hidden_state)[0] + hidden_states[self.idx + 1] = hidden_state + # all_hidden_states += [hidden_states_proj, hidden_states] + # hidden_states = torch.stack(all_hidden_states) log_checkpoint("after context") diff --git a/src/megatron/bridge/models/wan/wan_model.py b/src/megatron/bridge/models/wan/wan_model.py index 800a26c37f..ae8a456be0 100644 --- a/src/megatron/bridge/models/wan/wan_model.py +++ b/src/megatron/bridge/models/wan/wan_model.py @@ -46,7 +46,7 @@ from megatron.core.transformer.transformer_layer import get_transformer_layer_offset from megatron.core.utils import get_pg_rank -class IndexTransformerBlock(TransformerBlock): +class BaseTransformerBlock(TransformerBlock): def __init__( self, config: TransformerConfig, @@ -137,6 +137,96 @@ def build_layer(layer_spec, layer_number): ) else: self.final_layernorm = None # Either this or nn.Identity + +class ContextTransformerBlock(TransformerBlock): + def __init__( + self, + config: TransformerConfig, + spec: Union[TransformerBlockSubmodules, ModuleSpec], + post_layer_norm: bool = True, + pre_process: bool = True, + post_process: bool = True, + pg_collection: ProcessGroupCollection = None, + vp_stage: Optional[int] = None, + ): + # Pass block id and context_scale + self.vace_id = [i for i in range(0, config.num_layers)] if config.vace_layers is None else [i for i in range(0, len(config.vace_layers))] + print(self.vace_id) + assert 0 in self.vace_id + + super().__init__( + config=config, + spec=spec, + post_layer_norm=post_layer_norm, + pre_process=pre_process, + post_process=post_process, + pg_collection=pg_collection, + vp_stage=vp_stage, + ) + + def _build_layers(self): + # Transformer layers. + # @jcasper can we improve how we deal with layer_number? + # currently it's only used in CoreAttention? + # if self.apply_query_key_layer_scaling: + # coeff = self.layer_number + # self.norm_factor *= coeff + def build_layer(layer_spec, layer_number): + global_layer_number = layer_number + get_transformer_layer_offset( + self.config, self.vp_stage, get_pg_rank(self.pg_collection.pp) + ) # 1-based index + if self.config.heterogeneous_block_specs: + layer_config = self.config.get_config_for_layer(global_layer_number) + else: + layer_config = self.config + + # Get appropriate quantization context (FP8 and FP4 are mutually exclusive) + if layer_config.fp8: + quantization_context = get_fp8_context( + layer_config, global_layer_number - 1, is_init=True + ) + elif layer_config.fp4: + quantization_context = get_fp4_context( + layer_config, global_layer_number - 1, is_init=True + ) + else: + quantization_context = nullcontext() + + with quantization_context: + module = build_module( + layer_spec, + config=layer_config, + layer_number=layer_number, + pg_collection=self.pg_collection, + vp_stage=self.vp_stage, + ) + idx = global_layer_number - 1 + if idx in self.vace_id: + module.idx = idx + else: + module.idx = None + return module + + # offset is implicit in TransformerLayer + self.layers = torch.nn.ModuleList( + [ + build_layer(layer_spec, i + 1) + for i, layer_spec in enumerate(self.submodules.layer_specs) + ] + ) + + # @TODO: add back account_for_embedding_in_pipeline_split (see issue #293) + # In pipeline parallelism, we want to add this LN only to the last stage of the pipeline + # self.post_process and self.post_layer_norm guide this behavior + if self.submodules.layer_norm and self.post_process and self.post_layer_norm: + self.final_layernorm = build_module( + self.submodules.layer_norm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + else: + self.final_layernorm = None # Either this or nn.Identity def sinusoidal_embedding_1d(dim, position): # preprocess @@ -464,7 +554,7 @@ def __init__( self.vace_patch_embedding = nn.Conv3d( self.vace_in_channels, self.config.hidden_size, kernel_size=self.patch_size, stride=self.patch_size) - self.decoder = IndexTransformerBlock( + self.decoder = BaseTransformerBlock( config=self.config, spec=self.transformer_decoder_layer_spec, pre_process=self.pre_process, @@ -474,7 +564,7 @@ def __init__( # print(self.decoder) self.vace_config = copy.deepcopy(self.config) self.vace_config.num_layers = len(self.decoder.vace_layers) - self.vace_decoder = TransformerBlock( + self.vace_decoder = ContextTransformerBlock( config=self.vace_config, spec=self.vace_transformer_decoder_layer_spec, pre_process=self.pre_process, @@ -537,7 +627,8 @@ def forward( vace_context = vace_context.flatten(1) # output: vace_context.shape [s * b, hidden_size] vace_context = vace_context.reshape(vace_seq_len, batch_size, -1) # output: vace_context.shape [s, b, hidden_size] vace_context = self.vace_init_proj(vace_context) + x - vace_context = vace_context.unsqueeze(0) + # vace_context = vace_context.unsqueeze(0) + vace_context = torch.stack([vace_context] * (self.vace_config.num_layers + 1)) # split sequence for sequence_parallel # TODO: for PP, do we move scatter_to_sequence_parallel_region here or after "x = self.decoder.input_tensor" ??? diff --git a/vace.sh b/vace.sh index d879a4d7b3..5e6a37895f 100644 --- a/vace.sh +++ b/vace.sh @@ -9,7 +9,7 @@ CHECKPOINT_DIR=/opt/megatron_checkpoint_VACE T5_DIR=/opt/Wan2.1-T2V-1.3B VAE_DIR=/opt/Wan2.1-T2V-1.3B -NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_vace.py \ +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_vace.py \ --model_name vace-1.3B \ --sizes 832*480 \ --save_file "depth" \ @@ -21,8 +21,26 @@ NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 --rdzv-backend=c10d --rdzv-endpoin --prompts "Two dogs hit each other during boxing." \ --frame_nums 81 \ --tensor_parallel_size 1 \ - --context_parallel_size 1 \ + --context_parallel_size 2 \ --pipeline_parallel_size 1 \ --sequence_parallel False \ --base_seed 42 \ - --sample_steps 50 \ No newline at end of file + --sample_steps 50 + +# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_vace.py \ +# --model_name vace-1.3B \ +# --sizes 832*480 832*480 \ +# --save_file "depth" \ +# --src_video "src_video_depth.mp4" \ +# --checkpoint_dir ${CHECKPOINT_DIR} \ +# --checkpoint_step 0000 \ +# --t5_checkpoint_dir ${T5_DIR} \ +# --vae_checkpoint_dir ${VAE_DIR} \ +# --prompts "Two dogs hit each other during boxing." "Two dogs hit each other during boxing." \ +# --frame_nums 81 81 \ +# --tensor_parallel_size 1 \ +# --context_parallel_size 2 \ +# --pipeline_parallel_size 1 \ +# --sequence_parallel False \ +# --base_seed 42 \ +# --sample_steps 50 \ No newline at end of file From f25c81a02af5979ea281517f692f1fc269574b91 Mon Sep 17 00:00:00 2001 From: Tiancheng Zhao Date: Sat, 22 Nov 2025 23:26:06 +0000 Subject: [PATCH 16/17] enable batch size more than 1 --- examples/recipes/wan/inference_vace.py | 18 +++++++++--------- .../flow_matching/flow_inference_pipeline.py | 13 +++++++++++-- .../bridge/models/wan/wan_layer_spec.py | 8 ++++---- src/megatron/bridge/models/wan/wan_model.py | 3 +++ vace.sh | 14 +++++++------- 5 files changed, 34 insertions(+), 22 deletions(-) diff --git a/examples/recipes/wan/inference_vace.py b/examples/recipes/wan/inference_vace.py index a1d66f4003..b0ce120fec 100644 --- a/examples/recipes/wan/inference_vace.py +++ b/examples/recipes/wan/inference_vace.py @@ -240,32 +240,32 @@ def generate(args): if args.prompts is None: prompts = [None] else: - prompts = args.prompts + prompts = args.prompts * 8 if args.src_video is None: - src_video = [None] + src_video = [None] * len(prompts) else: - src_video = args.src_video + src_video = args.src_video * 8 if args.src_mask is None: - src_mask = [None] + src_mask = [None] * len(prompts) else: - src_mask = args.src_mask + src_mask = args.src_mask * 8 if args.src_ref_images is None: - src_ref_images = [None] + src_ref_images = [None] * len(prompts) else: - src_ref_images = args.src_ref_images + src_ref_images = args.src_ref_images * 8 # Resolve sizes list (default to first supported size for task) if args.sizes is not None and len(args.sizes) > 0: - size_keys = args.sizes + size_keys = args.sizes * 8 else: size_keys = [SUPPORTED_SIZES[args.model_name][0]] # Resolve frame counts list (default 81) if args.frame_nums is not None and len(args.frame_nums) > 0: - frame_nums = args.frame_nums + frame_nums = args.frame_nums * 8 else: frame_nums = [81] diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py index 608c5642d5..385bf6d741 100644 --- a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py +++ b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py @@ -1073,6 +1073,8 @@ def generate(self, vace_context[i] = F.pad(vace_context[i], (0, 0, 0, max_video_seq_len - vace_context[i].shape[0])) vace_context = torch.stack(vace_context, dim=1) + s, b, h = vace_context.shape + vace_context = vace_context.transpose(0, 1).reshape(s*b, 1, h) if n_prompt == "": n_prompt = self.sample_neg_prompt @@ -1105,6 +1107,9 @@ def generate(self, contexts = torch.stack(contexts, dim=1) contexts_null = torch.stack(contexts_null, dim=1) + s, b, h = contexts.shape + contexts = contexts.transpose(0, 1).reshape(s*b, 1, h) + contexts_null = contexts_null.transpose(0, 1).reshape(s*b, 1, h) ## setup noise noises = [] @@ -1119,7 +1124,7 @@ def generate(self, device=self.device, generator=seed_g) ) - + # noises = noises[:1] * len(noises) # calculate grid_sizes grid_sizes = [grid_sizes_calculation( @@ -1221,6 +1226,8 @@ def noop_no_sync(): latents[i] = F.pad(latents[i], (0, 0, 0, max_video_seq_len - latents[i].shape[0])) latents = torch.stack(latents, dim=1) + s, b, h = latents.shape + latents = latents.transpose(0, 1).reshape(s*b, 1, h) # context parallel if parallel_state.get_context_parallel_world_size() > 1: @@ -1228,7 +1235,7 @@ def noop_no_sync(): latent_model_input = latents - timestep = [t] * batch_size + timestep = [t] * 1 timestep = torch.stack(timestep) self.model.to(self.device) @@ -1244,6 +1251,8 @@ def noop_no_sync(): noise_pred_cond = thd_cat_outputs_cp(noise_pred_cond, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) noise_pred_uncond = thd_cat_outputs_cp(noise_pred_uncond, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + noise_pred_cond = noise_pred_cond.reshape(b, s, h).transpose(0, 1) + noise_pred_uncond = noise_pred_uncond.reshape(b, s, h).transpose(0, 1) # run unpatchify unpatchified_noise_pred_cond = noise_pred_cond diff --git a/src/megatron/bridge/models/wan/wan_layer_spec.py b/src/megatron/bridge/models/wan/wan_layer_spec.py index e6cf0a30f9..f54649bb4d 100644 --- a/src/megatron/bridge/models/wan/wan_layer_spec.py +++ b/src/megatron/bridge/models/wan/wan_layer_spec.py @@ -603,7 +603,7 @@ def forward( inference_context=None, ): - log_checkpoint("before base") + # log_checkpoint("before base") hidden_states, context = super().forward( hidden_states, @@ -626,7 +626,7 @@ def forward( # hidden_states = hidden_states + context_mask[self.idx] * 2.0 # hidden_states = hidden_states + torch.rand_like(context_mask[self.idx]) * 0.05 - log_checkpoint(f"after base {self.idx}") + # log_checkpoint(f"after base {self.idx}") return hidden_states, context @@ -699,7 +699,7 @@ def forward( inference_context=None, ): - log_checkpoint("before context") + # log_checkpoint("before context") # all_hidden_states = list(torch.unbind(hidden_states)) # hidden_states = all_hidden_states.pop(-1) @@ -722,7 +722,7 @@ def forward( # all_hidden_states += [hidden_states_proj, hidden_states] # hidden_states = torch.stack(all_hidden_states) - log_checkpoint("after context") + # log_checkpoint("after context") return hidden_states, context diff --git a/src/megatron/bridge/models/wan/wan_model.py b/src/megatron/bridge/models/wan/wan_model.py index ae8a456be0..30f9ebe003 100644 --- a/src/megatron/bridge/models/wan/wan_model.py +++ b/src/megatron/bridge/models/wan/wan_model.py @@ -658,6 +658,9 @@ def forward( n_head, dim_head = self.num_heads, self.config.hidden_size // self.num_heads rotary_pos_emb = self.rope_embeddings(n_head, dim_head, max_seq_len, grid_sizes, t.device) # output: rotary_pos_emb.shape [s, b, 1, dim_head] + s, b, sq, h = rotary_pos_emb.shape + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).reshape(s*b, 1, sq, h) + # run vace decoder vace_context = self.vace_decoder( hidden_states=vace_context, diff --git a/vace.sh b/vace.sh index 5e6a37895f..2ff32b6f0e 100644 --- a/vace.sh +++ b/vace.sh @@ -12,8 +12,8 @@ VAE_DIR=/opt/Wan2.1-T2V-1.3B NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_vace.py \ --model_name vace-1.3B \ --sizes 832*480 \ - --save_file "depth" \ - --src_video "src_video_depth.mp4" \ + --save_file "test" \ + --src_video "src_video_flow.mp4" \ --checkpoint_dir ${CHECKPOINT_DIR} \ --checkpoint_step 0000 \ --t5_checkpoint_dir ${T5_DIR} \ @@ -29,15 +29,15 @@ NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoin # NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_vace.py \ # --model_name vace-1.3B \ -# --sizes 832*480 832*480 \ -# --save_file "depth" \ -# --src_video "src_video_depth.mp4" \ +# --sizes 832*480 832*480 832*480 \ +# --save_file "test" \ +# --src_video "src_video_depth.mp4" "src_video_flow.mp4" "src_video_pose.mp4" \ # --checkpoint_dir ${CHECKPOINT_DIR} \ # --checkpoint_step 0000 \ # --t5_checkpoint_dir ${T5_DIR} \ # --vae_checkpoint_dir ${VAE_DIR} \ -# --prompts "Two dogs hit each other during boxing." "Two dogs hit each other during boxing." \ -# --frame_nums 81 81 \ +# --prompts "Two dogs hit each other during boxing." "Two dogs hit each other during boxing." "Two dogs hit each other during boxing." \ +# --frame_nums 81 81 81 \ # --tensor_parallel_size 1 \ # --context_parallel_size 2 \ # --pipeline_parallel_size 1 \ From 7eba8456e1d19db0bf2c7963728efc6c4a0d90b7 Mon Sep 17 00:00:00 2001 From: Tiancheng Zhao Date: Fri, 28 Nov 2025 06:42:53 +0000 Subject: [PATCH 17/17] add additional output for context branch and additional input for base branch --- .../bridge/models/wan/wan_layer_spec.py | 23 +- src/megatron/bridge/models/wan/wan_model.py | 662 +++++++++++++++++- 2 files changed, 667 insertions(+), 18 deletions(-) diff --git a/src/megatron/bridge/models/wan/wan_layer_spec.py b/src/megatron/bridge/models/wan/wan_layer_spec.py index f54649bb4d..8a839bcd89 100644 --- a/src/megatron/bridge/models/wan/wan_layer_spec.py +++ b/src/megatron/bridge/models/wan/wan_layer_spec.py @@ -593,6 +593,7 @@ def forward( attention_mask=None, context=None, context_mask=None, + context_signal=None, rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, @@ -620,11 +621,11 @@ def forward( inference_context=inference_context, ) # consider how to pass block id and context_scale - # the context_tokens from context branch is stored in context_mask argument + # the context_tokens from context branch is stored in context_signal argument if self.idx is not None: - hidden_states = hidden_states + context_mask[self.idx] * self.context_scale - # hidden_states = hidden_states + context_mask[self.idx] * 2.0 - # hidden_states = hidden_states + torch.rand_like(context_mask[self.idx]) * 0.05 + hidden_states = hidden_states + context_signal[self.idx] * self.context_scale + # hidden_states = hidden_states + context_signal[self.idx] * 2.0 + # hidden_states = hidden_states + torch.rand_like(context_signal[self.idx]) * 0.05 # log_checkpoint(f"after base {self.idx}") @@ -689,6 +690,7 @@ def forward( attention_mask=None, context=None, context_mask=None, + context_signal=None, rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, @@ -701,10 +703,8 @@ def forward( # log_checkpoint("before context") - # all_hidden_states = list(torch.unbind(hidden_states)) - # hidden_states = all_hidden_states.pop(-1) - hidden_state, context = super().forward( - hidden_states[self.idx], + hidden_states, context = super().forward( + hidden_states, attention_mask=attention_mask, context=context, context_mask=None, @@ -717,14 +717,11 @@ def forward( sequence_len_offset=sequence_len_offset, inference_context=inference_context, ) - hidden_states[self.idx] = self.context_proj(hidden_state)[0] - hidden_states[self.idx + 1] = hidden_state - # all_hidden_states += [hidden_states_proj, hidden_states] - # hidden_states = torch.stack(all_hidden_states) + context_signal[self.idx] = self.context_proj(hidden_states)[0] # log_checkpoint("after context") - return hidden_states, context + return hidden_states, context_signal import transformer_engine as te diff --git a/src/megatron/bridge/models/wan/wan_model.py b/src/megatron/bridge/models/wan/wan_model.py index 30f9ebe003..9bec5c85b3 100644 --- a/src/megatron/bridge/models/wan/wan_model.py +++ b/src/megatron/bridge/models/wan/wan_model.py @@ -41,10 +41,52 @@ from contextlib import nullcontext from megatron.core.fp4_utils import get_fp4_context from megatron.core.fp8_utils import get_fp8_context +from megatron.core.enums import Fp8Recipe +from megatron.core.fusions.fused_layer_norm import FusedLayerNorm +from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_layer import get_transformer_layer_offset -from megatron.core.utils import get_pg_rank +from megatron.core.utils import ( + WrappedTensor, + deprecate_inference_params, + get_pg_rank, + make_viewless_tensor, +) + +try: + import transformer_engine.pytorch as te # pylint: disable=unused-import + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + import apex # pylint: disable=unused-import + + HAVE_APEX = True +except ImportError: + HAVE_APEX = False + +get_cpu_offload_context = None +te_checkpoint = None + +if HAVE_TE: + from megatron.core.extensions.transformer_engine import ( + TENorm, + get_cpu_offload_context, + te_checkpoint, + ) + + LayerNormImpl = TENorm + +elif HAVE_APEX: + LayerNormImpl = FusedLayerNorm + +else: + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + LayerNormImpl = WrappedTorchNorm class BaseTransformerBlock(TransformerBlock): def __init__( @@ -138,6 +180,310 @@ def build_layer(layer_spec, layer_number): else: self.final_layernorm = None # Either this or nn.Identity + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor, + context_mask: Tensor, + context_signal: Tensor, + rotary_pos_emb: Tensor, + attention_bias: Tensor, + packed_seq_params: PackedSeqParams, + use_inner_quantization_context: bool, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward( + hidden_states, attention_mask, context, context_mask, context_signal, rotary_pos_emb + ): + for index in range(start, end): + layer = self._get_layer(index) + + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context( + self.config, layer.layer_number - 1 + ) + # TODO: check if fp4 is supported in this case + elif self.config.fp4: + inner_quantization_context = get_fp4_context( + self.config, layer.layer_number - 1 + ) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + + with inner_quantization_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + context_signal=context_signal, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=packed_seq_params, + ) + return hidden_states, context + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + # TODO: check if fp4 is supported in this case + if self.config.fp8 or self.config.fp4: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + self.pg_collection.tp, + hidden_states, + attention_mask, + context, + context_mask, + context_signal, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + context_signal, + rotary_pos_emb, + ) + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states, context = checkpoint_handler( + custom(layer_idx, layer_idx + self.config.recompute_num_layers) + ) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + # TODO: check if fp4 is supported in this case + if (self.config.fp8 or self.config.fp4) and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states, context = custom(layer_idx, layer_idx + 1)( + hidden_states, attention_mask, context, context_mask, context_signal, rotary_pos_emb + ) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Optional[Tensor], + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + context_signal: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + rotary_pos_cos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + dynamic_inference_decode_only: Optional[bool] = None, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + context_signal (Tensor, optional): Signal from context tokens + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. + rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. + rotary_pos_cos_sin (Optional[Tensor]): Combined rotary embedding cosine and sine. + Currently used exclusively for inference with dynamic batching and flashinfer RoPE. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_context (BaseInferenceContext, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + dynamic_inference_decode_only: Optional[bool]: If true, indicates that the current + inference context is for decode-only. This args is only used to uniquely + identify decode and non-decode cuda graph runners in the cuda graph manager. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + # Remove 'dynamic_inference_decode_only' from kwargs if present + # this is only used to uniquely identify decode and non-decode cuda graph + # runners in the cuda graph manager + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + # For FP4: NVFP4BlockScaling doesn't have delayed scaling, always uses inner context + if self.config.fp8: + use_outer_quantization_context = self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_quantization_context = self.config.fp8_recipe != Fp8Recipe.delayed + outer_quantization_context = ( + get_fp8_context(self.config) if use_outer_quantization_context else nullcontext() + ) + elif self.config.fp4: + use_outer_quantization_context = False + use_inner_quantization_context = True + outer_quantization_context = nullcontext() + else: + # No quantization + use_outer_quantization_context = False + use_inner_quantization_context = False + outer_quantization_context = nullcontext() + + with rng_context, outer_quantization_context: + # Forward pass. + if self.config.recompute_granularity == 'full' and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + context_signal=context_signal, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + use_inner_quantization_context=use_inner_quantization_context, + ) + else: + for l_no, layer in enumerate(self.layers): + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context( + self.config, layer.layer_number - 1 + ) + elif self.config.fp4: + inner_quantization_context = get_fp4_context( + self.config, layer.layer_number - 1 + ) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + + with self.offload_context, inner_quantization_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + context_signal=context_signal, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + # rotary_pos_cos_sin=rotary_pos_cos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + # Final layer norm. + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + + # If this TransformerBlock is empty, input and output hidden states will be the same node + # on the computational graph and will lead to unexpected errors in pipeline schedules. + if not self.pre_process and len(self.layers) == 0 and not self.final_layernorm: + hidden_states = hidden_states.clone() + + return hidden_states + class ContextTransformerBlock(TransformerBlock): def __init__( self, @@ -227,6 +573,310 @@ def build_layer(layer_spec, layer_number): ) else: self.final_layernorm = None # Either this or nn.Identity + + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor, + context_mask: Tensor, + context_signal: Tensor, + rotary_pos_emb: Tensor, + attention_bias: Tensor, + packed_seq_params: PackedSeqParams, + use_inner_quantization_context: bool, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward( + hidden_states, attention_mask, context, context_mask, context_signal, rotary_pos_emb + ): + for index in range(start, end): + layer = self._get_layer(index) + + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context( + self.config, layer.layer_number - 1 + ) + # TODO: check if fp4 is supported in this case + elif self.config.fp4: + inner_quantization_context = get_fp4_context( + self.config, layer.layer_number - 1 + ) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + + with inner_quantization_context: + hidden_states, context_signal = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + context_signal=context_signal, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=packed_seq_params, + ) + return hidden_states, context_signal + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + # TODO: check if fp4 is supported in this case + if self.config.fp8 or self.config.fp4: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + self.pg_collection.tp, + hidden_states, + attention_mask, + context, + context_mask, + context_signal, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + context_signal, + rotary_pos_emb, + ) + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states, context_signal = checkpoint_handler( + custom(layer_idx, layer_idx + self.config.recompute_num_layers) + ) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + # TODO: check if fp4 is supported in this case + if (self.config.fp8 or self.config.fp4) and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states, context_signal = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states, context_signal = custom(layer_idx, layer_idx + 1)( + hidden_states, attention_mask, context, context_mask, context_signal, rotary_pos_emb + ) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states, context_signal + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Optional[Tensor], + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + context_signal: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + rotary_pos_cos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + dynamic_inference_decode_only: Optional[bool] = None, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + context_signal (Tensor, optional): Signal from context tokens + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. + rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. + rotary_pos_cos_sin (Optional[Tensor]): Combined rotary embedding cosine and sine. + Currently used exclusively for inference with dynamic batching and flashinfer RoPE. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_context (BaseInferenceContext, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + dynamic_inference_decode_only: Optional[bool]: If true, indicates that the current + inference context is for decode-only. This args is only used to uniquely + identify decode and non-decode cuda graph runners in the cuda graph manager. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + # Remove 'dynamic_inference_decode_only' from kwargs if present + # this is only used to uniquely identify decode and non-decode cuda graph + # runners in the cuda graph manager + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + # For FP4: NVFP4BlockScaling doesn't have delayed scaling, always uses inner context + if self.config.fp8: + use_outer_quantization_context = self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_quantization_context = self.config.fp8_recipe != Fp8Recipe.delayed + outer_quantization_context = ( + get_fp8_context(self.config) if use_outer_quantization_context else nullcontext() + ) + elif self.config.fp4: + use_outer_quantization_context = False + use_inner_quantization_context = True + outer_quantization_context = nullcontext() + else: + # No quantization + use_outer_quantization_context = False + use_inner_quantization_context = False + outer_quantization_context = nullcontext() + + with rng_context, outer_quantization_context: + # Forward pass. + if self.config.recompute_granularity == 'full' and self.training: + hidden_states, context_signal = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + context_signal=context_signal, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + use_inner_quantization_context=use_inner_quantization_context, + ) + else: + for l_no, layer in enumerate(self.layers): + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context( + self.config, layer.layer_number - 1 + ) + elif self.config.fp4: + inner_quantization_context = get_fp4_context( + self.config, layer.layer_number - 1 + ) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + + with self.offload_context, inner_quantization_context: + hidden_states, context_signal = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + context_signal=context_signal, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + # rotary_pos_cos_sin=rotary_pos_cos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + # Final layer norm. + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + + # If this TransformerBlock is empty, input and output hidden states will be the same node + # on the computational graph and will lead to unexpected errors in pipeline schedules. + if not self.pre_process and len(self.layers) == 0 and not self.final_layernorm: + hidden_states = hidden_states.clone() + + return hidden_states, context_signal def sinusoidal_embedding_1d(dim, position): # preprocess @@ -628,7 +1278,7 @@ def forward( vace_context = vace_context.reshape(vace_seq_len, batch_size, -1) # output: vace_context.shape [s, b, hidden_size] vace_context = self.vace_init_proj(vace_context) + x # vace_context = vace_context.unsqueeze(0) - vace_context = torch.stack([vace_context] * (self.vace_config.num_layers + 1)) + vace_context = torch.stack([vace_context] * (self.vace_config.num_layers)) # split sequence for sequence_parallel # TODO: for PP, do we move scatter_to_sequence_parallel_region here or after "x = self.decoder.input_tensor" ??? @@ -663,22 +1313,24 @@ def forward( # run vace decoder vace_context = self.vace_decoder( - hidden_states=vace_context, + hidden_states=vace_context[0], attention_mask=e0, context=context, context_mask=None, + context_signal=vace_context, rotary_pos_emb=rotary_pos_emb, rotary_pos_cos=None, rotary_pos_sin=None, packed_seq_params=packed_seq_params, - )[:-1] + )[1] # run decoder x = self.decoder( hidden_states=x, attention_mask=e0, context=context, - context_mask=vace_context, + context_mask=None, + context_signal=vace_context, rotary_pos_emb=rotary_pos_emb, rotary_pos_cos=None, rotary_pos_sin=None,