Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
75e8921
Fix sequence padding for DiT. Add support for DiT Context Parallel wi…
sajadn Nov 13, 2025
38d2109
Enhance DiT and Wan layer specifications
abhinavg4 Nov 13, 2025
3d9bd13
Implement ProcessGroupCollection initialization in DiT and Wan models
abhinavg4 Nov 14, 2025
da3b3a2
Merge branch 'main' into megatron_fixes
abhinavg4 Nov 14, 2025
5824116
Update CONTRIBUTING.md to include detailed setup instructions for dev…
abhinavg4 Nov 14, 2025
0685906
Refactor import statements in dit_model.py to streamline dependencies…
abhinavg4 Nov 14, 2025
d2a7c6f
Refactor code style in DiT and Wan models
abhinavg4 Nov 14, 2025
471811f
Revert M4 changes
abhinavg4 Nov 15, 2025
f0a928b
Ruff
abhinavg4 Nov 15, 2025
f0aa573
Ruff
abhinavg4 Nov 15, 2025
e5c6b5b
Lint
abhinavg4 Nov 15, 2025
9af20cd
Merge branch 'main' into fix_dit_cp
abhinavg4 Nov 16, 2025
4be3eaf
Merge branch 'megatron_fixes' into fix_dit_cp
abhinavg4 Nov 16, 2025
344898f
Fix sequence padding for DiT. Add support for DiT Context Parallel wi…
sajadn Nov 13, 2025
34a8c50
fix cp inference. add cu_seqlen_kv_padded which was missing.
sajadn Nov 16, 2025
71ef76a
Add mock DiT dataset. Make DiT attention compatible with megatron bri…
sajadn Nov 16, 2025
0e41ab6
fix checkpoint loading issue.
sajadn Nov 16, 2025
27224d6
Merge branch 'fix_dit_cp' of github.com:NVIDIA-NeMo/DFM into fix_dit_cp
abhinavg4 Nov 17, 2025
6504613
Merge remote-tracking branch 'origin/main' into fix_dit_cp
abhinavg4 Nov 17, 2025
33bfbbe
Implement functional smoke tests for Mcore DiT pretrain and update te…
abhinavg4 Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions dfm/src/common/utils/torch_split_tensor_for_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch
import transformer_engine_torch as tex
from torch import Tensor
from torch.distributed import ProcessGroup, all_gather, get_world_size
from torch.distributed import ProcessGroup, all_gather, get_rank, get_world_size


def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor:
def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup, thd_cu_seqlens: Optional[Tensor] = None) -> Tensor:
"""
Concatenates tensors from multiple processes along a specified dimension.

Expand All @@ -28,24 +31,41 @@ def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor:
x (Tensor): The input tensor to be gathered and concatenated.
seq_dim (int): The dimension along which to concatenate the gathered tensors.
cp_group (ProcessGroup): The process group containing all the processes involved in the gathering.
thd_cu_seqlens (Tensor, optional): THD cumulative sequence lengths used during partitioning. Provide
this to restore the original token order after gathering.

Returns:
Tensor: A tensor resulting from the concatenation of tensors from all processes.
Tensor: A tensor resulting from the concatenation of tensors from all processes. If `thd_cu_seqlens`
is provided, the tensor is reordered to match the original (pre-partition) sequence order.

Raises:
RuntimeError: If the gathering of tensors fails.
"""
# Number of processes in the group
world_size = get_world_size(cp_group)

# List to hold tensors from each rank
gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)]

# Attempt to gather tensors from all ranks
try:
all_gather(gathered_tensors, x, group=cp_group)
except RuntimeError as e:
raise RuntimeError(f"Gathering failed: {e}")
all_gather(gathered_tensors, x, group=cp_group)

# Concatenate tensors along the specified dimension
return torch.cat(gathered_tensors, dim=seq_dim)
gathered = torch.cat(gathered_tensors, dim=seq_dim)
total_seq_len = int(thd_cu_seqlens[-1].item())
# Rebuild the global index ordering used during THD partitioning.
cp_rank = get_rank(cp_group)
local_indices = tex.thd_get_partitioned_indices(thd_cu_seqlens, total_seq_len, world_size, cp_rank).to(
device=x.device, dtype=torch.long
)

# Gather indices from all ranks to compute the inverse permutation.
gathered_indices = [torch.empty_like(local_indices) for _ in range(world_size)]
all_gather(gathered_indices, local_indices, group=cp_group)
global_indices = torch.cat(gathered_indices, dim=0)

if global_indices.numel() != gathered.size(seq_dim):
raise RuntimeError("Gathered indices size does not match gathered tensor along sequence dimension.")

