diff --git a/.gitignore b/.gitignore index 7e7db08e4c..d755ce3aa9 100644 --- a/.gitignore +++ b/.gitignore @@ -182,3 +182,5 @@ slurm*.out # UV package manager .uv/ + +*.mp4 \ No newline at end of file diff --git a/example_commands.sh b/example_commands.sh new file mode 100644 index 0000000000..d244d9cf4e --- /dev/null +++ b/example_commands.sh @@ -0,0 +1,129 @@ +# ### set path to Megatron-Bridge +# export MBRIDGE_PATH=/path/to/Megatron-Bridge +# export PYTHONPATH="${MBRIDGE_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-Launcher/launcher_scripts" + +export CUDA_VISIBLE_DEVICES=0,1 + +# ### install dependencies +# pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@ce8185cbbe04f38beb74360e878450f2e8525885 +# python3 -m pip install --upgrade diffusers +# pip install easydict +# pip install imageio +# pip install imageio-ffmpeg + + +# ### Convert checkpoint +# See examples/conversion/convert_wan_checkpoints.py for details. + + +# ### Finetuning +# export HF_TOKEN=... +# export WANDB_API_KEY=... +# EXP_NAME=... +# PRETRAINED_CHECKPOINT=/path/to/pretrained_checkpoint +# CHECKPOINT_DIR=/path/to/checkpoint_dir +# DATASET_PATH=/path/to/dataset +# cd $MBRIDGE_PATH +# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/recipes/wan/pretrain_wan.py \ +# model.tensor_model_parallel_size=1 \ +# model.pipeline_model_parallel_size=1 \ +# model.context_parallel_size=4 \ +# model.sequence_parallel=false \ +# model.qkv_format=thd \ +# dataset.path=${DATASET_PATH} \ +# checkpoint.save=${CHECKPOINT_DIR} \ +# checkpoint.load=${PRETRAINED_CHECKPOINT} \ +# checkpoint.load_optim=false \ +# checkpoint.save_interval=200 \ +# optimizer.lr=5e-6 \ +# optimizer.min_lr=5e-6 \ +# train.eval_iters=0 \ +# scheduler.lr_decay_style=constant \ +# scheduler.lr_warmup_iters=0 \ +# model.seq_length=2048 \ +# dataset.seq_length=2048 \ +# train.global_batch_size=1 \ +# train.micro_batch_size=1 \ +# dataset.global_batch_size=1 \ +# dataset.micro_batch_size=1 \ +# logger.log_interval=1 \ +# logger.wandb_project="wan" \ +# logger.wandb_exp_name=${EXP_NAME} \ +# logger.wandb_save_dir=${CHECKPOINT_DIR} + + +### Inferencing +# Download T5 weights and VAE weights from "https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/tree/main" +# T5: models_t5_umt5-xxl-enc-bf16.pth, google +# VAE: Wan2.1_VAE.pth + +CHECKPOINT_DIR=/opt/megatron_checkpoint_WAN +T5_DIR=/opt/Wan2.1-T2V-1.3B +VAE_DIR=/opt/Wan2.1-T2V-1.3B +# cd $MBRIDGE_PATH +# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ +# --task t2v-1.3B \ +# --sizes 832*480 \ +# --checkpoint_dir ${CHECKPOINT_DIR} \ +# --checkpoint_step 0000 \ +# --t5_checkpoint_dir ${T5_DIR} \ +# --vae_checkpoint_dir ${VAE_DIR} \ +# --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +# --frame_nums 81 \ +# --tensor_parallel_size 1 \ +# --context_parallel_size 1 \ +# --pipeline_parallel_size 1 \ +# --sequence_parallel False \ +# --base_seed 42 \ +# --sample_steps 50 + + +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_wan.py \ + --task t2v-1.3B \ + --sizes 832*480 \ + --checkpoint_dir ${CHECKPOINT_DIR} \ + --checkpoint_step 0000 \ + --t5_checkpoint_dir ${T5_DIR} \ + --vae_checkpoint_dir ${VAE_DIR} \ + --prompts "Beautiful maple leaves across the mountain during the autumn." \ + --frame_nums 81 \ + --tensor_parallel_size 1 \ + --context_parallel_size 1 \ + --pipeline_parallel_size 1 \ + --sequence_parallel False \ + --base_seed 42 \ + --sample_steps 50 + + + # NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 examples/recipes/wan/inference_wan.py \ + # --task t2v-1.3B \ + # --sizes 832*480 \ + # --checkpoint_dir ${CHECKPOINT_DIR} \ + # --checkpoint_step 0000 \ + # --t5_checkpoint_dir ${T5_DIR} \ + # --vae_checkpoint_dir ${VAE_DIR} \ + # --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + # --frame_nums 81 \ + # --tensor_parallel_size 1 \ + # --context_parallel_size 2 \ + # --pipeline_parallel_size 1 \ + # --sequence_parallel False \ + # --base_seed 42 \ + # --sample_steps 50 + + + # NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 examples/recipes/wan/inference_wan.py \ + # --task t2v-1.3B \ + # --sizes 832*480 \ + # --checkpoint_dir ${CHECKPOINT_DIR} \ + # --checkpoint_step 0000 \ + # --t5_checkpoint_dir ${T5_DIR} \ + # --vae_checkpoint_dir ${VAE_DIR} \ + # --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + # --frame_nums 81 \ + # --tensor_parallel_size 1 \ + # --context_parallel_size 1 \ + # --pipeline_parallel_size 2 \ + # --sequence_parallel False \ + # --base_seed 42 \ + # --sample_steps 50 \ No newline at end of file diff --git a/examples/conversion/convert_vace_checkpoints.py b/examples/conversion/convert_vace_checkpoints.py new file mode 100644 index 0000000000..dd0eb6e378 --- /dev/null +++ b/examples/conversion/convert_vace_checkpoints.py @@ -0,0 +1,49 @@ +import os, random, multiprocessing as mp + +def main(): + from megatron.bridge.models.hf_pretrained.wan import PreTrainedVACE + from megatron.bridge.models.wan.wan_bridge import VACEBridge + from megatron.bridge.training.model_load_save import save_megatron_model + + # --- minimal torch.distributed single-rank env --- + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", str(29500 + random.randint(0, 1000))) + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + + # --- build & load --- + hf = PreTrainedVACE("Wan-AI/Wan2.1-VACE-1.3B-Diffusers") + # hf = PreTrainedVACE("Wan-AI/Wan2.1-VACE-14B-Diffusers") + + bridge = VACEBridge() + provider = bridge.provider_bridge(hf) + provider.perform_initialization = False + + # If you're on GPU but want CPU init to reduce peak mem: + megatron_models = provider.provide_distributed_model( + wrap_with_ddp=False, use_cpu_initialization=True + ) + + bridge.load_weights_hf_to_megatron(hf, megatron_models) + + # Save Megatron-format checkpoint (this triggers async writer internally) + save_megatron_model( + megatron_models, + "/opt/megatron_checkpoint_VACE", + hf_tokenizer_path=None + ) + +if __name__ == "__main__": + # On Linux, prefer 'fork' to avoid re-importing the module on spawn. + try: + mp.set_start_method("fork") + except RuntimeError: + # already set (fine on re-entry or non-Linux) + pass + + # If you’re on macOS/Windows and still want to be extra safe: + # mp.freeze_support() + + main() + diff --git a/examples/conversion/convert_wan_checkpoints.py b/examples/conversion/convert_wan_checkpoints.py new file mode 100644 index 0000000000..c4cf0bfcf3 --- /dev/null +++ b/examples/conversion/convert_wan_checkpoints.py @@ -0,0 +1,74 @@ +# from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN +# from megatron.bridge.models.wan.wan_bridge import WanBridge +# from megatron.bridge.training.model_load_save import save_megatron_model +# import os, random +# os.environ["MASTER_ADDR"] = "127.0.0.1" +# os.environ["MASTER_PORT"] = str(29500 + random.randint(0, 1000)) +# os.environ["RANK"] = "0" +# os.environ["WORLD_SIZE"] = "1" +# os.environ["LOCAL_RANK"] = "0" +# # +# hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") +# # hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-14B-Diffusers") +# bridge = WanBridge() +# # +# provider = bridge.provider_bridge(hf) +# provider.perform_initialization = False +# megatron_models = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True) +# # +# bridge.load_weights_hf_to_megatron(hf, megatron_models) +# save_megatron_model(megatron_models, "/opt/megatron_checkpoint", hf_tokenizer_path=None) + + +# convert_wan_checkpoints.py + +import os, random, multiprocessing as mp + +def main(): + from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN + from megatron.bridge.models.wan.wan_bridge import WanBridge + from megatron.bridge.training.model_load_save import save_megatron_model + + # --- minimal torch.distributed single-rank env --- + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", str(29500 + random.randint(0, 1000))) + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + + # --- build & load --- + hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") + # hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-14B-Diffusers") + + bridge = WanBridge() + provider = bridge.provider_bridge(hf) + provider.perform_initialization = False + + # If you're on GPU but want CPU init to reduce peak mem: + megatron_models = provider.provide_distributed_model( + wrap_with_ddp=False, use_cpu_initialization=True + ) + print(megatron_models[0]) + bridge.load_weights_hf_to_megatron(hf, megatron_models) + + + # Save Megatron-format checkpoint (this triggers async writer internally) + save_megatron_model( + megatron_models, + "/opt/megatron_checkpoint_WAN", + hf_tokenizer_path=None + ) + +if __name__ == "__main__": + # On Linux, prefer 'fork' to avoid re-importing the module on spawn. + try: + mp.set_start_method("fork") + except RuntimeError: + # already set (fine on re-entry or non-Linux) + pass + + # If you’re on macOS/Windows and still want to be extra safe: + # mp.freeze_support() + + main() + diff --git a/examples/recipes/wan/inference_vace.py b/examples/recipes/wan/inference_vace.py new file mode 100644 index 0000000000..b0ce120fec --- /dev/null +++ b/examples/recipes/wan/inference_vace.py @@ -0,0 +1,378 @@ +import argparse +import logging +import os +import sys +import warnings +from datetime import datetime + +warnings.filterwarnings('ignore') + +import random + +import torch +import torch.distributed as dist +from PIL import Image + +from megatron.bridge.models.wan.flow_matching.flow_inference_pipeline import VACEFlowInferencePipeline +from megatron.bridge.models.wan.inference.configs import SIZE_CONFIGS, SUPPORTED_SIZES, MAX_AREA_CONFIGS, WAN_CONFIGS +from megatron.bridge.models.wan.inference.utils.utils import cache_video, cache_image, str2bool + + +EXAMPLE_PROMPT = { + "vace-1.3B": { + "src_ref_images": 'assets/images/girl.png,assets/images/snake.png', + "prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" + }, + "vace-14B": { + "src_ref_images": 'assets/images/girl.png,assets/images/snake.png', + "prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" + } +} + + + + +def validate_args(args): + # Basic check + assert args.checkpoint_dir is not None, "Please specify the checkpoint directory." + assert args.model_name in WAN_CONFIGS, f"Unsupport model name: {args.model_name}" + assert args.model_name in EXAMPLE_PROMPT, f"Unsupport model name: {args.model_name}" + + # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. + if args.sample_steps is None: + args.sample_steps = 50 + + if args.sample_shift is None: + args.sample_shift = 16 + + # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. + if args.frame_nums is None: + args.frame_nums = 81 + + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize) + # Size check + if args.sizes is not None and len(args.sizes) > 0: + for s in args.sizes: + assert s in SUPPORTED_SIZES[args.model_name], f"Unsupport size {s} for model name {args.model_name}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.model_name])}" + return args + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate a image or video from a text prompt or image using Wan" + ) + parser.add_argument( + "--model_name", + type=str, + default="vace-1.3B", + choices=list(WAN_CONFIGS.keys()), + help="The model name to run.") + parser.add_argument( + "--sizes", + type=str, + nargs="+", + default=None, + choices=list(SIZE_CONFIGS.keys()), + help="List of sizes to generate multiple images or videos (WIDTH*HEIGHT). Example: --sizes 1280*720 1920*1080" + ) + parser.add_argument( + "--frame_nums", + type=int, + nargs="+", + default=None, + help="List of frame counts (each should be 4n+1). Broadcasts if single value." + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default=None, + help="The path to the main VACE checkpoint directory.") + parser.add_argument( + "--checkpoint_step", + type=int, + default=None, + help=( + "Optional training step to load, e.g. 1800 -> iter_0001800. " + "If not provided, the latest (largest) step in --checkpoint_dir is used.") + ) + parser.add_argument( + "--t5_checkpoint_dir", + type=str, + default=None, + help="Optional directory containing T5 checkpoint/tokenizer") + parser.add_argument( + "--vae_checkpoint_dir", + type=str, + default=None, + help="Optional directory containing VAE checkpoint") + parser.add_argument( + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." + ) + parser.add_argument( + "--t5_cpu", + action="store_true", + default=False, + help="Whether to place T5 model on CPU.") + parser.add_argument( + "--save_file", + type=str, + default=None, + help="The file to save the generated image or video to.") + parser.add_argument( + "--src_video", + type=str, + nargs="+", + default=None, + help="List of name of the source video. Default None.") + parser.add_argument( + "--src_mask", + type=str, + nargs="+", + default=None, + help="List of name of the source mask. Default None.") + parser.add_argument( + "--src_ref_images", + type=str, + nargs="+", + default=None, + help="List of list of the source reference images. Separated by ','. Default None.") + parser.add_argument( + "--prompts", + type=str, + nargs="+", + default=None, + help="List of prompt to generate the image or video from.") + parser.add_argument( + "--base_seed", + type=int, + default=-1, + help="The seed to use for generating the image or video.") + parser.add_argument( + "--sample_solver", + type=str, + default='unipc', + choices=['unipc', 'dpm++'], + help="The solver used to sample.") + parser.add_argument( + "--sample_steps", type=int, default=None, help="The sampling steps.") + parser.add_argument( + "--sample_shift", + type=float, + default=None, + help="Sampling shift factor for flow matching schedulers.") + parser.add_argument( + "--sample_guide_scale", + type=float, + default=5.0, + help="Classifier free guidance scale.") + parser.add_argument( + "--tensor_parallel_size", + type=int, + default=1, + help="Tensor parallel size.") + parser.add_argument( + "--context_parallel_size", + type=int, + default=1, + help="Context parallel size.") + parser.add_argument( + "--pipeline_parallel_size", + type=int, + default=1, + help="Pipeline parallel size.") + parser.add_argument( + "--sequence_parallel", + type=str2bool, + default=False, + help="Sequence parallel.") + + args = parser.parse_args() + + validate_args(args) + + return args + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)]) + else: + logging.basicConfig(level=logging.ERROR) + + +def generate(args): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = local_rank + _init_logging(rank) + + if args.offload_model is None: + args.offload_model = False if world_size > 1 else True + logging.info( + f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + init_method="env://", + rank=rank, + world_size=world_size) + + cfg = WAN_CONFIGS[args.model_name] + + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {cfg}") + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + if args.prompts is None: + prompts = [None] + else: + prompts = args.prompts * 8 + + if args.src_video is None: + src_video = [None] * len(prompts) + else: + src_video = args.src_video * 8 + + if args.src_mask is None: + src_mask = [None] * len(prompts) + else: + src_mask = args.src_mask * 8 + + if args.src_ref_images is None: + src_ref_images = [None] * len(prompts) + else: + src_ref_images = args.src_ref_images * 8 + + # Resolve sizes list (default to first supported size for task) + if args.sizes is not None and len(args.sizes) > 0: + size_keys = args.sizes * 8 + else: + size_keys = [SUPPORTED_SIZES[args.model_name][0]] + + # Resolve frame counts list (default 81) + if args.frame_nums is not None and len(args.frame_nums) > 0: + frame_nums = args.frame_nums * 8 + else: + frame_nums = [81] + + # Enforce 1:1 pairing across lists + assert len(prompts) == len(size_keys) == len(frame_nums), ( + f"prompts ({len(prompts)}), sizes ({len(size_keys)}), and frame_nums ({len(frame_nums)}) " + f"must have the same length") + + logging.info("Creating VACE flow inference pipeline.") + pipeline = VACEFlowInferencePipeline( + config=cfg, + checkpoint_dir=args.checkpoint_dir, + checkpoint_step=args.checkpoint_step, + t5_checkpoint_dir=args.t5_checkpoint_dir, + vae_checkpoint_dir=args.vae_checkpoint_dir, + device_id=device, + rank=rank, + t5_cpu=args.t5_cpu, + tensor_parallel_size=args.tensor_parallel_size, + context_parallel_size=args.context_parallel_size, + pipeline_parallel_size=args.pipeline_parallel_size, + sequence_parallel=args.sequence_parallel, + pipeline_dtype=torch.float32, + ) + + # DEBUGGING + rank = dist.get_rank() + if rank == 0: + print("tensor_parallel_size:", args.tensor_parallel_size) + print("context_parallel_size:", args.context_parallel_size) + print("pipeline_parallel_size:", args.pipeline_parallel_size) + print("sequence_parallel:", args.sequence_parallel) + print("\n\n\n") + + for i in range(len(src_video)): + sub_src_video, sub_src_mask, sub_src_ref_images = pipeline.prepare_source([src_video[i]], + [src_mask[i]], + [src_ref_images[i]], + frame_nums[i], SIZE_CONFIGS[size_keys[i]], device) + src_video[i], src_mask[i], src_ref_images[i] = *sub_src_video, *sub_src_mask, *sub_src_ref_images + + + logging.info( + f"Generating videos ...") + videos = pipeline.generate( + prompts=prompts, + input_frames=src_video, + input_masks=src_mask, + input_ref_images=src_ref_images, + sizes=[SIZE_CONFIGS[size] for size in size_keys], + frame_nums=frame_nums, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + if rank == 0: + for i, video in enumerate(videos): + formatted_experiment_name = (args.save_file) if args.save_file is not None else "DefaultExp" + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = prompts[i].replace(" ", "_").replace("/", + "_")[:50] + suffix = '.mp4' + formatted_save_file = f"{args.model_name}_{formatted_experiment_name}_videoindex{int(i)}_size{size_keys[i].replace('*','x') if sys.platform=='win32' else size_keys[i]}_{formatted_prompt}_{formatted_time}" + suffix + + # if "t2v" in args.task: + logging.info(f"Saving generated video to {formatted_save_file}") + cache_video( + tensor=video[None], + save_file=formatted_save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + + cache_video( + tensor=src_video[i][None], + save_file=f'{args.model_name}_{formatted_experiment_name}_index{i}_src_video_{formatted_time}.mp4', + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + logging.info(f"Saving src_video to {args.model_name}_{formatted_experiment_name}_index{i}_src_video_{formatted_time}.mp4") + + cache_video( + tensor=src_mask[i][None], + save_file=f'{args.model_name}_{formatted_experiment_name}_index{i}_src_mask_{formatted_time}.mp4', + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(0, 1)) + logging.info(f"Saving src_mask to {args.model_name}_{formatted_experiment_name}_index{i}_src_mask_{formatted_time}.mp4") + + if src_ref_images[i] is not None: + for j, ref_img in enumerate(src_ref_images[i]): + cache_image( + tensor=ref_img[:, 0, ...], + save_file=f'{args.model_name}_{formatted_experiment_name}_index{i}_src_ref_image_{j}_{formatted_time}.png', + nrow=1, + normalize=True, + value_range=(-1, 1)) + logging.info(f"Saving src_ref_image_{j} to {args.model_name}_{formatted_experiment_name}_index{i}_src_ref_image_{j}_{formatted_time}.png") + logging.info("Finished.") + + +if __name__ == "__main__": + args = _parse_args() + generate(args) diff --git a/examples/recipes/wan/inference_wan.py b/examples/recipes/wan/inference_wan.py new file mode 100644 index 0000000000..61f38ecdea --- /dev/null +++ b/examples/recipes/wan/inference_wan.py @@ -0,0 +1,324 @@ +# Example of running script for Wan inference. +# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ +# --task t2v-1.3B \ +# --sizes 480*832 \ +# --checkpoint_dir /path/to/wan_checkpoint_dir \ +# --t5_checkpoint_dir /path/to/t5_checkpoint_dir \ +# --vae_checkpoint_dir /path/to/vae_checkpoint_dir \ +# --frame_nums 81 \ +# --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +# --tensor_parallel_size 1 \ +# --context_parallel_size 1 \ +# --pipeline_parallel_size 1 \ +# --sequence_parallel False \ +# --base_seed 42 \ +# --sample_steps 50 + +import argparse +import logging +import os +import sys +import warnings +from datetime import datetime + +warnings.filterwarnings('ignore') + +import random + +import torch +import torch.distributed as dist +from PIL import Image + +from megatron.bridge.models.wan.flow_matching.flow_inference_pipeline import FlowInferencePipeline +from megatron.bridge.models.wan.inference.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS +from megatron.bridge.models.wan.inference.utils.utils import cache_video, str2bool + +EXAMPLE_PROMPT = { + "t2v-1.3B": { + "prompt": + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "t2v-14B": { + "prompt": + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, +} + + +def _validate_args(args): + # Basic check + assert args.checkpoint_dir is not None, "Please specify the checkpoint directory." + assert args.t5_checkpoint_dir is not None, "Please specify the T5 checkpoint directory." + assert args.vae_checkpoint_dir is not None, "Please specify the VAE checkpoint directory." + assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" + assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" + + # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. + if args.sample_steps is None: + args.sample_steps = 50 + + if args.sample_shift is None: + args.sample_shift = 5.0 + + # Frames default handled later; no single frame arg anymore + + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( + 0, sys.maxsize) + # Size check: only validate provided --sizes; default handled later + if args.sizes is not None and len(args.sizes) > 0: + for s in args.sizes: + assert s in SUPPORTED_SIZES[args.task], ( + f"Unsupport size {s} for task {args.task}, supported sizes are: " + f"{', '.join(SUPPORTED_SIZES[args.task])}") + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate a image or video from a text prompt or image using Wan" + ) + parser.add_argument( + "--task", + type=str, + default="t2v-14B", + choices=list(WAN_CONFIGS.keys()), + help="The task to run.") + parser.add_argument( + "--sizes", + type=str, + nargs="+", + default=None, + choices=list(SIZE_CONFIGS.keys()), + help="A list of sizes to generate multiple images or videos (WIDTH*HEIGHT). Example: --sizes 1280*720 1920*1080" + ) + parser.add_argument( + "--frame_nums", + type=int, + nargs="+", + default=None, + help="List of frame counts (each should be 4n+1). Broadcasts if single value." + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default=None, + help="The path to the main WAN checkpoint directory.") + parser.add_argument( + "--checkpoint_step", + type=int, + default=None, + help=( + "Optional training step to load, e.g. 1800 -> iter_0001800. " + "If not provided, the latest (largest) step in --checkpoint_dir is used.") + ) + parser.add_argument( + "--t5_checkpoint_dir", + type=str, + default=None, + help="Optional directory containing T5 checkpoint/tokenizer") + parser.add_argument( + "--vae_checkpoint_dir", + type=str, + default=None, + help="Optional directory containing VAE checkpoint") + parser.add_argument( + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." + ) + parser.add_argument( + "--t5_cpu", + action="store_true", + default=False, + help="Whether to place T5 model on CPU.") + parser.add_argument( + "--save_file", + type=str, + default=None, + help="The file to save the generated image or video to.") + parser.add_argument( + "--prompts", + type=str, + nargs="+", + default=None, + help="A list of prompts to generate multiple images or videos. Example: --prompts 'a cat' 'a dog'" + ) + parser.add_argument( + "--base_seed", + type=int, + default=-1, + help="The seed to use for generating the image or video.") + parser.add_argument( + "--sample_solver", + type=str, + default='unipc', + choices=['unipc', 'dpm++'], + help="The solver used to sample.") + parser.add_argument( + "--sample_steps", type=int, default=None, help="The sampling steps.") + parser.add_argument( + "--sample_shift", + type=float, + default=None, + help="Sampling shift factor for flow matching schedulers.") + parser.add_argument( + "--sample_guide_scale", + type=float, + default=5.0, + help="Classifier free guidance scale.") + parser.add_argument( + "--tensor_parallel_size", + type=int, + default=1, + help="Tensor parallel size.") + parser.add_argument( + "--context_parallel_size", + type=int, + default=1, + help="Context parallel size.") + parser.add_argument( + "--pipeline_parallel_size", + type=int, + default=1, + help="Pipeline parallel size.") + parser.add_argument( + "--sequence_parallel", + type=str2bool, + default=False, + help="Sequence parallel.") + + args = parser.parse_args() + + _validate_args(args) + + return args + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)]) + else: + logging.basicConfig(level=logging.ERROR) + + +def generate(args): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = local_rank + _init_logging(rank) + + if args.offload_model is None: + args.offload_model = False if world_size > 1 else True + logging.info( + f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + init_method="env://", + rank=rank, + world_size=world_size) + + cfg = WAN_CONFIGS[args.task] + + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {cfg}") + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + if "t2v" in args.task: + # Resolve prompts list (default to example prompt) + if args.prompts is not None and len(args.prompts) > 0: + prompts = args.prompts + else: + prompts = [EXAMPLE_PROMPT[args.task]["prompt"]] + + # Resolve sizes list (default to first supported size for task) + if args.sizes is not None and len(args.sizes) > 0: + size_keys = args.sizes + else: + size_keys = [SUPPORTED_SIZES[args.task][0]] + + # Resolve frame counts list (default 81) + if args.frame_nums is not None and len(args.frame_nums) > 0: + frame_nums = args.frame_nums + else: + frame_nums = [81] + + # Enforce 1:1 pairing across lists + assert len(prompts) == len(size_keys) == len(frame_nums), ( + f"prompts ({len(prompts)}), sizes ({len(size_keys)}), and frame_nums ({len(frame_nums)}) " + f"must have the same length") + + logging.info("Creating flow inference pipeline.") + pipeline = FlowInferencePipeline( + config=cfg, + checkpoint_dir=args.checkpoint_dir, + checkpoint_step=args.checkpoint_step, + t5_checkpoint_dir=args.t5_checkpoint_dir, + vae_checkpoint_dir=args.vae_checkpoint_dir, + device_id=device, + rank=rank, + t5_cpu=args.t5_cpu, + tensor_parallel_size=args.tensor_parallel_size, + context_parallel_size=args.context_parallel_size, + pipeline_parallel_size=args.pipeline_parallel_size, + sequence_parallel=args.sequence_parallel, + pipeline_dtype=torch.float32, + ) + + # DEBUGGING + rank = dist.get_rank() + if rank == 0: + print("tensor_parallel_size:", args.tensor_parallel_size) + print("context_parallel_size:", args.context_parallel_size) + print("pipeline_parallel_size:", args.pipeline_parallel_size) + print("sequence_parallel:", args.sequence_parallel) + print("\n\n\n") + + logging.info( + f"Generating videos ...") + videos = pipeline.generate( + prompts=prompts, + sizes=[SIZE_CONFIGS[size] for size in size_keys], + frame_nums=frame_nums, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + if rank == 0: + for i, video in enumerate(videos): + formatted_experiment_name = (args.save_file) if args.save_file is not None else "DefaultExp" + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = prompts[i].replace(" ", "_").replace("/", + "_")[:50] + suffix = '.mp4' + formatted_save_file = f"{args.task}_{formatted_experiment_name}_videoindex{int(i)}_size{size_keys[i].replace('*','x') if sys.platform=='win32' else size_keys[i]}_{formatted_prompt}_{formatted_time}" + suffix + + if "t2v" in args.task: + logging.info(f"Saving generated video to {formatted_save_file}") + cache_video( + tensor=video[None], + save_file=formatted_save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + logging.info("Finished.") + + +if __name__ == "__main__": + args = _parse_args() + generate(args) diff --git a/examples/recipes/wan/pretrain_wan.py b/examples/recipes/wan/pretrain_wan.py new file mode 100644 index 0000000000..d6a492f655 --- /dev/null +++ b/examples/recipes/wan/pretrain_wan.py @@ -0,0 +1,184 @@ + +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wan Pretraining Script with YAML and CLI Configuration Overrides. + +This script provides a flexible way to pretrain Wan models using Megatron-Bridge with support for +both YAML configuration files and command-line overrides using Hydra-style syntax. + +Examples: + Basic usage with default configuration: + $ torchrun --nproc_per_node=8 pretrain_wan.py + + Using a custom YAML config file: + $ torchrun --nproc_per_node=8 pretrain_wan.py --config-file my_custom_config.yaml + + Using CLI overrides only: + $ torchrun --nproc_per_node=8 pretrain_wan.py model.tensor_model_parallel_size=4 train.train_iters=100000 + + Combining YAML and CLI overrides (CLI takes precedence): + $ torchrun --nproc_per_node=8 pretrain_wan.py --config-file conf/my_config.yaml \ + model.pipeline_dtype=torch.float16 \ + train.global_batch_size=512 + +Configuration Precedence: + 1. Base configuration from pretrain_config() recipe + 2. YAML overrides from --config-file (if provided) + 3. CLI overrides (highest precedence) + +Supported Override Syntax: + - Standard assignment: key=value + - Nested assignment: section.subsection.key=value + - Addition: +new_key=value + - Deletion: ~key_to_remove + - Type conversion: Automatic for basic types (int, float, bool, str) + - Complex types: torch.dtype, enums, etc. are supported +""" + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import Tuple + +from omegaconf import OmegaConf + +from megatron.bridge.recipes.wan.wan import pretrain_config +from megatron.bridge.training.config import ConfigContainer +from megatron.bridge.models.wan.wan_step import WanForwardStep +from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.training.utils.omegaconf_utils import ( + apply_overrides, + create_omegaconf_dict_config, + parse_hydra_overrides, +) +from megatron.bridge.utils.common_utils import get_rank_safe + + +logger: logging.Logger = logging.getLogger(__name__) + + +# Define paths relative to this script's location +# Assumes this script (pretrain_wan.py) is in Megatron-Bridge/examples/recipes/wan/ +# and the config is in a 'conf' subdirectory. +SCRIPT_DIR: Path = Path(__file__).parent.resolve() +DEFAULT_CONFIG_FILENAME: str = "wan_pretrain_override_example.yaml" +DEFAULT_CONFIG_FILE_PATH: Path = SCRIPT_DIR / "conf" / DEFAULT_CONFIG_FILENAME + +# DEBUGGING +import numpy as np +import torch +np.set_printoptions(precision=10, suppress=False) +torch.set_printoptions(precision=10, sci_mode=False) + +def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: + """Parse command line arguments, separating known script args from OmegaConf overrides.""" + parser = argparse.ArgumentParser( + description="Pretrain Wan model using Megatron-Bridge with YAML and CLI overrides", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--config-file", + type=str, + default=str(DEFAULT_CONFIG_FILE_PATH), + help="Path to the YAML OmegaConf override file. Default: conf/wan_pretrain_override_example.yaml", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + + # Parse known args for the script, remaining will be treated as overrides + args, cli_dotlist_overrides = parser.parse_known_args() + return args, cli_dotlist_overrides + + +def main() -> None: + """ + Entry point for the Wan pretraining script. + + This function orchestrates the complete configuration workflow: + 1. Loads the base configuration from pretrain_config() recipe + 2. Applies YAML overrides from --config-file (if exists) + 3. Applies CLI overrides using Hydra-style syntax + 4. Starts Megatron pretraining with the final merged configuration + + Configuration merging preserves callable fields (like activation functions) + and handles type conversions automatically. + + Examples of CLI usage: + # Use default config with custom learning rate + torchrun --nproc_per_node=8 pretrain_wan.py optimizer.lr=0.0002 + + # Custom config file with additional overrides + torchrun --nproc_per_node=8 pretrain_wan.py --config-file my_config.yaml train.train_iters=50000 + + # Multiple overrides for distributed training + torchrun --nproc_per_node=8 pretrain_wan.py \ + model.tensor_model_parallel_size=4 \ + model.pipeline_model_parallel_size=2 \ + train.global_batch_size=512 + """ + args, cli_overrides = parse_cli_args() + + logger.info("Megatron-Bridge Wan Pretraining Script with YAML & CLI Overrides") + logger.info("------------------------------------------------------------------") + + # Load base configuration from the recipe as a Python dataclass + cfg: ConfigContainer = pretrain_config() + logger.info("Loaded base configuration") + + # Print configuration on rank 0 + if get_rank_safe() == 0: + cfg.print_yaml() + + # Convert the initial Python dataclass to an OmegaConf DictConfig for merging + merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg) + + # Load and merge YAML overrides if a config file is provided + if args.config_file: + logger.debug(f"Loading YAML overrides from: {args.config_file}") + if not os.path.exists(args.config_file): + logger.error(f"Override YAML file not found: {args.config_file}") + sys.exit(1) + yaml_overrides_omega = OmegaConf.load(args.config_file) + merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega) + logger.debug("YAML overrides merged successfully.") + + # Apply command-line overrides using Hydra-style parsing + if cli_overrides: + logger.debug(f"Applying Hydra-style command-line overrides: {cli_overrides}") + merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides) + logger.debug("Hydra-style command-line overrides applied successfully.") + + # Apply the final merged OmegaConf configuration back to the original ConfigContainer + logger.debug("Applying final merged configuration back to Python ConfigContainer...") + final_overrides_as_dict = OmegaConf.to_container(merged_omega_conf, resolve=True) + # Apply overrides while preserving excluded fields + apply_overrides(cfg, final_overrides_as_dict, excluded_fields) + + # Display final configuration + if get_rank_safe() == 0: + logger.info("--- Final Merged Configuration ---") + cfg.print_yaml() + logger.info("----------------------------------") + + # Start training + logger.debug("Starting pretraining...") + pretrain(config=cfg, forward_step_func=WanForwardStep()) + + +if __name__ == "__main__": + main() diff --git a/src/megatron/bridge/data/loaders.py b/src/megatron/bridge/data/loaders.py index 6c3aeda95c..7d45114436 100644 --- a/src/megatron/bridge/data/loaders.py +++ b/src/megatron/bridge/data/loaders.py @@ -219,7 +219,11 @@ def worker_init_fn(_): valid_dataloader = build_pretraining_data_loader( valid_ds, train_state.consumed_valid_samples, - "cyclic", + # DEBUGGING + # known issue: + # https://nvidia.slack.com/archives/C09MX7UEB0W/p1761316355203679 + # "cyclic", + "external", cfg.train.micro_batch_size, cfg.dataset.num_workers, cfg.dataset.data_sharding, diff --git a/src/megatron/bridge/data/wan/prepare_energon_dataset_wan.py b/src/megatron/bridge/data/wan/prepare_energon_dataset_wan.py new file mode 100644 index 0000000000..a8464aa6ec --- /dev/null +++ b/src/megatron/bridge/data/wan/prepare_energon_dataset_wan.py @@ -0,0 +1,404 @@ +import os +import json +import pickle +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import cv2 +import numpy as np +import torch +import webdataset as wds + +from diffusers import AutoencoderKLWan +from transformers import AutoTokenizer, UMT5EncoderModel + + +def _map_interpolation(resize_mode: str) -> int: + interpolation_map = { + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "nearest": cv2.INTER_NEAREST, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4, + } + if resize_mode not in interpolation_map: + raise ValueError(f"Invalid resize_mode '{resize_mode}'. Choose from: {list(interpolation_map.keys())}") + return interpolation_map[resize_mode] + + +def _calculate_resize_dimensions( + original_height: int, + original_width: int, + target_size: Optional[Tuple[int, int]], + maintain_aspect_ratio: bool, +) -> Tuple[int, int]: + if target_size is None: + return original_height, original_width + + target_height, target_width = target_size + if not maintain_aspect_ratio: + return target_height, target_width + + original_aspect = original_width / max(1, original_height) + target_aspect = target_width / max(1, target_height) + + if original_aspect > target_aspect: + new_width = target_width + new_height = int(round(target_width / max(1e-6, original_aspect))) + else: + new_height = target_height + new_width = int(round(target_height * original_aspect)) + + return new_height, new_width + + +def _resize_frame( + frame: np.ndarray, + target_size: Optional[Tuple[int, int]], + resize_mode: str, + maintain_aspect_ratio: bool, + center_crop: bool, +) -> np.ndarray: + if target_size is None: + return frame + + original_height, original_width = frame.shape[:2] + resize_height, resize_width = _calculate_resize_dimensions( + original_height, original_width, target_size, maintain_aspect_ratio + ) + + interpolation = _map_interpolation(resize_mode) + resized_frame = cv2.resize(frame, (resize_width, resize_height), interpolation=interpolation) + + if maintain_aspect_ratio and center_crop: + target_height, target_width = target_size + if resize_height != target_height or resize_width != target_width: + y_start = max(0, (resize_height - target_height) // 2) + x_start = max(0, (resize_width - target_width) // 2) + y_end = min(resize_height, y_start + target_height) + x_end = min(resize_width, x_start + target_width) + resized_frame = resized_frame[y_start:y_end, x_start:x_end] + + if resized_frame.shape[0] < target_height or resized_frame.shape[1] < target_width: + pad_height = max(0, target_height - resized_frame.shape[0]) + pad_width = max(0, target_width - resized_frame.shape[1]) + resized_frame = np.pad( + resized_frame, ((0, pad_height), (0, pad_width), (0, 0)), mode="constant", constant_values=0 + ) + + return resized_frame + + +def _read_sidecar_caption(jsonl_path: Path) -> str: + if not jsonl_path.exists(): + return "" + try: + with open(jsonl_path, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + except Exception: + continue + # Prefer keys used across datasets + for key in ("vila_caption", "gemini_v2_caption", "caption", "text"): + if key in obj and isinstance(obj[key], str): + return obj[key] + # If no known key, try first string value + for v in obj.values(): + if isinstance(v, str): + return v + break + except Exception: + return "" + return "" + + +def _get_total_frames(video_path: str) -> int: + cap = cv2.VideoCapture(video_path) + total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + cap.release() + return max(0, total) + + +def _load_metadata(video_folder: Path) -> List[Dict]: + meta_path = video_folder / "meta.json" + if meta_path.exists(): + with open(meta_path, "r") as f: + return json.load(f) + + # Fallback: scan for .mp4 files with sidecar .jsonl; use full frame range + items: List[Dict] = [] + for entry in sorted(video_folder.iterdir()): + if not entry.is_file(): + continue + if entry.suffix.lower() != ".mp4": + continue + video_name = entry.name + video_path = str(entry) + total_frames = _get_total_frames(video_path) + start_frame = 0 + end_frame = max(0, total_frames - 1) + sidecar = entry.with_suffix("") + # Handle names with additional dots gracefully + sidecar_jsonl = Path(str(entry).rsplit(".", 1)[0] + ".jsonl") + caption = _read_sidecar_caption(sidecar_jsonl) + items.append( + { + "file_name": video_name, + "start_frame": start_frame, + "end_frame": end_frame, + "vila_caption": caption, + } + ) + if not items: + raise FileNotFoundError(f"No meta.json and no .mp4 files found in {video_folder}") + return items + + +def _load_frames_cv2( + video_path: str, + start_frame: int, + end_frame: int, + target_size: Optional[Tuple[int, int]], + resize_mode: str, + maintain_aspect_ratio: bool, + center_crop: bool, + target_dtype: torch.dtype, +) -> torch.Tensor: + cap = cv2.VideoCapture(video_path) + frames: List[np.ndarray] = [] + + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + for frame_idx in range(start_frame, end_frame + 1): + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = _resize_frame(frame, target_size, resize_mode, maintain_aspect_ratio, center_crop) + frame = frame.astype(np.float32) / 255.0 + frames.append(frame) + cap.release() + + if not frames: + raise ValueError(f"No frames loaded from {video_path}") + + video_array = np.array(frames) # T, H, W, C in [0,1] + video_tensor = torch.from_numpy(video_array) # T, H, W, C + video_tensor = video_tensor.permute(3, 0, 1, 2).unsqueeze(0) # 1, C, T, H, W + video_tensor = video_tensor.to(dtype=target_dtype) + return video_tensor + + +@torch.no_grad() +def _init_hf_models( + model_id: str, + device: str, + enable_memory_optimization: bool, +): + dtype = torch.float16 if device.startswith("cuda") else torch.float32 + + text_encoder = UMT5EncoderModel.from_pretrained( + model_id, + subfolder="text_encoder", + torch_dtype=dtype, + ) + text_encoder.to(device) + text_encoder.eval() + + vae = AutoencoderKLWan.from_pretrained( + model_id, + subfolder="vae", + torch_dtype=dtype, + ) + vae.to(device) + vae.eval() + if enable_memory_optimization: + vae.enable_slicing() + vae.enable_tiling() + + tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer") + + return vae, text_encoder, tokenizer, dtype + + +@torch.no_grad() +def _encode_text( + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + device: str, + caption: str, +) -> torch.Tensor: + caption = caption.strip() + inputs = tokenizer( + caption, + max_length=512, + padding="max_length", + truncation=True, + return_tensors="pt", + return_attention_mask=True, + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + outputs = text_encoder(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).last_hidden_state + return outputs + + +@torch.no_grad() +def _encode_video_latents( + vae: AutoencoderKLWan, + device: str, + video_tensor: torch.Tensor, + deterministic_latents: bool, +) -> torch.Tensor: + video_tensor = video_tensor.to(device=device, dtype=vae.dtype) + video_tensor = video_tensor * 2.0 - 1.0 # [0,1] -> [-1,1] + + latent_dist = vae.encode(video_tensor) + if deterministic_latents: + video_latents = latent_dist.latent_dist.mean + else: + video_latents = latent_dist.latent_dist.sample() + + latent_mean = video_latents.mean().item() + latent_std = video_latents.std().item() + + if abs(latent_mean) < 0.5 and 0.5 < latent_std < 2.0: + final_latents = video_latents + else: + if not hasattr(vae.config, "latents_mean") or not hasattr(vae.config, "latents_std"): + raise ValueError("Wan2.1 VAE requires latents_mean and latents_std in config") + latents_mean = torch.tensor(vae.config.latents_mean, device=device, dtype=vae.dtype).view(1, -1, 1, 1, 1) + latents_std = torch.tensor(vae.config.latents_std, device=device, dtype=vae.dtype).view(1, -1, 1, 1, 1) + final_latents = (video_latents - latents_mean) / latents_std + + return final_latents + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Prepare WAN WebDataset shards using HF automodel encoders and resizing" + ) + parser.add_argument("--video_folder", type=str, required=True, help="Folder containing videos and meta.json") + parser.add_argument("--output_dir", type=str, required=True, help="Directory to write webdataset shards") + parser.add_argument( + "--model", + default="Wan-AI/Wan2.1-T2V-14B-Diffusers", + help="Wan2.1 model ID (e.g., Wan-AI/Wan2.1-T2V-14B-Diffusers or Wan-AI/Wan2.1-T2V-1.3B-Diffusers)", + ) + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") + parser.add_argument( + "--stochastic", + action="store_true", + help="Use stochastic encoding (sampling) instead of deterministic posterior mean", + ) + parser.add_argument("--no-memory-optimization", action="store_true", help="Disable VAE slicing/tiling") + parser.add_argument("--shard_maxcount", type=int, default=10000, help="Max samples per shard") + + # Resize arguments (match automodel) + parser.add_argument("--height", type=int, default=None, help="Target height for video frames") + parser.add_argument("--width", type=int, default=None, help="Target width for video frames") + parser.add_argument( + "--resize_mode", + default="bilinear", + choices=["bilinear", "bicubic", "nearest", "area", "lanczos"], + help="Interpolation mode for resizing", + ) + parser.add_argument("--no-aspect-ratio", action="store_true", help="Disable aspect ratio preservation") + parser.add_argument("--center-crop", action="store_true", help="Center crop to exact target size after resize") + + args = parser.parse_args() + + video_folder = Path(args.video_folder) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + shard_pattern = str(output_dir / "shard-%06d.tar") + + # Target size + target_size = None + if args.height is not None and args.width is not None: + target_size = (args.height, args.width) + elif (args.height is None) ^ (args.width is None): + parser.error("Both --height and --width must be specified together") + + # Init HF models + vae, text_encoder, tokenizer, model_dtype = _init_hf_models( + model_id=args.model, + device=args.device, + enable_memory_optimization=not args.no_memory_optimization, + ) + + # Load metadata list + metadata_list = _load_metadata(video_folder) + + with wds.ShardWriter(shard_pattern, maxcount=args.shard_maxcount) as sink: + written = 0 + for index, meta in enumerate(metadata_list): + video_name = meta["file_name"] + start_frame = int(meta["start_frame"]) # inclusive + end_frame = int(meta["end_frame"]) # inclusive + caption_text = meta.get("vila_caption", "") + + video_path = str(video_folder / video_name) + # Load frames using the same OpenCV + resize path as automodel + video_tensor = _load_frames_cv2( + video_path=video_path, + start_frame=start_frame, + end_frame=end_frame, + target_size=target_size, + resize_mode=args.resize_mode, + maintain_aspect_ratio=not args.no_aspect_ratio, + center_crop=args.center_crop, + target_dtype=model_dtype, + ) + + # Encode text and video with HF models exactly like automodel + text_embed = _encode_text(tokenizer, text_encoder, args.device, caption_text) + latents = _encode_video_latents(vae, args.device, video_tensor, deterministic_latents=not args.stochastic) + + # Move to CPU without changing dtype; keep exact values to match automodel outputs + text_embed_cpu = text_embed.detach().to(device="cpu") + latents_cpu = latents.detach().to(device="cpu") + + # Reshape to match Mcore's Wan input format + text_embed_cpu = text_embed_cpu[0] + latents_cpu = latents_cpu[0] + + # Build JSON side-info similar to prepare_energon script + C, T, H, W = video_tensor.shape[1:] # 1,C,T,H,W + json_data = { + "video_path": video_path, + "processed_frames": int(T), + "processed_height": int(H), + "processed_width": int(W), + "caption": caption_text, + "deterministic_latents": bool(not args.stochastic), + "memory_optimization": bool(not args.no_memory_optimization), + "model_version": "wan2.1", + "resize_settings": { + "target_size": target_size, + "resize_mode": args.resize_mode, + "maintain_aspect_ratio": bool(not args.no_aspect_ratio), + "center_crop": bool(args.center_crop), + }, + } + + sample = { + "__key__": f"{index:06}", + "pth": latents_cpu, + "pickle": pickle.dumps(text_embed_cpu), + "json": json_data, + } + sink.write(sample) + written += 1 + + print("Done writing shards using HF automodel encoders.") + + +if __name__ == "__main__": + main() + + diff --git a/src/megatron/bridge/data/wan/wan_energon_datamodule.py b/src/megatron/bridge/data/wan/wan_energon_datamodule.py new file mode 100644 index 0000000000..98774e8157 --- /dev/null +++ b/src/megatron/bridge/data/wan/wan_energon_datamodule.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from dataclasses import dataclass +import logging +from typing import Any, Dict, Literal + +from torch import int_repr + +from megatron.bridge.data.Dit.data.diffusion_energon_datamodule import DiffusionDataModule +from megatron.bridge.data.wan.wan_taskencoder import WanTaskEncoder +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider + +@dataclass(kw_only=True) +class WanDataModuleConfig(DatasetProvider): + path: str + seq_length: int + micro_batch_size: int + global_batch_size: int + num_workers: int_repr + dataloader_type: str = "external" + + def __post_init__(self): + self.dataset = DiffusionDataModule( + path=self.path, + seq_length=self.seq_length, + task_encoder=WanTaskEncoder(seq_length=self.seq_length), + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + num_workers=self.num_workers) + self.sequence_length = self.dataset.seq_length + + def build_datasets(self, context: DatasetBuildContext): + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() \ No newline at end of file diff --git a/src/megatron/bridge/data/wan/wan_taskencoder.py b/src/megatron/bridge/data/wan/wan_taskencoder.py new file mode 100644 index 0000000000..a19f755617 --- /dev/null +++ b/src/megatron/bridge/data/wan/wan_taskencoder.py @@ -0,0 +1,192 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import torch +import torch.nn.functional as F +from megatron.energon import DefaultTaskEncoder, SkipSample +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys +from megatron.bridge.models.wan.utils.utils import grid_sizes_calculation, patchify +from megatron.core import parallel_state + + +def cook(sample: dict) -> dict: + """ + Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - All keys from the result of `basic_sample_keys(sample)` + - 'json': The contains meta data like resolution, aspect ratio, fps, etc. + - 'pth': contains video latent tensor + - 'pickle': contains text embeddings + """ + return dict( + **basic_sample_keys(sample), + json=sample["json"], + pth=sample["pth"], + pickle=sample["pickle"], + ) + + +class WanTaskEncoder(DefaultTaskEncoder): + """ + Task encoder for Wan dataset. + Attributes: + cookers (list): A list of Cooker objects used for processing. + patch_spatial (int): The spatial patch size. Defaults to 2. + patch_temporal (int): The temporal patch size. Defaults to 1. + seq_length (int): The sequence length. Defaults to 1024. + """ + + cookers = [ + Cooker(cook), + ] + + def __init__( + self, + *args, + max_frames: int = None, + patch_spatial: int = 2, + patch_temporal: int = 1, + seq_length: int = 1024, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.seq_length = seq_length + + + def encode_sample(self, sample: dict) -> dict: + + video_latent = sample["pth"] + context_embeddings = sample["pickle"] + video_metadata = sample["json"] + + # sanity quality check + if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): + raise SkipSample() + if torch.max(torch.abs(video_latent)) > 1e3: + raise SkipSample() + + # calculate grid size + grid_size = grid_sizes_calculation( + input_shape = video_latent.shape[1:], + patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial), + ) + + ### Note: shape of sample's values + # video_latent: [latents_channels, F_latents, W_latents, H_latents] + # grid_size: [F_patches, W_patches, H_patches] + # context_embeddings: [context_seq_len, text_embedding_dim] + + return dict( + video_latent=video_latent, + grid_size=grid_size, + context_embeddings=context_embeddings, + video_metadata=video_metadata, + ) + + + # def encode_sample(self, sample: dict) -> dict: + + # # mock encode sample + # video_latent = torch.tensor(torch.randn(16, 3, 104, 60), dtype=torch.float32) + # # video_latent = torch.tensor(torch.randn(16, 24, 104, 60), dtype=torch.float32) + # grid_size = torch.tensor([video_latent.shape[1] // self.patch_temporal, video_latent.shape[2] // self.patch_spatial, video_latent.shape[3] // self.patch_spatial], dtype=torch.int32) + # context_embeddings = torch.tensor(torch.randn(512, 4096), dtype=torch.float32) + # video_metadata = {} + + # return dict( + # video_latent=video_latent, + # grid_size=grid_size, + # context_embeddings=context_embeddings, + # video_metadata=video_metadata, + # ) + + + def batch(self, samples: list[dict]) -> dict: + + # process video latents + # do padding here for video latents + self.patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial) + + # running patchify + video_latents = patchify([sample["video_latent"] for sample in samples], self.patch_size) + + # build per-sample loss masks (1 for valid tokens pre-padding) + loss_masks = [torch.ones(v.shape[0]) for v in video_latents] + # calculate all sequence lengths of video latents for self-attention (for videos, we do this before padding to get original seq len) + seq_len_q = [v.shape[0] for v in video_latents] + seq_len_q = torch.tensor(seq_len_q, dtype=torch.int32) + + + # padding and stack video latents + max_video_seq_len = max([video_latent.shape[0] for video_latent in video_latents]) + # CAVEAT: + # when using pipeline parallelism, we need to set batch sequence length to DataModule's seq_length because + # because pipeline parallelism requires pre-specified sequence length to create buffer + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + if max_video_seq_len > self.seq_length: + raise ValueError(f"max_video_seq_len {max_video_seq_len} is greater than DataModule's seq_length {self.seq_length}") + else: + # set max_video_seq_len to DataModule's seq_length + max_video_seq_len = self.seq_length + # CAVEAT: + # when using context parallelism, we need to pad batch sequence length to be divisible by [cp_rank*2] + # (because TransformerEngine's context parallelism requires "AssertionError: Sequence length per GPU needs to be divisible by 2!") + if parallel_state.get_context_parallel_world_size() > 1: + batch_size = len(video_latents) + assert batch_size == 1, "Error: Batch size must be 1 when using context parallelism" + sharding_factor = parallel_state.get_context_parallel_world_size() * 2 + max_video_seq_len = ((max_video_seq_len + sharding_factor - 1) // sharding_factor) * sharding_factor + video_latents = [F.pad(video_latent, (0, 0, 0, max_video_seq_len - video_latent.shape[0])) for video_latent in video_latents] + video_latents = torch.stack(video_latents, dim=1) + # pad and stack loss masks to shape [S_max, B] + loss_masks = [F.pad(m, (0, max_video_seq_len - m.shape[0])) for m in loss_masks] + loss_masks = torch.stack(loss_masks, dim=1) + + # process grid sizes + grid_sizes = [torch.tensor(sample["grid_size"], dtype=torch.int32) for sample in samples] + grid_sizes = torch.stack(grid_sizes, dim=0) + + # process text embeddings + # pad here for text embeddings + context_max_len = 512 + context_embeddings = [sample["context_embeddings"] for sample in samples] + context_embeddings = [F.pad(context_embedding, (0, 0, 0, context_max_len - context_embedding.shape[0])) for context_embedding in context_embeddings] + # calculate all sequence lengths of context embeddings for cross-attention (for videos, we do this after padding to get padded seq len) + seq_len_kv = [c.shape[0] for c in context_embeddings] + seq_len_kv = torch.tensor(seq_len_kv, dtype=torch.int32) + # stack context embeddings + context_embeddings = torch.stack(context_embeddings, dim=1) + + # process video metadata + video_metadata = [sample["video_metadata"] for sample in samples] + + return dict( + video_latents = video_latents, + max_video_seq_len = max_video_seq_len, + grid_sizes = grid_sizes, + context_embeddings = context_embeddings, + loss_mask = loss_masks, + seq_len_q = seq_len_q, + seq_len_kv = seq_len_kv, + video_metadata = video_metadata, + ) \ No newline at end of file diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index 70e33b3734..e3014dcb49 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -1339,6 +1339,90 @@ def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": ) +class KVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): + """ + Mapping for interleaved Key/Value projection weights. + + This mapping converts between separate K and V tensors used in external + checkpoints and Megatron's interleaved KV format following grouped-query + attention semantics. + + External format (HF) + - Separate tensors: k_proj, v_proj + - Shapes mirror QKV mappings but without Q + + Megatron format + - Single interleaved tensor with order: [k1, v1, k2, v2, ...] + where index corresponds to query-group id + + Tensor-parallel distribution is delegated to AutoMapping. + """ + + def __init__(self, megatron_param: str, k: str, v: str): + super().__init__(megatron_param, {"k": k, "v": v}) + # Delegate TP sharding/broadcasting + self._tp_mapping = AutoMapping(megatron_param, megatron_param) + + def hf_to_megatron( + self, + hf_weights: Dict[str, torch.Tensor], + megatron_module: nn.Module, + ) -> torch.Tensor: + """Merge K and V into interleaved format and distribute across TP.""" + if self.tp_rank == 0: + config = self._get_config(megatron_module) + + if hf_weights["k"].ndim == 1: + merged = merge_kv_biases(config, hf_weights["k"], hf_weights["v"]) + else: + merged = merge_kv_weights(config, hf_weights["k"], hf_weights["v"]) + else: + merged = None + + return self._tp_mapping.hf_to_megatron(merged, megatron_module) + + def megatron_to_hf( + self, + megatron_weights: Optional[torch.Tensor], + megatron_module: Optional[nn.Module], + ) -> Dict[str, torch.Tensor]: + """Gather KV shards and split into separate K and V tensors.""" + if megatron_weights is not None: + megatron_weights = self.maybe_dequantize(megatron_weights) + + # Ensure all PP ranks participate in config broadcast + if megatron_module is None: + config = self.broadcast_obj_from_pp_rank(None, "kv_config") + else: + config = self._get_config(megatron_module) + config = remove_non_pickleables(config, max_depth=2) + config = self.broadcast_obj_from_pp_rank(config, "kv_config") + + packed_dict = self._tp_mapping.megatron_to_hf(megatron_weights, megatron_module) + if not packed_dict: + return {} + + packed_kv = next(iter(packed_dict.values())) + + if packed_kv.ndim == 1: + k, v = split_kv_biases(config, packed_kv) + else: + k, v = split_kv_weights(config, packed_kv) + + return { + self.hf_param["k"]: k, + self.hf_param["v"]: v, + } + + def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": + resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) + return type(self)( + resolved_megatron_param, + resolved_hf_param["k"], + resolved_hf_param["v"], + ) + + class GatedMLPMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): r"""Mapping for **gated-MLP** projection weights (SwiGLU / GeGLU). @@ -1652,3 +1736,71 @@ def split_qkv_weights( v = v.reshape(-1, hidden_size) return q, k, v + + +def merge_kv_biases(config: TransformerConfig, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Merge separate K, V bias vectors into Megatron's interleaved KV format (1D).""" + num_query_groups = config.num_query_groups + head_size = config.kv_channels or (config.hidden_size // config.num_attention_heads) + + k = k.view(num_query_groups, head_size) + v = v.view(num_query_groups, head_size) + + pieces: List[torch.Tensor] = [] + for i in range(num_query_groups): + pieces.append(k[i : i + 1, :]) + pieces.append(v[i : i + 1, :]) + + kv = torch.cat(pieces, dim=0) + return kv.reshape(-1) + + +def split_kv_biases(config: TransformerConfig, kv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Split Megatron's interleaved KV bias (1D) into separate K and V biases.""" + num_query_groups = config.num_query_groups + head_size = config.kv_channels or (config.hidden_size // config.num_attention_heads) + kv_total_dim = 2 * num_query_groups + + kv_reshaped = kv.view(kv_total_dim, head_size) + + k_slice = torch.arange(0, kv_total_dim, 2) + v_slice = torch.arange(1, kv_total_dim, 2) + + k = kv_reshaped[k_slice].reshape(-1) + v = kv_reshaped[v_slice].reshape(-1) + return k, v + + +def merge_kv_weights(provider: TransformerConfig, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Merge separate K, V weights into Megatron's interleaved KV format (2D).""" + num_query_groups = provider.num_query_groups + head_size = provider.kv_channels or (provider.hidden_size // provider.num_attention_heads) + hidden_size = provider.hidden_size + + k_reshaped = k.view(num_query_groups, head_size, hidden_size) + v_reshaped = v.view(num_query_groups, head_size, hidden_size) + + pieces: List[torch.Tensor] = [] + for i in range(num_query_groups): + pieces.append(k_reshaped[i : i + 1]) + pieces.append(v_reshaped[i : i + 1]) + + kv = torch.cat(pieces, dim=0) + return kv.view(-1, hidden_size) + + +def split_kv_weights(provider: TransformerConfig, kv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Split Megatron's interleaved KV weights (2D) into separate K and V matrices.""" + num_query_groups = provider.num_query_groups + head_size = provider.kv_channels or (provider.hidden_size // provider.num_attention_heads) + hidden_size = kv.shape[-1] + kv_total_dim = 2 * num_query_groups + + kv_reshaped = kv.view(kv_total_dim, head_size, hidden_size) + + k_slice = torch.arange(0, kv_total_dim, 2) + v_slice = torch.arange(1, kv_total_dim, 2) + + k = kv_reshaped[k_slice].reshape(-1, hidden_size) + v = kv_reshaped[v_slice].reshape(-1, hidden_size) + return k, v diff --git a/src/megatron/bridge/models/hf_pretrained/__init__.py b/src/megatron/bridge/models/hf_pretrained/__init__.py index de1604f253..9bfb9fd83f 100644 --- a/src/megatron/bridge/models/hf_pretrained/__init__.py +++ b/src/megatron/bridge/models/hf_pretrained/__init__.py @@ -14,6 +14,7 @@ from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM +from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN -__all__ = ["PreTrainedCausalLM", "PreTrainedVLM"] +__all__ = ["PreTrainedCausalLM", "PreTrainedVLM", "PreTrainedWAN"] diff --git a/src/megatron/bridge/models/hf_pretrained/state.py b/src/megatron/bridge/models/hf_pretrained/state.py index a47a22771d..b35f2c05f9 100644 --- a/src/megatron/bridge/models/hf_pretrained/state.py +++ b/src/megatron/bridge/models/hf_pretrained/state.py @@ -496,7 +496,8 @@ def key_to_filename_map(self) -> Dict[str, str]: from safetensors import safe_open key_map = {} - safetensor_files = file_glob(str(self.path / "*.safetensors")) + # DEBUGGING + safetensor_files = file_glob(str(self.path / "transformer" / "*.safetensors")) for file_path in safetensor_files: filename = os.path.basename(file_path) try: @@ -564,7 +565,8 @@ def get_all_keys(self) -> List[str]: all_keys.update(key_to_filename_map.keys()) if not all_keys: - safetensor_files = file_glob(str(self.path / "*.safetensors")) + # DEBUGGING + safetensor_files = file_glob(str(self.path / "transformer" / "*.safetensors")) if not safetensor_files and not key_to_filename_map: raise FileNotFoundError(f"No .safetensors files or index found in {self.model_name_or_path}") for safetensor_file in safetensor_files: @@ -603,7 +605,8 @@ def load_tensors(self, keys_to_load: List[str]) -> Dict[str, torch.Tensor]: remaining_keys.discard(key) if remaining_keys: - safetensor_files = file_glob(str(self.path / "*.safetensors")) + # DEBUGGING + safetensor_files = file_glob(str(self.path / "transformer" / "*.safetensors")) if not safetensor_files and not key_to_filename_map and not loaded_tensors: raise FileNotFoundError( f"No .safetensors files found in {self.model_name_or_path} to load keys: {remaining_keys}" @@ -650,7 +653,8 @@ def has_glob(self, pattern: str) -> bool: return False # If no index map, scan the files directly. - safetensor_files = file_glob(str(self.path / "*.safetensors")) + # DEBUGGING + safetensor_files = file_glob(str(self.path / "transformer" / "*.safetensors")) if not safetensor_files: return False diff --git a/src/megatron/bridge/models/hf_pretrained/wan.py b/src/megatron/bridge/models/hf_pretrained/wan.py new file mode 100644 index 0000000000..d682c5cf07 --- /dev/null +++ b/src/megatron/bridge/models/hf_pretrained/wan.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Optional, Union + +from diffusers import WanTransformer3DModel, WanVACETransformer3DModel +from transformers import AutoConfig + +from megatron.bridge.models.hf_pretrained.base import PreTrainedBase + + +class PreTrainedWAN(PreTrainedBase): + """ + Lightweight pretrained wrapper for Diffusers WAN models. + + Provides access to WAN config and state through the common PreTrainedBase API + so bridges can consume `.config` and `.state` uniformly. + """ + + def __init__(self, model_name_or_path: Union[str, Path], **kwargs): + self._model_name_or_path = str(model_name_or_path) + super().__init__(**kwargs) + + @property + def model_name_or_path(self) -> str: + return self._model_name_or_path + + # Model loading is optional for conversion; implemented for completeness + def _load_model(self) -> WanTransformer3DModel: + return WanTransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer") + + # Config is required by the WAN bridge + def _load_config(self) -> AutoConfig: + # WanTransformer3DModel returns a config-like object with required fields + + print(f"Loading config from {self.model_name_or_path}") + + return WanTransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer").config + + +class PreTrainedVACE(PreTrainedBase): + """ + Lightweight pretrained wrapper for Diffusers WAN models. + + Provides access to WAN config and state through the common PreTrainedBase API + so bridges can consume `.config` and `.state` uniformly. + """ + + def __init__(self, model_name_or_path: Union[str, Path], **kwargs): + self._model_name_or_path = str(model_name_or_path) + super().__init__(**kwargs) + + @property + def model_name_or_path(self) -> str: + return self._model_name_or_path + + # Model loading is optional for conversion; implemented for completeness + def _load_model(self) -> WanVACETransformer3DModel: + return WanVACETransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer") + + # Config is required by the WAN bridge + def _load_config(self) -> AutoConfig: + # WanTransformer3DModel returns a config-like object with required fields + + print(f"Loading config from {self.model_name_or_path}") + + return WanVACETransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer").config + + diff --git a/src/megatron/bridge/models/wan/flow_matching/__init__.py b/src/megatron/bridge/models/wan/flow_matching/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/src/megatron/bridge/models/wan/flow_matching/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py new file mode 100644 index 0000000000..385bf6d741 --- /dev/null +++ b/src/megatron/bridge/models/wan/flow_matching/flow_inference_pipeline.py @@ -0,0 +1,1330 @@ +import gc +import logging +import math +import os +import random +import sys +import types +import re +from contextlib import contextmanager +from functools import partial + +from PIL import Image +import torchvision.transforms.functional as TF +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +from tqdm import tqdm + +from megatron.bridge.models.wan.wan_model import WanModel, VACEModel +from megatron.bridge.models.wan.wan_provider import WanModelProvider, VACEModelProvider +from megatron.bridge.models.wan.modules.t5 import T5EncoderModel +from megatron.bridge.models.wan.modules import WanVAE +from megatron.bridge.models.wan.inference.utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) +from megatron.bridge.models.wan.inference.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from megatron.bridge.models.wan.utils.utils import grid_sizes_calculation, patchify +from megatron.core import parallel_state +from torch.nn import functional as F +from megatron.bridge.models.wan.utils.utils import split_inputs_cp, cat_outputs_cp, thd_split_inputs_cp, thd_cat_outputs_cp + +import math +from typing import Tuple, Union + +from ..utils.preprocessor import VaceVideoProcessor + +class FlowInferencePipeline: + + def __init__( + self, + config, + checkpoint_dir, + checkpoint_step=None, + t5_checkpoint_dir=None, + vae_checkpoint_dir=None, + device_id=0, + rank=0, + t5_cpu=False, + + tensor_parallel_size=1, + context_parallel_size=1, + pipeline_parallel_size=1, + sequence_parallel=False, + pipeline_dtype=torch.float32, + ): + r""" + Initializes the FlowInferencePipeline with the given parameters. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + t5_checkpoint_dir (`str`, *optional*, defaults to None): + Optional directory containing T5 checkpoint and tokenizer; falls back to `checkpoint_dir` if None. + vae_checkpoint_dir (`str`, *optional*, defaults to None): + Optional directory containing VAE checkpoint; falls back to `checkpoint_dir` if None. + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.tensor_parallel_size = tensor_parallel_size + self.context_parallel_size = context_parallel_size + self.pipeline_parallel_size = pipeline_parallel_size + self.sequence_parallel = sequence_parallel + self.pipeline_dtype = pipeline_dtype + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(t5_checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(t5_checkpoint_dir, config.t5_tokenizer), + shard_fn=None) + + log_checkpoint("before vae") + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = WanVAE( + vae_pth=os.path.join(vae_checkpoint_dir, config.vae_checkpoint), + device=self.device) + + wan_checkpoint_dir = self._select_checkpoint_dir(checkpoint_dir, checkpoint_step) + self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) + + # if we use context parallelism, we need to set qkv_format to "thd" for context parallelism + self.model.config.qkv_format = "thd" # "sbhd" + + # set self.sp_size=1 for later use, just to respect the original Wan inference code + self.sp_size = 1 + + if dist.is_initialized(): + dist.barrier() + self.model.to(self.device) + + log_checkpoint("after transformer") + + self.sample_neg_prompt = config.sample_neg_prompt + + + def unpatchify(self, x: torch.Tensor, grid_sizes: torch.Tensor, out_dim: int) -> list[torch.Tensor]: + r""" + Reconstruct video tensors from patch embeddings into a list of videotensors. + + Args: + x (torch.Tensor): + Tensor of patchified features, with shape [seq_len, c * pF * pH * pW] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + list[torch.Tensor]: list of tensors, each with shape [c, F_latents, H_latents, W_latents] + """ + + c = out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + + def setup_model_from_checkpoint(self, checkpoint_dir): + provider = WanModelProvider() + provider.tensor_model_parallel_size = self.tensor_parallel_size + provider.pipeline_model_parallel_size = self.pipeline_parallel_size + provider.context_parallel_size = self.context_parallel_size + provider.sequence_parallel = self.sequence_parallel + provider.pipeline_dtype = self.pipeline_dtype + # Once all overrides are set, finalize the model provider to ensure the post initialization logic is run + provider.finalize() + provider.initialize_model_parallel(seed=0) + + ## Read from megatron checkpoint + from megatron.bridge.training.model_load_save import load_megatron_model as _load_megatron_model + model = _load_megatron_model( + checkpoint_dir, + mp_overrides={ + "tensor_model_parallel_size": self.tensor_parallel_size, + "pipeline_model_parallel_size": self.pipeline_parallel_size, + "context_parallel_size": self.context_parallel_size, + "sequence_parallel": self.sequence_parallel, + "pipeline_dtype": self.pipeline_dtype, + }, + ) + if isinstance(model, list): + model = model[0] + # for i in list(model.state_dict().keys()): + # print(i) + if hasattr(model, "module"): + model = model.module + # for ly in model.decoder.layers: + # print(ly.idx) + return model + + def _select_checkpoint_dir(self, base_dir: str, checkpoint_step) -> str: + """ + Resolve checkpoint directory: + - If checkpoint_step is provided, use base_dir/iter_{step:07d} + - Otherwise, pick the largest iter_######## subdirectory under base_dir + """ + if checkpoint_step is not None: + path = os.path.join(base_dir, f"iter_{int(checkpoint_step):07d}") + if os.path.isdir(path): + logging.info(f"Using specified checkpoint: {path}") + return path + raise FileNotFoundError(f"Specified checkpoint step {checkpoint_step} not found at {path}") + + if not os.path.isdir(base_dir): + raise FileNotFoundError(f"Checkpoint base directory does not exist: {base_dir}") + + pattern = re.compile(r"^iter_(\d+)$") + try: + _, latest_path = max( + ((int(pattern.match(e.name).group(1)), e.path) + for e in os.scandir(base_dir) + if e.is_dir() and pattern.match(e.name)), + key=lambda x: x[0], + ) + except ValueError: + raise FileNotFoundError( + f"No checkpoints found under {base_dir}. Expected subdirectories named like 'iter_0001800'.") + + logging.info(f"Auto-selected latest checkpoint: {latest_path}") + return latest_path + + + def forward_pp_step( + self, + latent_model_input: torch.Tensor, + grid_sizes: list[Tuple[int, int, int]], + max_video_seq_len: int, + timestep: torch.Tensor, + arg_c: dict, + ) -> torch.Tensor: + """ + Forward pass supporting pipeline parallelism. + """ + + from megatron.core import parallel_state + from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage, recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank + + pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() + is_pp_first = parallel_state.is_pipeline_first_stage(ignore_virtual=True) + is_pp_last = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + + # PP=1: no pipeline parallelism + if pp_world_size == 1: + noise_pred_pp = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + **arg_c) + return noise_pred_pp + + # PP>1: pipeline parallelism + hidden_size = self.model.config.hidden_size + batch_size = latent_model_input.shape[1] + # noise prediction shape for communication between first and last pipeline stages + noise_pred_pp_shape = list(latent_model_input.shape) + + if is_pp_first: + # First stage: compute multimodal + first PP slice, send activations, then receive sampled token + hidden_states = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + **arg_c) + send_to_next_pipeline_rank(hidden_states) + + noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) + return noise_pred_pp + + if is_pp_last: + # Last stage: recv activations, run final slice + output, sample, broadcast + recv_buffer = torch.empty( + (max_video_seq_len, batch_size, hidden_size), + dtype=next(self.model.parameters()).dtype, + device=latent_model_input[0].device, + ) + recv_from_prev_pipeline_rank_(recv_buffer) + recv_buffer = recv_buffer.to(torch.bfloat16) # ???? + self.model.set_input_tensor(recv_buffer) + noise_pred_pp = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + **arg_c) + + noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=noise_pred_pp.dtype, tensor=noise_pred_pp.contiguous()) + return noise_pred_pp + + # Intermediate stages: recv -> run local slice -> send -> receive broadcast token + recv_buffer = torch.empty( + (max_video_seq_len, batch_size, hidden_size), + dtype=next(self.model.parameters()).dtype, + device=latent_model_input[0].device, + ) + recv_from_prev_pipeline_rank_(recv_buffer) + recv_buffer = recv_buffer.to(torch.bfloat16) # ???? + self.model.set_input_tensor(recv_buffer) + hidden_states = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + **arg_c) + send_to_next_pipeline_rank(hidden_states) + + noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) + return noise_pred_pp + + + def generate(self, + prompts, + sizes, + frame_nums, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + prompts (`list[str]`): + Text prompt for content generation + sizes (list[tuple[int, int]]): + Controls video resolution, (width,height). + frame_nums (`list[int]`): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + + # preprocess + target_shapes = [] + for size, frame_num in zip(sizes, frame_nums): + target_shapes.append((self.vae.model.z_dim, (frame_num - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2])) + + max_video_seq_len = 0 + seq_lens = [] + for target_shape in target_shapes: + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + seq_lens.append(seq_len) + max_video_seq_len = max(seq_lens) + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + + ## process context + context_max_len = 512 + context_lens = [] + contexts = [] + contexts_null = [] + for prompt in prompts: + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([prompt], self.device)[0] + context_null = self.text_encoder([n_prompt], self.device)[0] + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([prompt], torch.device('cpu'))[0].to(self.device) + context_null = self.text_encoder([n_prompt], torch.device('cpu'))[0].to(self.device) + context_lens.append(context_max_len) # all samples have the same context_max_len + contexts.append(context) + contexts_null.append(context_null) + # pad to context_max_len tokens, and stack to a tensor of shape [s, b, hidden] + contexts = [F.pad(context, (0, 0, 0, context_max_len - context.shape[0])) for context in contexts] + contexts_null = [F.pad(context_null, (0, 0, 0, context_max_len - context_null.shape[0])) for context_null in contexts_null] + contexts = torch.stack(contexts, dim=1) + contexts_null = torch.stack(contexts_null, dim=1) + + + ## setup noise + noises = [] + for target_shape in target_shapes: + noises.append( + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ) + + + # calculate grid_sizes + grid_sizes = [grid_sizes_calculation( + input_shape =u.shape[1:], + patch_size=self.model.patch_size, + ) for u in noises] + grid_sizes = torch.tensor(grid_sizes, dtype=torch.long) + + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + # Create a prototype scheduler to compute shared timesteps + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + + # Instantiate per-sample schedulers so each sample maintains its own state + batch_size_for_schedulers = len(noises) + schedulers = [] + for _ in range(batch_size_for_schedulers): + s = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + s.set_timesteps(sampling_steps, device=self.device, shift=shift) + schedulers.append(s) + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noises + + from megatron.core.packed_seq_params import PackedSeqParams + cu_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)]) + cu_q = cu_q.to(torch.int32).to(self.device) + cu_kv_self = cu_q + cu_kv_cross = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(context_lens), dim=0)]) + cu_kv_cross = cu_kv_cross.to(torch.int32).to(self.device) + packed_seq_params = { + "self_attention": PackedSeqParams( + cu_seqlens_q=cu_q, + cu_seqlens_kv=cu_kv_self, + qkv_format=self.model.config.qkv_format, + ), + "cross_attention": PackedSeqParams( + cu_seqlens_q=cu_q, + cu_seqlens_kv=cu_kv_cross, + qkv_format=self.model.config.qkv_format, + ), + } + + + # context parallel + if parallel_state.get_context_parallel_world_size() > 1: + contexts = thd_split_inputs_cp(contexts, packed_seq_params['cross_attention'].cu_seqlens_kv, parallel_state.get_context_parallel_group()) + contexts_null = thd_split_inputs_cp(contexts_null, packed_seq_params['cross_attention'].cu_seqlens_kv, parallel_state.get_context_parallel_group()) + + + arg_c = {'context': contexts, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} + arg_null = {'context': contexts_null, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} + + for _, t in enumerate(tqdm(timesteps)): + + batch_size = len(latents) + + # patchify latents + unpatchified_latents = latents + latents = patchify(latents, self.patch_size) + # pad to have same length + for i in range(batch_size): + latents[i] = F.pad(latents[i], (0, 0, 0, max_video_seq_len - latents[i].shape[0])) + latents = torch.stack(latents, dim=1) + + + # context parallel + if parallel_state.get_context_parallel_world_size() > 1: + latents = thd_split_inputs_cp(latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + + + latent_model_input = latents + timestep = [t] * batch_size + timestep = torch.stack(timestep) + + self.model.to(self.device) + noise_pred_cond = self.forward_pp_step( + latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_c) + + noise_pred_uncond = self.forward_pp_step( + latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_null) + + + # context parallel + if parallel_state.get_context_parallel_world_size() > 1: + noise_pred_cond = thd_cat_outputs_cp(noise_pred_cond, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + noise_pred_uncond = thd_cat_outputs_cp(noise_pred_uncond, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + + + # run unpatchify + unpatchified_noise_pred_cond = noise_pred_cond + unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. + unpatchified_noise_pred_cond = self.unpatchify(unpatchified_noise_pred_cond, grid_sizes, self.vae.model.z_dim) + unpatchified_noise_pred_uncond = noise_pred_uncond + unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond.transpose(0, 1) # bring sbhd -> bshd + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. + unpatchified_noise_pred_uncond = self.unpatchify(unpatchified_noise_pred_uncond, grid_sizes, self.vae.model.z_dim) + + noise_preds = [] + for i in range(batch_size): + noise_pred = unpatchified_noise_pred_uncond[i] + guide_scale * ( + unpatchified_noise_pred_cond[i] - unpatchified_noise_pred_uncond[i]) + noise_preds.append(noise_pred) + + # step and update latents + latents = [] + for i in range(batch_size): + + if sample_solver == 'unipc': + temp_x0 = schedulers[i].step( + noise_preds[i].unsqueeze(0), + t, + unpatchified_latents[i].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + else: + temp_x0 = sample_scheduler.step( + noise_preds[i].unsqueeze(0), + t, + unpatchified_latents[i].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents.append(temp_x0.squeeze(0)) + + x0 = latents + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + if self.rank == 0: + videos = self.vae.decode(x0) + else: + videos = None + + del noises, latents + if sample_solver == 'unipc': + del schedulers + else: + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos if self.rank == 0 else None + + +def log_checkpoint(tag): + torch.cuda.synchronize() + alloc = torch.cuda.memory_allocated() / 1024**3 + reserved = torch.cuda.memory_reserved() / 1024**3 + print(f"[{tag}] alloc={alloc:.2f} GB reserved={reserved:.2f} GB") + + +class VACEFlowInferencePipeline: + + def __init__( + self, + config, + checkpoint_dir, + checkpoint_step=None, + t5_checkpoint_dir=None, + vae_checkpoint_dir=None, + device_id=0, + rank=0, + t5_cpu=False, + + tensor_parallel_size=1, + context_parallel_size=1, + pipeline_parallel_size=1, + sequence_parallel=False, + pipeline_dtype=torch.float32, + ): + r""" + Initializes the FlowInferencePipeline with the given parameters. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + t5_checkpoint_dir (`str`, *optional*, defaults to None): + Optional directory containing T5 checkpoint and tokenizer; falls back to `checkpoint_dir` if None. + vae_checkpoint_dir (`str`, *optional*, defaults to None): + Optional directory containing VAE checkpoint; falls back to `checkpoint_dir` if None. + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.tensor_parallel_size = tensor_parallel_size + self.context_parallel_size = context_parallel_size + self.pipeline_parallel_size = pipeline_parallel_size + self.sequence_parallel = sequence_parallel + self.pipeline_dtype = pipeline_dtype + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(t5_checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(t5_checkpoint_dir, config.t5_tokenizer), + shard_fn=None) + + log_checkpoint("before vae") + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = WanVAE( + vae_pth=os.path.join(vae_checkpoint_dir, config.vae_checkpoint), + device=self.device) + + wan_checkpoint_dir = self._select_checkpoint_dir(checkpoint_dir, checkpoint_step) + self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) + + # if we use context parallelism, we need to set qkv_format to "thd" for context parallelism + self.model.config.qkv_format = "thd" # "sbhd" + + # set self.sp_size=1 for later use, just to respect the original Wan inference code + self.sp_size = 1 + + if dist.is_initialized(): + dist.barrier() + self.model.to(self.device) + + log_checkpoint("after transformer") + + self.sample_neg_prompt = config.sample_neg_prompt + + self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(self.vae_stride, self.patch_size)]), + min_area=832 *480, + max_area=832 *480, + min_fps=self.config.sample_fps, + max_fps=self.config.sample_fps, + zero_start=True, + seq_len=32760, + keep_last=True) + + + def unpatchify(self, x: torch.Tensor, grid_sizes: torch.Tensor, out_dim: int) -> list[torch.Tensor]: + r""" + Reconstruct video tensors from patch embeddings into a list of videotensors. + + Args: + x (torch.Tensor): + Tensor of patchified features, with shape [seq_len, c * pF * pH * pW] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + list[torch.Tensor]: list of tensors, each with shape [c, F_latents, H_latents, W_latents] + """ + + c = out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + + def setup_model_from_checkpoint(self, checkpoint_dir): + provider = VACEModelProvider() + provider.tensor_model_parallel_size = self.tensor_parallel_size + provider.pipeline_model_parallel_size = self.pipeline_parallel_size + provider.context_parallel_size = self.context_parallel_size + provider.sequence_parallel = self.sequence_parallel + provider.pipeline_dtype = self.pipeline_dtype + # Once all overrides are set, finalize the model provider to ensure the post initialization logic is run + provider.finalize() + provider.initialize_model_parallel(seed=0) + + ## Read from megatron checkpoint + from megatron.bridge.training.model_load_save import load_megatron_model as _load_megatron_model + model = _load_megatron_model( + checkpoint_dir, + mp_overrides={ + "tensor_model_parallel_size": self.tensor_parallel_size, + "pipeline_model_parallel_size": self.pipeline_parallel_size, + "context_parallel_size": self.context_parallel_size, + "sequence_parallel": self.sequence_parallel, + "pipeline_dtype": self.pipeline_dtype, + }, + ) + if isinstance(model, list): + model = model[0] + if hasattr(model, "module"): + model = model.module + return model + + def _select_checkpoint_dir(self, base_dir: str, checkpoint_step) -> str: + """ + Resolve checkpoint directory: + - If checkpoint_step is provided, use base_dir/iter_{step:07d} + - Otherwise, pick the largest iter_######## subdirectory under base_dir + """ + if checkpoint_step is not None: + path = os.path.join(base_dir, f"iter_{int(checkpoint_step):07d}") + if os.path.isdir(path): + logging.info(f"Using specified checkpoint: {path}") + return path + raise FileNotFoundError(f"Specified checkpoint step {checkpoint_step} not found at {path}") + + if not os.path.isdir(base_dir): + raise FileNotFoundError(f"Checkpoint base directory does not exist: {base_dir}") + + pattern = re.compile(r"^iter_(\d+)$") + try: + _, latest_path = max( + ((int(pattern.match(e.name).group(1)), e.path) + for e in os.scandir(base_dir) + if e.is_dir() and pattern.match(e.name)), + key=lambda x: x[0], + ) + except ValueError: + raise FileNotFoundError( + f"No checkpoints found under {base_dir}. Expected subdirectories named like 'iter_0001800'.") + + logging.info(f"Auto-selected latest checkpoint: {latest_path}") + return latest_path + + + def vace_encode_frames(self, frames, ref_images, masks=None): + vae = self.vae + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames) + else: + masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks] + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive) + reactive = vae.encode(reactive) + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs) + else: + ref_latent = vae.encode(refs) + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + + def vace_encode_masks(self, masks, ref_images=None): + vae_stride = self.vae_stride + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // vae_stride[0]) + height = 2 * (int(height) // (vae_stride[1] * 2)) + width = 2 * (int(width) // (vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, vae_stride[1], width, vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + vae_stride[1] * vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + + def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device): + area = image_size[0] * image_size[1] + self.vid_proc.set_area(area) + if area == 1280*720: + self.vid_proc.set_seq_len(75600) + elif area == 832*480: + self.vid_proc.set_seq_len(32760) + else: + raise NotImplementedError(f'image_size {image_size} is not supported') + + image_size = (image_size[1], image_size[0]) + image_sizes = [] + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_mask is not None and sub_src_video is not None: + src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask) + src_video[i] = src_video[i].to(device) + src_mask[i] = src_mask[i].to(device) + src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) + image_sizes.append(src_video[i].shape[2:]) + elif sub_src_video is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(image_size) + else: + src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video) + src_video[i] = src_video[i].to(device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(src_video[i].shape[2:]) + + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + image_size = image_sizes[i] + for j, ref_img in enumerate(ref_images): + if ref_img is not None: + ref_img = Image.open(ref_img).convert("RGB") + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + ref_img = white_canvas + src_ref_images[i][j] = ref_img.to(device) + return src_video, src_mask, src_ref_images + + + def decode_latent(self, latent, ref_images=None): + vae = self.vae + if ref_images is None: + ref_images = [None] * len(latent) + else: + assert len(latent) == len(ref_images) + + trimed_latent = [] + for lat, refs in zip(latent, ref_images): + if refs is not None: + lat = lat[:, len(refs):, :, :] + trimed_latent.append(lat) + + return vae.decode(trimed_latent) + + + def forward_pp_step( + self, + latent_model_input: torch.Tensor, + grid_sizes: list[Tuple[int, int, int]], + max_video_seq_len: int, + timestep: torch.Tensor, + vace_context: torch.Tensor, + arg_c: dict, + ) -> torch.Tensor: + """ + Forward pass supporting pipeline parallelism. + """ + + from megatron.core import parallel_state + from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage, recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank + + pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() + is_pp_first = parallel_state.is_pipeline_first_stage(ignore_virtual=True) + is_pp_last = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + + # PP=1: no pipeline parallelism + if pp_world_size == 1: + noise_pred_pp = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + vace_context=vace_context, + **arg_c) + return noise_pred_pp + + # # PP>1: pipeline parallelism + # hidden_size = self.model.config.hidden_size + # batch_size = latent_model_input.shape[1] + # # noise prediction shape for communication between first and last pipeline stages + # noise_pred_pp_shape = list(latent_model_input.shape) + + # if is_pp_first: + # # First stage: compute multimodal + first PP slice, send activations, then receive sampled token + # hidden_states = self.model( + # latent_model_input, + # grid_sizes=grid_sizes, + # t=timestep, + # **arg_c) + # send_to_next_pipeline_rank(hidden_states) + + # noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) + # return noise_pred_pp + + # if is_pp_last: + # # Last stage: recv activations, run final slice + output, sample, broadcast + # recv_buffer = torch.empty( + # (max_video_seq_len, batch_size, hidden_size), + # dtype=next(self.model.parameters()).dtype, + # device=latent_model_input[0].device, + # ) + # recv_from_prev_pipeline_rank_(recv_buffer) + # recv_buffer = recv_buffer.to(torch.bfloat16) # ???? + # self.model.set_input_tensor(recv_buffer) + # noise_pred_pp = self.model( + # latent_model_input, + # grid_sizes=grid_sizes, + # t=timestep, + # **arg_c) + + # noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=noise_pred_pp.dtype, tensor=noise_pred_pp.contiguous()) + # return noise_pred_pp + + # # Intermediate stages: recv -> run local slice -> send -> receive broadcast token + # recv_buffer = torch.empty( + # (max_video_seq_len, batch_size, hidden_size), + # dtype=next(self.model.parameters()).dtype, + # device=latent_model_input[0].device, + # ) + # recv_from_prev_pipeline_rank_(recv_buffer) + # recv_buffer = recv_buffer.to(torch.bfloat16) # ???? + # self.model.set_input_tensor(recv_buffer) + # hidden_states = self.model( + # latent_model_input, + # grid_sizes=grid_sizes, + # t=timestep, + # **arg_c) + # send_to_next_pipeline_rank(hidden_states) + + # noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) + # return noise_pred_pp + + + def generate(self, + prompts, + input_frames, + input_masks, + input_ref_images, + sizes, + frame_nums, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + prompts (`list[str]`): + Text prompt for content generation + Input_frames (`list[Tensor]`): + Input frames for content generation + Input_masks (`list[Tensor]`): + Input masks for content generation + Input_ref_images (`list[Tensor]`): + Input reference images for content generation + sizes (list[tuple[int, int]]): + Controls video resolution, (width,height). + frame_nums (`list[int]`): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N, H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + + + # process source video, mask, reference image + vace_context0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks) + mask0 = self.vace_encode_masks(input_masks, input_ref_images) + vace_context = self.vace_latent(vace_context0, mask0) + + # # for huggingface inference, latent shape: B, C_latent, N/4, H/8, W/8 + # vace_context_hf = torch.stack(vace_context) + + max_video_seq_len = 0 + seq_lens = [] + target_shapes = [] + for item in vace_context0: + target_shape = list(item.shape) + target_shape[0] = int(target_shape[0] / 2) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + seq_lens.append(seq_len) + target_shapes.append(target_shape) + max_video_seq_len = max(seq_lens) + + vace_context = patchify(vace_context, self.patch_size) + # pad to have same length + for i in range(len(vace_context)): + vace_context[i] = F.pad(vace_context[i], (0, 0, 0, max_video_seq_len - vace_context[i].shape[0])) + vace_context = torch.stack(vace_context, dim=1) + + s, b, h = vace_context.shape + vace_context = vace_context.transpose(0, 1).reshape(s*b, 1, h) + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + + ## process context + context_max_len = 512 + context_lens = [] + contexts = [] + contexts_null = [] + for prompt in prompts: + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([prompt], self.device)[0] + context_null = self.text_encoder([n_prompt], self.device)[0] + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([prompt], torch.device('cpu'))[0].to(self.device) + context_null = self.text_encoder([n_prompt], torch.device('cpu'))[0].to(self.device) + context_lens.append(context_max_len) # all samples have the same context_max_len + contexts.append(context) + contexts_null.append(context_null) + # pad to context_max_len tokens, and stack to a tensor of shape [s, b, hidden] + contexts = [F.pad(context, (0, 0, 0, context_max_len - context.shape[0])) for context in contexts] + contexts_null = [F.pad(context_null, (0, 0, 0, context_max_len - context_null.shape[0])) for context_null in contexts_null] + contexts = torch.stack(contexts, dim=1) + contexts_null = torch.stack(contexts_null, dim=1) + + s, b, h = contexts.shape + contexts = contexts.transpose(0, 1).reshape(s*b, 1, h) + contexts_null = contexts_null.transpose(0, 1).reshape(s*b, 1, h) + + ## setup noise + noises = [] + for target_shape in target_shapes: + noises.append( + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ) + # noises = noises[:1] * len(noises) + + # calculate grid_sizes + grid_sizes = [grid_sizes_calculation( + input_shape =u.shape[1:], + patch_size=self.model.patch_size, + ) for u in noises] + grid_sizes = torch.tensor(grid_sizes, dtype=torch.long) + + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + # Create a prototype scheduler to compute shared timesteps + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + + # Instantiate per-sample schedulers so each sample maintains its own state + batch_size_for_schedulers = len(noises) + schedulers = [] + for _ in range(batch_size_for_schedulers): + s = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + s.set_timesteps(sampling_steps, device=self.device, shift=shift) + schedulers.append(s) + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noises + + from megatron.core.packed_seq_params import PackedSeqParams + cu_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)]) + cu_q = cu_q.to(torch.int32).to(self.device) + cu_kv_self = cu_q + cu_kv_cross = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(context_lens), dim=0)]) + cu_kv_cross = cu_kv_cross.to(torch.int32).to(self.device) + packed_seq_params = { + "self_attention": PackedSeqParams( + cu_seqlens_q=cu_q, + cu_seqlens_kv=cu_kv_self, + qkv_format=self.model.config.qkv_format, + ), + "cross_attention": PackedSeqParams( + cu_seqlens_q=cu_q, + cu_seqlens_kv=cu_kv_cross, + qkv_format=self.model.config.qkv_format, + ), + } + + + # context parallel + if parallel_state.get_context_parallel_world_size() > 1: + vace_context = thd_split_inputs_cp(vace_context, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + contexts = thd_split_inputs_cp(contexts, packed_seq_params['cross_attention'].cu_seqlens_kv, parallel_state.get_context_parallel_group()) + contexts_null = thd_split_inputs_cp(contexts_null, packed_seq_params['cross_attention'].cu_seqlens_kv, parallel_state.get_context_parallel_group()) + + + arg_c = {'context': contexts, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} + arg_null = {'context': contexts_null, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} + + + from megatron.bridge.models.hf_pretrained.wan import PreTrainedVACE + hf = PreTrainedVACE("Wan-AI/Wan2.1-VACE-1.3B-Diffusers")._load_model().to(self.device) + + + for _, t in enumerate(tqdm(timesteps)): + + batch_size = len(latents) + + # patchify latents + unpatchified_latents = latents + latents = patchify(latents, self.patch_size) + # pad to have same length + for i in range(batch_size): + latents[i] = F.pad(latents[i], (0, 0, 0, max_video_seq_len - latents[i].shape[0])) + latents = torch.stack(latents, dim=1) + + s, b, h = latents.shape + latents = latents.transpose(0, 1).reshape(s*b, 1, h) + + # context parallel + if parallel_state.get_context_parallel_world_size() > 1: + latents = thd_split_inputs_cp(latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + + + latent_model_input = latents + timestep = [t] * 1 + timestep = torch.stack(timestep) + + self.model.to(self.device) + noise_pred_cond = self.forward_pp_step( + latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, vace_context=vace_context, arg_c=arg_c) + + noise_pred_uncond = self.forward_pp_step( + latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, vace_context=vace_context, arg_c=arg_null) + + + # context parallel + if parallel_state.get_context_parallel_world_size() > 1: + noise_pred_cond = thd_cat_outputs_cp(noise_pred_cond, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + noise_pred_uncond = thd_cat_outputs_cp(noise_pred_uncond, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + + noise_pred_cond = noise_pred_cond.reshape(b, s, h).transpose(0, 1) + noise_pred_uncond = noise_pred_uncond.reshape(b, s, h).transpose(0, 1) + + # run unpatchify + unpatchified_noise_pred_cond = noise_pred_cond + unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. + unpatchified_noise_pred_cond = self.unpatchify(unpatchified_noise_pred_cond, grid_sizes, self.vae.model.z_dim) + unpatchified_noise_pred_uncond = noise_pred_uncond + unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond.transpose(0, 1) # bring sbhd -> bshd + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. + unpatchified_noise_pred_uncond = self.unpatchify(unpatchified_noise_pred_uncond, grid_sizes, self.vae.model.z_dim) + + + # # for huggingface inference + # unpatchified_latents = torch.stack(latents) + # timestep = [t] * batch_size + # timestep = torch.stack(timestep) + # unpatchified_noise_pred_cond=hf(hidden_states=unpatchified_latents, + # timestep=timestep, + # encoder_hidden_states=contexts.transpose(0,1), + # control_hidden_states=vace_context_hf, + # return_dict=False)[0] + # unpatchified_noise_pred_uncond=hf(hidden_states=unpatchified_latents, + # timestep=timestep, + # encoder_hidden_states=contexts_null.transpose(0,1), + # control_hidden_states=vace_context_hf, + # return_dict=False)[0] + + + noise_preds = [] + for i in range(batch_size): + noise_pred = unpatchified_noise_pred_uncond[i] + guide_scale * ( + unpatchified_noise_pred_cond[i] - unpatchified_noise_pred_uncond[i]) + noise_preds.append(noise_pred) + + # step and update latents + latents = [] + for i in range(batch_size): + + if sample_solver == 'unipc': + temp_x0 = schedulers[i].step( + noise_preds[i].unsqueeze(0), + t, + unpatchified_latents[i].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + else: + temp_x0 = sample_scheduler.step( + noise_preds[i].unsqueeze(0), + t, + unpatchified_latents[i].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents.append(temp_x0.squeeze(0)) + + x0 = latents + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + if self.rank == 0: + videos = self.decode_latent(x0, input_ref_images) + else: + videos = None + + del noises, latents + if sample_solver == 'unipc': + del schedulers + else: + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos if self.rank == 0 else None diff --git a/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py b/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py new file mode 100644 index 0000000000..f6b80c1f19 --- /dev/null +++ b/src/megatron/bridge/models/wan/flow_matching/flow_pipeline.py @@ -0,0 +1,223 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, Optional, Tuple, List + +import numpy as np +import torch +from megatron.core import parallel_state +from torch import Tensor +from diffusers import WanPipeline +from megatron.bridge.models.wan.flow_matching.time_shift_utils import compute_density_for_timestep_sampling +from megatron.bridge.models.wan.utils.utils import patchify, thd_split_inputs_cp + +class FlowPipeline: + + def __init__( + self, + model_id="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + seed=1234, + ): + """ + Initializes the FlowPipeline with the given parameters. + """ + self.pipe = WanPipeline.from_pretrained(model_id, vae=None, torch_dtype=torch.float32, text_encoder=None) + + + def training_step( + self, + model, + data_batch: dict[str, torch.Tensor], + # Flow matching parameters + use_sigma_noise: bool = True, + timestep_sampling: str = "uniform", + logit_mean: float = 0.0, + logit_std: float = 1.0, + flow_shift: float = 3.0, + mix_uniform_ratio: float = 0.1, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + Performs a single training step using flow matching algorithm. + + This method is responsible for executing one iteration of the model's training. It involves: + 1. Generate noise and add it to the input data. + 2. Pass the noisy data through the network to generate predictions. + 3. Compute the loss based on the difference between the predictions and target. + """ + + video_latents = data_batch['video_latents'] + max_video_seq_len = data_batch['max_video_seq_len'] + context_embeddings = data_batch['context_embeddings'] + loss_mask = data_batch['loss_mask'] + grid_sizes = data_batch['grid_sizes'] + packed_seq_params = data_batch['packed_seq_params'] + video_metadata = data_batch['video_metadata'] + + self.model = model + + batch_size = video_latents.shape[1] + device = video_latents.device + + # # # DEBUGGING precision + # # import torch.cuda.amp as amp + # # with amp.autocast(dtype=torch.bfloat16): + # # # Pass through model + # # ... + + # ======================================================================== + # Flow Matching Timestep Sampling + # ======================================================================== + + num_train_timesteps = self.pipe.scheduler.config.num_train_timesteps + + if use_sigma_noise: + use_uniform = torch.rand(1).item() < mix_uniform_ratio + + if use_uniform or timestep_sampling == "uniform": + # Pure uniform: u ~ U(0, 1) + u = torch.rand(size=(batch_size,), device=device) + sampling_method = "uniform" + else: + # Density-based sampling + u = compute_density_for_timestep_sampling( + weighting_scheme=timestep_sampling, + batch_size=batch_size, + logit_mean=logit_mean, + logit_std=logit_std, + ).to(device) + sampling_method = timestep_sampling + + # Apply flow shift: σ = shift/(shift + (1/u - 1)) + u_clamped = torch.clamp(u, min=1e-5) # Avoid division by zero + sigma = flow_shift / (flow_shift + (1.0 / u_clamped - 1.0)) + sigma = torch.clamp(sigma, 0.0, 1.0) + + else: + # Simple uniform without shift + u = torch.rand(size=(batch_size,), device=device) + sigma = u + sampling_method = "uniform_no_shift" + + # ======================================================================== + # Manual Flow Matching Noise Addition + # ======================================================================== + + # Generate noise + noise = torch.randn_like(torch.ones([1, 16, grid_sizes[0][0], grid_sizes[0][1]*2, grid_sizes[0][2]*2], device=video_latents.device), dtype=torch.float32) + noise = patchify(noise, (1, 2, 2))[0].unsqueeze(1) + # DEBUGGING + # because video_latents might be padded, we need to make sure noise also be padded to have the same shape + seq_noise = noise.shape[0] + seq_video = video_latents.shape[0] + if seq_noise < seq_video: + pad_len = seq_video - seq_noise + pad = torch.zeros((pad_len, noise.shape[1], noise.shape[2]), device=noise.device, dtype=noise.dtype) + noise = torch.cat([noise, pad], dim=0) + + # CRITICAL: Manual flow matching (NOT scheduler.add_noise!) + # x_t = (1 - σ) * x_0 + σ * ε + sigma_reshaped = sigma.view(1, batch_size, 1) + noisy_latents = ( + (1.0 - sigma_reshaped) * video_latents.float() + + sigma_reshaped * noise + ) + + # Timesteps for model [0, 1000] + timesteps = sigma * num_train_timesteps + + # ======================================================================== + # Cast model inputs to bf16 + # ======================================================================== + + video_latents = video_latents.to(torch.bfloat16) + noisy_latents = noisy_latents.to(torch.bfloat16) + context_embeddings = context_embeddings.to(torch.bfloat16) + timesteps = timesteps.to(torch.bfloat16) + + # ======================================================================== + # Split accross context parallelism + # ======================================================================== + + if parallel_state.get_context_parallel_world_size() > 1: + video_latents = thd_split_inputs_cp(video_latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + noisy_latents = thd_split_inputs_cp(noisy_latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + noise = thd_split_inputs_cp(noise, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + context_embeddings = thd_split_inputs_cp(context_embeddings, packed_seq_params['cross_attention'].cu_seqlens_kv, parallel_state.get_context_parallel_group()) + split_loss_mask = thd_split_inputs_cp(loss_mask, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + else: + video_latents = video_latents + noisy_latents = noisy_latents + noise = noise + context_embeddings = context_embeddings + split_loss_mask = loss_mask + + + # ======================================================================== + # Forward Pass + # ======================================================================== + + if parallel_state.is_pipeline_last_stage(): + + model_pred = self.model( + x = noisy_latents, + grid_sizes = grid_sizes, + t = timesteps, + context = context_embeddings, + max_seq_len = max_video_seq_len, + packed_seq_params=packed_seq_params, + ) + + # ======================================================================== + # Target: Flow Matching Velocity + # ======================================================================== + + # Flow matching target: v = ε - x_0 + target = noise - video_latents.float() + + # ======================================================================== + # Loss with Flow Weighting + # ======================================================================== + + loss = torch.nn.functional.mse_loss( + model_pred.float(), + target.float(), + reduction="none" + ) + + # Flow weight: w = 1 + shift * σ + loss_weight = 1.0 + flow_shift * sigma # shape [batch_size] + loss_weight = loss_weight.view(1, batch_size, 1).to(device) # shape [1, batch_size, 1] + unweighted_loss = loss + weighted_loss = (loss * loss_weight) # shape [seq_length / cp_size, batch_size, -1] + + # Safety check + mean_weighted_loss = weighted_loss.mean() + if torch.isnan(mean_weighted_loss) or mean_weighted_loss > 100: + print(f"[ERROR] Loss explosion! Loss={mean_weighted_loss.item():.3f}") + print(f"[DEBUG] Stopping training - check hyperparameters") + raise ValueError(f"Loss exploded: {mean_weighted_loss.item()}") + + return model_pred, weighted_loss, split_loss_mask + + else: + hidden_states = self.model( + x = noisy_latents, + grid_sizes = grid_sizes, + t = timesteps, + context = context_embeddings, + max_seq_len = max_video_seq_len, + packed_seq_params=packed_seq_params, + ) + + return hidden_states \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/flow_matching/time_shift_utils.py b/src/megatron/bridge/models/wan/flow_matching/time_shift_utils.py new file mode 100644 index 0000000000..56faee4b20 --- /dev/null +++ b/src/megatron/bridge/models/wan/flow_matching/time_shift_utils.py @@ -0,0 +1,108 @@ +# time_shift_utils.py - Timestep sampling and sigma computation utilities + +import math +import numpy as np +import torch + + +def time_shift( + t: torch.Tensor, + image_seq_len: int, + shift_type: str = "constant", + base_shift: float = 0.5, + max_shift: float = 1.15, + constant: float = 3.0, +): + """ + Convert timesteps to sigmas with sequence-length-aware shifting. + + Args: + t: timesteps in range [0, 1] + image_seq_len: number of tokens (frames * height * width / patch_size^2) + shift_type: "linear", "sqrt", or "constant" + base_shift: base shift for linear mode + max_shift: max shift for linear mode + constant: shift value for constant mode (default 3.0 matches Pika) + + Returns: + sigma values for noise scheduling + """ + if shift_type == "linear": + # Linear interpolation based on sequence length + mu = base_shift + (max_shift - base_shift) * (image_seq_len / 4096) + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)) + + elif shift_type == "sqrt": + # Square root scaling (Flux-style) + # Assuming 128x128 latent space (1024x1024 image) gives mu=3 + mu = np.maximum(1.0, np.sqrt(image_seq_len / (128.0 * 128.0)) * 3.0) + return mu / (mu + (1 / t - 1)) + + elif shift_type == "constant": + # Constant shift (Pika default) + return constant / (constant + (1 / t - 1)) + + else: + # No shift, return original t + return t + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, + batch_size: int, + logit_mean: float = 0.0, + logit_std: float = 1.0, + mode_scale: float = 1.29, +): + """ + Sample timesteps from different distributions for better training coverage. + + Args: + weighting_scheme: "uniform", "logit_normal", or "mode" + batch_size: number of samples to generate + logit_mean: mean for logit-normal distribution + logit_std: std for logit-normal distribution + mode_scale: scale for mode-based sampling + + Returns: + Tensor of shape (batch_size,) with values in [0, 1] + """ + if weighting_scheme == "logit_normal": + # SD3-style logit-normal sampling + u = torch.normal( + mean=logit_mean, + std=logit_std, + size=(batch_size,), + device="cpu" + ) + u = torch.nn.functional.sigmoid(u) + + elif weighting_scheme == "mode": + # Mode-based sampling (concentrates around certain timesteps) + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + + else: + # Uniform sampling (default) + u = torch.rand(size=(batch_size,), device="cpu") + + return u + + +def get_flow_match_loss_weight(sigma: torch.Tensor, shift: float = 3.0): + """ + Compute loss weights for flow matching based on sigma values. + + Higher sigma (more noise) typically gets higher weight. + + Args: + sigma: sigma values in range [0, 1] + shift: weight scaling factor + + Returns: + Loss weights with same shape as sigma + """ + # Flow matching weight: weight = 1 + shift * sigma + # This gives more weight to noisier timesteps + weight = 1.0 + shift * sigma + return weight \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/inference/configs/__init__.py b/src/megatron/bridge/models/wan/inference/configs/__init__.py new file mode 100644 index 0000000000..a28c03c5fd --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/configs/__init__.py @@ -0,0 +1,52 @@ +import copy +import os + +os.environ['TOKENIZERS_PARALLELISM'] = 'false' + +from .wan_i2v_14B import i2v_14B +from .wan_t2v_1_3B import t2v_1_3B +from .wan_t2v_14B import t2v_14B + +# the config of t2i_14B is the same as t2v_14B +t2i_14B = copy.deepcopy(t2v_14B) +t2i_14B.__name__ = 'Config: Wan T2I 14B' + +# the config of flf2v_14B is the same as i2v_14B +flf2v_14B = copy.deepcopy(i2v_14B) +flf2v_14B.__name__ = 'Config: Wan FLF2V 14B' +flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt + +WAN_CONFIGS = { + 't2v-14B': t2v_14B, + 't2v-1.3B': t2v_1_3B, + 'i2v-14B': i2v_14B, + 't2i-14B': t2i_14B, + 'flf2v-14B': flf2v_14B, + 'vace-1.3B': t2v_1_3B, + 'vace-14B': t2v_14B, +} + +SIZE_CONFIGS = { + '720*1280': (720, 1280), + '1280*720': (1280, 720), + '480*832': (480, 832), + '832*480': (832, 480), + '1024*1024': (1024, 1024), +} + +MAX_AREA_CONFIGS = { + '720*1280': 720 * 1280, + '1280*720': 1280 * 720, + '480*832': 480 * 832, + '832*480': 832 * 480, +} + +SUPPORTED_SIZES = { + 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), + 't2v-1.3B': ('480*832', '832*480'), + 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), + 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), + 't2i-14B': tuple(SIZE_CONFIGS.keys()), + 'vace-1.3B': ('480*832', '832*480'), + 'vace-14B': ('720*1280', '1280*720', '480*832', '832*480') +} diff --git a/src/megatron/bridge/models/wan/inference/configs/shared_config.py b/src/megatron/bridge/models/wan/inference/configs/shared_config.py new file mode 100644 index 0000000000..37d3ae0c43 --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/configs/shared_config.py @@ -0,0 +1,18 @@ +import torch +from easydict import EasyDict + +#------------------------ Wan shared config ------------------------# +wan_shared_cfg = EasyDict() + +# t5 +wan_shared_cfg.t5_model = 'umt5_xxl' +wan_shared_cfg.t5_dtype = torch.bfloat16 +wan_shared_cfg.text_len = 512 + +# transformer +wan_shared_cfg.param_dtype = torch.bfloat16 + +# inference +wan_shared_cfg.num_train_timesteps = 1000 +wan_shared_cfg.sample_fps = 16 +wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' diff --git a/src/megatron/bridge/models/wan/inference/configs/wan_i2v_14B.py b/src/megatron/bridge/models/wan/inference/configs/wan_i2v_14B.py new file mode 100644 index 0000000000..764d2ed8c3 --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/configs/wan_i2v_14B.py @@ -0,0 +1,35 @@ +import torch +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan I2V 14B ------------------------# + +i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') +i2v_14B.update(wan_shared_cfg) +i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt + +i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +i2v_14B.t5_tokenizer = 'google/umt5-xxl' + +# clip +i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' +i2v_14B.clip_dtype = torch.float16 +i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' +i2v_14B.clip_tokenizer = 'xlm-roberta-large' + +# vae +i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' +i2v_14B.vae_stride = (4, 8, 8) + +# transformer +i2v_14B.patch_size = (1, 2, 2) +i2v_14B.dim = 5120 +i2v_14B.ffn_dim = 13824 +i2v_14B.freq_dim = 256 +i2v_14B.num_heads = 40 +i2v_14B.num_layers = 40 +i2v_14B.window_size = (-1, -1) +i2v_14B.qk_norm = True +i2v_14B.cross_attn_norm = True +i2v_14B.eps = 1e-6 diff --git a/src/megatron/bridge/models/wan/inference/configs/wan_t2v_14B.py b/src/megatron/bridge/models/wan/inference/configs/wan_t2v_14B.py new file mode 100644 index 0000000000..c793f7f6c3 --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/configs/wan_t2v_14B.py @@ -0,0 +1,28 @@ +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan T2V 14B ------------------------# + +t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') +t2v_14B.update(wan_shared_cfg) + +# t5 +t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +t2v_14B.t5_tokenizer = 'google/umt5-xxl' + +# vae +t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' +t2v_14B.vae_stride = (4, 8, 8) + +# transformer +t2v_14B.patch_size = (1, 2, 2) +t2v_14B.dim = 5120 +t2v_14B.ffn_dim = 13824 +t2v_14B.freq_dim = 256 +t2v_14B.num_heads = 40 +t2v_14B.num_layers = 40 +t2v_14B.window_size = (-1, -1) +t2v_14B.qk_norm = True +t2v_14B.cross_attn_norm = True +t2v_14B.eps = 1e-6 diff --git a/src/megatron/bridge/models/wan/inference/configs/wan_t2v_1_3B.py b/src/megatron/bridge/models/wan/inference/configs/wan_t2v_1_3B.py new file mode 100644 index 0000000000..c8458ce804 --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/configs/wan_t2v_1_3B.py @@ -0,0 +1,28 @@ +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan T2V 1.3B ------------------------# + +t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') +t2v_1_3B.update(wan_shared_cfg) + +# t5 +t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' + +# vae +t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' +t2v_1_3B.vae_stride = (4, 8, 8) + +# transformer +t2v_1_3B.patch_size = (1, 2, 2) +t2v_1_3B.dim = 1536 +t2v_1_3B.ffn_dim = 8960 +t2v_1_3B.freq_dim = 256 +t2v_1_3B.num_heads = 12 +t2v_1_3B.num_layers = 30 +t2v_1_3B.window_size = (-1, -1) +t2v_1_3B.qk_norm = True +t2v_1_3B.cross_attn_norm = True +t2v_1_3B.eps = 1e-6 diff --git a/src/megatron/bridge/models/wan/inference/utils/fm_solvers.py b/src/megatron/bridge/models/wan/inference/utils/fm_solvers.py new file mode 100644 index 0000000000..a38b755c40 --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/utils/fm_solvers.py @@ -0,0 +1,858 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +# Convert dpm solver for flow matching + +import inspect +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate, is_scipy_available +from diffusers.utils.torch_utils import randn_tensor + +if is_scipy_available(): + pass + + +def get_sampling_sigmas(sampling_steps, shift): + sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] + sigma = (shift * sigma / (1 + (shift - 1) * sigma)) + + return sigma + + +def retrieve_timesteps( + scheduler, + num_inference_steps=None, + device=None, + timesteps=None, + sigmas=None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. This determines the resolution of the diffusion process. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored + and used in multistep updates. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + shift (`float`, *optional*, defaults to 1.0): + A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling + process. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is + applied on the fly. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent + saturation and improve photorealism. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, *optional*, defaults to "zero"): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + invert_sigmas: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", + deprecation_message) + + # settings for DPM-Solver + if algorithm_type not in [ + "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++" + ]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError( + f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++" + ] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + self._step_index = None + self._begin_index = None + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / + sigma_s) * sample - (alpha_t * + (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / + alpha_s) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ((alpha_t / alpha_s) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * + (alpha_t * (torch.exp(-h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * + (sigma_t * (torch.exp(h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * + (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / + (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + self.sigmas[self.step_index - 2], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[ + -2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - + (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((alpha_t / alpha_s0) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2) + return x_t # pyright: ignore + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final or + (self.config.lower_order_final and len(self.timesteps) < 15) or + self.config.final_sigmas_type == "zero") + lower_order_second = ((self.step_index == len(self.timesteps) - 2) and + self.config.lower_order_final and + len(self.timesteps) < 15) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++" + ] and variance_noise is None: + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=torch.float32) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to( + device=model_output.device, + dtype=torch.float32) # pyright: ignore + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update( + model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.Tensor`): + The input sample. + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/megatron/bridge/models/wan/inference/utils/fm_solvers_unipc.py b/src/megatron/bridge/models/wan/inference/utils/fm_solvers_unipc.py new file mode 100644 index 0000000000..8d96058394 --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/utils/fm_solvers_unipc.py @@ -0,0 +1,801 @@ +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError( + " missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], + b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop( + "this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError( + " missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError( + " missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError( + " missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ + self.step_index - 1] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step(self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and + self.step_index - 1 not in self.disable_corrector and + self.last_sample is not None # pyright: ignore + ) + + model_output_convert = self.convert_model_output( + model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, + len(self.timesteps) - + self.step_index) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, + self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/megatron/bridge/models/wan/inference/utils/utils.py b/src/megatron/bridge/models/wan/inference/utils/utils.py new file mode 100644 index 0000000000..a57f9bb993 --- /dev/null +++ b/src/megatron/bridge/models/wan/inference/utils/utils.py @@ -0,0 +1,117 @@ +import argparse +import binascii +import os +import os.path as osp + +import imageio +import torch +import torchvision + +__all__ = ['cache_video', 'cache_image', 'str2bool'] + + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + + +def cache_video(tensor, + save_file=None, + fps=30, + suffix='.mp4', + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5): + # cache file + cache_file = osp.join('/tmp', rand_name( + suffix=suffix)) if save_file is None else save_file + + # save to cache + error = None + for _ in range(retry): + try: + # preprocess + tensor = tensor.clamp(min(value_range), max(value_range)) + tensor = torch.stack([ + torchvision.utils.make_grid( + u, nrow=nrow, normalize=normalize, value_range=value_range) + for u in tensor.unbind(2) + ], + dim=1).permute(1, 2, 3, 0) + tensor = (tensor * 255).type(torch.uint8).cpu() + + # write video + writer = imageio.get_writer( + cache_file, fps=fps, codec='libx264', quality=8) + for frame in tensor.numpy(): + writer.append_data(frame) + writer.close() + return cache_file + except Exception as e: + error = e + continue + else: + print(f'cache_video failed, error: {error}', flush=True) + return None + + +def cache_image(tensor, + save_file, + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5): + # cache file + suffix = osp.splitext(save_file)[1] + if suffix.lower() not in [ + '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' + ]: + suffix = '.png' + + # save to cache + error = None + for _ in range(retry): + try: + tensor = tensor.clamp(min(value_range), max(value_range)) + torchvision.utils.save_image( + tensor, + save_file, + nrow=nrow, + normalize=normalize, + value_range=value_range) + return save_file + except Exception as e: + error = e + continue + + +def str2bool(v): + """ + Convert a string to a boolean. + + Supported true values: 'yes', 'true', 't', 'y', '1' + Supported false values: 'no', 'false', 'f', 'n', '0' + + Args: + v (str): String to convert. + + Returns: + bool: Converted boolean value. + + Raises: + argparse.ArgumentTypeError: If the value cannot be converted to boolean. + """ + if isinstance(v, bool): + return v + v_lower = v.lower() + if v_lower in ('yes', 'true', 't', 'y', '1'): + return True + elif v_lower in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected (True/False)') diff --git a/src/megatron/bridge/models/wan/modules/__init__.py b/src/megatron/bridge/models/wan/modules/__init__.py new file mode 100644 index 0000000000..435f1eef0d --- /dev/null +++ b/src/megatron/bridge/models/wan/modules/__init__.py @@ -0,0 +1,13 @@ +from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model +from .tokenizers import HuggingfaceTokenizer +from .vae import WanVAE + + +__all__ = [ + 'WanVAE', + 'T5Model', + 'T5Encoder', + 'T5Decoder', + 'T5EncoderModel', + 'HuggingfaceTokenizer', +] diff --git a/src/megatron/bridge/models/wan/modules/t5.py b/src/megatron/bridge/models/wan/modules/t5.py new file mode 100644 index 0000000000..fecd989e07 --- /dev/null +++ b/src/megatron/bridge/models/wan/modules/t5.py @@ -0,0 +1,512 @@ +# Modified from transformers.models.t5.modeling_t5 +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .tokenizers import HuggingfaceTokenizer + +__all__ = [ + 'T5Model', + 'T5Encoder', + 'T5Decoder', + 'T5EncoderModel', +] + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5Model): + nn.init.normal_(m.token_embedding.weight, std=1.0) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_( + m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) + + +class GELU(nn.Module): + + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, + -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum('bnij,bjnc->binc', attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5CrossAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=False) + + def forward(self, + x, + mask=None, + encoder_states=None, + encoder_mask=None, + pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn( + self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ + torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( + 0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + math.log(self.max_dist / max_exact) * + (num_buckets - max_exact)).long() + rel_pos_large = torch.min( + rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + + +class T5Encoder(nn.Module): + + def __init__(self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Encoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Decoder(nn.Module): + + def __init__(self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Decoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=False) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): + b, s = ids.size() + + # causal mask + if mask is None: + mask = torch.tril(torch.ones(1, s, s).to(ids.device)) + elif mask.ndim == 2: + mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) + + # layers + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Model(nn.Module): + + def __init__(self, + vocab_size, + dim, + dim_attn, + dim_ffn, + num_heads, + encoder_layers, + decoder_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Model, self).__init__() + self.vocab_size = vocab_size + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_buckets = num_buckets + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn, + num_heads, encoder_layers, num_buckets, + shared_pos, dropout) + self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn, + num_heads, decoder_layers, num_buckets, + shared_pos, dropout) + self.head = nn.Linear(dim, vocab_size, bias=False) + + # initialize weights + self.apply(init_weights) + + def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): + x = self.encoder(encoder_ids, encoder_mask) + x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) + x = self.head(x) + return x + + +def _t5(name, + encoder_only=False, + decoder_only=False, + return_tokenizer=False, + tokenizer_kwargs={}, + dtype=torch.float32, + device='cpu', + **kwargs): + # sanity check + assert not (encoder_only and decoder_only) + + # params + if encoder_only: + model_cls = T5Encoder + kwargs['vocab'] = kwargs.pop('vocab_size') + kwargs['num_layers'] = kwargs.pop('encoder_layers') + _ = kwargs.pop('decoder_layers') + elif decoder_only: + model_cls = T5Decoder + kwargs['vocab'] = kwargs.pop('vocab_size') + kwargs['num_layers'] = kwargs.pop('decoder_layers') + _ = kwargs.pop('encoder_layers') + else: + model_cls = T5Model + + # init model + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + + # init tokenizer + if return_tokenizer: + from .tokenizers import HuggingfaceTokenizer + tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs) + return model, tokenizer + else: + return model + + +def umt5_xxl(**kwargs): + cfg = dict( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + encoder_layers=24, + decoder_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1) + cfg.update(**kwargs) + return _t5('umt5-xxl', **cfg) + + +class T5EncoderModel: + + def __init__( + self, + text_len, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + checkpoint_path=None, + tokenizer_path=None, + shard_fn=None, + ): + self.text_len = text_len + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + model = umt5_xxl( + encoder_only=True, + return_tokenizer=False, + dtype=dtype, + device=device).eval().requires_grad_(False) + logging.info(f'loading {checkpoint_path}') + model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + self.model = model + if shard_fn is not None: + self.model = shard_fn(self.model, sync_module_states=False) + else: + self.model.to(self.device) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer( + name=tokenizer_path, seq_len=text_len, clean='whitespace') + + def __call__(self, texts, device): + ids, mask = self.tokenizer( + texts, return_mask=True, add_special_tokens=True) + ids = ids.to(device) + mask = mask.to(device) + seq_lens = mask.gt(0).sum(dim=1).long() + context = self.model(ids, mask) + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/src/megatron/bridge/models/wan/modules/tokenizers.py b/src/megatron/bridge/models/wan/modules/tokenizers.py new file mode 100644 index 0000000000..a69972adf2 --- /dev/null +++ b/src/megatron/bridge/models/wan/modules/tokenizers.py @@ -0,0 +1,81 @@ +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + +__all__ = ['HuggingfaceTokenizer'] + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace('_', ' ') + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans('', '', string.punctuation)) + for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans('', '', string.punctuation)) + text = text.lower() + text = re.sub(r'\s+', ' ', text) + return text.strip() + + +class HuggingfaceTokenizer: + + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, 'whitespace', 'lower', 'canonicalize') + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop('return_mask', False) + + # arguments + _kwargs = {'return_tensors': 'pt'} + if self.seq_len is not None: + _kwargs.update({ + 'padding': 'max_length', + 'truncation': True, + 'max_length': self.seq_len + }) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == 'whitespace': + text = whitespace_clean(basic_clean(text)) + elif self.clean == 'lower': + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == 'canonicalize': + text = canonicalize(basic_clean(text)) + return text diff --git a/src/megatron/bridge/models/wan/modules/vae.py b/src/megatron/bridge/models/wan/modules/vae.py new file mode 100644 index 0000000000..d4f1ef1d0e --- /dev/null +++ b/src/megatron/bridge/models/wan/modules/vae.py @@ -0,0 +1,662 @@ +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = [ + 'WanVAE', +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -1:, :, :].clone() + # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': + # # cache last frame of last two chunk + # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 + conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk( + 3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + ## 对encode输入的x,按时间拆分为1、4、4、4.... + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + #cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + # params + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0) + cfg.update(**kwargs) + + # init model + with torch.device('meta'): + model = WanVAE_(**cfg) + + # load checkpoint + logging.info(f'loading {pretrained_path}') + model.load_state_dict( + torch.load(pretrained_path, map_location=device), assign=True) + + return model + + +class WanVAE: + + def __init__(self, + z_dim=16, + vae_pth='cache/vae_step_411000.pth', + dtype=torch.float, + device="cuda"): + self.dtype = dtype + self.device = device + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean, dtype=dtype, device=device) + self.std = torch.tensor(std, dtype=dtype, device=device) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae( + pretrained_path=vae_pth, + z_dim=z_dim, + ).eval().requires_grad_(False).to(device) + + def encode(self, videos): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + with amp.autocast(dtype=self.dtype): + return [ + self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) + for u in videos + ] + + def decode(self, zs): + with amp.autocast(dtype=self.dtype): + return [ + self.model.decode(u.unsqueeze(0), + self.scale).float().clamp_(-1, 1).squeeze(0) + for u in zs + ] diff --git a/src/megatron/bridge/models/wan/rope_utils.py b/src/megatron/bridge/models/wan/rope_utils.py new file mode 100644 index 0000000000..1f79d8bc7c --- /dev/null +++ b/src/megatron/bridge/models/wan/rope_utils.py @@ -0,0 +1,65 @@ +import torch +from torch.cuda import amp +from megatron.bridge.models.wan.utils.utils import split_inputs_cp +from megatron.core import parallel_state + +class Wan3DRopeEmbeddings(torch.nn.Module): + """ + Wan 3D RoPE embeddings implementation. + Implements Wan's 3D RoPE embeddings for Mcore Attention based on Wan's implementation at https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py. + """ + + def __init__(self, dim_head, max_position_len): + super().__init__() + self.freqs = torch.cat([ + self.rope_params(max_position_len, dim_head - 4 * (dim_head // 6)), + self.rope_params(max_position_len, 2 * (dim_head // 6)), + self.rope_params(max_position_len, 2 * (dim_head // 6)) + ], dim=1) + + def rope_params(self, max_position_len, dim_head, theta=10000): + assert dim_head % 2 == 0 + freqs = torch.outer( + torch.arange(max_position_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim_head, 2).div(dim_head))) + return freqs + + def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): + self.freqs = self.freqs.to(device) # ??? do we need to put this here, or the when we move WanModel to device, it also move freqs to device? + + n, c = n_head, dim_head // 2 + + # split freqs + freqs = self.freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + freqs_real = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + freqs_real_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(seq_len, 1, 1, -1) # <-- add 1,1 for batch/head broadcasting + + # Double dimension from c -> 2c with rotating angles as (x0, x0, x1, x1, ...), for interleaving RoPE + freqs_real_i = freqs_real_i.unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(seq_len, 1, 1, dim_head) + + # Pad freqs_real_i to (max_seq_len, 1, 1, dim_head) with 0s + if freqs_real_i.shape[0] < max_seq_len: + pad_shape = (max_seq_len - freqs_real_i.shape[0], 1, 1, dim_head) + freqs_real_i = torch.cat( + [freqs_real_i, torch.zeros(pad_shape, dtype=freqs_real_i.dtype, device=freqs_real_i.device)] + ) + freqs_real.append(freqs_real_i) + + # Each freqs_real[i] is (max_seq_len, 1, 1, dim_head) + # We concatenate them along dim=1 to get (max_seq_len, batch_size, 1, dim_head) + freqs_real = torch.cat(freqs_real, dim=1) + + # Note: + # when running context_parallel, which must use "thd" for qkv_format, + # we don't need to scatter the freqs to the context parallel region, + # because mcore rope_utils will automatically retrieve the correct freqs for each context parallel region + + return freqs_real \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/utils/preprocessor.py b/src/megatron/bridge/models/wan/utils/preprocessor.py new file mode 100644 index 0000000000..fc5ea6a740 --- /dev/null +++ b/src/megatron/bridge/models/wan/utils/preprocessor.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF + + +class VaceImageProcessor(object): + def __init__(self, downsample=None, seq_len=None): + self.downsample = downsample + self.seq_len = seq_len + + def _pillow_convert(self, image, cvt_type='RGB'): + if image.mode != cvt_type: + if image.mode == 'P': + image = image.convert(f'{cvt_type}A') + if image.mode == f'{cvt_type}A': + bg = Image.new(cvt_type, + size=(image.width, image.height), + color=(255, 255, 255)) + bg.paste(image, (0, 0), mask=image) + image = bg + else: + image = image.convert(cvt_type) + return image + + def _load_image(self, img_path): + if img_path is None or img_path == '': + return None + img = Image.open(img_path) + img = self._pillow_convert(img) + return img + + def _resize_crop(self, img, oh, ow, normalize=True): + """ + Resize, center crop, convert to tensor, and normalize. + """ + # resize and crop + iw, ih = img.size + if iw != ow or ih != oh: + # resize + scale = max(ow / iw, oh / ih) + img = img.resize( + (round(scale * iw), round(scale * ih)), + resample=Image.Resampling.LANCZOS + ) + assert img.width >= ow and img.height >= oh + + # center crop + x1 = (img.width - ow) // 2 + y1 = (img.height - oh) // 2 + img = img.crop((x1, y1, x1 + ow, y1 + oh)) + + # normalize + if normalize: + img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1) + return img + + def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs): + return self._resize_crop(img, oh, ow, normalize) + + def load_image(self, data_key, **kwargs): + return self.load_image_batch(data_key, **kwargs) + + def load_image_pair(self, data_key, data_key2, **kwargs): + return self.load_image_batch(data_key, data_key2, **kwargs) + + def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs): + seq_len = self.seq_len if seq_len is None else seq_len + imgs = [] + for data_key in data_key_batch: + img = self._load_image(data_key) + imgs.append(img) + w, h = imgs[0].size + dh, dw = self.downsample[1:] + + # compute output size + scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw)))) + oh = int(h * scale) // dh * dh + ow = int(w * scale) // dw * dw + assert (oh // dh) * (ow // dw) <= seq_len + imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs] + return *imgs, (oh, ow) + + +class VaceVideoProcessor(object): + def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs): + self.downsample = downsample + self.min_area = min_area + self.max_area = max_area + self.min_fps = min_fps + self.max_fps = max_fps + self.zero_start = zero_start + self.keep_last = keep_last + self.seq_len = seq_len + assert seq_len >= min_area / (self.downsample[1] * self.downsample[2]) + + def set_area(self, area): + self.min_area = area + self.max_area = area + + def set_seq_len(self, seq_len): + self.seq_len = seq_len + + @staticmethod + def resize_crop(video: torch.Tensor, oh: int, ow: int): + """ + Resize, center crop and normalize for decord loaded video (torch.Tensor type) + + Parameters: + video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C) + oh - target height (int) + ow - target width (int) + + Returns: + The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W) + + Raises: + """ + # permute ([t, h, w, c] -> [t, c, h, w]) + video = video.permute(0, 3, 1, 2) + + # resize and crop + ih, iw = video.shape[2:] + if ih != oh or iw != ow: + # resize + scale = max(ow / iw, oh / ih) + video = F.interpolate( + video, + size=(round(scale * ih), round(scale * iw)), + mode='bicubic', + antialias=True + ) + assert video.size(3) >= ow and video.size(2) >= oh + + # center crop + x1 = (video.size(3) - ow) // 2 + y1 = (video.size(2) - oh) // 2 + video = video[:, :, y1:y1 + oh, x1:x1 + ow] + + # permute ([t, c, h, w] -> [c, t, h, w]) and normalize + video = video.transpose(0, 1).float().div_(127.5).sub_(1.) + return video + + def _video_preprocess(self, video, oh, ow): + return self.resize_crop(video, oh, ow) + + def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng): + target_fps = min(fps, self.max_fps) + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + of = min( + (int(duration * target_fps) - 1) // df + 1, + int(self.seq_len / area_z) + ) + + # deduce target shape of the [latent video] + target_area_z = min(area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = of / target_fps + begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration) + timestamps = np.linspace(begin, begin + target_duration, of) + frame_ids = np.argmax(np.logical_and( + timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] < frame_timestamps[None, :, 1] + ), axis=1).tolist() + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng): + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + of = min( + (len(frame_timestamps) - 1) // df + 1, + int(self.seq_len / area_z) + ) + + # deduce target shape of the [latent video] + target_area_z = min(area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = duration + target_fps = of / target_duration + timestamps = np.linspace(0., target_duration, of) + frame_ids = np.argmax(np.logical_and( + timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] <= frame_timestamps[None, :, 1] + ), axis=1).tolist() + # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids)) + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + + def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): + if self.keep_last: + return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng) + else: + return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng) + + def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs): + rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) + # read video + import decord + decord.bridge.set_bridge('torch') + readers = [] + for data_k in data_key_batch: + reader = decord.VideoReader(data_k) + readers.append(reader) + + fps = readers[0].get_avg_fps() + length = min([len(r) for r in readers]) + frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)] + frame_timestamps = np.array(frame_timestamps, dtype=np.float32) + h, w = readers[0].next().shape[:2] + frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng) + + # preprocess video + videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] + videos = [self._video_preprocess(video, oh, ow) for video in videos] + return *videos, frame_ids, (oh, ow), fps + # return videos if len(videos) > 1 else videos[0] + + +def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device): + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_video is None and sub_src_mask is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device) + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + for j, ref_img in enumerate(ref_images): + if ref_img is not None and ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + src_ref_images[i][j] = white_canvas + return src_video, src_mask, src_ref_images \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/utils/utils.py b/src/megatron/bridge/models/wan/utils/utils.py new file mode 100644 index 0000000000..0f93526632 --- /dev/null +++ b/src/megatron/bridge/models/wan/utils/utils.py @@ -0,0 +1,221 @@ +import torch +from typing import Tuple +from torch.distributed import all_gather +import megatron.core.parallel_state as parallel_state +import math +import torch.distributed as dist +import transformer_engine_torch as tex + +def grid_sizes_calculation( + input_shape: Tuple[int, int, int], # (F_latents, H_latents, W_latents) + patch_size: Tuple[int, int, int], # (pF, pH, pW) +) -> Tuple[int, int, int]: + """ + Compute the (f,h,w) output spatial/temporal dimensions of a Conv3d patch embedder. + """ + + F_latents, H_latents, W_latents = input_shape + pF, pH, pW = patch_size + F_patches = F_latents // pF + H_patches = H_latents // pH + W_patches = W_latents // pW + + return [F_patches, H_patches, W_patches] + + +def patchify(x, patch_size): + """ + Convert a list of reconstructed video tensor into patch embeddings. + This method is the inverse of `unpatchify`. + + Args: + x (list[torch.Tensor]): list of tensors, each with shape [c, F_patches * pF, H_patches * pH, W_patches * pW] + patch_size (tuple): (pF, pH, pW) + + Returns: + torch.Tensor: shape [ (F_patches * H_patches * W_patches), (c * pF * pH * pW)], + """ + out = [] + for u in x: + c, F_pF, H_pH, W_pW = u.shape + pF, pH, pW = patch_size + assert F_pF % pF == 0 and H_pH % pH == 0 and W_pW % pW == 0, \ + "Spatial dimensions must be divisible by patch size." + + F_patches, H_patches, W_patches = F_pF // pF, H_pH // pH, W_pW // pW + + # split spatial dims into (grid, patch) and reorder to match original patch layout: + # start: (c, F_patches * pF, H_patches * pH, W_patches * pW) + # reshape -> (c, F_patches, pF, H_patches, pH, W_patches, pW) + # permute -> (F_patches, H_patches, W_patches, pF, pH, pW, c) + t = u.reshape(c, F_patches, pF, H_patches, pH, W_patches, pW) + t = t.permute(1, 3, 5, 2, 4, 6, 0) + + num_patches = F_patches * H_patches * W_patches + out.append(t.reshape(num_patches, c * (pF * pH * pW))) + return out + + +def unpatchify(x: list[torch.Tensor], grid_sizes: list[Tuple[int, int, int]], out_dim: int, patch_size: Tuple[int, int, int]) -> list[torch.Tensor]: + """ + Reconstruct video tensors from patch embeddings into a list of videotensors. + + Args: + x (list[torch.Tensor]): + list of tensors, each with shape [seq_len, c * pF * pH * pW] + grid_sizes (list[Tuple[int, int, int]]): + list of tensors, each with original spatial-temporal grid dimensions before patching, + (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + list[torch.Tensor]: list of tensors, each with shape [c, F_latents, H_latents, W_latents] + """ + + c = out_dim + out = [] + for u, v in zip(x, grid_sizes): + u = u[:math.prod(v)].view(*v, *patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, patch_size)]) + out.append(u) + return out + + +def split_inputs_cp(x: torch.Tensor, seq_dim: int = 0) -> torch.Tensor: + """ + Split input tensor along the sequence dimension for context parallelism. + + Args: + x: Input tensor to be split. (e.g. shape [seq_len, batch_size, ...]) + seq_dim: The dimension along which to split the input (sequence dimension). + + Returns: + A slice of the input tensor corresponding to the current rank. (e.g. shape [seq_len/cp_size, batch_size, ...]) + """ + + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + cp_rank = parallel_state.get_context_parallel_rank() + assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}" + x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim + 1) :]) + seq_idx = torch.tensor([cp_rank], device=x.device) + x = x.index_select(seq_dim, seq_idx) + # Note that the new sequence length is the original sequence length / cp_size + x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + return x + + +def cat_outputs_cp(x: torch.Tensor, seq_dim: int) -> torch.Tensor: + """ + Concatenate tensors from multiple processes along a specified dimension. + + Args: + x: Input tensor to be concatenated. (e.g. shape [seq_len/cp_size, batch_size, ...]) + seq_dim: The dimension along which to concatenate the input tensors. + + Returns: + A tensor with the concatenated tensors. (e.g. shape [seq_len, batch_size, ...]) + """ + + cp_group = parallel_state.get_context_parallel_group() + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + gathered_tensors = [torch.zeros_like(x) for _ in range(cp_size)] + # Attempt to gather tensors from all ranks + # PyTorch’s all_gather orders outputs by rank within the group, which matches how chunks were selected by cp_rank + all_gather(gathered_tensors, x, group=cp_group) + gathered_tensors = torch.cat(gathered_tensors, dim=seq_dim) + return gathered_tensors + else: + return x + + +def thd_split_inputs_cp(x: torch.Tensor, + cu_seqlens_q_padded: torch.Tensor, + cp_group: dist.ProcessGroup) -> torch.Tensor: + """ + Split a THD-packed tensor across CP ranks for inputs shaped [S, B, ...]. + + Args: + x: [S, B, ...] tensor (sequence first). + cu_seqlens_q_padded: 1D int32 THD cu_seqlens (padded) used for packing. + cp_group: context-parallel process group. + + Returns: + x_local: [S_local, B, ...] shard for this CP rank. + """ + # Move to [B, S, ...] to use THD partitioning along S + x_bs = x.transpose(0, 1).contiguous() # [B, S, ...] + + total_S = x_bs.size(1) + cp_size = dist.get_world_size(cp_group) + cp_rank = dist.get_rank(cp_group) + + # Compute this rank's THD partition indices (same API as during gather) + idx = tex.thd_get_partitioned_indices( + cu_seqlens_q_padded, # int32 offsets + total_S, + cp_size, + cp_rank, + ).to(device=x_bs.device, dtype=torch.long) # [S_local] + + # Take the shard along sequence dim + x_local_bs = x_bs.index_select(dim=1, index=idx).contiguous() # [B, S_local, ...] + + # Return to [S, B, ...] + x_local = x_local_bs.transpose(0, 1).contiguous() # [S_local, B, ...] + return x_local + + +def thd_cat_outputs_cp(x_local: torch.Tensor, + cu_seqlens_q_padded: torch.Tensor, + cp_group: dist.ProcessGroup) -> torch.Tensor: + """ + Reverse of thd_split_inputs_cp: gather THD-partitioned local shards back to global. + + Args: + x_local: [S_local, B, ...] tensor (this rank's shard, sequence first). + cu_seqlens_q_padded: 1D int32 THD cu_seqlens (padded) used for packing. + cp_group: context-parallel process group. + + Returns: + x_global: [S, B, ...] tensor reassembled across CP ranks. + """ + # Work in [B, S_local, ...] for easy indexing along S + x_local_bs = x_local.transpose(0, 1).contiguous() # [B, S_local, ...] + + cp_size = dist.get_world_size(cp_group) + cp_rank = dist.get_rank(cp_group) + + # Discover total S from cu_seqlens (last value) + # (Matches 'total_S' used during split.) + total_S = int(cu_seqlens_q_padded[-1].item()) + + # All-gather local shards across CP group + gather_list = [torch.empty_like(x_local_bs) for _ in range(cp_size)] + dist.all_gather(gather_list, x_local_bs, group=cp_group) # each is [B, S_r, ...] + + # Compute per-rank indices once (same device/dtype as input) + # NOTE: tex.thd_get_partitioned_indices returns indices along S for that rank. + idx_list = [] + for r in range(cp_size): + idx_r = tex.thd_get_partitioned_indices( + cu_seqlens_q_padded, # int32 offsets + total_S, + cp_size, + r, + ).to(device=x_local_bs.device, dtype=torch.long) # [S_r] + idx_list.append(idx_r) + + # Allocate output [B, S, ...] and place each rank's slice back + out_shape = list(x_local_bs.shape) + out_shape[1] = total_S # replace S_local with S + x_global_bs = x_local_bs.new_zeros(out_shape) # [B, S, ...] + + # index_copy_ along S dimension + for shard, idx in zip(gather_list, idx_list): + x_global_bs.index_copy_(dim=1, index=idx, source=shard) + + # Return to [S, B, ...] + x_global = x_global_bs.transpose(0, 1).contiguous() # [S, B, ...] + return x_global \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/wan_bridge.py b/src/megatron/bridge/models/wan/wan_bridge.py new file mode 100644 index 0000000000..ebcbf8e1c4 --- /dev/null +++ b/src/megatron/bridge/models/wan/wan_bridge.py @@ -0,0 +1,411 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import torch +from megatron.bridge.models.wan.wan_model import WanModel, VACEModel +from diffusers import WanTransformer3DModel, WanVACETransformer3DModel + +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, + KVMapping, + ReplicatedMapping, +) +from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN, PreTrainedVACE +from megatron.bridge.models.wan.wan_provider import WanModelProvider, VACEModelProvider +from megatron.core.transformer.utils import openai_gelu +from megatron.bridge.models.conversion.utils import get_module_and_param_from_name + + +@MegatronModelBridge.register_bridge(source=WanTransformer3DModel, target=WanModel) +class WanBridge(MegatronModelBridge): + """ + Megatron Bridge for WAN model. + + As a user you would not use this bridge directly, but through `AutoBridge`. + + Example: + >>> from megatron.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("WAN-3D-1.3B-v1") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedWAN) -> WanModelProvider: + hf_config = hf_pretrained.config + + cls = WanModelProvider + + provider = cls( + num_layers=hf_config.num_layers, + hidden_size=hf_config.num_attention_heads * hf_config.attention_head_dim, + kv_channels=hf_config.attention_head_dim, + num_query_groups=hf_config.num_attention_heads, + crossattn_emb_size=hf_config.num_attention_heads * hf_config.attention_head_dim, + ffn_hidden_size=hf_config.ffn_dim, + num_attention_heads=hf_config.num_attention_heads, + activation_func=openai_gelu, + in_channels=hf_config.in_channels, + out_channels=hf_config.out_channels, + text_dim=hf_config.text_dim, + patch_spatial=hf_config.patch_size[1], + patch_temporal=hf_config.patch_size[0], + layernorm_epsilon=hf_config.eps, + hidden_dropout=0, + attention_dropout=0, + use_cpu_initialization=True, + freq_dim=hf_config.freq_dim, + bf16=False, + params_dtype=torch.float32, + ) + + return provider + + + def mapping_registry(self) -> MegatronMappingRegistry: + """Return MegatronMappingRegistry containing parameter mappings from HF to Megatron format. + + Returns: + MegatronMappingRegistry: Registry of parameter mappings + """ + # Dictionary maps HF parameter names -> Megatron parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "scale_shift_table": "head.modulation", + "patch_embedding.weight": "patch_embedding.weight", + "patch_embedding.bias": "patch_embedding.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "blocks.*.scale_shift_table": "decoder.layers.*.adaLN.modulation", + "blocks.*.attn1.to_out.0.weight": "decoder.layers.*.full_self_attention.linear_proj.weight", + "blocks.*.attn1.to_out.0.bias": "decoder.layers.*.full_self_attention.linear_proj.bias", + "blocks.*.attn1.norm_q.weight": "decoder.layers.*.full_self_attention.q_layernorm.weight", + "blocks.*.attn1.norm_k.weight": "decoder.layers.*.full_self_attention.k_layernorm.weight", + "blocks.*.attn2.to_q.weight": "decoder.layers.*.cross_attention.linear_q.weight", + "blocks.*.attn2.to_q.bias": "decoder.layers.*.cross_attention.linear_q.bias", + "blocks.*.attn2.to_out.0.weight": "decoder.layers.*.cross_attention.linear_proj.weight", + "blocks.*.attn2.to_out.0.bias": "decoder.layers.*.cross_attention.linear_proj.bias", + "blocks.*.attn2.norm_q.weight": "decoder.layers.*.cross_attention.q_layernorm.weight", + "blocks.*.attn2.norm_k.weight": "decoder.layers.*.cross_attention.k_layernorm.weight", + "blocks.*.norm2.weight": "decoder.layers.*.norm3.weight", + "blocks.*.norm2.bias": "decoder.layers.*.norm3.bias", + "blocks.*.ffn.net.0.proj.weight": "decoder.layers.*.mlp.linear_fc1.weight", + "blocks.*.ffn.net.0.proj.bias": "decoder.layers.*.mlp.linear_fc1.bias", + "blocks.*.ffn.net.2.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "blocks.*.ffn.net.2.bias": "decoder.layers.*.mlp.linear_fc2.bias", + "proj_out.weight": "head.head.weight", + "proj_out.bias": "head.head.bias", + } + + + # Custom WAN mapping to safely handle replicated params whose owning module + # does not expose a top-level `.weight` (e.g., Head.modulation) + class _ReplicatedByParamNameMapping(ReplicatedMapping): + def hf_to_megatron(self, hf_weights, megatron_module): + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + + target_device = target_param.device + target_dtype = target_param.dtype + + hf_weights = hf_weights.to(device=target_device, dtype=target_dtype) + if self.tp_size == 1: + return hf_weights + + if target_device.type == "cuda" and torch.cuda.is_available(): + if target_device.index != torch.cuda.current_device(): + hf_weights = hf_weights.to(torch.cuda.current_device()) + + if self.tp_rank > 0: + hf_weights = torch.empty_like(hf_weights) + + return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0) + + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(hf_param, megatron_param) + for hf_param, megatron_param in param_mappings.items(): + if hf_param in {"scale_shift_table", "blocks.*.scale_shift_table", "proj_out.weight", "proj_out.bias"}: + # Use WAN-specific replicated mapping that resolves the exact param + mapping_list.append(_ReplicatedByParamNameMapping(hf_param=hf_param, megatron_param=megatron_param)) + else: + mapping_list.append(AutoMapping(hf_param=hf_param, megatron_param=megatron_param)) + + # Adding custom module types for AutoMapping + AutoMapping.register_module_type("Linear", "replicated") + AutoMapping.register_module_type("Conv3d", "replicated") + AutoMapping.register_module_type("WanAdaLN", "replicated") + AutoMapping.register_module_type("Head", "replicated") + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # QKV: Combine separate Q, K, V matrices into single QKV matrix + QKVMapping( + q="blocks.*.attn1.to_q.weight", + k="blocks.*.attn1.to_k.weight", + v="blocks.*.attn1.to_v.weight", + megatron_param="decoder.layers.*.full_self_attention.linear_qkv.weight", + ), + # QKV bias: Combine separate Q, K, V bias into single QKV bias + QKVMapping( + q="blocks.*.attn1.to_q.bias", + k="blocks.*.attn1.to_k.bias", + v="blocks.*.attn1.to_v.bias", + megatron_param="decoder.layers.*.full_self_attention.linear_qkv.bias", + ), + # K, V: Combine separate K, V matrices into single KV matrix + KVMapping( + k="blocks.*.attn2.to_k.weight", + v="blocks.*.attn2.to_v.weight", + megatron_param="decoder.layers.*.cross_attention.linear_kv.weight", + ), + # K, V bias: Combine separate K, V bias into single KV bias + KVMapping( + k="blocks.*.attn2.to_k.bias", + v="blocks.*.attn2.to_v.bias", + megatron_param="decoder.layers.*.cross_attention.linear_kv.bias", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) + + +@MegatronModelBridge.register_bridge(source=WanVACETransformer3DModel, target=VACEModel) +class VACEBridge(MegatronModelBridge): + """ + Megatron Bridge for VACE model. + + As a user you would not use this bridge directly, but through `AutoBridge`. + + Example: + >>> from megatron.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("WAN-3D-1.3B-v1") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedVACE) -> VACEModelProvider: + hf_config = hf_pretrained.config + + cls = VACEModelProvider + + provider = cls( + num_layers=hf_config.num_layers, + hidden_size=hf_config.num_attention_heads * hf_config.attention_head_dim, + kv_channels=hf_config.attention_head_dim, + num_query_groups=hf_config.num_attention_heads, + crossattn_emb_size=hf_config.num_attention_heads * hf_config.attention_head_dim, + ffn_hidden_size=hf_config.ffn_dim, + num_attention_heads=hf_config.num_attention_heads, + activation_func=openai_gelu, + in_channels=hf_config.in_channels, + out_channels=hf_config.out_channels, + text_dim=hf_config.text_dim, + patch_spatial=hf_config.patch_size[1], + patch_temporal=hf_config.patch_size[0], + layernorm_epsilon=hf_config.eps, + hidden_dropout=0, + attention_dropout=0, + use_cpu_initialization=True, + freq_dim=hf_config.freq_dim, + bf16=False, + params_dtype=torch.float32, + vace_in_channels=hf_config.vace_in_channels, + vace_layers=hf_config.vace_layers, + base_num_layers=hf_config.num_layers, + ) + + return provider + + + def mapping_registry(self) -> MegatronMappingRegistry: + """Return MegatronMappingRegistry containing parameter mappings from HF to Megatron format. + + Returns: + MegatronMappingRegistry: Registry of parameter mappings + """ + # Dictionary maps HF parameter names -> Megatron parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "scale_shift_table": "head.modulation", + "patch_embedding.weight": "patch_embedding.weight", + "patch_embedding.bias": "patch_embedding.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "blocks.*.scale_shift_table": "decoder.layers.*.adaLN.modulation", + "blocks.*.attn1.to_out.0.weight": "decoder.layers.*.full_self_attention.linear_proj.weight", + "blocks.*.attn1.to_out.0.bias": "decoder.layers.*.full_self_attention.linear_proj.bias", + "blocks.*.attn1.norm_q.weight": "decoder.layers.*.full_self_attention.q_layernorm.weight", + "blocks.*.attn1.norm_k.weight": "decoder.layers.*.full_self_attention.k_layernorm.weight", + "blocks.*.attn2.to_q.weight": "decoder.layers.*.cross_attention.linear_q.weight", + "blocks.*.attn2.to_q.bias": "decoder.layers.*.cross_attention.linear_q.bias", + "blocks.*.attn2.to_out.0.weight": "decoder.layers.*.cross_attention.linear_proj.weight", + "blocks.*.attn2.to_out.0.bias": "decoder.layers.*.cross_attention.linear_proj.bias", + "blocks.*.attn2.norm_q.weight": "decoder.layers.*.cross_attention.q_layernorm.weight", + "blocks.*.attn2.norm_k.weight": "decoder.layers.*.cross_attention.k_layernorm.weight", + "blocks.*.norm2.weight": "decoder.layers.*.norm3.weight", + "blocks.*.norm2.bias": "decoder.layers.*.norm3.bias", + "blocks.*.ffn.net.0.proj.weight": "decoder.layers.*.mlp.linear_fc1.weight", + "blocks.*.ffn.net.0.proj.bias": "decoder.layers.*.mlp.linear_fc1.bias", + "blocks.*.ffn.net.2.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "blocks.*.ffn.net.2.bias": "decoder.layers.*.mlp.linear_fc2.bias", + "proj_out.weight": "head.head.weight", + "proj_out.bias": "head.head.bias", + + "vace_patch_embedding.weight": "vace_patch_embedding.weight", + "vace_patch_embedding.bias": "vace_patch_embedding.bias", + "vace_blocks.0.proj_in.weight": "vace_init_proj.weight", + "vace_blocks.0.proj_in.bias": "vace_init_proj.bias", + "vace_blocks.*.scale_shift_table": "vace_decoder.layers.*.adaLN.modulation", + "vace_blocks.*.attn1.to_out.0.weight": "vace_decoder.layers.*.full_self_attention.linear_proj.weight", + "vace_blocks.*.attn1.to_out.0.bias": "vace_decoder.layers.*.full_self_attention.linear_proj.bias", + "vace_blocks.*.attn1.norm_q.weight": "vace_decoder.layers.*.full_self_attention.q_layernorm.weight", + "vace_blocks.*.attn1.norm_k.weight": "vace_decoder.layers.*.full_self_attention.k_layernorm.weight", + "vace_blocks.*.attn2.to_q.weight": "vace_decoder.layers.*.cross_attention.linear_q.weight", + "vace_blocks.*.attn2.to_q.bias": "vace_decoder.layers.*.cross_attention.linear_q.bias", + "vace_blocks.*.attn2.to_out.0.weight": "vace_decoder.layers.*.cross_attention.linear_proj.weight", + "vace_blocks.*.attn2.to_out.0.bias": "vace_decoder.layers.*.cross_attention.linear_proj.bias", + "vace_blocks.*.attn2.norm_q.weight": "vace_decoder.layers.*.cross_attention.q_layernorm.weight", + "vace_blocks.*.attn2.norm_k.weight": "vace_decoder.layers.*.cross_attention.k_layernorm.weight", + "vace_blocks.*.norm2.weight": "vace_decoder.layers.*.norm3.weight", + "vace_blocks.*.norm2.bias": "vace_decoder.layers.*.norm3.bias", + "vace_blocks.*.ffn.net.0.proj.weight": "vace_decoder.layers.*.mlp.linear_fc1.weight", + "vace_blocks.*.ffn.net.0.proj.bias": "vace_decoder.layers.*.mlp.linear_fc1.bias", + "vace_blocks.*.ffn.net.2.weight": "vace_decoder.layers.*.mlp.linear_fc2.weight", + "vace_blocks.*.ffn.net.2.bias": "vace_decoder.layers.*.mlp.linear_fc2.bias", + "vace_blocks.*.proj_out.weight":"vace_decoder.layers.*.context_proj.weight", + "vace_blocks.*.proj_out.bias":"vace_decoder.layers.*.context_proj.bias", + } + + + # Custom WAN mapping to safely handle replicated params whose owning module + # does not expose a top-level `.weight` (e.g., Head.modulation) + class _ReplicatedByParamNameMapping(ReplicatedMapping): + def hf_to_megatron(self, hf_weights, megatron_module): + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + + target_device = target_param.device + target_dtype = target_param.dtype + + hf_weights = hf_weights.to(device=target_device, dtype=target_dtype) + if self.tp_size == 1: + return hf_weights + + if target_device.type == "cuda" and torch.cuda.is_available(): + if target_device.index != torch.cuda.current_device(): + hf_weights = hf_weights.to(torch.cuda.current_device()) + + if self.tp_rank > 0: + hf_weights = torch.empty_like(hf_weights) + + return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0) + + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(hf_param, megatron_param) + for hf_param, megatron_param in param_mappings.items(): + if hf_param in {"scale_shift_table", "blocks.*.scale_shift_table", "vace_blocks.*.scale_shift_table", "proj_out.weight", "proj_out.bias"}: + # Use WAN-specific replicated mapping that resolves the exact param + mapping_list.append(_ReplicatedByParamNameMapping(hf_param=hf_param, megatron_param=megatron_param)) + else: + mapping_list.append(AutoMapping(hf_param=hf_param, megatron_param=megatron_param)) + + # Adding custom module types for AutoMapping + AutoMapping.register_module_type("Linear", "replicated") + AutoMapping.register_module_type("Conv3d", "replicated") + AutoMapping.register_module_type("WanAdaLN", "replicated") + AutoMapping.register_module_type("Head", "replicated") + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # QKV: Combine separate Q, K, V matrices into single QKV matrix + QKVMapping( + q="blocks.*.attn1.to_q.weight", + k="blocks.*.attn1.to_k.weight", + v="blocks.*.attn1.to_v.weight", + megatron_param="decoder.layers.*.full_self_attention.linear_qkv.weight", + ), + # QKV bias: Combine separate Q, K, V bias into single QKV bias + QKVMapping( + q="blocks.*.attn1.to_q.bias", + k="blocks.*.attn1.to_k.bias", + v="blocks.*.attn1.to_v.bias", + megatron_param="decoder.layers.*.full_self_attention.linear_qkv.bias", + ), + # K, V: Combine separate K, V matrices into single KV matrix + KVMapping( + k="blocks.*.attn2.to_k.weight", + v="blocks.*.attn2.to_v.weight", + megatron_param="decoder.layers.*.cross_attention.linear_kv.weight", + ), + # K, V bias: Combine separate K, V bias into single KV bias + KVMapping( + k="blocks.*.attn2.to_k.bias", + v="blocks.*.attn2.to_v.bias", + megatron_param="decoder.layers.*.cross_attention.linear_kv.bias", + ), + + # QKV: Combine separate Q, K, V matrices into single QKV matrix + QKVMapping( + q="vace_blocks.*.attn1.to_q.weight", + k="vace_blocks.*.attn1.to_k.weight", + v="vace_blocks.*.attn1.to_v.weight", + megatron_param="vace_decoder.layers.*.full_self_attention.linear_qkv.weight", + ), + # QKV bias: Combine separate Q, K, V bias into single QKV bias + QKVMapping( + q="vace_blocks.*.attn1.to_q.bias", + k="vace_blocks.*.attn1.to_k.bias", + v="vace_blocks.*.attn1.to_v.bias", + megatron_param="vace_decoder.layers.*.full_self_attention.linear_qkv.bias", + ), + # K, V: Combine separate K, V matrices into single KV matrix + KVMapping( + k="vace_blocks.*.attn2.to_k.weight", + v="vace_blocks.*.attn2.to_v.weight", + megatron_param="vace_decoder.layers.*.cross_attention.linear_kv.weight", + ), + # K, V bias: Combine separate K, V bias into single KV bias + KVMapping( + k="vace_blocks.*.attn2.to_k.bias", + v="vace_blocks.*.attn2.to_v.bias", + megatron_param="vace_decoder.layers.*.cross_attention.linear_kv.bias", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/wan_layer_spec.py b/src/megatron/bridge/models/wan/wan_layer_spec.py new file mode 100644 index 0000000000..8a839bcd89 --- /dev/null +++ b/src/megatron/bridge/models/wan/wan_layer_spec.py @@ -0,0 +1,862 @@ + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import copy +from dataclasses import dataclass +from typing import Union, Optional + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from megatron.core import parallel_state, tensor_parallel +from megatron.core.transformer.attention import ( + CrossAttention, + CrossAttentionSubmodules, + SelfAttention, + SelfAttentionSubmodules, +) +from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TERowParallelLinear, + TELinear, +) +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.utils import make_viewless_tensor +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.extensions.transformer_engine import TENorm + +try: + import transformer_engine # pylint: disable=unused-import + + HAVE_TE = True + from megatron.core.extensions.transformer_engine import SplitAlongDim + +except ImportError: + HAVE_TE = False + SplitAlongDim = None + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x).type_as(x) + + +@dataclass +class WanSelfAttentionSubmodules: + """ + Configuration class for specifying the submodules of a self-attention. + """ + + linear_qkv: Union[ModuleSpec, type] = None + core_attention: Union[ModuleSpec, type] = None + linear_proj: Union[ModuleSpec, type] = None + layernorm_across_head: bool = False + q_layernorm: Union[ModuleSpec, type] = None + k_layernorm: Union[ModuleSpec, type] = None + + +@dataclass +class WanCrossAttentionSubmodules: + """ + Configuration class for specifying the submodules of a cross-attention. + """ + linear_q: Union[ModuleSpec, type] = None + linear_kv: Union[ModuleSpec, type] = None + core_attention: Union[ModuleSpec, type] = None + linear_proj: Union[ModuleSpec, type] = None + layernorm_across_head: bool = False + q_layernorm: Union[ModuleSpec, type] = None + k_layernorm: Union[ModuleSpec, type] = None + + +class WanSelfAttention(SelfAttention): + def __init__( + self, + config: TransformerConfig, + submodules: WanSelfAttentionSubmodules, + layer_number: int, + attn_mask_type: AttnMaskType, + cp_comm_type: str = None, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + super().__init__( + config, + submodules, + layer_number, + attn_mask_type, + cp_comm_type, + pg_collection, + ) + + self.layernorm_across_head = submodules.layernorm_across_head + + # override q_layernorm + if submodules.q_layernorm is not None: + if self.layernorm_across_head: + q_layernorm_size = self.query_projection_size + else: + q_layernorm_size = self.hidden_size_per_attention_head + import transformer_engine as te + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.q_layernorm = build_module( + submodules.q_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=q_layernorm_size, + config=norm_config, + ) + else: + self.q_layernorm = None + + # override k_layernorm + if submodules.k_layernorm is not None: + if self.layernorm_across_head: + k_layernorm_size = self.kv_projection_size + else: + k_layernorm_size = self.hidden_size_per_attention_head + import transformer_engine as te + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.k_layernorm = build_module( + submodules.k_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=k_layernorm_size, + config=norm_config, + ) + else: + self.k_layernorm = None + + def get_query_key_value_tensors(self, hidden_states, key_value_states=None): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_qkv, _ = self.linear_qkv(hidden_states) + + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + + split_arg_list = [ + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + + if SplitAlongDim is not None: + + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) + else: + + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) + + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + + # gather query and key heads across TP ranks if self.layernorm_across_head is True + if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.gather_from_tensor_model_parallel_region(query) + key = tensor_parallel.gather_from_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + + if self.q_layernorm is not None: + if self.layernorm_across_head: + q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] + q_flat = self.q_layernorm(q_flat) + query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] + else: + query = self.q_layernorm(query.contiguous()) + + if self.k_layernorm is not None: + if self.layernorm_across_head: + k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() + k_flat = self.k_layernorm(k_flat) + key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) + else: + key = self.k_layernorm(key.contiguous()) + + # scatter query and key heads across TP ranks if self.layernorm_across_head is True + if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.scatter_to_tensor_model_parallel_region(query) + key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = query.contiguous() # important becuase TE attention expects contiguous tensors + key = key.contiguous() # important becuase TE attention expects contiguous tensors + + if self.config.test_mode: + self.run_realtime_tests() + + return query, key, value + + +class WanCrossAttention(CrossAttention): + def __init__( + self, + config: TransformerConfig, + submodules: WanCrossAttentionSubmodules, + layer_number: int, + attn_mask_type: AttnMaskType, + cp_comm_type: str = None, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + super().__init__( + config, + submodules, + layer_number, + attn_mask_type, + cp_comm_type, + pg_collection, + ) + + self.layernorm_across_head = submodules.layernorm_across_head + + # override q_layernorm + if submodules.q_layernorm is not None: + if self.layernorm_across_head: + q_layernorm_size = self.query_projection_size + else: + q_layernorm_size = self.hidden_size_per_attention_head + import transformer_engine as te + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.q_layernorm = build_module( + submodules.q_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=q_layernorm_size, + config=norm_config, + ) + else: + self.q_layernorm = None + + # override k_layernorm + if submodules.k_layernorm is not None: + if self.layernorm_across_head: + k_layernorm_size = self.kv_projection_size + else: + k_layernorm_size = self.hidden_size_per_attention_head + import transformer_engine as te + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.k_layernorm = build_module( + submodules.k_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=k_layernorm_size, + config=norm_config, + ) + else: + self.k_layernorm = None + + def get_query_key_value_tensors(self, hidden_states, key_value_states): + """ + Derives `query` tensor from `hidden_states`, and `key`/`value` tensors + from `key_value_states`. + """ + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv, _ = self.linear_kv(key_value_states) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv.size()[:-1] + ( + self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head, + ) + mixed_kv = mixed_kv.view(*new_tensor_shape) + + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) + + # Attention head [sq, b, h] --> [sq, b, hp] + query, _ = self.linear_q(hidden_states) + + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + query = query.view(*new_tensor_shape) + + # gather query and key heads across TP ranks if self.layernorm_across_head is True + if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.gather_from_tensor_model_parallel_region(query) + key = tensor_parallel.gather_from_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + + if self.q_layernorm is not None: + if self.layernorm_across_head: + q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] + q_flat = self.q_layernorm(q_flat) + query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] + else: + query = self.q_layernorm(query.contiguous()) + + if self.k_layernorm is not None: + if self.layernorm_across_head: + k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() + k_flat = self.k_layernorm(k_flat) + key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) + else: + key = self.k_layernorm(key.contiguous()) + + # scatter query and key heads across TP ranks if self.layernorm_across_head is True + if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.scatter_to_tensor_model_parallel_region(query) + key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = query.contiguous() # important becuase TE attention expects contiguous tensors + key = key.contiguous() # important becuase TE attention expects contiguous tensors + + return query, key, value + + +@dataclass +class WanWithAdaLNSubmodules(TransformerLayerSubmodules): + temporal_self_attention: Union[ModuleSpec, type] = IdentityOp + full_self_attention: Union[ModuleSpec, type] = IdentityOp + norm1: Union[ModuleSpec, type] = None + norm3: Union[ModuleSpec, type] = None + norm2: Union[ModuleSpec, type] = None + context_proj: Union[ModuleSpec, type] = IdentityOp + + +# @dataclass +# class VACEContextLayerSubmodules(WanWithAdaLNSubmodules): + + + +class WanAdaLN(MegatronModule): + """ + Adaptive Layer Normalization Module for DiT. + """ + + def __init__( + self, config: TransformerConfig + ): + super().__init__(config) + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5) + + setattr(self.modulation, "sequence_parallel", config.sequence_parallel) + + def forward(self, timestep_emb): + e = (self.modulation + timestep_emb).chunk(6, dim=1) + return e + + # @jit_fuser + def modulate(self, x, shift, scale): + return x * (1 + scale) + shift + + # @jit_fuser + def scale_add(self, residual, x, gate): + return residual + gate * x + + +class WanLayerWithAdaLN(TransformerLayer): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + DiT with Adapative Layer Normalization. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + ): + super().__init__( + config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout, pg_collection=pg_collection, vp_stage=vp_stage + ) + + # # TODO: Override Cross Attention to disable TP Comm overlap as well. ??? + # # Not disabling will attempt re-use of buffer size same as Q and lead to incorrect tensor shapes. + # cp_override_config = copy.deepcopy(config) + # cp_override_config.tp_comm_overlap = False + # self.cross_attention = build_module( + # submodules.cross_attention, + # config=cp_override_config, + # layer_number=layer_number, + # ) + + self.full_self_attention = build_module( + submodules.full_self_attention, + config=self.config, + layer_number=layer_number, + cp_comm_type=config.cp_comm_type, + pg_collection=pg_collection, + ) + + self.adaLN = WanAdaLN(config=self.config) + self.norm1 = build_module( + submodules.norm1, + dim=config.hidden_size, + eps=config.layernorm_epsilon, + elementwise_affine=False + ) + self.norm3 = build_module( + submodules.norm3, + dim=config.hidden_size, + eps=config.layernorm_epsilon, + elementwise_affine=True, + ) + self.norm2 = build_module( + submodules.norm2, + dim=config.hidden_size, + eps=config.layernorm_epsilon, + elementwise_affine=False, + ) + + + def forward( + self, + hidden_states, + attention_mask=None, + context=None, + context_mask=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + inference_context=None, + ): + + # log_checkpoint("before layer") + + # the timestep embedding is stored in attention_mask argument + timestep_emb = attention_mask + rope_emb = rotary_pos_emb + + shift_full, scale_full, gate_full, shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb) + # transpose to bring it to [1, b, ...] format + shift_full = shift_full.transpose(0, 1) + scale_full = scale_full.transpose(0, 1) + gate_full = gate_full.transpose(0, 1) + shift_mlp = shift_mlp.transpose(0, 1) + scale_mlp = scale_mlp.transpose(0, 1) + gate_mlp = gate_mlp.transpose(0, 1) + + # ******************************************** full self attention ******************************************* + + # adaLN with scale + shift + gate + pre_full_attn_layernorm_output_ada = self.adaLN.modulate( + self.norm1(hidden_states), + shift=shift_full, + scale=scale_full, + ) + + attention_output, bias = self.full_self_attention( + pre_full_attn_layernorm_output_ada, + attention_mask=None, + rotary_pos_emb=rope_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params['self_attention'], + ) + if bias is not None: + attention_output = attention_output + bias + + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=attention_output, gate=gate_full) + + # ******************************************** cross attention ****************************************************** + + attention_output, bias = self.cross_attention( + self.norm3(hidden_states), + attention_mask=context_mask, + key_value_states=context, + packed_seq_params=packed_seq_params['cross_attention'], + ) + if bias is not None: + attention_output = attention_output + bias + + hidden_states = hidden_states + attention_output + + # ******************************************** mlp ****************************************************** + + pre_mlp_layernorm_output_ada = self.adaLN.modulate( + self.norm2(hidden_states), + shift=shift_mlp, + scale=scale_mlp, + ) + + mlp_output, bias = self.mlp(pre_mlp_layernorm_output_ada) + if bias is not None: + mlp_output = mlp_output + bias + + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) + + # TODO: Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. ??? + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) + # output = hidden_states + + # log_checkpoint("after layer") + + return output, context + +def log_checkpoint(tag): + torch.cuda.synchronize() + alloc = torch.cuda.memory_allocated() / 1024**3 + reserved = torch.cuda.memory_reserved() / 1024**3 + print(f"[{tag}] alloc={alloc:.2f} GB reserved={reserved:.2f} GB") + +class VACEBaseLayer(WanLayerWithAdaLN): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + DiT with Adapative Layer Normalization. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + ): + super().__init__( + config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout, pg_collection=pg_collection, vp_stage=vp_stage + ) + + + def forward( + self, + hidden_states, + attention_mask=None, + context=None, + context_mask=None, + context_signal=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + inference_context=None, + ): + + # log_checkpoint("before base") + + hidden_states, context = super().forward( + hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=None, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + inference_context=inference_context, + ) + # consider how to pass block id and context_scale + # the context_tokens from context branch is stored in context_signal argument + if self.idx is not None: + hidden_states = hidden_states + context_signal[self.idx] * self.context_scale + # hidden_states = hidden_states + context_signal[self.idx] * 2.0 + # hidden_states = hidden_states + torch.rand_like(context_signal[self.idx]) * 0.05 + + # log_checkpoint(f"after base {self.idx}") + + return hidden_states, context + + +class VACEContextLayer(WanLayerWithAdaLN): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + DiT with Adapative Layer Normalization. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + ): + super().__init__( + config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout, pg_collection=pg_collection, vp_stage=vp_stage + ) + + # self.context_proj = build_module( + # submodules.context_proj, + # self.config.hidden_size, + # self.config.hidden_size, + # config=self.config, + # init_method=self.config.output_layer_init_method, + # bias=self.config.add_bias_linear, + # input_is_parallel=False, + # skip_bias_add=True, + # is_expert=False, + # tp_comm_buffer_name='proj', + # tp_group=self.pg_collection.tp, + # ) + self.context_proj = build_module( + submodules.context_proj, + self.config.hidden_size, + self.config.hidden_size, + parallel_mode="duplicated", + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + skip_bias_add=False, + skip_weight_param_allocation=False, + is_expert=False, + symmetric_ar_type=self.config.symmetric_ar_type, + tp_comm_buffer_name='proj', + tp_group=None, + ) + + + def forward( + self, + hidden_states, + attention_mask=None, + context=None, + context_mask=None, + context_signal=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + inference_context=None, + ): + + # log_checkpoint("before context") + + hidden_states, context = super().forward( + hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=None, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + inference_context=inference_context, + ) + context_signal[self.idx] = self.context_proj(hidden_states)[0] + + # log_checkpoint("after context") + + return hidden_states, context_signal + + +import transformer_engine as te +def get_wan_block_with_transformer_engine_spec() -> ModuleSpec: + params = {"attn_mask_type": AttnMaskType.padding} + return ModuleSpec( + module=WanLayerWithAdaLN, + submodules=WanWithAdaLNSubmodules( + norm1=WanLayerNorm, + norm3=WanLayerNorm, + norm2=WanLayerNorm, + full_self_attention=ModuleSpec( + module=WanSelfAttention, + params=params, + submodules=WanSelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + cross_attention=ModuleSpec( + module=WanCrossAttention, + params=params, + submodules=WanCrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + # by default, activation_func is openai_gelu, which is equivalent to nn.GELU(approximate='tanh') + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) + + +def get_vace_base_block_with_transformer_engine_spec() -> ModuleSpec: + params = {"attn_mask_type": AttnMaskType.padding} + return ModuleSpec( + module=VACEBaseLayer, + submodules=WanWithAdaLNSubmodules( + norm1=WanLayerNorm, + norm3=WanLayerNorm, + norm2=WanLayerNorm, + full_self_attention=ModuleSpec( + module=WanSelfAttention, + params=params, + submodules=WanSelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + cross_attention=ModuleSpec( + module=WanCrossAttention, + params=params, + submodules=WanCrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + # by default, activation_func is openai_gelu, which is equivalent to nn.GELU(approximate='tanh') + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) + + +def get_vace_context_block_with_transformer_engine_spec() -> ModuleSpec: + params = {"attn_mask_type": AttnMaskType.padding} + return ModuleSpec( + module=VACEContextLayer, + submodules=WanWithAdaLNSubmodules( + norm1=WanLayerNorm, + norm3=WanLayerNorm, + norm2=WanLayerNorm, + full_self_attention=ModuleSpec( + module=WanSelfAttention, + params=params, + submodules=WanSelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + cross_attention=ModuleSpec( + module=WanCrossAttention, + params=params, + submodules=WanCrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + # by default, activation_func is openai_gelu, which is equivalent to nn.GELU(approximate='tanh') + linear_fc2=TERowParallelLinear, + ), + ), + context_proj=TELinear + ), + ) + diff --git a/src/megatron/bridge/models/wan/wan_model.py b/src/megatron/bridge/models/wan/wan_model.py new file mode 100644 index 0000000000..9bec5c85b3 --- /dev/null +++ b/src/megatron/bridge/models/wan/wan_model.py @@ -0,0 +1,1356 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from typing import Dict, Literal, Optional, Tuple, List, Union +import copy + +import math +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from megatron.core import parallel_state, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.transformer_block import TransformerBlock, TransformerBlockSubmodules +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import make_sharded_tensor_for_checkpoint +from megatron.bridge.models.wan.wan_layer_spec import ( + get_wan_block_with_transformer_engine_spec as WanLayerWithAdaLNspec, + get_vace_base_block_with_transformer_engine_spec as VACEBaseLayerspec, + get_vace_context_block_with_transformer_engine_spec as VACEContextLayerspec, +) +from megatron.bridge.models.wan.wan_layer_spec import WanLayerNorm +from torch import Tensor +from .rope_utils import Wan3DRopeEmbeddings + +from contextlib import nullcontext +from megatron.core.fp4_utils import get_fp4_context +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.enums import Fp8Recipe +from megatron.core.fusions.fused_layer_norm import FusedLayerNorm +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset +from megatron.core.utils import ( + WrappedTensor, + deprecate_inference_params, + get_pg_rank, + make_viewless_tensor, +) + +try: + import transformer_engine.pytorch as te # pylint: disable=unused-import + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + import apex # pylint: disable=unused-import + + HAVE_APEX = True +except ImportError: + HAVE_APEX = False + +get_cpu_offload_context = None +te_checkpoint = None + +if HAVE_TE: + from megatron.core.extensions.transformer_engine import ( + TENorm, + get_cpu_offload_context, + te_checkpoint, + ) + + LayerNormImpl = TENorm + +elif HAVE_APEX: + LayerNormImpl = FusedLayerNorm + +else: + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + LayerNormImpl = WrappedTorchNorm + +class BaseTransformerBlock(TransformerBlock): + def __init__( + self, + config: TransformerConfig, + spec: Union[TransformerBlockSubmodules, ModuleSpec], + post_layer_norm: bool = True, + pre_process: bool = True, + post_process: bool = True, + pg_collection: ProcessGroupCollection = None, + vp_stage: Optional[int] = None, + ): + # Pass block id and context_scale + self.vace_layers = [i for i in range(0, config.num_layers, 2)] if config.vace_layers is None else config.vace_layers + print(self.vace_layers) + assert 0 in self.vace_layers + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + super().__init__( + config=config, + spec=spec, + post_layer_norm=post_layer_norm, + pre_process=pre_process, + post_process=post_process, + pg_collection=pg_collection, + vp_stage=vp_stage, + ) + + def _build_layers(self): + # Transformer layers. + # @jcasper can we improve how we deal with layer_number? + # currently it's only used in CoreAttention? + # if self.apply_query_key_layer_scaling: + # coeff = self.layer_number + # self.norm_factor *= coeff + def build_layer(layer_spec, layer_number): + global_layer_number = layer_number + get_transformer_layer_offset( + self.config, self.vp_stage, get_pg_rank(self.pg_collection.pp) + ) # 1-based index + if self.config.heterogeneous_block_specs: + layer_config = self.config.get_config_for_layer(global_layer_number) + else: + layer_config = self.config + + # Get appropriate quantization context (FP8 and FP4 are mutually exclusive) + if layer_config.fp8: + quantization_context = get_fp8_context( + layer_config, global_layer_number - 1, is_init=True + ) + elif layer_config.fp4: + quantization_context = get_fp4_context( + layer_config, global_layer_number - 1, is_init=True + ) + else: + quantization_context = nullcontext() + + with quantization_context: + module = build_module( + layer_spec, + config=layer_config, + layer_number=layer_number, + pg_collection=self.pg_collection, + vp_stage=self.vp_stage, + ) + idx = global_layer_number - 1 + if idx in self.vace_layers: + module.idx = self.vace_layers_mapping[idx] + module.context_scale = self.config.context_scale + else: + module.idx = None + return module + + # offset is implicit in TransformerLayer + self.layers = torch.nn.ModuleList( + [ + build_layer(layer_spec, i + 1) + for i, layer_spec in enumerate(self.submodules.layer_specs) + ] + ) + + # @TODO: add back account_for_embedding_in_pipeline_split (see issue #293) + # In pipeline parallelism, we want to add this LN only to the last stage of the pipeline + # self.post_process and self.post_layer_norm guide this behavior + if self.submodules.layer_norm and self.post_process and self.post_layer_norm: + self.final_layernorm = build_module( + self.submodules.layer_norm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + else: + self.final_layernorm = None # Either this or nn.Identity + + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor, + context_mask: Tensor, + context_signal: Tensor, + rotary_pos_emb: Tensor, + attention_bias: Tensor, + packed_seq_params: PackedSeqParams, + use_inner_quantization_context: bool, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward( + hidden_states, attention_mask, context, context_mask, context_signal, rotary_pos_emb + ): + for index in range(start, end): + layer = self._get_layer(index) + + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context( + self.config, layer.layer_number - 1 + ) + # TODO: check if fp4 is supported in this case + elif self.config.fp4: + inner_quantization_context = get_fp4_context( + self.config, layer.layer_number - 1 + ) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + + with inner_quantization_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + context_signal=context_signal, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=packed_seq_params, + ) + return hidden_states, context + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + # TODO: check if fp4 is supported in this case + if self.config.fp8 or self.config.fp4: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + self.pg_collection.tp, + hidden_states, + attention_mask, + context, + context_mask, + context_signal, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + context_signal, + rotary_pos_emb, + ) + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states, context = checkpoint_handler( + custom(layer_idx, layer_idx + self.config.recompute_num_layers) + ) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + # TODO: check if fp4 is supported in this case + if (self.config.fp8 or self.config.fp4) and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states, context = custom(layer_idx, layer_idx + 1)( + hidden_states, attention_mask, context, context_mask, context_signal, rotary_pos_emb + ) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Optional[Tensor], + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + context_signal: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + rotary_pos_cos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + dynamic_inference_decode_only: Optional[bool] = None, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + context_signal (Tensor, optional): Signal from context tokens + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. + rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. + rotary_pos_cos_sin (Optional[Tensor]): Combined rotary embedding cosine and sine. + Currently used exclusively for inference with dynamic batching and flashinfer RoPE. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_context (BaseInferenceContext, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + dynamic_inference_decode_only: Optional[bool]: If true, indicates that the current + inference context is for decode-only. This args is only used to uniquely + identify decode and non-decode cuda graph runners in the cuda graph manager. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + # Remove 'dynamic_inference_decode_only' from kwargs if present + # this is only used to uniquely identify decode and non-decode cuda graph + # runners in the cuda graph manager + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + # For FP4: NVFP4BlockScaling doesn't have delayed scaling, always uses inner context + if self.config.fp8: + use_outer_quantization_context = self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_quantization_context = self.config.fp8_recipe != Fp8Recipe.delayed + outer_quantization_context = ( + get_fp8_context(self.config) if use_outer_quantization_context else nullcontext() + ) + elif self.config.fp4: + use_outer_quantization_context = False + use_inner_quantization_context = True + outer_quantization_context = nullcontext() + else: + # No quantization + use_outer_quantization_context = False + use_inner_quantization_context = False + outer_quantization_context = nullcontext() + + with rng_context, outer_quantization_context: + # Forward pass. + if self.config.recompute_granularity == 'full' and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + context_signal=context_signal, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + use_inner_quantization_context=use_inner_quantization_context, + ) + else: + for l_no, layer in enumerate(self.layers): + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context( + self.config, layer.layer_number - 1 + ) + elif self.config.fp4: + inner_quantization_context = get_fp4_context( + self.config, layer.layer_number - 1 + ) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + + with self.offload_context, inner_quantization_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + context_signal=context_signal, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + # rotary_pos_cos_sin=rotary_pos_cos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + # Final layer norm. + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + + # If this TransformerBlock is empty, input and output hidden states will be the same node + # on the computational graph and will lead to unexpected errors in pipeline schedules. + if not self.pre_process and len(self.layers) == 0 and not self.final_layernorm: + hidden_states = hidden_states.clone() + + return hidden_states + +class ContextTransformerBlock(TransformerBlock): + def __init__( + self, + config: TransformerConfig, + spec: Union[TransformerBlockSubmodules, ModuleSpec], + post_layer_norm: bool = True, + pre_process: bool = True, + post_process: bool = True, + pg_collection: ProcessGroupCollection = None, + vp_stage: Optional[int] = None, + ): + # Pass block id and context_scale + self.vace_id = [i for i in range(0, config.num_layers)] if config.vace_layers is None else [i for i in range(0, len(config.vace_layers))] + print(self.vace_id) + assert 0 in self.vace_id + + super().__init__( + config=config, + spec=spec, + post_layer_norm=post_layer_norm, + pre_process=pre_process, + post_process=post_process, + pg_collection=pg_collection, + vp_stage=vp_stage, + ) + + def _build_layers(self): + # Transformer layers. + # @jcasper can we improve how we deal with layer_number? + # currently it's only used in CoreAttention? + # if self.apply_query_key_layer_scaling: + # coeff = self.layer_number + # self.norm_factor *= coeff + def build_layer(layer_spec, layer_number): + global_layer_number = layer_number + get_transformer_layer_offset( + self.config, self.vp_stage, get_pg_rank(self.pg_collection.pp) + ) # 1-based index + if self.config.heterogeneous_block_specs: + layer_config = self.config.get_config_for_layer(global_layer_number) + else: + layer_config = self.config + + # Get appropriate quantization context (FP8 and FP4 are mutually exclusive) + if layer_config.fp8: + quantization_context = get_fp8_context( + layer_config, global_layer_number - 1, is_init=True + ) + elif layer_config.fp4: + quantization_context = get_fp4_context( + layer_config, global_layer_number - 1, is_init=True + ) + else: + quantization_context = nullcontext() + + with quantization_context: + module = build_module( + layer_spec, + config=layer_config, + layer_number=layer_number, + pg_collection=self.pg_collection, + vp_stage=self.vp_stage, + ) + idx = global_layer_number - 1 + if idx in self.vace_id: + module.idx = idx + else: + module.idx = None + return module + + # offset is implicit in TransformerLayer + self.layers = torch.nn.ModuleList( + [ + build_layer(layer_spec, i + 1) + for i, layer_spec in enumerate(self.submodules.layer_specs) + ] + ) + + # @TODO: add back account_for_embedding_in_pipeline_split (see issue #293) + # In pipeline parallelism, we want to add this LN only to the last stage of the pipeline + # self.post_process and self.post_layer_norm guide this behavior + if self.submodules.layer_norm and self.post_process and self.post_layer_norm: + self.final_layernorm = build_module( + self.submodules.layer_norm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + else: + self.final_layernorm = None # Either this or nn.Identity + + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor, + context_mask: Tensor, + context_signal: Tensor, + rotary_pos_emb: Tensor, + attention_bias: Tensor, + packed_seq_params: PackedSeqParams, + use_inner_quantization_context: bool, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward( + hidden_states, attention_mask, context, context_mask, context_signal, rotary_pos_emb + ): + for index in range(start, end): + layer = self._get_layer(index) + + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context( + self.config, layer.layer_number - 1 + ) + # TODO: check if fp4 is supported in this case + elif self.config.fp4: + inner_quantization_context = get_fp4_context( + self.config, layer.layer_number - 1 + ) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + + with inner_quantization_context: + hidden_states, context_signal = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + context_signal=context_signal, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=packed_seq_params, + ) + return hidden_states, context_signal + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + # TODO: check if fp4 is supported in this case + if self.config.fp8 or self.config.fp4: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + self.pg_collection.tp, + hidden_states, + attention_mask, + context, + context_mask, + context_signal, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + context_signal, + rotary_pos_emb, + ) + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states, context_signal = checkpoint_handler( + custom(layer_idx, layer_idx + self.config.recompute_num_layers) + ) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + # TODO: check if fp4 is supported in this case + if (self.config.fp8 or self.config.fp4) and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states, context_signal = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states, context_signal = custom(layer_idx, layer_idx + 1)( + hidden_states, attention_mask, context, context_mask, context_signal, rotary_pos_emb + ) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states, context_signal + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Optional[Tensor], + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + context_signal: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + rotary_pos_cos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + dynamic_inference_decode_only: Optional[bool] = None, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + context_signal (Tensor, optional): Signal from context tokens + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. + rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. + rotary_pos_cos_sin (Optional[Tensor]): Combined rotary embedding cosine and sine. + Currently used exclusively for inference with dynamic batching and flashinfer RoPE. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_context (BaseInferenceContext, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + dynamic_inference_decode_only: Optional[bool]: If true, indicates that the current + inference context is for decode-only. This args is only used to uniquely + identify decode and non-decode cuda graph runners in the cuda graph manager. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + # Remove 'dynamic_inference_decode_only' from kwargs if present + # this is only used to uniquely identify decode and non-decode cuda graph + # runners in the cuda graph manager + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + # For FP4: NVFP4BlockScaling doesn't have delayed scaling, always uses inner context + if self.config.fp8: + use_outer_quantization_context = self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_quantization_context = self.config.fp8_recipe != Fp8Recipe.delayed + outer_quantization_context = ( + get_fp8_context(self.config) if use_outer_quantization_context else nullcontext() + ) + elif self.config.fp4: + use_outer_quantization_context = False + use_inner_quantization_context = True + outer_quantization_context = nullcontext() + else: + # No quantization + use_outer_quantization_context = False + use_inner_quantization_context = False + outer_quantization_context = nullcontext() + + with rng_context, outer_quantization_context: + # Forward pass. + if self.config.recompute_granularity == 'full' and self.training: + hidden_states, context_signal = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + context_signal=context_signal, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + use_inner_quantization_context=use_inner_quantization_context, + ) + else: + for l_no, layer in enumerate(self.layers): + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context( + self.config, layer.layer_number - 1 + ) + elif self.config.fp4: + inner_quantization_context = get_fp4_context( + self.config, layer.layer_number - 1 + ) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + + with self.offload_context, inner_quantization_context: + hidden_states, context_signal = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + context_signal=context_signal, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + # rotary_pos_cos_sin=rotary_pos_cos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + # Final layer norm. + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + + # If this TransformerBlock is empty, input and output hidden states will be the same node + # on the computational graph and will lead to unexpected errors in pipeline schedules. + if not self.pre_process and len(self.layers) == 0 and not self.final_layernorm: + hidden_states = hidden_states.clone() + + return hidden_states, context_signal + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + return x + + +class WanModel(VisionModule): + """ + WanModel is a VisionModule that implements a Wan model. + Attributes: + config (TransformerConfig): Configuration for the transformer. + pre_process (bool): Whether to apply pre-processing steps. + post_process (bool): Whether to apply post-processing steps. + fp16_lm_cross_entropy (bool): Whether to use fp16 for cross-entropy loss. + parallel_output (bool): Whether to use parallel output. + transformer_decoder_layer_spec (WanLayerWithAdaLNspec): Specification for the transformer decoder layer. + model_type (ModelType): Type of the model. + """ + + def __init__( + self, + config: TransformerConfig, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + transformer_decoder_layer_spec=WanLayerWithAdaLNspec, + **kwargs, + ): + super(WanModel, self).__init__(config=config) + + self.config: TransformerConfig = config + + self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + + # megatron core pipelining currently depends on model type + # TODO: remove this dependency ? + self.model_type = ModelType.encoder_or_decoder + + self.num_heads = self.config.num_attention_heads + self.freq_dim = self.config.freq_dim + self.in_channels = self.config.in_channels + self.out_channels = self.config.out_channels + self.patch_spatial = self.config.patch_spatial + self.patch_temporal = self.config.patch_temporal + self.patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial) + + # these attributes are unused for images/videos, we just set because bridge training requires for LLMs + self.share_embeddings_and_output_weights = False + + ###################################### + ########## Wan architecture ########## + + # embeddings + if self.pre_process: + self.patch_embedding = nn.Conv3d( + self.in_channels, self.config.hidden_size, kernel_size=self.patch_size, stride=self.patch_size) + + self.text_embedding = nn.Sequential( + nn.Linear(self.config.text_dim, self.config.hidden_size), nn.GELU(approximate='tanh'), + nn.Linear(self.config.hidden_size, self.config.hidden_size)) + + self.time_embedding = nn.Sequential( + nn.Linear(self.freq_dim, self.config.hidden_size), nn.SiLU(), nn.Linear(self.config.hidden_size, self.config.hidden_size)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(self.config.hidden_size, self.config.hidden_size * 6)) + + self.rope_embeddings = Wan3DRopeEmbeddings(dim_head = self.config.hidden_size // self.num_heads, max_position_len = 1024) + + # decoder blocks + self.decoder = TransformerBlock( + config=self.config, + spec=self.transformer_decoder_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=False, + ) + + # output head + if self.post_process: + self.head = Head(self.config.hidden_size, self.out_channels, self.patch_size, eps = 1e-6) + + + def forward( + self, + x: Tensor, + grid_sizes: list[Tuple[int, int, int]], + t: Tensor, + context: Tensor, + max_seq_len: int, + packed_seq_params: PackedSeqParams = None, + **kwargs, + ) -> Tensor: + """Forward pass. + + Args: + x List[Tensor]: list of vae encoded data (s, b, c * pF * pH * pW) + grid_sizes List[Tuple[int, int, int]]: list of grid sizes (f, h, w) + t Tensor: timesteps + context List[Tensor]: list of context (text_len, hidden_size) + max_seq_len int: maximum sequence length + packed_seq_params PackedSeqParams: packed sequence parameters + + Returns: + Tensor: output tensor (still patchified) of shape [seq_len, batch_size, hidden_size] + """ + ################################# + ########## Wan forward ########## + + # ============= embedders ============= + + # run input embedding + if self.pre_process: + # x.shape [s, b, c * pF * pH * pW] + seq_len, batch_size, _ = x.shape + c = self.in_channels + pF, pH, pW = self.patch_size + x = x.reshape(seq_len * batch_size, pF, pH, pW, c) # output: x.shape [s * b, pF, pH, pW, c] + x = x.permute(0, 4, 1, 2, 3) # output: x.shape [s * b, c, pF, pH, pW] + x = self.patch_embedding(x) # output: x.shape [s * b, hidden_size, 1, 1, 1] + x = x.flatten(1) # output: x.shape [s * b, hidden_size] + x = x.reshape(seq_len, batch_size, -1) # output: x.shape [s, b, hidden_size] + + # split sequence for sequence_parallel + # TODO: for PP, do we move scatter_to_sequence_parallel_region here or after "x = self.decoder.input_tensor" ??? + if self.config.sequence_parallel: + x = tensor_parallel.scatter_to_sequence_parallel_region(x) # output: x.shape [s * b // tp_size, hidden_size] + + else: + # intermediate stage of pipeline + x = self.decoder.input_tensor + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(x.dtype) + ) + e0 = self.time_projection(e).unflatten(1, (6, self.config.hidden_size)) + + # context embeddings + context = self.text_embedding(context) # shape [text_len, b, hidden_size] + + + # ============= decoder ============= + # calculate rotary pos emb + n_head, dim_head = self.num_heads, self.config.hidden_size // self.num_heads + rotary_pos_emb = self.rope_embeddings(n_head, dim_head, max_seq_len, grid_sizes, t.device) # output: rotary_pos_emb.shape [s, b, 1, dim_head] + + # run decoder + x = self.decoder( + hidden_states=x, + attention_mask=e0, + context=context, + context_mask=None, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=None, + rotary_pos_sin=None, + packed_seq_params=packed_seq_params, + ) + + # return if not post_process + if not self.post_process: + return x + + # head + x = x.transpose(0, 1) # head expects shape [b, s, hidden_size] + x = self.head(x, e) # output: x.shape [b, s, c * pF * pH * pW] + x = x.transpose(0, 1) # reshape back to shape [s, b, c * pF * pH * pW] + + # gather outputs for sequence_parallel + # Note: in GPT models, because the vocab projection matrix is ColumnParallelLinear, the sequence is + # automatically gathered in ColumnParallelLinear forward pass. + # However, in Wan models, we need to gather the outputs manually. + if self.config.sequence_parallel: + x = tensor_parallel.gather_from_sequence_parallel_region(x) + + return x # output: x.shape [s, b, c * pF * pH * pW] + + + def set_input_tensor(self, input_tensor: Tensor) -> None: + """Sets input tensor to the model. + + See megatron.model.transformer.set_input_tensor() + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + assert len(input_tensor) == 1, "input_tensor should only be length 1 for gpt/bert" + self.decoder.set_input_tensor(input_tensor[0]) + + + def sharded_state_dict( + self, prefix: str = "", sharded_offsets: tuple = (), metadata: Optional[Dict] = None + ) -> ShardedStateDict: + """Sharded state dict implementation for GPTModel backward-compatibility (removing extra state). + + Args: + prefix (str): Module name prefix. + sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. + metadata (Optional[Dict]): metadata controlling sharded state dict creation. + + Returns: + ShardedStateDict: sharded state dict for the GPTModel + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + + # DEBUGGING + # for module in ["t_embedder"]: + # for param_name, param in getattr(self, module).named_parameters(): + # weight_key = f"{prefix}{module}.{param_name}" + # self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key) + # DEBUGGING + # Ensure replica ids for non-transformer embedder weights include pipeline dimension + for module in ["text_embedding", "time_embedding", "time_projection"]: + if hasattr(self, module): + for param_name, param in getattr(self, module).named_parameters(): + weight_key = f"{prefix}{module}.{param_name}" + if weight_key in sharded_state_dict: + self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key) + + return sharded_state_dict + + + def _set_embedder_weights_replica_id( + self, tensor: Tensor, sharded_state_dict: ShardedStateDict, embedder_weight_key: str + ) -> None: + """set replica ids of the weights in t_embedder for sharded state dict. + + Args: + sharded_state_dict (ShardedStateDict): state dict with the weight to tie + weight_key (str): key of the weight in the state dict. + This entry will be replaced with a tied version + + Returns: None, acts in-place + """ + tp_rank = parallel_state.get_tensor_model_parallel_rank() + vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() + vpp_rank = vpp_rank if vpp_rank else 0 + vpp_world = parallel_state.get_virtual_pipeline_model_parallel_world_size() + vpp_world = vpp_world if vpp_world else 1 + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + del sharded_state_dict[embedder_weight_key] + replica_id = ( + tp_rank, + (vpp_rank + pp_rank * vpp_world), + parallel_state.get_data_parallel_rank(with_context_parallel=True), + ) + + sharded_state_dict[embedder_weight_key] = make_sharded_tensor_for_checkpoint( + tensor=tensor, + key=embedder_weight_key, + replica_id=replica_id, + allow_shape_mismatch=False, + ) + + +class VACEModel(WanModel): + def __init__( + self, + config: TransformerConfig, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + transformer_decoder_layer_spec=VACEBaseLayerspec, + vace_transformer_decoder_layer_spec=VACEContextLayerspec, + **kwargs, + ): + super().__init__( + config, + pre_process, + post_process, + fp16_lm_cross_entropy, + parallel_output, + transformer_decoder_layer_spec, + **kwargs + ) + + self.vace_in_channels = self.config.vace_in_channels + self.vace_transformer_decoder_layer_spec = vace_transformer_decoder_layer_spec() + + if self.pre_process: + self.vace_patch_embedding = nn.Conv3d( + self.vace_in_channels, self.config.hidden_size, kernel_size=self.patch_size, stride=self.patch_size) + + self.decoder = BaseTransformerBlock( + config=self.config, + spec=self.transformer_decoder_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=False, + ) + # print(self.decoder) + self.vace_config = copy.deepcopy(self.config) + self.vace_config.num_layers = len(self.decoder.vace_layers) + self.vace_decoder = ContextTransformerBlock( + config=self.vace_config, + spec=self.vace_transformer_decoder_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=False, + ) + # print(self.vace_decoder.state_dict().keys()) + + self.vace_init_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size) + + + def forward( + self, + x: Tensor, + grid_sizes: list[Tuple[int, int, int]], + t: Tensor, + context: Tensor, + vace_context: Tensor, + max_seq_len: int, + packed_seq_params: PackedSeqParams = None, + **kwargs, + ) -> Tensor: + """Forward pass. + + Args: + x List[Tensor]: list of vae encoded data (s, b, c * pF * pH * pW) + grid_sizes List[Tuple[int, int, int]]: list of grid sizes (f, h, w) + t Tensor: timesteps + context List[Tensor]: list of context (text_len, hidden_size) + max_seq_len int: maximum sequence length + packed_seq_params PackedSeqParams: packed sequence parameters + + Returns: + Tensor: output tensor (still patchified) of shape [seq_len, batch_size, hidden_size] + """ + ################################# + ########## Wan forward ########## + + # ============= embedders ============= + + # run input embedding + if self.pre_process: + # x.shape [s, b, c * pF * pH * pW] + seq_len, batch_size, _ = x.shape + c = self.in_channels + pF, pH, pW = self.patch_size + x = x.reshape(seq_len * batch_size, pF, pH, pW, c) # output: x.shape [s * b, pF, pH, pW, c] + x = x.permute(0, 4, 1, 2, 3) # output: x.shape [s * b, c, pF, pH, pW] + x = self.patch_embedding(x) # output: x.shape [s * b, hidden_size, 1, 1, 1] + x = x.flatten(1) # output: x.shape [s * b, hidden_size] + x = x.reshape(seq_len, batch_size, -1) # output: x.shape [s, b, hidden_size] + + # vace_context.shape [s, b, c * pF * pH * pW] + vace_seq_len, _, _ = vace_context.shape + vace_c = self.vace_in_channels + # pF, pH, pW = self.patch_size + vace_context = vace_context.reshape(vace_seq_len * batch_size, pF, pH, pW, vace_c) # output: vace_context.shape [s * b, pF, pH, pW, c] + vace_context = vace_context.permute(0, 4, 1, 2, 3) # output: vace_context.shape [s * b, c, pF, pH, pW] + vace_context = self.vace_patch_embedding(vace_context) # output: vace_context.shape [s * b, hidden_size, 1, 1, 1] + vace_context = vace_context.flatten(1) # output: vace_context.shape [s * b, hidden_size] + vace_context = vace_context.reshape(vace_seq_len, batch_size, -1) # output: vace_context.shape [s, b, hidden_size] + vace_context = self.vace_init_proj(vace_context) + x + # vace_context = vace_context.unsqueeze(0) + vace_context = torch.stack([vace_context] * (self.vace_config.num_layers)) + + # split sequence for sequence_parallel + # TODO: for PP, do we move scatter_to_sequence_parallel_region here or after "x = self.decoder.input_tensor" ??? + if self.config.sequence_parallel: + x = tensor_parallel.scatter_to_sequence_parallel_region(x) # output: x.shape [s * b // tp_size, hidden_size] + vace_context = tensor_parallel.scatter_to_sequence_parallel_region(vace_context) # output: vace_context.shape [s * b // tp_size, hidden_size] + + else: + # intermediate stage of pipeline + x = self.decoder.input_tensor + vace_context = self.vace_decoder.input_tensor + + # run context token embedding + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(x.dtype) + ) + e0 = self.time_projection(e).unflatten(1, (6, self.config.hidden_size)) + + # context embeddings + context = self.text_embedding(context) # shape [text_len, b, hidden_size] + + + # ============= decoder ============= + # calculate rotary pos emb + n_head, dim_head = self.num_heads, self.config.hidden_size // self.num_heads + rotary_pos_emb = self.rope_embeddings(n_head, dim_head, max_seq_len, grid_sizes, t.device) # output: rotary_pos_emb.shape [s, b, 1, dim_head] + + s, b, sq, h = rotary_pos_emb.shape + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).reshape(s*b, 1, sq, h) + + # run vace decoder + vace_context = self.vace_decoder( + hidden_states=vace_context[0], + attention_mask=e0, + context=context, + context_mask=None, + context_signal=vace_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=None, + rotary_pos_sin=None, + packed_seq_params=packed_seq_params, + )[1] + + # run decoder + x = self.decoder( + hidden_states=x, + attention_mask=e0, + context=context, + context_mask=None, + context_signal=vace_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=None, + rotary_pos_sin=None, + packed_seq_params=packed_seq_params, + ) + + # return if not post_process + if not self.post_process: + return x + + # head + x = x.transpose(0, 1) # head expects shape [b, s, hidden_size] + x = self.head(x, e) # output: x.shape [b, s, c * pF * pH * pW] + x = x.transpose(0, 1) # reshape back to shape [s, b, c * pF * pH * pW] + + # gather outputs for sequence_parallel + # Note: in GPT models, because the vocab projection matrix is ColumnParallelLinear, the sequence is + # automatically gathered in ColumnParallelLinear forward pass. + # However, in Wan models, we need to gather the outputs manually. + if self.config.sequence_parallel: + x = tensor_parallel.gather_from_sequence_parallel_region(x) + + return x # output: x.shape [s, b, c * pF * pH * pW] \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/wan_provider.py b/src/megatron/bridge/models/wan/wan_provider.py new file mode 100644 index 0000000000..c48b103bfb --- /dev/null +++ b/src/megatron/bridge/models/wan/wan_provider.py @@ -0,0 +1,109 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass + +import torch +from megatron.core import parallel_state +from megatron.bridge.models.transformer_config import TransformerConfig + +from megatron.bridge.models.model_provider import ModelProviderMixin +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.bridge.models.wan.wan_model import WanModel, VACEModel + +logger = logging.getLogger(__name__) + +@dataclass +class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): + crossattn_emb_size: int = 1536 + add_bias_linear: bool = True + gated_linear_unit: bool = False + + num_layers: int = 30 + hidden_size: int = 1536 + ffn_hidden_size: int = 8960 + num_attention_heads: int = 12 + layernorm_epsilon: float = 1e-6 + normalization: str = "RMSNorm" + layernorm_zero_centered_gamma: bool = False + add_qkv_bias: bool = True + rotary_interleaved: bool = True + hidden_dropout: float = 0 + attention_dropout: float = 0 + fp16_lm_cross_entropy: bool = False + parallel_output: bool = True + bf16: bool = False + params_dtype: torch.dtype = torch.float32 + # qkv_format: str = "sbhd" # "thd". NOTE: if we use context parallelism, we need to use "thd" + qkv_format: str = "thd" + # these attributes are unused for images/videos, we just set because bridge training requires for LLMs + seq_length: int = 1024 + share_embeddings_and_output_weights: bool = False + vocab_size: int = 25256 * 8 + make_vocab_size_divisible_by: int = 128 + + # images/videos attributes + in_channels: int = 16 + out_channels: int = 16 + patch_spatial: int = 2 + patch_temporal: int = 1 + freq_dim: int = 256 + text_len: int = 512 + text_dim: int = 4096 + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> WanModel: + vp_size = self.virtual_pipeline_model_parallel_size + if vp_size: + p_size = self.pipeline_model_parallel_size + assert (self.num_layers // p_size) % vp_size == 0, ( + "Make sure the number of model chunks is the same across all pipeline stages." + ) + + model = WanModel + + return model( + self, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + ) + + +@dataclass +class VACEModelProvider(WanModelProvider): + vace_layers: list = None + # vace_layers: list = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28] + vace_in_channels: int = 96 + base_num_layers: int = 30 + context_scale: float = 1.0 + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> VACEModel: + vp_size = self.virtual_pipeline_model_parallel_size + if vp_size: + p_size = self.pipeline_model_parallel_size + assert (self.num_layers // p_size) % vp_size == 0, ( + "Make sure the number of model chunks is the same across all pipeline stages." + ) + + model = VACEModel + + return model( + self, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + ) \ No newline at end of file diff --git a/src/megatron/bridge/models/wan/wan_step.py b/src/megatron/bridge/models/wan/wan_step.py new file mode 100644 index 0000000000..58429a6856 --- /dev/null +++ b/src/megatron/bridge/models/wan/wan_step.py @@ -0,0 +1,127 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import partial +from typing import Iterable + +import torch +from megatron.core import parallel_state +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.utils import get_model_config +from megatron.bridge.models.wan.flow_matching.flow_pipeline import FlowPipeline +from megatron.bridge.training.losses import masked_next_token_loss +from megatron.bridge.training.state import GlobalState + +logger = logging.getLogger(__name__) + +def wan_data_step(qkv_format, dataloader_iter): + batch = next(iter(dataloader_iter.iterable)) + + batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} + + # Construct packed sequence parameters + if ("seq_len_q" in batch) and ("seq_len_kv" in batch): + cu_seqlens = batch["seq_len_q"].cumsum(dim=0).to(torch.int32) + zero = torch.zeros(1, dtype=torch.int32, device="cuda") + cu_seqlens = torch.cat((zero, cu_seqlens)) + + cu_seqlens_kv = batch["seq_len_kv"].cumsum(dim=0).to(torch.int32) + cu_seqlens_kv = torch.cat((zero, cu_seqlens_kv)) + + batch["packed_seq_params"] = { + "self_attention": PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + qkv_format=qkv_format, + ), + "cross_attention": PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens_kv, + qkv_format=qkv_format, + ), + } + + return batch + + +class WanForwardStep: + def __init__(self): + self.diffusion_pipeline = FlowPipeline() + + + def __call__( + self, state: GlobalState, data_iterator: Iterable, model: VisionModule + ) -> tuple[torch.Tensor, partial]: + """ + Forward training step. + """ + timers = state.timers + straggler_timer = state.straggler_timer + + config = get_model_config(model) + + timers("batch-generator", log_level=2).start() + + qkv_format = getattr(config, "qkv_format", "sbhd") + with straggler_timer(bdata=True): + batch = wan_data_step( + qkv_format, data_iterator + ) + timers("batch-generator").stop() + + check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss + check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss + + # run diffusion training step + with straggler_timer: + if parallel_state.is_pipeline_last_stage(): + output_batch, loss, split_loss_mask = self.diffusion_pipeline.training_step(model, batch) + output_tensor = torch.mean(loss, dim=-1) + batch["loss_mask"] = split_loss_mask + else: + output_tensor = self.diffusion_pipeline.training_step(model, batch) + + # DEBUGGING + # TODO: do we need to gather output with sequence or context parallelism here + # especially when we have pipeline parallelism + + loss = output_tensor + if "loss_mask" not in batch or batch["loss_mask"] is None: + loss_mask = torch.ones_like(loss) + loss_mask = batch["loss_mask"] + + loss_function = self._create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) + + return output_tensor, loss_function + + + def _create_loss_function(self, loss_mask: torch.Tensor, check_for_nan_in_loss: bool, check_for_spiky_loss: bool) -> partial: + """Create a partial loss function with the specified configuration. + + Args: + loss_mask: Used to mask out some portions of the loss + check_for_nan_in_loss: Whether to check for NaN values in the loss + check_for_spiky_loss: Whether to check for spiky loss values + + Returns: + A partial function that can be called with output_tensor to compute the loss + """ + return partial( + masked_next_token_loss, + loss_mask, + check_for_nan_in_loss=check_for_nan_in_loss, + check_for_spiky_loss=check_for_spiky_loss, + ) diff --git a/src/megatron/bridge/recipes/wan/wan.py b/src/megatron/bridge/recipes/wan/wan.py new file mode 100644 index 0000000000..b4975ad5a9 --- /dev/null +++ b/src/megatron/bridge/recipes/wan/wan.py @@ -0,0 +1,219 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List, Optional, Union + +from megatron.bridge.data.wan.wan_energon_datamodule import WanDataModuleConfig +from megatron.bridge.models.wan.wan_provider import WanModelProvider +import torch +from megatron.core.distributed import DistributedDataParallelConfig + +from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing +from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE +from megatron.bridge.training.comm_overlap import CommOverlapConfig +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + LoggerConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, get_mixed_precision_config + + +def model_config( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 1, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + seq_length: int = 1024, +) -> WanModelProvider: + """ + Configure the Wan model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + seq_length (int): Sequence length for the model. + Returns: + WanModelProvider: Configuration for the Wan model. + """ + return WanModelProvider( + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_dtype, + virtual_pipeline_model_parallel_size=None, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + seq_length=seq_length, + ) + + +def pretrain_config( + dir: Optional[str] = None, + name: str = "default", + # Dataset configuration + data_paths: Optional[List[str]] = None, + data_args_path: Optional[str] = None, + train_data_path: Optional[List[str]] = None, + valid_data_path: Optional[List[str]] = None, + test_data_path: Optional[List[str]] = None, + per_split_data_args_path: Optional[str] = None, + mock: bool = False, + # Model configuration + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 1, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + use_megatron_fsdp: bool = False, + # Training hyperparameters + train_iters: int = 10000, + global_batch_size: int = 4, + micro_batch_size: int = 1, + lr: float = 0.9e-4, + lr_warmup_iters: int = 2000, + # Precision recipe + # DEBUGGING + precision_config: Optional[Union[MixedPrecisionConfig, str]] = "bf16_mixed", + # precision_config: Optional[Union[MixedPrecisionConfig, str]] = MixedPrecisionConfig( + # fp32=True, + # params_dtype=torch.float32, + # pipeline_dtype=torch.float32, + # autocast_enabled=False, + # ), + comm_overlap_config: Optional[CommOverlapConfig] = None, +) -> ConfigContainer: + """ + Create a pre-training configuration for GPT3 175B model. + + The default configuration is expected to run on 64 nodes with 8 GPUs each. + + Args: + dir (Optional[str]): Base directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + data_paths (Optional[List[str]]): List of paths to dataset files. If None, mock data will be used. + data_args_path (Optional[str]): Path to file containing data arguments. + train_data_path (Optional[List[str]]): List of training data paths. + valid_data_path (Optional[List[str]]): List of validation data paths. + test_data_path (Optional[List[str]]): List of test data paths. + per_split_data_args_path (Optional[str]): Path to JSON file with per-split data configuration. + mock (bool): Whether to use mock data. If True, ignores data_paths. + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism to be passed to model_config. + sequence_parallelism (bool): Whether to use sequence parallelism. + train_iters (int): Total number of training iterations. + global_batch_size (int): Global batch size for training. + micro_batch_size (int): Micro batch size for training. + seq_length (int): Sequence length for training data. + lr (float): Learning rate. + min_lr (float): Minimum learning rate for cosine decay. + lr_warmup_iters (int): Number of warmup iterations for the learning rate. + precision_config (Optional[Union[MixedPrecisionConfig, str]]): Precision configuration for the model. + comm_overlap_config (Optional[CommOverlapConfig]): Communication overlap configuration for the model. + + Returns: + ConfigContainer: Configuration for pre-training. + """ + base_output_dir = dir if dir is not None else os.path.join(os.getcwd(), "nemo_experiments") + run_output_dir = os.path.join(base_output_dir, name) + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + tensorboard_dir = os.path.join(run_output_dir, "tb_logs") + + + model_cfg = model_config( + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, + pipeline_parallelism_dtype=pipeline_parallelism_dtype, + virtual_pipeline_parallelism=virtual_pipeline_parallelism, + context_parallelism=context_parallelism, + sequence_parallelism=sequence_parallelism, + seq_length=1024, + ) + + opt_config, scheduler = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=lr_warmup_iters, + lr_decay_iters=train_iters, + max_lr=lr, + ) + opt_config.use_precision_aware_optimizer = False + + if isinstance(precision_config, str): + precision_config = get_mixed_precision_config(precision_config) + + precision_config.grad_reduce_in_fp32 = False + + + # Config Container + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=train_iters, + eval_interval=2000, + eval_iters=32, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + manual_gc=True, + manual_gc_interval=100, + manual_gc_eval=100, + ), + optimizer=opt_config, + scheduler=scheduler, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=False, + overlap_param_gather=False, + average_in_collective=True, + use_distributed_optimizer=True, + use_megatron_fsdp=use_megatron_fsdp, # need use_distributed_optimizer=True + ), + dataset= WanDataModuleConfig( + path=None, + seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=10) + , + logger=LoggerConfig( + log_interval=10, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + ), + tokenizer=TokenizerConfig(tokenizer_type="NullTokenizer", vocab_size=DEFAULT_NULL_TOKENIZER_VOCAB_SIZE), + checkpoint=CheckpointConfig( + save_interval=2000, + save=checkpoint_dir, + load=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + ), + rng=RNGConfig(seed=1234), + comm_overlap=comm_overlap_config, + mixed_precision=precision_config, + ) + + return cfg diff --git a/vace.sh b/vace.sh new file mode 100644 index 0000000000..2ff32b6f0e --- /dev/null +++ b/vace.sh @@ -0,0 +1,46 @@ +export CUDA_VISIBLE_DEVICES=0,1 + +### Inferencing +# Download T5 weights and VAE weights from "https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/tree/main" +# T5: models_t5_umt5-xxl-enc-bf16.pth, google +# VAE: Wan2.1_VAE.pth + +CHECKPOINT_DIR=/opt/megatron_checkpoint_VACE +T5_DIR=/opt/Wan2.1-T2V-1.3B +VAE_DIR=/opt/Wan2.1-T2V-1.3B + +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_vace.py \ + --model_name vace-1.3B \ + --sizes 832*480 \ + --save_file "test" \ + --src_video "src_video_flow.mp4" \ + --checkpoint_dir ${CHECKPOINT_DIR} \ + --checkpoint_step 0000 \ + --t5_checkpoint_dir ${T5_DIR} \ + --vae_checkpoint_dir ${VAE_DIR} \ + --prompts "Two dogs hit each other during boxing." \ + --frame_nums 81 \ + --tensor_parallel_size 1 \ + --context_parallel_size 2 \ + --pipeline_parallel_size 1 \ + --sequence_parallel False \ + --base_seed 42 \ + --sample_steps 50 + +# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_vace.py \ +# --model_name vace-1.3B \ +# --sizes 832*480 832*480 832*480 \ +# --save_file "test" \ +# --src_video "src_video_depth.mp4" "src_video_flow.mp4" "src_video_pose.mp4" \ +# --checkpoint_dir ${CHECKPOINT_DIR} \ +# --checkpoint_step 0000 \ +# --t5_checkpoint_dir ${T5_DIR} \ +# --vae_checkpoint_dir ${VAE_DIR} \ +# --prompts "Two dogs hit each other during boxing." "Two dogs hit each other during boxing." "Two dogs hit each other during boxing." \ +# --frame_nums 81 81 81 \ +# --tensor_parallel_size 1 \ +# --context_parallel_size 2 \ +# --pipeline_parallel_size 1 \ +# --sequence_parallel False \ +# --base_seed 42 \ +# --sample_steps 50 \ No newline at end of file