diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 0f016473b6a..deef93b8ebd 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -1,10 +1,22 @@ # Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. -from typing import Any, List, Optional +import enum +from typing import Any, Dict, List, Optional, Type import torch from megatron.core import parallel_state +from megatron.core.datasets.data_schedule_utils import ( + broadcast_scalars, + broadcast_tensor, + broadcast_to_pp_group, + build_packed_microbatches, + create_data_iterator, + get_batch_and_global_seqlens, + reroute_samples_to_dcp_ranks, +) +from megatron.core.extensions.transformer_engine import get_thd_partitioned_indices +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler from megatron.core.process_groups_config import ProcessGroupCollection @@ -57,7 +69,6 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: Gathers the sequence lengths of all subsamples from all DP ranks. Each DP rank loads the same number of microbatches but each microbatch may have a different number of subsamples. - We find the number of subsamples each rank holds and then gather the sequence lengths of all subsamples from all ranks. """ @@ -299,3 +310,564 @@ def __next__(self) -> Any: batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets ) return samples_this_rank_with_id, sample_id_groups + + +class BasePackingScheduler: + """Base class for sequence packing schedulers.""" + + def __init__( + self, + max_seqlen_per_dp_cp_rank: int, + cp_size: int, + dp_size: int, + microbatch_group_size_per_vp_stage: Optional[int], + ): + """ + Args: + max_seqlen_per_dp_cp_rank: The maximum sequence length per DPxCP rank. + cp_size: The context parallel size. + dp_size: The data parallel size. + microbatch_group_size_per_vp_stage: The microbatch group size per virtual + pipeline stage, only used when enabling VPP, otherwise None. + """ + self.max_seqlen_per_dp_cp_rank = max_seqlen_per_dp_cp_rank + self.cp_size = cp_size + self.dp_size = dp_size + self.microbatch_group_size_per_vp_stage = microbatch_group_size_per_vp_stage + + def get_required_sample_keys(self): + """Return the required key of each batch.""" + raise NotImplementedError + + def get_groups_and_subsamples(self, sample_id_seqlens): + """schedule the samples into groups""" + raise NotImplementedError + + def run( + self, + data_iterator, + num_microbatches, + dp_group, + tp_group, + pp_group, + dp_cp_group, + dev, + config, + ): + """ + Run the scheduler and return the new data_iterator. + + Args: + data_iterator: The data iterator. + num_microbatches: The number of microbatches to fetch. + dp_group: Data parallel process group. + tp_group: Tensor parallel process group. + pp_group: Pipeline parallel process group. + dp_cp_group: Data parallel + context parallel process group. + dev: CUDA device. + config: Model parallel config. + + Returns: + new_data_iterator: The new data iterator (or list for VPP). + num_micro_batches: Number of micro batches after scheduling. + seqlen_sum_this_global_batch: Total tokens for FLOPs calculation. + seqlen_squared_sum_this_global_batch: Sum of squared seqlens for FLOPs. + """ + raise NotImplementedError + + +class DpBalancedScheduler(BasePackingScheduler): + """Packs sequences in their original order until reaching the max limit of sequence length.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_seq_len_all_ranks = self.max_seqlen_per_dp_cp_rank * self.cp_size + + def get_required_sample_keys(self): + """Return the required key of each batch.""" + return [ + "tokens", + "labels", + "loss_mask", + "position_ids", + "original_seq_len", # Length of the original sequence length, should be a gpu tensor. + "padded_seq_len", # Length of the padded sequence length, should be a gpu tensor. + ] + + def get_groups_and_subsamples(self, sample_id_seqlens): + """ + Packs sequences in their original order until reaching the max limit of sequence length. + """ + sample_id_groups = [] + packed_id_groups = [] + sum_seqlen = 0 + single_microbatch = [] + + for i in range(len(sample_id_seqlens)): + if sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks: + single_microbatch.append(i) + sum_seqlen += sample_id_seqlens[i][1] + else: + packed_id_groups.append(single_microbatch) + single_microbatch = [i] + sum_seqlen = sample_id_seqlens[i][1] + if len(single_microbatch) > 0: + packed_id_groups.append(single_microbatch) + + # we want the number of packed sequences to be multiple of dp_size + # so we move few samples from previous microbatch + # to the end of the microbatches if needed + num_packed_sequence = len(packed_id_groups) + + # when enabling vpp, we want the number of packed sequences to be + # multiple of dp_size * microbatch_group_size_per_vp_stage + multiple = self.dp_size * ( + self.microbatch_group_size_per_vp_stage + if self.microbatch_group_size_per_vp_stage is not None + else 1 + ) + if num_packed_sequence % multiple != 0: + remainder = num_packed_sequence % multiple + num_to_move = multiple - remainder + i = num_packed_sequence - 1 + while num_to_move > 0: + assert i > 0, "Not enough samples to move" + if len(packed_id_groups[i]) > 1: + seq_id = packed_id_groups[i].pop() + packed_id_groups.append([seq_id]) + num_to_move -= 1 + else: + i -= 1 + + num_micro_batches = int(len(packed_id_groups) / self.dp_size) + for i in range(num_micro_batches): + sample_id_groups.append([]) + for j in range(self.cp_size * self.dp_size): + seq_id = int(i * self.dp_size + j / self.cp_size) + sample_id_groups[i].append(packed_id_groups[seq_id]) + return sample_id_groups + + def run( + self, + data_iterator, + num_microbatches: int, + dp_group, + tp_group, + pp_group, + dp_cp_group, + dev: torch.device, + config, + ): + """ + Run the complete scheduling pipeline. + + Steps: + 1. Fetch batches and gather global sequence lengths + 2. Check required sample keys + 3. Schedule samples into groups + 4. Reroute samples to DCP ranks + 5. Build packed microbatches + 6. Calculate FLOPs info + 7. Broadcast to PP group (for middle PP stages) + 8. Broadcast to TP group (for non-TP-0 ranks) + 9. Handle VPP if enabled + + Args: + data_iterator: The data iterator. + num_microbatches: The number of microbatches to fetch. + dp_group: Data parallel process group. + tp_group: Tensor parallel process group. + pp_group: Pipeline parallel process group. + dp_cp_group: Data parallel + context parallel process group. + dev: CUDA device. + config: Model parallel config. + + Returns: + new_data_iterator: The new data iterator (or list for VPP). + num_micro_batches: Number of micro batches after scheduling. + seqlen_sum_this_global_batch: Total tokens for FLOPs calculation. + seqlen_squared_sum_this_global_batch: Sum of squared seqlens for FLOPs. + """ + + total_dcp_gpus = dp_cp_group.size() + + # Handle VPP: extract the correct data_iterator for this PP stage + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + # if enable VPP, data_iterator is a list of data_iterators for each VPP stage, + # and only the first and last stage rank will have data_iterator, + # other stages will have None. + assert len(data_iterator) == config.virtual_pipeline_model_parallel_size + if pp_group.rank() == 0: + # the first stage + data_iterator = data_iterator[0] + elif pp_group.rank() == pp_group.size() - 1: + # the last stage + data_iterator = data_iterator[-1] + else: + data_iterator = None + + # data_iterator is not None when TP rank 0, with PP stage 0 or -1. + if data_iterator is not None: + assert tp_group.rank() == 0 and ( + pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1 + ), f"Only TP rank 0 and PP stage 0 or -1 should have data_iterator" + + # Step 1: Fetch batches and gather global sequence lengths + batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered = ( + get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group) + ) + + # Step 2: Check required sample keys + for key in self.get_required_sample_keys(): + assert ( + key in batch[0] + ), f"Batch missing required key {key}, provided keys: {batch[0].keys()}" + + # Step 3: Schedule samples into groups + sample_id_groups = self.get_groups_and_subsamples(global_id_seqlens) + + # Validate scheduling result + set_gbs = set() + for group in sample_id_groups: + for sub in group: + set_gbs.update(sub) + assert len(set_gbs) == len(global_id_seqlens), ( + f"set_gbs length: {len(set_gbs)} != " + f"global_id_seqlens length: {len(global_id_seqlens)}" + ) + + # Step 4: Reroute samples to DCP ranks + samples_this_rank_with_id = reroute_samples_to_dcp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_dcp_gpus, + ) + + dcp_rank = dp_cp_group.rank() + num_micro_batches = len(sample_id_groups) + + grouped_samples = [ + [ + samples_this_rank_with_id[sub_sample_id] + for sub_sample_id in sample_id_groups[i][dcp_rank] + ] + for i in range(num_micro_batches) + ] + + # Step 5: Build packed microbatches + new_samples = build_packed_microbatches(grouped_samples, dev) + + # Step 6: Calculate FLOPs info + seqlen_sum_this_global_batch = float(sum(seqlens_gathered)) + seqlen_squared_sum_this_global_batch = float( + sum(seqlen**2 for seqlen in seqlens_gathered) + ) + else: + ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = (None, None, None, None) + + # Step 7: Broadcast to PP group (for middle PP stages) + if tp_group.rank() == 0: + ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = broadcast_to_pp_group( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + pp_group, + dev, + ) + + # Step 8: Broadcast to TP group (for non-TP-0 ranks) + (num_micro_batches, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) = ( + broadcast_scalars( + [ + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ], + tp_group, + dev, + ) + ) + num_micro_batches = int(num_micro_batches) + + # Step 9: create data_iterator and handle VPP if enabled + new_data_iterator = create_data_iterator(new_samples, pp_group, tp_group, config) + + return ( + new_data_iterator, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + +class PackingSchedulerEnum(enum.Enum): + """Enum for supported sequence packing algorithms.""" + + DP_BALANCED = "dp_balanced" + + +scheduler_map: Dict[PackingSchedulerEnum, Type[BasePackingScheduler]] = { + PackingSchedulerEnum.DP_BALANCED: DpBalancedScheduler +} + + +def wrap_data_iterator( + data_iterator, config, num_microbatches, pg_collection: Optional[ProcessGroupCollection] = None +): + """ + A wrapper function that wraps around an existing data_iterator + and return the num_micro_batches for sequence packing. + + Args: + data_iterator: The original data_iterator to wrap around + config: The config object containing the max_seqlen_per_dp_cp_rank + dp_cp_group: Data parallel context parallel group. + pg_collection: The process group collection. + """ + + if pg_collection is None: + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + dp_group = parallel_state.get_data_parallel_group() + tp_group = parallel_state.get_tensor_model_parallel_group() + pp_group = parallel_state.get_pipeline_model_parallel_group() + else: + dp_cp_group = pg_collection.dp_cp + dp_group = pg_collection.dp + tp_group = pg_collection.tp + pp_group = pg_collection.pp + assert ( + dp_cp_group is not None + and dp_group is not None + and tp_group is not None + and pp_group is not None + ), "dp_cp_group, dp_group, tp_group must not be None when using sequence packing" + + dev = torch.cuda.current_device() + dp_size = dp_group.size() + cp_size = dp_cp_group.size() // dp_size + + # Convert string to enum + scheduler_type = config.sequence_packing_scheduler + scheduler_type = PackingSchedulerEnum[scheduler_type.upper()] + + scheduler = scheduler_map[scheduler_type]( + config.max_seqlen_per_dp_cp_rank, + cp_size, + dp_size, + # When VPP is enabled, align num_micro_batches to this multiple. + ( + None + if config.virtual_pipeline_model_parallel_size is None + else config.microbatch_group_size_per_vp_stage + ), + ) + + ( + new_data_iterator, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = scheduler.run( + data_iterator, num_microbatches, dp_group, tp_group, pp_group, dp_cp_group, dev, config + ) + + return ( + new_data_iterator, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + +def get_batch_on_this_rank_for_sequence_packing( + data_iterator, + vpp_size: Optional[int] = None, + mtp_on_this_rank: bool = False, + vp_stage: Optional[int] = None, + pg_collection: Optional[ProcessGroupCollection] = None, +): + """ + Get a batch of data for sequence packing. + Args: + data_iterator (Iterator): The data iterator to get the batch from. + mtp_on_this_rank (bool): Whether to use multi-token prediction. + vp_stage (Optional[int]): The stage of the pipeline. + Returns: + tuple of (tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params) + """ + + if pg_collection is None: + tp_group = parallel_state.get_tensor_model_parallel_group() + pp_group = parallel_state.get_pipeline_model_parallel_group() + cp_group = parallel_state.get_context_parallel_group() + else: + tp_group = pg_collection.tp + pp_group = pg_collection.pp + cp_group = pg_collection.cp + + tp_src_rank = torch.distributed.get_process_group_ranks(tp_group)[0] + + is_tp_rank_0 = tp_group.rank() == 0 + is_first_stage = pp_group.rank() == 0 and (vp_stage is None or vp_stage == 0) + is_last_stage = pp_group.rank() == pp_group.size() - 1 and ( + vp_stage is None or vp_stage == vpp_size - 1 + ) + + is_first_or_last_stage = is_first_stage or is_last_stage + dev = torch.cuda.current_device() + + # data_iterator should return a batch including the following keys. + batch_keys = ['cu_seqlens', 'cu_seqlens_padded', 'max_seqlen'] + if is_first_stage: + batch_keys.append('tokens') + batch_keys.append('position_ids') + if is_last_stage: + batch_keys.append('labels') + batch_keys.append('loss_mask') + + # Get a batch from data_iterator or create an emtpy batch. + if is_tp_rank_0: + assert data_iterator is not None + batch = next(data_iterator) + for key in batch_keys: + assert key in batch, f"{key} is missing in current batch." + else: + assert data_iterator is None, "Non TP 0 rank should not have data_iterator" + batch = {} + + # Partition tokens, position_ids, labels, loss_mask for context parallel, currently only + # TP rank 0 and the first/last PP stage rank has these data. + if is_tp_rank_0 and is_first_or_last_stage: + cp_size = cp_group.size() + cp_rank = cp_group.rank() + # If cp_size == 1, no need to do further processing. + if cp_size > 1: + total_tokens = batch['tokens'].size(0) + # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as + # cu_seqlens to get the correct result. + # TODO: Revert this workaround once TE fixes the issue. + cu_seqlens = batch["cu_seqlens_padded"] + index = get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) + for key in ['tokens', 'position_ids', 'labels', 'loss_mask']: + batch[key] = batch[key].index_select(0, index) + + # Broadcast cu_seqlens_size because we need it to create placeholder for cu_seqlens and + # cu_seqlens_padded for non TP 0 ranks. + if is_tp_rank_0: + cu_seqlen_size = torch.tensor(batch['cu_seqlens'].size(0), dtype=torch.int32, device=dev) + else: + cu_seqlen_size = torch.empty(1, dtype=torch.int32, device=dev) + broadcast_tensor(cu_seqlen_size, tp_src_rank, tp_group) + cu_seqlen_size = cu_seqlen_size.item() + + # Broadcast total_tokens because we need it to create placeholder for tokens, position_ids, + # labels, loss_mask for non TP 0 ranks. Only first or last stage need this. + if is_first_or_last_stage: + if is_tp_rank_0: + total_tokens = torch.tensor(batch['tokens'].size(0), dtype=torch.int32, device=dev) + else: + total_tokens = torch.empty(1, dtype=torch.int32, device=dev) + broadcast_tensor(total_tokens, tp_src_rank, tp_group) + total_tokens = total_tokens.item() + + # Step1: Prepare "tokens", "position_ids" on all ranks. + if is_first_stage or mtp_on_this_rank: + if is_tp_rank_0: + assert batch['tokens'].dtype == torch.int64 + assert batch['position_ids'].dtype == torch.int64 + batch['tokens'] = batch['tokens'].view(1, total_tokens) + batch['position_ids'] = batch['position_ids'].view(1, total_tokens) + else: + batch['tokens'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + batch['position_ids'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + else: + # Non first stage rank doesn't need tokens and position_ids. + batch['tokens'] = None + batch['position_ids'] = None + + # Step2: Prepare "labels", "loss_mask" on all ranks. + if is_last_stage: + if is_tp_rank_0: + assert batch['labels'].dtype == torch.int64 + assert batch['loss_mask'].dtype == torch.float32 + batch['labels'] = batch['labels'].view(1, total_tokens) + batch['loss_mask'] = batch['loss_mask'].view(1, total_tokens) + else: + batch['labels'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + batch['loss_mask'] = torch.empty([1, total_tokens], dtype=torch.float32, device=dev) + else: + # Non last stage rank doesn't need labels and loss_mask. + batch['labels'] = None + batch['loss_mask'] = None + + # Step3: Prepare "cu_seqlens", "cu_seqlens_padded", "max_seqlen" on all ranks. + if is_tp_rank_0: + assert batch['cu_seqlens'].dtype == torch.int32 + assert batch['cu_seqlens_padded'].dtype == torch.int32 + assert batch['cu_seqlens'].dim() == 1 + assert batch['cu_seqlens_padded'].dim() == 1 + if type(batch['max_seqlen']) == int: + batch['max_seqlen'] = torch.tensor(batch['max_seqlen'], dtype=torch.int32, device=dev) + else: + assert batch['max_seqlen'].dtype == torch.int32 + assert batch['max_seqlen'].numel() == 1 + else: + batch['cu_seqlens'] = torch.empty([cu_seqlen_size], dtype=torch.int32, device=dev) + batch['cu_seqlens_padded'] = torch.empty([cu_seqlen_size], dtype=torch.int32, device=dev) + batch['max_seqlen'] = torch.empty(1, dtype=torch.int32, device=dev) + + # Broadcast batch inside TP group. + broadcast_tensor(batch['tokens'], tp_src_rank, tp_group) + broadcast_tensor(batch['position_ids'], tp_src_rank, tp_group) + broadcast_tensor(batch['labels'], tp_src_rank, tp_group) + broadcast_tensor(batch['loss_mask'], tp_src_rank, tp_group) + broadcast_tensor(batch['cu_seqlens'], tp_src_rank, tp_group) + broadcast_tensor(batch['cu_seqlens_padded'], tp_src_rank, tp_group) + broadcast_tensor(batch['max_seqlen'], tp_src_rank, tp_group) + + # Extract the data from batch after broadcasting. + tokens = batch['tokens'] + position_ids = batch['position_ids'] + labels = batch['labels'] + loss_mask = batch['loss_mask'] + cu_seqlens = batch['cu_seqlens'] + cu_seqlens_padded = batch['cu_seqlens_padded'] + max_seqlen = batch['max_seqlen'].item() + + # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as cu_seqlens to + # get the correct result. + # TODO: Revert this workaround once TE fixes the issue. + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + local_cp_size=None, + cp_group=None, + ) + + # "attention_mask" is not valid for sequence packing, so set it to None. + return tokens, labels, loss_mask, None, position_ids, packed_seq_params diff --git a/megatron/core/datasets/data_schedule_utils.py b/megatron/core/datasets/data_schedule_utils.py new file mode 100644 index 00000000000..2e5635dfcd8 --- /dev/null +++ b/megatron/core/datasets/data_schedule_utils.py @@ -0,0 +1,492 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. + +from typing import Dict, List + +import numpy as np +import torch + +from megatron.core.rerun_state_machine import RerunDataIterator + + +def _unpack_batch(batch: List[Dict[str, torch.Tensor]]) -> List[Dict[str, torch.Tensor]]: + """ + Unpacks the packed samples into a list of sub-samples. + Since each sub-sample may be routed to different DPxCP ranks, + we unpack the sample here to avoid unnecessarily transferring + the entire packed sample. + """ + batch_unpacked = [] + dev = batch[0]["tokens"].device + original_seq_lens = [] + padded_seq_lens = [] + for sample in batch: + for key in sample.keys(): + if len(sample[key].shape) == 2: + # squeeze the redundant batch dimension added by + # default collate_fn in pytorch dataloader + # we need a custom collate_fn for THD to avoid this + # current THD does not support micro_batch_size > 1 due to sft_dataset.py and + # data_loader in data_samples.py + sample[key] = sample[key].squeeze(0) + for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): + sub_sample_dict = {} + start_idx = sample["cu_seqlens"][sub_sample] + end_idx = sample["cu_seqlens"][sub_sample + 1] + if end_idx - start_idx == 0: + continue + for key in ["tokens", "labels", "loss_mask", "position_ids"]: + sub_sample_dict[key] = sample[key][start_idx:end_idx] + # Since sft_dataset.py does not provide cu_seqlens_original, + # we assume original_seq_len equals padded_seq_len here. + # Ideally the dataset should define the pre-padding seq_len. + seq_len = (end_idx - start_idx).item() + original_seq_lens.append(seq_len) + padded_seq_lens.append(seq_len) + batch_unpacked.append(sub_sample_dict) + + # Single H2D transfer for all seq lens + original_seq_lens_cuda = torch.tensor(original_seq_lens, device=dev) + padded_seq_lens_cuda = torch.tensor(padded_seq_lens, device=dev) + for i, sub_sample_dict in enumerate(batch_unpacked): + sub_sample_dict["original_seq_len"] = original_seq_lens_cuda[i : i + 1] + sub_sample_dict["padded_seq_len"] = padded_seq_lens_cuda[i : i + 1] + + return batch_unpacked + + +def _get_global_seqlens_and_ids(subsample_seqlens: torch.Tensor, dp_group): + """ + Gathers the sequence lengths of all subsamples from all DP ranks and calculates global IDs. + """ + # Collect the number of subsamples from all ranks + num_local_subsamples = subsample_seqlens.shape[0] + local_len = torch.tensor([num_local_subsamples], dtype=torch.int32).cuda() + dp_subsample_count = [torch.zeros_like(local_len) for _ in range(dp_group.size())] + torch.distributed.all_gather(dp_subsample_count, local_len, group=dp_group) + + # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length + dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) + max_sub_samples = int(dp_subsample_counts.max().item()) + + if num_local_subsamples < max_sub_samples: + subsample_seqlens_padded = torch.cat( + [ + subsample_seqlens, + torch.zeros(max_sub_samples - num_local_subsamples, dtype=torch.int32).cuda(), + ], + dim=0, + ) + else: + subsample_seqlens_padded = subsample_seqlens + + # Gather the subsample_seqlens from all ranks + seqlens_gathered = [torch.empty_like(subsample_seqlens_padded) for _ in range(dp_group.size())] + torch.distributed.all_gather(seqlens_gathered, subsample_seqlens_padded, group=dp_group) + + # Trim each seqlens_gathered to the length of the correct sample + for dp_rank, seqlen in enumerate(seqlens_gathered): + seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]] + + seqlens_gathered = torch.cat(seqlens_gathered, dim=0) + seqlens_gathered = seqlens_gathered.cpu().tolist() + + # Calculate the offsets to assign unique global ID to each subsample. + csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32) + offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum], dim=0) + + # Calculate global ID for each subsample + dp_rank = dp_group.rank() + global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() + + # Create a list of (global_id, seqlen) tuples for scheduling + global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] + + # Get the global IDs locally present on this rank + start_idx = offsets[dp_rank] + end_idx = offsets[dp_rank + 1] + + global_ids_this_rank = global_ids[start_idx:end_idx] + + return global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered + + +def _pack_sequences( + samples: List, padded_lengths: torch.Tensor, original_lengths: torch.Tensor, dev: torch.device +) -> Dict[str, torch.Tensor]: + """Pack multiple samples into a single packed sample.""" + + def _pack_tensors(tensors): + return torch.cat([t.reshape(-1) for t in tensors], dim=0) + + tokens = _pack_tensors([sample["tokens"] for sample in samples]) + labels = _pack_tensors([sample["labels"] for sample in samples]) + loss_mask = _pack_tensors([sample["loss_mask"] for sample in samples]) + position_ids = _pack_tensors([sample["position_ids"] for sample in samples]) + + new_sample = {} + new_sample["tokens"] = tokens + new_sample["labels"] = labels + new_sample["loss_mask"] = loss_mask + new_sample["position_ids"] = position_ids + + padded_lengths = padded_lengths.to(device=dev, dtype=torch.int32, non_blocking=True).reshape(-1) + cu_seqlens_padded = torch.empty(padded_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens_padded[0] = 0 + cu_seqlens_padded[1:] = torch.cumsum(padded_lengths, dim=0) + max_seqlen = torch.max(padded_lengths).to(dtype=torch.int32) + + new_sample["cu_seqlens_padded"] = cu_seqlens_padded + new_sample["max_seqlen"] = max_seqlen + + original_lengths = original_lengths.to( + device=dev, dtype=torch.int32, non_blocking=True + ).reshape(-1) + cu_seqlens = torch.empty(original_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens[0] = 0 + cu_seqlens[1:] = torch.cumsum(original_lengths, dim=0).reshape(-1) + new_sample["cu_seqlens"] = cu_seqlens + + return new_sample + + +def broadcast_tensor(item, src_rank, group) -> None: + """Broadcast a tensor from src_rank to all ranks in the group.""" + if item is not None: + torch.distributed.broadcast(item, src_rank, group=group) + + +def broadcast_to_pp_group( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + pp_group, + dev, +): + """ + Broadcast num_micro_batches, seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch and metadata to middle PP stages. + Before this broadcast, the new_samples on middle PP stages are None, + after this broadcast, the new_samples on middle PP stages contain the metadata but + without tokens, labels, loss_mask, position_ids. + """ + + pp_src_rank = torch.distributed.get_process_group_ranks(pp_group)[0] + + if pp_group.size() > 2: + if pp_group.rank() == 0: + tensor_list = [ + torch.tensor( + [ + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ], + dtype=torch.float32, + ).cuda() + ] + for sample in new_samples: + tensor_list.append(sample["max_seqlen"].unsqueeze(0)) + for sample in new_samples: + tensor_list.append(sample["cu_seqlens"]) + tensor_list.append(sample["cu_seqlens_padded"]) + info_to_broadcast = torch.cat(tensor_list, dim=0).to(device=dev, dtype=torch.float32) + info_length_tensor = torch.tensor(info_to_broadcast.shape[0], dtype=torch.int32).cuda() + broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) + broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) + else: + info_length_tensor = torch.tensor(0, dtype=torch.int32).cuda() + broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) + info_to_broadcast = torch.empty(info_length_tensor.item(), dtype=torch.float32).cuda() + broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) + if pp_group.rank() != pp_group.size() - 1: + # middle PP stages receive the broadcasted info and unpack it + info_numpy = info_to_broadcast.cpu().numpy() + num_micro_batches = int(info_numpy[0]) + seqlen_sum_this_global_batch = info_numpy[1] + seqlen_squared_sum_this_global_batch = info_numpy[2] + max_seqlens = info_to_broadcast[3 : 3 + num_micro_batches] + cu_seqlens_list = [] + cu_seqlens_padded_list = [] + indices = np.where(info_numpy == 0)[0] + for i in range(num_micro_batches): + cu_seqlens_list.append(info_to_broadcast[indices[i * 2] : indices[i * 2 + 1]]) + if i == num_micro_batches - 1: + cu_seqlens_padded_list.append(info_to_broadcast[indices[i * 2 + 1] :]) + else: + cu_seqlens_padded_list.append( + info_to_broadcast[indices[i * 2 + 1] : indices[i * 2 + 2]] + ) + + new_samples = [] + for i in range(num_micro_batches): + new_sample = {} + new_sample["max_seqlen"] = max_seqlens[i].to(torch.int32) + new_sample["cu_seqlens"] = cu_seqlens_list[i].to(torch.int32) + new_sample["cu_seqlens_padded"] = cu_seqlens_padded_list[i].to(torch.int32) + new_samples.append(new_sample) + + return ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + +def broadcast_scalars(values: List, group, dev, dtype=torch.float32) -> List: + """ + Broadcast scalar values from rank 0 to all ranks in the group. + + Args: + values: List of scalar values to broadcast (only used on rank 0). + group: The process group to broadcast within. + dev: The device to use for the tensor. + dtype: The data type for the tensor. + + Returns: + List of broadcasted values. + """ + if group.size() <= 1: + return values + + src_rank = torch.distributed.get_process_group_ranks(group)[0] + num_values = len(values) + + if group.rank() == 0: + info_to_broadcast = torch.tensor(values, dtype=dtype, device=dev) + else: + info_to_broadcast = torch.zeros(num_values, dtype=dtype, device=dev) + + broadcast_tensor(info_to_broadcast, src_rank, group) + + if group.rank() != 0: + values = info_to_broadcast.cpu().tolist() + + return values + + +def create_data_iterator(new_samples, pp_group, tp_group, config): + """Handle virtual pipeline parallelism.""" + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + vpp_size = config.virtual_pipeline_model_parallel_size + if tp_group.rank() == 0: + if pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1: + metadata = [ + {k: sample[k] for k in ["max_seqlen", "cu_seqlens", "cu_seqlens_padded"]} + for sample in new_samples + ] + if pp_group.rank() == 0: + new_data_iterator = [RerunDataIterator(iter(new_samples))] + [ + RerunDataIterator(iter(metadata)) for _ in range(vpp_size - 1) + ] + else: + new_data_iterator = [ + RerunDataIterator(iter(metadata)) for _ in range(vpp_size - 1) + ] + [RerunDataIterator(iter(new_samples))] + else: + # on middle PP stages, the new_samples are the metadata + metadata = new_samples + new_data_iterator = [RerunDataIterator(iter(metadata)) for _ in range(vpp_size)] + else: + new_data_iterator = [None for _ in range(vpp_size)] + else: + new_data_iterator = RerunDataIterator(iter(new_samples)) if tp_group.rank() == 0 else None + + return new_data_iterator + + +def reroute_samples_to_dcp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_dcp_gpus, +): + """ + Reroutes the sub-samples to the correct rank after scheduling. + + For each key in the batch dict, we perform an all-to-all communication + to transfer the data to the correct ranks. + """ + + def _gid_to_src_rank(gid: int) -> int: + dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) + dcp_rank = ( + torch.distributed.get_process_group_ranks(dp_group)[dp_src_rank] // tp_group.size() + ) % dp_cp_group.size() + return dcp_rank + + gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} + dcp_rank = dp_cp_group.rank() + dp_ranks = torch.distributed.get_process_group_ranks(dp_group) + dp_ranks = [(r // tp_group.size()) % dp_cp_group.size() for r in dp_ranks] + + data_keys = batch[0].keys() + + # Create the send plan + combined_sample_id_groups: List[List[int]] = [[] for _ in range(total_dcp_gpus)] + for d in range(total_dcp_gpus): + for sample_id_group in sample_id_groups: + combined_sample_id_groups[d].extend(sample_id_group[d]) + for dest_rank in range(total_dcp_gpus): + combined_sample_id_groups[dest_rank].sort() + + send_ids_sorted = [ + gid for d in dp_ranks for gid in combined_sample_id_groups[d] if gid in global_ids_this_rank + ] + + send_num_split = [0] * total_dcp_gpus + send_lens_split = [0] * total_dcp_gpus + for dest_rank in range(total_dcp_gpus): + if dest_rank in dp_ranks: + send_seq_lens = [ + global_id_seqlens[gid][1] + for gid in combined_sample_id_groups[dest_rank] + if gid in global_ids_this_rank + ] + send_num_split[dest_rank] = len(send_seq_lens) + send_lens_split[dest_rank] = sum(send_seq_lens) + else: + send_lens_split[dest_rank] = 0 + + # Create the recv plan + recv_sample_id_groups = [[] for _ in range(total_dcp_gpus)] + for gid in combined_sample_id_groups[dcp_rank]: + src_rank = _gid_to_src_rank(gid) + recv_sample_id_groups[src_rank].append(gid) + + recv_lens_split = [0] * total_dcp_gpus + for src_rank in range(total_dcp_gpus): + recv_lens_split[src_rank] = sum( + [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] + ) + + recv_ids_sorted = [gid for d in range(total_dcp_gpus) for gid in recv_sample_id_groups[d]] + recv_counts = [len(recv_sample_id_groups[d]) for d in range(total_dcp_gpus)] + + recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] + + def _pack_sample_by_key(key: str) -> torch.Tensor: + flattened_tensors = [] + for gid in send_ids_sorted: + t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) + flattened_tensors.append(t.reshape(-1)) + return ( + torch.cat(flattened_tensors, dim=0) + if flattened_tensors + else torch.empty(1, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + ) + + def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): + cursor = 0 + for i, gid in enumerate(recv_ids_sorted): + sample_len = ( + 1 if key in ["original_seq_len", "padded_seq_len"] else global_id_seqlens[gid][1] + ) + recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] + cursor += sample_len + + for key in data_keys: + output_split_sizes, input_split_sizes = ( + (recv_counts, send_num_split) + if key in ["original_seq_len", "padded_seq_len"] + else (recv_lens_split, send_lens_split) + ) + send_tensor = _pack_sample_by_key(key) + recv_tensor_size = sum(output_split_sizes) + recv_tensor = torch.empty( + recv_tensor_size, device=torch.cuda.current_device(), dtype=send_tensor.dtype + ) + torch.distributed.all_to_all_single( + output=recv_tensor, + input=send_tensor, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=dp_cp_group, + ) + _unpack_sample_by_key(key, recv_tensor) + + recv_sample_with_id = {recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted)} + return recv_sample_with_id + + +def build_packed_microbatches( + grouped_samples: List[List[Dict[str, torch.Tensor]]], dev: torch.device +) -> List[Dict[str, torch.Tensor]]: + """Build packed samples for each microbatch.""" + num_micro_batches = len(grouped_samples) + seg_starts: List[int] = [0] + original_lens_tensors = [] + padded_lens_tensors = [] + + for i in range(num_micro_batches): + samples = grouped_samples[i] + seg_starts.append(seg_starts[-1] + len(samples)) + original_lens_tensors.extend([s["original_seq_len"].reshape(-1) for s in samples]) + padded_lens_tensors.extend([s["padded_seq_len"].reshape(-1) for s in samples]) + + padded_lens_all_gpu = torch.cat(padded_lens_tensors, dim=0).to(dtype=torch.int32) + original_lens_all_gpu = torch.cat(original_lens_tensors, dim=0).to(dtype=torch.int32) + + new_samples: List[Dict[str, torch.Tensor]] = [] + for i in range(num_micro_batches): + samples = grouped_samples[i] + lens_padded = padded_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + lens_original = original_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + new_sample = _pack_sequences(samples, lens_padded, lens_original, dev) + new_samples.append(new_sample) + + return new_samples + + +def get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group): + """ + Get the batch and global sequence lengths. + Each DP rank loads the same number of sequences, so we need to gather the sequence + lengths from all ranks then we can schedule the sequences into groups. + Args: + data_iterator: The data iterator. + num_microbatches: The number of microbatches. + dp_group: The data parallel group. + + Returns: + batch: The batch. + global_id_seqlens: The global sequence lengths. + global_ids_this_rank: The global IDs locally present on this rank. + """ + + batch_list = [next(data_iterator) for _ in range(num_microbatches)] + + batch = [] + for item in batch_list: + if isinstance(item, dict): + batch.append(item) + elif isinstance(item, list): + batch.extend(item) + else: + raise ValueError(f"Invalid item type: {type(item)}") + + # in sft_dataset.py, sequences are already packed before rescheduling, + # so we need to unpack them here and repack after rescheduling. + # This is only to adapt to the current megatron-lm sft_dataset. + # If you implement your own dataset, just have __getitem__ return List[Dict] + # and this step can be skipped. + batch = _unpack_batch(batch) + + subsample_seqlens = torch.cat([sample["padded_seq_len"] for sample in batch]).to( + dtype=torch.int32, device=torch.cuda.current_device() + ) + + global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered = ( + _get_global_seqlens_and_ids(subsample_seqlens, dp_group) + ) + + return batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index cbe0652402d..bea7f4a4c18 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -79,6 +79,14 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): context_parallel_size: Optional[int] = None """The size of the context parallel group. Needed for padding in packed sequences.""" + sft_mock_dataset_config_json: Optional[str] = None + """This config provides the necessary information for the mock dataset.""" + + sequence_packing_scheduler: Optional[str] = None + """Scheduler for sequence packing and hybrid context parallel. + dp_balanced: DP-balanced scheduler for sequence packing. + """ + def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() diff --git a/megatron/core/datasets/readme.md b/megatron/core/datasets/readme.md index 452bf24e4a2..64889feb481 100644 --- a/megatron/core/datasets/readme.md +++ b/megatron/core/datasets/readme.md @@ -192,6 +192,28 @@ To query the `BlendedDataset` for the _k_-th sample we do the following To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `BlendedDataset.__init__` function. +## Packing Scheduler + +The packing scheduler re-schedules variable-length sequences across DP×CP ranks to improve GPU utilization. It is built around the following modules: + +### `data_schedule` + +This module contains the high-level scheduling logic and entry points: + +- **`HybridCPDataLoaderWrapper`**: A wrapper class for hybrid context parallel (CP) scheduling. For every `__next__` call, it: (1) pulls a batch of packed samples from each DP rank, (2) gathers sequence lengths across the DP group, (3) schedules sub-samples using the `BalancedCPScheduler`, (4) reroutes sub-samples to the correct DPxCP ranks via all-to-all communication. + +- **`BasePackingScheduler`**: Abstract base class for packing schedulers. Defines the interface for `get_groups_and_subsamples()` (scheduling algorithm) and `run()` (full scheduling pipeline including fetch, schedule, reroute, pack, broadcast, and VPP handling). + +- **`DpBalancedScheduler`**: A concrete scheduler that packs sequences in their original order until reaching the max sequence length limit per DPxCP rank. Supports aligning the number of microbatches to DP size and VPP stage multiples. + +- **`wrap_data_iterator()`**: Top-level entry point that wraps an existing `data_iterator`. It creates the appropriate scheduler, runs the scheduling pipeline, broadcast metadata and new num_microbatches, returns a new data iterator along with the updated number of microbatches and FLOPs statistics. + +- **`get_batch_on_this_rank_for_sequence_packing()`**: Fetches and broadcasts a single packed microbatch for the current rank. Handles TP/PP broadcasting, constructs `PackedSeqParams` (with `cu_seqlens`, `max_seqlen`, `qkv_format=thd`), and optionally partitions sequences across CP ranks using Transformer Engine's `thd_get_partitioned_indices`. + +### `data_schedule_utils.py` + +This module contains the utility functions used by the schedulers. + ## Fast DataLoader initialization Especially for large-scale runs, DataLoader initialization can take several minutes, since it involves opening and memory-mapping multiple files and can significantly stress the filesystem. To speed up this process, we have developed the following three optimizations, controlled by configuration flags": diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 996330f5674..537e9ad216a 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -2557,3 +2557,24 @@ def set_save_original_input(module): from transformer_engine.pytorch.float8_tensor import Float8Tensor except ImportError: Float8Tensor = None + + +def get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank): + """Get partitioned indices for THD format data in context parallel. + + Args: + cu_seqlens: Cumulative sequence lengths tensor. + total_tokens: Total number of tokens. + cp_size: Context parallel world size. + cp_rank: Context parallel rank. + + Returns: + Partitioned indices tensor. + """ + assert is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + import transformer_engine_torch as tex + + return tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 3c6ff04d3b0..b8d23a4735c 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -59,7 +59,7 @@ class ModelParallelConfig: can handle without overflowing the memory. Typically, a good starting point is to set this to maximum sequence length / context parallel size. This is used to calculate the number and length of sub-samples assigned to - each rank when using hybrid_context_parallel. + each rank when sequence_packing_scheduler is not None. """ hybrid_context_parallel: bool = False @@ -69,6 +69,12 @@ class ModelParallelConfig: Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel. """ + sequence_packing_scheduler: Optional[Literal['dp_balanced']] = None + """ + Scheduler for sequence packing and hybrid context parallel. + dp_balanced: DP-balanced scheduler for sequence packing. + """ + expert_model_parallel_size: int = 1 """Distributes Moe Experts across sub data parallel dimension.""" diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 15c5adfc7a2..6ed6317d61c 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -552,6 +552,9 @@ def forward_backward_no_pipelining( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + pg_collection.dp = parallel_state.get_data_parallel_group( + with_context_parallel=False, partial_data_parallel=False + ) elif pg_collection is not None: assert hasattr(pg_collection, 'tp') @@ -879,6 +882,9 @@ def forward_backward_pipelining_with_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + pg_collection.dp = parallel_state.get_data_parallel_group( + with_context_parallel=False, partial_data_parallel=False + ) elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model[0]) @@ -2028,6 +2034,9 @@ def forward_backward_pipelining_without_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + pg_collection.dp = parallel_state.get_data_parallel_group( + with_context_parallel=False, partial_data_parallel=False + ) elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model) assert model_type != ModelType.encoder_and_decoder, ( diff --git a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py index 8a418f2dd7f..b9a4f7b0e4b 100644 --- a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py @@ -108,6 +108,7 @@ def __init__(self, tokenizer_path: str, prompt_format: str): self._prompt_format = prompt_format + def tokenize_conversation( self, conversation: List[Dict], return_target: bool, add_generation_prompt: bool ): diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 1aad3c4b89f..20327d58112 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2019,6 +2019,40 @@ def __post_init__(self): self.attention_backend == AttnBackend.flash ), "Batch invariant mode only supports FlashAttention" + if self.sequence_packing_scheduler is not None: + # Check TE version. + if not HAVE_PACKAGING: + raise ImportError( + "packaging is not installed. Please install it with `pip install packaging`." + ) + # TODO: remove this after we fix the convergence issue with TE < 2.9. + if not ( + is_te_min_version("2.9.0") or get_te_version() == PkgVersion("2.9.0.dev0+5b3092a") + ): + raise ValueError( + "SFT sequence packing requires Transformer Engine >= 2.9.0 " + f"but got {get_te_version()} (TE < 2.9.0 may have convergence issues)." + ) + + # Needed for passing variable sequences between pp stages. + self.variable_seq_lengths = True + + # TODO(tailaim): add support for other dispatcher types + assert self.moe_token_dispatcher_type == "alltoall", ( + f"sequence_packing only supports moe_token_dispatcher_type='alltoall', " + f"got '{self.moe_token_dispatcher_type}'" + ) + + supported_schedulers = ['dp_balanced'] + if ( + self.sequence_packing_scheduler is not None + and self.sequence_packing_scheduler not in supported_schedulers + ): + raise ValueError( + f"Unsupported scheduler: {self.sequence_packing_scheduler}. " + f"Available schedulers: {supported_schedulers}" + ) + @dataclass class MLATransformerConfig(TransformerConfig): diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index eed05652e09..a4f4b982031 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -823,13 +823,6 @@ def validate_args(args, defaults={}): if args.rl_use_sequence_packing: args.consumed_train_bins = 0 - # Support for variable sequence lengths across batches/microbatches. - # set it if the dataloader supports generation of variable sequence lengths - # across batches/microbatches. Due to additional communication overhead - # during pipeline parallelism, it should not be set if sequence length - # is constant during training. - args.variable_seq_lengths = False - # Iteration-based training. if args.train_iters: # If we use iteration-based training, make sure the @@ -1000,6 +993,29 @@ def validate_args(args, defaults={}): assert args.dataloader_type == 'single', 'Hybrid context parallelism only supported with single dataloader type' assert args.calculate_per_token_loss, 'Hybrid context parallelism must be used with --calculate-per-token-loss' + # Support for variable sequence lengths across batches/microbatches. + # set it if the dataloader supports generation of variable sequence lengths + # across batches/microbatches. Due to additional communication overhead + # during pipeline parallelism, it should not be set if sequence length + # is constant during training. + args.variable_seq_lengths = False + if args.sequence_packing_scheduler is not None: + args.variable_seq_lengths = True + assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ + f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ + f'must be >= single sequence max length ({args.seq_length})' + if args.mock_data and args.sft_mock_dataset_config_json is None: + args.sft_mock_dataset_config_json = json.dumps( + { + "mode": "distribution", + "type": "lognormal", + "min_seq_len": args.seq_length // 2, + "max_seq_len": args.seq_length, + "mean_seq_len": args.seq_length // 4 * 3, + "lognormal_sigma": 1.1, + } + ) + # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled if (args.tensor_model_parallel_size > 1 or args.context_parallel_size > 1) \ @@ -1707,6 +1723,9 @@ def _add_network_size_args(parser): "persist_layer_norm", "bias_dropout_fusion", "apply_rope_fusion", + "max_seqlen_per_dp_cp_rank", + "hybrid_context_parallel", + "sequence_packing_scheduler", ] transformer_factory = ArgumentGroupFactory(TransformerConfig, exclude=exclude) transformer_group = transformer_factory.build_group(parser, "transformer configuration") @@ -2399,6 +2418,14 @@ def _add_distributed_args(parser): 'all layers will share the same communication type. Users can also ' 'specify separated types for each layer like ' '--cp-comm-type p2p p2p a2a a2a a2a+p2p a2a+p2p') + group.add_argument('--max-seqlen-per-dp-cp-rank', type=int, default=None, + help='Maximum sequence length per CP rank. This is used to calculate the ' + 'number of sub-samples assigned to each CP rank when using heterogeneous context parallel.') + group.add_argument('--hybrid-context-parallel', action='store_true', default=False, + help='Enables hybrid context parallel. This is used to balance the workload ' + 'of each CP rank when we use packed samples with variable sequence lengths. ' + 'Requires --max-seqlen-per-dp-cp-rank to be set.') + group.add_argument('--sequence-packing-scheduler', type=str, default='default_sequence_packing', choices=['default_sequence_packing']) group.add_argument('--fake-process-group', action='store_true', default=False, help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \ This is quite useful for profiling memory usage of distributed training with just one GPU. \ @@ -2940,4 +2967,8 @@ def _add_sft_args(parser): group.add_argument('--sft', action="store_true", help='Megatron SFT training') group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", help='SFT prompt format.') + group.add_argument('--sft-mock-dataset-config-json', type=str, default=None, + help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution. ' + 'If not specified and --mock-data is set, defaults to a lognormal distribution with ' + 'min_seq_len=seq_length//2, max_seq_len=seq_length, mean_seq_len=seq_length*3//4, lognormal_sigma=1.1.') return parser diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index fd9d1fe7c14..d7044247e6c 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -2,9 +2,12 @@ import atexit, json from collections import Counter -from typing import Any, Dict, Optional +import json +import math +from typing import Any, Dict, Optional, List, Union import numpy as np +import pandas as pd import torch from megatron.core.datasets.gpt_dataset import GPTDatasetConfig @@ -47,7 +50,6 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> list: return self.dataset[idx]["messages"] - class SFTDataset(MegatronDataset): """The dataset used during SFT""" @@ -61,6 +63,8 @@ def __init__( config: GPTDatasetConfig, ) -> None: super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) + # Pre-calculate padding divisor to avoid redundant computation in get_item + self.padding_divisor = self._calculate_padding_divisor() @staticmethod def numel_low_level_dataset(low_level_dataset: LowLevelDataset) -> int: @@ -88,6 +92,26 @@ def _split_conversations(self, merged_conversations): split_conversations.append(current) return split_conversations + def _calculate_padding_divisor(self) -> int: + """ + Calculate the divisor used for sequence padding. + tp_pad = tp_size * 2 if tp_size > 1 else 1 + cp_pad = cp_size * 2 if cp_size > 1 else 1 + cp_pad = cp_pad * dp_size if hybrid_cp else cp_pad + divisor = cp_pad * tp_pad + """ + if self.config.hybrid_context_parallel: + # Hybrid CP: consider both CP and DP + cp_pad = self.config.data_parallel_size * self.config.context_parallel_size * 2 + else: + # Standard CP: only consider CP + cp_pad = self.config.context_parallel_size * 2 if self.config.context_parallel_size > 1 else 1 + tp_pad = self.config.sequence_parallel_size if self.config.sequence_parallel_size > 0 else 1 + divisor = cp_pad * tp_pad + # TODO(tailaim): do we need to pad for FP8 execution? + # divisor = ((divisor + 15) // 16) * 16 + return divisor + def __getitem__(self, idx: int) -> Dict[str, Any]: tokenizer = self.config.tokenizer @@ -128,12 +152,11 @@ def extend_with_padding(tokens, targets, positions, pad_len): assert not self.config.reset_position_ids pack_positions.extend(range(len(tokens_list))) - if self.config.context_parallel_size > 1: - pad_granularity = self.config.context_parallel_size * 2 - mod_token_count = len(pack_tokens) % pad_granularity - if mod_token_count != 0: - pad_len = pad_granularity - mod_token_count - extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + pad_granularity = self.padding_divisor + mod_token_count = len(pack_tokens) % pad_granularity + if mod_token_count != 0: + pad_len = pad_granularity - mod_token_count + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) # TODO(duncan): Consider also padding to multiple of number of tokens here. This might # be needed for efficiency (and potentially set via command-line argument). @@ -194,3 +217,158 @@ def extend_with_padding(tokens, targets, positions, pad_len): 'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, } + +class MockSFTLowLevelDataset: + """The low-level mock dataset for SFT + + Args: + mode (str): Either 'file' or 'distribution'. + **kwargs: Additional arguments depending on mode. + For mode='file': path (str) - path to a CSV file with sequence lengths. + For mode='distribution': type (str), min_seq_len (int), max_seq_len (int), + mean_seq_len (int), and distribution-specific params (e.g. lognormal_sigma). + """ + + seed: int = 0 + """The hard-coded random seed to use to set the NumPy RNG""" + + size: int = 1000000 + """The hard-coded number of sequence to generate""" + + def __init__(self, mode: str, **kwargs) -> None: + np.random.seed(self.seed) + + if mode == "file": + self.sequence_lengths = np.array(pd.read_csv(kwargs["path"])).flatten() + self.size = len(self.sequence_lengths) + elif mode == "distribution": + min_seq_len = kwargs["min_seq_len"] + max_seq_len = kwargs["max_seq_len"] + mean_seq_len = kwargs["mean_seq_len"] + if kwargs["type"] == "lognormal": + lognormal_sigma = kwargs["lognormal_sigma"] + self.sequence_lengths = self.generate_lognormal_samples( + self.size, mean_seq_len, lognormal_sigma, min_seq_len, max_seq_len + ) + else: + raise ValueError(f"Unsupported distribution type {kwargs['type']}") + else: + raise ValueError(f"Unsupported mode '{mode}', must be 'file' or 'distribution'") + + def generate_lognormal_samples(self, size, mean, sigma, min_seq_len, max_seq_len): + mu = np.log(mean) - sigma**2 / 2 + samples = np.random.lognormal(mu, sigma, size) + samples = np.clip(samples, min_seq_len, max_seq_len) + return samples.astype(int) + + def __len__(self) -> int: + return self.size + + def __getitem__(self, idx: int) -> List[np.ndarray]: + # the length of sample is 'length', but only length-1 elements are generated here, + # because an eod token will be appended at the end later in SFTDataset + + length = self.sequence_lengths[idx % self.size] + sample = np.arange(1, length, dtype=np.int64) + return sample +class MockSFTDataset(SFTDataset): + """The mock dataset used during SFT""" + + def __init__( + self, + dataset: LowLevelDataset, + dataset_path: Optional[str], + indices: np.ndarray, + num_samples: Optional[int], + index_split: Split, + config: GPTDatasetConfig, + ) -> None: + super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) + + @staticmethod + def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowLevelDataset: + mock_config = json.loads(config.sft_mock_dataset_config_json) + return MockSFTLowLevelDataset(**mock_config) + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> Dict[str, Any]: + + tokenizer = self.config.tokenizer + pack_length = self.config.sequence_length + eod = tokenizer.eod + pad = tokenizer.pad + + tokens = self.dataset[int(self.indices[idx % len(self.indices)])] + + def extend_with_padding(tokens, targets, positions, pad_len): + tokens.extend([pad] * pad_len) + targets.extend([pad] * pad_len) + positions.extend(range(positions[-1] + 1, positions[-1] + 1 + pad_len)) + + # Convert tokens to list and add EOD + tokens_list = tokens.tolist() + if tokens_list[-1] != eod: + tokens_list.append(eod) + targets_list = list(tokens_list) + + pack_tokens = list(tokens_list) + pack_targets = list(targets_list) + pack_positions = list(range(len(tokens_list))) + cu_seqlens = [0] + + # Pad to padding_divisor alignment + if self.padding_divisor > 1: + mod_token_count = len(pack_tokens) % self.padding_divisor + if mod_token_count != 0: + pad_len = self.padding_divisor - mod_token_count + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + + # Record padded boundary after padding + cu_seqlens.append(len(pack_tokens)) + + # Handle any necessary truncation + if len(pack_tokens) >= pack_length + 1: # +1 here to account for later alignment + max_body = pack_length - 1 + pack_tokens = pack_tokens[:max_body] + pack_targets = pack_targets[:max_body] + pack_tokens.extend([eod, pad]) + pack_targets.extend([eod, pad]) + pack_positions = pack_positions[:pack_length + 1] + cu_seqlens[-1] = len(pack_tokens) - 1 + + # Handle any necessary padding + if len(pack_tokens) < pack_length + 1: # +1 here to account for later alignment + pad_len = pack_length + 1 - len(pack_tokens) + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + cu_seqlens[-1] = len(pack_tokens) - 1 + + assert len(pack_tokens) == pack_length + 1 + assert len(pack_targets) == pack_length + 1 + assert len(pack_positions) == pack_length + 1 + + # Align and convert to tensors + input_ids = torch.tensor(pack_tokens[:-1], dtype=torch.int64) + labels = torch.tensor(pack_targets[1:], dtype=torch.int64) + position_ids = torch.tensor(pack_positions[:-1], dtype=torch.int64) + + # Loss mask + loss_mask = torch.ones(pack_length, dtype=torch.float32) + loss_mask[labels == pad] = 0.0 + loss_mask[labels == IGNORE_INDEX] = 0.0 + + assert len(cu_seqlens) >= 2 + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) + # Calculating max_seqlen here because of possible effects of truncation and padding + adjacent_diffs = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = adjacent_diffs.max() # max_seqlen is a 0-D tensor + + return { + 'tokens': input_ids, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + 'cu_seqlens': cu_seqlens, + 'max_seqlen': max_seqlen, + } \ No newline at end of file diff --git a/megatron/training/training.py b/megatron/training/training.py index 2c68c70735d..1326dc39634 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -139,7 +139,6 @@ def set_startup_timestamps(program_start=None, main_entry=None): from megatron.training.initialize import set_jit_fusion_options from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank from megatron.training.datasets.data_samplers import build_pretraining_data_loader -from megatron.core.datasets.data_schedule import HybridCPDataLoaderWrapper from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.transformer.moe import upcycling_utils from megatron.core.transformer.moe.moe_utils import track_moe_metrics, clear_aux_losses_tracker @@ -169,6 +168,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): get_num_microbatches, update_num_microbatches ) +from megatron.core.datasets.data_schedule import wrap_data_iterator from .async_utils import maybe_finalize_async_save from .utils import ( @@ -225,7 +225,7 @@ def print_datetime(string, override_timestamp=None): time_str = datetime.fromtimestamp(override_timestamp).strftime('%Y-%m-%d %H:%M:%S.%f') print_rank_0(f'[{string}] datetime: {time_str} ') -def num_floating_point_operations(args, batch_size): +def num_floating_point_operations(args, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch): def calculate_layer_counts(): """Calculate the number of attention, Mamba, and MLP layers.""" if args.hybrid_override_pattern: @@ -251,44 +251,42 @@ def calculate_layer_counts(): num_moe_layers = 0 return num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers - def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False): + def mlp_layer_flops(seqlen_sum_this_global_batch, hidden_size, expansion=4.0, swiglu=False): """Calculate FLOPs for an MLP layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 - return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 + return 4 * expansion * scale_factor * seqlen_sum_this_global_batch * hidden_size**2 - def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, + def moe_layer_flops(seqlen_sum_this_global_batch, hidden_size, moe_ffn_hidden_size, shared_expert_ffn_hidden_size, num_experts_routed_to, moe_latent_size=None, swiglu=False): """Calculate FLOPs for an MoE layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 if moe_latent_size is None: - routed_flops = (4 * batch_size * seq_len * hidden_size * + routed_flops = (4 * seqlen_sum_this_global_batch * hidden_size * moe_ffn_hidden_size * num_experts_routed_to * scale_factor) else: # Routed experts run on moe_latent_size. - routed_flops = (4 * batch_size * seq_len * moe_latent_size * + routed_flops = (4 * seqlen_sum_this_global_batch * moe_latent_size * moe_ffn_hidden_size * num_experts_routed_to * scale_factor) # Up proj and down proj. - routed_flops += (4 * batch_size * seq_len * hidden_size * moe_latent_size) - shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor + routed_flops += (4 * seqlen_sum_this_global_batch * hidden_size * moe_latent_size) + shared_flops = 4 * seqlen_sum_this_global_batch * hidden_size * shared_expert_ffn_hidden_size * scale_factor return routed_flops + shared_flops def attn_layer_flops( - batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None + seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None ): """Calculate FLOPs for an attention layer.""" p = (kv_channels * num_heads / hidden_size) if kv_channels else 1 g = gqa_groups if gqa else num_heads return ( 4 - * batch_size - * seq_len * hidden_size * p - * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2)) + * (hidden_size * seqlen_sum_this_global_batch + (hidden_size * (g / num_heads)) * seqlen_sum_this_global_batch + (seqlen_squared_sum_this_global_batch / 2)) ) - def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, + def mamba_layer_flops(seqlen_sum_this_global_batch, hidden_size, state_dim=16, head_dim=64, num_groups=1, num_heads=128): """Calculate FLOPs for a Mamba layer.""" # Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels, @@ -301,16 +299,15 @@ def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, return ( ( 2 - * batch_size - * seq_len + * seqlen_sum_this_global_batch * hidden_size * (2 * d_in + 2 * num_groups * state_dim + nheads) ) # in_proj - + (7 * batch_size * seq_len * d_in * state_dim) # scan - + (2 * batch_size * seq_len * d_in * hidden_size) # out_proj + + (7 * seqlen_sum_this_global_batch * d_in * state_dim) # scan + + (2 * seqlen_sum_this_global_batch * d_in * hidden_size) # out_proj ) - def hybrid_flops(batch_size, seq_len, hidden_size, + def hybrid_flops(seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, hidden_size, num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers, mamba_state_dim=128, mamba_head_dim=64, mamba_num_groups=8, mamba_num_heads=128, @@ -322,17 +319,17 @@ def hybrid_flops(batch_size, seq_len, hidden_size, vocab_size=256000, mtp_num_layers=0): """Calculate total FLOPs for the hybrid model.""" flops_fwd = ( - num_attn_layers * attn_layer_flops(batch_size, seq_len, hidden_size, + num_attn_layers * attn_layer_flops(seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, hidden_size, num_attn_heads, gqa, gqa_groups, kv_channels) + - num_mlp_layers * mlp_layer_flops(batch_size, seq_len, hidden_size, + num_mlp_layers * mlp_layer_flops(seqlen_sum_this_global_batch, hidden_size, mlp_expansion, swiglu) + - num_mamba_layers * mamba_layer_flops(batch_size, seq_len, hidden_size, + num_mamba_layers * mamba_layer_flops(seqlen_sum_this_global_batch, hidden_size, mamba_state_dim, mamba_head_dim, mamba_num_groups, mamba_num_heads) + - num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, + num_moe_layers * moe_layer_flops(seqlen_sum_this_global_batch, hidden_size, moe_ffn_hidden_size, shared_expert_ffn_hidden_size, num_experts_routed_to, moe_latent_size, swiglu) + - (2 * batch_size * seq_len * hidden_size * vocab_size * (1 + mtp_num_layers)) # logits computation + (2 * seqlen_sum_this_global_batch * hidden_size * vocab_size * (1 + mtp_num_layers)) # logits computation ) return flops_fwd * 3 @@ -403,13 +400,18 @@ def transformer_flops(): assert not args.group_query_attention ''' Basic arithmetic - let B is batch size, s is seq_len, h is embedding dim, - for one self_attnetion block (prenorm is not included) - qkv projection: 6Bsh^2 - attn: 2Bs^2h - attn over value: 2Bs^2h - oproj: 2Bsh^2 - + + Let h be the embedding dim. + We use two statistics to unify BSHD and THD cases: + seqlen_sum_this_global_batch: total number of tokens in this global batch + seqlen_squared_sum_this_global_batch: sum of squared sequence lengths in this global batch + + For one self-attention block (prenorm not included): + qkv projection: 6 * seqlen_sum_this_global_batch * h^2 + attn: 2 * seqlen_squared_sum_this_global_batch * h + attn over value: 2 * seqlen_squared_sum_this_global_batch * h + oproj: 2 * seqlen_sum_this_global_batch * h^2 + references https://arxiv.org/abs/2305.10403 https://arxiv.org/abs/2205.05198 @@ -430,7 +432,7 @@ def transformer_flops(): standard_self_attn_term = ( forward_backward_expansion_factor * fma_expansion_factor - * ( + * ( seqlen_sum_this_global_batch * ( ## q lora + rope + q norm q_term ## kv lora + rope + kv norm @@ -442,12 +444,12 @@ def transformer_flops(): ) + args.hidden_size * args.qk_pos_emb_head_dim ## o proj - + (args.num_attention_heads * args.v_head_dim) * args.hidden_size + + (args.num_attention_heads * args.v_head_dim) * args.hidden_size) ## core attn - + args.seq_length + + seqlen_squared_sum_this_global_batch * (args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)) - / 2 # causal mask (only half of the mask is non-zero) - + args.seq_length * args.num_attention_heads * args.v_head_dim / 2 + / 2 # causal mask (only half of the mask is non-zero) + + seqlen_squared_sum_this_global_batch * args.num_attention_heads * args.v_head_dim / 2 ) ) @@ -460,7 +462,7 @@ def transformer_flops(): standard_self_attn_term = ( forward_backward_expansion_factor * fma_expansion_factor - * ( + * ( seqlen_sum_this_global_batch *( ## qkv proj args.hidden_size * ( @@ -468,14 +470,14 @@ def transformer_flops(): + key_projection_size + value_projection_size + gate_projection_size - ) + )) ## core attention + query_projection_size - * args.seq_length + * seqlen_squared_sum_this_global_batch / 2 # causal mask (only half of the mask is non-zero) * 2 # QK^T and (QK^T)V ## out proj - + query_projection_size + + seqlen_sum_this_global_batch * query_projection_size * args.hidden_size ) ) @@ -553,8 +555,7 @@ def transformer_flops(): ) total_floating_point_operations = ( - batch_size - * args.seq_length + seqlen_sum_this_global_batch * ( # MLP forward_backward_expansion_factor @@ -584,8 +585,6 @@ def transformer_flops(): + (shared_expert_ffn_hidden_size * ffn_expansion_factor) * num_moe_layers ) - # Self Attention - + self_attn_term # MTP norms and proj + forward_backward_expansion_factor * fma_expansion_factor @@ -603,6 +602,10 @@ def transformer_flops(): * args.padded_vocab_size * (mtp_num_layers + 1) # MTP + final logit ) + + + # Self Attention + self_attn_term + ) return total_floating_point_operations @@ -616,8 +619,8 @@ def transformer_flops(): mtp_num_layers = 0 # Compute hybrid model FLOPs. return hybrid_flops( - batch_size=batch_size, - seq_len=args.seq_length, + seqlen_sum_this_global_batch=seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch=seqlen_squared_sum_this_global_batch, hidden_size=args.hidden_size, num_attn_layers=num_attn_layers, num_mamba_layers=num_mamba_layers, @@ -1712,6 +1715,19 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch if isinstance(optim_instance, DistributedOptimizer): optim_instance._copy_main_params_to_param_buffer() + if config.sequence_packing: + ( + data_iterator, + num_microbatches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = wrap_data_iterator(data_iterator, config, get_num_microbatches()) + else: + # data_iterator unchanged + num_microbatches = get_num_microbatches() + seqlen_sum_this_global_batch = args.seq_length * args.micro_batch_size * args.data_parallel_size * num_microbatches + seqlen_squared_sum_this_global_batch = args.seq_length ** 2 * args.micro_batch_size * args.data_parallel_size * num_microbatches + # Forward pass. if save_dgrads_in_this_iteration: enable_dgrad_logging(model, args.save) @@ -1719,7 +1735,7 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch forward_step_func=forward_step_func, data_iterator=data_iterator, model=model, - num_microbatches=get_num_microbatches(), + num_microbatches=num_microbatches, seq_length=args.seq_length, micro_batch_size=args.micro_batch_size, decoder_seq_length=args.decoder_seq_length, @@ -1750,9 +1766,10 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch # iteration is 0-indexed, move to 1-indexed for checkpoint name and logging. save_grads(args.save, state_dict, iteration + 1, "wgrads") + should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() if should_exit: - return {}, True, should_checkpoint, should_exit, exit_code, None, None, 0 + return {}, True, should_checkpoint, should_exit, exit_code, None, None, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, 0 # Empty unused memory. if args.empty_unused_memory_level >= 1: @@ -1832,8 +1849,10 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch grad_norm, num_zeros_in_grad, log_max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, ) - return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit + return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch def training_log( @@ -1848,6 +1867,8 @@ def training_log( params_norm, num_zeros_in_grad, max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, pg_collection=None, is_first_iteration=False, ): @@ -2082,7 +2103,7 @@ def training_log( elapsed_time = timers('interval-time').elapsed(barrier=True, reset=should_reset) elapsed_time_per_iteration = elapsed_time / total_iterations - throughput = num_floating_point_operations(args, batch_size) / ( + throughput = num_floating_point_operations(args,seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) / ( elapsed_time_per_iteration * 10**12 * args.world_size ) @@ -2833,6 +2854,8 @@ def trace_handler(p): # Completely skip iteration if needed. if iteration in args.iterations_to_skip: + # TODO(tailaim): this need to be modified + assert config.sequence_packing_scheduler is None, "Sequence packing scheduler is not supported in skip iteration mode" # Dummy train_step to fast forward train_data_iterator. dummy_train_step(train_data_iterator) if iteration == start_iteration: @@ -2875,6 +2898,8 @@ def trace_handler(p): grad_norm, num_zeros_in_grad, max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, ) = train_step( forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func, iteration=iteration ) @@ -2962,7 +2987,7 @@ def trace_handler(p): else: assert num_skipped_samples_in_batch == 0 args.skipped_train_samples += num_skipped_samples_in_batch - num_floating_point_operations_in_batch = num_floating_point_operations(args, batch_size) + num_floating_point_operations_in_batch = num_floating_point_operations(args, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) num_floating_point_operations_so_far += num_floating_point_operations_in_batch num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch @@ -2988,6 +3013,8 @@ def trace_handler(p): params_norm, num_zeros_in_grad, max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, pg_collection=model_pg_collection, is_first_iteration=is_first_iteration, ) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 81768944623..9beadc1ec37 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -25,6 +25,7 @@ from megatron.core import parallel_state from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset +from megatron.core.datasets.data_schedule import get_batch_on_this_rank_for_sequence_packing from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel from megatron.core.rerun_state_machine import get_rerun_state_machine @@ -48,6 +49,7 @@ get_blend_and_blend_per_split, is_first_or_last_pipeline_stage, ) +from megatron.training.datasets.sft_dataset import SFTDataset, MockSFTDataset from model_provider import model_provider try: @@ -65,6 +67,15 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): """Generate a batch.""" args = get_args() config = core_transformer_config_from_args(args) + + if args.sequence_packing_scheduler is not None: + return get_batch_on_this_rank_for_sequence_packing( + data_iterator, + vpp_size=config.virtual_pipeline_model_parallel_size, + mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage), + vp_stage=vp_stage, + ) + # TODO: this is pretty hacky, find a better way if not is_first_or_last_pipeline_stage(vp_stage) and ( (not mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage))): @@ -249,6 +260,8 @@ def core_gpt_dataset_config_from_args(args): "data_parallel_size": args.data_parallel_size, "sequence_parallel_size": args.tensor_model_parallel_size*args.sequence_parallel, "hybrid_context_parallel": args.hybrid_context_parallel, + "sft_mock_dataset_config_json":args.sft_mock_dataset_config_json, + "sequence_packing_scheduler": args.sequence_packing_scheduler, } # add FIM args to the config @@ -286,7 +299,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None config = core_gpt_dataset_config_from_args(args) if args.sft: - dataset_type = SFTDataset + if args.mock_data: + dataset_type = MockSFTDataset + else: + dataset_type = SFTDataset else: if args.mock_data: dataset_type = MockGPTDataset diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index f933d811779..912d97528e7 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -277,6 +277,7 @@ "offload_modules": [], "hybrid_context_parallel": False, "max_seqlen_per_dp_cp_rank": None, + "sequence_packing_scheduler": None, } # Fields to ignore entirely (ephemeral, environment-specific, very large). SKIP_FIELDS = set() diff --git a/tests/unit_tests/test_sequence_packing.py b/tests/unit_tests/test_sequence_packing.py new file mode 100644 index 00000000000..60316b0236e --- /dev/null +++ b/tests/unit_tests/test_sequence_packing.py @@ -0,0 +1,479 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import random +from types import SimpleNamespace + +import numpy as np +import pytest +import torch + +from megatron.core import parallel_state +from megatron.core.datasets.data_schedule import ( + get_batch_on_this_rank_for_sequence_packing, + wrap_data_iterator, +) +from megatron.core.rerun_state_machine import RerunDataIterator +from megatron.training.global_vars import unset_global_variables +from tests.unit_tests.test_utilities import Utils + + +class MockVariableLengthSequencePackingDataIterator: + """ + Mock data iterator for testing get_batch_on_this_rank_for_sequence_packing. + + Generates variable-length (THD format) packed sequences with deterministic + data for verification across parallel ranks. + """ + + def __init__( + self, + total_seq_length: int, + sequence_lengths: list, + local_cp_size: int = None, + device: str = "cuda", + seed: int = 42, + ): + """ + Args: + total_seq_length: Total length of packed sequences + sequence_lengths: List of individual sequence lengths (variable-length). + If None, generates random variable lengths. + device: Device to create tensors on + seed: Random seed for reproducibility + """ + self.total_seq_length = total_seq_length + self.sequence_lengths = sequence_lengths + self.local_cp_size = local_cp_size + self.device = device + self.seed = seed + assert ( + sum(self.sequence_lengths) == total_seq_length + ), f"Sequence lengths sum {sum(self.sequence_lengths)} != total {total_seq_length}" + + def __iter__(self): + """Interface for the data iterator.""" + return self + + def __next__(self): + """Generate a mock batch with variable-length THD format.""" + dev = self.device + torch.manual_seed(self.seed) + torch.cuda.manual_seed(self.seed) + + tokens = torch.randint(0, 16384, (self.total_seq_length,), dtype=torch.int64, device=dev) + + # Create position_ids that reset for each sequence (THD format) + position_ids = [] + for seq_len in self.sequence_lengths: + position_ids.extend(range(seq_len)) + position_ids = torch.tensor(position_ids, dtype=torch.int64, device=dev) + + # Labels are tokens shifted by 1 for easy verification + labels = tokens + 1 + + # Loss mask: 1.0 for all positions except padding (none here) + loss_mask = torch.ones(self.total_seq_length, dtype=torch.float32, device=dev) + + # Create cu_seqlens for variable-length packed sequences + cu_seqlens = [0] + for seq_len in self.sequence_lengths: + cu_seqlens.append(cu_seqlens[-1] + seq_len) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=dev) + cu_seqlens_padded = cu_seqlens.clone() + + max_seqlen = torch.tensor([max(self.sequence_lengths)], dtype=torch.int32, device=dev) + + batch = { + "tokens": tokens, + "position_ids": position_ids, + "labels": labels, + "loss_mask": loss_mask, + "cu_seqlens": cu_seqlens, + "cu_seqlens_padded": cu_seqlens_padded, + "max_seqlen": max_seqlen, + } + + if not ( + parallel_state.is_pipeline_first_stage(ignore_virtual=True) + or parallel_state.is_pipeline_last_stage(ignore_virtual=True) + ): + batch["tokens"] = None + batch["position_ids"] = None + batch["labels"] = None + batch["loss_mask"] = None + + if self.local_cp_size is not None: + batch["local_cp_size"] = torch.tensor( + [self.local_cp_size], dtype=torch.int32, device=dev + ) + + return batch + + +def _gather_tensor_from_tp_group(tensor): + """Gather tensors from all TP ranks for comparison.""" + assert tensor is not None, "Tensor should not be None" + tp_size = parallel_state.get_tensor_model_parallel_world_size() + gathered = [torch.zeros_like(tensor) for _ in range(tp_size)] + torch.distributed.all_gather( + gathered, tensor, group=parallel_state.get_tensor_model_parallel_group() + ) + return gathered + + +def _gather_tensor_from_all_ranks(tensor): + """Gather tensors from all PP ranks for comparison.""" + assert tensor is not None, "Tensor should not be None" + if type(tensor) is int: + tensor = torch.tensor(tensor, dtype=torch.int32, device=torch.cuda.current_device()) + gathered = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(gathered, tensor) + return gathered + + +@pytest.mark.parametrize( + ("tp", "pp", "cp"), + [ + (1, 1, 1), # Basic case: no parallelism + (2, 1, 1), # Tensor parallel only + (1, 2, 1), # Pipeline parallel only + (2, 2, 1), # TP + PP + (1, 1, 2), # CP only + (2, 1, 2), # TP + CP + (1, 2, 2), # PP + CP + (1, 4, 1), # Has middle pp stage + ], +) +def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp): + """ + Test get_batch_on_this_rank_for_sequence_packing function with variable-length THD format. + + This test verifies: + 1. TP ranks: All ranks within a TP group receive identical data after broadcast + 2. PP ranks: Middle PP ranks have the same packed_seq_params as first/last stages + 3. CP ranks: Data is correctly partitioned with proper shape and values + 4. Variable-length (THD) format: Different sequence lengths are handled correctly + """ + args = SimpleNamespace() + args.tensor_model_parallel_size = tp + args.pipeline_model_parallel_size = pp + args.context_parallel_size = cp + args.virtual_pipeline_model_parallel_size = None + args.data_parallel_size = 8 // (tp * pp * cp) + args.seq_length = 8192 + + # Skip invalid configurations + if args.data_parallel_size < 1: + raise ValueError(f"Invalid config: tp={tp}, pp={pp}, cp={cp} exceeds world size 8") + + # Initialize model parallel + Utils.initialize_model_parallel(tp, pp, None, context_parallel_size=cp) + + try: + # Create mock data iterator with variable-length sequences + # Only TP rank 0 needs the iterator; other TP ranks pass None + tp_rank = parallel_state.get_tensor_model_parallel_rank() + if tp_rank == 0: + # Use deterministic seed based on DP rank so same data within TP/PP/CP group + dp_rank = parallel_state.get_data_parallel_rank() + sequence_lengths = [1024, 2048, 512, 1536, 3072] + assert ( + sum(sequence_lengths) == args.seq_length + ), f"Sequence lengths sum {sum(sequence_lengths)} != total {args.seq_length}" + data_iterator = iter( + MockVariableLengthSequencePackingDataIterator( + total_seq_length=args.seq_length, + sequence_lengths=sequence_lengths, # Variable lengths, sum=8192 + seed=42 + dp_rank, # Same seed within PP/CP group + ) + ) + else: + # Non-TP-rank-0 ranks don't need the iterator + data_iterator = None + + # Call the function under test + result = get_batch_on_this_rank_for_sequence_packing( + data_iterator=data_iterator, mtp_on_this_rank=False, vp_stage=None + ) + + # Unpack the result + tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = result + + # Get parallel state info + tp_rank = parallel_state.get_tensor_model_parallel_rank() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + cp_rank = parallel_state.get_context_parallel_rank() + is_first_stage = parallel_state.is_pipeline_first_stage(ignore_virtual=True) + is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + is_first_or_last = is_first_stage or is_last_stage + + # ===================================================================== + # TEST 1: Verify data based on pipeline stage + # ===================================================================== + if is_first_stage: + assert tokens is not None, "First stage should have tokens" + assert position_ids is not None, "First stage should have position_ids" + assert tokens.dim() == 2, "Tokens should be 2D (batch, seq)" + assert position_ids.dim() == 2, "Position IDs should be 2D (batch, seq)" + assert tokens.size(0) == 1, "batch should be 1 in THD format" + assert position_ids.size(0) == 1, "batch should be 1 in THD format" + else: + assert tokens is None, "Non-first stage should not have tokens" + assert position_ids is None, "Non-first stage should not have position_ids" + + if is_last_stage: + assert labels is not None, "Last stage should have labels" + assert loss_mask is not None, "Last stage should have loss_mask" + assert labels.dim() == 2, "Labels should be 2D (batch, seq)" + assert loss_mask.dim() == 2, "Loss mask should be 2D (batch, seq)" + assert labels.size(0) == 1, "batch should be 1 in THD format" + assert loss_mask.size(0) == 1, "batch should be 1 in THD format" + else: + assert labels is None, "Non-last stage should not have labels" + assert loss_mask is None, "Non-last stage should not have loss_mask" + + # ===================================================================== + # TEST 2: Verify all ranks have consistent packed_seq_params + # ===================================================================== + assert packed_seq_params is not None + assert packed_seq_params.qkv_format == "thd" + + test_keys = [ + "cu_seqlens_q", + "cu_seqlens_q_padded", + "max_seqlen_q", + "cu_seqlens_kv", + "cu_seqlens_kv_padded", + "max_seqlen_kv", + ] + for key in test_keys: + tensor = getattr(packed_seq_params, key) + assert tensor is not None + gathered_tensor = _gather_tensor_from_all_ranks(tensor) + for i in range(1, len(gathered_tensor)): + assert torch.equal( + gathered_tensor[0], gathered_tensor[i] + ), f"Rank 0 and rank {i} have different {key}" + + # ===================================================================== + # TEST 3: Verify TP ranks receive identical data after broadcast + # ===================================================================== + if tp > 1: + test_tensors = [] + if is_first_stage: + test_tensors.extend([tokens, position_ids]) + if is_last_stage: + test_tensors.extend([labels, loss_mask]) + + for tensor in test_tensors: + gathered_tensors = _gather_tensor_from_tp_group(tensor) + for i in range(1, tp): + assert torch.equal( + gathered_tensors[0], gathered_tensors[i] + ), f"TP rank 0 and rank {i} have different data" + + # ===================================================================== + # TEST 4: Verify CP partitioning + # ===================================================================== + if cp > 1: + # With CP, the sequence should be partitioned + expected_seq_len = args.seq_length // cp + + if is_first_stage: + actual_seq_len = tokens.shape[1] + assert ( + actual_seq_len == expected_seq_len + ), f"CP partitioned tokens have wrong shape: {actual_seq_len} != {expected_seq_len}" + + # Verify labels only if all CP ranks are at last stage + if is_last_stage: + actual_seq_len = labels.shape[1] + assert ( + actual_seq_len == expected_seq_len + ), f"CP partitioned labels have wrong shape: {actual_seq_len} != {expected_seq_len}" + + finally: + Utils.destroy_model_parallel() + unset_global_variables() + + +@pytest.mark.parametrize( + ("tp", "pp", "cp", "vpp", "scheduler_type"), + [ + (1, 1, 8, None, "dp_balanced"), + (2, 1, 4, None, "dp_balanced"), + (2, 4, 1, None, "dp_balanced"), + (2, 2, 1, None, "dp_balanced"), + (1, 4, 1, 4, "dp_balanced"), + ], +) +def test_wrap_dataloader(tp, pp, cp, vpp, scheduler_type): + ''' + Test wrap_dataloader function with different scheduler types. + ''' + args = SimpleNamespace() + args.tensor_model_parallel_size = tp + args.pipeline_model_parallel_size = pp + args.context_parallel_size = cp + args.virtual_pipeline_model_parallel_size = None + args.data_parallel_size = 8 // (tp * pp * cp) + args.seq_length = 8192 + args.max_seqlen_per_dp_cp_rank = 8192 + + # Skip invalid configurations + if args.data_parallel_size < 1: + raise ValueError(f"Invalid config: tp={tp}, pp={pp}, cp={cp} exceeds world size 8") + + def _create_single_sample(seq_len): + # hard code the padding size to 16 + pad_size = 16 + seq_len_padded = ((seq_len + pad_size - 1) // pad_size) * pad_size + device = torch.device("cuda", torch.cuda.current_device()) + tokens = torch.randint(0, 128, (seq_len_padded,), dtype=torch.int64, device=device) + labels = tokens + 1 + position_ids = torch.arange(seq_len_padded, dtype=torch.int64, device=device) + loss_mask = torch.ones(seq_len_padded, dtype=torch.float32, device=device) + loss_mask[0:seq_len] = 1 + loss_mask[seq_len:] = 0 + cu_seqlens = torch.tensor([0, seq_len_padded], dtype=torch.int32, device=device) + + return { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + 'cu_seqlens': cu_seqlens, + } + + # Initialize model parallel + Utils.initialize_model_parallel(tp, pp, vpp, context_parallel_size=cp) + + global_batch_size = 64 + micro_batch_size = 1 + nums = [random.randint(2048, args.seq_length) for _ in range(global_batch_size)] # 64 sequences + + config = SimpleNamespace() + config.max_seqlen_per_dp_cp_rank = args.max_seqlen_per_dp_cp_rank + config.microbatch_group_size_per_vp_stage = pp + config.virtual_pipeline_model_parallel_size = vpp + config.sequence_packing_scheduler = scheduler_type + + dp_rank = parallel_state.get_data_parallel_rank() + dp_size = parallel_state.get_data_parallel_world_size() + + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + is_pp_first = pp_rank == 0 + is_pp_last = pp_rank == pp - 1 + is_pp_first_or_last = is_pp_first or is_pp_last + is_tp_first = tp_rank == 0 + + num_micro_batches_old = global_batch_size // micro_batch_size // dp_size + + if is_tp_first and (is_pp_first or is_pp_last): + samples = [ + _create_single_sample(num) + for num in nums[dp_rank * num_micro_batches_old : (dp_rank + 1) * num_micro_batches_old] + ] + data_iterator = RerunDataIterator(iter(samples)) + else: + data_iterator = None + + if is_tp_first: + if vpp is not None and vpp > 1: + if is_pp_first: + data_iterator = [data_iterator] + [None for _ in range(vpp - 1)] + elif is_pp_last: + data_iterator = [None for _ in range(vpp - 1)] + [data_iterator] + else: + data_iterator = [None for _ in range(vpp)] + try: + # Call the function under test + ( + new_data_iterator, + num_micro_batches, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, + ) = wrap_data_iterator(data_iterator, config, num_micro_batches_old) + + # check the result + assert type(num_micro_batches) is int + assert ( + type(num_total_tokens_this_global_batch) is float + or type(num_total_tokens_this_global_batch) is np.float32 + ) + assert ( + type(sequence_square_sum_this_global_batch) is float + or type(sequence_square_sum_this_global_batch) is np.float32 + ) + + def _check_batch(batch_all, batch_keys): + for batch in batch_all: + assert set(batch_keys) <= set( + batch.keys() + ), f"batch keys: {set(batch.keys())} missing {set(batch_keys) - set(batch.keys())}" + for key in batch_keys: + assert batch[key] is not None + + if is_tp_first: + # CHECK KEYS + batch_keys = ["cu_seqlens", "max_seqlen", "cu_seqlens_padded"] + if vpp is not None and vpp > 1: + # check metadata for all stages (save batches to avoid re-consuming iterators) + all_stage_batches = [] + for temp_data_iterator in new_data_iterator: + stage_batch = [next(temp_data_iterator) for _ in range(num_micro_batches)] + all_stage_batches.append(stage_batch) + _check_batch(stage_batch, batch_keys) + + # check for first or last stage on first or last pp rank + if is_pp_first_or_last: + batch_all = all_stage_batches[0] if is_pp_first else all_stage_batches[-1] + batch_keys += ["tokens", "position_ids", "labels", "loss_mask"] + _check_batch(batch_all, batch_keys) + else: + # non-VPP: single iterator + batch_all = [next(new_data_iterator) for _ in range(num_micro_batches)] + if is_pp_first_or_last: + batch_keys += ["tokens", "position_ids", "labels", "loss_mask"] + _check_batch(batch_all, batch_keys) + + # CHECK TOKEN SUM ON FIRST OR LAST PP RANK + # Note: data_iterator is consumed by wrap_data_iterator, new_data_iterator is consumed above. + # Use `samples` for before-wrap, reuse `batch_all` from the check above for after-wrap. + if is_pp_first_or_last: + # Compute token sum before wrap + token_sum_before = torch.tensor(0, dtype=torch.int64, device='cuda') + for sample in samples: + token_sum_before += sample['tokens'].long().sum() + + # Compute token sum after wrap (batch_all already collected above with tokens) + token_sum_after = torch.tensor(0, dtype=torch.int64, device='cuda') + for batch in batch_all: + token_sum_after += batch['tokens'].long().sum() + + # Reduce sum across dp_cp group and verify equality + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=False) + torch.distributed.all_reduce( + token_sum_before, op=torch.distributed.ReduceOp.SUM, group=dp_cp_group + ) + torch.distributed.all_reduce( + token_sum_after, op=torch.distributed.ReduceOp.SUM, group=dp_cp_group + ) + + assert ( + token_sum_before == token_sum_after + ), f"Token sum mismatch: before={token_sum_before.item()}, after={token_sum_after.item()}" + + else: + if vpp is not None and vpp > 1: + assert type(new_data_iterator) is list and len(new_data_iterator) == vpp + for data_iterator in new_data_iterator: + assert data_iterator is None + else: + assert new_data_iterator is None + + finally: + Utils.destroy_model_parallel() + unset_global_variables()