From 75e89210d1695cc5aa9ef2658ad4af281345d701 Mon Sep 17 00:00:00 2001 From: sajadn Date: Thu, 13 Nov 2025 09:58:25 -0800 Subject: [PATCH 01/15] Fix sequence padding for DiT. Add support for DiT Context Parallel with THD. Signed-off-by: sajadn --- .../common/diffusion_task_encoder_with_sp.py | 1 + dfm/src/megatron/data/dit/dit_taskencoder.py | 17 ++--- .../megatron/model/dit/dit_data_process.py | 63 +++++++++---------- dfm/src/megatron/model/dit/dit_layer_spec.py | 16 ++--- dfm/src/megatron/model/dit/dit_step.py | 27 +++++++- 5 files changed, 70 insertions(+), 54 deletions(-) diff --git a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py index f3fae0b0..a44f36dd 100644 --- a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py +++ b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py @@ -100,6 +100,7 @@ def cat(attr): __subflavors__=samples[0].__subflavors__, video=cat("video"), context_embeddings=cat("context_embeddings"), + context_mask=cat("context_mask"), loss_mask=cat("loss_mask"), seq_len_q=cat("seq_len_q"), seq_len_q_padded=cat("seq_len_q_padded"), diff --git a/dfm/src/megatron/data/dit/dit_taskencoder.py b/dfm/src/megatron/data/dit/dit_taskencoder.py index fe3e6180..4668ce20 100644 --- a/dfm/src/megatron/data/dit/dit_taskencoder.py +++ b/dfm/src/megatron/data/dit/dit_taskencoder.py @@ -130,13 +130,14 @@ def encode_sample(self, sample: dict) -> DiffusionSample: "T H W d -> (T H W) d", ) - if self.packing_buffer_size is None: - pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len)) - loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16) - loss_mask[:seq_len] = 1 - video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len)) - else: - loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) + loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) + sharding_factor = 64 + seq_len_q_padded = ((seq_len + sharding_factor - 1) // sharding_factor) * sharding_factor + + if seq_len < seq_len_q_padded: + video_latent = F.pad(video_latent, (0, 0, 0, seq_len_q_padded - seq_len)) + loss_mask = F.pad(loss_mask, (0, seq_len_q_padded - seq_len)) + pos_ids = F.pad(pos_ids, (0, 0, 0, seq_len_q_padded - seq_len)) return DiffusionSample( __key__=sample["__key__"], @@ -148,6 +149,7 @@ def encode_sample(self, sample: dict) -> DiffusionSample: context_mask=t5_text_mask, loss_mask=loss_mask, seq_len_q=torch.tensor([seq_len], dtype=torch.int32), + seq_len_q_padded=torch.tensor([seq_len_q_padded], dtype=torch.int32), seq_len_kv=torch.tensor([self.text_embedding_padding_size], dtype=torch.int32), pos_ids=pos_ids, latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), @@ -168,6 +170,7 @@ def batch(self, samples: List[DiffusionSample]) -> dict: context_mask=sample.context_mask.unsqueeze_(0) if sample.context_mask is not None else None, loss_mask=sample.loss_mask.unsqueeze_(0) if sample.loss_mask is not None else None, seq_len_q=sample.seq_len_q, + seq_len_q_padded=sample.seq_len_q_padded, seq_len_kv=sample.seq_len_kv, pos_ids=sample.pos_ids.unsqueeze_(0) if sample.pos_ids is not None else None, latent_shape=sample.latent_shape, diff --git a/dfm/src/megatron/model/dit/dit_data_process.py b/dfm/src/megatron/model/dit/dit_data_process.py index e9e9344c..7d732970 100644 --- a/dfm/src/megatron/model/dit/dit_data_process.py +++ b/dfm/src/megatron/model/dit/dit_data_process.py @@ -13,16 +13,18 @@ # limitations under the License. import torch +from megatron.core import parallel_state as ps from megatron.core.packed_seq_params import PackedSeqParams def dit_data_step(qkv_format, dataloader_iter): # import pdb;pdb.set_trace() batch = next(iter(dataloader_iter.iterable)) - batch = get_batch_on_this_cp_rank(batch) - batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} batch["is_preprocessed"] = True # assume data is preprocessed - return encode_seq_length(batch, format=qkv_format) + batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} + batch = encode_seq_length(batch, format=qkv_format) + batch = get_batch_on_this_cp_rank(batch) + return batch def encode_seq_length(batch, format): @@ -35,19 +37,20 @@ def encode_seq_length(batch, format): cu_seqlens_kv = batch["seq_len_kv"].cumsum(dim=0).to(torch.int32) cu_seqlens_kv = torch.cat((zero, cu_seqlens_kv)) + cu_seqlens_q_padded = batch["seq_len_q_padded"].cumsum(dim=0).to(torch.int32) + cu_seqlens_q_padded = torch.cat((zero, cu_seqlens_q_padded)) + batch["packed_seq_params"] = { "self_attention": PackedSeqParams( cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_q, - cu_seqlens_q_padded=None, - cu_seqlens_kv_padded=None, + cu_seqlens_q_padded=cu_seqlens_q_padded, qkv_format=format, ), "cross_attention": PackedSeqParams( cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None, - cu_seqlens_kv_padded=None, + cu_seqlens_q_padded=cu_seqlens_q_padded, qkv_format=format, ), } @@ -57,34 +60,26 @@ def encode_seq_length(batch, format): def get_batch_on_this_cp_rank(data): """Split the data for context parallelism.""" - from megatron.core import mpu - - cp_size = mpu.get_context_parallel_world_size() - cp_rank = mpu.get_context_parallel_rank() - - t = 16 + cp_size = ps.get_context_parallel_world_size() if cp_size > 1: - # cp split on seq_length, for video_latent, noise_latent and pos_ids - assert t % cp_size == 0, "t must divisibly by cp_size" - num_valid_tokens_in_ub = None - if "loss_mask" in data and data["loss_mask"] is not None: - num_valid_tokens_in_ub = data["loss_mask"].sum() + import transformer_engine_torch as tex + + cp_rank = ps.get_context_parallel_rank() + for key in ["video", "loss_mask", "pos_ids"]: + if data[key] is not None: + index = tex.thd_get_partitioned_indices( + data["packed_seq_params"]["self_attention"].cu_seqlens_q_padded, + data[key].size(1), + cp_size, + cp_rank, + ).to(device=data[key].device, dtype=torch.long) + data[key] = data[key].index_select(1, index).contiguous() - for key, value in data.items(): - if (value is not None) and (key in ["video", "video_latent", "noise_latent", "pos_ids"]): - if len(value.shape) > 5: - value = value.squeeze(0) - B, C, T, H, W = value.shape - if T % cp_size == 0: - # FIXME packed sequencing - data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous() - else: - # FIXME packed sequencing - data[key] = value.view(B, C, T, cp_size, H // cp_size, W)[:, :, :, cp_rank, ...].contiguous() - loss_mask = data["loss_mask"] - data["loss_mask"] = loss_mask.view(loss_mask.shape[0], cp_size, loss_mask.shape[1] // cp_size)[ - :, cp_rank, ... - ].contiguous() - data["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub + for key in ["context_embeddings", "context_mask"]: + if data[key] is not None: + index = tex.thd_get_partitioned_indices( + data["packed_seq_params"]["cross_attention"].cu_seqlens_kv, data[key].size(1), cp_size, cp_rank + ).to(device=data[key].device, dtype=torch.long) + data[key] = data[key].index_select(1, index).contiguous() return data diff --git a/dfm/src/megatron/model/dit/dit_layer_spec.py b/dfm/src/megatron/model/dit/dit_layer_spec.py index 2b19be2c..97afaf1d 100644 --- a/dfm/src/megatron/model/dit/dit_layer_spec.py +++ b/dfm/src/megatron/model/dit/dit_layer_spec.py @@ -144,17 +144,11 @@ def _replace_no_cp_submodules(submodules): # Override Cross Attention to disable CP. # Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as Q and lead to # incorrect tensor shapes. - if submodules.cross_attention != IdentityOp: - cp_override_config = copy.deepcopy(config) - cp_override_config.context_parallel_size = 1 - cp_override_config.tp_comm_overlap = False - self.cross_attention = build_module( - submodules.cross_attention, - config=cp_override_config, - layer_number=layer_number, - ) - else: - self.cross_attention = None + self.cross_attention = build_module( + submodules.cross_attention, + config=self.config, + layer_number=layer_number, + ) self.full_self_attention = build_module( submodules.full_self_attention, diff --git a/dfm/src/megatron/model/dit/dit_step.py b/dfm/src/megatron/model/dit/dit_step.py index 9bde05d5..49d97d62 100644 --- a/dfm/src/megatron/model/dit/dit_step.py +++ b/dfm/src/megatron/model/dit/dit_step.py @@ -18,6 +18,7 @@ from typing import Iterable import torch +import wandb from einops import rearrange from megatron.bridge.training.losses import masked_next_token_loss from megatron.bridge.training.state import GlobalState @@ -41,7 +42,7 @@ def __init__(self): self.train = True self.validation_step = 0 - def on_validation_start(self, batch, model, step): + def on_validation_start(self, state, batch, model, step): C, T, H, W = batch["latent_shape"][0] latent = self.diffusion_pipeline.generate_samples_from_batch( model, @@ -81,6 +82,28 @@ def on_validation_start(self, batch, model, step): video_save_path=f"{image_folder}/validation_step={step}_rank={rank}.mp4", ) + wandb_rank = parallel_state.get_data_parallel_world_size() - 1 + if torch.distributed.get_rank() == wandb_rank: + gather_list = [None for _ in range(parallel_state.get_data_parallel_world_size())] + else: + gather_list = None + + torch.distributed.gather_object( + obj=decoded_video[0], + object_gather_list=gather_list, + dst=wandb_rank, + group=parallel_state.get_data_parallel_group(), + ) + if torch.distributed.get_rank() == wandb_rank: + if gather_list is not None: + videos = [] + for video_data in gather_list: + video_data_transposed = video_data.transpose(0, 3, 1, 2) + videos.append(wandb.Video(video_data_transposed, fps=24, format="mp4")) + + if state.wandb_logger is not None: + state.wandb_logger.log({"prediction": videos}) + def __call__( self, state: GlobalState, data_iterator: Iterable, model: GPTModel, return_schedule_plan: bool = False ) -> tuple[torch.Tensor, partial]: @@ -103,7 +126,7 @@ def __call__( self.train = False self.valid = True self.validation_step += 1 - self.on_validation_start(batch, model, step=self.validation_step) + self.on_validation_start(state, batch, model, step=self.validation_step) return self.forward_step(state, batch, model, return_schedule_plan) def data_process( From 38d2109c5826f99afb3bccde04e699a351b2e8bf Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Thu, 13 Nov 2025 22:58:50 +0000 Subject: [PATCH 02/15] Enhance DiT and Wan layer specifications - Updated `get_query_key_value_tensors` method in `dit_attention.py` to include an `output_gate` parameter and set `split_qkv` to default to `True`. - Modified `WanLayerWithAdaLN` class in `wan_layer_spec.py` to add `rotary_pos_cos_sin` parameter for improved positional encoding handling. --- dfm/src/megatron/model/common/dit_attention.py | 6 +++--- dfm/src/megatron/model/wan/wan_layer_spec.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dfm/src/megatron/model/common/dit_attention.py b/dfm/src/megatron/model/common/dit_attention.py index 321e9b08..9f29ff49 100644 --- a/dfm/src/megatron/model/common/dit_attention.py +++ b/dfm/src/megatron/model/common/dit_attention.py @@ -100,7 +100,7 @@ def __init__( else: self.k_layernorm = None - def get_query_key_value_tensors(self, hidden_states, key_value_states=None, split_qkv=False): + def get_query_key_value_tensors(self, hidden_states, key_value_states=None, output_gate=None, split_qkv=True): """ Derives `query`, `key` and `value` tensors from `hidden_states`. """ @@ -251,13 +251,13 @@ def __init__( is_expert=False, ) - def get_query_key_value_tensors(self, hidden_states, key_value_states, split_qkv=False): + def get_query_key_value_tensors(self, hidden_states, key_value_states, output_gate=None, split_qkv=True): """ Derives `query` tensor from `hidden_states`, and `key`/`value` tensors from `key_value_states`. """ - query, key, value = super().get_query_key_value_tensors(hidden_states, key_value_states) + query, key, value = super().get_query_key_value_tensors(hidden_states, key_value_states, output_gate=output_gate, split_qkv=split_qkv) # gather query and key heads across TP ranks if self.layernorm_across_heads is True if self.layernorm_across_heads and parallel_state.get_tensor_model_parallel_world_size() > 1: diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index 2b355930..a0d6354e 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -162,6 +162,7 @@ def forward( packed_seq_params=None, sequence_len_offset=None, inference_context=None, + rotary_pos_cos_sin=None, ): # the timestep embedding is stored in attention_mask argument timestep_emb = attention_mask From 3d9bd13850de0cc72db1444c5fd0eca5e1c48db6 Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Fri, 14 Nov 2025 10:01:25 +0000 Subject: [PATCH 03/15] Implement ProcessGroupCollection initialization in DiT and Wan models - Added initialization of `pg_collection` in both `DiTCrossAttentionModel` and `WanModel` to ensure proper handling of process groups. - This change checks if `pg_collection` exists and is not None before assigning it, enhancing the robustness of the models. --- dfm/src/megatron/model/dit/dit_model.py | 4 ++++ dfm/src/megatron/model/wan/wan_model.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/dfm/src/megatron/model/dit/dit_model.py b/dfm/src/megatron/model/dit/dit_model.py index e3ae8a29..213f7a0f 100644 --- a/dfm/src/megatron/model/dit/dit_model.py +++ b/dfm/src/megatron/model/dit/dit_model.py @@ -25,6 +25,7 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer.enums import ModelType from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import make_sharded_tensor_for_checkpoint from torch import Tensor @@ -103,6 +104,9 @@ def __init__( **kwargs, ): super(DiTCrossAttentionModel, self).__init__(config=config) + # Check if pg_collection exists and is not none then only do this + if not hasattr(self, 'pg_collection') or self.pg_collection is None: + self.pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.config: TransformerConfig = config diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index 444aa597..61f47458 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -24,6 +24,7 @@ from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.models.common.vision_module.vision_module import VisionModule from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.enums import ModelType from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig @@ -103,6 +104,8 @@ def __init__( super(WanModel, self).__init__(config=config) self.config: TransformerConfig = config + if not hasattr(self, 'pg_collection') or self.pg_collection is None: + self.pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() self.pre_process = pre_process From 582411689897cc766e56bccb37d3c238305b5133 Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Fri, 14 Nov 2025 13:28:53 +0000 Subject: [PATCH 04/15] Update CONTRIBUTING.md to include detailed setup instructions for development environment and Docker container usage. Added sections for building and running the container, as well as setting the PYTHONPATH for DFM. --- CONTRIBUTING.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 68ab66d4..aed9cf99 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,32 @@ # Contributing To NeMo DFM +## 🛠️ Setting Up Your Environment + +Use the instructions below to setup a dev environment and a dev container + +### Building a container +```bash +# We recommend you to get the latest commits for Megatron-Bridge and Autmodel +# The easiest way to do that might be to remove the 3rdparty directly completely before running the following commands +git submodule update --init --recursive --remote # Get all the 3rd party submodules +cd 3rdparty/Megatron-Bridge/3rdparty/Megatron-LM # Megatron LM commit might be wrong +# Get the right megatron commit from here: https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/3rdparty +git checkout +cd ../../../../ +docker build -f docker/Dockerfile.ci -t dfm:latest . +``` + +### Run the container +```bash +docker run --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --gpus all $(pwd):/opt/DFM -it dfm:latest bash +``` + +### inside the container +```bash +# Add DFM to PYTHONPATH +export PYTHONPATH=$PYTHONPATH:/opt/DFM + +# Run a Mock Run: +``` ## Signing Your Work From 06859062aaef1dfcc20327d9b30264ea853211df Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Fri, 14 Nov 2025 13:30:49 +0000 Subject: [PATCH 05/15] Refactor import statements in dit_model.py to streamline dependencies. Removed redundant import of ProcessGroupCollection, enhancing code clarity and maintainability. --- dfm/src/megatron/model/dit/dit_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfm/src/megatron/model/dit/dit_model.py b/dfm/src/megatron/model/dit/dit_model.py index 213f7a0f..5cc49c37 100644 --- a/dfm/src/megatron/model/dit/dit_model.py +++ b/dfm/src/megatron/model/dit/dit_model.py @@ -23,9 +23,9 @@ from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.models.common.vision_module.vision_module import VisionModule from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.enums import ModelType from megatron.core.transformer.transformer_block import TransformerBlock -from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import make_sharded_tensor_for_checkpoint from torch import Tensor From d2a7c6fc449f809253a3507dd638a18b68ca6077 Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Fri, 14 Nov 2025 13:44:55 +0000 Subject: [PATCH 06/15] Refactor code style in DiT and Wan models - Updated string quotes in `dit_model.py` and `wan_model.py` for consistency, changing from single to double quotes. - Reformatted the `get_query_key_value_tensors` method call in `dit_attention.py` for improved readability by breaking it into multiple lines. --- dfm/src/megatron/model/common/dit_attention.py | 4 +++- dfm/src/megatron/model/dit/dit_model.py | 2 +- dfm/src/megatron/model/wan/wan_model.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dfm/src/megatron/model/common/dit_attention.py b/dfm/src/megatron/model/common/dit_attention.py index 9f29ff49..acf39d47 100644 --- a/dfm/src/megatron/model/common/dit_attention.py +++ b/dfm/src/megatron/model/common/dit_attention.py @@ -257,7 +257,9 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states, output_ga from `key_value_states`. """ - query, key, value = super().get_query_key_value_tensors(hidden_states, key_value_states, output_gate=output_gate, split_qkv=split_qkv) + query, key, value = super().get_query_key_value_tensors( + hidden_states, key_value_states, output_gate=output_gate, split_qkv=split_qkv + ) # gather query and key heads across TP ranks if self.layernorm_across_heads is True if self.layernorm_across_heads and parallel_state.get_tensor_model_parallel_world_size() > 1: diff --git a/dfm/src/megatron/model/dit/dit_model.py b/dfm/src/megatron/model/dit/dit_model.py index 5cc49c37..934e9a15 100644 --- a/dfm/src/megatron/model/dit/dit_model.py +++ b/dfm/src/megatron/model/dit/dit_model.py @@ -105,7 +105,7 @@ def __init__( ): super(DiTCrossAttentionModel, self).__init__(config=config) # Check if pg_collection exists and is not none then only do this - if not hasattr(self, 'pg_collection') or self.pg_collection is None: + if not hasattr(self, "pg_collection") or self.pg_collection is None: self.pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.config: TransformerConfig = config diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index 61f47458..39d1dc39 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -104,7 +104,7 @@ def __init__( super(WanModel, self).__init__(config=config) self.config: TransformerConfig = config - if not hasattr(self, 'pg_collection') or self.pg_collection is None: + if not hasattr(self, "pg_collection") or self.pg_collection is None: self.pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() From 471811fe40aa9bfac39c5883d7c79c286ae25896 Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Sat, 15 Nov 2025 18:07:36 +0000 Subject: [PATCH 07/15] Revert M4 changes --- dfm/src/megatron/model/dit/dit_model.py | 4 ---- dfm/src/megatron/model/wan/wan_model.py | 2 -- 2 files changed, 6 deletions(-) diff --git a/dfm/src/megatron/model/dit/dit_model.py b/dfm/src/megatron/model/dit/dit_model.py index 934e9a15..ebfd5409 100644 --- a/dfm/src/megatron/model/dit/dit_model.py +++ b/dfm/src/megatron/model/dit/dit_model.py @@ -104,10 +104,6 @@ def __init__( **kwargs, ): super(DiTCrossAttentionModel, self).__init__(config=config) - # Check if pg_collection exists and is not none then only do this - if not hasattr(self, "pg_collection") or self.pg_collection is None: - self.pg_collection = ProcessGroupCollection.use_mpu_process_groups() - self.config: TransformerConfig = config self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index 39d1dc39..b5e6b6d8 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -104,8 +104,6 @@ def __init__( super(WanModel, self).__init__(config=config) self.config: TransformerConfig = config - if not hasattr(self, "pg_collection") or self.pg_collection is None: - self.pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() self.pre_process = pre_process From f0a928b82c8470c9bc8d25f4913a6f46ec8d6dfc Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Sat, 15 Nov 2025 18:08:29 +0000 Subject: [PATCH 08/15] Ruff --- dfm/src/megatron/model/dit/dit_model.py | 1 - dfm/src/megatron/model/wan/wan_model.py | 1 - 2 files changed, 2 deletions(-) diff --git a/dfm/src/megatron/model/dit/dit_model.py b/dfm/src/megatron/model/dit/dit_model.py index ebfd5409..5c42f495 100644 --- a/dfm/src/megatron/model/dit/dit_model.py +++ b/dfm/src/megatron/model/dit/dit_model.py @@ -105,7 +105,6 @@ def __init__( ): super(DiTCrossAttentionModel, self).__init__(config=config) self.config: TransformerConfig = config - self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() self.pre_process = pre_process self.post_process = post_process diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index b5e6b6d8..444aa597 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -24,7 +24,6 @@ from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.models.common.vision_module.vision_module import VisionModule from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.enums import ModelType from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig From f0aa57348c35363869cb79a2ca36fae9819e34a8 Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Sat, 15 Nov 2025 18:09:50 +0000 Subject: [PATCH 09/15] Ruff --- dfm/src/megatron/model/dit/dit_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfm/src/megatron/model/dit/dit_model.py b/dfm/src/megatron/model/dit/dit_model.py index 5c42f495..0cab5992 100644 --- a/dfm/src/megatron/model/dit/dit_model.py +++ b/dfm/src/megatron/model/dit/dit_model.py @@ -23,7 +23,6 @@ from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.models.common.vision_module.vision_module import VisionModule from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.enums import ModelType from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig @@ -104,6 +103,7 @@ def __init__( **kwargs, ): super(DiTCrossAttentionModel, self).__init__(config=config) + self.config: TransformerConfig = config self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() self.pre_process = pre_process From e5c6b5b5b9d75bbab70f9d35ffba9c553acc7445 Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Sat, 15 Nov 2025 18:12:02 +0000 Subject: [PATCH 10/15] Lint --- dfm/src/megatron/model/dit/dit_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfm/src/megatron/model/dit/dit_model.py b/dfm/src/megatron/model/dit/dit_model.py index 0cab5992..38cb8422 100644 --- a/dfm/src/megatron/model/dit/dit_model.py +++ b/dfm/src/megatron/model/dit/dit_model.py @@ -103,7 +103,7 @@ def __init__( **kwargs, ): super(DiTCrossAttentionModel, self).__init__(config=config) - + self.config: TransformerConfig = config self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() self.pre_process = pre_process From 344898ff88732e36b0ef098fbeb47676ebae0e76 Mon Sep 17 00:00:00 2001 From: sajadn Date: Thu, 13 Nov 2025 12:09:17 -0800 Subject: [PATCH 11/15] Fix sequence padding for DiT. Add support for DiT Context Parallel with THD. Signed-off-by: sajadn --- dfm/src/megatron/data/dit/dit_taskencoder.py | 2 ++ dfm/src/megatron/model/dit/dit_step.py | 14 +++++++++++--- dfm/src/megatron/model/dit/edm/edm_pipeline.py | 6 +----- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/dfm/src/megatron/data/dit/dit_taskencoder.py b/dfm/src/megatron/data/dit/dit_taskencoder.py index 4668ce20..0f574e74 100644 --- a/dfm/src/megatron/data/dit/dit_taskencoder.py +++ b/dfm/src/megatron/data/dit/dit_taskencoder.py @@ -151,6 +151,7 @@ def encode_sample(self, sample: dict) -> DiffusionSample: seq_len_q=torch.tensor([seq_len], dtype=torch.int32), seq_len_q_padded=torch.tensor([seq_len_q_padded], dtype=torch.int32), seq_len_kv=torch.tensor([self.text_embedding_padding_size], dtype=torch.int32), + seq_len_kv_padded=torch.tensor([self.text_embedding_padding_size], dtype=torch.int32), pos_ids=pos_ids, latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), ) @@ -172,6 +173,7 @@ def batch(self, samples: List[DiffusionSample]) -> dict: seq_len_q=sample.seq_len_q, seq_len_q_padded=sample.seq_len_q_padded, seq_len_kv=sample.seq_len_kv, + seq_len_kv_padded=sample.seq_len_kv_padded, pos_ids=sample.pos_ids.unsqueeze_(0) if sample.pos_ids is not None else None, latent_shape=sample.latent_shape, ) diff --git a/dfm/src/megatron/model/dit/dit_step.py b/dfm/src/megatron/model/dit/dit_step.py index 49d97d62..600bd82a 100644 --- a/dfm/src/megatron/model/dit/dit_step.py +++ b/dfm/src/megatron/model/dit/dit_step.py @@ -82,8 +82,16 @@ def on_validation_start(self, state, batch, model, step): video_save_path=f"{image_folder}/validation_step={step}_rank={rank}.mp4", ) - wandb_rank = parallel_state.get_data_parallel_world_size() - 1 - if torch.distributed.get_rank() == wandb_rank: + is_last_dp_rank = parallel_state.get_data_parallel_rank() == ( + parallel_state.get_data_parallel_world_size() - 1 + ) + + last_dp_local_rank = parallel_state.get_data_parallel_world_size() - 1 + dp_group = parallel_state.get_data_parallel_group() + dp_ranks = torch.distributed.get_process_group_ranks(dp_group) + wandb_rank = dp_ranks[last_dp_local_rank] + + if is_last_dp_rank: gather_list = [None for _ in range(parallel_state.get_data_parallel_world_size())] else: gather_list = None @@ -94,7 +102,7 @@ def on_validation_start(self, state, batch, model, step): dst=wandb_rank, group=parallel_state.get_data_parallel_group(), ) - if torch.distributed.get_rank() == wandb_rank: + if is_last_dp_rank: if gather_list is not None: videos = [] for video_data in gather_list: diff --git a/dfm/src/megatron/model/dit/edm/edm_pipeline.py b/dfm/src/megatron/model/dit/edm/edm_pipeline.py index 4f2f8ece..ff0d9581 100644 --- a/dfm/src/megatron/model/dit/edm/edm_pipeline.py +++ b/dfm/src/megatron/model/dit/edm/edm_pipeline.py @@ -360,18 +360,14 @@ def generate_samples_from_batch( if self._noise_generator is None: self._initialize_generators() x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) - - state_shape = list(state_shape) - state_shape[1] //= parallel_state.get_context_parallel_world_size() x_sigma_max = ( torch.randn(state_shape, **self.tensor_kwargs, generator=self._noise_generator) * self.sde.sigma_max ) samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) - if cp_enabled: cp_group = parallel_state.get_context_parallel_group() - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=cp_group) + samples = cat_outputs_cp(samples, seq_dim=1, cp_group=cp_group) return samples From 34a8c509898c79a63a075df6271bc48556fc1709 Mon Sep 17 00:00:00 2001 From: sajadn Date: Sun, 16 Nov 2025 10:03:51 -0800 Subject: [PATCH 12/15] fix cp inference. add cu_seqlen_kv_padded which was missing. Signed-off-by: sajadn --- .../common/utils/torch_split_tensor_for_cp.py | 38 ++++++++++++++----- dfm/src/megatron/data/dit/dit_taskencoder.py | 37 +++++++----------- .../megatron/model/dit/dit_data_process.py | 5 +++ .../megatron/model/dit/edm/edm_pipeline.py | 3 +- 4 files changed, 50 insertions(+), 33 deletions(-) diff --git a/dfm/src/common/utils/torch_split_tensor_for_cp.py b/dfm/src/common/utils/torch_split_tensor_for_cp.py index f389b7ba..6d72050b 100644 --- a/dfm/src/common/utils/torch_split_tensor_for_cp.py +++ b/dfm/src/common/utils/torch_split_tensor_for_cp.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch +import transformer_engine_torch as tex from torch import Tensor -from torch.distributed import ProcessGroup, all_gather, get_world_size +from torch.distributed import ProcessGroup, all_gather, get_rank, get_world_size -def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: +def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup, thd_cu_seqlens: Optional[Tensor] = None) -> Tensor: """ Concatenates tensors from multiple processes along a specified dimension. @@ -28,24 +31,41 @@ def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: x (Tensor): The input tensor to be gathered and concatenated. seq_dim (int): The dimension along which to concatenate the gathered tensors. cp_group (ProcessGroup): The process group containing all the processes involved in the gathering. + thd_cu_seqlens (Tensor, optional): THD cumulative sequence lengths used during partitioning. Provide + this to restore the original token order after gathering. Returns: - Tensor: A tensor resulting from the concatenation of tensors from all processes. + Tensor: A tensor resulting from the concatenation of tensors from all processes. If `thd_cu_seqlens` + is provided, the tensor is reordered to match the original (pre-partition) sequence order. Raises: RuntimeError: If the gathering of tensors fails. """ # Number of processes in the group world_size = get_world_size(cp_group) - # List to hold tensors from each rank gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)] # Attempt to gather tensors from all ranks - try: - all_gather(gathered_tensors, x, group=cp_group) - except RuntimeError as e: - raise RuntimeError(f"Gathering failed: {e}") + all_gather(gathered_tensors, x, group=cp_group) # Concatenate tensors along the specified dimension - return torch.cat(gathered_tensors, dim=seq_dim) + gathered = torch.cat(gathered_tensors, dim=seq_dim) + total_seq_len = int(thd_cu_seqlens[-1].item()) + # Rebuild the global index ordering used during THD partitioning. + cp_rank = get_rank(cp_group) + local_indices = tex.thd_get_partitioned_indices(thd_cu_seqlens, total_seq_len, world_size, cp_rank).to( + device=x.device, dtype=torch.long + ) + + # Gather indices from all ranks to compute the inverse permutation. + gathered_indices = [torch.empty_like(local_indices) for _ in range(world_size)] + all_gather(gathered_indices, local_indices, group=cp_group) + global_indices = torch.cat(gathered_indices, dim=0) + + if global_indices.numel() != gathered.size(seq_dim): + raise RuntimeError("Gathered indices size does not match gathered tensor along sequence dimension.") + + restore_order = torch.argsort(global_indices, dim=0) + gathered = gathered.index_select(seq_dim, restore_order.to(device=gathered.device)) + return gathered.contiguous() diff --git a/dfm/src/megatron/data/dit/dit_taskencoder.py b/dfm/src/megatron/data/dit/dit_taskencoder.py index 0f574e74..f7553bce 100644 --- a/dfm/src/megatron/data/dit/dit_taskencoder.py +++ b/dfm/src/megatron/data/dit/dit_taskencoder.py @@ -94,7 +94,9 @@ def encode_sample(self, sample: dict) -> DiffusionSample: pw=self.patch_spatial, pt=self.patch_temporal, ) - sample["pickle"] = sample["pickle"].cpu().float().numpy().squeeze(0) + sample["pickle"] = sample["pickle"].cpu().float().numpy() + if sample["pickle"].shape[0] == 1: + sample["pickle"] = sample["pickle"][0] if is_image: t5_text_embeddings = torch.from_numpy(sample["pickle"]).to(torch.bfloat16) else: @@ -103,28 +105,8 @@ def encode_sample(self, sample: dict) -> DiffusionSample: if t5_text_embeddings_seq_length > self.text_embedding_padding_size: t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size] - else: - t5_text_embeddings = F.pad( - t5_text_embeddings, - ( - 0, - 0, - 0, - self.text_embedding_padding_size - t5_text_embeddings_seq_length, - ), - ) t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16) - if is_image: - h, w = info["image_height"], info["image_width"] - fps = torch.tensor([30] * 1, dtype=torch.bfloat16) - num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16) - else: - h, w = info["height"], info["width"] - fps = torch.tensor([info["framerate"]] * 1, dtype=torch.bfloat16) - num_frames = torch.tensor([info["num_frames"]] * 1, dtype=torch.bfloat16) - image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16) - pos_ids = rearrange( pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial), "T H W d -> (T H W) d", @@ -133,12 +115,21 @@ def encode_sample(self, sample: dict) -> DiffusionSample: loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) sharding_factor = 64 seq_len_q_padded = ((seq_len + sharding_factor - 1) // sharding_factor) * sharding_factor + seq_len_kv_padded = ( + (t5_text_embeddings_seq_length + sharding_factor - 1) // sharding_factor + ) * sharding_factor if seq_len < seq_len_q_padded: video_latent = F.pad(video_latent, (0, 0, 0, seq_len_q_padded - seq_len)) loss_mask = F.pad(loss_mask, (0, seq_len_q_padded - seq_len)) pos_ids = F.pad(pos_ids, (0, 0, 0, seq_len_q_padded - seq_len)) + if t5_text_embeddings_seq_length < seq_len_kv_padded: + t5_text_embeddings = F.pad( + t5_text_embeddings, (0, 0, 0, seq_len_kv_padded - t5_text_embeddings_seq_length) + ) + t5_text_mask = F.pad(t5_text_mask, (0, seq_len_kv_padded - t5_text_embeddings_seq_length)) + return DiffusionSample( __key__=sample["__key__"], __restore_key__=sample["__restore_key__"], @@ -150,8 +141,8 @@ def encode_sample(self, sample: dict) -> DiffusionSample: loss_mask=loss_mask, seq_len_q=torch.tensor([seq_len], dtype=torch.int32), seq_len_q_padded=torch.tensor([seq_len_q_padded], dtype=torch.int32), - seq_len_kv=torch.tensor([self.text_embedding_padding_size], dtype=torch.int32), - seq_len_kv_padded=torch.tensor([self.text_embedding_padding_size], dtype=torch.int32), + seq_len_kv=torch.tensor([t5_text_embeddings_seq_length], dtype=torch.int32), + seq_len_kv_padded=torch.tensor([seq_len_kv_padded], dtype=torch.int32), pos_ids=pos_ids, latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), ) diff --git a/dfm/src/megatron/model/dit/dit_data_process.py b/dfm/src/megatron/model/dit/dit_data_process.py index 7d732970..e599581a 100644 --- a/dfm/src/megatron/model/dit/dit_data_process.py +++ b/dfm/src/megatron/model/dit/dit_data_process.py @@ -40,17 +40,22 @@ def encode_seq_length(batch, format): cu_seqlens_q_padded = batch["seq_len_q_padded"].cumsum(dim=0).to(torch.int32) cu_seqlens_q_padded = torch.cat((zero, cu_seqlens_q_padded)) + cu_seqlens_kv_padded = batch["seq_len_kv_padded"].cumsum(dim=0).to(torch.int32) + cu_seqlens_kv_padded = torch.cat((zero, cu_seqlens_kv_padded)) + batch["packed_seq_params"] = { "self_attention": PackedSeqParams( cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_q, cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_q_padded, qkv_format=format, ), "cross_attention": PackedSeqParams( cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, qkv_format=format, ), } diff --git a/dfm/src/megatron/model/dit/edm/edm_pipeline.py b/dfm/src/megatron/model/dit/edm/edm_pipeline.py index ff0d9581..dc4a6aba 100644 --- a/dfm/src/megatron/model/dit/edm/edm_pipeline.py +++ b/dfm/src/megatron/model/dit/edm/edm_pipeline.py @@ -367,7 +367,8 @@ def generate_samples_from_batch( samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) if cp_enabled: cp_group = parallel_state.get_context_parallel_group() - samples = cat_outputs_cp(samples, seq_dim=1, cp_group=cp_group) + thd_cu_seqlen_q_padded = data_batch["packed_seq_params"]["self_attention"].cu_seqlens_q_padded + samples = cat_outputs_cp(samples, seq_dim=1, cp_group=cp_group, thd_cu_seqlens=thd_cu_seqlen_q_padded) return samples From 71ef76a9290d6ba1e03de324b62369ccc36efba0 Mon Sep 17 00:00:00 2001 From: Sajad Norouzi Date: Sun, 16 Nov 2025 14:11:34 -0800 Subject: [PATCH 13/15] Add mock DiT dataset. Make DiT attention compatible with megatron bridge. Signed-off-by: Sajad Norouzi --- .../megatron/data/dit/dit_mock_datamodule.py | 165 ++++++++++++++++++ .../megatron/model/common/dit_attention.py | 8 +- dfm/src/megatron/recipes/dit/dit.py | 38 +++- .../recipes/dit/pretrain_dit_model.py | 3 +- 4 files changed, 201 insertions(+), 13 deletions(-) create mode 100644 dfm/src/megatron/data/dit/dit_mock_datamodule.py diff --git a/dfm/src/megatron/data/dit/dit_mock_datamodule.py b/dfm/src/megatron/data/dit/dit_mock_datamodule.py new file mode 100644 index 00000000..22640292 --- /dev/null +++ b/dfm/src/megatron/data/dit/dit_mock_datamodule.py @@ -0,0 +1,165 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from dataclasses import dataclass + +import torch +from einops import rearrange +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider +from torch.utils.data import DataLoader, Dataset + +from dfm.src.megatron.data.dit.dit_taskencoder import PosID3D + + +pos_id_3d = PosID3D() + + +class _MockDataset(Dataset): + def __init__(self, length: int): + self.length = max(int(length), 1) + + def __len__(self) -> int: + return self.length + + def __getitem__(self, idx: int) -> dict: + return {} + + +def mock_batch( + F_latents: int, + H_latents: int, + W_latents: int, + patch_temporal: int, + patch_spatial: int, + number_packed_samples: int, + context_seq_len: int, + context_embeddings_dim: int, +) -> dict: + # set mock values for one video sample + video_latent = torch.randn(16, F_latents, H_latents, W_latents, dtype=torch.float32) + C, T, H, W = video_latent.shape + video_latent = rearrange( + video_latent, + "C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)", + ph=patch_spatial, + pw=patch_spatial, + pt=patch_temporal, + ) + video_latent = torch.as_tensor(video_latent, dtype=torch.bfloat16) + + context_embeddings = torch.randn(context_seq_len, context_embeddings_dim, dtype=torch.bfloat16) + context_embeddings_seq_length = context_embeddings.shape[0] + context_embeddings_mask = torch.ones(context_embeddings_seq_length, dtype=torch.bfloat16) + + pos_ids = rearrange( + pos_id_3d.get_pos_id_3d(t=T // patch_temporal, h=H // patch_spatial, w=W // patch_spatial), + "T H W d -> (T H W) d", + ) + + seq_len_q = video_latent.shape[0] + seq_len_q_padded = seq_len_q + + loss_mask = torch.ones(seq_len_q, dtype=torch.bfloat16) + + seq_len_kv = context_embeddings.shape[0] + seq_len_kv_padded = seq_len_kv + + video_latents_packed = [video_latent for _ in range(number_packed_samples)] + video_latents_packed = torch.cat(video_latents_packed, dim=0) + + context_embeddings_packed = [context_embeddings for _ in range(number_packed_samples)] + context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0) + + context_embeddings_mask_packed = [context_embeddings_mask for _ in range(number_packed_samples)] + context_embeddings_mask_packed = torch.cat(context_embeddings_mask_packed, dim=0) + + loss_masks_packed = [loss_mask for _ in range(number_packed_samples)] + loss_masks_packed = torch.cat(loss_masks_packed, dim=0) + + seq_len_q_packed = torch.tensor([seq_len_q for _ in range(number_packed_samples)], dtype=torch.int32) + seq_len_q_padded_packed = torch.tensor([seq_len_q_padded for _ in range(number_packed_samples)], dtype=torch.int32) + + seq_len_kv_packed = torch.tensor([seq_len_kv for _ in range(number_packed_samples)], dtype=torch.int32) + seq_len_kv_padded_packed = torch.tensor( + [seq_len_kv_padded for _ in range(number_packed_samples)], dtype=torch.int32 + ) + + pos_ids_packed = [pos_ids for _ in range(number_packed_samples)] + pos_ids_packed = torch.cat(pos_ids_packed, dim=0) + + context_embeddings_packed = [context_embeddings for _ in range(number_packed_samples)] + context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0) + + batch = dict( + video=video_latents_packed.unsqueeze(0), + context_embeddings=context_embeddings_packed.unsqueeze(0), + context_mask=context_embeddings_mask_packed.unsqueeze(0), + loss_mask=loss_masks_packed.unsqueeze(0), + seq_len_q=seq_len_q_packed, + seq_len_q_padded=seq_len_q_padded_packed, + seq_len_kv=seq_len_kv_packed, + seq_len_kv_padded=seq_len_kv_padded_packed, + latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), + pos_ids=pos_ids_packed, + ) + + return batch + + +@dataclass(kw_only=True) +class DiTMockDataModuleConfig(DatasetProvider): + path: str = "" + seq_length: int + packing_buffer_size: int + micro_batch_size: int + global_batch_size: int + num_workers: int + dataloader_type: str = "external" + task_encoder_seq_length: int = None + F_latents: int = 1 + H_latents: int = 64 + W_latents: int = 96 + patch_spatial: int = 2 + patch_temporal: int = 1 + number_packed_samples: int = 3 + context_seq_len: int = 512 + context_embeddings_dim: int = 1024 + + def __post_init__(self): + mock_ds = _MockDataset(length=1024) + self._train_dl = DataLoader( + mock_ds, + batch_size=self.micro_batch_size, + num_workers=self.num_workers, + collate_fn=lambda samples: mock_batch( + F_latents=self.F_latents, + H_latents=self.H_latents, + W_latents=self.W_latents, + patch_temporal=self.patch_temporal, + patch_spatial=self.patch_spatial, + number_packed_samples=self.number_packed_samples, + context_seq_len=self.context_seq_len, + context_embeddings_dim=self.context_embeddings_dim, + ), + shuffle=False, + drop_last=False, + ) + self.sequence_length = self.seq_length + + def build_datasets(self, _context: DatasetBuildContext): + if hasattr(self, "dataset"): + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() + return self._train_dl, self._train_dl, self._train_dl diff --git a/dfm/src/megatron/model/common/dit_attention.py b/dfm/src/megatron/model/common/dit_attention.py index 321e9b08..000e683d 100644 --- a/dfm/src/megatron/model/common/dit_attention.py +++ b/dfm/src/megatron/model/common/dit_attention.py @@ -100,7 +100,7 @@ def __init__( else: self.k_layernorm = None - def get_query_key_value_tensors(self, hidden_states, key_value_states=None, split_qkv=False): + def get_query_key_value_tensors(self, hidden_states, key_value_states=None, output_gate=None, split_qkv=True): """ Derives `query`, `key` and `value` tensors from `hidden_states`. """ @@ -251,13 +251,15 @@ def __init__( is_expert=False, ) - def get_query_key_value_tensors(self, hidden_states, key_value_states, split_qkv=False): + def get_query_key_value_tensors(self, hidden_states, key_value_states, output_gate=None, split_qkv=False): """ Derives `query` tensor from `hidden_states`, and `key`/`value` tensors from `key_value_states`. """ - query, key, value = super().get_query_key_value_tensors(hidden_states, key_value_states) + query, key, value = super().get_query_key_value_tensors( + hidden_states, key_value_states, output_gate=output_gate, split_qkv=split_qkv + ) # gather query and key heads across TP ranks if self.layernorm_across_heads is True if self.layernorm_across_heads and parallel_state.get_tensor_model_parallel_world_size() > 1: diff --git a/dfm/src/megatron/recipes/dit/dit.py b/dfm/src/megatron/recipes/dit/dit.py index 14c15e89..74c762cd 100644 --- a/dfm/src/megatron/recipes/dit/dit.py +++ b/dfm/src/megatron/recipes/dit/dit.py @@ -31,6 +31,7 @@ from megatron.core.distributed import DistributedDataParallelConfig from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModuleConfig +from dfm.src.megatron.data.dit.dit_mock_datamodule import DiTMockDataModuleConfig from dfm.src.megatron.model.dit.dit_model_provider import DiTModelProvider @@ -158,6 +159,33 @@ def pretrain_config( precision_config.grad_reduce_in_fp32 = False + if mock: + dataset = DiTMockDataModuleConfig( + path=None, + seq_length=2048, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + task_encoder_seq_length=8000, + packing_buffer_size=40, + num_workers=10, + # mock arguments + F_latents=1, + H_latents=96, + W_latents=64, + context_seq_len=512, + context_embeddings_dim=1024, + ) + else: + dataset = DiffusionDataModuleConfig( + path=dataset_path, + seq_length=2048, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + task_encoder_seq_length=8000, + packing_buffer_size=40, + num_workers=10, + ) + # Config Container cfg = ConfigContainer( model=model_cfg, @@ -182,15 +210,7 @@ def pretrain_config( use_distributed_optimizer=True, use_megatron_fsdp=use_megatron_fsdp, # need use_distributed_optimizer=True ), - dataset=DiffusionDataModuleConfig( - path=dataset_path, - seq_length=2048, - task_encoder_seq_length=8000, - packing_buffer_size=40, - micro_batch_size=micro_batch_size, - global_batch_size=global_batch_size, - num_workers=10, - ), + dataset=dataset, logger=LoggerConfig( log_interval=10, tensorboard_dir=tensorboard_dir, diff --git a/examples/megatron/recipes/dit/pretrain_dit_model.py b/examples/megatron/recipes/dit/pretrain_dit_model.py index ae4dc54d..faf9f35c 100644 --- a/examples/megatron/recipes/dit/pretrain_dit_model.py +++ b/examples/megatron/recipes/dit/pretrain_dit_model.py @@ -86,6 +86,7 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: description="Pretrain Llama3 8B model using Megatron-Bridge with YAML and CLI overrides", formatter_class=argparse.RawTextHelpFormatter, ) + parser.add_argument("--mock", action="store_true", help="Whether to use mock data.") parser.add_argument( "--config-file", type=str, @@ -139,7 +140,7 @@ def main() -> None: logger.info("------------------------------------------------------------------") # Load base configuration from the recipe as a Python dataclass - cfg: ConfigContainer = pretrain_config(dataset_path=args.dataset_path) + cfg: ConfigContainer = pretrain_config(dataset_path=args.dataset_path, mock=args.mock) logger.info("Loaded base configuration") # Print configuration on rank 0 From 0e41ab600b93529fdf400dfc6f42aea60573e71f Mon Sep 17 00:00:00 2001 From: sajadn Date: Sun, 16 Nov 2025 15:21:44 -0800 Subject: [PATCH 14/15] fix checkpoint loading issue. Signed-off-by: sajadn --- dfm/src/megatron/model/dit/dit_model.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/dfm/src/megatron/model/dit/dit_model.py b/dfm/src/megatron/model/dit/dit_model.py index e3ae8a29..46e8e742 100644 --- a/dfm/src/megatron/model/dit/dit_model.py +++ b/dfm/src/megatron/model/dit/dit_model.py @@ -287,6 +287,30 @@ def set_input_tensor(self, input_tensor: Tensor) -> None: assert len(input_tensor) == 1, "input_tensor should only be length 1 for gpt/bert" self.decoder.set_input_tensor(input_tensor[0]) + def load_state_dict(self, state_dict, strict=True): + """Load state dict with automatic handling of 'module.' prefix mismatch. + + This method handles the case where checkpoints saved with DistributedDataParallel + have a 'module.' prefix that needs to be removed when loading. + + Args: + state_dict (dict): The state dictionary to load + strict (bool): Whether to strictly enforce that the keys match + + Returns: + NamedTuple: with 'missing_keys' and 'unexpected_keys' fields + """ + # Check if state_dict has 'module.' prefix but model doesn't + has_module_prefix = any(k.startswith("module.") for k in state_dict.keys()) + if has_module_prefix: + new_state_dict = {} + for key, value in state_dict.items(): + new_key = key.replace("module.", "", 1) if key.startswith("module.") else key + new_state_dict[new_key] = value + state_dict = new_state_dict + + return super().load_state_dict(state_dict, strict=strict) + def sharded_state_dict( self, prefix: str = "module.", sharded_offsets: tuple = (), metadata: Optional[Dict] = None ) -> ShardedStateDict: From 33bfbbeb5102f720ba36a6a5d0dadad57bc2446f Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Mon, 17 Nov 2025 08:00:43 +0000 Subject: [PATCH 15/15] Implement functional smoke tests for Mcore DiT pretrain and update test command in GPU mock tests. Added a new test file for DiT pretraining and modified the existing GPU test script to run all tests in the recipes directory. --- .../L2_Mcore_Mock_Tests_GPU.sh | 2 +- .../mcore/recipes/test_dit_pretrain.py | 86 +++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 tests/functional_tests/mcore/recipes/test_dit_pretrain.py diff --git a/tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh b/tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh index 7977af26..477901da 100644 --- a/tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh +++ b/tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -CUDA_VISIBLE_DEVICES="0,1" uv run --group megatron-bridge coverage run -a --data-file=/opt/DFM/.coverage --source=/opt/DFM/ -m pytest tests/functional_tests/mcore/recipes/test_wan_pretrain.py -m "not pleasefixme" --with_downloads -v +CUDA_VISIBLE_DEVICES="0,1" uv run --group megatron-bridge coverage run -a --data-file=/opt/DFM/.coverage --source=/opt/DFM/ -m pytest tests/functional_tests/mcore/recipes -m "not pleasefixme" --with_downloads -v diff --git a/tests/functional_tests/mcore/recipes/test_dit_pretrain.py b/tests/functional_tests/mcore/recipes/test_dit_pretrain.py new file mode 100644 index 00000000..87d5b0cf --- /dev/null +++ b/tests/functional_tests/mcore/recipes/test_dit_pretrain.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functional smoke tests for Mcore DiT pretrain mock runs.""" + +import os +import subprocess + +import pytest + + +class TestMcoreDiTPretrain: + """Test class for Mcore DiT pretrain functional tests.""" + + @pytest.mark.run_only_on("GPU") + def test_DiT_pretrain_mock(self, tmp_path): + """ + Functional test for DiT pretrain recipe with mock data. + + This test verifies that the DiT pretrain recipe can run successfully + in mock mode with minimal configuration, ensuring: + 1. The distributed training can start without errors + 2. Model initialization works correctly + 3. Forward/backward passes complete successfully + 4. The training loop executes without crashes + """ + # Set up temporary directories for dataset and checkpoints + dataset_path = os.path.join(tmp_path, "mock_dataset") + checkpoint_dir = os.path.join(tmp_path, "checkpoints") + os.makedirs(dataset_path, exist_ok=True) + os.makedirs(checkpoint_dir, exist_ok=True) + + # Build the command for the mock run + cmd = [ + "python", + "-m", + "torch.distributed.run", + "--nproc_per_node=2", + "examples/megatron/recipes/dit/pretrain_dit_model.py", + "train.train_iters=10", + "train.eval_iters=0", + "model.tensor_model_parallel_size=1", + "model.pipeline_model_parallel_size=1", + "model.context_parallel_size=1", + "model.qkv_format=thd", + "model.num_attention_heads=16", + "dataset.task_encoder_seq_length=4608", + "dataset.seq_length=4608", + "train.global_batch_size=2", + "train.micro_batch_size=1", + "--mock", + ] + + # Run the command with a timeout + try: + # Stream output in real-time instead of capturing it + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=1800, # 30 minute timeout + check=True, + ) + + # Print output for debugging if needed + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + + # Basic verification that the run completed + assert result.returncode == 0, f"Command failed with return code {result.returncode}" + + except subprocess.TimeoutExpired: + pytest.fail("DiT pretrain mock run exceeded timeout of 1800 seconds (30 minutes)") + except subprocess.CalledProcessError as e: + pytest.fail(f"DiT pretrain mock run failed with return code {e.returncode}")