restore_order = torch.argsort(global_indices, dim=0)
gathered = gathered.index_select(seq_dim, restore_order.to(device=gathered.device))
return gathered.contiguous()
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def cat(attr):
__subflavors__=samples[0].__subflavors__,
video=cat("video"),
context_embeddings=cat("context_embeddings"),
context_mask=cat("context_mask"),
loss_mask=cat("loss_mask"),
seq_len_q=cat("seq_len_q"),
seq_len_q_padded=cat("seq_len_q_padded"),
Expand Down
165 changes: 165 additions & 0 deletions dfm/src/megatron/data/dit/dit_mock_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=C0115,C0116,C0301

from dataclasses import dataclass

import torch
from einops import rearrange
from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider
from torch.utils.data import DataLoader, Dataset

from dfm.src.megatron.data.dit.dit_taskencoder import PosID3D


pos_id_3d = PosID3D()


class _MockDataset(Dataset):
def __init__(self, length: int):
self.length = max(int(length), 1)

def __len__(self) -> int:
return self.length

def __getitem__(self, idx: int) -> dict:
return {}


def mock_batch(
F_latents: int,
H_latents: int,
W_latents: int,
patch_temporal: int,
patch_spatial: int,
number_packed_samples: int,
context_seq_len: int,
context_embeddings_dim: int,
) -> dict:
# set mock values for one video sample
video_latent = torch.randn(16, F_latents, H_latents, W_latents, dtype=torch.float32)
C, T, H, W = video_latent.shape
video_latent = rearrange(
video_latent,
"C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)",
ph=patch_spatial,
pw=patch_spatial,
pt=patch_temporal,
)
video_latent = torch.as_tensor(video_latent, dtype=torch.bfloat16)

context_embeddings = torch.randn(context_seq_len, context_embeddings_dim, dtype=torch.bfloat16)
context_embeddings_seq_length = context_embeddings.shape[0]
context_embeddings_mask = torch.ones(context_embeddings_seq_length, dtype=torch.bfloat16)

pos_ids = rearrange(
pos_id_3d.get_pos_id_3d(t=T // patch_temporal, h=H // patch_spatial, w=W // patch_spatial),
"T H W d -> (T H W) d",
)

seq_len_q = video_latent.shape[0]
seq_len_q_padded = seq_len_q

loss_mask = torch.ones(seq_len_q, dtype=torch.bfloat16)

seq_len_kv = context_embeddings.shape[0]
seq_len_kv_padded = seq_len_kv

video_latents_packed = [video_latent for _ in range(number_packed_samples)]
video_latents_packed = torch.cat(video_latents_packed, dim=0)

context_embeddings_packed = [context_embeddings for _ in range(number_packed_samples)]
context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0)

context_embeddings_mask_packed = [context_embeddings_mask for _ in range(number_packed_samples)]
context_embeddings_mask_packed = torch.cat(context_embeddings_mask_packed, dim=0)

loss_masks_packed = [loss_mask for _ in range(number_packed_samples)]
loss_masks_packed = torch.cat(loss_masks_packed, dim=0)

seq_len_q_packed = torch.tensor([seq_len_q for _ in range(number_packed_samples)], dtype=torch.int32)
seq_len_q_padded_packed = torch.tensor([seq_len_q_padded for _ in range(number_packed_samples)], dtype=torch.int32)

seq_len_kv_packed = torch.tensor([seq_len_kv for _ in range(number_packed_samples)], dtype=torch.int32)
seq_len_kv_padded_packed = torch.tensor(
[seq_len_kv_padded for _ in range(number_packed_samples)], dtype=torch.int32
)

pos_ids_packed = [pos_ids for _ in range(number_packed_samples)]
pos_ids_packed = torch.cat(pos_ids_packed, dim=0)

context_embeddings_packed = [context_embeddings for _ in range(number_packed_samples)]
context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0)

batch = dict(
video=video_latents_packed.unsqueeze(0),
context_embeddings=context_embeddings_packed.unsqueeze(0),
context_mask=context_embeddings_mask_packed.unsqueeze(0),
loss_mask=loss_masks_packed.unsqueeze(0),
seq_len_q=seq_len_q_packed,
seq_len_q_padded=seq_len_q_padded_packed,
seq_len_kv=seq_len_kv_packed,
seq_len_kv_padded=seq_len_kv_padded_packed,
latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32),
pos_ids=pos_ids_packed,
)

return batch


@dataclass(kw_only=True)
class DiTMockDataModuleConfig(DatasetProvider):
path: str = ""
seq_length: int
packing_buffer_size: int
micro_batch_size: int
global_batch_size: int
num_workers: int
dataloader_type: str = "external"
task_encoder_seq_length: int = None
F_latents: int = 1
H_latents: int = 64
W_latents: int = 96
patch_spatial: int = 2
patch_temporal: int = 1
number_packed_samples: int = 3
context_seq_len: int = 512
context_embeddings_dim: int = 1024

