diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 0f016473b6a..00591e4c24d 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -1,10 +1,21 @@ # Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional, Type import torch from megatron.core import parallel_state +from megatron.core.datasets.data_schedule_utils import ( + broadcast_scalars, + broadcast_tensor, + broadcast_to_pp_group, + build_packed_microbatches, + create_data_iterator, + get_batch_and_global_seqlens, + get_cp_slice_for_thd, + reroute_samples_to_dcp_ranks, +) +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler from megatron.core.process_groups_config import ProcessGroupCollection @@ -299,3 +310,547 @@ def __next__(self) -> Any: batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets ) return samples_this_rank_with_id, sample_id_groups + + +class BasePackingScheduler: + """Base class for sequence packing schedulers.""" + + def __init__( + self, + max_seqlen_per_dp_cp_rank: int, + cp_size: int, + dp_size: int, + microbatch_group_size_per_vp_stage: Optional[int], + ): + """ + Args: + max_seqlen_per_dp_cp_rank: The maximum sequence length per DPxCP rank. + cp_size: The context parallel size. + dp_size: The data parallel size. + microbatch_group_size_per_vp_stage: The microbatch group size per virtual + pipeline stage, only used when enabling VPP, otherwise None. + """ + self.max_seqlen_per_dp_cp_rank = max_seqlen_per_dp_cp_rank + self.cp_size = cp_size + self.dp_size = dp_size + self.microbatch_group_size_per_vp_stage = microbatch_group_size_per_vp_stage + + def get_required_sample_keys(self): + """Return the required key of each batch.""" + raise NotImplementedError + + def get_groups_and_subsamples(self, sample_id_seqlens): + """schedule the samples into groups""" + raise NotImplementedError + + def run( + self, + data_iterator, + num_microbatches, + dp_group, + tp_group, + pp_group, + dp_cp_group, + dev, + config, + ): + """ + Run the scheduler and return the new data_iterator. + + Args: + data_iterator: The data iterator. + num_microbatches: The number of microbatches to fetch. + dp_group: Data parallel process group. + tp_group: Tensor parallel process group. + pp_group: Pipeline parallel process group. + dp_cp_group: Data parallel + context parallel process group. + dev: CUDA device. + config: Model parallel config. + + Returns: + new_data_iterator: The new data iterator (or list for VPP). + num_micro_batches: Number of micro batches after scheduling. + seqlen_sum_this_global_batch: Total tokens for FLOPs calculation. + seqlen_squared_sum_this_global_batch: Sum of squared seqlens for FLOPs. + """ + raise NotImplementedError + + +class DpBalancedScheduler(BasePackingScheduler): + """Packs sequences in their original order until reaching the max limit of sequence length.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_seq_len_all_ranks = self.max_seqlen_per_dp_cp_rank * self.cp_size + + def get_required_sample_keys(self): + """Return the required key of each batch.""" + return [ + "tokens", + "labels", + "loss_mask", + "position_ids", + "original_seq_len", # Length of the original sequence length, should be a gpu tensor. + "padded_seq_len", # Length of the padded sequence length, should be a gpu tensor. + ] + + def get_groups_and_subsamples(self, sample_id_seqlens): + """ + Packs sequences in their original order until reaching the max limit of sequence length. + """ + sample_id_groups = [] + packed_id_groups = [] + sum_seqlen = 0 + single_microbatch = [] + + for i in range(len(sample_id_seqlens)): + if sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks: + single_microbatch.append(i) + sum_seqlen += sample_id_seqlens[i][1] + else: + packed_id_groups.append(single_microbatch) + single_microbatch = [i] + sum_seqlen = sample_id_seqlens[i][1] + if len(single_microbatch) > 0: + packed_id_groups.append(single_microbatch) + + # we want the number of packed sequences to be multiple of dp_size + # so we move few samples from previous microbatch + # to the end of the microbatches if needed + num_packed_sequence = len(packed_id_groups) + + # when enabling vpp, we want the number of packed sequences to be + # multiple of dp_size * microbatch_group_size_per_vp_stage + multiple = self.dp_size * ( + self.microbatch_group_size_per_vp_stage + if self.microbatch_group_size_per_vp_stage is not None + else 1 + ) + if num_packed_sequence % multiple != 0: + remainder = num_packed_sequence % multiple + num_to_move = multiple - remainder + i = num_packed_sequence - 1 + while num_to_move > 0: + assert i >= 0, "Not enough samples to move" + if len(packed_id_groups[i]) > 1: + seq_id = packed_id_groups[i].pop() + packed_id_groups.append([seq_id]) + num_to_move -= 1 + else: + i -= 1 + + num_micro_batches = int(len(packed_id_groups) / self.dp_size) + for i in range(num_micro_batches): + sample_id_groups.append([]) + for j in range(self.cp_size * self.dp_size): + seq_id = int(i * self.dp_size + j / self.cp_size) + sample_id_groups[i].append(packed_id_groups[seq_id]) + return sample_id_groups + + def run( + self, + data_iterator, + num_microbatches: int, + dp_group, + tp_group, + pp_group, + dp_cp_group, + dev: torch.device, + config, + ): + """ + Run the complete scheduling pipeline. + + Steps: + 1. Fetch batches and gather global sequence lengths + 2. Check required sample keys + 3. Schedule samples into groups + 4. Reroute samples to DCP ranks + 5. Build packed microbatches + 6. Calculate FLOPs info + 7. Broadcast to PP group (for middle PP stages) + 8. Broadcast to TP group (for non-TP-0 ranks) + 9. Handle VPP if enabled + + Args: + data_iterator: The data iterator. + num_microbatches: The number of microbatches to fetch. + dp_group: Data parallel process group. + tp_group: Tensor parallel process group. + pp_group: Pipeline parallel process group. + dp_cp_group: Data parallel + context parallel process group. + dev: CUDA device. + config: Model parallel config. + + Returns: + new_data_iterator: The new data iterator (or list for VPP). + num_micro_batches: Number of micro batches after scheduling. + seqlen_sum_this_global_batch: Total tokens for FLOPs calculation. + seqlen_squared_sum_this_global_batch: Sum of squared seqlens for FLOPs. + """ + + total_dcp_gpus = dp_cp_group.size() + + # Handle VPP: extract the correct data_iterator for this PP stage. + # When VPP is enabled, data_iterator is a list with one entry per VPP stage. + # We only need one data_iterator to run the schedule (all VPP stages on the + # same PP rank share the same underlying dataset), so pick the first non-None. + # Record which VPP stages had data so create_data_iterator knows which ones + # need full samples vs metadata only. + vpp_has_data = None + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + assert len(data_iterator) == config.virtual_pipeline_model_parallel_size + vpp_has_data = [di is not None for di in data_iterator] + extracted = None + for di in data_iterator: + if di is not None: + extracted = di + break + data_iterator = extracted + + # data_iterator is not None on TP rank 0 for PP stages that need data + # (first stage, last stage, or any stage with MTP). + if data_iterator is not None: + assert tp_group.rank() == 0, "Only TP rank 0 should have data_iterator" + + # Step 1: Fetch batches and gather global sequence lengths + batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered = ( + get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group) + ) + + # Step 2: Check required sample keys + for key in self.get_required_sample_keys(): + assert ( + key in batch[0] + ), f"Batch missing required key {key}, provided keys: {batch[0].keys()}" + + # Step 3: Schedule samples into groups + sample_id_groups = self.get_groups_and_subsamples(global_id_seqlens) + + # Validate scheduling result + set_gbs = set() + for group in sample_id_groups: + for sub in group: + set_gbs.update(sub) + assert len(set_gbs) == len(global_id_seqlens), ( + f"set_gbs length: {len(set_gbs)} != " + f"global_id_seqlens length: {len(global_id_seqlens)}" + ) + + # Step 4: Reroute samples to DCP ranks + samples_this_rank_with_id = reroute_samples_to_dcp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_dcp_gpus, + ) + + dcp_rank = dp_cp_group.rank() + num_micro_batches = len(sample_id_groups) + + grouped_samples = [ + [ + samples_this_rank_with_id[sub_sample_id] + for sub_sample_id in sample_id_groups[i][dcp_rank] + ] + for i in range(num_micro_batches) + ] + + # Step 5: Build packed microbatches + new_samples = build_packed_microbatches(grouped_samples, dev) + + # Step 6: Calculate FLOPs info + seqlen_sum_this_global_batch = float(sum(seqlens_gathered)) + seqlen_squared_sum_this_global_batch = float( + sum(seqlen**2 for seqlen in seqlens_gathered) + ) + else: + ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = (None, None, None, None) + + # Step 7: Broadcast to PP group (for middle PP stages) + if tp_group.rank() == 0: + ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = broadcast_to_pp_group( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + pp_group, + dev, + ) + + # Step 8: Broadcast to TP group (for non-TP-0 ranks) + (num_micro_batches, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) = ( + broadcast_scalars( + [ + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ], + tp_group, + dev, + ) + ) + num_micro_batches = int(num_micro_batches) + + # Step 9: create data_iterator and handle VPP if enabled + new_data_iterator = create_data_iterator(new_samples, tp_group, config, vpp_has_data) + + return ( + new_data_iterator, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + +scheduler_map: Dict[str, Type[BasePackingScheduler]] = {"dp_balanced": DpBalancedScheduler} + + +def wrap_data_iterator( + data_iterator, config, num_microbatches, pg_collection: Optional[ProcessGroupCollection] = None +): + """ + A wrapper function that wraps around an existing data_iterator + and return the num_micro_batches for sequence packing. + + Args: + data_iterator: The original data_iterator to wrap around + config: The config object containing the max_seqlen_per_dp_cp_rank + dp_cp_group: Data parallel context parallel group. + pg_collection: The process group collection. + """ + + if pg_collection is None: + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + dp_group = parallel_state.get_data_parallel_group() + tp_group = parallel_state.get_tensor_model_parallel_group() + pp_group = parallel_state.get_pipeline_model_parallel_group() + else: + dp_cp_group = pg_collection.dp_cp + dp_group = pg_collection.dp + tp_group = pg_collection.tp + pp_group = pg_collection.pp + assert ( + dp_cp_group is not None + and dp_group is not None + and tp_group is not None + and pp_group is not None + ), "dp_cp_group, dp_group, tp_group must not be None when using sequence packing" + + dev = torch.cuda.current_device() + dp_size = dp_group.size() + cp_size = dp_cp_group.size() // dp_size + + # Look up the scheduler class by name + scheduler_type = config.sequence_packing_scheduler + + scheduler = scheduler_map[scheduler_type]( + config.max_seqlen_per_dp_cp_rank, + cp_size, + dp_size, + # When VPP is enabled, align num_micro_batches to this multiple. + ( + None + if config.virtual_pipeline_model_parallel_size is None + else config.microbatch_group_size_per_vp_stage + ), + ) + + ( + new_data_iterator, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = scheduler.run( + data_iterator, num_microbatches, dp_group, tp_group, pp_group, dp_cp_group, dev, config + ) + + return ( + new_data_iterator, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + +def get_batch_on_this_rank_for_sequence_packing( + data_iterator, + vpp_size: Optional[int] = None, + mtp_on_this_rank: bool = False, + vp_stage: Optional[int] = None, + pg_collection: Optional[ProcessGroupCollection] = None, +): + """ + Get a batch of data for sequence packing. + Args: + data_iterator (Iterator): The data iterator to get the batch from. + mtp_on_this_rank (bool): Whether to use multi-token prediction. + vp_stage (Optional[int]): The stage of the pipeline. + Returns: + tuple of (tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params) + """ + + if pg_collection is None: + tp_group = parallel_state.get_tensor_model_parallel_group() + pp_group = parallel_state.get_pipeline_model_parallel_group() + cp_group = parallel_state.get_context_parallel_group() + else: + tp_group = pg_collection.tp + pp_group = pg_collection.pp + cp_group = pg_collection.cp + + tp_src_rank = torch.distributed.get_process_group_ranks(tp_group)[0] + + is_tp_rank_0 = tp_group.rank() == 0 + is_first_stage = pp_group.rank() == 0 and (vp_stage is None or vp_stage == 0) + is_last_stage = pp_group.rank() == pp_group.size() - 1 and ( + vp_stage is None or vp_stage == vpp_size - 1 + ) + + is_first_or_last_stage = is_first_stage or is_last_stage + dev = torch.cuda.current_device() + + # data_iterator should return a batch including the following keys. + batch_keys = ['cu_seqlens', 'cu_seqlens_padded', 'max_seqlen'] + if is_first_stage or mtp_on_this_rank: + batch_keys.append('tokens') + batch_keys.append('position_ids') + if is_last_stage or mtp_on_this_rank: + batch_keys.append('labels') + batch_keys.append('loss_mask') + + # Get a batch from data_iterator or create an emtpy batch. + if is_tp_rank_0: + assert data_iterator is not None + batch = next(data_iterator) + for key in batch_keys: + assert key in batch, f"{key} is missing in current batch." + else: + assert data_iterator is None, "Non TP 0 rank should not have data_iterator" + batch = {} + + # Partition tokens, position_ids, labels, loss_mask for context parallel. + # Only TP rank 0 on stages that have data (first/last PP stage or MTP stage) needs this. + if is_tp_rank_0 and (is_first_or_last_stage or mtp_on_this_rank): + get_cp_slice_for_thd(batch, cp_group) + + # Broadcast cu_seqlens_size because we need it to create placeholder for cu_seqlens and + # cu_seqlens_padded for non TP 0 ranks. + if is_tp_rank_0: + cu_seqlen_size = torch.tensor(batch['cu_seqlens'].size(0), dtype=torch.int32, device=dev) + else: + cu_seqlen_size = torch.empty(1, dtype=torch.int32, device=dev) + broadcast_tensor(cu_seqlen_size, tp_src_rank, tp_group) + cu_seqlen_size = cu_seqlen_size.item() + + # Broadcast total_tokens because we need it to create placeholder for tokens, position_ids, + # labels, loss_mask for non TP 0 ranks. Only first stage, last stage, + # and stage with mtp need this. + + if is_first_or_last_stage or mtp_on_this_rank: + if is_tp_rank_0: + total_tokens = torch.tensor(batch['tokens'].size(0), dtype=torch.int32, device=dev) + else: + total_tokens = torch.empty(1, dtype=torch.int32, device=dev) + broadcast_tensor(total_tokens, tp_src_rank, tp_group) + total_tokens = total_tokens.item() + + # Step1: Prepare "tokens", "position_ids" for first stage and stage with mtp on all TP ranks. + if is_first_stage or mtp_on_this_rank: + if is_tp_rank_0: + assert batch['tokens'].dtype == torch.int64 + assert batch['position_ids'].dtype == torch.int64 + batch['tokens'] = batch['tokens'].view(1, total_tokens) + batch['position_ids'] = batch['position_ids'].view(1, total_tokens) + else: + batch['tokens'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + batch['position_ids'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + else: + # Non first stage rank doesn't need tokens and position_ids. + batch['tokens'] = None + batch['position_ids'] = None + + # Step2: Prepare "labels", "loss_mask" for last stage and stage with mtp on all TP ranks. + if is_last_stage or mtp_on_this_rank: + if is_tp_rank_0: + assert batch['labels'].dtype == torch.int64 + assert batch['loss_mask'].dtype == torch.float32 + batch['labels'] = batch['labels'].view(1, total_tokens) + batch['loss_mask'] = batch['loss_mask'].view(1, total_tokens) + else: + batch['labels'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + batch['loss_mask'] = torch.empty([1, total_tokens], dtype=torch.float32, device=dev) + else: + # Non last stage rank doesn't need labels and loss_mask. + batch['labels'] = None + batch['loss_mask'] = None + + # Step3: Prepare "cu_seqlens", "cu_seqlens_padded", "max_seqlen" on all ranks. + if is_tp_rank_0: + assert batch['cu_seqlens'].dtype == torch.int32 + assert batch['cu_seqlens_padded'].dtype == torch.int32 + assert batch['cu_seqlens'].dim() == 1 + assert batch['cu_seqlens_padded'].dim() == 1 + if type(batch['max_seqlen']) == int: + batch['max_seqlen'] = torch.tensor(batch['max_seqlen'], dtype=torch.int32, device=dev) + else: + assert batch['max_seqlen'].dtype == torch.int32 + assert batch['max_seqlen'].numel() == 1 + else: + batch['cu_seqlens'] = torch.empty([cu_seqlen_size], dtype=torch.int32, device=dev) + batch['cu_seqlens_padded'] = torch.empty([cu_seqlen_size], dtype=torch.int32, device=dev) + batch['max_seqlen'] = torch.empty(1, dtype=torch.int32, device=dev) + + # Broadcast batch inside TP group. + broadcast_tensor(batch['tokens'], tp_src_rank, tp_group) + broadcast_tensor(batch['position_ids'], tp_src_rank, tp_group) + broadcast_tensor(batch['labels'], tp_src_rank, tp_group) + broadcast_tensor(batch['loss_mask'], tp_src_rank, tp_group) + broadcast_tensor(batch['cu_seqlens'], tp_src_rank, tp_group) + broadcast_tensor(batch['cu_seqlens_padded'], tp_src_rank, tp_group) + broadcast_tensor(batch['max_seqlen'], tp_src_rank, tp_group) + + # Extract the data from batch after broadcasting. + tokens = batch['tokens'] + position_ids = batch['position_ids'] + labels = batch['labels'] + loss_mask = batch['loss_mask'] + cu_seqlens = batch['cu_seqlens'] + cu_seqlens_padded = batch['cu_seqlens_padded'] + max_seqlen = batch['max_seqlen'].item() + + # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as cu_seqlens to + # get the correct result. + # TODO: Revert this workaround once TE fixes the issue. + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + local_cp_size=None, + cp_group=None, + ) + + # "attention_mask" is not valid for sequence packing, so set it to None. + return tokens, labels, loss_mask, None, position_ids, packed_seq_params diff --git a/megatron/core/datasets/data_schedule_utils.py b/megatron/core/datasets/data_schedule_utils.py new file mode 100644 index 00000000000..f3c637e4c79 --- /dev/null +++ b/megatron/core/datasets/data_schedule_utils.py @@ -0,0 +1,529 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. + +from typing import Dict, List + +import numpy as np +import torch + +from megatron.core.extensions.transformer_engine import get_thd_partitioned_indices +from megatron.core.rerun_state_machine import RerunDataIterator + + +def get_cp_slice_for_thd(batch, cp_group): + """Partition sequence data for context parallelism in THD format. + + Uses TE's THD partitioned indices to split the packed sequence across CP ranks. + Only keys present in the batch are sliced. + + Args: + batch: Dict with packed sequence data. + cp_group: Context parallel process group. + """ + cp_size = cp_group.size() + if cp_size <= 1: + return + cp_rank = cp_group.rank() + total_tokens = batch['tokens'].size(0) + # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as + # cu_seqlens to get the correct result. + # TODO: Revert this workaround once TE fixes the issue. + cu_seqlens = batch["cu_seqlens_padded"] + index = get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) + for key in ['tokens', 'position_ids', 'labels', 'loss_mask']: + if key in batch: + batch[key] = batch[key].index_select(0, index) + + +def _unpack_batch(batch: List[Dict[str, torch.Tensor]]) -> List[Dict[str, torch.Tensor]]: + """ + Unpacks the packed samples into a list of sub-samples. + Since each sub-sample may be routed to different DPxCP ranks, + we unpack the sample here to avoid unnecessarily transferring + the entire packed sample. + """ + batch_unpacked = [] + dev = batch[0]["tokens"].device + original_seq_lens = [] + padded_seq_lens = [] + for sample in batch: + for key in sample.keys(): + if len(sample[key].shape) == 2: + # squeeze the redundant batch dimension added by + # default collate_fn in pytorch dataloader + # we need a custom collate_fn for THD to avoid this + # current THD does not support micro_batch_size > 1 due to sft_dataset.py and + # data_loader in data_samples.py + sample[key] = sample[key].squeeze(0) + for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): + sub_sample_dict = {} + start_idx = sample["cu_seqlens"][sub_sample] + end_idx = sample["cu_seqlens"][sub_sample + 1] + if end_idx - start_idx == 0: + continue + for key in ["tokens", "labels", "loss_mask", "position_ids"]: + sub_sample_dict[key] = sample[key][start_idx:end_idx] + # Since sft_dataset.py does not provide cu_seqlens_original, + # we assume original_seq_len equals padded_seq_len here. + # Ideally the dataset should define the pre-padding seq_len. + seq_len = (end_idx - start_idx).item() + original_seq_lens.append(seq_len) + padded_seq_lens.append(seq_len) + batch_unpacked.append(sub_sample_dict) + + # Single H2D transfer for all seq lens + original_seq_lens_cuda = torch.tensor(original_seq_lens, device=dev) + padded_seq_lens_cuda = torch.tensor(padded_seq_lens, device=dev) + for i, sub_sample_dict in enumerate(batch_unpacked): + sub_sample_dict["original_seq_len"] = original_seq_lens_cuda[i : i + 1] + sub_sample_dict["padded_seq_len"] = padded_seq_lens_cuda[i : i + 1] + + return batch_unpacked + + +def _get_global_seqlens_and_ids(subsample_seqlens: torch.Tensor, dp_group): + """ + Gathers the sequence lengths of all subsamples from all DP ranks and calculates global IDs. + """ + # Collect the number of subsamples from all ranks + num_local_subsamples = subsample_seqlens.shape[0] + local_len = torch.tensor([num_local_subsamples], dtype=torch.int32).cuda() + dp_subsample_count = [torch.zeros_like(local_len) for _ in range(dp_group.size())] + torch.distributed.all_gather(dp_subsample_count, local_len, group=dp_group) + + # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length + dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) + max_sub_samples = int(dp_subsample_counts.max().item()) + + if num_local_subsamples < max_sub_samples: + subsample_seqlens_padded = torch.cat( + [ + subsample_seqlens, + torch.zeros(max_sub_samples - num_local_subsamples, dtype=torch.int32).cuda(), + ], + dim=0, + ) + else: + subsample_seqlens_padded = subsample_seqlens + + # Gather the subsample_seqlens from all ranks + seqlens_gathered = [torch.empty_like(subsample_seqlens_padded) for _ in range(dp_group.size())] + torch.distributed.all_gather(seqlens_gathered, subsample_seqlens_padded, group=dp_group) + + # Trim each seqlens_gathered to the length of the correct sample + for dp_rank, seqlen in enumerate(seqlens_gathered): + seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]] + + seqlens_gathered = torch.cat(seqlens_gathered, dim=0) + seqlens_gathered = seqlens_gathered.cpu().tolist() + + # Calculate the offsets to assign unique global ID to each subsample. + csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32) + offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum], dim=0) + + # Calculate global ID for each subsample + dp_rank = dp_group.rank() + global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() + + # Create a list of (global_id, seqlen) tuples for scheduling + global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] + + # Get the global IDs locally present on this rank + start_idx = offsets[dp_rank] + end_idx = offsets[dp_rank + 1] + + global_ids_this_rank = global_ids[start_idx:end_idx] + + return global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered + + +def _pack_sequences( + samples: List, padded_lengths: torch.Tensor, original_lengths: torch.Tensor, dev: torch.device +) -> Dict[str, torch.Tensor]: + """Pack multiple samples into a single packed sample.""" + + def _pack_tensors(tensors): + return torch.cat([t.reshape(-1) for t in tensors], dim=0) + + tokens = _pack_tensors([sample["tokens"] for sample in samples]) + labels = _pack_tensors([sample["labels"] for sample in samples]) + loss_mask = _pack_tensors([sample["loss_mask"] for sample in samples]) + position_ids = _pack_tensors([sample["position_ids"] for sample in samples]) + + new_sample = {} + new_sample["tokens"] = tokens + new_sample["labels"] = labels + new_sample["loss_mask"] = loss_mask + new_sample["position_ids"] = position_ids + + padded_lengths = padded_lengths.to(device=dev, dtype=torch.int32, non_blocking=True).reshape(-1) + cu_seqlens_padded = torch.empty(padded_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens_padded[0] = 0 + cu_seqlens_padded[1:] = torch.cumsum(padded_lengths, dim=0) + max_seqlen = torch.max(padded_lengths).to(dtype=torch.int32) + + new_sample["cu_seqlens_padded"] = cu_seqlens_padded + new_sample["max_seqlen"] = max_seqlen + + original_lengths = original_lengths.to( + device=dev, dtype=torch.int32, non_blocking=True + ).reshape(-1) + cu_seqlens = torch.empty(original_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens[0] = 0 + cu_seqlens[1:] = torch.cumsum(original_lengths, dim=0).reshape(-1) + new_sample["cu_seqlens"] = cu_seqlens + + return new_sample + + +def broadcast_tensor(item, src_rank, group) -> None: + """Broadcast a tensor from src_rank to all ranks in the group.""" + if item is not None: + torch.distributed.broadcast(item, src_rank, group=group) + + +def broadcast_to_pp_group( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + pp_group, + dev, +): + """ + Broadcast num_micro_batches, seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch and metadata to middle PP stages. + Before this broadcast, the new_samples on middle PP stages are None, + after this broadcast, the new_samples on middle PP stages contain the metadata but + without tokens, labels, loss_mask, position_ids. + """ + + pp_src_rank = torch.distributed.get_process_group_ranks(pp_group)[0] + + if pp_group.size() > 2: + if pp_group.rank() == 0: + tensor_list = [ + torch.tensor( + [ + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ], + dtype=torch.float32, + ).cuda() + ] + for sample in new_samples: + tensor_list.append(sample["max_seqlen"].unsqueeze(0)) + for sample in new_samples: + tensor_list.append(sample["cu_seqlens"]) + tensor_list.append(sample["cu_seqlens_padded"]) + info_to_broadcast = torch.cat(tensor_list, dim=0).to(device=dev, dtype=torch.float32) + info_length_tensor = torch.tensor(info_to_broadcast.shape[0], dtype=torch.int32).cuda() + broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) + broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) + else: + info_length_tensor = torch.tensor(0, dtype=torch.int32).cuda() + broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) + info_to_broadcast = torch.empty(info_length_tensor.item(), dtype=torch.float32).cuda() + broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) + if pp_group.rank() != pp_group.size() - 1: + # middle PP stages receive the broadcasted info and unpack it + info_numpy = info_to_broadcast.cpu().numpy() + num_micro_batches = int(info_numpy[0]) + seqlen_sum_this_global_batch = info_numpy[1] + seqlen_squared_sum_this_global_batch = info_numpy[2] + max_seqlens = info_to_broadcast[3 : 3 + num_micro_batches] + cu_seqlens_list = [] + cu_seqlens_padded_list = [] + # cu_seqlens always starts with 0, and the other metadata values + # (num_micro_batches, seqlen_sum, seqlen_squared_sum, max_seqlens) + # are always positive, so we can use 0 as the delimiter to locate + # the start of each cu_seqlens / cu_seqlens_padded tensor. + # This avoids an extra broadcast for the lengths of cu_seqlens. + indices = np.where(info_numpy == 0)[0] + for i in range(num_micro_batches): + cu_seqlens_list.append(info_to_broadcast[indices[i * 2] : indices[i * 2 + 1]]) + if i == num_micro_batches - 1: + cu_seqlens_padded_list.append(info_to_broadcast[indices[i * 2 + 1] :]) + else: + cu_seqlens_padded_list.append( + info_to_broadcast[indices[i * 2 + 1] : indices[i * 2 + 2]] + ) + + new_samples = [] + for i in range(num_micro_batches): + new_sample = {} + new_sample["max_seqlen"] = max_seqlens[i].to(torch.int32) + new_sample["cu_seqlens"] = cu_seqlens_list[i].to(torch.int32) + new_sample["cu_seqlens_padded"] = cu_seqlens_padded_list[i].to(torch.int32) + new_samples.append(new_sample) + + return ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + +def broadcast_scalars(values: List, group, dev, dtype=torch.float32) -> List: + """ + Broadcast scalar values from rank 0 to all ranks in the group. + + Args: + values: List of scalar values to broadcast (only used on rank 0). + group: The process group to broadcast within. + dev: The device to use for the tensor. + dtype: The data type for the tensor. + + Returns: + List of broadcasted values. + """ + if group.size() <= 1: + return values + + src_rank = torch.distributed.get_process_group_ranks(group)[0] + num_values = len(values) + + if group.rank() == 0: + info_to_broadcast = torch.tensor(values, dtype=dtype, device=dev) + else: + info_to_broadcast = torch.zeros(num_values, dtype=dtype, device=dev) + + broadcast_tensor(info_to_broadcast, src_rank, group) + + if group.rank() != 0: + values = info_to_broadcast.cpu().tolist() + + return values + + +def create_data_iterator(new_samples, tp_group, config, vpp_has_data=None): + """Handle virtual pipeline parallelism. + + For VPP, each PP rank needs a list of data iterators (one per VPP stage). + VPP stages that originally had a data_iterator (indicated by vpp_has_data) + get full samples; others get metadata only (cu_seqlens, cu_seqlens_padded, + max_seqlen). + + Args: + new_samples: The packed samples after scheduling. + tp_group: Tensor parallel process group. + config: Model parallel config. + vpp_has_data: A list of booleans (one per VPP stage) indicating which + VPP stages originally had a data_iterator. None if VPP is disabled. + """ + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + vpp_size = config.virtual_pipeline_model_parallel_size + if tp_group.rank() == 0: + metadata = [ + {k: sample[k] for k in ["max_seqlen", "cu_seqlens", "cu_seqlens_padded"]} + for sample in new_samples + ] + new_data_iterator = [] + for i in range(vpp_size): + if vpp_has_data is not None and vpp_has_data[i]: + new_data_iterator.append(RerunDataIterator(iter(new_samples))) + else: + new_data_iterator.append(RerunDataIterator(iter(metadata))) + else: + new_data_iterator = [None for _ in range(vpp_size)] + else: + new_data_iterator = RerunDataIterator(iter(new_samples)) if tp_group.rank() == 0 else None + + return new_data_iterator + + +def reroute_samples_to_dcp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_dcp_gpus, +): + """ + Reroutes the sub-samples to the correct rank after scheduling. + + For each key in the batch dict, we perform an all-to-all communication + to transfer the data to the correct ranks. + """ + + def _gid_to_src_rank(gid: int) -> int: + dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) + dcp_rank = ( + torch.distributed.get_process_group_ranks(dp_group)[dp_src_rank] // tp_group.size() + ) % dp_cp_group.size() + return dcp_rank + + gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} + dcp_rank = dp_cp_group.rank() + dp_ranks = torch.distributed.get_process_group_ranks(dp_group) + dp_ranks = [(r // tp_group.size()) % dp_cp_group.size() for r in dp_ranks] + + data_keys = batch[0].keys() + + # Create the send plan + combined_sample_id_groups: List[List[int]] = [[] for _ in range(total_dcp_gpus)] + for d in range(total_dcp_gpus): + for sample_id_group in sample_id_groups: + combined_sample_id_groups[d].extend(sample_id_group[d]) + for dest_rank in range(total_dcp_gpus): + combined_sample_id_groups[dest_rank].sort() + + send_ids_sorted = [ + gid for d in dp_ranks for gid in combined_sample_id_groups[d] if gid in global_ids_this_rank + ] + + send_num_split = [0] * total_dcp_gpus + send_lens_split = [0] * total_dcp_gpus + for dest_rank in range(total_dcp_gpus): + if dest_rank in dp_ranks: + send_seq_lens = [ + global_id_seqlens[gid][1] + for gid in combined_sample_id_groups[dest_rank] + if gid in global_ids_this_rank + ] + send_num_split[dest_rank] = len(send_seq_lens) + send_lens_split[dest_rank] = sum(send_seq_lens) + else: + send_lens_split[dest_rank] = 0 + + # Create the recv plan + recv_sample_id_groups = [[] for _ in range(total_dcp_gpus)] + for gid in combined_sample_id_groups[dcp_rank]: + src_rank = _gid_to_src_rank(gid) + recv_sample_id_groups[src_rank].append(gid) + + recv_lens_split = [0] * total_dcp_gpus + for src_rank in range(total_dcp_gpus): + recv_lens_split[src_rank] = sum( + [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] + ) + + recv_ids_sorted = [gid for d in range(total_dcp_gpus) for gid in recv_sample_id_groups[d]] + recv_counts = [len(recv_sample_id_groups[d]) for d in range(total_dcp_gpus)] + + recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] + + def _pack_sample_by_key(key: str) -> torch.Tensor: + flattened_tensors = [] + for gid in send_ids_sorted: + t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) + flattened_tensors.append(t.reshape(-1)) + return ( + torch.cat(flattened_tensors, dim=0) + if flattened_tensors + else torch.empty(1, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + ) + + def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): + cursor = 0 + for i, gid in enumerate(recv_ids_sorted): + sample_len = ( + 1 if key in ["original_seq_len", "padded_seq_len"] else global_id_seqlens[gid][1] + ) + recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] + cursor += sample_len + + for key in data_keys: + output_split_sizes, input_split_sizes = ( + (recv_counts, send_num_split) + if key in ["original_seq_len", "padded_seq_len"] + else (recv_lens_split, send_lens_split) + ) + send_tensor = _pack_sample_by_key(key) + recv_tensor_size = sum(output_split_sizes) + recv_tensor = torch.empty( + recv_tensor_size, device=torch.cuda.current_device(), dtype=send_tensor.dtype + ) + torch.distributed.all_to_all_single( + output=recv_tensor, + input=send_tensor, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=dp_cp_group, + ) + _unpack_sample_by_key(key, recv_tensor) + + recv_sample_with_id = {recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted)} + return recv_sample_with_id + + +def build_packed_microbatches( + grouped_samples: List[List[Dict[str, torch.Tensor]]], dev: torch.device +) -> List[Dict[str, torch.Tensor]]: + """Build packed samples for each microbatch.""" + num_micro_batches = len(grouped_samples) + seg_starts: List[int] = [0] + original_lens_tensors = [] + padded_lens_tensors = [] + + for i in range(num_micro_batches): + samples = grouped_samples[i] + seg_starts.append(seg_starts[-1] + len(samples)) + original_lens_tensors.extend([s["original_seq_len"].reshape(-1) for s in samples]) + padded_lens_tensors.extend([s["padded_seq_len"].reshape(-1) for s in samples]) + + padded_lens_all_gpu = torch.cat(padded_lens_tensors, dim=0).to(dtype=torch.int32) + original_lens_all_gpu = torch.cat(original_lens_tensors, dim=0).to(dtype=torch.int32) + + new_samples: List[Dict[str, torch.Tensor]] = [] + for i in range(num_micro_batches): + samples = grouped_samples[i] + lens_padded = padded_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + lens_original = original_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + new_sample = _pack_sequences(samples, lens_padded, lens_original, dev) + new_samples.append(new_sample) + + return new_samples + + +def get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group): + """ + Get the batch and global sequence lengths. + Each DP rank loads the same number of sequences, so we need to gather the sequence + lengths from all ranks then we can schedule the sequences into groups. + Args: + data_iterator: The data iterator. + num_microbatches: The number of microbatches. + dp_group: The data parallel group. + + Returns: + batch: The batch. + global_id_seqlens: The global sequence lengths. + global_ids_this_rank: The global IDs locally present on this rank. + """ + + batch_list = [next(data_iterator) for _ in range(num_microbatches)] + + batch = [] + for item in batch_list: + if isinstance(item, dict): + batch.append(item) + elif isinstance(item, list): + batch.extend(item) + else: + raise ValueError(f"Invalid item type: {type(item)}") + + # in sft_dataset.py, sequences are already packed before rescheduling, + # so we need to unpack them here and repack after rescheduling. + # This is only to adapt to the current megatron-lm sft_dataset. + # If you implement your own dataset, just have __getitem__ return List[Dict] + # and this step can be skipped. + batch = _unpack_batch(batch) + + subsample_seqlens = torch.cat([sample["padded_seq_len"] for sample in batch]).to( + dtype=torch.int32, device=torch.cuda.current_device() + ) + + global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered = ( + _get_global_seqlens_and_ids(subsample_seqlens, dp_group) + ) + + return batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index cbe0652402d..04d2c279818 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -79,6 +79,9 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): context_parallel_size: Optional[int] = None """The size of the context parallel group. Needed for padding in packed sequences.""" + sft_mock_dataset_config_json: Optional[str] = None + """This config provides the necessary information for the mock dataset.""" + def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() diff --git a/megatron/core/datasets/readme.md b/megatron/core/datasets/readme.md index 452bf24e4a2..a61c623d960 100644 --- a/megatron/core/datasets/readme.md +++ b/megatron/core/datasets/readme.md @@ -192,6 +192,68 @@ To query the `BlendedDataset` for the _k_-th sample we do the following To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `BlendedDataset.__init__` function. +## Packing Scheduler + +The packing scheduler re-schedules variable-length sequences across DP×CP ranks to improve GPU utilization. It is built around two modules: `data_schedule.py` (high-level logic and entry points) and `data_schedule_utils.py` (utility functions). + +### Call Hierarchy + +The scheduling pipeline has two phases connected by the data iterator: `wrap_data_iterator` consumes the **original** data iterator, performs global-batch scheduling, and produces a **wrapped** (packed) data iterator; `get_batch_on_this_rank_for_sequence_packing` then consumes this **wrapped** data iterator to fetch individual packed microbatches during training. + +``` + original wrapped (packed) + data_iterator data_iterator + │ │ + ▼ ▼ + ┌────────────────────────┐ ┌────────────────────────────────────┐ + │ wrap_data_iterator() │ │ get_batch_on_this_rank_for_ │ +Phase 1 │ (once per global │ ────────► │ sequence_packing() │ Phase 2 +(scheduling) │ batch) │ returns │ (once per microbatch, │ (fetching) + │ │ wrapped │ called by training loop) │ + └───────────┬────────────┘ iterator └──────────────┬─────────────────────┘ + │ │ + ▼ ▼ + DpBalancedScheduler.run() next(wrapped_data_iterator) + │ ├─ get_thd_partitioned_indices() [TE] + ├─ get_batch_and_global_seqlens() [utils] ├─ broadcast_tensor() [utils] + ├─ get_groups_and_subsamples() └─ PackedSeqParams(...) + ├─ reroute_samples_to_dcp_ranks() [utils] + ├─ build_packed_microbatches() [utils] + ├─ broadcast_to_pp_group() [utils] + ├─ broadcast_scalars() [utils] + └─ create_data_iterator() [utils] +``` + +### `data_schedule.py` + +#### Entry Points + +- **`wrap_data_iterator(original_data_iterator) → wrapped_data_iterator`** — Top-level entry point called once per global batch. Takes the **original** data iterator as input, resolves the scheduler class from `scheduler_map`, instantiates it, and delegates to `scheduler.run()` which consumes all microbatches from the original iterator, re-schedules them, and produces a **wrapped** (packed) data iterator along with the updated `num_microbatches` and FLOPs statistics. + +- **`get_batch_on_this_rank_for_sequence_packing(wrapped_data_iterator)`** — Per-microbatch entry point called by the training loop. Takes the **wrapped** data iterator returned by `wrap_data_iterator` as input. Fetches one packed microbatch via `next(wrapped_data_iterator)`, broadcasts batch fields across TP ranks, optionally partitions sequences across CP ranks using Transformer Engine's `thd_get_partitioned_indices`, and constructs `PackedSeqParams` (with `cu_seqlens`, `max_seqlen`, `qkv_format=thd`). + +#### Scheduler Classes + +- **`BasePackingScheduler`** — Abstract base class. Defines the interface: + - `get_groups_and_subsamples()` — pure scheduling algorithm (must be overridden). + - `run()` — full pipeline: fetch → schedule → reroute → pack → broadcast → VPP handling. + +- **`DpBalancedScheduler(BasePackingScheduler)`** — Concrete scheduler that packs sequences in their original order until reaching `max_seqlen_per_dp_cp_rank × cp_size`. Aligns the number of microbatches to `dp_size` (and VPP stage multiples when applicable). + +### `data_schedule_utils.py` + +Utility functions consumed by the schedulers above: + +| Function | Role | +|---|---| +| `get_batch_and_global_seqlens()` | Fetch `num_microbatches` batches from the data iterator and all-gather sequence lengths across DP ranks. | +| `reroute_samples_to_dcp_ranks()` | All-to-all communication to transfer sub-samples to their scheduled DP×CP rank. | +| `build_packed_microbatches()` | Concatenate sub-samples within each microbatch group and produce `cu_seqlens`. | +| `broadcast_to_pp_group()` | Broadcast packed samples and metadata from the first/last PP stage to middle stages. | +| `broadcast_scalars()` | Broadcast scalar values (e.g. `num_microbatches`, FLOPs stats) across a process group. | +| `broadcast_tensor()` | Broadcast a single tensor within a process group. | +| `create_data_iterator()` | Wrap packed sample lists into a data iterator; handles VPP stage splitting. | + ## Fast DataLoader initialization Especially for large-scale runs, DataLoader initialization can take several minutes, since it involves opening and memory-mapping multiple files and can significantly stress the filesystem. To speed up this process, we have developed the following three optimizations, controlled by configuration flags": diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index bb913d97446..20f0ece635e 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -2559,3 +2559,24 @@ def set_save_original_input(module): from transformer_engine.pytorch.float8_tensor import Float8Tensor except ImportError: Float8Tensor = None + + +def get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank): + """Get partitioned indices for THD format data in context parallel. + + Args: + cu_seqlens: Cumulative sequence lengths tensor. + total_tokens: Total number of tokens. + cp_size: Context parallel world size. + cp_rank: Context parallel rank. + + Returns: + Partitioned indices tensor. + """ + assert is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + import transformer_engine_torch as tex + + return tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 5bbeef9b022..970b3b871fe 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -62,7 +62,7 @@ class ModelParallelConfig: can handle without overflowing the memory. Typically, a good starting point is to set this to maximum sequence length / context parallel size. This is used to calculate the number and length of sub-samples assigned to - each rank when using hybrid_context_parallel. + each rank when sequence_packing_scheduler is not None. """ hybrid_context_parallel: bool = False @@ -72,6 +72,12 @@ class ModelParallelConfig: Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel. """ + sequence_packing_scheduler: Optional[Literal['dp_balanced']] = None + """ + Scheduler for sequence packing and hybrid context parallel. + dp_balanced: DP-balanced scheduler for sequence packing. + """ + expert_model_parallel_size: int = 1 """Distributes Moe Experts across sub data parallel dimension.""" diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 9da9a644a47..d48e29c1e71 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2076,6 +2076,40 @@ def __post_init__(self): self.attention_backend == AttnBackend.flash ), "Batch invariant mode only supports FlashAttention" + if self.sequence_packing_scheduler is not None: + # Check TE version. + if not HAVE_PACKAGING: + raise ImportError( + "packaging is not installed. Please install it with `pip install packaging`." + ) + # TODO: remove this after we fix the convergence issue with TE < 2.9. + if not ( + is_te_min_version("2.9.0") or get_te_version() == PkgVersion("2.9.0.dev0+5b3092a") + ): + raise ValueError( + "SFT sequence packing requires Transformer Engine >= 2.9.0 " + f"but got {get_te_version()} (TE < 2.9.0 may have convergence issues)." + ) + + # Needed for passing variable sequences between pp stages. + self.variable_seq_lengths = True + + # TODO(tailaim): add support for other dispatcher types + assert self.moe_token_dispatcher_type == "alltoall", ( + f"sequence_packing only supports moe_token_dispatcher_type='alltoall', " + f"got '{self.moe_token_dispatcher_type}'" + ) + + supported_schedulers = ['dp_balanced'] + if ( + self.sequence_packing_scheduler is not None + and self.sequence_packing_scheduler not in supported_schedulers + ): + raise ValueError( + f"Unsupported scheduler: {self.sequence_packing_scheduler}. " + f"Available schedulers: {supported_schedulers}" + ) + @dataclass @experimental_api diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5d5fa34b6c5..25f0d0d06d0 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -884,13 +884,6 @@ def validate_args(args, defaults={}): if args.rl_use_sequence_packing: args.consumed_train_bins = 0 - # Support for variable sequence lengths across batches/microbatches. - # set it if the dataloader supports generation of variable sequence lengths - # across batches/microbatches. Due to additional communication overhead - # during pipeline parallelism, it should not be set if sequence length - # is constant during training. - args.variable_seq_lengths = False - # Iteration-based training. if args.train_iters: # If we use iteration-based training, make sure the @@ -1061,6 +1054,11 @@ def validate_args(args, defaults={}): assert args.dataloader_type == 'single', 'Hybrid context parallelism only supported with single dataloader type' assert args.calculate_per_token_loss, 'Hybrid context parallelism must be used with --calculate-per-token-loss' + if args.sequence_packing_scheduler is not None: + assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ + f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ + f'must be >= single sequence max length ({args.seq_length})' + # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled if (args.tensor_model_parallel_size > 1 or args.context_parallel_size > 1) \ @@ -3061,4 +3059,8 @@ def _add_sft_args(parser): group.add_argument('--sft', action="store_true", help='Megatron SFT training') group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", help='SFT prompt format.') + group.add_argument('--sft-mock-dataset-config-json', type=str, default=None, + help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution. ' + 'If not specified and --mock-data is set, defaults to a lognormal distribution with ' + 'min_seq_len=seq_length//2, max_seq_len=seq_length, mean_seq_len=seq_length*3//4, lognormal_sigma=1.1.') return parser diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index 9de5d2a52fe..3f2e6e7362c 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -2,12 +2,16 @@ import atexit, json from collections import Counter -from typing import Any, Dict, Optional +import json +import math +from typing import Any, Dict, Optional, List, Union import numpy as np +import pandas as pd import torch from megatron.core.datasets.gpt_dataset import GPTDatasetConfig +from megatron.core.datasets.indexed_dataset import IndexedDataset from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset from megatron.core.datasets.utils import Split @@ -88,6 +92,26 @@ def _split_conversations(self, merged_conversations): split_conversations.append(current) return split_conversations + def _calculate_padding_divisor(self) -> int: + """ + Calculate the divisor used for sequence padding. + tp_pad = tp_size * 2 if tp_size > 1 else 1 + cp_pad = cp_size * 2 if cp_size > 1 else 1 + cp_pad = cp_pad * dp_size if hybrid_cp else cp_pad + divisor = cp_pad * tp_pad + """ + if self.config.hybrid_context_parallel: + # Hybrid CP: consider both CP and DP + cp_pad = self.config.data_parallel_size * self.config.context_parallel_size * 2 + else: + # Standard CP: only consider CP + cp_pad = self.config.context_parallel_size * 2 if self.config.context_parallel_size > 1 else 1 + tp_pad = self.config.sequence_parallel_size if self.config.sequence_parallel_size > 0 else 1 + divisor = cp_pad * tp_pad + # TODO(tailaim): do we need to pad for FP8 execution? + # divisor = ((divisor + 15) // 16) * 16 + return divisor + def __getitem__(self, idx: int) -> Dict[str, Any]: tokenizer = self.config.tokenizer @@ -124,12 +148,11 @@ def extend_with_padding(tokens, targets, positions, pad_len): assert not self.config.reset_position_ids pack_positions.extend(range(len(tokens_list))) - if self.config.context_parallel_size > 1: - pad_granularity = self.config.context_parallel_size * 2 - mod_token_count = len(pack_tokens) % pad_granularity - if mod_token_count != 0: - pad_len = pad_granularity - mod_token_count - extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + pad_granularity = self._calculate_padding_divisor() + mod_token_count = len(pack_tokens) % pad_granularity + if mod_token_count != 0: + pad_len = pad_granularity - mod_token_count + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) # TODO(duncan): Consider also padding to multiple of number of tokens here. This might # be needed for efficiency (and potentially set via command-line argument). @@ -190,3 +213,214 @@ def extend_with_padding(tokens, targets, positions, pad_len): 'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, } + + +class MockSFTLowLevelDataset: + """The low-level mock dataset for SFT + + Args: + mode (str): One of 'file', 'distribution', or 'verification'. + **kwargs: Additional arguments depending on mode. + For mode='file': path (str) - path to a CSV file with sequence lengths. + For mode='distribution': type (str), min_seq_len (int), max_seq_len (int), + mean_seq_len (int), and distribution-specific params (e.g. lognormal_sigma). + For mode='verification': data_path (str) - prefix path to an IndexedDataset + (.bin/.idx files). Optional lognormal distribution params same as + 'distribution' mode (defaults: min_seq_len=100, max_seq_len=4096, + mean_seq_len=2048, lognormal_sigma=1.1). + format (str): Output format for MockSFTDataset. Either 'thd' (default, sequence + packing with cu_seqlens) or 'sbhd' (padded to seq_length, no cu_seqlens). + """ + + seed: int = 0 + """The hard-coded random seed to use to set the NumPy RNG""" + + size: int = 1000000 + """The hard-coded number of sequence to generate""" + + def __init__(self, mode: str, **kwargs) -> None: + np.random.seed(self.seed) + self.format = kwargs.get("format", "thd") + + if mode == "file": + self.sequence_lengths = np.array(pd.read_csv(kwargs["path"])).flatten() + self.size = len(self.sequence_lengths) + elif mode == "distribution": + min_seq_len = kwargs["min_seq_len"] + max_seq_len = kwargs["max_seq_len"] + mean_seq_len = kwargs["mean_seq_len"] + if kwargs["type"] == "lognormal": + lognormal_sigma = kwargs["lognormal_sigma"] + self.sequence_lengths = self.generate_lognormal_samples( + self.size, mean_seq_len, lognormal_sigma, min_seq_len, max_seq_len + ) + else: + raise ValueError(f"Unsupported distribution type {kwargs['type']}") + elif mode == "verification": + # Load real tokens from an IndexedDataset for realistic loss curves. + # Sequence lengths are drawn from a lognormal distribution (same as + # "distribution" mode) to allow controlled comparison of THD vs SBHD. + self.indexed_dataset = IndexedDataset(kwargs["data_path"]) + min_seq_len = kwargs.get("min_seq_len", 100) + max_seq_len = kwargs.get("max_seq_len", 4096) + mean_seq_len = kwargs.get("mean_seq_len", 2048) + lognormal_sigma = kwargs.get("lognormal_sigma", 1.1) + self.sequence_lengths = self.generate_lognormal_samples( + self.size, mean_seq_len, lognormal_sigma, min_seq_len, max_seq_len + ) + else: + raise ValueError(f"Unsupported mode '{mode}', must be 'file', 'distribution', or 'verification'") + + def generate_lognormal_samples(self, size, mean, sigma, min_seq_len, max_seq_len): + mu = np.log(mean) - sigma**2 / 2 + samples = np.random.lognormal(mu, sigma, size) + samples = np.clip(samples, min_seq_len, max_seq_len) + return samples.astype(int) + + def __len__(self) -> int: + return self.size + + def __getitem__(self, idx: int) -> np.ndarray: + # The returned sample has 'length-1' tokens; an EOD token is appended + # later in MockSFTDataset.__getitem__, making the total 'length' tokens. + length = int(self.sequence_lengths[idx % self.size]) + if hasattr(self, 'indexed_dataset'): + target = length - 1 + num_docs = len(self.indexed_dataset) + doc_idx = idx % num_docs + raw = self.indexed_dataset[doc_idx] + if len(raw) >= target: + sample = raw[:target] + else: + # Concatenate documents until we reach the target length. + chunks = [raw] + total = len(raw) + next_doc = doc_idx + 1 + while total < target: + raw_next = self.indexed_dataset[next_doc % num_docs] + need = target - total + chunks.append(raw_next[:need]) + total += min(len(raw_next), need) + next_doc += 1 + sample = np.concatenate(chunks)[:target] + assert len(sample) == target + return sample.astype(np.int64) + else: + return np.arange(1, length, dtype=np.int64) + + +class MockSFTDataset(SFTDataset): + """The mock dataset used during SFT""" + + def __init__( + self, + dataset: LowLevelDataset, + dataset_path: Optional[str], + indices: np.ndarray, + num_samples: Optional[int], + index_split: Split, + config: GPTDatasetConfig, + ) -> None: + super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) + + @staticmethod + def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowLevelDataset: + if config.sft_mock_dataset_config_json is None: + mock_config = { + "mode": "distribution", + "type": "lognormal", + "min_seq_len": config.sequence_length // 2, + "max_seq_len": config.sequence_length, + "mean_seq_len": config.sequence_length // 4 * 3, + "lognormal_sigma": 1.1, + } + else: + mock_config = json.loads(config.sft_mock_dataset_config_json) + return MockSFTLowLevelDataset(**mock_config) + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> Dict[str, Any]: + + tokenizer = self.config.tokenizer + pack_length = self.config.sequence_length + eod = tokenizer.eod + pad = tokenizer.pad + + tokens = self.dataset[int(self.indices[idx % len(self.indices)])] + + # Convert tokens to list and always append EOD to ensure length consistency. + # The low-level dataset returns length-1 tokens, and we add EOD to make it length tokens. + tokens_list = tokens.tolist() + tokens_list.append(eod) + + if self.dataset.format == "sbhd": + # SBHD format: single padded sequence without cu_seqlens. + # Long sequences are truncated to pack_length tokens (including EOD). + if len(tokens_list) >= pack_length + 1: + tokens_list = tokens_list[:pack_length - 1] + [eod] + # Pad to pack_length + 1 (offset by 1 for input/label split). + pad_len = pack_length + 1 - len(tokens_list) + if pad_len > 0: + tokens_list = tokens_list + [pad] * pad_len + assert len(tokens_list) == pack_length + 1 + input_ids = torch.tensor(tokens_list[:-1], dtype=torch.int64) + labels = torch.tensor(tokens_list[1:], dtype=torch.int64) + # Position IDs are sequential across the entire sequence including padding, + # matching GPTDataset behavior for standard (non-packed) training. + position_ids = torch.arange(pack_length, dtype=torch.int64) + loss_mask = torch.ones(pack_length, dtype=torch.float32) + loss_mask[labels == pad] = 0.0 + return { + 'tokens': input_ids, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + + # THD format (sequence packing) below. + def extend_with_padding(tokens, positions, pad_len): + tokens.extend([pad] * pad_len) + positions.extend(range(positions[-1] + 1, positions[-1] + 1 + pad_len)) + + pack_tokens = list(tokens_list) + [pad] + pack_positions = list(range(len(pack_tokens))) + + # Truncate if sequence exceeds pack_length + 1 (need +1 for shift). + if len(pack_tokens) > pack_length + 1: + pack_tokens = pack_tokens[:pack_length - 1] + [eod, pad] + pack_positions = pack_positions[:pack_length + 1] + + # Pad to pad_granularity alignment (tp * cp * 2). + # We need final length (after shift) to be divisible by pad_granularity. + pad_granularity = self._calculate_padding_divisor() + final_len = len(pack_tokens) - 1 + mod_token_count = final_len % pad_granularity + if mod_token_count != 0: + pad_len = pad_granularity - mod_token_count + extend_with_padding(pack_tokens, pack_positions, pad_len) + + # Apply shift for next-token prediction. + input_ids = torch.tensor(pack_tokens[:-1], dtype=torch.int64) + labels = torch.tensor(pack_tokens[1:], dtype=torch.int64) + position_ids = torch.tensor(pack_positions[:-1], dtype=torch.int64) + + seq_len = len(input_ids) + cu_seqlens = [0, seq_len] + + # Loss mask: mask padding tokens + loss_mask = torch.ones(seq_len, dtype=torch.float32) + loss_mask[labels == pad] = 0.0 + + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) + max_seqlen = torch.tensor(seq_len, dtype=torch.int32) + + return { + 'tokens': input_ids, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + 'cu_seqlens': cu_seqlens, + 'max_seqlen': max_seqlen, + } diff --git a/megatron/training/training.py b/megatron/training/training.py index 0c33206ba8b..26769fabe96 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -169,6 +169,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): get_num_microbatches, update_num_microbatches ) +from megatron.core.datasets.data_schedule import wrap_data_iterator from .async_utils import maybe_finalize_async_save from .utils import ( @@ -225,7 +226,7 @@ def print_datetime(string, override_timestamp=None): time_str = datetime.fromtimestamp(override_timestamp).strftime('%Y-%m-%d %H:%M:%S.%f') print_rank_0(f'[{string}] datetime: {time_str} ') -def num_floating_point_operations(args, batch_size): +def num_floating_point_operations(args, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch): def calculate_layer_counts(): """Calculate the number of attention, Mamba, and MLP layers.""" if args.hybrid_override_pattern: @@ -251,44 +252,42 @@ def calculate_layer_counts(): num_moe_layers = 0 return num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers - def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False): + def mlp_layer_flops(seqlen_sum_this_global_batch, hidden_size, expansion=4.0, swiglu=False): """Calculate FLOPs for an MLP layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 - return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 + return 4 * expansion * scale_factor * seqlen_sum_this_global_batch * hidden_size**2 - def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, + def moe_layer_flops(seqlen_sum_this_global_batch, hidden_size, moe_ffn_hidden_size, shared_expert_ffn_hidden_size, num_experts_routed_to, moe_latent_size=None, swiglu=False): """Calculate FLOPs for an MoE layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 if moe_latent_size is None: - routed_flops = (4 * batch_size * seq_len * hidden_size * + routed_flops = (4 * seqlen_sum_this_global_batch * hidden_size * moe_ffn_hidden_size * num_experts_routed_to * scale_factor) else: # Routed experts run on moe_latent_size. - routed_flops = (4 * batch_size * seq_len * moe_latent_size * + routed_flops = (4 * seqlen_sum_this_global_batch * moe_latent_size * moe_ffn_hidden_size * num_experts_routed_to * scale_factor) # Up proj and down proj. - routed_flops += (4 * batch_size * seq_len * hidden_size * moe_latent_size) - shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor + routed_flops += (4 * seqlen_sum_this_global_batch * hidden_size * moe_latent_size) + shared_flops = 4 * seqlen_sum_this_global_batch * hidden_size * shared_expert_ffn_hidden_size * scale_factor return routed_flops + shared_flops def attn_layer_flops( - batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None + seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None ): """Calculate FLOPs for an attention layer.""" p = (kv_channels * num_heads / hidden_size) if kv_channels else 1 g = gqa_groups if gqa else num_heads return ( 4 - * batch_size - * seq_len * hidden_size * p - * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2)) + * (hidden_size * seqlen_sum_this_global_batch + (hidden_size * (g / num_heads)) * seqlen_sum_this_global_batch + (seqlen_squared_sum_this_global_batch / 2)) ) - def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, + def mamba_layer_flops(seqlen_sum_this_global_batch, hidden_size, state_dim=16, head_dim=64, num_groups=1, num_heads=128): """Calculate FLOPs for a Mamba layer.""" # Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels, @@ -301,16 +300,15 @@ def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, return ( ( 2 - * batch_size - * seq_len + * seqlen_sum_this_global_batch * hidden_size * (2 * d_in + 2 * num_groups * state_dim + nheads) ) # in_proj - + (7 * batch_size * seq_len * d_in * state_dim) # scan - + (2 * batch_size * seq_len * d_in * hidden_size) # out_proj + + (7 * seqlen_sum_this_global_batch * d_in * state_dim) # scan + + (2 * seqlen_sum_this_global_batch * d_in * hidden_size) # out_proj ) - def hybrid_flops(batch_size, seq_len, hidden_size, + def hybrid_flops(seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, hidden_size, num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers, mamba_state_dim=128, mamba_head_dim=64, mamba_num_groups=8, mamba_num_heads=128, @@ -322,17 +320,17 @@ def hybrid_flops(batch_size, seq_len, hidden_size, vocab_size=256000, mtp_num_layers=0): """Calculate total FLOPs for the hybrid model.""" flops_fwd = ( - num_attn_layers * attn_layer_flops(batch_size, seq_len, hidden_size, + num_attn_layers * attn_layer_flops(seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, hidden_size, num_attn_heads, gqa, gqa_groups, kv_channels) + - num_mlp_layers * mlp_layer_flops(batch_size, seq_len, hidden_size, + num_mlp_layers * mlp_layer_flops(seqlen_sum_this_global_batch, hidden_size, mlp_expansion, swiglu) + - num_mamba_layers * mamba_layer_flops(batch_size, seq_len, hidden_size, + num_mamba_layers * mamba_layer_flops(seqlen_sum_this_global_batch, hidden_size, mamba_state_dim, mamba_head_dim, mamba_num_groups, mamba_num_heads) + - num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, + num_moe_layers * moe_layer_flops(seqlen_sum_this_global_batch, hidden_size, moe_ffn_hidden_size, shared_expert_ffn_hidden_size, num_experts_routed_to, moe_latent_size, swiglu) + - (2 * batch_size * seq_len * hidden_size * vocab_size * (1 + mtp_num_layers)) # logits computation + (2 * seqlen_sum_this_global_batch * hidden_size * vocab_size * (1 + mtp_num_layers)) # logits computation ) return flops_fwd * 3 @@ -403,13 +401,18 @@ def transformer_flops(): assert not args.group_query_attention ''' Basic arithmetic - let B is batch size, s is seq_len, h is embedding dim, - for one self_attnetion block (prenorm is not included) - qkv projection: 6Bsh^2 - attn: 2Bs^2h - attn over value: 2Bs^2h - oproj: 2Bsh^2 - + + Let h be the embedding dim. + We use two statistics to unify BSHD and THD cases: + seqlen_sum_this_global_batch: total number of tokens in this global batch + seqlen_squared_sum_this_global_batch: sum of squared sequence lengths in this global batch + + For one self-attention block (prenorm not included): + qkv projection: 6 * seqlen_sum_this_global_batch * h^2 + attn: 2 * seqlen_squared_sum_this_global_batch * h + attn over value: 2 * seqlen_squared_sum_this_global_batch * h + oproj: 2 * seqlen_sum_this_global_batch * h^2 + references https://arxiv.org/abs/2305.10403 https://arxiv.org/abs/2205.05198 @@ -430,7 +433,7 @@ def transformer_flops(): standard_self_attn_term = ( forward_backward_expansion_factor * fma_expansion_factor - * ( + * ( seqlen_sum_this_global_batch * ( ## q lora + rope + q norm q_term ## kv lora + rope + kv norm @@ -442,12 +445,12 @@ def transformer_flops(): ) + args.hidden_size * args.qk_pos_emb_head_dim ## o proj - + (args.num_attention_heads * args.v_head_dim) * args.hidden_size + + (args.num_attention_heads * args.v_head_dim) * args.hidden_size) ## core attn - + args.seq_length + + seqlen_squared_sum_this_global_batch * (args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)) - / 2 # causal mask (only half of the mask is non-zero) - + args.seq_length * args.num_attention_heads * args.v_head_dim / 2 + / 2 # causal mask (only half of the mask is non-zero) + + seqlen_squared_sum_this_global_batch * args.num_attention_heads * args.v_head_dim / 2 ) ) @@ -460,7 +463,7 @@ def transformer_flops(): standard_self_attn_term = ( forward_backward_expansion_factor * fma_expansion_factor - * ( + * ( seqlen_sum_this_global_batch *( ## qkv proj args.hidden_size * ( @@ -468,14 +471,14 @@ def transformer_flops(): + key_projection_size + value_projection_size + gate_projection_size - ) + )) ## core attention + query_projection_size - * args.seq_length + * seqlen_squared_sum_this_global_batch / 2 # causal mask (only half of the mask is non-zero) * 2 # QK^T and (QK^T)V ## out proj - + query_projection_size + + seqlen_sum_this_global_batch * query_projection_size * args.hidden_size ) ) @@ -536,7 +539,7 @@ def transformer_flops(): + args.hidden_size * v_dim ) - ) + ) * seqlen_sum_this_global_batch else: raise ValueError( "Invalid experimental_attention_variant: " @@ -553,8 +556,7 @@ def transformer_flops(): ) total_floating_point_operations = ( - batch_size - * args.seq_length + seqlen_sum_this_global_batch * ( # MLP forward_backward_expansion_factor @@ -584,8 +586,6 @@ def transformer_flops(): + (shared_expert_ffn_hidden_size * ffn_expansion_factor) * num_moe_layers ) - # Self Attention - + self_attn_term # MTP norms and proj + forward_backward_expansion_factor * fma_expansion_factor @@ -603,6 +603,10 @@ def transformer_flops(): * args.padded_vocab_size * (mtp_num_layers + 1) # MTP + final logit ) + + + # Self Attention + self_attn_term + ) return total_floating_point_operations @@ -616,8 +620,8 @@ def transformer_flops(): mtp_num_layers = 0 # Compute hybrid model FLOPs. return hybrid_flops( - batch_size=batch_size, - seq_len=args.seq_length, + seqlen_sum_this_global_batch=seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch=seqlen_squared_sum_this_global_batch, hidden_size=args.hidden_size, num_attn_layers=num_attn_layers, num_mamba_layers=num_mamba_layers, @@ -1728,6 +1732,27 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch if isinstance(optim_instance, DistributedOptimizer): optim_instance.release_offloaded_gpu_states() + if config.sequence_packing_scheduler is not None: + # This wrapper is designed to support DP-balanced THD and dynamic-CP. + # Before wrapping, the data_iterator returns either a single sequence per get_item call, or a list where each element is a sequence. + # The wrapper is responsible for: + # 1. scheduling the sequences across ranks + # 2. packing them into THD format + # 3. broadcast flops parametes and num_microbatches to TP ranks to support unfixed num_microbatches + # 4. broadcast metadata(cu_seqlens, cu_seqlens_padded, max_seqlen, etc.) to PP ranks to + # 5. returning the packed data iterator and the FLOPs parameters + ( + data_iterator, + num_microbatches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = wrap_data_iterator(data_iterator, config, get_num_microbatches()) + else: + # data_iterator unchanged + num_microbatches = get_num_microbatches() + seqlen_sum_this_global_batch = args.seq_length * args.global_batch_size + seqlen_squared_sum_this_global_batch = args.seq_length ** 2 * args.global_batch_size + # Forward pass. if save_dgrads_in_this_iteration: enable_dgrad_logging(model, args.save) @@ -1735,7 +1760,7 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch forward_step_func=forward_step_func, data_iterator=data_iterator, model=model, - num_microbatches=get_num_microbatches(), + num_microbatches=num_microbatches, seq_length=args.seq_length, micro_batch_size=args.micro_batch_size, decoder_seq_length=args.decoder_seq_length, @@ -1768,7 +1793,7 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() if should_exit: - return {}, True, should_checkpoint, should_exit, exit_code, None, None, 0 + return {}, True, should_checkpoint, should_exit, exit_code, None, None, 0, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch # Empty unused memory. if args.empty_unused_memory_level >= 1: @@ -1848,8 +1873,10 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch grad_norm, num_zeros_in_grad, log_max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, ) - return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit + return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch def training_log( @@ -1864,6 +1891,8 @@ def training_log( params_norm, num_zeros_in_grad, max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, pg_collection=None, is_first_iteration=False, ): @@ -2096,7 +2125,7 @@ def training_log( elapsed_time = timers('interval-time').elapsed(barrier=True, reset=should_reset) elapsed_time_per_iteration = elapsed_time / total_iterations - throughput = num_floating_point_operations(args, batch_size) / ( + throughput = num_floating_point_operations(args,seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) / ( elapsed_time_per_iteration * 10**12 * args.world_size ) @@ -2864,6 +2893,8 @@ def trace_handler(p): # Completely skip iteration if needed. if iteration in args.iterations_to_skip: + # TODO(tailaim): this need to be modified + assert config.sequence_packing_scheduler is None, "Sequence packing scheduler is not supported in skip iteration mode" # Dummy train_step to fast forward train_data_iterator. dummy_train_step(train_data_iterator) if iteration == start_iteration: @@ -2906,6 +2937,8 @@ def trace_handler(p): grad_norm, num_zeros_in_grad, max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, ) = train_step( forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func, iteration=iteration ) @@ -2993,7 +3026,7 @@ def trace_handler(p): else: assert num_skipped_samples_in_batch == 0 args.skipped_train_samples += num_skipped_samples_in_batch - num_floating_point_operations_in_batch = num_floating_point_operations(args, batch_size) + num_floating_point_operations_in_batch = num_floating_point_operations(args, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) num_floating_point_operations_so_far += num_floating_point_operations_in_batch num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch @@ -3019,6 +3052,8 @@ def trace_handler(p): params_norm, num_zeros_in_grad, max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, pg_collection=model_pg_collection, is_first_iteration=is_first_iteration, ) @@ -3214,9 +3249,30 @@ def evaluate( # Don't care about timing during evaluation config.timers = None ft_integration.on_eval_step_start() + if config.sequence_packing_scheduler is not None: + # This wrapper is designed to support DP-balanced THD and dynamic-CP. + # Before wrapping, the data_iterator returns either a single sequence per get_item call, or a list where each element is a sequence. + # The wrapper is responsible for: + # 1. scheduling the sequences across ranks + # 2. packing them into THD format + # 3. broadcast flops parametes and num_microbatches to TP ranks to support unfixed num_microbatches + # 4. broadcast metadata(cu_seqlens, cu_seqlens_padded, max_seqlen, etc.) to PP ranks to + # 5. returning the packed data iterator and the FLOPs parameters + try: + ( + packed_data_iterator, + eval_num_microbatches, + _, + _, + ) = wrap_data_iterator(data_iterator, config, eval_num_microbatches) + except StopIteration: + # Validation data iterator exhausted, stop evaluation early. + break + else: + packed_data_iterator = data_iterator loss_dicts = forward_backward_func( forward_step_func=forward_step_func, - data_iterator=data_iterator, + data_iterator=packed_data_iterator, model=model, num_microbatches=eval_num_microbatches, seq_length=args.seq_length, diff --git a/pretrain_gpt.py b/pretrain_gpt.py index e6ce7ac2a48..083f97b0a2f 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -25,6 +25,7 @@ from megatron.core import parallel_state from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset +from megatron.core.datasets.data_schedule import get_batch_on_this_rank_for_sequence_packing from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel from megatron.core.rerun_state_machine import get_rerun_state_machine @@ -49,6 +50,7 @@ get_blend_and_blend_per_split, is_first_or_last_pipeline_stage, ) +from megatron.training.datasets.sft_dataset import SFTDataset, MockSFTDataset from model_provider import model_provider try: @@ -66,6 +68,15 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): """Generate a batch.""" args = get_args() config = core_transformer_config_from_args(args) + + if args.sequence_packing_scheduler is not None: + return get_batch_on_this_rank_for_sequence_packing( + data_iterator, + vpp_size=config.virtual_pipeline_model_parallel_size, + mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage), + vp_stage=vp_stage, + ) + # TODO: this is pretty hacky, find a better way if not is_first_or_last_pipeline_stage(vp_stage) and ( (not mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage))): @@ -250,6 +261,7 @@ def core_gpt_dataset_config_from_args(args): "data_parallel_size": args.data_parallel_size, "sequence_parallel_size": args.tensor_model_parallel_size*args.sequence_parallel, "hybrid_context_parallel": args.hybrid_context_parallel, + "sft_mock_dataset_config_json":args.sft_mock_dataset_config_json, } # add FIM args to the config @@ -287,7 +299,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None config = core_gpt_dataset_config_from_args(args) if args.sft: - dataset_type = SFTDataset + if args.mock_data: + dataset_type = MockSFTDataset + else: + dataset_type = SFTDataset else: if args.mock_data: dataset_type = MockGPTDataset diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index 39b4a18e243..9797f5c20f7 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -275,6 +275,7 @@ "offload_modules": [], "hybrid_context_parallel": False, "max_seqlen_per_dp_cp_rank": None, + "sequence_packing_scheduler": None, "fallback_to_eager_attn": False, "linear_attention_type": None, "moe_router_force_biased": None, diff --git a/tests/unit_tests/test_sequence_packing.py b/tests/unit_tests/test_sequence_packing.py new file mode 100644 index 00000000000..60316b0236e --- /dev/null +++ b/tests/unit_tests/test_sequence_packing.py @@ -0,0 +1,479 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import random +from types import SimpleNamespace + +import numpy as np +import pytest +import torch + +from megatron.core import parallel_state +from megatron.core.datasets.data_schedule import ( + get_batch_on_this_rank_for_sequence_packing, + wrap_data_iterator, +) +from megatron.core.rerun_state_machine import RerunDataIterator +from megatron.training.global_vars import unset_global_variables +from tests.unit_tests.test_utilities import Utils + + +class MockVariableLengthSequencePackingDataIterator: + """ + Mock data iterator for testing get_batch_on_this_rank_for_sequence_packing. + + Generates variable-length (THD format) packed sequences with deterministic + data for verification across parallel ranks. + """ + + def __init__( + self, + total_seq_length: int, + sequence_lengths: list, + local_cp_size: int = None, + device: str = "cuda", + seed: int = 42, + ): + """ + Args: + total_seq_length: Total length of packed sequences + sequence_lengths: List of individual sequence lengths (variable-length). + If None, generates random variable lengths. + device: Device to create tensors on + seed: Random seed for reproducibility + """ + self.total_seq_length = total_seq_length + self.sequence_lengths = sequence_lengths + self.local_cp_size = local_cp_size + self.device = device + self.seed = seed + assert ( + sum(self.sequence_lengths) == total_seq_length + ), f"Sequence lengths sum {sum(self.sequence_lengths)} != total {total_seq_length}" + + def __iter__(self): + """Interface for the data iterator.""" + return self + + def __next__(self): + """Generate a mock batch with variable-length THD format.""" + dev = self.device + torch.manual_seed(self.seed) + torch.cuda.manual_seed(self.seed) + + tokens = torch.randint(0, 16384, (self.total_seq_length,), dtype=torch.int64, device=dev) + + # Create position_ids that reset for each sequence (THD format) + position_ids = [] + for seq_len in self.sequence_lengths: + position_ids.extend(range(seq_len)) + position_ids = torch.tensor(position_ids, dtype=torch.int64, device=dev) + + # Labels are tokens shifted by 1 for easy verification + labels = tokens + 1 + + # Loss mask: 1.0 for all positions except padding (none here) + loss_mask = torch.ones(self.total_seq_length, dtype=torch.float32, device=dev) + + # Create cu_seqlens for variable-length packed sequences + cu_seqlens = [0] + for seq_len in self.sequence_lengths: + cu_seqlens.append(cu_seqlens[-1] + seq_len) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=dev) + cu_seqlens_padded = cu_seqlens.clone() + + max_seqlen = torch.tensor([max(self.sequence_lengths)], dtype=torch.int32, device=dev) + + batch = { + "tokens": tokens, + "position_ids": position_ids, + "labels": labels, + "loss_mask": loss_mask, + "cu_seqlens": cu_seqlens, + "cu_seqlens_padded": cu_seqlens_padded, + "max_seqlen": max_seqlen, + } + + if not ( + parallel_state.is_pipeline_first_stage(ignore_virtual=True) + or parallel_state.is_pipeline_last_stage(ignore_virtual=True) + ): + batch["tokens"] = None + batch["position_ids"] = None + batch["labels"] = None + batch["loss_mask"] = None + + if self.local_cp_size is not None: + batch["local_cp_size"] = torch.tensor( + [self.local_cp_size], dtype=torch.int32, device=dev + ) + + return batch + + +def _gather_tensor_from_tp_group(tensor): + """Gather tensors from all TP ranks for comparison.""" + assert tensor is not None, "Tensor should not be None" + tp_size = parallel_state.get_tensor_model_parallel_world_size() + gathered = [torch.zeros_like(tensor) for _ in range(tp_size)] + torch.distributed.all_gather( + gathered, tensor, group=parallel_state.get_tensor_model_parallel_group() + ) + return gathered + + +def _gather_tensor_from_all_ranks(tensor): + """Gather tensors from all PP ranks for comparison.""" + assert tensor is not None, "Tensor should not be None" + if type(tensor) is int: + tensor = torch.tensor(tensor, dtype=torch.int32, device=torch.cuda.current_device()) + gathered = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(gathered, tensor) + return gathered + + +@pytest.mark.parametrize( + ("tp", "pp", "cp"), + [ + (1, 1, 1), # Basic case: no parallelism + (2, 1, 1), # Tensor parallel only + (1, 2, 1), # Pipeline parallel only + (2, 2, 1), # TP + PP + (1, 1, 2), # CP only + (2, 1, 2), # TP + CP + (1, 2, 2), # PP + CP + (1, 4, 1), # Has middle pp stage + ], +) +def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp): + """ + Test get_batch_on_this_rank_for_sequence_packing function with variable-length THD format. + + This test verifies: + 1. TP ranks: All ranks within a TP group receive identical data after broadcast + 2. PP ranks: Middle PP ranks have the same packed_seq_params as first/last stages + 3. CP ranks: Data is correctly partitioned with proper shape and values + 4. Variable-length (THD) format: Different sequence lengths are handled correctly + """ + args = SimpleNamespace() + args.tensor_model_parallel_size = tp + args.pipeline_model_parallel_size = pp + args.context_parallel_size = cp + args.virtual_pipeline_model_parallel_size = None + args.data_parallel_size = 8 // (tp * pp * cp) + args.seq_length = 8192 + + # Skip invalid configurations + if args.data_parallel_size < 1: + raise ValueError(f"Invalid config: tp={tp}, pp={pp}, cp={cp} exceeds world size 8") + + # Initialize model parallel + Utils.initialize_model_parallel(tp, pp, None, context_parallel_size=cp) + + try: + # Create mock data iterator with variable-length sequences + # Only TP rank 0 needs the iterator; other TP ranks pass None + tp_rank = parallel_state.get_tensor_model_parallel_rank() + if tp_rank == 0: + # Use deterministic seed based on DP rank so same data within TP/PP/CP group + dp_rank = parallel_state.get_data_parallel_rank() + sequence_lengths = [1024, 2048, 512, 1536, 3072] + assert ( + sum(sequence_lengths) == args.seq_length + ), f"Sequence lengths sum {sum(sequence_lengths)} != total {args.seq_length}" + data_iterator = iter( + MockVariableLengthSequencePackingDataIterator( + total_seq_length=args.seq_length, + sequence_lengths=sequence_lengths, # Variable lengths, sum=8192 + seed=42 + dp_rank, # Same seed within PP/CP group + ) + ) + else: + # Non-TP-rank-0 ranks don't need the iterator + data_iterator = None + + # Call the function under test + result = get_batch_on_this_rank_for_sequence_packing( + data_iterator=data_iterator, mtp_on_this_rank=False, vp_stage=None + ) + + # Unpack the result + tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = result + + # Get parallel state info + tp_rank = parallel_state.get_tensor_model_parallel_rank() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + cp_rank = parallel_state.get_context_parallel_rank() + is_first_stage = parallel_state.is_pipeline_first_stage(ignore_virtual=True) + is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + is_first_or_last = is_first_stage or is_last_stage + + # ===================================================================== + # TEST 1: Verify data based on pipeline stage + # ===================================================================== + if is_first_stage: + assert tokens is not None, "First stage should have tokens" + assert position_ids is not None, "First stage should have position_ids" + assert tokens.dim() == 2, "Tokens should be 2D (batch, seq)" + assert position_ids.dim() == 2, "Position IDs should be 2D (batch, seq)" + assert tokens.size(0) == 1, "batch should be 1 in THD format" + assert position_ids.size(0) == 1, "batch should be 1 in THD format" + else: + assert tokens is None, "Non-first stage should not have tokens" + assert position_ids is None, "Non-first stage should not have position_ids" + + if is_last_stage: + assert labels is not None, "Last stage should have labels" + assert loss_mask is not None, "Last stage should have loss_mask" + assert labels.dim() == 2, "Labels should be 2D (batch, seq)" + assert loss_mask.dim() == 2, "Loss mask should be 2D (batch, seq)" + assert labels.size(0) == 1, "batch should be 1 in THD format" + assert loss_mask.size(0) == 1, "batch should be 1 in THD format" + else: + assert labels is None, "Non-last stage should not have labels" + assert loss_mask is None, "Non-last stage should not have loss_mask" + + # ===================================================================== + # TEST 2: Verify all ranks have consistent packed_seq_params + # ===================================================================== + assert packed_seq_params is not None + assert packed_seq_params.qkv_format == "thd" + + test_keys = [ + "cu_seqlens_q", + "cu_seqlens_q_padded", + "max_seqlen_q", + "cu_seqlens_kv", + "cu_seqlens_kv_padded", + "max_seqlen_kv", + ] + for key in test_keys: + tensor = getattr(packed_seq_params, key) + assert tensor is not None + gathered_tensor = _gather_tensor_from_all_ranks(tensor) + for i in range(1, len(gathered_tensor)): + assert torch.equal( + gathered_tensor[0], gathered_tensor[i] + ), f"Rank 0 and rank {i} have different {key}" + + # ===================================================================== + # TEST 3: Verify TP ranks receive identical data after broadcast + # ===================================================================== + if tp > 1: + test_tensors = [] + if is_first_stage: + test_tensors.extend([tokens, position_ids]) + if is_last_stage: + test_tensors.extend([labels, loss_mask]) + + for tensor in test_tensors: + gathered_tensors = _gather_tensor_from_tp_group(tensor) + for i in range(1, tp): + assert torch.equal( + gathered_tensors[0], gathered_tensors[i] + ), f"TP rank 0 and rank {i} have different data" + + # ===================================================================== + # TEST 4: Verify CP partitioning + # ===================================================================== + if cp > 1: + # With CP, the sequence should be partitioned + expected_seq_len = args.seq_length // cp + + if is_first_stage: + actual_seq_len = tokens.shape[1] + assert ( + actual_seq_len == expected_seq_len + ), f"CP partitioned tokens have wrong shape: {actual_seq_len} != {expected_seq_len}" + + # Verify labels only if all CP ranks are at last stage + if is_last_stage: + actual_seq_len = labels.shape[1] + assert ( + actual_seq_len == expected_seq_len + ), f"CP partitioned labels have wrong shape: {actual_seq_len} != {expected_seq_len}" + + finally: + Utils.destroy_model_parallel() + unset_global_variables() + + +@pytest.mark.parametrize( + ("tp", "pp", "cp", "vpp", "scheduler_type"), + [ + (1, 1, 8, None, "dp_balanced"), + (2, 1, 4, None, "dp_balanced"), + (2, 4, 1, None, "dp_balanced"), + (2, 2, 1, None, "dp_balanced"), + (1, 4, 1, 4, "dp_balanced"), + ], +) +def test_wrap_dataloader(tp, pp, cp, vpp, scheduler_type): + ''' + Test wrap_dataloader function with different scheduler types. + ''' + args = SimpleNamespace() + args.tensor_model_parallel_size = tp + args.pipeline_model_parallel_size = pp + args.context_parallel_size = cp + args.virtual_pipeline_model_parallel_size = None + args.data_parallel_size = 8 // (tp * pp * cp) + args.seq_length = 8192 + args.max_seqlen_per_dp_cp_rank = 8192 + + # Skip invalid configurations + if args.data_parallel_size < 1: + raise ValueError(f"Invalid config: tp={tp}, pp={pp}, cp={cp} exceeds world size 8") + + def _create_single_sample(seq_len): + # hard code the padding size to 16 + pad_size = 16 + seq_len_padded = ((seq_len + pad_size - 1) // pad_size) * pad_size + device = torch.device("cuda", torch.cuda.current_device()) + tokens = torch.randint(0, 128, (seq_len_padded,), dtype=torch.int64, device=device) + labels = tokens + 1 + position_ids = torch.arange(seq_len_padded, dtype=torch.int64, device=device) + loss_mask = torch.ones(seq_len_padded, dtype=torch.float32, device=device) + loss_mask[0:seq_len] = 1 + loss_mask[seq_len:] = 0 + cu_seqlens = torch.tensor([0, seq_len_padded], dtype=torch.int32, device=device) + + return { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + 'cu_seqlens': cu_seqlens, + } + + # Initialize model parallel + Utils.initialize_model_parallel(tp, pp, vpp, context_parallel_size=cp) + + global_batch_size = 64 + micro_batch_size = 1 + nums = [random.randint(2048, args.seq_length) for _ in range(global_batch_size)] # 64 sequences + + config = SimpleNamespace() + config.max_seqlen_per_dp_cp_rank = args.max_seqlen_per_dp_cp_rank + config.microbatch_group_size_per_vp_stage = pp + config.virtual_pipeline_model_parallel_size = vpp + config.sequence_packing_scheduler = scheduler_type + + dp_rank = parallel_state.get_data_parallel_rank() + dp_size = parallel_state.get_data_parallel_world_size() + + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + is_pp_first = pp_rank == 0 + is_pp_last = pp_rank == pp - 1 + is_pp_first_or_last = is_pp_first or is_pp_last + is_tp_first = tp_rank == 0 + + num_micro_batches_old = global_batch_size // micro_batch_size // dp_size + + if is_tp_first and (is_pp_first or is_pp_last): + samples = [ + _create_single_sample(num) + for num in nums[dp_rank * num_micro_batches_old : (dp_rank + 1) * num_micro_batches_old] + ] + data_iterator = RerunDataIterator(iter(samples)) + else: + data_iterator = None + + if is_tp_first: + if vpp is not None and vpp > 1: + if is_pp_first: + data_iterator = [data_iterator] + [None for _ in range(vpp - 1)] + elif is_pp_last: + data_iterator = [None for _ in range(vpp - 1)] + [data_iterator] + else: + data_iterator = [None for _ in range(vpp)] + try: + # Call the function under test + ( + new_data_iterator, + num_micro_batches, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, + ) = wrap_data_iterator(data_iterator, config, num_micro_batches_old) + + # check the result + assert type(num_micro_batches) is int + assert ( + type(num_total_tokens_this_global_batch) is float + or type(num_total_tokens_this_global_batch) is np.float32 + ) + assert ( + type(sequence_square_sum_this_global_batch) is float + or type(sequence_square_sum_this_global_batch) is np.float32 + ) + + def _check_batch(batch_all, batch_keys): + for batch in batch_all: + assert set(batch_keys) <= set( + batch.keys() + ), f"batch keys: {set(batch.keys())} missing {set(batch_keys) - set(batch.keys())}" + for key in batch_keys: + assert batch[key] is not None + + if is_tp_first: + # CHECK KEYS + batch_keys = ["cu_seqlens", "max_seqlen", "cu_seqlens_padded"] + if vpp is not None and vpp > 1: + # check metadata for all stages (save batches to avoid re-consuming iterators) + all_stage_batches = [] + for temp_data_iterator in new_data_iterator: + stage_batch = [next(temp_data_iterator) for _ in range(num_micro_batches)] + all_stage_batches.append(stage_batch) + _check_batch(stage_batch, batch_keys) + + # check for first or last stage on first or last pp rank + if is_pp_first_or_last: + batch_all = all_stage_batches[0] if is_pp_first else all_stage_batches[-1] + batch_keys += ["tokens", "position_ids", "labels", "loss_mask"] + _check_batch(batch_all, batch_keys) + else: + # non-VPP: single iterator + batch_all = [next(new_data_iterator) for _ in range(num_micro_batches)] + if is_pp_first_or_last: + batch_keys += ["tokens", "position_ids", "labels", "loss_mask"] + _check_batch(batch_all, batch_keys) + + # CHECK TOKEN SUM ON FIRST OR LAST PP RANK + # Note: data_iterator is consumed by wrap_data_iterator, new_data_iterator is consumed above. + # Use `samples` for before-wrap, reuse `batch_all` from the check above for after-wrap. + if is_pp_first_or_last: + # Compute token sum before wrap + token_sum_before = torch.tensor(0, dtype=torch.int64, device='cuda') + for sample in samples: + token_sum_before += sample['tokens'].long().sum() + + # Compute token sum after wrap (batch_all already collected above with tokens) + token_sum_after = torch.tensor(0, dtype=torch.int64, device='cuda') + for batch in batch_all: + token_sum_after += batch['tokens'].long().sum() + + # Reduce sum across dp_cp group and verify equality + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=False) + torch.distributed.all_reduce( + token_sum_before, op=torch.distributed.ReduceOp.SUM, group=dp_cp_group + ) + torch.distributed.all_reduce( + token_sum_after, op=torch.distributed.ReduceOp.SUM, group=dp_cp_group + ) + + assert ( + token_sum_before == token_sum_after + ), f"Token sum mismatch: before={token_sum_before.item()}, after={token_sum_after.item()}" + + else: + if vpp is not None and vpp > 1: + assert type(new_data_iterator) is list and len(new_data_iterator) == vpp + for data_iterator in new_data_iterator: + assert data_iterator is None + else: + assert new_data_iterator is None + + finally: + Utils.destroy_model_parallel() + unset_global_variables()