diff --git a/dfm/src/common/utils/torch_split_tensor_for_cp.py b/dfm/src/common/utils/torch_split_tensor_for_cp.py index f389b7ba..6d72050b 100644 --- a/dfm/src/common/utils/torch_split_tensor_for_cp.py +++ b/dfm/src/common/utils/torch_split_tensor_for_cp.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch +import transformer_engine_torch as tex from torch import Tensor -from torch.distributed import ProcessGroup, all_gather, get_world_size +from torch.distributed import ProcessGroup, all_gather, get_rank, get_world_size -def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: +def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup, thd_cu_seqlens: Optional[Tensor] = None) -> Tensor: """ Concatenates tensors from multiple processes along a specified dimension. @@ -28,24 +31,41 @@ def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: x (Tensor): The input tensor to be gathered and concatenated. seq_dim (int): The dimension along which to concatenate the gathered tensors. cp_group (ProcessGroup): The process group containing all the processes involved in the gathering. + thd_cu_seqlens (Tensor, optional): THD cumulative sequence lengths used during partitioning. Provide + this to restore the original token order after gathering. Returns: - Tensor: A tensor resulting from the concatenation of tensors from all processes. + Tensor: A tensor resulting from the concatenation of tensors from all processes. If `thd_cu_seqlens` + is provided, the tensor is reordered to match the original (pre-partition) sequence order. Raises: RuntimeError: If the gathering of tensors fails. """ # Number of processes in the group world_size = get_world_size(cp_group) - # List to hold tensors from each rank gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)] # Attempt to gather tensors from all ranks - try: - all_gather(gathered_tensors, x, group=cp_group) - except RuntimeError as e: - raise RuntimeError(f"Gathering failed: {e}") + all_gather(gathered_tensors, x, group=cp_group) # Concatenate tensors along the specified dimension - return torch.cat(gathered_tensors, dim=seq_dim) + gathered = torch.cat(gathered_tensors, dim=seq_dim) + total_seq_len = int(thd_cu_seqlens[-1].item()) + # Rebuild the global index ordering used during THD partitioning. + cp_rank = get_rank(cp_group) + local_indices = tex.thd_get_partitioned_indices(thd_cu_seqlens, total_seq_len, world_size, cp_rank).to( + device=x.device, dtype=torch.long + ) + + # Gather indices from all ranks to compute the inverse permutation. + gathered_indices = [torch.empty_like(local_indices) for _ in range(world_size)] + all_gather(gathered_indices, local_indices, group=cp_group) + global_indices = torch.cat(gathered_indices, dim=0) + + if global_indices.numel() != gathered.size(seq_dim): + raise RuntimeError("Gathered indices size does not match gathered tensor along sequence dimension.") + + restore_order = torch.argsort(global_indices, dim=0) + gathered = gathered.index_select(seq_dim, restore_order.to(device=gathered.device)) + return gathered.contiguous() diff --git a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py index f3fae0b0..a44f36dd 100644 --- a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py +++ b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py @@ -100,6 +100,7 @@ def cat(attr): __subflavors__=samples[0].__subflavors__, video=cat("video"), context_embeddings=cat("context_embeddings"), + context_mask=cat("context_mask"), loss_mask=cat("loss_mask"), seq_len_q=cat("seq_len_q"), seq_len_q_padded=cat("seq_len_q_padded"), diff --git a/dfm/src/megatron/data/dit/dit_mock_datamodule.py b/dfm/src/megatron/data/dit/dit_mock_datamodule.py new file mode 100644 index 00000000..22640292 --- /dev/null +++ b/dfm/src/megatron/data/dit/dit_mock_datamodule.py @@ -0,0 +1,165 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from dataclasses import dataclass + +import torch +from einops import rearrange +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider +from torch.utils.data import DataLoader, Dataset + +from dfm.src.megatron.data.dit.dit_taskencoder import PosID3D + + +pos_id_3d = PosID3D() + + +class _MockDataset(Dataset): + def __init__(self, length: int): + self.length = max(int(length), 1) + + def __len__(self) -> int: + return self.length + + def __getitem__(self, idx: int) -> dict: + return {} + + +def mock_batch( + F_latents: int, + H_latents: int, + W_latents: int, + patch_temporal: int, + patch_spatial: int, + number_packed_samples: int, + context_seq_len: int, + context_embeddings_dim: int, +) -> dict: + # set mock values for one video sample + video_latent = torch.randn(16, F_latents, H_latents, W_latents, dtype=torch.float32) + C, T, H, W = video_latent.shape + video_latent = rearrange( + video_latent, + "C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)", + ph=patch_spatial, + pw=patch_spatial, + pt=patch_temporal, + ) + video_latent = torch.as_tensor(video_latent, dtype=torch.bfloat16) + + context_embeddings = torch.randn(context_seq_len, context_embeddings_dim, dtype=torch.bfloat16) + context_embeddings_seq_length = context_embeddings.shape[0] + context_embeddings_mask = torch.ones(context_embeddings_seq_length, dtype=torch.bfloat16) + + pos_ids = rearrange( + pos_id_3d.get_pos_id_3d(t=T // patch_temporal, h=H // patch_spatial, w=W // patch_spatial), + "T H W d -> (T H W) d", + ) + + seq_len_q = video_latent.shape[0] + seq_len_q_padded = seq_len_q + + loss_mask = torch.ones(seq_len_q, dtype=torch.bfloat16) + + seq_len_kv = context_embeddings.shape[0] + seq_len_kv_padded = seq_len_kv + + video_latents_packed = [video_latent for _ in range(number_packed_samples)] + video_latents_packed = torch.cat(video_latents_packed, dim=0) + + context_embeddings_packed = [context_embeddings for _ in range(number_packed_samples)] + context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0) + + context_embeddings_mask_packed = [context_embeddings_mask for _ in range(number_packed_samples)] + context_embeddings_mask_packed = torch.cat(context_embeddings_mask_packed, dim=0) + + loss_masks_packed = [loss_mask for _ in range(number_packed_samples)] + loss_masks_packed = torch.cat(loss_masks_packed, dim=0) + + seq_len_q_packed = torch.tensor([seq_len_q for _ in range(number_packed_samples)], dtype=torch.int32) + seq_len_q_padded_packed = torch.tensor([seq_len_q_padded for _ in range(number_packed_samples)], dtype=torch.int32) + + seq_len_kv_packed = torch.tensor([seq_len_kv for _ in range(number_packed_samples)], dtype=torch.int32) + seq_len_kv_padded_packed = torch.tensor( + [seq_len_kv_padded for _ in range(number_packed_samples)], dtype=torch.int32 + ) + + pos_ids_packed = [pos_ids for _ in range(number_packed_samples)] + pos_ids_packed = torch.cat(pos_ids_packed, dim=0) + + context_embeddings_packed = [context_embeddings for _ in range(number_packed_samples)] + context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0) + + batch = dict( + video=video_latents_packed.unsqueeze(0), + context_embeddings=context_embeddings_packed.unsqueeze(0), + context_mask=context_embeddings_mask_packed.unsqueeze(0), + loss_mask=loss_masks_packed.unsqueeze(0), + seq_len_q=seq_len_q_packed, + seq_len_q_padded=seq_len_q_padded_packed, + seq_len_kv=seq_len_kv_packed, + seq_len_kv_padded=seq_len_kv_padded_packed, + latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), + pos_ids=pos_ids_packed, + ) + + return batch + + +@dataclass(kw_only=True) +class DiTMockDataModuleConfig(DatasetProvider): + path: str = "" + seq_length: int + packing_buffer_size: int + micro_batch_size: int + global_batch_size: int + num_workers: int + dataloader_type: str = "external" + task_encoder_seq_length: int = None + F_latents: int = 1 + H_latents: int = 64 + W_latents: int = 96 + patch_spatial: int = 2 + patch_temporal: int = 1 + number_packed_samples: int = 3 + context_seq_len: int = 512 + context_embeddings_dim: int = 1024 + + def __post_init__(self): + mock_ds = _MockDataset(length=1024) + self._train_dl = DataLoader( + mock_ds, + batch_size=self.micro_batch_size, + num_workers=self.num_workers, + collate_fn=lambda samples: mock_batch( + F_latents=self.F_latents, + H_latents=self.H_latents, + W_latents=self.W_latents, + patch_temporal=self.patch_temporal, + patch_spatial=self.patch_spatial, + number_packed_samples=self.number_packed_samples, + context_seq_len=self.context_seq_len, + context_embeddings_dim=self.context_embeddings_dim, + ), + shuffle=False, + drop_last=False, + ) + self.sequence_length = self.seq_length + + def build_datasets(self, _context: DatasetBuildContext): + if hasattr(self, "dataset"): + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() + return self._train_dl, self._train_dl, self._train_dl diff --git a/dfm/src/megatron/data/dit/dit_taskencoder.py b/dfm/src/megatron/data/dit/dit_taskencoder.py index fe3e6180..f7553bce 100644 --- a/dfm/src/megatron/data/dit/dit_taskencoder.py +++ b/dfm/src/megatron/data/dit/dit_taskencoder.py @@ -94,7 +94,9 @@ def encode_sample(self, sample: dict) -> DiffusionSample: pw=self.patch_spatial, pt=self.patch_temporal, ) - sample["pickle"] = sample["pickle"].cpu().float().numpy().squeeze(0) + sample["pickle"] = sample["pickle"].cpu().float().numpy() + if sample["pickle"].shape[0] == 1: + sample["pickle"] = sample["pickle"][0] if is_image: t5_text_embeddings = torch.from_numpy(sample["pickle"]).to(torch.bfloat16) else: @@ -103,40 +105,30 @@ def encode_sample(self, sample: dict) -> DiffusionSample: if t5_text_embeddings_seq_length > self.text_embedding_padding_size: t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size] - else: - t5_text_embeddings = F.pad( - t5_text_embeddings, - ( - 0, - 0, - 0, - self.text_embedding_padding_size - t5_text_embeddings_seq_length, - ), - ) t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16) - if is_image: - h, w = info["image_height"], info["image_width"] - fps = torch.tensor([30] * 1, dtype=torch.bfloat16) - num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16) - else: - h, w = info["height"], info["width"] - fps = torch.tensor([info["framerate"]] * 1, dtype=torch.bfloat16) - num_frames = torch.tensor([info["num_frames"]] * 1, dtype=torch.bfloat16) - image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16) - pos_ids = rearrange( pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial), "T H W d -> (T H W) d", ) - if self.packing_buffer_size is None: - pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len)) - loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16) - loss_mask[:seq_len] = 1 - video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len)) - else: - loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) + loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) + sharding_factor = 64 + seq_len_q_padded = ((seq_len + sharding_factor - 1) // sharding_factor) * sharding_factor + seq_len_kv_padded = ( + (t5_text_embeddings_seq_length + sharding_factor - 1) // sharding_factor + ) * sharding_factor + + if seq_len < seq_len_q_padded: + video_latent = F.pad(video_latent, (0, 0, 0, seq_len_q_padded - seq_len)) + loss_mask = F.pad(loss_mask, (0, seq_len_q_padded - seq_len)) + pos_ids = F.pad(pos_ids, (0, 0, 0, seq_len_q_padded - seq_len)) + + if t5_text_embeddings_seq_length < seq_len_kv_padded: + t5_text_embeddings = F.pad( + t5_text_embeddings, (0, 0, 0, seq_len_kv_padded - t5_text_embeddings_seq_length) + ) + t5_text_mask = F.pad(t5_text_mask, (0, seq_len_kv_padded - t5_text_embeddings_seq_length)) return DiffusionSample( __key__=sample["__key__"], @@ -148,7 +140,9 @@ def encode_sample(self, sample: dict) -> DiffusionSample: context_mask=t5_text_mask, loss_mask=loss_mask, seq_len_q=torch.tensor([seq_len], dtype=torch.int32), - seq_len_kv=torch.tensor([self.text_embedding_padding_size], dtype=torch.int32), + seq_len_q_padded=torch.tensor([seq_len_q_padded], dtype=torch.int32), + seq_len_kv=torch.tensor([t5_text_embeddings_seq_length], dtype=torch.int32), + seq_len_kv_padded=torch.tensor([seq_len_kv_padded], dtype=torch.int32), pos_ids=pos_ids, latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), ) @@ -168,7 +162,9 @@ def batch(self, samples: List[DiffusionSample]) -> dict: context_mask=sample.context_mask.unsqueeze_(0) if sample.context_mask is not None else None, loss_mask=sample.loss_mask.unsqueeze_(0) if sample.loss_mask is not None else None, seq_len_q=sample.seq_len_q, + seq_len_q_padded=sample.seq_len_q_padded, seq_len_kv=sample.seq_len_kv, + seq_len_kv_padded=sample.seq_len_kv_padded, pos_ids=sample.pos_ids.unsqueeze_(0) if sample.pos_ids is not None else None, latent_shape=sample.latent_shape, ) diff --git a/dfm/src/megatron/model/dit/dit_data_process.py b/dfm/src/megatron/model/dit/dit_data_process.py index e9e9344c..e599581a 100644 --- a/dfm/src/megatron/model/dit/dit_data_process.py +++ b/dfm/src/megatron/model/dit/dit_data_process.py @@ -13,16 +13,18 @@ # limitations under the License. import torch +from megatron.core import parallel_state as ps from megatron.core.packed_seq_params import PackedSeqParams def dit_data_step(qkv_format, dataloader_iter): # import pdb;pdb.set_trace() batch = next(iter(dataloader_iter.iterable)) - batch = get_batch_on_this_cp_rank(batch) - batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} batch["is_preprocessed"] = True # assume data is preprocessed - return encode_seq_length(batch, format=qkv_format) + batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} + batch = encode_seq_length(batch, format=qkv_format) + batch = get_batch_on_this_cp_rank(batch) + return batch def encode_seq_length(batch, format): @@ -35,19 +37,25 @@ def encode_seq_length(batch, format): cu_seqlens_kv = batch["seq_len_kv"].cumsum(dim=0).to(torch.int32) cu_seqlens_kv = torch.cat((zero, cu_seqlens_kv)) + cu_seqlens_q_padded = batch["seq_len_q_padded"].cumsum(dim=0).to(torch.int32) + cu_seqlens_q_padded = torch.cat((zero, cu_seqlens_q_padded)) + + cu_seqlens_kv_padded = batch["seq_len_kv_padded"].cumsum(dim=0).to(torch.int32) + cu_seqlens_kv_padded = torch.cat((zero, cu_seqlens_kv_padded)) + batch["packed_seq_params"] = { "self_attention": PackedSeqParams( cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_q, - cu_seqlens_q_padded=None, - cu_seqlens_kv_padded=None, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_q_padded, qkv_format=format, ), "cross_attention": PackedSeqParams( cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None, - cu_seqlens_kv_padded=None, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, qkv_format=format, ), } @@ -57,34 +65,26 @@ def encode_seq_length(batch, format): def get_batch_on_this_cp_rank(data): """Split the data for context parallelism.""" - from megatron.core import mpu + cp_size = ps.get_context_parallel_world_size() + if cp_size > 1: + import transformer_engine_torch as tex - cp_size = mpu.get_context_parallel_world_size() - cp_rank = mpu.get_context_parallel_rank() + cp_rank = ps.get_context_parallel_rank() + for key in ["video", "loss_mask", "pos_ids"]: + if data[key] is not None: + index = tex.thd_get_partitioned_indices( + data["packed_seq_params"]["self_attention"].cu_seqlens_q_padded, + data[key].size(1), + cp_size, + cp_rank, + ).to(device=data[key].device, dtype=torch.long) + data[key] = data[key].index_select(1, index).contiguous() - t = 16 - if cp_size > 1: - # cp split on seq_length, for video_latent, noise_latent and pos_ids - assert t % cp_size == 0, "t must divisibly by cp_size" - num_valid_tokens_in_ub = None - if "loss_mask" in data and data["loss_mask"] is not None: - num_valid_tokens_in_ub = data["loss_mask"].sum() - - for key, value in data.items(): - if (value is not None) and (key in ["video", "video_latent", "noise_latent", "pos_ids"]): - if len(value.shape) > 5: - value = value.squeeze(0) - B, C, T, H, W = value.shape - if T % cp_size == 0: - # FIXME packed sequencing - data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous() - else: - # FIXME packed sequencing - data[key] = value.view(B, C, T, cp_size, H // cp_size, W)[:, :, :, cp_rank, ...].contiguous() - loss_mask = data["loss_mask"] - data["loss_mask"] = loss_mask.view(loss_mask.shape[0], cp_size, loss_mask.shape[1] // cp_size)[ - :, cp_rank, ... - ].contiguous() - data["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub + for key in ["context_embeddings", "context_mask"]: + if data[key] is not None: + index = tex.thd_get_partitioned_indices( + data["packed_seq_params"]["cross_attention"].cu_seqlens_kv, data[key].size(1), cp_size, cp_rank + ).to(device=data[key].device, dtype=torch.long) + data[key] = data[key].index_select(1, index).contiguous() return data diff --git a/dfm/src/megatron/model/dit/dit_layer_spec.py b/dfm/src/megatron/model/dit/dit_layer_spec.py index 2b19be2c..97afaf1d 100644 --- a/dfm/src/megatron/model/dit/dit_layer_spec.py +++ b/dfm/src/megatron/model/dit/dit_layer_spec.py @@ -144,17 +144,11 @@ def _replace_no_cp_submodules(submodules): # Override Cross Attention to disable CP. # Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as Q and lead to # incorrect tensor shapes. - if submodules.cross_attention != IdentityOp: - cp_override_config = copy.deepcopy(config) - cp_override_config.context_parallel_size = 1 - cp_override_config.tp_comm_overlap = False - self.cross_attention = build_module( - submodules.cross_attention, - config=cp_override_config, - layer_number=layer_number, - ) - else: - self.cross_attention = None + self.cross_attention = build_module( + submodules.cross_attention, + config=self.config, + layer_number=layer_number, + ) self.full_self_attention = build_module( submodules.full_self_attention, diff --git a/dfm/src/megatron/model/dit/dit_model.py b/dfm/src/megatron/model/dit/dit_model.py index 38cb8422..55f89054 100644 --- a/dfm/src/megatron/model/dit/dit_model.py +++ b/dfm/src/megatron/model/dit/dit_model.py @@ -286,6 +286,30 @@ def set_input_tensor(self, input_tensor: Tensor) -> None: assert len(input_tensor) == 1, "input_tensor should only be length 1 for gpt/bert" self.decoder.set_input_tensor(input_tensor[0]) + def load_state_dict(self, state_dict, strict=True): + """Load state dict with automatic handling of 'module.' prefix mismatch. + + This method handles the case where checkpoints saved with DistributedDataParallel + have a 'module.' prefix that needs to be removed when loading. + + Args: + state_dict (dict): The state dictionary to load + strict (bool): Whether to strictly enforce that the keys match + + Returns: + NamedTuple: with 'missing_keys' and 'unexpected_keys' fields + """ + # Check if state_dict has 'module.' prefix but model doesn't + has_module_prefix = any(k.startswith("module.") for k in state_dict.keys()) + if has_module_prefix: + new_state_dict = {} + for key, value in state_dict.items(): + new_key = key.replace("module.", "", 1) if key.startswith("module.") else key + new_state_dict[new_key] = value + state_dict = new_state_dict + + return super().load_state_dict(state_dict, strict=strict) + def sharded_state_dict( self, prefix: str = "module.", sharded_offsets: tuple = (), metadata: Optional[Dict] = None ) -> ShardedStateDict: diff --git a/dfm/src/megatron/model/dit/dit_step.py b/dfm/src/megatron/model/dit/dit_step.py index 9bde05d5..600bd82a 100644 --- a/dfm/src/megatron/model/dit/dit_step.py +++ b/dfm/src/megatron/model/dit/dit_step.py @@ -18,6 +18,7 @@ from typing import Iterable import torch +import wandb from einops import rearrange from megatron.bridge.training.losses import masked_next_token_loss from megatron.bridge.training.state import GlobalState @@ -41,7 +42,7 @@ def __init__(self): self.train = True self.validation_step = 0 - def on_validation_start(self, batch, model, step): + def on_validation_start(self, state, batch, model, step): C, T, H, W = batch["latent_shape"][0] latent = self.diffusion_pipeline.generate_samples_from_batch( model, @@ -81,6 +82,36 @@ def on_validation_start(self, batch, model, step): video_save_path=f"{image_folder}/validation_step={step}_rank={rank}.mp4", ) + is_last_dp_rank = parallel_state.get_data_parallel_rank() == ( + parallel_state.get_data_parallel_world_size() - 1 + ) + + last_dp_local_rank = parallel_state.get_data_parallel_world_size() - 1 + dp_group = parallel_state.get_data_parallel_group() + dp_ranks = torch.distributed.get_process_group_ranks(dp_group) + wandb_rank = dp_ranks[last_dp_local_rank] + + if is_last_dp_rank: + gather_list = [None for _ in range(parallel_state.get_data_parallel_world_size())] + else: + gather_list = None + + torch.distributed.gather_object( + obj=decoded_video[0], + object_gather_list=gather_list, + dst=wandb_rank, + group=parallel_state.get_data_parallel_group(), + ) + if is_last_dp_rank: + if gather_list is not None: + videos = [] + for video_data in gather_list: + video_data_transposed = video_data.transpose(0, 3, 1, 2) + videos.append(wandb.Video(video_data_transposed, fps=24, format="mp4")) + + if state.wandb_logger is not None: + state.wandb_logger.log({"prediction": videos}) + def __call__( self, state: GlobalState, data_iterator: Iterable, model: GPTModel, return_schedule_plan: bool = False ) -> tuple[torch.Tensor, partial]: @@ -103,7 +134,7 @@ def __call__( self.train = False self.valid = True self.validation_step += 1 - self.on_validation_start(batch, model, step=self.validation_step) + self.on_validation_start(state, batch, model, step=self.validation_step) return self.forward_step(state, batch, model, return_schedule_plan) def data_process( diff --git a/dfm/src/megatron/model/dit/edm/edm_pipeline.py b/dfm/src/megatron/model/dit/edm/edm_pipeline.py index 4f2f8ece..dc4a6aba 100644 --- a/dfm/src/megatron/model/dit/edm/edm_pipeline.py +++ b/dfm/src/megatron/model/dit/edm/edm_pipeline.py @@ -360,18 +360,15 @@ def generate_samples_from_batch( if self._noise_generator is None: self._initialize_generators() x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) - - state_shape = list(state_shape) - state_shape[1] //= parallel_state.get_context_parallel_world_size() x_sigma_max = ( torch.randn(state_shape, **self.tensor_kwargs, generator=self._noise_generator) * self.sde.sigma_max ) samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) - if cp_enabled: cp_group = parallel_state.get_context_parallel_group() - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=cp_group) + thd_cu_seqlen_q_padded = data_batch["packed_seq_params"]["self_attention"].cu_seqlens_q_padded + samples = cat_outputs_cp(samples, seq_dim=1, cp_group=cp_group, thd_cu_seqlens=thd_cu_seqlen_q_padded) return samples diff --git a/dfm/src/megatron/recipes/dit/dit.py b/dfm/src/megatron/recipes/dit/dit.py index 14c15e89..74c762cd 100644 --- a/dfm/src/megatron/recipes/dit/dit.py +++ b/dfm/src/megatron/recipes/dit/dit.py @@ -31,6 +31,7 @@ from megatron.core.distributed import DistributedDataParallelConfig from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModuleConfig +from dfm.src.megatron.data.dit.dit_mock_datamodule import DiTMockDataModuleConfig from dfm.src.megatron.model.dit.dit_model_provider import DiTModelProvider @@ -158,6 +159,33 @@ def pretrain_config( precision_config.grad_reduce_in_fp32 = False + if mock: + dataset = DiTMockDataModuleConfig( + path=None, + seq_length=2048, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + task_encoder_seq_length=8000, + packing_buffer_size=40, + num_workers=10, + # mock arguments + F_latents=1, + H_latents=96, + W_latents=64, + context_seq_len=512, + context_embeddings_dim=1024, + ) + else: + dataset = DiffusionDataModuleConfig( + path=dataset_path, + seq_length=2048, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + task_encoder_seq_length=8000, + packing_buffer_size=40, + num_workers=10, + ) + # Config Container cfg = ConfigContainer( model=model_cfg, @@ -182,15 +210,7 @@ def pretrain_config( use_distributed_optimizer=True, use_megatron_fsdp=use_megatron_fsdp, # need use_distributed_optimizer=True ), - dataset=DiffusionDataModuleConfig( - path=dataset_path, - seq_length=2048, - task_encoder_seq_length=8000, - packing_buffer_size=40, - micro_batch_size=micro_batch_size, - global_batch_size=global_batch_size, - num_workers=10, - ), + dataset=dataset, logger=LoggerConfig( log_interval=10, tensorboard_dir=tensorboard_dir, diff --git a/examples/megatron/recipes/dit/pretrain_dit_model.py b/examples/megatron/recipes/dit/pretrain_dit_model.py index ae4dc54d..faf9f35c 100644 --- a/examples/megatron/recipes/dit/pretrain_dit_model.py +++ b/examples/megatron/recipes/dit/pretrain_dit_model.py @@ -86,6 +86,7 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: description="Pretrain Llama3 8B model using Megatron-Bridge with YAML and CLI overrides", formatter_class=argparse.RawTextHelpFormatter, ) + parser.add_argument("--mock", action="store_true", help="Whether to use mock data.") parser.add_argument( "--config-file", type=str, @@ -139,7 +140,7 @@ def main() -> None: logger.info("------------------------------------------------------------------") # Load base configuration from the recipe as a Python dataclass - cfg: ConfigContainer = pretrain_config(dataset_path=args.dataset_path) + cfg: ConfigContainer = pretrain_config(dataset_path=args.dataset_path, mock=args.mock) logger.info("Loaded base configuration") # Print configuration on rank 0 diff --git a/tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh b/tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh index 7977af26..477901da 100644 --- a/tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh +++ b/tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -CUDA_VISIBLE_DEVICES="0,1" uv run --group megatron-bridge coverage run -a --data-file=/opt/DFM/.coverage --source=/opt/DFM/ -m pytest tests/functional_tests/mcore/recipes/test_wan_pretrain.py -m "not pleasefixme" --with_downloads -v +CUDA_VISIBLE_DEVICES="0,1" uv run --group megatron-bridge coverage run -a --data-file=/opt/DFM/.coverage --source=/opt/DFM/ -m pytest tests/functional_tests/mcore/recipes -m "not pleasefixme" --with_downloads -v diff --git a/tests/functional_tests/mcore/recipes/test_dit_pretrain.py b/tests/functional_tests/mcore/recipes/test_dit_pretrain.py new file mode 100644 index 00000000..87d5b0cf --- /dev/null +++ b/tests/functional_tests/mcore/recipes/test_dit_pretrain.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functional smoke tests for Mcore DiT pretrain mock runs.""" + +import os +import subprocess + +import pytest + + +class TestMcoreDiTPretrain: + """Test class for Mcore DiT pretrain functional tests.""" + + @pytest.mark.run_only_on("GPU") + def test_DiT_pretrain_mock(self, tmp_path): + """ + Functional test for DiT pretrain recipe with mock data. + + This test verifies that the DiT pretrain recipe can run successfully + in mock mode with minimal configuration, ensuring: + 1. The distributed training can start without errors + 2. Model initialization works correctly + 3. Forward/backward passes complete successfully + 4. The training loop executes without crashes + """ + # Set up temporary directories for dataset and checkpoints + dataset_path = os.path.join(tmp_path, "mock_dataset") + checkpoint_dir = os.path.join(tmp_path, "checkpoints") + os.makedirs(dataset_path, exist_ok=True) + os.makedirs(checkpoint_dir, exist_ok=True) + + # Build the command for the mock run + cmd = [ + "python", + "-m", + "torch.distributed.run", + "--nproc_per_node=2", + "examples/megatron/recipes/dit/pretrain_dit_model.py", + "train.train_iters=10", + "train.eval_iters=0", + "model.tensor_model_parallel_size=1", + "model.pipeline_model_parallel_size=1", + "model.context_parallel_size=1", + "model.qkv_format=thd", + "model.num_attention_heads=16", + "dataset.task_encoder_seq_length=4608", + "dataset.seq_length=4608", + "train.global_batch_size=2", + "train.micro_batch_size=1", + "--mock", + ] + + # Run the command with a timeout + try: + # Stream output in real-time instead of capturing it + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=1800, # 30 minute timeout + check=True, + ) + + # Print output for debugging if needed + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + + # Basic verification that the run completed + assert result.returncode == 0, f"Command failed with return code {result.returncode}" + + except subprocess.TimeoutExpired: + pytest.fail("DiT pretrain mock run exceeded timeout of 1800 seconds (30 minutes)") + except subprocess.CalledProcessError as e: + pytest.fail(f"DiT pretrain mock run failed with return code {e.returncode}")