def __post_init__(self):
mock_ds = _MockDataset(length=1024)
self._train_dl = DataLoader(
mock_ds,
batch_size=self.micro_batch_size,
num_workers=self.num_workers,
collate_fn=lambda samples: mock_batch(
F_latents=self.F_latents,
H_latents=self.H_latents,
W_latents=self.W_latents,
patch_temporal=self.patch_temporal,
patch_spatial=self.patch_spatial,
number_packed_samples=self.number_packed_samples,
context_seq_len=self.context_seq_len,
context_embeddings_dim=self.context_embeddings_dim,
),
shuffle=False,
drop_last=False,
)
self.sequence_length = self.seq_length

def build_datasets(self, _context: DatasetBuildContext):
if hasattr(self, "dataset"):
return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader()
return self._train_dl, self._train_dl, self._train_dl
54 changes: 25 additions & 29 deletions dfm/src/megatron/data/dit/dit_taskencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def encode_sample(self, sample: dict) -> DiffusionSample:
pw=self.patch_spatial,
pt=self.patch_temporal,
)
sample["pickle"] = sample["pickle"].cpu().float().numpy().squeeze(0)
sample["pickle"] = sample["pickle"].cpu().float().numpy()
if sample["pickle"].shape[0] == 1:
sample["pickle"] = sample["pickle"][0]
if is_image:
t5_text_embeddings = torch.from_numpy(sample["pickle"]).to(torch.bfloat16)
else:
Expand All @@ -103,40 +105,30 @@ def encode_sample(self, sample: dict) -> DiffusionSample:

if t5_text_embeddings_seq_length > self.text_embedding_padding_size:
t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size]
else:
t5_text_embeddings = F.pad(
t5_text_embeddings,
(
0,
0,
0,
self.text_embedding_padding_size - t5_text_embeddings_seq_length,
),
)
t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16)

if is_image:
h, w = info["image_height"], info["image_width"]
fps = torch.tensor([30] * 1, dtype=torch.bfloat16)
num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16)
else:
h, w = info["height"], info["width"]
fps = torch.tensor([info["framerate"]] * 1, dtype=torch.bfloat16)
num_frames = torch.tensor([info["num_frames"]] * 1, dtype=torch.bfloat16)
image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16)

pos_ids = rearrange(
pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial),
"T H W d -> (T H W) d",
)

if self.packing_buffer_size is None:
pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len))
loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16)
loss_mask[:seq_len] = 1
video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len))
else:
loss_mask = torch.ones(seq_len, dtype=torch.bfloat16)
loss_mask = torch.ones(seq_len, dtype=torch.bfloat16)
sharding_factor = 64
seq_len_q_padded = ((seq_len + sharding_factor - 1) // sharding_factor) * sharding_factor
seq_len_kv_padded = (
(t5_text_embeddings_seq_length + sharding_factor - 1) // sharding_factor
) * sharding_factor

if seq_len < seq_len_q_padded:
video_latent = F.pad(video_latent, (0, 0, 0, seq_len_q_padded - seq_len))
loss_mask = F.pad(loss_mask, (0, seq_len_q_padded - seq_len))
pos_ids = F.pad(pos_ids, (0, 0, 0, seq_len_q_padded - seq_len))

if t5_text_embeddings_seq_length < seq_len_kv_padded:
t5_text_embeddings = F.pad(
t5_text_embeddings, (0, 0, 0, seq_len_kv_padded - t5_text_embeddings_seq_length)
)
t5_text_mask = F.pad(t5_text_mask, (0, seq_len_kv_padded - t5_text_embeddings_seq_length))

return DiffusionSample(
__key__=sample["__key__"],
Expand All @@ -148,7 +140,9 @@ def encode_sample(self, sample: dict) -> DiffusionSample:
context_mask=t5_text_mask,
loss_mask=loss_mask,
seq_len_q=torch.tensor([seq_len], dtype=torch.int32),
seq_len_kv=torch.tensor([self.text_embedding_padding_size], dtype=torch.int32),
seq_len_q_padded=torch.tensor([seq_len_q_padded], dtype=torch.int32),
seq_len_kv=torch.tensor([t5_text_embeddings_seq_length], dtype=torch.int32),
seq_len_kv_padded=torch.tensor([seq_len_kv_padded], dtype=torch.int32),
pos_ids=pos_ids,
latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32),
)
Expand All @@ -168,7 +162,9 @@ def batch(self, samples: List[DiffusionSample]) -> dict:
context_mask=sample.context_mask.unsqueeze_(0) if sample.context_mask is not None else None,
loss_mask=sample.loss_mask.unsqueeze_(0) if sample.loss_mask is not None else None,
seq_len_q=sample.seq_len_q,
seq_len_q_padded=sample.seq_len_q_padded,
seq_len_kv=sample.seq_len_kv,
seq_len_kv_padded=sample.seq_len_kv_padded,
pos_ids=sample.pos_ids.unsqueeze_(0) if sample.pos_ids is not None else None,
latent_shape=sample.latent_shape,
)
Expand Down
Loading