Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,5 @@ slurm*.out

# UV package manager
.uv/

*.mp4
129 changes: 129 additions & 0 deletions example_commands.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# ### set path to Megatron-Bridge
# export MBRIDGE_PATH=/path/to/Megatron-Bridge
# export PYTHONPATH="${MBRIDGE_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-Launcher/launcher_scripts"

export CUDA_VISIBLE_DEVICES=0,1

# ### install dependencies
# pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@ce8185cbbe04f38beb74360e878450f2e8525885
# python3 -m pip install --upgrade diffusers
# pip install easydict
# pip install imageio
# pip install imageio-ffmpeg


# ### Convert checkpoint
# See examples/conversion/convert_wan_checkpoints.py for details.


# ### Finetuning
# export HF_TOKEN=...
# export WANDB_API_KEY=...
# EXP_NAME=...
# PRETRAINED_CHECKPOINT=/path/to/pretrained_checkpoint
# CHECKPOINT_DIR=/path/to/checkpoint_dir
# DATASET_PATH=/path/to/dataset
# cd $MBRIDGE_PATH
# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/recipes/wan/pretrain_wan.py \
# model.tensor_model_parallel_size=1 \
# model.pipeline_model_parallel_size=1 \
# model.context_parallel_size=4 \
# model.sequence_parallel=false \
# model.qkv_format=thd \
# dataset.path=${DATASET_PATH} \
# checkpoint.save=${CHECKPOINT_DIR} \
# checkpoint.load=${PRETRAINED_CHECKPOINT} \
# checkpoint.load_optim=false \
# checkpoint.save_interval=200 \
# optimizer.lr=5e-6 \
# optimizer.min_lr=5e-6 \
# train.eval_iters=0 \
# scheduler.lr_decay_style=constant \
# scheduler.lr_warmup_iters=0 \
# model.seq_length=2048 \
# dataset.seq_length=2048 \
# train.global_batch_size=1 \
# train.micro_batch_size=1 \
# dataset.global_batch_size=1 \
# dataset.micro_batch_size=1 \
# logger.log_interval=1 \
# logger.wandb_project="wan" \
# logger.wandb_exp_name=${EXP_NAME} \
# logger.wandb_save_dir=${CHECKPOINT_DIR}


### Inferencing
# Download T5 weights and VAE weights from "https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/tree/main"
# T5: models_t5_umt5-xxl-enc-bf16.pth, google
# VAE: Wan2.1_VAE.pth

CHECKPOINT_DIR=/opt/megatron_checkpoint_WAN
T5_DIR=/opt/Wan2.1-T2V-1.3B
VAE_DIR=/opt/Wan2.1-T2V-1.3B
# cd $MBRIDGE_PATH
# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \
# --task t2v-1.3B \
# --sizes 832*480 \
# --checkpoint_dir ${CHECKPOINT_DIR} \
# --checkpoint_step 0000 \
# --t5_checkpoint_dir ${T5_DIR} \
# --vae_checkpoint_dir ${VAE_DIR} \
# --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
# --frame_nums 81 \
# --tensor_parallel_size 1 \
# --context_parallel_size 1 \
# --pipeline_parallel_size 1 \
# --sequence_parallel False \
# --base_seed 42 \
# --sample_steps 50


NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 examples/recipes/wan/inference_wan.py \
--task t2v-1.3B \
--sizes 832*480 \
--checkpoint_dir ${CHECKPOINT_DIR} \
--checkpoint_step 0000 \
--t5_checkpoint_dir ${T5_DIR} \
--vae_checkpoint_dir ${VAE_DIR} \
--prompts "Beautiful maple leaves across the mountain during the autumn." \
--frame_nums 81 \
--tensor_parallel_size 1 \
--context_parallel_size 1 \
--pipeline_parallel_size 1 \
--sequence_parallel False \
--base_seed 42 \
--sample_steps 50


# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 examples/recipes/wan/inference_wan.py \
# --task t2v-1.3B \
# --sizes 832*480 \
# --checkpoint_dir ${CHECKPOINT_DIR} \
# --checkpoint_step 0000 \
# --t5_checkpoint_dir ${T5_DIR} \
# --vae_checkpoint_dir ${VAE_DIR} \
# --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
# --frame_nums 81 \
# --tensor_parallel_size 1 \
# --context_parallel_size 2 \
# --pipeline_parallel_size 1 \
# --sequence_parallel False \
# --base_seed 42 \
# --sample_steps 50


# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=2 examples/recipes/wan/inference_wan.py \
# --task t2v-1.3B \
# --sizes 832*480 \
# --checkpoint_dir ${CHECKPOINT_DIR} \
# --checkpoint_step 0000 \
# --t5_checkpoint_dir ${T5_DIR} \
# --vae_checkpoint_dir ${VAE_DIR} \
# --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
# --frame_nums 81 \
# --tensor_parallel_size 1 \
# --context_parallel_size 1 \
# --pipeline_parallel_size 2 \
# --sequence_parallel False \
# --base_seed 42 \
# --sample_steps 50
49 changes: 49 additions & 0 deletions examples/conversion/convert_vace_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os, random, multiprocessing as mp

def main():
from megatron.bridge.models.hf_pretrained.wan import PreTrainedVACE
from megatron.bridge.models.wan.wan_bridge import VACEBridge
from megatron.bridge.training.model_load_save import save_megatron_model

# --- minimal torch.distributed single-rank env ---
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", str(29500 + random.randint(0, 1000)))
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("LOCAL_RANK", "0")

# --- build & load ---
hf = PreTrainedVACE("Wan-AI/Wan2.1-VACE-1.3B-Diffusers")
# hf = PreTrainedVACE("Wan-AI/Wan2.1-VACE-14B-Diffusers")

bridge = VACEBridge()
provider = bridge.provider_bridge(hf)
provider.perform_initialization = False

# If you're on GPU but want CPU init to reduce peak mem:
megatron_models = provider.provide_distributed_model(
wrap_with_ddp=False, use_cpu_initialization=True
)

bridge.load_weights_hf_to_megatron(hf, megatron_models)

# Save Megatron-format checkpoint (this triggers async writer internally)
save_megatron_model(
megatron_models,
"/opt/megatron_checkpoint_VACE",
hf_tokenizer_path=None
)

if __name__ == "__main__":
# On Linux, prefer 'fork' to avoid re-importing the module on spawn.
try:
mp.set_start_method("fork")
except RuntimeError:
# already set (fine on re-entry or non-Linux)
pass

# If you’re on macOS/Windows and still want to be extra safe:
# mp.freeze_support()

main()

74 changes: 74 additions & 0 deletions examples/conversion/convert_wan_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN
# from megatron.bridge.models.wan.wan_bridge import WanBridge
# from megatron.bridge.training.model_load_save import save_megatron_model
# import os, random
# os.environ["MASTER_ADDR"] = "127.0.0.1"
# os.environ["MASTER_PORT"] = str(29500 + random.randint(0, 1000))
# os.environ["RANK"] = "0"
# os.environ["WORLD_SIZE"] = "1"
# os.environ["LOCAL_RANK"] = "0"
# #
# hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
# # hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-14B-Diffusers")
# bridge = WanBridge()
# #
# provider = bridge.provider_bridge(hf)
# provider.perform_initialization = False
# megatron_models = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True)
# #
# bridge.load_weights_hf_to_megatron(hf, megatron_models)
# save_megatron_model(megatron_models, "/opt/megatron_checkpoint", hf_tokenizer_path=None)


# convert_wan_checkpoints.py

import os, random, multiprocessing as mp

def main():
from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN
from megatron.bridge.models.wan.wan_bridge import WanBridge
from megatron.bridge.training.model_load_save import save_megatron_model

# --- minimal torch.distributed single-rank env ---
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", str(29500 + random.randint(0, 1000)))
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("LOCAL_RANK", "0")

# --- build & load ---
hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
# hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-14B-Diffusers")

bridge = WanBridge()
provider = bridge.provider_bridge(hf)
provider.perform_initialization = False

# If you're on GPU but want CPU init to reduce peak mem:
megatron_models = provider.provide_distributed_model(
wrap_with_ddp=False, use_cpu_initialization=True
)
print(megatron_models[0])
bridge.load_weights_hf_to_megatron(hf, megatron_models)


# Save Megatron-format checkpoint (this triggers async writer internally)
save_megatron_model(
megatron_models,
"/opt/megatron_checkpoint_WAN",
hf_tokenizer_path=None
)

if __name__ == "__main__":
# On Linux, prefer 'fork' to avoid re-importing the module on spawn.
try:
mp.set_start_method("fork")
except RuntimeError:
# already set (fine on re-entry or non-Linux)
pass

# If you’re on macOS/Windows and still want to be extra safe:
# mp.freeze_support()

main()

Loading