From 73b4bbb9e1a5b968fa21261bd7cc502fee919114 Mon Sep 17 00:00:00 2001 From: xiaoyao0115 <1804647152@qq.com> Date: Tue, 3 Feb 2026 01:46:21 -0800 Subject: [PATCH 1/9] add thd e2e support and mock dataset Signed-off-by: xiaoyao0115 <1804647152@qq.com> --- megatron/core/datasets/data_schedule.py | 983 +++++++++++++++++- megatron/core/datasets/gpt_dataset.py | 6 + megatron/core/model_parallel_config.py | 13 +- megatron/core/pipeline_parallel/schedules.py | 9 + .../text/libraries/sft_tokenizer.py | 5 + .../core/transformer/transformer_config.py | 32 + megatron/training/arguments.py | 37 +- megatron/training/datasets/sft_dataset.py | 174 +++- megatron/training/training.py | 129 ++- pretrain_gpt.py | 17 +- tests/unit_tests/test_sequence_packing.py | 322 ++++++ 11 files changed, 1664 insertions(+), 63 deletions(-) create mode 100644 tests/unit_tests/test_sequence_packing.py diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 0f016473b6a..7a1dd9e587d 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -1,12 +1,27 @@ # Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. -from typing import Any, List, Optional +import enum +from typing import Any, Dict, List, Optional, Type, Union +import numpy as np import torch from megatron.core import parallel_state +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.rerun_state_machine import RerunDataIterator +from megatron.core.utils import is_te_min_version + +try: + # Register the TE CUDA kernels + import transformer_engine # pylint: disable=unused-import + + # Alias the PyTorch wrapper so we can call tex.* APIs + import transformer_engine_torch as tex +except ImportError: + # TE isn’t installed or the torch wrapper is missing + tex = None class HybridCPDataLoaderWrapper: @@ -299,3 +314,969 @@ def __next__(self) -> Any: batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets ) return samples_this_rank_with_id, sample_id_groups + + +class PackingScheduler(enum.Enum): + """Enum for supported sequence packing algorithms.""" + + DEFAULT_SEQUENCE_PACKING = "default_sequence_packing" + + +def _broadcast_tensor(item, src_rank, group) -> None: + """Broadcast a tensor from src_rank to all ranks in the group.""" + if item is not None: + torch.distributed.broadcast(item, src_rank, group=group) + + +class BaseScheduler: + """Base class for sequence packing schedulers.""" + + def __init__( + self, + max_seqlen_per_dp_cp_rank: int, + cp_size: int, + dp_size: int, + microbatch_group_size_per_vp_stage: Optional[int], + ): + self.max_seqlen_per_dp_cp_rank = max_seqlen_per_dp_cp_rank + self.cp_size = cp_size + self.dp_size = dp_size + self.microbatch_group_size_per_vp_stage = microbatch_group_size_per_vp_stage + + def get_require_sample_keys(self): + """Return the required key of each batch.""" + raise NotImplementedError + + def get_groups_and_subsamples(self, sample_id_seqlens): + """schedule the samples into groups""" + raise NotImplementedError + + def run( + self, + data_iterator, + num_microbatches, + dp_group, + tp_group, + pp_group, + dp_cp_group, + dev, + config, + ): + """Run the scheduler and return the new data_iterator.""" + raise NotImplementedError + + @staticmethod + def _get_global_seqlens(subsample_seqlens: torch.Tensor, dp_group) -> List[int]: + """ + Gathers the sequence lengths of all subsamples from all DP ranks. + + Each DP rank has the same number of subsamples (num_microbatches), + so we can directly all_gather without padding. + """ + dp_size = dp_group.size() + num_local_subsamples = subsample_seqlens.shape[0] + + # Gather the subsample_seqlens from all ranks + seqlens_gathered = [torch.empty_like(subsample_seqlens) for _ in range(dp_size)] + torch.distributed.all_gather(seqlens_gathered, subsample_seqlens, group=dp_group) + + seqlens_gathered = torch.cat(seqlens_gathered, dim=0) + seqlens_gathered = seqlens_gathered.cpu().tolist() + + # Calculate the offsets to assign unique global ID to each subsample. + # Since each rank has the same number of subsamples, offsets are evenly spaced. + offsets = torch.arange( + 0, dp_size * num_local_subsamples, num_local_subsamples, dtype=torch.int32 + ) + + return seqlens_gathered, offsets + + @staticmethod + def _get_global_id_seqlens(num_local_subsamples, offsets, seqlens_gathered, dp_group): + """ + Calculates the global ID for each subsample. + + We assign a unique global ID to each subsample. + + Returns: + global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. + global_ids_this_rank: list of global IDs locally present on this rank. + """ + dp_rank = dp_group.rank() + global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() + # Create a list of (global_id, seqlen) tuples for scheduling + global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] + # Get the global IDs locally present on this rank + global_ids_this_rank = global_ids[ + offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples + ] + + return global_id_seqlens, global_ids_this_rank + + @staticmethod + def _broadcast_to_pp_group( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + pp_group, + dev, + ): + """ + Broadcast num_micro_batches, seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch and metadata to middle PP stages. + """ + + pp_src_rank = torch.distributed.get_process_group_ranks(pp_group)[0] + + if pp_group.size() > 2: + if pp_group.rank() == 0: + tensor_list = [ + torch.tensor( + [ + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ], + dtype=torch.float32, + ).cuda() + ] + for sample in new_samples: + tensor_list.append(sample["max_seqlen"].unsqueeze(0)) + for sample in new_samples: + tensor_list.append(sample["cu_seqlens"]) + tensor_list.append(sample["cu_seqlens_padded"]) + info_to_broadcast = torch.cat(tensor_list, dim=0).to( + device=dev, dtype=torch.float32 + ) + info_length_tensor = torch.tensor( + info_to_broadcast.shape[0], dtype=torch.int32 + ).cuda() + _broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) + _broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) + else: + info_length_tensor = torch.tensor(0, dtype=torch.int32).cuda() + _broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) + info_to_broadcast = torch.empty( + info_length_tensor.item(), dtype=torch.float32 + ).cuda() + _broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) + if pp_group.rank() != pp_group.size() - 1: + # middle PP stages receive the broadcasted info and unpack it + info_numpy = info_to_broadcast.cpu().numpy() + num_micro_batches = int(info_numpy[0]) + seqlen_sum_this_global_batch = info_numpy[1] + seqlen_squared_sum_this_global_batch = info_numpy[2] + max_seqlens = info_to_broadcast[3 : 3 + num_micro_batches] + cu_seqlens_list = [] + cu_seqlens_padded_list = [] + indices = np.where(info_numpy == 0)[0] + for i in range(num_micro_batches): + cu_seqlens_list.append( + info_to_broadcast[indices[i * 2] : indices[i * 2 + 1]] + ) + if i == num_micro_batches - 1: + cu_seqlens_padded_list.append(info_to_broadcast[indices[i * 2 + 1] :]) + else: + cu_seqlens_padded_list.append( + info_to_broadcast[indices[i * 2 + 1] : indices[i * 2 + 2]] + ) + + new_samples = [] + for i in range(num_micro_batches): + new_sample = {} + new_sample["max_seqlen"] = max_seqlens[i].to(torch.int32) + new_sample["cu_seqlens"] = cu_seqlens_list[i].to(torch.int32) + new_sample["cu_seqlens_padded"] = cu_seqlens_padded_list[i].to(torch.int32) + new_samples.append(new_sample) + + return ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + @staticmethod + def _broadcast_scalars(values: List, group, dev, dtype=torch.float32) -> List: + """ + Broadcast scalar values from rank 0 to all ranks in the group. + + Args: + values: List of scalar values to broadcast (only used on rank 0). + group: The process group to broadcast within. + dev: The device to use for the tensor. + dtype: The data type for the tensor. + + Returns: + List of broadcasted values. + """ + if group.size() <= 1: + return values + + src_rank = torch.distributed.get_process_group_ranks(group)[0] + num_values = len(values) + + if group.rank() == 0: + info_to_broadcast = torch.tensor(values, dtype=dtype, device=dev) + else: + info_to_broadcast = torch.zeros(num_values, dtype=dtype, device=dev) + + _broadcast_tensor(info_to_broadcast, src_rank, group) + + if group.rank() != 0: + values = info_to_broadcast.cpu().tolist() + + return values + + @staticmethod + def _create_data_iterator(new_samples, pp_group, tp_group, config): + """Handle virtual pipeline parallelism.""" + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + vpp_size = config.virtual_pipeline_model_parallel_size + if tp_group.rank() == 0: + if pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1: + new_samples_for_other_ppstage = [] + for sample in new_samples: + new_sample_for_other_ppstage = {} + new_sample_for_other_ppstage["max_seqlen"] = sample["max_seqlen"] + new_sample_for_other_ppstage["cu_seqlens"] = sample["cu_seqlens"] + new_sample_for_other_ppstage["cu_seqlens_padded"] = sample[ + "cu_seqlens_padded" + ] + new_samples_for_other_ppstage.append(new_sample_for_other_ppstage) + if pp_group.rank() == 0: + new_data_iterator = [RerunDataIterator(iter(new_samples))] + [ + RerunDataIterator(iter(new_samples_for_other_ppstage)) + for _ in range(vpp_size - 1) + ] + else: + new_data_iterator = [ + RerunDataIterator(iter(new_samples_for_other_ppstage)) + for _ in range(vpp_size - 1) + ] + [RerunDataIterator(iter(new_samples))] + else: + new_data_iterator = [ + RerunDataIterator(iter(new_samples)) for _ in range(vpp_size) + ] + else: + new_data_iterator = [None for _ in range(vpp_size)] + else: + new_data_iterator = ( + RerunDataIterator(iter(new_samples)) if tp_group.rank() == 0 else None + ) + + return new_data_iterator + + @staticmethod + def _reroute_samples_to_dcp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_dcp_gpus, + ): + """ + Reroutes the sub-samples to the correct rank after scheduling. + + For each key in the batch dict, we perform an all-to-all communication + to transfer the data to the correct ranks. + """ + + def _gid_to_src_rank(gid: int) -> int: + dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) + dcp_rank = ( + torch.distributed.get_process_group_ranks(dp_group)[dp_src_rank] // tp_group.size() + ) % dp_cp_group.size() + return dcp_rank + + gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} + dcp_rank = dp_cp_group.rank() + dp_ranks = torch.distributed.get_process_group_ranks(dp_group) + dp_ranks = [(r // tp_group.size()) % dp_cp_group.size() for r in dp_ranks] + + data_keys = batch[0].keys() + + # Create the send plan + combined_sample_id_groups: List[List[int]] = [[] for _ in range(total_dcp_gpus)] + for d in range(total_dcp_gpus): + for sample_id_group in sample_id_groups: + combined_sample_id_groups[d].extend(sample_id_group[d]) + for dest_rank in range(total_dcp_gpus): + combined_sample_id_groups[dest_rank].sort() + + send_ids_sorted = [ + gid + for d in dp_ranks + for gid in combined_sample_id_groups[d] + if gid in global_ids_this_rank + ] + + send_num_split = [0] * total_dcp_gpus + send_lens_split = [0] * total_dcp_gpus + for dest_rank in range(total_dcp_gpus): + if dest_rank in dp_ranks: + send_seq_lens = [ + global_id_seqlens[gid][1] + for gid in combined_sample_id_groups[dest_rank] + if gid in global_ids_this_rank + ] + send_num_split[dest_rank] = len(send_seq_lens) + send_lens_split[dest_rank] = sum(send_seq_lens) + else: + send_lens_split[dest_rank] = 0 + + # Create the recv plan + recv_sample_id_groups = [[] for _ in range(total_dcp_gpus)] + for gid in combined_sample_id_groups[dcp_rank]: + src_rank = _gid_to_src_rank(gid) + recv_sample_id_groups[src_rank].append(gid) + + recv_lens_split = [0] * total_dcp_gpus + for src_rank in range(total_dcp_gpus): + recv_lens_split[src_rank] = sum( + [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] + ) + + recv_ids_sorted = [gid for d in range(total_dcp_gpus) for gid in recv_sample_id_groups[d]] + recv_counts = [len(recv_sample_id_groups[d]) for d in range(total_dcp_gpus)] + + recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] + + def _pack_sample_by_key(key: str) -> torch.Tensor: + flattened_tensors = [] + for gid in send_ids_sorted: + t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) + flattened_tensors.append(t.reshape(-1)) + return ( + torch.cat(flattened_tensors, dim=0) + if flattened_tensors + else torch.empty(1, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + ) + + def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): + cursor = 0 + for i, gid in enumerate(recv_ids_sorted): + sample_len = ( + 1 + if key in ["original_seq_len", "padded_seq_len"] + else global_id_seqlens[gid][1] + ) + recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] + cursor += sample_len + + for key in data_keys: + output_split_sizes, input_split_sizes = ( + (recv_counts, send_num_split) + if key in ["original_seq_len", "padded_seq_len"] + else (recv_lens_split, send_lens_split) + ) + send_tensor = _pack_sample_by_key(key) + recv_tensor_size = sum(output_split_sizes) + recv_tensor = torch.empty( + recv_tensor_size, device=torch.cuda.current_device(), dtype=send_tensor.dtype + ) + torch.distributed.all_to_all_single( + output=recv_tensor, + input=send_tensor, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=dp_cp_group, + ) + _unpack_sample_by_key(key, recv_tensor) + + recv_sample_with_id = { + recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted) + } + return recv_sample_with_id + + @staticmethod + def _pack_sequences( + samples: List, + padded_lengths: torch.Tensor, + original_lengths: torch.Tensor, + dev: torch.device, + ) -> Dict[str, torch.Tensor]: + """Pack multiple samples into a single packed sample.""" + + def _pack_tensors(tensors): + return torch.cat([t.reshape(-1) for t in tensors], dim=0) + + tokens = _pack_tensors([sample["tokens"] for sample in samples]) + labels = _pack_tensors([sample["labels"] for sample in samples]) + loss_mask = _pack_tensors([sample["loss_mask"] for sample in samples]) + position_ids = _pack_tensors([sample["position_ids"] for sample in samples]) + + new_sample = {} + new_sample["tokens"] = tokens + new_sample["labels"] = labels + new_sample["loss_mask"] = loss_mask + new_sample["position_ids"] = position_ids + + padded_lengths = padded_lengths.to( + device=dev, dtype=torch.int32, non_blocking=True + ).reshape(-1) + cu_seqlens_padded = torch.empty(padded_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens_padded[0] = 0 + cu_seqlens_padded[1:] = torch.cumsum(padded_lengths, dim=0) + max_seqlen = torch.max(padded_lengths).to(dtype=torch.int32) + + new_sample["cu_seqlens_padded"] = cu_seqlens_padded + new_sample["max_seqlen"] = max_seqlen + + original_lengths = original_lengths.to( + device=dev, dtype=torch.int32, non_blocking=True + ).reshape(-1) + cu_seqlens = torch.empty(original_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens[0] = 0 + cu_seqlens[1:] = torch.cumsum(original_lengths, dim=0).reshape(-1) + new_sample["cu_seqlens"] = cu_seqlens + + return new_sample + + @staticmethod + def _build_packed_microbatches( + grouped_samples: List[List[Dict[str, torch.Tensor]]], dev: torch.device + ) -> List[Dict[str, torch.Tensor]]: + """Build packed samples for each microbatch.""" + num_micro_batches = len(grouped_samples) + seg_starts: List[int] = [0] + original_lens_tensors = [] + padded_lens_tensors = [] + + for i in range(num_micro_batches): + samples = grouped_samples[i] + seg_starts.append(seg_starts[-1] + len(samples)) + original_lens_tensors.extend([s["original_seq_len"].reshape(-1) for s in samples]) + padded_lens_tensors.extend([s["padded_seq_len"].reshape(-1) for s in samples]) + + padded_lens_all_gpu = torch.cat(padded_lens_tensors, dim=0).to(dtype=torch.int32) + original_lens_all_gpu = torch.cat(original_lens_tensors, dim=0).to(dtype=torch.int32) + + new_samples: List[Dict[str, torch.Tensor]] = [] + for i in range(num_micro_batches): + samples = grouped_samples[i] + lp = padded_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + lo = original_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + new_sample = BaseScheduler._pack_sequences(samples, lp, lo, dev) + new_samples.append(new_sample) + + return new_samples + + @staticmethod + def _get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group): + """ + Get the batch and global sequence lengths. + Each DP rank loads the same number of sequences, so we need to gather the sequence + lengths from all ranks then we can schedule the sequences into groups. + Args: + data_iterator: The data iterator. + num_microbatches: The number of microbatches. + dp_group: The data parallel group. + + Returns: + batch: The batch. + global_id_seqlens: The global sequence lengths. + global_ids_this_rank: The global IDs locally present on this rank. + """ + batch = [next(data_iterator) for _ in range(num_microbatches)] + subsample_seqlens = [] + for sample in batch: + subsample_seqlens.extend([sample["tokens"].numel()]) + subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() + + seqlens_gathered, offsets = BaseScheduler._get_global_seqlens(subsample_seqlens, dp_group) + + global_id_seqlens, global_ids_this_rank = BaseScheduler._get_global_id_seqlens( + subsample_seqlens.shape[0], offsets, seqlens_gathered, dp_group + ) + + return batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered + + +class DefaultSequencePackingScheduler(BaseScheduler): + """Packs sequences in their original order until reaching the max limit of sequence length.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_seq_len_all_ranks = self.max_seqlen_per_dp_cp_rank * self.cp_size + + def get_require_sample_keys(self): + """Return the required key of each batch.""" + return [ + "tokens", + "labels", + "loss_mask", + "position_ids", + "original_seq_len", # Length of the original sequence length, should be a gpu tensor. + "padded_seq_len", # Length of the padded sequence length, should be a gpu tensor. + ] + + def get_groups_and_subsamples(self, sample_id_seqlens): + """ + Packs sequences in their original order until reaching the max limit of sequence length. + """ + sample_id_groups = [] + packed_id_groups = [] + sum_seqlen = 0 + single_microbatch = [] + + for i in range(len(sample_id_seqlens)): + if sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks: + single_microbatch.append(i) + sum_seqlen += sample_id_seqlens[i][1] + else: + packed_id_groups.append(single_microbatch) + single_microbatch = [i] + sum_seqlen = sample_id_seqlens[i][1] + if len(single_microbatch) > 0: + packed_id_groups.append(single_microbatch) + + # we want the number of packed sequences to be multiple of dp_size + # so we move few samples from previous microbatch + # to the end of the microbatches if needed + num_packed_sequence = len(packed_id_groups) + + # when enabling vpp, we want the number of packed sequences to be + # multiple of dp_size * microbatch_group_size_per_vp_stage + multiple = self.dp_size * ( + self.microbatch_group_size_per_vp_stage + if self.microbatch_group_size_per_vp_stage is not None + else 1 + ) + if num_packed_sequence % multiple != 0: + remainder = num_packed_sequence % multiple + num_to_move = multiple - remainder + i = num_packed_sequence - 1 + while num_to_move > 0: + assert i > 0, "Not enough samples to move" + if len(packed_id_groups[i]) > 1: + seq_id = packed_id_groups[i].pop() + packed_id_groups.append([seq_id]) + num_to_move -= 1 + else: + i -= 1 + + num_micro_batches = int(len(packed_id_groups) / self.dp_size) + for i in range(num_micro_batches): + sample_id_groups.append([]) + for j in range(self.cp_size * self.dp_size): + seq_id = int(i * self.dp_size + j / self.cp_size) + sample_id_groups[i].append(packed_id_groups[seq_id]) + return sample_id_groups + + def run( + self, + data_iterator, + num_microbatches: int, + dp_group, + tp_group, + pp_group, + dp_cp_group, + dev: torch.device, + config, + ): + """ + Run the complete scheduling pipeline. + + Steps: + 1. Fetch batches and gather global sequence lengths + 2. Check required sample keys + 3. Schedule samples into groups + 4. Reroute samples to DCP ranks + 5. Build packed microbatches + 6. Calculate FLOPs info + 7. Broadcast to PP group (for middle PP stages) + 8. Broadcast to TP group (for non-TP-0 ranks) + 9. Handle VPP if enabled + + Args: + data_iterator: The data iterator. + num_microbatches: The number of microbatches to fetch. + dp_group: Data parallel process group. + tp_group: Tensor parallel process group. + pp_group: Pipeline parallel process group. + dp_cp_group: Data parallel + context parallel process group. + dev: CUDA device. + config: Model parallel config. + + Returns: + new_data_iterator: The new data iterator (or list for VPP). + num_micro_batches: Number of micro batches after scheduling. + seqlen_sum_this_global_batch: Total tokens for FLOPs calculation. + seqlen_squared_sum_this_global_batch: Sum of squared seqlens for FLOPs. + """ + + total_dcp_gpus = dp_cp_group.size() + + # Handle VPP: extract the correct data_iterator for this PP stage + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + # if enable VPP, data_iterator is a list of data_iterators for each VPP stage, + # and only the first and last stage rank will have data_iterator, + # other stages will have None. + assert len(data_iterator) == config.virtual_pipeline_model_parallel_size + if pp_group.rank() == 0: + # the first stage + data_iterator = data_iterator[0] + elif pp_group.rank() == pp_group.size() - 1: + # the last stage + data_iterator = data_iterator[-1] + else: + data_iterator = None + + # data_iterator is not None when TP rank 0, with PP stage 0 or -1. + if data_iterator is not None: + assert tp_group.rank() == 0 and ( + pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1 + ), f"Only TP rank 0 and PP stage 0 or -1 should have data_iterator" + + # Step 1: Fetch batches and gather global sequence lengths + batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered = ( + self._get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group) + ) + + # Step 2: Check required sample keys + for key in self.get_require_sample_keys(): + assert key in batch[0], f"Batch missing required key {key}" + + # Step 3: Schedule samples into groups + sample_id_groups = self.get_groups_and_subsamples(global_id_seqlens) + + # Validate scheduling result + set_gbs = set() + for group in sample_id_groups: + for sub in group: + set_gbs.update(sub) + assert len(set_gbs) == len(global_id_seqlens), ( + f"set_gbs length: {len(set_gbs)} != " + f"global_id_seqlens length: {len(global_id_seqlens)}" + ) + + # Step 4: Reroute samples to DCP ranks + samples_this_rank_with_id = self._reroute_samples_to_dcp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_dcp_gpus, + ) + + dcp_rank = dp_cp_group.rank() + num_micro_batches = len(sample_id_groups) + + grouped_samples = [ + [ + samples_this_rank_with_id[sub_sample_id] + for sub_sample_id in sample_id_groups[i][dcp_rank] + ] + for i in range(num_micro_batches) + ] + + # Step 5: Build packed microbatches + new_samples = self._build_packed_microbatches(grouped_samples, dev) + + # Step 6: Calculate FLOPs info + seqlen_sum_this_global_batch = float(sum(seqlens_gathered)) + seqlen_squared_sum_this_global_batch = float( + sum(seqlen**2 for seqlen in seqlens_gathered) + ) + else: + ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = (None, None, None, None) + + # Step 7: Broadcast to PP group (for middle PP stages) + if tp_group.rank() == 0: + ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = self._broadcast_to_pp_group( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + pp_group, + dev, + ) + + # Step 8: Broadcast to TP group (for non-TP-0 ranks) + (num_micro_batches, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) = ( + self._broadcast_scalars( + [ + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ], + tp_group, + dev, + ) + ) + num_micro_batches = int(num_micro_batches) + + # Step 9: create data_iterator and handle VPP if enabled + new_data_iterator = self._create_data_iterator(new_samples, pp_group, tp_group, config) + + return ( + new_data_iterator, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + +def wrap_dataloader( + data_iterator, config, num_microbatches, pg_collection: Optional[ProcessGroupCollection] = None +): + """ + A wrapper function that wraps around an existing data_iterator + and return the num_micro_batches for sequence packing. + + Args: + data_iterator: The original data_iterator to wrap around + config: The config object containing the max_seqlen_per_dp_cp_rank + dp_cp_group: Data parallel context parallel group. + pg_collection: The process group collection. + """ + + scheduler_map: Dict[PackingScheduler, Type[BaseScheduler]] = { + PackingScheduler.DEFAULT_SEQUENCE_PACKING: DefaultSequencePackingScheduler + } + + if pg_collection is None: + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + dp_group = parallel_state.get_data_parallel_group() + tp_group = parallel_state.get_tensor_model_parallel_group() + pp_group = parallel_state.get_pipeline_model_parallel_group() + else: + dp_cp_group = pg_collection.dp_cp + dp_group = pg_collection.dp + tp_group = pg_collection.tp + pp_group = pg_collection.pp + assert ( + dp_cp_group is not None + and dp_group is not None + and tp_group is not None + and pp_group is not None + ), "dp_cp_group, dp_group, tp_group must not be None when using sequence packing" + + dev = torch.cuda.current_device() + dp_size = dp_group.size() + cp_size = dp_cp_group.size() // dp_size + + # Convert string to enum + scheduler_type = config.sequence_packing_scheduler + scheduler_type = PackingScheduler[scheduler_type.upper()] + + scheduler = scheduler_map[scheduler_type]( + config.max_seqlen_per_dp_cp_rank, + cp_size, + dp_size, + # When VPP is enabled, align num_micro_batches to this multiple. + ( + None + if config.virtual_pipeline_model_parallel_size is None + else config.microbatch_group_size_per_vp_stage + ), + ) + + ( + new_data_iterator, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = scheduler.run( + data_iterator, num_microbatches, dp_group, tp_group, pp_group, dp_cp_group, dev, config + ) + + return ( + new_data_iterator, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + +def get_batch_on_this_rank_for_sequence_packing( + data_iterator, mtp_on_this_rank: bool = False, vp_stage: Optional[int] = None +): + """ + Get a batch of data for sequence packing. + Args: + data_iterator (Iterator): The data iterator to get the batch from. + mtp_on_this_rank (bool): Whether to use multi-token prediction. + vp_stage (Optional[int]): The stage of the pipeline. + Returns: + tuple of (tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params) + """ + + tp_src_rank = parallel_state.get_tensor_model_parallel_src_rank() + tp_group = parallel_state.get_tensor_model_parallel_group() + + is_tp_rank_0 = parallel_state.get_tensor_model_parallel_rank() == 0 + is_first_stage = parallel_state.is_pipeline_first_stage( + ignore_virtual=vp_stage is None, vp_stage=vp_stage + ) + is_last_stage = parallel_state.is_pipeline_last_stage( + ignore_virtual=vp_stage is None, vp_stage=vp_stage + ) + is_first_or_last_stage = is_first_stage or is_last_stage + dev = torch.cuda.current_device() + + # data_iterator should return a batch including the following keys. + batch_keys = ['cu_seqlens', 'cu_seqlens_padded', 'max_seqlen'] + if is_first_stage: + batch_keys.append('tokens') + batch_keys.append('position_ids') + if is_last_stage: + batch_keys.append('labels') + batch_keys.append('loss_mask') + + # Get a batch from data_iterator or create an emtpy batch. + if is_tp_rank_0: + assert data_iterator is not None + batch = next(data_iterator) + for key in batch_keys: + assert key in batch, f"{key} is missing in current batch." + else: + assert data_iterator is None, "Non TP 0 rank should not have data_iterator" + batch = {} + + # Partition tokens, position_ids, labels, loss_mask for context parallel, currently only + # TP rank 0 and the first/last PP stage rank has these data. + if is_tp_rank_0 and is_first_or_last_stage: + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + # If cp_size == 1, no need to do further processing. + if cp_size > 1: + assert tex is not None and is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + total_tokens = batch['tokens'].size(0) + # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as + # cu_seqlens to get the correct result. + # TODO: Revert this workaround once TE fixes the issue. + cu_seqlens = batch["cu_seqlens_padded"] + index = tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) + for key in ['tokens', 'position_ids', 'labels', 'loss_mask']: + batch[key] = batch[key].index_select(0, index) + + # Broadcast cu_seqlens_size because we need it to create placeholder for cu_seqlens and + # cu_seqlens_padded for non TP 0 ranks. + if is_tp_rank_0: + cu_seqlen_size = torch.tensor(batch['cu_seqlens'].size(0), dtype=torch.int32, device=dev) + else: + cu_seqlen_size = torch.empty(1, dtype=torch.int32, device=dev) + _broadcast_tensor(cu_seqlen_size, tp_src_rank, tp_group) + cu_seqlen_size = cu_seqlen_size.item() + + # Broadcast total_tokens because we need it to create placeholder for tokens, position_ids, + # labels, loss_mask for non TP 0 ranks. Only first or last stage need this. + if is_first_or_last_stage: + if is_tp_rank_0: + total_tokens = torch.tensor(batch['tokens'].size(0), dtype=torch.int32, device=dev) + else: + total_tokens = torch.empty(1, dtype=torch.int32, device=dev) + _broadcast_tensor(total_tokens, tp_src_rank, tp_group) + total_tokens = total_tokens.item() + + # Step1: Prepare "tokens", "position_ids" on all ranks. + if is_first_stage or mtp_on_this_rank: + if is_tp_rank_0: + assert batch['tokens'].dtype == torch.int64 + assert batch['position_ids'].dtype == torch.int64 + batch['tokens'] = batch['tokens'].view(1, total_tokens) + batch['position_ids'] = batch['position_ids'].view(1, total_tokens) + else: + batch['tokens'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + batch['position_ids'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + else: + # Non first stage rank doesn't need tokens and position_ids. + batch['tokens'] = None + batch['position_ids'] = None + + # Step2: Prepare "labels", "loss_mask" on all ranks. + if is_last_stage: + if is_tp_rank_0: + assert batch['labels'].dtype == torch.int64 + assert batch['loss_mask'].dtype == torch.float32 + batch['labels'] = batch['labels'].view(1, total_tokens) + batch['loss_mask'] = batch['loss_mask'].view(1, total_tokens) + else: + batch['labels'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + batch['loss_mask'] = torch.empty([1, total_tokens], dtype=torch.float32, device=dev) + else: + # Non last stage rank doesn't need labels and loss_mask. + batch['labels'] = None + batch['loss_mask'] = None + + # Step3: Prepare "cu_seqlens", "cu_seqlens_padded", "max_seqlen" on all ranks. + if is_tp_rank_0: + assert batch['cu_seqlens'].dtype == torch.int32 + assert batch['cu_seqlens_padded'].dtype == torch.int32 + assert batch['cu_seqlens'].dim() == 1 + assert batch['cu_seqlens_padded'].dim() == 1 + if type(batch['max_seqlen']) == int: + batch['max_seqlen'] = torch.tensor(batch['max_seqlen'], dtype=torch.int32, device=dev) + else: + assert batch['max_seqlen'].dtype == torch.int32 + assert batch['max_seqlen'].numel() == 1 + else: + batch['cu_seqlens'] = torch.empty([cu_seqlen_size], dtype=torch.int32, device=dev) + batch['cu_seqlens_padded'] = torch.empty([cu_seqlen_size], dtype=torch.int32, device=dev) + batch['max_seqlen'] = torch.empty(1, dtype=torch.int32, device=dev) + + # Broadcast batch inside TP group. + _broadcast_tensor(batch['tokens'], tp_src_rank, tp_group) + _broadcast_tensor(batch['position_ids'], tp_src_rank, tp_group) + _broadcast_tensor(batch['labels'], tp_src_rank, tp_group) + _broadcast_tensor(batch['loss_mask'], tp_src_rank, tp_group) + _broadcast_tensor(batch['cu_seqlens'], tp_src_rank, tp_group) + _broadcast_tensor(batch['cu_seqlens_padded'], tp_src_rank, tp_group) + _broadcast_tensor(batch['max_seqlen'], tp_src_rank, tp_group) + + # Extract the data from batch after broadcasting. + tokens = batch['tokens'] + position_ids = batch['position_ids'] + labels = batch['labels'] + loss_mask = batch['loss_mask'] + cu_seqlens = batch['cu_seqlens'] + cu_seqlens_padded = batch['cu_seqlens_padded'] + max_seqlen = batch['max_seqlen'].item() + + # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as cu_seqlens to + # get the correct result. + # TODO: Revert this workaround once TE fixes the issue. + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + local_cp_size=None, + cp_group=None, + ) + + # "attention_mask" is not valid for sequence packing, so set it to None. + return tokens, labels, loss_mask, None, position_ids, packed_seq_params diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index cbe0652402d..fa421641db5 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -79,6 +79,12 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): context_parallel_size: Optional[int] = None """The size of the context parallel group. Needed for padding in packed sequences.""" + sft_mock_dataset_config_json: Optional[str] = None + """This config provides the necessary information for the mock dataset.""" + + sequence_packing: bool = False + """Option to enable sequence packing for training.""" + def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 5bbeef9b022..c5cf76fb099 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -62,7 +62,7 @@ class ModelParallelConfig: can handle without overflowing the memory. Typically, a good starting point is to set this to maximum sequence length / context parallel size. This is used to calculate the number and length of sub-samples assigned to - each rank when using hybrid_context_parallel. + each rank when using sequence_packing. """ hybrid_context_parallel: bool = False @@ -72,6 +72,17 @@ class ModelParallelConfig: Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel. """ + sequence_packing_scheduler: Optional[str] = None + """ + Scheduler for sequence packing and hybrid context parallel. + default_sequence_packing: default sequence packing scheduler for sequence packing. + """ + + sequence_packing: bool = False + """ + If true, enables sft sequence packing. + """ + expert_model_parallel_size: int = 1 """Distributes Moe Experts across sub data parallel dimension.""" diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index e903f392bf0..23727f9c751 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -554,6 +554,9 @@ def forward_backward_no_pipelining( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + pg_collection.dp = parallel_state.get_data_parallel_group( + with_context_parallel=False, partial_data_parallel=False + ) elif pg_collection is not None: assert hasattr(pg_collection, 'tp') @@ -1011,6 +1014,9 @@ def forward_backward_pipelining_with_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + pg_collection.dp = parallel_state.get_data_parallel_group( + with_context_parallel=False, partial_data_parallel=False + ) elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model[0]) @@ -2156,6 +2162,9 @@ def forward_backward_pipelining_without_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + pg_collection.dp = parallel_state.get_data_parallel_group( + with_context_parallel=False, partial_data_parallel=False + ) elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model) assert hasattr(p2p_communicator, 'config'), "p2p_communicator must have a config" diff --git a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py index 8a418f2dd7f..e06e51e5ee3 100644 --- a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py @@ -237,6 +237,11 @@ def bos_id(self): def eod(self): """End of sentence token ID.""" return self._tokenizer.eos_token_id + + @property + def eos(self): + """End of sentence token ID.""" + return self._tokenizer.eos_token_id @property def vocab(self): diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 9da9a644a47..aafdc79bbc6 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2076,6 +2076,38 @@ def __post_init__(self): self.attention_backend == AttnBackend.flash ), "Batch invariant mode only supports FlashAttention" + if self.sequence_packing: + # Check TE version. + if not HAVE_PACKAGING: + raise ImportError( + "packaging is not installed. Please install it with `pip install packaging`." + ) + # TODO: remove this after we fix the convergence issue with TE < 2.9. + if not ( + is_te_min_version("2.9.0") or get_te_version() == PkgVersion("2.9.0.dev0+5b3092a") + ): + raise ValueError( + "SFT sequence packing requires Transformer Engine >= 2.9.0 " + f"but got {get_te_version()} (TE < 2.9.0 may have convergence issues)." + ) + + # Needed for passing variable sequences between pp stages. + self.variable_seq_lengths = True + + # TODO(tailaim): add support for other dispatcher types + warnings.warn("Setting moe_token_dispatcher_type to alltoall for sft sequence packing.") + self.moe_token_dispatcher_type = "alltoall" + + if self.sequence_packing_scheduler is None: + self.sequence_packing_scheduler = 'default_sequence_packing' + + supported_schedulers = ['default_sequence_packing'] + if self.sequence_packing_scheduler not in supported_schedulers: + raise ValueError( + f"Unknown scheduler: {self.sequence_packing_scheduler}. " + f"Available schedulers: {supported_schedulers}" + ) + @dataclass @experimental_api diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5d5fa34b6c5..50e2eb3581a 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -884,13 +884,6 @@ def validate_args(args, defaults={}): if args.rl_use_sequence_packing: args.consumed_train_bins = 0 - # Support for variable sequence lengths across batches/microbatches. - # set it if the dataloader supports generation of variable sequence lengths - # across batches/microbatches. Due to additional communication overhead - # during pipeline parallelism, it should not be set if sequence length - # is constant during training. - args.variable_seq_lengths = False - # Iteration-based training. if args.train_iters: # If we use iteration-based training, make sure the @@ -1061,6 +1054,32 @@ def validate_args(args, defaults={}): assert args.dataloader_type == 'single', 'Hybrid context parallelism only supported with single dataloader type' assert args.calculate_per_token_loss, 'Hybrid context parallelism must be used with --calculate-per-token-loss' + # Support for variable sequence lengths across batches/microbatches. + # set it if the dataloader supports generation of variable sequence lengths + # across batches/microbatches. Due to additional communication overhead + # during pipeline parallelism, it should not be set if sequence length + # is constant during training. + args.variable_seq_lengths = False + if args.sequence_packing: + args.variable_seq_lengths = True + assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ + f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ + f'must be >= single sequence max length ({args.seq_length})' + # TODO(tailaim): add support for other dispatcher types + print(f"Setting moe_token_dispatcher_type to alltoall for sft sequence packing with pipeline parallelism") + args.moe_token_dispatcher_type = "alltoall" + if args.mock_data and args.sft_mock_dataset_config_json is None: + args.sft_mock_dataset_config_json = json.dumps( + { + "mode": "distribution", + "type": "lognormal", + "min_seq_len": args.seq_length // 2, + "max_seq_len": args.seq_length, + "mean_seq_len": args.seq_length // 4 * 3, + "lognormal_sigma": 1.1, + } + ) + # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled if (args.tensor_model_parallel_size > 1 or args.context_parallel_size > 1) \ @@ -3061,4 +3080,8 @@ def _add_sft_args(parser): group.add_argument('--sft', action="store_true", help='Megatron SFT training') group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", help='SFT prompt format.') + group.add_argument('--sequence-packing', action='store_true', + help='use sequence packing(thd format) for training') + group.add_argument('--sft-mock-dataset-config-json', type=str, default=None, + help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution.') return parser diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index 9de5d2a52fe..b8b15e1a985 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -2,9 +2,12 @@ import atexit, json from collections import Counter -from typing import Any, Dict, Optional +import json +import math +from typing import Any, Dict, Optional, List import numpy as np +import pandas as pd import torch from megatron.core.datasets.gpt_dataset import GPTDatasetConfig @@ -61,6 +64,8 @@ def __init__( config: GPTDatasetConfig, ) -> None: super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) + # Pre-calculate padding divisor to avoid redundant computation in get_padding_size + self.padding_divisor = self._calculate_padding_divisor() @staticmethod def numel_low_level_dataset(low_level_dataset: LowLevelDataset) -> int: @@ -88,8 +93,38 @@ def _split_conversations(self, merged_conversations): split_conversations.append(current) return split_conversations - def __getitem__(self, idx: int) -> Dict[str, Any]: + def _calculate_padding_divisor(self) -> int: + """ + Calculate the divisor used for sequence padding. + tp_pad = tp_size * 2 if tp_size > 1 else 1 + cp_pad = cp_size * 2 if cp_size > 1 else 1 + cp_pad = cp_pad * dp_size if hybrid_cp else cp_pad + divisor = cp_pad * tp_pad + """ + if self.config.hybrid_context_parallel: + # Hybrid CP: consider both CP and DP + cp_pad = self.config.data_parallel_size * self.config.context_parallel_size * 2 + else: + # Standard CP: only consider CP + cp_pad = self.config.context_parallel_size * 2 if self.config.context_parallel_size > 1 else 1 + tp_pad = self.config.sequence_parallel_size if self.config.sequence_parallel_size > 0 else 1 + divisor = cp_pad * tp_pad + # TODO(tailaim): do we need to pad for FP8 execution? + # divisor = ((divisor + 15) // 16) * 16 + return divisor + + def get_padding_size( + self, + seq_len: int, + ) -> int: + seq_len_padded = math.ceil(seq_len / self.padding_divisor) * self.padding_divisor + assert seq_len > seq_len_padded / 2 / self.config.context_parallel_size * (self.config.context_parallel_size - 1), \ + f"sequence length {seq_len} is too short, the divisor is {self.padding_divisor}, that means cp_rank \ + {self.config.context_parallel_size-1} will have no valid tokens" + return seq_len_padded + def __getitem__(self, idx: int) -> Dict[str, Any]: + sequence_packing = self.config.sequence_packing tokenizer = self.config.tokenizer pack_length = self.config.sequence_length @@ -190,3 +225,138 @@ def extend_with_padding(tokens, targets, positions, pad_len): 'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, } + +class MockSFTLowLevelDataset: + """The low-level mock dataset for SFT + + Args: + mock_config (dict): The config for mock dataset. + """ + + seed: int = 0 + """The hard-coded random seed to use to set the NumPy RNG""" + + size: int = 1000000 + """The hard-coded number of sequence to generate""" + + # This is to maintain consistency with the SFT dataset that uses real data. In the real dataset, an element in the low-level dataset often contains multiple sequences. So here, each element in the mock low-level dataset also contains num_sequence_per_sample sequences. This will be made more reasonable in the future. + + + def __init__(self, config: Dict) -> None: + np.random.seed(self.seed) + # either choose to load sequence lengths from external file, or generate random sequence lengths + + assert "mode" in config, f"mode must be set, either 'file' or 'distribution'" + + if config["mode"] == "file": + self.sequence_lengths = np.array(pd.read_csv(config["path"])).flatten() + self.size = len(self.sequence_lengths) + elif config["mode"] == "distribution": + min_seq_len = config["min_seq_len"] + max_seq_len = config["max_seq_len"] + mean_seq_len = config["mean_seq_len"] + if config["type"] == "lognormal": + lognormal_sigma = config["lognormal_sigma"] + self.sequence_lengths = self.generate_lognormal_samples(self.size, mean_seq_len,lognormal_sigma, min_seq_len, max_seq_len) + else: + raise ValueError(f"Unsupported sequence length distribution type {config['type']}") + + def generate_lognormal_samples(self, size, mean, sigma, min_seq_len, max_seq_len): + mu = np.log(mean) - sigma**2 / 2 + samples = np.random.lognormal(mu, sigma, size) + samples = np.clip(samples, min_seq_len, max_seq_len) + return samples.astype(int) + + def __len__(self) -> int: + return self.size + + def __getitem__(self, idx: int) -> List[np.ndarray]: + length = self.sequence_lengths[idx % self.size] + # the length of sample is 'length', but only length-1 elements are generated here, + # because an eod token will be appended at the end later in SFTDataset + sample = np.arange(1, length, dtype=np.int64) + return sample +class MockSFTDataset(SFTDataset): + """The mock dataset used during SFT""" + + def __init__( + self, + dataset: LowLevelDataset, + dataset_path: Optional[str], + indices: np.ndarray, + num_samples: Optional[int], + index_split: Split, + config: GPTDatasetConfig, + ) -> None: + super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) + + @staticmethod + def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowLevelDataset: + mock_config = json.loads(config.sft_mock_dataset_config_json) + return MockSFTLowLevelDataset(mock_config) + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> Dict[str, Any]: + sequence_packing = self.config.sequence_packing + tokenizer = self.config.tokenizer + max_seq_len = self.config.sequence_length + + tokens = self.dataset[int(self.indices[idx % len(self.indices)])] + target = np.array(tokens, dtype=np.int64) + + force_eod_length = int(tokenizer.force_eod) + + if len(tokens) > max_seq_len - force_eod_length: + # cut the right side + tokens = tokens[: max_seq_len - force_eod_length] + target = target[: max_seq_len - force_eod_length] + # tokens = tokens[(-max_seq_len + force_eod_length):] + # target = target[(-max_seq_len + force_eod_length):] + + # padding + num_tokens = len(tokens) + force_eod_length + if sequence_packing: + padding_len = self.get_padding_size(num_tokens) - num_tokens + else: + padding_len = max_seq_len - num_tokens + assert padding_len >= 0 + filler = [tokenizer.eod] * force_eod_length + [tokenizer.pad] * (padding_len + 1) + + tokens = np.array(tokens.tolist() + filler, dtype=np.int64) + target = np.array(target.tolist() + filler, dtype=np.int64) + + tokens = torch.tensor(tokens) + target = torch.tensor(target) + + tokens = tokens[:-1].contiguous() + target = target[1:].contiguous() + seq_len = tokens.numel() + + loss_mask, position_ids, attention_mask = self._get_ltor_masks_and_position_ids( + seq_len, target, tokenizer.pad + ) + + if self.config.create_attention_mask: + ret = { + 'tokens': tokens, + 'labels': target, + 'attention_mask': attention_mask, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + else: + ret = { + 'tokens': tokens, + 'labels': target, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + + if sequence_packing: + # sequence packing need both original sequence length and padded length + ret['original_seq_len'] = torch.tensor(num_tokens, dtype=torch.int32, device=tokens.device) + ret['padded_seq_len'] = torch.tensor(seq_len, dtype=torch.int32, device=tokens.device) + + return ret diff --git a/megatron/training/training.py b/megatron/training/training.py index 0c33206ba8b..4b811528fde 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -139,7 +139,6 @@ def set_startup_timestamps(program_start=None, main_entry=None): from megatron.training.initialize import set_jit_fusion_options from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank from megatron.training.datasets.data_samplers import build_pretraining_data_loader -from megatron.core.datasets.data_schedule import HybridCPDataLoaderWrapper from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.transformer.moe import upcycling_utils from megatron.core.transformer.moe.moe_utils import track_moe_metrics, clear_aux_losses_tracker @@ -169,6 +168,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): get_num_microbatches, update_num_microbatches ) +from megatron.core.datasets.data_schedule import wrap_dataloader from .async_utils import maybe_finalize_async_save from .utils import ( @@ -225,7 +225,7 @@ def print_datetime(string, override_timestamp=None): time_str = datetime.fromtimestamp(override_timestamp).strftime('%Y-%m-%d %H:%M:%S.%f') print_rank_0(f'[{string}] datetime: {time_str} ') -def num_floating_point_operations(args, batch_size): +def num_floating_point_operations(args, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch): def calculate_layer_counts(): """Calculate the number of attention, Mamba, and MLP layers.""" if args.hybrid_override_pattern: @@ -251,44 +251,42 @@ def calculate_layer_counts(): num_moe_layers = 0 return num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers - def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False): + def mlp_layer_flops(seqlen_sum_this_global_batch, hidden_size, expansion=4.0, swiglu=False): """Calculate FLOPs for an MLP layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 - return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 + return 4 * expansion * scale_factor * seqlen_sum_this_global_batch * hidden_size**2 - def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, + def moe_layer_flops(seqlen_sum_this_global_batch, hidden_size, moe_ffn_hidden_size, shared_expert_ffn_hidden_size, num_experts_routed_to, moe_latent_size=None, swiglu=False): """Calculate FLOPs for an MoE layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 if moe_latent_size is None: - routed_flops = (4 * batch_size * seq_len * hidden_size * + routed_flops = (4 * seqlen_sum_this_global_batch * hidden_size * moe_ffn_hidden_size * num_experts_routed_to * scale_factor) else: # Routed experts run on moe_latent_size. - routed_flops = (4 * batch_size * seq_len * moe_latent_size * + routed_flops = (4 * seqlen_sum_this_global_batch * moe_latent_size * moe_ffn_hidden_size * num_experts_routed_to * scale_factor) # Up proj and down proj. - routed_flops += (4 * batch_size * seq_len * hidden_size * moe_latent_size) - shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor + routed_flops += (4 * seqlen_sum_this_global_batch * hidden_size * moe_latent_size) + shared_flops = 4 * seqlen_sum_this_global_batch * hidden_size * shared_expert_ffn_hidden_size * scale_factor return routed_flops + shared_flops def attn_layer_flops( - batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None + seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None ): """Calculate FLOPs for an attention layer.""" p = (kv_channels * num_heads / hidden_size) if kv_channels else 1 g = gqa_groups if gqa else num_heads return ( 4 - * batch_size - * seq_len * hidden_size * p - * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2)) + * (hidden_size * seqlen_sum_this_global_batch + (hidden_size * (g / num_heads)) * seqlen_sum_this_global_batch + (seqlen_squared_sum_this_global_batch / 2)) ) - def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, + def mamba_layer_flops(seqlen_sum_this_global_batch, hidden_size, state_dim=16, head_dim=64, num_groups=1, num_heads=128): """Calculate FLOPs for a Mamba layer.""" # Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels, @@ -301,16 +299,15 @@ def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, return ( ( 2 - * batch_size - * seq_len + * seqlen_sum_this_global_batch * hidden_size * (2 * d_in + 2 * num_groups * state_dim + nheads) ) # in_proj - + (7 * batch_size * seq_len * d_in * state_dim) # scan - + (2 * batch_size * seq_len * d_in * hidden_size) # out_proj + + (7 * seqlen_sum_this_global_batch * d_in * state_dim) # scan + + (2 * seqlen_sum_this_global_batch * d_in * hidden_size) # out_proj ) - def hybrid_flops(batch_size, seq_len, hidden_size, + def hybrid_flops(seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, hidden_size, num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers, mamba_state_dim=128, mamba_head_dim=64, mamba_num_groups=8, mamba_num_heads=128, @@ -322,17 +319,17 @@ def hybrid_flops(batch_size, seq_len, hidden_size, vocab_size=256000, mtp_num_layers=0): """Calculate total FLOPs for the hybrid model.""" flops_fwd = ( - num_attn_layers * attn_layer_flops(batch_size, seq_len, hidden_size, + num_attn_layers * attn_layer_flops(seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, hidden_size, num_attn_heads, gqa, gqa_groups, kv_channels) + - num_mlp_layers * mlp_layer_flops(batch_size, seq_len, hidden_size, + num_mlp_layers * mlp_layer_flops(seqlen_sum_this_global_batch, hidden_size, mlp_expansion, swiglu) + - num_mamba_layers * mamba_layer_flops(batch_size, seq_len, hidden_size, + num_mamba_layers * mamba_layer_flops(seqlen_sum_this_global_batch, hidden_size, mamba_state_dim, mamba_head_dim, mamba_num_groups, mamba_num_heads) + - num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, + num_moe_layers * moe_layer_flops(seqlen_sum_this_global_batch, hidden_size, moe_ffn_hidden_size, shared_expert_ffn_hidden_size, num_experts_routed_to, moe_latent_size, swiglu) + - (2 * batch_size * seq_len * hidden_size * vocab_size * (1 + mtp_num_layers)) # logits computation + (2 * seqlen_sum_this_global_batch * hidden_size * vocab_size * (1 + mtp_num_layers)) # logits computation ) return flops_fwd * 3 @@ -403,13 +400,18 @@ def transformer_flops(): assert not args.group_query_attention ''' Basic arithmetic - let B is batch size, s is seq_len, h is embedding dim, - for one self_attnetion block (prenorm is not included) - qkv projection: 6Bsh^2 - attn: 2Bs^2h - attn over value: 2Bs^2h - oproj: 2Bsh^2 - + + Let h be the embedding dim. + We use two statistics to unify BSHD and THD cases: + seqlen_sum_this_global_batch: total number of tokens in this global batch + seqlen_squared_sum_this_global_batch: sum of squared sequence lengths in this global batch + + For one self-attention block (prenorm not included): + qkv projection: 6 * seqlen_sum_this_global_batch * h^2 + attn: 2 * seqlen_squared_sum_this_global_batch * h + attn over value: 2 * seqlen_squared_sum_this_global_batch * h + oproj: 2 * seqlen_sum_this_global_batch * h^2 + references https://arxiv.org/abs/2305.10403 https://arxiv.org/abs/2205.05198 @@ -430,7 +432,7 @@ def transformer_flops(): standard_self_attn_term = ( forward_backward_expansion_factor * fma_expansion_factor - * ( + * ( seqlen_sum_this_global_batch * ( ## q lora + rope + q norm q_term ## kv lora + rope + kv norm @@ -442,12 +444,12 @@ def transformer_flops(): ) + args.hidden_size * args.qk_pos_emb_head_dim ## o proj - + (args.num_attention_heads * args.v_head_dim) * args.hidden_size + + (args.num_attention_heads * args.v_head_dim) * args.hidden_size) ## core attn - + args.seq_length + + seqlen_squared_sum_this_global_batch * (args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)) - / 2 # causal mask (only half of the mask is non-zero) - + args.seq_length * args.num_attention_heads * args.v_head_dim / 2 + / 2 # causal mask (only half of the mask is non-zero) + + seqlen_squared_sum_this_global_batch * args.num_attention_heads * args.v_head_dim / 2 ) ) @@ -460,7 +462,7 @@ def transformer_flops(): standard_self_attn_term = ( forward_backward_expansion_factor * fma_expansion_factor - * ( + * ( seqlen_sum_this_global_batch *( ## qkv proj args.hidden_size * ( @@ -468,14 +470,14 @@ def transformer_flops(): + key_projection_size + value_projection_size + gate_projection_size - ) + )) ## core attention + query_projection_size - * args.seq_length + * seqlen_squared_sum_this_global_batch / 2 # causal mask (only half of the mask is non-zero) * 2 # QK^T and (QK^T)V ## out proj - + query_projection_size + + seqlen_sum_this_global_batch * query_projection_size * args.hidden_size ) ) @@ -553,8 +555,7 @@ def transformer_flops(): ) total_floating_point_operations = ( - batch_size - * args.seq_length + seqlen_sum_this_global_batch * ( # MLP forward_backward_expansion_factor @@ -584,8 +585,6 @@ def transformer_flops(): + (shared_expert_ffn_hidden_size * ffn_expansion_factor) * num_moe_layers ) - # Self Attention - + self_attn_term # MTP norms and proj + forward_backward_expansion_factor * fma_expansion_factor @@ -603,6 +602,10 @@ def transformer_flops(): * args.padded_vocab_size * (mtp_num_layers + 1) # MTP + final logit ) + + + # Self Attention + self_attn_term + ) return total_floating_point_operations @@ -616,8 +619,8 @@ def transformer_flops(): mtp_num_layers = 0 # Compute hybrid model FLOPs. return hybrid_flops( - batch_size=batch_size, - seq_len=args.seq_length, + seqlen_sum_this_global_batch=seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch=seqlen_squared_sum_this_global_batch, hidden_size=args.hidden_size, num_attn_layers=num_attn_layers, num_mamba_layers=num_mamba_layers, @@ -1728,6 +1731,19 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch if isinstance(optim_instance, DistributedOptimizer): optim_instance.release_offloaded_gpu_states() + if config.sequence_packing: + ( + data_iterator, + num_microbatches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = wrap_dataloader(data_iterator, config, num_microbatches) + else: + # data_iterator unchanged + num_microbatches = get_num_microbatches() + seqlen_sum_this_global_batch = args.seq_length * args.micro_batch_size * args.data_parallel_size * num_microbatches + seqlen_squared_sum_this_global_batch = args.seq_length ** 2 * args.micro_batch_size * args.data_parallel_size * num_microbatches + # Forward pass. if save_dgrads_in_this_iteration: enable_dgrad_logging(model, args.save) @@ -1735,7 +1751,7 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch forward_step_func=forward_step_func, data_iterator=data_iterator, model=model, - num_microbatches=get_num_microbatches(), + num_microbatches=num_microbatches, seq_length=args.seq_length, micro_batch_size=args.micro_batch_size, decoder_seq_length=args.decoder_seq_length, @@ -1766,9 +1782,10 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch # iteration is 0-indexed, move to 1-indexed for checkpoint name and logging. save_grads(args.save, state_dict, iteration + 1, "wgrads") + should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() if should_exit: - return {}, True, should_checkpoint, should_exit, exit_code, None, None, 0 + return {}, True, should_checkpoint, should_exit, exit_code, None, None, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, 0 # Empty unused memory. if args.empty_unused_memory_level >= 1: @@ -1848,8 +1865,10 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch grad_norm, num_zeros_in_grad, log_max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, ) - return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit + return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch def training_log( @@ -1864,6 +1883,8 @@ def training_log( params_norm, num_zeros_in_grad, max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, pg_collection=None, is_first_iteration=False, ): @@ -2096,7 +2117,7 @@ def training_log( elapsed_time = timers('interval-time').elapsed(barrier=True, reset=should_reset) elapsed_time_per_iteration = elapsed_time / total_iterations - throughput = num_floating_point_operations(args, batch_size) / ( + throughput = num_floating_point_operations(args,seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) / ( elapsed_time_per_iteration * 10**12 * args.world_size ) @@ -2864,6 +2885,8 @@ def trace_handler(p): # Completely skip iteration if needed. if iteration in args.iterations_to_skip: + # TODO(tailaim): this need to be modified + assert not args.sequence_packing, "Sequence packing is not supported in skip iteration mode" # Dummy train_step to fast forward train_data_iterator. dummy_train_step(train_data_iterator) if iteration == start_iteration: @@ -2906,6 +2929,8 @@ def trace_handler(p): grad_norm, num_zeros_in_grad, max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, ) = train_step( forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func, iteration=iteration ) @@ -2993,7 +3018,7 @@ def trace_handler(p): else: assert num_skipped_samples_in_batch == 0 args.skipped_train_samples += num_skipped_samples_in_batch - num_floating_point_operations_in_batch = num_floating_point_operations(args, batch_size) + num_floating_point_operations_in_batch = num_floating_point_operations(args, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) num_floating_point_operations_so_far += num_floating_point_operations_in_batch num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch @@ -3019,6 +3044,8 @@ def trace_handler(p): params_norm, num_zeros_in_grad, max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, pg_collection=model_pg_collection, is_first_iteration=is_first_iteration, ) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index e6ce7ac2a48..b2d7ce192c1 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -25,6 +25,7 @@ from megatron.core import parallel_state from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset +from megatron.core.datasets.data_schedule import get_batch_on_this_rank_for_sequence_packing from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel from megatron.core.rerun_state_machine import get_rerun_state_machine @@ -49,6 +50,7 @@ get_blend_and_blend_per_split, is_first_or_last_pipeline_stage, ) +from megatron.training.datasets.sft_dataset import SFTDataset, MockSFTDataset from model_provider import model_provider try: @@ -66,6 +68,14 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): """Generate a batch.""" args = get_args() config = core_transformer_config_from_args(args) + + if args.sequence_packing: + return get_batch_on_this_rank_for_sequence_packing( + data_iterator, + mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage), + vp_stage=vp_stage, + ) + # TODO: this is pretty hacky, find a better way if not is_first_or_last_pipeline_stage(vp_stage) and ( (not mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage))): @@ -250,6 +260,8 @@ def core_gpt_dataset_config_from_args(args): "data_parallel_size": args.data_parallel_size, "sequence_parallel_size": args.tensor_model_parallel_size*args.sequence_parallel, "hybrid_context_parallel": args.hybrid_context_parallel, + "sft_mock_dataset_config_json":args.sft_mock_dataset_config_json, + "sequence_packing": args.sequence_packing, } # add FIM args to the config @@ -287,7 +299,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None config = core_gpt_dataset_config_from_args(args) if args.sft: - dataset_type = SFTDataset + if args.mock_data: + dataset_type = MockSFTDataset + else: + dataset_type = SFTDataset else: if args.mock_data: dataset_type = MockGPTDataset diff --git a/tests/unit_tests/test_sequence_packing.py b/tests/unit_tests/test_sequence_packing.py new file mode 100644 index 00000000000..bac7ea79db1 --- /dev/null +++ b/tests/unit_tests/test_sequence_packing.py @@ -0,0 +1,322 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from types import SimpleNamespace + +import pytest +import torch + +from megatron.core import parallel_state +from megatron.core.datasets.data_schedule import get_batch_on_this_rank_for_sequence_packing +from megatron.training.global_vars import unset_global_variables +from tests.unit_tests.test_utilities import Utils + + +class MockVariableLengthSequencePackingDataIterator: + """ + Mock data iterator for testing get_batch_on_this_rank_for_sequence_packing. + + Generates variable-length (THD format) packed sequences with deterministic + data for verification across parallel ranks. + """ + + def __init__( + self, + total_seq_length: int, + sequence_lengths: list, + local_cp_size: int = None, + device: str = "cuda", + seed: int = 42, + ): + """ + Args: + total_seq_length: Total length of packed sequences + sequence_lengths: List of individual sequence lengths (variable-length). + If None, generates random variable lengths. + local_cp_size: Local CP size for hybrid context parallel + device: Device to create tensors on + seed: Random seed for reproducibility + """ + self.total_seq_length = total_seq_length + self.sequence_lengths = sequence_lengths + self.local_cp_size = local_cp_size + self.device = device + self.seed = seed + assert ( + sum(self.sequence_lengths) == total_seq_length + ), f"Sequence lengths sum {sum(self.sequence_lengths)} != total {total_seq_length}" + + def __iter__(self): + """Interface for the data iterator.""" + return self + + def __next__(self): + """Generate a mock batch with variable-length THD format.""" + dev = self.device + torch.manual_seed(self.seed) + torch.cuda.manual_seed(self.seed) + + tokens = torch.randint(0, 16384, (self.total_seq_length,), dtype=torch.int64, device=dev) + + # Create position_ids that reset for each sequence (THD format) + position_ids = [] + for seq_len in self.sequence_lengths: + position_ids.extend(range(seq_len)) + position_ids = torch.tensor(position_ids, dtype=torch.int64, device=dev) + + # Labels are tokens shifted by 1 for easy verification + labels = tokens + 1 + + # Loss mask: 1.0 for all positions except padding (none here) + loss_mask = torch.ones(self.total_seq_length, dtype=torch.float32, device=dev) + + # Create cu_seqlens for variable-length packed sequences + cu_seqlens = [0] + for seq_len in self.sequence_lengths: + cu_seqlens.append(cu_seqlens[-1] + seq_len) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=dev) + cu_seqlens_padded = cu_seqlens.clone() + + max_seqlen = torch.tensor([max(self.sequence_lengths)], dtype=torch.int32, device=dev) + + batch = { + "tokens": tokens, + "position_ids": position_ids, + "labels": labels, + "loss_mask": loss_mask, + "cu_seqlens": cu_seqlens, + "cu_seqlens_padded": cu_seqlens_padded, + "max_seqlen": max_seqlen, + } + + if not ( + parallel_state.is_pipeline_first_stage(ignore_virtual=True) + or parallel_state.is_pipeline_last_stage(ignore_virtual=True) + ): + batch["tokens"] = None + batch["position_ids"] = None + batch["labels"] = None + batch["loss_mask"] = None + + if self.local_cp_size is not None: + batch["local_cp_size"] = torch.tensor( + [self.local_cp_size], dtype=torch.int32, device=dev + ) + + return batch + + +def _gather_tensor_from_tp_group(tensor): + """Gather tensors from all TP ranks for comparison.""" + assert tensor is not None, "Tensor should not be None" + tp_size = parallel_state.get_tensor_model_parallel_world_size() + gathered = [torch.zeros_like(tensor) for _ in range(tp_size)] + torch.distributed.all_gather( + gathered, tensor, group=parallel_state.get_tensor_model_parallel_group() + ) + return gathered + + +def _gather_tensor_from_all_ranks(tensor): + """Gather tensors from all PP ranks for comparison.""" + assert tensor is not None, "Tensor should not be None" + if type(tensor) is int: + tensor = torch.tensor(tensor, dtype=torch.int32, device=torch.cuda.current_device()) + gathered = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(gathered, tensor) + return gathered + + +@pytest.mark.parametrize( + ("tp", "pp", "cp", "hybrid_cp"), + [ + (1, 1, 1, False), # Basic case: no parallelism + (2, 1, 1, False), # Tensor parallel only + (1, 2, 1, False), # Pipeline parallel only + (2, 2, 1, False), # TP + PP + (1, 1, 2, False), # CP only + (2, 1, 2, False), # TP + CP + (1, 2, 2, False), # PP + CP + (1, 4, 1, False), # Has middle pp stage + (1, 1, 1, True), # Hybrid CP enabled (CP=1 with hybrid groups) + (2, 1, 1, True), # TP + Hybrid CP + ], +) +def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): + """ + Test get_batch_on_this_rank_for_sequence_packing function with variable-length THD format. + + This test verifies: + 1. TP ranks: All ranks within a TP group receive identical data after broadcast + 2. PP ranks: Middle PP ranks have the same packed_seq_params as first/last stages + 3. CP ranks: Data is correctly partitioned with proper shape and values + 4. Variable-length (THD) format: Different sequence lengths are handled correctly + """ + args = SimpleNamespace() + args.tensor_model_parallel_size = tp + args.pipeline_model_parallel_size = pp + args.context_parallel_size = cp + args.hybrid_context_parallel = hybrid_cp + args.virtual_pipeline_model_parallel_size = None + args.data_parallel_size = 8 // (tp * pp * cp) + args.seq_length = 8192 + + # Skip invalid configurations + if args.data_parallel_size < 1: + raise ValueError(f"Invalid config: tp={tp}, pp={pp}, cp={cp} exceeds world size 8") + + # Initialize model parallel + Utils.initialize_model_parallel( + tp, + pp, + None, + context_parallel_size=cp, + hybrid_context_parallel=hybrid_cp, + min_hybrid_context_parallel_size=1, + ) + + try: + # Create mock data iterator with variable-length sequences + # Only TP rank 0 needs the iterator; other TP ranks pass None + tp_rank = parallel_state.get_tensor_model_parallel_rank() + local_cp_size = 8 // (tp * pp) if hybrid_cp else None + + if tp_rank == 0: + # Use deterministic seed based on DP rank so same data within TP/PP/CP group + dp_rank = parallel_state.get_data_parallel_rank() + sequence_lengths = [1024, 2048, 512, 1536, 3072] + assert ( + sum(sequence_lengths) == args.seq_length + ), f"Sequence lengths sum {sum(sequence_lengths)} != total {args.seq_length}" + data_iterator = iter( + MockVariableLengthSequencePackingDataIterator( + total_seq_length=args.seq_length, + sequence_lengths=sequence_lengths, # Variable lengths, sum=8192 + local_cp_size=local_cp_size, + seed=42 + dp_rank, # Same seed within PP/CP group + ) + ) + else: + # Non-TP-rank-0 ranks don't need the iterator + data_iterator = None + + # Call the function under test + result = get_batch_on_this_rank_for_sequence_packing( + data_iterator=data_iterator, + mtp_on_this_rank=False, + vp_stage=None, + hybrid_context_parallel=hybrid_cp, + ) + + # Unpack the result + tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = result + + # Get parallel state info + tp_rank = parallel_state.get_tensor_model_parallel_rank() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + cp_rank = parallel_state.get_context_parallel_rank() + is_first_stage = parallel_state.is_pipeline_first_stage(ignore_virtual=True) + is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + is_first_or_last = is_first_stage or is_last_stage + + # ===================================================================== + # TEST 1: Verify data based on pipeline stage + # ===================================================================== + if is_first_stage: + assert tokens is not None, "First stage should have tokens" + assert position_ids is not None, "First stage should have position_ids" + assert tokens.dim() == 2, "Tokens should be 2D (batch, seq)" + assert position_ids.dim() == 2, "Position IDs should be 2D (batch, seq)" + assert tokens.size(0) == 1, "batch should be 1 in THD format" + assert position_ids.size(0) == 1, "batch should be 1 in THD format" + else: + assert tokens is None, "Non-first stage should not have tokens" + assert position_ids is None, "Non-first stage should not have position_ids" + + if is_last_stage: + assert labels is not None, "Last stage should have labels" + assert loss_mask is not None, "Last stage should have loss_mask" + assert labels.dim() == 2, "Labels should be 2D (batch, seq)" + assert loss_mask.dim() == 2, "Loss mask should be 2D (batch, seq)" + assert labels.size(0) == 1, "batch should be 1 in THD format" + assert loss_mask.size(0) == 1, "batch should be 1 in THD format" + else: + assert labels is None, "Non-last stage should not have labels" + assert loss_mask is None, "Non-last stage should not have loss_mask" + + # ===================================================================== + # TEST 2: Verify all ranks have consistent packed_seq_params + # ===================================================================== + assert packed_seq_params is not None + assert packed_seq_params.qkv_format == "thd" + if hybrid_cp: + assert packed_seq_params.local_cp_size is not None + assert packed_seq_params.cp_group is not None + + test_keys = [ + "cu_seqlens_q", + "cu_seqlens_q_padded", + "max_seqlen_q", + "cu_seqlens_kv", + "cu_seqlens_kv_padded", + "max_seqlen_kv", + ] + if hybrid_cp: + test_keys.append("local_cp_size") + for key in test_keys: + tensor = getattr(packed_seq_params, key) + assert tensor is not None + gathered_tensor = _gather_tensor_from_all_ranks(tensor) + for i in range(1, len(gathered_tensor)): + assert torch.equal( + gathered_tensor[0], gathered_tensor[i] + ), f"Rank 0 and rank {i} have different {key}" + + # ===================================================================== + # TEST 3: Verify TP ranks receive identical data after broadcast + # ===================================================================== + if tp > 1: + test_tensors = [] + if is_first_stage: + test_tensors.extend([tokens, position_ids]) + if is_last_stage: + test_tensors.extend([labels, loss_mask]) + + for tensor in test_tensors: + gathered_tensors = _gather_tensor_from_tp_group(tensor) + for i in range(1, tp): + assert torch.equal( + gathered_tensors[0], gathered_tensors[i] + ), f"TP rank 0 and rank {i} have different data" + + # ===================================================================== + # TEST 4: Verify CP partitioning + # ===================================================================== + if cp > 1 or hybrid_cp: + if hybrid_cp: + assert packed_seq_params.local_cp_size is not None + cp_size = packed_seq_params.local_cp_size + assert packed_seq_params.cp_group == ( + parallel_state.get_hybrid_data_context_parallel_groups(group_size=cp_size) + ) + else: + cp_size = cp + + # With CP, the sequence should be partitioned + expected_seq_len = args.seq_length // cp_size + + if is_first_stage: + actual_seq_len = tokens.shape[1] + assert ( + actual_seq_len == expected_seq_len + ), f"CP partitioned tokens have wrong shape: {actual_seq_len} != {expected_seq_len}" + + # Verify labels only if all CP ranks are at last stage + if is_last_stage: + actual_seq_len = labels.shape[1] + assert ( + actual_seq_len == expected_seq_len + ), f"CP partitioned labels have wrong shape: {actual_seq_len} != {expected_seq_len}" + + finally: + Utils.destroy_model_parallel() + unset_global_variables() From 496431e01bcb9454204ae9019bdca36956c87d6a Mon Sep 17 00:00:00 2001 From: tailaim Date: Mon, 9 Feb 2026 21:36:01 -0800 Subject: [PATCH 2/9] refactor according to comments and the new sftdataset Signed-off-by: tailaim --- megatron/core/datasets/data_schedule.py | 581 +++--------------- .../core/datasets/data_scheduler_utils.py | 492 +++++++++++++++ megatron/core/datasets/gpt_dataset.py | 6 +- .../core/extensions/transformer_engine.py | 21 + megatron/core/model_parallel_config.py | 9 +- .../text/libraries/sft_tokenizer.py | 5 - .../core/transformer/transformer_config.py | 18 +- megatron/training/arguments.py | 10 +- megatron/training/datasets/sft_dataset.py | 161 ++--- megatron/training/training.py | 16 +- pretrain_gpt.py | 5 +- tests/unit_tests/test_sequence_packing.py | 251 ++++++-- 12 files changed, 926 insertions(+), 649 deletions(-) create mode 100644 megatron/core/datasets/data_scheduler_utils.py diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 7a1dd9e587d..76e53c6fe37 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -1,27 +1,24 @@ # Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. import enum -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type -import numpy as np import torch from megatron.core import parallel_state +from megatron.core.datasets.data_scheduler_utils import ( + broadcast_scalars, + broadcast_tensor, + broadcast_to_pp_group, + build_packed_microbatches, + create_data_iterator, + get_batch_and_global_seqlens, + reroute_samples_to_dcp_ranks, +) +from megatron.core.extensions.transformer_engine import get_thd_partitioned_indices from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.rerun_state_machine import RerunDataIterator -from megatron.core.utils import is_te_min_version - -try: - # Register the TE CUDA kernels - import transformer_engine # pylint: disable=unused-import - - # Alias the PyTorch wrapper so we can call tex.* APIs - import transformer_engine_torch as tex -except ImportError: - # TE isn’t installed or the torch wrapper is missing - tex = None class HybridCPDataLoaderWrapper: @@ -72,7 +69,6 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: Gathers the sequence lengths of all subsamples from all DP ranks. Each DP rank loads the same number of microbatches but each microbatch may have a different number of subsamples. - We find the number of subsamples each rank holds and then gather the sequence lengths of all subsamples from all ranks. """ @@ -316,18 +312,6 @@ def __next__(self) -> Any: return samples_this_rank_with_id, sample_id_groups -class PackingScheduler(enum.Enum): - """Enum for supported sequence packing algorithms.""" - - DEFAULT_SEQUENCE_PACKING = "default_sequence_packing" - - -def _broadcast_tensor(item, src_rank, group) -> None: - """Broadcast a tensor from src_rank to all ranks in the group.""" - if item is not None: - torch.distributed.broadcast(item, src_rank, group=group) - - class BaseScheduler: """Base class for sequence packing schedulers.""" @@ -338,6 +322,14 @@ def __init__( dp_size: int, microbatch_group_size_per_vp_stage: Optional[int], ): + """ + Args: + max_seqlen_per_dp_cp_rank: The maximum sequence length per DPxCP rank. + cp_size: The context parallel size. + dp_size: The data parallel size. + microbatch_group_size_per_vp_stage: The microbatch group size per virtual + pipeline stage, only used when enabling VPP, otherwise None. + """ self.max_seqlen_per_dp_cp_rank = max_seqlen_per_dp_cp_rank self.cp_size = cp_size self.dp_size = dp_size @@ -362,446 +354,29 @@ def run( dev, config, ): - """Run the scheduler and return the new data_iterator.""" - raise NotImplementedError - - @staticmethod - def _get_global_seqlens(subsample_seqlens: torch.Tensor, dp_group) -> List[int]: - """ - Gathers the sequence lengths of all subsamples from all DP ranks. - - Each DP rank has the same number of subsamples (num_microbatches), - so we can directly all_gather without padding. - """ - dp_size = dp_group.size() - num_local_subsamples = subsample_seqlens.shape[0] - - # Gather the subsample_seqlens from all ranks - seqlens_gathered = [torch.empty_like(subsample_seqlens) for _ in range(dp_size)] - torch.distributed.all_gather(seqlens_gathered, subsample_seqlens, group=dp_group) - - seqlens_gathered = torch.cat(seqlens_gathered, dim=0) - seqlens_gathered = seqlens_gathered.cpu().tolist() - - # Calculate the offsets to assign unique global ID to each subsample. - # Since each rank has the same number of subsamples, offsets are evenly spaced. - offsets = torch.arange( - 0, dp_size * num_local_subsamples, num_local_subsamples, dtype=torch.int32 - ) - - return seqlens_gathered, offsets - - @staticmethod - def _get_global_id_seqlens(num_local_subsamples, offsets, seqlens_gathered, dp_group): """ - Calculates the global ID for each subsample. - - We assign a unique global ID to each subsample. - - Returns: - global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. - global_ids_this_rank: list of global IDs locally present on this rank. - """ - dp_rank = dp_group.rank() - global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() - # Create a list of (global_id, seqlen) tuples for scheduling - global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] - # Get the global IDs locally present on this rank - global_ids_this_rank = global_ids[ - offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples - ] - - return global_id_seqlens, global_ids_this_rank - - @staticmethod - def _broadcast_to_pp_group( - new_samples, - num_micro_batches, - seqlen_sum_this_global_batch, - seqlen_squared_sum_this_global_batch, - pp_group, - dev, - ): - """ - Broadcast num_micro_batches, seqlen_sum_this_global_batch, - seqlen_squared_sum_this_global_batch and metadata to middle PP stages. - """ - - pp_src_rank = torch.distributed.get_process_group_ranks(pp_group)[0] - - if pp_group.size() > 2: - if pp_group.rank() == 0: - tensor_list = [ - torch.tensor( - [ - num_micro_batches, - seqlen_sum_this_global_batch, - seqlen_squared_sum_this_global_batch, - ], - dtype=torch.float32, - ).cuda() - ] - for sample in new_samples: - tensor_list.append(sample["max_seqlen"].unsqueeze(0)) - for sample in new_samples: - tensor_list.append(sample["cu_seqlens"]) - tensor_list.append(sample["cu_seqlens_padded"]) - info_to_broadcast = torch.cat(tensor_list, dim=0).to( - device=dev, dtype=torch.float32 - ) - info_length_tensor = torch.tensor( - info_to_broadcast.shape[0], dtype=torch.int32 - ).cuda() - _broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) - _broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) - else: - info_length_tensor = torch.tensor(0, dtype=torch.int32).cuda() - _broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) - info_to_broadcast = torch.empty( - info_length_tensor.item(), dtype=torch.float32 - ).cuda() - _broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) - if pp_group.rank() != pp_group.size() - 1: - # middle PP stages receive the broadcasted info and unpack it - info_numpy = info_to_broadcast.cpu().numpy() - num_micro_batches = int(info_numpy[0]) - seqlen_sum_this_global_batch = info_numpy[1] - seqlen_squared_sum_this_global_batch = info_numpy[2] - max_seqlens = info_to_broadcast[3 : 3 + num_micro_batches] - cu_seqlens_list = [] - cu_seqlens_padded_list = [] - indices = np.where(info_numpy == 0)[0] - for i in range(num_micro_batches): - cu_seqlens_list.append( - info_to_broadcast[indices[i * 2] : indices[i * 2 + 1]] - ) - if i == num_micro_batches - 1: - cu_seqlens_padded_list.append(info_to_broadcast[indices[i * 2 + 1] :]) - else: - cu_seqlens_padded_list.append( - info_to_broadcast[indices[i * 2 + 1] : indices[i * 2 + 2]] - ) - - new_samples = [] - for i in range(num_micro_batches): - new_sample = {} - new_sample["max_seqlen"] = max_seqlens[i].to(torch.int32) - new_sample["cu_seqlens"] = cu_seqlens_list[i].to(torch.int32) - new_sample["cu_seqlens_padded"] = cu_seqlens_padded_list[i].to(torch.int32) - new_samples.append(new_sample) - - return ( - new_samples, - num_micro_batches, - seqlen_sum_this_global_batch, - seqlen_squared_sum_this_global_batch, - ) - - @staticmethod - def _broadcast_scalars(values: List, group, dev, dtype=torch.float32) -> List: - """ - Broadcast scalar values from rank 0 to all ranks in the group. + Run the scheduler and return the new data_iterator. - Args: - values: List of scalar values to broadcast (only used on rank 0). - group: The process group to broadcast within. - dev: The device to use for the tensor. - dtype: The data type for the tensor. - - Returns: - List of broadcasted values. - """ - if group.size() <= 1: - return values - - src_rank = torch.distributed.get_process_group_ranks(group)[0] - num_values = len(values) - - if group.rank() == 0: - info_to_broadcast = torch.tensor(values, dtype=dtype, device=dev) - else: - info_to_broadcast = torch.zeros(num_values, dtype=dtype, device=dev) - - _broadcast_tensor(info_to_broadcast, src_rank, group) - - if group.rank() != 0: - values = info_to_broadcast.cpu().tolist() - - return values - - @staticmethod - def _create_data_iterator(new_samples, pp_group, tp_group, config): - """Handle virtual pipeline parallelism.""" - if ( - config.virtual_pipeline_model_parallel_size is not None - and config.virtual_pipeline_model_parallel_size > 1 - ): - vpp_size = config.virtual_pipeline_model_parallel_size - if tp_group.rank() == 0: - if pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1: - new_samples_for_other_ppstage = [] - for sample in new_samples: - new_sample_for_other_ppstage = {} - new_sample_for_other_ppstage["max_seqlen"] = sample["max_seqlen"] - new_sample_for_other_ppstage["cu_seqlens"] = sample["cu_seqlens"] - new_sample_for_other_ppstage["cu_seqlens_padded"] = sample[ - "cu_seqlens_padded" - ] - new_samples_for_other_ppstage.append(new_sample_for_other_ppstage) - if pp_group.rank() == 0: - new_data_iterator = [RerunDataIterator(iter(new_samples))] + [ - RerunDataIterator(iter(new_samples_for_other_ppstage)) - for _ in range(vpp_size - 1) - ] - else: - new_data_iterator = [ - RerunDataIterator(iter(new_samples_for_other_ppstage)) - for _ in range(vpp_size - 1) - ] + [RerunDataIterator(iter(new_samples))] - else: - new_data_iterator = [ - RerunDataIterator(iter(new_samples)) for _ in range(vpp_size) - ] - else: - new_data_iterator = [None for _ in range(vpp_size)] - else: - new_data_iterator = ( - RerunDataIterator(iter(new_samples)) if tp_group.rank() == 0 else None - ) - - return new_data_iterator - - @staticmethod - def _reroute_samples_to_dcp_ranks( - batch, - global_ids_this_rank, - global_id_seqlens, - sample_id_groups, - offsets, - dp_group, - tp_group, - dp_cp_group, - total_dcp_gpus, - ): - """ - Reroutes the sub-samples to the correct rank after scheduling. - - For each key in the batch dict, we perform an all-to-all communication - to transfer the data to the correct ranks. - """ - - def _gid_to_src_rank(gid: int) -> int: - dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) - dcp_rank = ( - torch.distributed.get_process_group_ranks(dp_group)[dp_src_rank] // tp_group.size() - ) % dp_cp_group.size() - return dcp_rank - - gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} - dcp_rank = dp_cp_group.rank() - dp_ranks = torch.distributed.get_process_group_ranks(dp_group) - dp_ranks = [(r // tp_group.size()) % dp_cp_group.size() for r in dp_ranks] - - data_keys = batch[0].keys() - - # Create the send plan - combined_sample_id_groups: List[List[int]] = [[] for _ in range(total_dcp_gpus)] - for d in range(total_dcp_gpus): - for sample_id_group in sample_id_groups: - combined_sample_id_groups[d].extend(sample_id_group[d]) - for dest_rank in range(total_dcp_gpus): - combined_sample_id_groups[dest_rank].sort() - - send_ids_sorted = [ - gid - for d in dp_ranks - for gid in combined_sample_id_groups[d] - if gid in global_ids_this_rank - ] - - send_num_split = [0] * total_dcp_gpus - send_lens_split = [0] * total_dcp_gpus - for dest_rank in range(total_dcp_gpus): - if dest_rank in dp_ranks: - send_seq_lens = [ - global_id_seqlens[gid][1] - for gid in combined_sample_id_groups[dest_rank] - if gid in global_ids_this_rank - ] - send_num_split[dest_rank] = len(send_seq_lens) - send_lens_split[dest_rank] = sum(send_seq_lens) - else: - send_lens_split[dest_rank] = 0 - - # Create the recv plan - recv_sample_id_groups = [[] for _ in range(total_dcp_gpus)] - for gid in combined_sample_id_groups[dcp_rank]: - src_rank = _gid_to_src_rank(gid) - recv_sample_id_groups[src_rank].append(gid) - - recv_lens_split = [0] * total_dcp_gpus - for src_rank in range(total_dcp_gpus): - recv_lens_split[src_rank] = sum( - [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] - ) - - recv_ids_sorted = [gid for d in range(total_dcp_gpus) for gid in recv_sample_id_groups[d]] - recv_counts = [len(recv_sample_id_groups[d]) for d in range(total_dcp_gpus)] - - recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] - - def _pack_sample_by_key(key: str) -> torch.Tensor: - flattened_tensors = [] - for gid in send_ids_sorted: - t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) - flattened_tensors.append(t.reshape(-1)) - return ( - torch.cat(flattened_tensors, dim=0) - if flattened_tensors - else torch.empty(1, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) - ) - - def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): - cursor = 0 - for i, gid in enumerate(recv_ids_sorted): - sample_len = ( - 1 - if key in ["original_seq_len", "padded_seq_len"] - else global_id_seqlens[gid][1] - ) - recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] - cursor += sample_len - - for key in data_keys: - output_split_sizes, input_split_sizes = ( - (recv_counts, send_num_split) - if key in ["original_seq_len", "padded_seq_len"] - else (recv_lens_split, send_lens_split) - ) - send_tensor = _pack_sample_by_key(key) - recv_tensor_size = sum(output_split_sizes) - recv_tensor = torch.empty( - recv_tensor_size, device=torch.cuda.current_device(), dtype=send_tensor.dtype - ) - torch.distributed.all_to_all_single( - output=recv_tensor, - input=send_tensor, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=dp_cp_group, - ) - _unpack_sample_by_key(key, recv_tensor) - - recv_sample_with_id = { - recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted) - } - return recv_sample_with_id - - @staticmethod - def _pack_sequences( - samples: List, - padded_lengths: torch.Tensor, - original_lengths: torch.Tensor, - dev: torch.device, - ) -> Dict[str, torch.Tensor]: - """Pack multiple samples into a single packed sample.""" - - def _pack_tensors(tensors): - return torch.cat([t.reshape(-1) for t in tensors], dim=0) - - tokens = _pack_tensors([sample["tokens"] for sample in samples]) - labels = _pack_tensors([sample["labels"] for sample in samples]) - loss_mask = _pack_tensors([sample["loss_mask"] for sample in samples]) - position_ids = _pack_tensors([sample["position_ids"] for sample in samples]) - - new_sample = {} - new_sample["tokens"] = tokens - new_sample["labels"] = labels - new_sample["loss_mask"] = loss_mask - new_sample["position_ids"] = position_ids - - padded_lengths = padded_lengths.to( - device=dev, dtype=torch.int32, non_blocking=True - ).reshape(-1) - cu_seqlens_padded = torch.empty(padded_lengths.numel() + 1, device=dev, dtype=torch.int32) - cu_seqlens_padded[0] = 0 - cu_seqlens_padded[1:] = torch.cumsum(padded_lengths, dim=0) - max_seqlen = torch.max(padded_lengths).to(dtype=torch.int32) - - new_sample["cu_seqlens_padded"] = cu_seqlens_padded - new_sample["max_seqlen"] = max_seqlen - - original_lengths = original_lengths.to( - device=dev, dtype=torch.int32, non_blocking=True - ).reshape(-1) - cu_seqlens = torch.empty(original_lengths.numel() + 1, device=dev, dtype=torch.int32) - cu_seqlens[0] = 0 - cu_seqlens[1:] = torch.cumsum(original_lengths, dim=0).reshape(-1) - new_sample["cu_seqlens"] = cu_seqlens - - return new_sample - - @staticmethod - def _build_packed_microbatches( - grouped_samples: List[List[Dict[str, torch.Tensor]]], dev: torch.device - ) -> List[Dict[str, torch.Tensor]]: - """Build packed samples for each microbatch.""" - num_micro_batches = len(grouped_samples) - seg_starts: List[int] = [0] - original_lens_tensors = [] - padded_lens_tensors = [] - - for i in range(num_micro_batches): - samples = grouped_samples[i] - seg_starts.append(seg_starts[-1] + len(samples)) - original_lens_tensors.extend([s["original_seq_len"].reshape(-1) for s in samples]) - padded_lens_tensors.extend([s["padded_seq_len"].reshape(-1) for s in samples]) - - padded_lens_all_gpu = torch.cat(padded_lens_tensors, dim=0).to(dtype=torch.int32) - original_lens_all_gpu = torch.cat(original_lens_tensors, dim=0).to(dtype=torch.int32) - - new_samples: List[Dict[str, torch.Tensor]] = [] - for i in range(num_micro_batches): - samples = grouped_samples[i] - lp = padded_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] - lo = original_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] - new_sample = BaseScheduler._pack_sequences(samples, lp, lo, dev) - new_samples.append(new_sample) - - return new_samples - - @staticmethod - def _get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group): - """ - Get the batch and global sequence lengths. - Each DP rank loads the same number of sequences, so we need to gather the sequence - lengths from all ranks then we can schedule the sequences into groups. Args: data_iterator: The data iterator. - num_microbatches: The number of microbatches. - dp_group: The data parallel group. + num_microbatches: The number of microbatches to fetch. + dp_group: Data parallel process group. + tp_group: Tensor parallel process group. + pp_group: Pipeline parallel process group. + dp_cp_group: Data parallel + context parallel process group. + dev: CUDA device. + config: Model parallel config. Returns: - batch: The batch. - global_id_seqlens: The global sequence lengths. - global_ids_this_rank: The global IDs locally present on this rank. + new_data_iterator: The new data iterator (or list for VPP). + num_micro_batches: Number of micro batches after scheduling. + seqlen_sum_this_global_batch: Total tokens for FLOPs calculation. + seqlen_squared_sum_this_global_batch: Sum of squared seqlens for FLOPs. """ - batch = [next(data_iterator) for _ in range(num_microbatches)] - subsample_seqlens = [] - for sample in batch: - subsample_seqlens.extend([sample["tokens"].numel()]) - subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() - - seqlens_gathered, offsets = BaseScheduler._get_global_seqlens(subsample_seqlens, dp_group) - - global_id_seqlens, global_ids_this_rank = BaseScheduler._get_global_id_seqlens( - subsample_seqlens.shape[0], offsets, seqlens_gathered, dp_group - ) - - return batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered + raise NotImplementedError -class DefaultSequencePackingScheduler(BaseScheduler): +class DpBalancedScheduler(BaseScheduler): """Packs sequences in their original order until reaching the max limit of sequence length.""" def __init__(self, *args, **kwargs): @@ -942,12 +517,14 @@ def run( # Step 1: Fetch batches and gather global sequence lengths batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered = ( - self._get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group) + get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group) ) # Step 2: Check required sample keys for key in self.get_require_sample_keys(): - assert key in batch[0], f"Batch missing required key {key}" + assert ( + key in batch[0] + ), f"Batch missing required key {key}, provided keys: {batch[0].keys()}" # Step 3: Schedule samples into groups sample_id_groups = self.get_groups_and_subsamples(global_id_seqlens) @@ -963,7 +540,7 @@ def run( ) # Step 4: Reroute samples to DCP ranks - samples_this_rank_with_id = self._reroute_samples_to_dcp_ranks( + samples_this_rank_with_id = reroute_samples_to_dcp_ranks( batch, global_ids_this_rank, global_id_seqlens, @@ -987,7 +564,7 @@ def run( ] # Step 5: Build packed microbatches - new_samples = self._build_packed_microbatches(grouped_samples, dev) + new_samples = build_packed_microbatches(grouped_samples, dev) # Step 6: Calculate FLOPs info seqlen_sum_this_global_batch = float(sum(seqlens_gathered)) @@ -1009,7 +586,7 @@ def run( num_micro_batches, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, - ) = self._broadcast_to_pp_group( + ) = broadcast_to_pp_group( new_samples, num_micro_batches, seqlen_sum_this_global_batch, @@ -1020,7 +597,7 @@ def run( # Step 8: Broadcast to TP group (for non-TP-0 ranks) (num_micro_batches, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) = ( - self._broadcast_scalars( + broadcast_scalars( [ num_micro_batches, seqlen_sum_this_global_batch, @@ -1033,7 +610,7 @@ def run( num_micro_batches = int(num_micro_batches) # Step 9: create data_iterator and handle VPP if enabled - new_data_iterator = self._create_data_iterator(new_samples, pp_group, tp_group, config) + new_data_iterator = create_data_iterator(new_samples, pp_group, tp_group, config) return ( new_data_iterator, @@ -1043,7 +620,18 @@ def run( ) -def wrap_dataloader( +class PackingScheduler(enum.Enum): + """Enum for supported sequence packing algorithms.""" + + DP_BALANCED = "dp_balanced" + + +scheduler_map: Dict[PackingScheduler, Type[BaseScheduler]] = { + PackingScheduler.DP_BALANCED: DpBalancedScheduler +} + + +def wrap_data_iterator( data_iterator, config, num_microbatches, pg_collection: Optional[ProcessGroupCollection] = None ): """ @@ -1057,10 +645,6 @@ def wrap_dataloader( pg_collection: The process group collection. """ - scheduler_map: Dict[PackingScheduler, Type[BaseScheduler]] = { - PackingScheduler.DEFAULT_SEQUENCE_PACKING: DefaultSequencePackingScheduler - } - if pg_collection is None: dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) dp_group = parallel_state.get_data_parallel_group() @@ -1116,7 +700,11 @@ def wrap_dataloader( def get_batch_on_this_rank_for_sequence_packing( - data_iterator, mtp_on_this_rank: bool = False, vp_stage: Optional[int] = None + data_iterator, + vpp_size: Optional[int] = None, + mtp_on_this_rank: bool = False, + vp_stage: Optional[int] = None, + pg_collection: Optional[ProcessGroupCollection] = None, ): """ Get a batch of data for sequence packing. @@ -1128,16 +716,23 @@ def get_batch_on_this_rank_for_sequence_packing( tuple of (tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params) """ - tp_src_rank = parallel_state.get_tensor_model_parallel_src_rank() - tp_group = parallel_state.get_tensor_model_parallel_group() + if pg_collection is None: + tp_group = parallel_state.get_tensor_model_parallel_group() + pp_group = parallel_state.get_pipeline_model_parallel_group() + cp_group = parallel_state.get_context_parallel_group() + else: + tp_group = pg_collection.tp + pp_group = pg_collection.pp + cp_group = pg_collection.cp - is_tp_rank_0 = parallel_state.get_tensor_model_parallel_rank() == 0 - is_first_stage = parallel_state.is_pipeline_first_stage( - ignore_virtual=vp_stage is None, vp_stage=vp_stage - ) - is_last_stage = parallel_state.is_pipeline_last_stage( - ignore_virtual=vp_stage is None, vp_stage=vp_stage + tp_src_rank = torch.distributed.get_process_group_ranks(tp_group)[0] + + is_tp_rank_0 = tp_group.rank() == 0 + is_first_stage = pp_group.rank() == 0 and (vp_stage is None or vp_stage == 0) + is_last_stage = pp_group.rank() == pp_group.size() - 1 and ( + vp_stage is None or vp_stage == vpp_size - 1 ) + is_first_or_last_stage = is_first_stage or is_last_stage dev = torch.cuda.current_device() @@ -1163,20 +758,16 @@ def get_batch_on_this_rank_for_sequence_packing( # Partition tokens, position_ids, labels, loss_mask for context parallel, currently only # TP rank 0 and the first/last PP stage rank has these data. if is_tp_rank_0 and is_first_or_last_stage: - cp_size = parallel_state.get_context_parallel_world_size() - cp_rank = parallel_state.get_context_parallel_rank() + cp_size = cp_group.size() + cp_rank = cp_group.rank() # If cp_size == 1, no need to do further processing. if cp_size > 1: - assert tex is not None and is_te_min_version("1.10.0"), ( - "Please update Transformer Engine to >= 1.10 to use " - "Context Parallel with THD format data" - ) total_tokens = batch['tokens'].size(0) # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as # cu_seqlens to get the correct result. # TODO: Revert this workaround once TE fixes the issue. cu_seqlens = batch["cu_seqlens_padded"] - index = tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) + index = get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) for key in ['tokens', 'position_ids', 'labels', 'loss_mask']: batch[key] = batch[key].index_select(0, index) @@ -1186,7 +777,7 @@ def get_batch_on_this_rank_for_sequence_packing( cu_seqlen_size = torch.tensor(batch['cu_seqlens'].size(0), dtype=torch.int32, device=dev) else: cu_seqlen_size = torch.empty(1, dtype=torch.int32, device=dev) - _broadcast_tensor(cu_seqlen_size, tp_src_rank, tp_group) + broadcast_tensor(cu_seqlen_size, tp_src_rank, tp_group) cu_seqlen_size = cu_seqlen_size.item() # Broadcast total_tokens because we need it to create placeholder for tokens, position_ids, @@ -1196,7 +787,7 @@ def get_batch_on_this_rank_for_sequence_packing( total_tokens = torch.tensor(batch['tokens'].size(0), dtype=torch.int32, device=dev) else: total_tokens = torch.empty(1, dtype=torch.int32, device=dev) - _broadcast_tensor(total_tokens, tp_src_rank, tp_group) + broadcast_tensor(total_tokens, tp_src_rank, tp_group) total_tokens = total_tokens.item() # Step1: Prepare "tokens", "position_ids" on all ranks. @@ -1246,13 +837,13 @@ def get_batch_on_this_rank_for_sequence_packing( batch['max_seqlen'] = torch.empty(1, dtype=torch.int32, device=dev) # Broadcast batch inside TP group. - _broadcast_tensor(batch['tokens'], tp_src_rank, tp_group) - _broadcast_tensor(batch['position_ids'], tp_src_rank, tp_group) - _broadcast_tensor(batch['labels'], tp_src_rank, tp_group) - _broadcast_tensor(batch['loss_mask'], tp_src_rank, tp_group) - _broadcast_tensor(batch['cu_seqlens'], tp_src_rank, tp_group) - _broadcast_tensor(batch['cu_seqlens_padded'], tp_src_rank, tp_group) - _broadcast_tensor(batch['max_seqlen'], tp_src_rank, tp_group) + broadcast_tensor(batch['tokens'], tp_src_rank, tp_group) + broadcast_tensor(batch['position_ids'], tp_src_rank, tp_group) + broadcast_tensor(batch['labels'], tp_src_rank, tp_group) + broadcast_tensor(batch['loss_mask'], tp_src_rank, tp_group) + broadcast_tensor(batch['cu_seqlens'], tp_src_rank, tp_group) + broadcast_tensor(batch['cu_seqlens_padded'], tp_src_rank, tp_group) + broadcast_tensor(batch['max_seqlen'], tp_src_rank, tp_group) # Extract the data from batch after broadcasting. tokens = batch['tokens'] diff --git a/megatron/core/datasets/data_scheduler_utils.py b/megatron/core/datasets/data_scheduler_utils.py new file mode 100644 index 00000000000..2a1e2a6528d --- /dev/null +++ b/megatron/core/datasets/data_scheduler_utils.py @@ -0,0 +1,492 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. + +from typing import Dict, List + +import numpy as np +import torch + +from megatron.core.rerun_state_machine import RerunDataIterator + + +def _unpack_batch(batch): + """ + Unpacks the packed samples into a list of sub-samples. + Since each sub-sample may be routed to different DPxCP ranks, + we unpack the sample here to avoid unnecessarily transferring + the entire packed sample. + """ + batch_unpacked = [] + dev = batch[0]["tokens"].device + original_seq_lens = [] + padded_seq_lens = [] + for sample in batch: + for key in sample.keys(): + if len(sample[key].shape) == 2: + # squeeze the redundant batch dimension added by + # default collate_fn in pytorch dataloader + # we need a custom collate_fn for THD to avoid this + # current THD does not support micro_batch_size > 1 due to sft_dataset.py and + # data_loader in data_samples.py + sample[key] = sample[key].squeeze(0) + for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): + sub_sample_dict = {} + start_idx = sample["cu_seqlens"][sub_sample] + end_idx = sample["cu_seqlens"][sub_sample + 1] + if end_idx - start_idx == 0: + continue + for key in ["tokens", "labels", "loss_mask", "position_ids"]: + sub_sample_dict[key] = sample[key][start_idx:end_idx] + # Since sft_dataset.py does not provide cu_seqlens_original, + # we assume original_seq_len equals padded_seq_len here. + # Ideally the dataset should define the pre-padding seq_len. + seq_len = (end_idx - start_idx).item() + original_seq_lens.append(seq_len) + padded_seq_lens.append(seq_len) + batch_unpacked.append(sub_sample_dict) + + # Single H2D transfer for all seq lens + original_seq_lens_cuda = torch.tensor(original_seq_lens, device=dev) + padded_seq_lens_cuda = torch.tensor(padded_seq_lens, device=dev) + for i, sub_sample_dict in enumerate(batch_unpacked): + sub_sample_dict["original_seq_len"] = original_seq_lens_cuda[i : i + 1] + sub_sample_dict["padded_seq_len"] = padded_seq_lens_cuda[i : i + 1] + + return batch_unpacked + + +def _get_global_seqlens_and_ids(subsample_seqlens: torch.Tensor, dp_group): + """ + Gathers the sequence lengths of all subsamples from all DP ranks and calculates global IDs. + """ + # Collect the number of subsamples from all ranks + num_local_subsamples = subsample_seqlens.shape[0] + local_len = torch.tensor([num_local_subsamples], dtype=torch.int32).cuda() + dp_subsample_count = [torch.zeros_like(local_len) for _ in range(dp_group.size())] + torch.distributed.all_gather(dp_subsample_count, local_len, group=dp_group) + + # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length + dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) + max_sub_samples = int(dp_subsample_counts.max().item()) + + if num_local_subsamples < max_sub_samples: + subsample_seqlens_padded = torch.cat( + [ + subsample_seqlens, + torch.zeros(max_sub_samples - num_local_subsamples, dtype=torch.int32).cuda(), + ], + dim=0, + ) + else: + subsample_seqlens_padded = subsample_seqlens + + # Gather the subsample_seqlens from all ranks + seqlens_gathered = [torch.empty_like(subsample_seqlens_padded) for _ in range(dp_group.size())] + torch.distributed.all_gather(seqlens_gathered, subsample_seqlens_padded, group=dp_group) + + # Trim each seqlens_gathered to the length of the correct sample + for dp_rank, seqlen in enumerate(seqlens_gathered): + seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]] + + seqlens_gathered = torch.cat(seqlens_gathered, dim=0) + seqlens_gathered = seqlens_gathered.cpu().tolist() + + # Calculate the offsets to assign unique global ID to each subsample. + csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32) + offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum], dim=0) + + # Calculate global ID for each subsample + dp_rank = dp_group.rank() + global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() + + # Create a list of (global_id, seqlen) tuples for scheduling + global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] + + # Get the global IDs locally present on this rank + start_idx = offsets[dp_rank] + end_idx = offsets[dp_rank + 1] + + global_ids_this_rank = global_ids[start_idx:end_idx] + + return global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered + + +def _pack_sequences( + samples: List, padded_lengths: torch.Tensor, original_lengths: torch.Tensor, dev: torch.device +) -> Dict[str, torch.Tensor]: + """Pack multiple samples into a single packed sample.""" + + def _pack_tensors(tensors): + return torch.cat([t.reshape(-1) for t in tensors], dim=0) + + tokens = _pack_tensors([sample["tokens"] for sample in samples]) + labels = _pack_tensors([sample["labels"] for sample in samples]) + loss_mask = _pack_tensors([sample["loss_mask"] for sample in samples]) + position_ids = _pack_tensors([sample["position_ids"] for sample in samples]) + + new_sample = {} + new_sample["tokens"] = tokens + new_sample["labels"] = labels + new_sample["loss_mask"] = loss_mask + new_sample["position_ids"] = position_ids + + padded_lengths = padded_lengths.to(device=dev, dtype=torch.int32, non_blocking=True).reshape(-1) + cu_seqlens_padded = torch.empty(padded_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens_padded[0] = 0 + cu_seqlens_padded[1:] = torch.cumsum(padded_lengths, dim=0) + max_seqlen = torch.max(padded_lengths).to(dtype=torch.int32) + + new_sample["cu_seqlens_padded"] = cu_seqlens_padded + new_sample["max_seqlen"] = max_seqlen + + original_lengths = original_lengths.to( + device=dev, dtype=torch.int32, non_blocking=True + ).reshape(-1) + cu_seqlens = torch.empty(original_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens[0] = 0 + cu_seqlens[1:] = torch.cumsum(original_lengths, dim=0).reshape(-1) + new_sample["cu_seqlens"] = cu_seqlens + + return new_sample + + +def broadcast_tensor(item, src_rank, group) -> None: + """Broadcast a tensor from src_rank to all ranks in the group.""" + if item is not None: + torch.distributed.broadcast(item, src_rank, group=group) + + +def broadcast_to_pp_group( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + pp_group, + dev, +): + """ + Broadcast num_micro_batches, seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch and metadata to middle PP stages. + Before this broadcast, the new_samples on middle PP stages are None, + after this broadcast, the new_samples on middle PP stages contain the metadata but + without tokens, labels, loss_mask, position_ids. + """ + + pp_src_rank = torch.distributed.get_process_group_ranks(pp_group)[0] + + if pp_group.size() > 2: + if pp_group.rank() == 0: + tensor_list = [ + torch.tensor( + [ + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ], + dtype=torch.float32, + ).cuda() + ] + for sample in new_samples: + tensor_list.append(sample["max_seqlen"].unsqueeze(0)) + for sample in new_samples: + tensor_list.append(sample["cu_seqlens"]) + tensor_list.append(sample["cu_seqlens_padded"]) + info_to_broadcast = torch.cat(tensor_list, dim=0).to(device=dev, dtype=torch.float32) + info_length_tensor = torch.tensor(info_to_broadcast.shape[0], dtype=torch.int32).cuda() + broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) + broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) + else: + info_length_tensor = torch.tensor(0, dtype=torch.int32).cuda() + broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) + info_to_broadcast = torch.empty(info_length_tensor.item(), dtype=torch.float32).cuda() + broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) + if pp_group.rank() != pp_group.size() - 1: + # middle PP stages receive the broadcasted info and unpack it + info_numpy = info_to_broadcast.cpu().numpy() + num_micro_batches = int(info_numpy[0]) + seqlen_sum_this_global_batch = info_numpy[1] + seqlen_squared_sum_this_global_batch = info_numpy[2] + max_seqlens = info_to_broadcast[3 : 3 + num_micro_batches] + cu_seqlens_list = [] + cu_seqlens_padded_list = [] + indices = np.where(info_numpy == 0)[0] + for i in range(num_micro_batches): + cu_seqlens_list.append(info_to_broadcast[indices[i * 2] : indices[i * 2 + 1]]) + if i == num_micro_batches - 1: + cu_seqlens_padded_list.append(info_to_broadcast[indices[i * 2 + 1] :]) + else: + cu_seqlens_padded_list.append( + info_to_broadcast[indices[i * 2 + 1] : indices[i * 2 + 2]] + ) + + new_samples = [] + for i in range(num_micro_batches): + new_sample = {} + new_sample["max_seqlen"] = max_seqlens[i].to(torch.int32) + new_sample["cu_seqlens"] = cu_seqlens_list[i].to(torch.int32) + new_sample["cu_seqlens_padded"] = cu_seqlens_padded_list[i].to(torch.int32) + new_samples.append(new_sample) + + return ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + +def broadcast_scalars(values: List, group, dev, dtype=torch.float32) -> List: + """ + Broadcast scalar values from rank 0 to all ranks in the group. + + Args: + values: List of scalar values to broadcast (only used on rank 0). + group: The process group to broadcast within. + dev: The device to use for the tensor. + dtype: The data type for the tensor. + + Returns: + List of broadcasted values. + """ + if group.size() <= 1: + return values + + src_rank = torch.distributed.get_process_group_ranks(group)[0] + num_values = len(values) + + if group.rank() == 0: + info_to_broadcast = torch.tensor(values, dtype=dtype, device=dev) + else: + info_to_broadcast = torch.zeros(num_values, dtype=dtype, device=dev) + + broadcast_tensor(info_to_broadcast, src_rank, group) + + if group.rank() != 0: + values = info_to_broadcast.cpu().tolist() + + return values + + +def create_data_iterator(new_samples, pp_group, tp_group, config): + """Handle virtual pipeline parallelism.""" + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + vpp_size = config.virtual_pipeline_model_parallel_size + if tp_group.rank() == 0: + if pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1: + metadata = [ + {k: sample[k] for k in ["max_seqlen", "cu_seqlens", "cu_seqlens_padded"]} + for sample in new_samples + ] + if pp_group.rank() == 0: + new_data_iterator = [RerunDataIterator(iter(new_samples))] + [ + RerunDataIterator(iter(metadata)) for _ in range(vpp_size - 1) + ] + else: + new_data_iterator = [ + RerunDataIterator(iter(metadata)) for _ in range(vpp_size - 1) + ] + [RerunDataIterator(iter(new_samples))] + else: + # on middle PP stages, the new_samples are the metadata + metadata = new_samples + new_data_iterator = [RerunDataIterator(iter(metadata)) for _ in range(vpp_size)] + else: + new_data_iterator = [None for _ in range(vpp_size)] + else: + new_data_iterator = RerunDataIterator(iter(new_samples)) if tp_group.rank() == 0 else None + + return new_data_iterator + + +def reroute_samples_to_dcp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_dcp_gpus, +): + """ + Reroutes the sub-samples to the correct rank after scheduling. + + For each key in the batch dict, we perform an all-to-all communication + to transfer the data to the correct ranks. + """ + + def _gid_to_src_rank(gid: int) -> int: + dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) + dcp_rank = ( + torch.distributed.get_process_group_ranks(dp_group)[dp_src_rank] // tp_group.size() + ) % dp_cp_group.size() + return dcp_rank + + gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} + dcp_rank = dp_cp_group.rank() + dp_ranks = torch.distributed.get_process_group_ranks(dp_group) + dp_ranks = [(r // tp_group.size()) % dp_cp_group.size() for r in dp_ranks] + + data_keys = batch[0].keys() + + # Create the send plan + combined_sample_id_groups: List[List[int]] = [[] for _ in range(total_dcp_gpus)] + for d in range(total_dcp_gpus): + for sample_id_group in sample_id_groups: + combined_sample_id_groups[d].extend(sample_id_group[d]) + for dest_rank in range(total_dcp_gpus): + combined_sample_id_groups[dest_rank].sort() + + send_ids_sorted = [ + gid for d in dp_ranks for gid in combined_sample_id_groups[d] if gid in global_ids_this_rank + ] + + send_num_split = [0] * total_dcp_gpus + send_lens_split = [0] * total_dcp_gpus + for dest_rank in range(total_dcp_gpus): + if dest_rank in dp_ranks: + send_seq_lens = [ + global_id_seqlens[gid][1] + for gid in combined_sample_id_groups[dest_rank] + if gid in global_ids_this_rank + ] + send_num_split[dest_rank] = len(send_seq_lens) + send_lens_split[dest_rank] = sum(send_seq_lens) + else: + send_lens_split[dest_rank] = 0 + + # Create the recv plan + recv_sample_id_groups = [[] for _ in range(total_dcp_gpus)] + for gid in combined_sample_id_groups[dcp_rank]: + src_rank = _gid_to_src_rank(gid) + recv_sample_id_groups[src_rank].append(gid) + + recv_lens_split = [0] * total_dcp_gpus + for src_rank in range(total_dcp_gpus): + recv_lens_split[src_rank] = sum( + [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] + ) + + recv_ids_sorted = [gid for d in range(total_dcp_gpus) for gid in recv_sample_id_groups[d]] + recv_counts = [len(recv_sample_id_groups[d]) for d in range(total_dcp_gpus)] + + recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] + + def _pack_sample_by_key(key: str) -> torch.Tensor: + flattened_tensors = [] + for gid in send_ids_sorted: + t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) + flattened_tensors.append(t.reshape(-1)) + return ( + torch.cat(flattened_tensors, dim=0) + if flattened_tensors + else torch.empty(1, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + ) + + def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): + cursor = 0 + for i, gid in enumerate(recv_ids_sorted): + sample_len = ( + 1 if key in ["original_seq_len", "padded_seq_len"] else global_id_seqlens[gid][1] + ) + recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] + cursor += sample_len + + for key in data_keys: + output_split_sizes, input_split_sizes = ( + (recv_counts, send_num_split) + if key in ["original_seq_len", "padded_seq_len"] + else (recv_lens_split, send_lens_split) + ) + send_tensor = _pack_sample_by_key(key) + recv_tensor_size = sum(output_split_sizes) + recv_tensor = torch.empty( + recv_tensor_size, device=torch.cuda.current_device(), dtype=send_tensor.dtype + ) + torch.distributed.all_to_all_single( + output=recv_tensor, + input=send_tensor, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=dp_cp_group, + ) + _unpack_sample_by_key(key, recv_tensor) + + recv_sample_with_id = {recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted)} + return recv_sample_with_id + + +def build_packed_microbatches( + grouped_samples: List[List[Dict[str, torch.Tensor]]], dev: torch.device +) -> List[Dict[str, torch.Tensor]]: + """Build packed samples for each microbatch.""" + num_micro_batches = len(grouped_samples) + seg_starts: List[int] = [0] + original_lens_tensors = [] + padded_lens_tensors = [] + + for i in range(num_micro_batches): + samples = grouped_samples[i] + seg_starts.append(seg_starts[-1] + len(samples)) + original_lens_tensors.extend([s["original_seq_len"].reshape(-1) for s in samples]) + padded_lens_tensors.extend([s["padded_seq_len"].reshape(-1) for s in samples]) + + padded_lens_all_gpu = torch.cat(padded_lens_tensors, dim=0).to(dtype=torch.int32) + original_lens_all_gpu = torch.cat(original_lens_tensors, dim=0).to(dtype=torch.int32) + + new_samples: List[Dict[str, torch.Tensor]] = [] + for i in range(num_micro_batches): + samples = grouped_samples[i] + lens_padded = padded_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + lens_original = original_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + new_sample = _pack_sequences(samples, lens_padded, lens_original, dev) + new_samples.append(new_sample) + + return new_samples + + +def get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group): + """ + Get the batch and global sequence lengths. + Each DP rank loads the same number of sequences, so we need to gather the sequence + lengths from all ranks then we can schedule the sequences into groups. + Args: + data_iterator: The data iterator. + num_microbatches: The number of microbatches. + dp_group: The data parallel group. + + Returns: + batch: The batch. + global_id_seqlens: The global sequence lengths. + global_ids_this_rank: The global IDs locally present on this rank. + """ + + batch_list = [next(data_iterator) for _ in range(num_microbatches)] + + batch = [] + for item in batch_list: + if isinstance(item, dict): + batch.append(item) + elif isinstance(item, list): + batch.extend(item) + else: + raise ValueError(f"Invalid item type: {type(item)}") + + # This unpack step is redundant: in sft_dataset.py, sequences are already packed before + # rescheduling, so we need to unpack them here and repack after rescheduling. This is only + # to adapt to the current megatron-lm sft_dataset. + # If you implement your own dataset, just have __getitem__ return List[Dict] + # and this step can be skipped. + batch = _unpack_batch(batch) + + subsample_seqlens = torch.cat([sample["padded_seq_len"] for sample in batch]).to( + dtype=torch.int32, device=torch.cuda.current_device() + ) + + global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered = ( + _get_global_seqlens_and_ids(subsample_seqlens, dp_group) + ) + + return batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index fa421641db5..bea7f4a4c18 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -82,8 +82,10 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): sft_mock_dataset_config_json: Optional[str] = None """This config provides the necessary information for the mock dataset.""" - sequence_packing: bool = False - """Option to enable sequence packing for training.""" + sequence_packing_scheduler: Optional[str] = None + """Scheduler for sequence packing and hybrid context parallel. + dp_balanced: DP-balanced scheduler for sequence packing. + """ def __post_init__(self) -> None: """Do asserts and set fields post init""" diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index bb913d97446..20f0ece635e 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -2559,3 +2559,24 @@ def set_save_original_input(module): from transformer_engine.pytorch.float8_tensor import Float8Tensor except ImportError: Float8Tensor = None + + +def get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank): + """Get partitioned indices for THD format data in context parallel. + + Args: + cu_seqlens: Cumulative sequence lengths tensor. + total_tokens: Total number of tokens. + cp_size: Context parallel world size. + cp_rank: Context parallel rank. + + Returns: + Partitioned indices tensor. + """ + assert is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + import transformer_engine_torch as tex + + return tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index c5cf76fb099..442f5a31e3c 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -62,7 +62,7 @@ class ModelParallelConfig: can handle without overflowing the memory. Typically, a good starting point is to set this to maximum sequence length / context parallel size. This is used to calculate the number and length of sub-samples assigned to - each rank when using sequence_packing. + each rank when sequence_packing_scheduler is not None. """ hybrid_context_parallel: bool = False @@ -75,12 +75,7 @@ class ModelParallelConfig: sequence_packing_scheduler: Optional[str] = None """ Scheduler for sequence packing and hybrid context parallel. - default_sequence_packing: default sequence packing scheduler for sequence packing. - """ - - sequence_packing: bool = False - """ - If true, enables sft sequence packing. + dp_balanced: DP-balanced scheduler for sequence packing. """ expert_model_parallel_size: int = 1 diff --git a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py index e06e51e5ee3..8a418f2dd7f 100644 --- a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py @@ -237,11 +237,6 @@ def bos_id(self): def eod(self): """End of sentence token ID.""" return self._tokenizer.eos_token_id - - @property - def eos(self): - """End of sentence token ID.""" - return self._tokenizer.eos_token_id @property def vocab(self): diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index aafdc79bbc6..71508401657 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2076,7 +2076,7 @@ def __post_init__(self): self.attention_backend == AttnBackend.flash ), "Batch invariant mode only supports FlashAttention" - if self.sequence_packing: + if self.sequence_packing_scheduler is not None: # Check TE version. if not HAVE_PACKAGING: raise ImportError( @@ -2095,14 +2095,16 @@ def __post_init__(self): self.variable_seq_lengths = True # TODO(tailaim): add support for other dispatcher types - warnings.warn("Setting moe_token_dispatcher_type to alltoall for sft sequence packing.") - self.moe_token_dispatcher_type = "alltoall" - - if self.sequence_packing_scheduler is None: - self.sequence_packing_scheduler = 'default_sequence_packing' + assert self.moe_token_dispatcher_type == "alltoall", ( + f"sequence_packing only supports moe_token_dispatcher_type='alltoall', " + f"got '{self.moe_token_dispatcher_type}'" + ) - supported_schedulers = ['default_sequence_packing'] - if self.sequence_packing_scheduler not in supported_schedulers: + supported_schedulers = ['dp_balanced'] + if ( + self.sequence_packing_scheduler is not None + and self.sequence_packing_scheduler not in supported_schedulers + ): raise ValueError( f"Unknown scheduler: {self.sequence_packing_scheduler}. " f"Available schedulers: {supported_schedulers}" diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 50e2eb3581a..56c3c077d46 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1060,7 +1060,7 @@ def validate_args(args, defaults={}): # during pipeline parallelism, it should not be set if sequence length # is constant during training. args.variable_seq_lengths = False - if args.sequence_packing: + if args.sequence_packing_scheduler is not None: args.variable_seq_lengths = True assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ @@ -1834,6 +1834,9 @@ def _add_network_size_args(parser): "persist_layer_norm", "bias_dropout_fusion", "apply_rope_fusion", + "max_seqlen_per_dp_cp_rank", + "hybrid_context_parallel", + "sequence_packing_scheduler", ] transformer_factory = ArgumentGroupFactory(TransformerConfig, exclude=exclude) transformer_group = transformer_factory.build_group(parser, "transformer configuration") @@ -2535,6 +2538,7 @@ def _add_distributed_args(parser): 'all layers will share the same communication type. Users can also ' 'specify separated types for each layer like ' '--cp-comm-type p2p p2p a2a a2a a2a+p2p a2a+p2p') + group.add_argument('--sequence-packing-scheduler', type=str, default='default_sequence_packing', choices=['default_sequence_packing']) group.add_argument('--fake-process-group', action='store_true', default=False, help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \ This is quite useful for profiling memory usage of distributed training with just one GPU. \ @@ -3083,5 +3087,7 @@ def _add_sft_args(parser): group.add_argument('--sequence-packing', action='store_true', help='use sequence packing(thd format) for training') group.add_argument('--sft-mock-dataset-config-json', type=str, default=None, - help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution.') + help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution. ' + 'If not specified and --mock-data is set, defaults to a lognormal distribution with ' + 'min_seq_len=seq_length//2, max_seq_len=seq_length, mean_seq_len=seq_length*3//4, lognormal_sigma=1.1.') return parser diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index b8b15e1a985..1672ad75310 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -4,7 +4,7 @@ from collections import Counter import json import math -from typing import Any, Dict, Optional, List +from typing import Any, Dict, Optional, List, Union import numpy as np import pandas as pd @@ -50,7 +50,6 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> list: return self.dataset[idx]["messages"] - class SFTDataset(MegatronDataset): """The dataset used during SFT""" @@ -64,7 +63,7 @@ def __init__( config: GPTDatasetConfig, ) -> None: super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) - # Pre-calculate padding divisor to avoid redundant computation in get_padding_size + # Pre-calculate padding divisor to avoid redundant computation in get_item self.padding_divisor = self._calculate_padding_divisor() @staticmethod @@ -112,19 +111,9 @@ def _calculate_padding_divisor(self) -> int: # TODO(tailaim): do we need to pad for FP8 execution? # divisor = ((divisor + 15) // 16) * 16 return divisor - - def get_padding_size( - self, - seq_len: int, - ) -> int: - seq_len_padded = math.ceil(seq_len / self.padding_divisor) * self.padding_divisor - assert seq_len > seq_len_padded / 2 / self.config.context_parallel_size * (self.config.context_parallel_size - 1), \ - f"sequence length {seq_len} is too short, the divisor is {self.padding_divisor}, that means cp_rank \ - {self.config.context_parallel_size-1} will have no valid tokens" - return seq_len_padded def __getitem__(self, idx: int) -> Dict[str, Any]: - sequence_packing = self.config.sequence_packing + tokenizer = self.config.tokenizer pack_length = self.config.sequence_length @@ -159,12 +148,11 @@ def extend_with_padding(tokens, targets, positions, pad_len): assert not self.config.reset_position_ids pack_positions.extend(range(len(tokens_list))) - if self.config.context_parallel_size > 1: - pad_granularity = self.config.context_parallel_size * 2 - mod_token_count = len(pack_tokens) % pad_granularity - if mod_token_count != 0: - pad_len = pad_granularity - mod_token_count - extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + pad_granularity = self.padding_divisor + mod_token_count = len(pack_tokens) % pad_granularity + if mod_token_count != 0: + pad_len = pad_granularity - mod_token_count + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) # TODO(duncan): Consider also padding to multiple of number of tokens here. This might # be needed for efficiency (and potentially set via command-line argument). @@ -299,64 +287,81 @@ def __len__(self) -> int: return self.num_samples def __getitem__(self, idx: int) -> Dict[str, Any]: - sequence_packing = self.config.sequence_packing + tokenizer = self.config.tokenizer - max_seq_len = self.config.sequence_length + pack_length = self.config.sequence_length + eod = tokenizer.eod + pad = tokenizer.pad tokens = self.dataset[int(self.indices[idx % len(self.indices)])] - target = np.array(tokens, dtype=np.int64) - - force_eod_length = int(tokenizer.force_eod) - - if len(tokens) > max_seq_len - force_eod_length: - # cut the right side - tokens = tokens[: max_seq_len - force_eod_length] - target = target[: max_seq_len - force_eod_length] - # tokens = tokens[(-max_seq_len + force_eod_length):] - # target = target[(-max_seq_len + force_eod_length):] - - # padding - num_tokens = len(tokens) + force_eod_length - if sequence_packing: - padding_len = self.get_padding_size(num_tokens) - num_tokens - else: - padding_len = max_seq_len - num_tokens - assert padding_len >= 0 - filler = [tokenizer.eod] * force_eod_length + [tokenizer.pad] * (padding_len + 1) - - tokens = np.array(tokens.tolist() + filler, dtype=np.int64) - target = np.array(target.tolist() + filler, dtype=np.int64) - - tokens = torch.tensor(tokens) - target = torch.tensor(target) - - tokens = tokens[:-1].contiguous() - target = target[1:].contiguous() - seq_len = tokens.numel() - - loss_mask, position_ids, attention_mask = self._get_ltor_masks_and_position_ids( - seq_len, target, tokenizer.pad - ) - - if self.config.create_attention_mask: - ret = { - 'tokens': tokens, - 'labels': target, - 'attention_mask': attention_mask, - 'loss_mask': loss_mask, - 'position_ids': position_ids, - } - else: - ret = { - 'tokens': tokens, - 'labels': target, - 'loss_mask': loss_mask, - 'position_ids': position_ids, - } - - if sequence_packing: - # sequence packing need both original sequence length and padded length - ret['original_seq_len'] = torch.tensor(num_tokens, dtype=torch.int32, device=tokens.device) - ret['padded_seq_len'] = torch.tensor(seq_len, dtype=torch.int32, device=tokens.device) - - return ret + + def extend_with_padding(tokens, targets, positions, pad_len): + tokens.extend([pad] * pad_len) + targets.extend([pad] * pad_len) + positions.extend(range(positions[-1] + 1, positions[-1] + 1 + pad_len)) + + # Convert tokens to list and add EOD + tokens_list = tokens.tolist() + if tokens_list[-1] != eod: + tokens_list.append(eod) + targets_list = list(tokens_list) + + pack_tokens = list(tokens_list) + pack_targets = list(targets_list) + pack_positions = list(range(len(tokens_list))) + cu_seqlens = [0] + + # Pad to padding_divisor alignment + if self.padding_divisor > 1: + mod_token_count = len(pack_tokens) % self.padding_divisor + if mod_token_count != 0: + pad_len = self.padding_divisor - mod_token_count + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + + # Record padded boundary after padding + cu_seqlens.append(len(pack_tokens)) + + # Handle any necessary truncation + if len(pack_tokens) >= pack_length + 1: # +1 here to account for later alignment + max_body = pack_length - 1 + pack_tokens = pack_tokens[:max_body] + pack_targets = pack_targets[:max_body] + pack_tokens.extend([eod, pad]) + pack_targets.extend([eod, pad]) + pack_positions = pack_positions[:pack_length + 1] + cu_seqlens[-1] = len(pack_tokens) - 1 + + # Handle any necessary padding + if len(pack_tokens) < pack_length + 1: # +1 here to account for later alignment + pad_len = pack_length + 1 - len(pack_tokens) + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + cu_seqlens[-1] = len(pack_tokens) - 1 + + assert len(pack_tokens) == pack_length + 1 + assert len(pack_targets) == pack_length + 1 + assert len(pack_positions) == pack_length + 1 + + # Align and convert to tensors + input_ids = torch.tensor(pack_tokens[:-1], dtype=torch.int64) + labels = torch.tensor(pack_targets[1:], dtype=torch.int64) + position_ids = torch.tensor(pack_positions[:-1], dtype=torch.int64) + + # Loss mask + loss_mask = torch.ones(pack_length, dtype=torch.float32) + loss_mask[labels == pad] = 0.0 + loss_mask[labels == IGNORE_INDEX] = 0.0 + + assert len(cu_seqlens) >= 2 + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) + # Calculating max_seqlen here because of possible effects of truncation and padding + adjacent_diffs = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = adjacent_diffs.max() # max_seqlen is a 0-D tensor + + return { + 'tokens': input_ids, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + 'cu_seqlens': cu_seqlens, + 'max_seqlen': max_seqlen, + } \ No newline at end of file diff --git a/megatron/training/training.py b/megatron/training/training.py index 4b811528fde..a9be5c40469 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -168,7 +168,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): get_num_microbatches, update_num_microbatches ) -from megatron.core.datasets.data_schedule import wrap_dataloader +from megatron.core.datasets.data_schedule import wrap_data_iterator from .async_utils import maybe_finalize_async_save from .utils import ( @@ -1731,13 +1731,21 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch if isinstance(optim_instance, DistributedOptimizer): optim_instance.release_offloaded_gpu_states() - if config.sequence_packing: + if config.sequence_packing_scheduler is not None: + # This wrapper is designed to support DP-balanced THD and dynamic-CP. + # Before wrapping, the data_iterator returns either a single sequence per get_item call, or a list where each element is a sequence. + # The wrapper is responsible for: + # 1. scheduling the sequences across ranks + # 2. packing them into THD format + # 3. broadcast flops parametes and num_microbatches to TP ranks to support unfixed num_microbatches + # 4. broadcast metadata(cu_seqlens, cu_seqlens_padded, max_seqlen, etc.) to PP ranks to + # 5. returning the packed data iterator and the FLOPs parameters ( data_iterator, num_microbatches, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, - ) = wrap_dataloader(data_iterator, config, num_microbatches) + ) = wrap_data_iterator(data_iterator, config, get_num_microbatches()) else: # data_iterator unchanged num_microbatches = get_num_microbatches() @@ -2886,7 +2894,7 @@ def trace_handler(p): # Completely skip iteration if needed. if iteration in args.iterations_to_skip: # TODO(tailaim): this need to be modified - assert not args.sequence_packing, "Sequence packing is not supported in skip iteration mode" + assert config.sequence_packing_scheduler is None, "Sequence packing scheduler is not supported in skip iteration mode" # Dummy train_step to fast forward train_data_iterator. dummy_train_step(train_data_iterator) if iteration == start_iteration: diff --git a/pretrain_gpt.py b/pretrain_gpt.py index b2d7ce192c1..b9222a8c001 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -69,9 +69,10 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): args = get_args() config = core_transformer_config_from_args(args) - if args.sequence_packing: + if args.sequence_packing_scheduler is not None: return get_batch_on_this_rank_for_sequence_packing( data_iterator, + vpp_size=config.virtual_pipeline_model_parallel_size, mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage), vp_stage=vp_stage, ) @@ -261,7 +262,7 @@ def core_gpt_dataset_config_from_args(args): "sequence_parallel_size": args.tensor_model_parallel_size*args.sequence_parallel, "hybrid_context_parallel": args.hybrid_context_parallel, "sft_mock_dataset_config_json":args.sft_mock_dataset_config_json, - "sequence_packing": args.sequence_packing, + "sequence_packing_scheduler": args.sequence_packing_scheduler, } # add FIM args to the config diff --git a/tests/unit_tests/test_sequence_packing.py b/tests/unit_tests/test_sequence_packing.py index bac7ea79db1..288e66f2f93 100644 --- a/tests/unit_tests/test_sequence_packing.py +++ b/tests/unit_tests/test_sequence_packing.py @@ -1,12 +1,20 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +import random from types import SimpleNamespace +import numpy as np import pytest import torch from megatron.core import parallel_state -from megatron.core.datasets.data_schedule import get_batch_on_this_rank_for_sequence_packing +from megatron.core.datasets.data_schedule import ( + PackingScheduler, + get_batch_on_this_rank_for_sequence_packing, + scheduler_map, + wrap_data_iterator, +) +from megatron.core.rerun_state_machine import RerunDataIterator from megatron.training.global_vars import unset_global_variables from tests.unit_tests.test_utilities import Utils @@ -32,7 +40,6 @@ def __init__( total_seq_length: Total length of packed sequences sequence_lengths: List of individual sequence lengths (variable-length). If None, generates random variable lengths. - local_cp_size: Local CP size for hybrid context parallel device: Device to create tensors on seed: Random seed for reproducibility """ @@ -127,21 +134,19 @@ def _gather_tensor_from_all_ranks(tensor): @pytest.mark.parametrize( - ("tp", "pp", "cp", "hybrid_cp"), + ("tp", "pp", "cp"), [ - (1, 1, 1, False), # Basic case: no parallelism - (2, 1, 1, False), # Tensor parallel only - (1, 2, 1, False), # Pipeline parallel only - (2, 2, 1, False), # TP + PP - (1, 1, 2, False), # CP only - (2, 1, 2, False), # TP + CP - (1, 2, 2, False), # PP + CP - (1, 4, 1, False), # Has middle pp stage - (1, 1, 1, True), # Hybrid CP enabled (CP=1 with hybrid groups) - (2, 1, 1, True), # TP + Hybrid CP + (1, 1, 1), # Basic case: no parallelism + (2, 1, 1), # Tensor parallel only + (1, 2, 1), # Pipeline parallel only + (2, 2, 1), # TP + PP + (1, 1, 2), # CP only + (2, 1, 2), # TP + CP + (1, 2, 2), # PP + CP + (1, 4, 1), # Has middle pp stage ], ) -def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): +def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp): """ Test get_batch_on_this_rank_for_sequence_packing function with variable-length THD format. @@ -155,7 +160,6 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): args.tensor_model_parallel_size = tp args.pipeline_model_parallel_size = pp args.context_parallel_size = cp - args.hybrid_context_parallel = hybrid_cp args.virtual_pipeline_model_parallel_size = None args.data_parallel_size = 8 // (tp * pp * cp) args.seq_length = 8192 @@ -165,21 +169,12 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): raise ValueError(f"Invalid config: tp={tp}, pp={pp}, cp={cp} exceeds world size 8") # Initialize model parallel - Utils.initialize_model_parallel( - tp, - pp, - None, - context_parallel_size=cp, - hybrid_context_parallel=hybrid_cp, - min_hybrid_context_parallel_size=1, - ) + Utils.initialize_model_parallel(tp, pp, None, context_parallel_size=cp) try: # Create mock data iterator with variable-length sequences # Only TP rank 0 needs the iterator; other TP ranks pass None tp_rank = parallel_state.get_tensor_model_parallel_rank() - local_cp_size = 8 // (tp * pp) if hybrid_cp else None - if tp_rank == 0: # Use deterministic seed based on DP rank so same data within TP/PP/CP group dp_rank = parallel_state.get_data_parallel_rank() @@ -191,7 +186,6 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): MockVariableLengthSequencePackingDataIterator( total_seq_length=args.seq_length, sequence_lengths=sequence_lengths, # Variable lengths, sum=8192 - local_cp_size=local_cp_size, seed=42 + dp_rank, # Same seed within PP/CP group ) ) @@ -201,10 +195,7 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): # Call the function under test result = get_batch_on_this_rank_for_sequence_packing( - data_iterator=data_iterator, - mtp_on_this_rank=False, - vp_stage=None, - hybrid_context_parallel=hybrid_cp, + data_iterator=data_iterator, mtp_on_this_rank=False, vp_stage=None ) # Unpack the result @@ -248,9 +239,6 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): # ===================================================================== assert packed_seq_params is not None assert packed_seq_params.qkv_format == "thd" - if hybrid_cp: - assert packed_seq_params.local_cp_size is not None - assert packed_seq_params.cp_group is not None test_keys = [ "cu_seqlens_q", @@ -260,8 +248,6 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): "cu_seqlens_kv_padded", "max_seqlen_kv", ] - if hybrid_cp: - test_keys.append("local_cp_size") for key in test_keys: tensor = getattr(packed_seq_params, key) assert tensor is not None @@ -291,18 +277,9 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): # ===================================================================== # TEST 4: Verify CP partitioning # ===================================================================== - if cp > 1 or hybrid_cp: - if hybrid_cp: - assert packed_seq_params.local_cp_size is not None - cp_size = packed_seq_params.local_cp_size - assert packed_seq_params.cp_group == ( - parallel_state.get_hybrid_data_context_parallel_groups(group_size=cp_size) - ) - else: - cp_size = cp - + if cp > 1: # With CP, the sequence should be partitioned - expected_seq_len = args.seq_length // cp_size + expected_seq_len = args.seq_length // cp if is_first_stage: actual_seq_len = tokens.shape[1] @@ -320,3 +297,185 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): finally: Utils.destroy_model_parallel() unset_global_variables() + + +@pytest.mark.parametrize( + ("tp", "pp", "cp", "vpp", "scheduler_type"), + [ + (1, 1, 8, None, "dp_balanced"), + (2, 1, 4, None, "dp_balanced"), + (2, 4, 1, None, "dp_balanced"), + (2, 2, 1, None, "dp_balanced"), + (1, 4, 1, 4, "dp_balanced"), + ], +) +def test_wrap_dataloader(tp, pp, cp, vpp, scheduler_type): + ''' + Test wrap_dataloader function with different scheduler types. + ''' + args = SimpleNamespace() + args.tensor_model_parallel_size = tp + args.pipeline_model_parallel_size = pp + args.context_parallel_size = cp + args.virtual_pipeline_model_parallel_size = None + args.data_parallel_size = 8 // (tp * pp * cp) + args.seq_length = 8192 + args.max_seqlen_per_dp_cp_rank = 8192 + + # Skip invalid configurations + if args.data_parallel_size < 1: + raise ValueError(f"Invalid config: tp={tp}, pp={pp}, cp={cp} exceeds world size 8") + + def _create_single_sample(seq_len): + # hard code the padding size to 16 + pad_size = 16 + seq_len_padded = ((seq_len + pad_size - 1) // pad_size) * pad_size + device = torch.device("cuda", torch.cuda.current_device()) + tokens = torch.randint(0, 128, (seq_len_padded,), dtype=torch.int64, device=device) + labels = tokens + 1 + position_ids = torch.arange(seq_len_padded, dtype=torch.int64, device=device) + loss_mask = torch.ones(seq_len_padded, dtype=torch.float32, device=device) + loss_mask[0:seq_len] = 1 + loss_mask[seq_len:] = 0 + cu_seqlens = torch.tensor([0, seq_len_padded], dtype=torch.int32, device=device) + + return { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + 'cu_seqlens': cu_seqlens, + } + + # Initialize model parallel + Utils.initialize_model_parallel(tp, pp, vpp, context_parallel_size=cp) + + global_batch_size = 64 + micro_batch_size = 1 + nums = [random.randint(2048, args.seq_length) for _ in range(global_batch_size)] # 64 sequences + + config = SimpleNamespace() + config.max_seqlen_per_dp_cp_rank = args.max_seqlen_per_dp_cp_rank + config.microbatch_group_size_per_vp_stage = pp + config.virtual_pipeline_model_parallel_size = vpp + config.sequence_packing_scheduler = scheduler_type + + dp_rank = parallel_state.get_data_parallel_rank() + dp_size = parallel_state.get_data_parallel_world_size() + + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + is_pp_first = pp_rank == 0 + is_pp_last = pp_rank == pp - 1 + is_pp_first_or_last = is_pp_first or is_pp_last + is_tp_first = tp_rank == 0 + + num_micro_batches_old = global_batch_size // micro_batch_size // dp_size + + if is_tp_first and (is_pp_first or is_pp_last): + samples = [ + _create_single_sample(num) + for num in nums[dp_rank * num_micro_batches_old : (dp_rank + 1) * num_micro_batches_old] + ] + data_iterator = RerunDataIterator(iter(samples)) + else: + data_iterator = None + + if is_tp_first: + if vpp is not None and vpp > 1: + if is_pp_first: + data_iterator = [data_iterator] + [None for _ in range(vpp - 1)] + elif is_pp_last: + data_iterator = [None for _ in range(vpp - 1)] + [data_iterator] + else: + data_iterator = [None for _ in range(vpp)] + try: + # Call the function under test + ( + new_data_iterator, + num_micro_batches, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, + ) = wrap_data_iterator(data_iterator, config, num_micro_batches_old) + + # check the result + assert type(num_micro_batches) is int + assert ( + type(num_total_tokens_this_global_batch) is float + or type(num_total_tokens_this_global_batch) is np.float32 + ) + assert ( + type(sequence_square_sum_this_global_batch) is float + or type(sequence_square_sum_this_global_batch) is np.float32 + ) + + def _check_batch(batch_all, batch_keys): + for batch in batch_all: + assert set(batch_keys) <= set( + batch.keys() + ), f"batch keys: {set(batch.keys())} missing {set(batch_keys) - set(batch.keys())}" + for key in batch_keys: + assert batch[key] is not None + + if is_tp_first: + # CHECK KEYS + batch_keys = ["cu_seqlens", "max_seqlen", "cu_seqlens_padded"] + if vpp is not None and vpp > 1: + # check metadata for all stages (save batches to avoid re-consuming iterators) + all_stage_batches = [] + for temp_data_iterator in new_data_iterator: + stage_batch = [next(temp_data_iterator) for _ in range(num_micro_batches)] + all_stage_batches.append(stage_batch) + _check_batch(stage_batch, batch_keys) + + # check for first or last stage on first or last pp rank + if is_pp_first_or_last: + batch_all = all_stage_batches[0] if is_pp_first else all_stage_batches[-1] + batch_keys += ["tokens", "position_ids", "labels", "loss_mask"] + _check_batch(batch_all, batch_keys) + else: + # non-VPP: single iterator + batch_all = [next(new_data_iterator) for _ in range(num_micro_batches)] + if is_pp_first_or_last: + batch_keys += ["tokens", "position_ids", "labels", "loss_mask"] + _check_batch(batch_all, batch_keys) + + # CHECK TOKEN SUM ON FIRST OR LAST PP RANK + # Note: data_iterator is consumed by wrap_data_iterator, new_data_iterator is consumed above. + # Use `samples` for before-wrap, reuse `batch_all` from the check above for after-wrap. + if is_pp_first_or_last: + # Compute token sum before wrap + token_sum_before = torch.tensor(0, dtype=torch.int64, device='cuda') + for sample in samples: + token_sum_before += sample['tokens'].long().sum() + + # Compute token sum after wrap (batch_all already collected above with tokens) + token_sum_after = torch.tensor(0, dtype=torch.int64, device='cuda') + for batch in batch_all: + token_sum_after += batch['tokens'].long().sum() + + # Reduce sum across dp_cp group and verify equality + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=False) + torch.distributed.all_reduce( + token_sum_before, op=torch.distributed.ReduceOp.SUM, group=dp_cp_group + ) + torch.distributed.all_reduce( + token_sum_after, op=torch.distributed.ReduceOp.SUM, group=dp_cp_group + ) + + assert ( + token_sum_before == token_sum_after + ), f"Token sum mismatch: before={token_sum_before.item()}, after={token_sum_after.item()}" + + else: + if vpp is not None and vpp > 1: + assert type(new_data_iterator) is list and len(new_data_iterator) == vpp + for data_iterator in new_data_iterator: + assert data_iterator is None + else: + assert new_data_iterator is None + + finally: + Utils.destroy_model_parallel() + unset_global_variables() From 787d08e903bae5bea03421cd8276cc8a986e833b Mon Sep 17 00:00:00 2001 From: xiaoyao0115 <1804647152@qq.com> Date: Thu, 12 Feb 2026 05:42:09 -0800 Subject: [PATCH 3/9] small fixes according to comments Signed-off-by: xiaoyao0115 <1804647152@qq.com> --- megatron/core/datasets/data_schedule.py | 20 ++++----- ...eduler_utils.py => data_schedule_utils.py} | 8 ++-- megatron/core/datasets/readme.md | 22 +++++++++ megatron/core/model_parallel_config.py | 2 +- .../text/libraries/sft_tokenizer.py | 1 + .../core/transformer/transformer_config.py | 2 +- megatron/training/arguments.py | 5 --- megatron/training/datasets/sft_dataset.py | 45 ++++++++++--------- .../unit_tests/models/test_mamba_moe_model.py | 1 + tests/unit_tests/test_sequence_packing.py | 2 - 10 files changed, 64 insertions(+), 44 deletions(-) rename megatron/core/datasets/{data_scheduler_utils.py => data_schedule_utils.py} (98%) diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 76e53c6fe37..deef93b8ebd 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -6,7 +6,7 @@ import torch from megatron.core import parallel_state -from megatron.core.datasets.data_scheduler_utils import ( +from megatron.core.datasets.data_schedule_utils import ( broadcast_scalars, broadcast_tensor, broadcast_to_pp_group, @@ -312,7 +312,7 @@ def __next__(self) -> Any: return samples_this_rank_with_id, sample_id_groups -class BaseScheduler: +class BasePackingScheduler: """Base class for sequence packing schedulers.""" def __init__( @@ -335,7 +335,7 @@ def __init__( self.dp_size = dp_size self.microbatch_group_size_per_vp_stage = microbatch_group_size_per_vp_stage - def get_require_sample_keys(self): + def get_required_sample_keys(self): """Return the required key of each batch.""" raise NotImplementedError @@ -376,14 +376,14 @@ def run( raise NotImplementedError -class DpBalancedScheduler(BaseScheduler): +class DpBalancedScheduler(BasePackingScheduler): """Packs sequences in their original order until reaching the max limit of sequence length.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_seq_len_all_ranks = self.max_seqlen_per_dp_cp_rank * self.cp_size - def get_require_sample_keys(self): + def get_required_sample_keys(self): """Return the required key of each batch.""" return [ "tokens", @@ -521,7 +521,7 @@ def run( ) # Step 2: Check required sample keys - for key in self.get_require_sample_keys(): + for key in self.get_required_sample_keys(): assert ( key in batch[0] ), f"Batch missing required key {key}, provided keys: {batch[0].keys()}" @@ -620,14 +620,14 @@ def run( ) -class PackingScheduler(enum.Enum): +class PackingSchedulerEnum(enum.Enum): """Enum for supported sequence packing algorithms.""" DP_BALANCED = "dp_balanced" -scheduler_map: Dict[PackingScheduler, Type[BaseScheduler]] = { - PackingScheduler.DP_BALANCED: DpBalancedScheduler +scheduler_map: Dict[PackingSchedulerEnum, Type[BasePackingScheduler]] = { + PackingSchedulerEnum.DP_BALANCED: DpBalancedScheduler } @@ -668,7 +668,7 @@ def wrap_data_iterator( # Convert string to enum scheduler_type = config.sequence_packing_scheduler - scheduler_type = PackingScheduler[scheduler_type.upper()] + scheduler_type = PackingSchedulerEnum[scheduler_type.upper()] scheduler = scheduler_map[scheduler_type]( config.max_seqlen_per_dp_cp_rank, diff --git a/megatron/core/datasets/data_scheduler_utils.py b/megatron/core/datasets/data_schedule_utils.py similarity index 98% rename from megatron/core/datasets/data_scheduler_utils.py rename to megatron/core/datasets/data_schedule_utils.py index 2a1e2a6528d..2e5635dfcd8 100644 --- a/megatron/core/datasets/data_scheduler_utils.py +++ b/megatron/core/datasets/data_schedule_utils.py @@ -8,7 +8,7 @@ from megatron.core.rerun_state_machine import RerunDataIterator -def _unpack_batch(batch): +def _unpack_batch(batch: List[Dict[str, torch.Tensor]]) -> List[Dict[str, torch.Tensor]]: """ Unpacks the packed samples into a list of sub-samples. Since each sub-sample may be routed to different DPxCP ranks, @@ -474,9 +474,9 @@ def get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group): else: raise ValueError(f"Invalid item type: {type(item)}") - # This unpack step is redundant: in sft_dataset.py, sequences are already packed before - # rescheduling, so we need to unpack them here and repack after rescheduling. This is only - # to adapt to the current megatron-lm sft_dataset. + # in sft_dataset.py, sequences are already packed before rescheduling, + # so we need to unpack them here and repack after rescheduling. + # This is only to adapt to the current megatron-lm sft_dataset. # If you implement your own dataset, just have __getitem__ return List[Dict] # and this step can be skipped. batch = _unpack_batch(batch) diff --git a/megatron/core/datasets/readme.md b/megatron/core/datasets/readme.md index 452bf24e4a2..64889feb481 100644 --- a/megatron/core/datasets/readme.md +++ b/megatron/core/datasets/readme.md @@ -192,6 +192,28 @@ To query the `BlendedDataset` for the _k_-th sample we do the following To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `BlendedDataset.__init__` function. +## Packing Scheduler + +The packing scheduler re-schedules variable-length sequences across DP×CP ranks to improve GPU utilization. It is built around the following modules: + +### `data_schedule` + +This module contains the high-level scheduling logic and entry points: + +- **`HybridCPDataLoaderWrapper`**: A wrapper class for hybrid context parallel (CP) scheduling. For every `__next__` call, it: (1) pulls a batch of packed samples from each DP rank, (2) gathers sequence lengths across the DP group, (3) schedules sub-samples using the `BalancedCPScheduler`, (4) reroutes sub-samples to the correct DPxCP ranks via all-to-all communication. + +- **`BasePackingScheduler`**: Abstract base class for packing schedulers. Defines the interface for `get_groups_and_subsamples()` (scheduling algorithm) and `run()` (full scheduling pipeline including fetch, schedule, reroute, pack, broadcast, and VPP handling). + +- **`DpBalancedScheduler`**: A concrete scheduler that packs sequences in their original order until reaching the max sequence length limit per DPxCP rank. Supports aligning the number of microbatches to DP size and VPP stage multiples. + +- **`wrap_data_iterator()`**: Top-level entry point that wraps an existing `data_iterator`. It creates the appropriate scheduler, runs the scheduling pipeline, broadcast metadata and new num_microbatches, returns a new data iterator along with the updated number of microbatches and FLOPs statistics. + +- **`get_batch_on_this_rank_for_sequence_packing()`**: Fetches and broadcasts a single packed microbatch for the current rank. Handles TP/PP broadcasting, constructs `PackedSeqParams` (with `cu_seqlens`, `max_seqlen`, `qkv_format=thd`), and optionally partitions sequences across CP ranks using Transformer Engine's `thd_get_partitioned_indices`. + +### `data_schedule_utils.py` + +This module contains the utility functions used by the schedulers. + ## Fast DataLoader initialization Especially for large-scale runs, DataLoader initialization can take several minutes, since it involves opening and memory-mapping multiple files and can significantly stress the filesystem. To speed up this process, we have developed the following three optimizations, controlled by configuration flags": diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 442f5a31e3c..970b3b871fe 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -72,7 +72,7 @@ class ModelParallelConfig: Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel. """ - sequence_packing_scheduler: Optional[str] = None + sequence_packing_scheduler: Optional[Literal['dp_balanced']] = None """ Scheduler for sequence packing and hybrid context parallel. dp_balanced: DP-balanced scheduler for sequence packing. diff --git a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py index 8a418f2dd7f..b9a4f7b0e4b 100644 --- a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py @@ -108,6 +108,7 @@ def __init__(self, tokenizer_path: str, prompt_format: str): self._prompt_format = prompt_format + def tokenize_conversation( self, conversation: List[Dict], return_target: bool, add_generation_prompt: bool ): diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 71508401657..d48e29c1e71 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2106,7 +2106,7 @@ def __post_init__(self): and self.sequence_packing_scheduler not in supported_schedulers ): raise ValueError( - f"Unknown scheduler: {self.sequence_packing_scheduler}. " + f"Unsupported scheduler: {self.sequence_packing_scheduler}. " f"Available schedulers: {supported_schedulers}" ) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 56c3c077d46..3326eebf833 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1065,9 +1065,6 @@ def validate_args(args, defaults={}): assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ f'must be >= single sequence max length ({args.seq_length})' - # TODO(tailaim): add support for other dispatcher types - print(f"Setting moe_token_dispatcher_type to alltoall for sft sequence packing with pipeline parallelism") - args.moe_token_dispatcher_type = "alltoall" if args.mock_data and args.sft_mock_dataset_config_json is None: args.sft_mock_dataset_config_json = json.dumps( { @@ -3084,8 +3081,6 @@ def _add_sft_args(parser): group.add_argument('--sft', action="store_true", help='Megatron SFT training') group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", help='SFT prompt format.') - group.add_argument('--sequence-packing', action='store_true', - help='use sequence packing(thd format) for training') group.add_argument('--sft-mock-dataset-config-json', type=str, default=None, help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution. ' 'If not specified and --mock-data is set, defaults to a lognormal distribution with ' diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index 1672ad75310..818f3383377 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -218,7 +218,11 @@ class MockSFTLowLevelDataset: """The low-level mock dataset for SFT Args: - mock_config (dict): The config for mock dataset. + mode (str): Either 'file' or 'distribution'. + **kwargs: Additional arguments depending on mode. + For mode='file': path (str) - path to a CSV file with sequence lengths. + For mode='distribution': type (str), min_seq_len (int), max_seq_len (int), + mean_seq_len (int), and distribution-specific params (e.g. lognormal_sigma). """ seed: int = 0 @@ -226,28 +230,26 @@ class MockSFTLowLevelDataset: size: int = 1000000 """The hard-coded number of sequence to generate""" - - # This is to maintain consistency with the SFT dataset that uses real data. In the real dataset, an element in the low-level dataset often contains multiple sequences. So here, each element in the mock low-level dataset also contains num_sequence_per_sample sequences. This will be made more reasonable in the future. - - def __init__(self, config: Dict) -> None: + def __init__(self, mode: str, **kwargs) -> None: np.random.seed(self.seed) - # either choose to load sequence lengths from external file, or generate random sequence lengths - - assert "mode" in config, f"mode must be set, either 'file' or 'distribution'" - - if config["mode"] == "file": - self.sequence_lengths = np.array(pd.read_csv(config["path"])).flatten() + + if mode == "file": + self.sequence_lengths = np.array(pd.read_csv(kwargs["path"])).flatten() self.size = len(self.sequence_lengths) - elif config["mode"] == "distribution": - min_seq_len = config["min_seq_len"] - max_seq_len = config["max_seq_len"] - mean_seq_len = config["mean_seq_len"] - if config["type"] == "lognormal": - lognormal_sigma = config["lognormal_sigma"] - self.sequence_lengths = self.generate_lognormal_samples(self.size, mean_seq_len,lognormal_sigma, min_seq_len, max_seq_len) + elif mode == "distribution": + min_seq_len = kwargs["min_seq_len"] + max_seq_len = kwargs["max_seq_len"] + mean_seq_len = kwargs["mean_seq_len"] + if kwargs["type"] == "lognormal": + lognormal_sigma = kwargs["lognormal_sigma"] + self.sequence_lengths = self.generate_lognormal_samples( + self.size, mean_seq_len, lognormal_sigma, min_seq_len, max_seq_len + ) else: - raise ValueError(f"Unsupported sequence length distribution type {config['type']}") + raise ValueError(f"Unsupported distribution type {kwargs['type']}") + else: + raise ValueError(f"Unsupported mode '{mode}', must be 'file' or 'distribution'") def generate_lognormal_samples(self, size, mean, sigma, min_seq_len, max_seq_len): mu = np.log(mean) - sigma**2 / 2 @@ -259,9 +261,10 @@ def __len__(self) -> int: return self.size def __getitem__(self, idx: int) -> List[np.ndarray]: - length = self.sequence_lengths[idx % self.size] # the length of sample is 'length', but only length-1 elements are generated here, # because an eod token will be appended at the end later in SFTDataset + + length = self.sequence_lengths[idx % self.size] sample = np.arange(1, length, dtype=np.int64) return sample class MockSFTDataset(SFTDataset): @@ -281,7 +284,7 @@ def __init__( @staticmethod def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowLevelDataset: mock_config = json.loads(config.sft_mock_dataset_config_json) - return MockSFTLowLevelDataset(mock_config) + return MockSFTLowLevelDataset(**mock_config) def __len__(self) -> int: return self.num_samples diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index 39b4a18e243..9797f5c20f7 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -275,6 +275,7 @@ "offload_modules": [], "hybrid_context_parallel": False, "max_seqlen_per_dp_cp_rank": None, + "sequence_packing_scheduler": None, "fallback_to_eager_attn": False, "linear_attention_type": None, "moe_router_force_biased": None, diff --git a/tests/unit_tests/test_sequence_packing.py b/tests/unit_tests/test_sequence_packing.py index 288e66f2f93..60316b0236e 100644 --- a/tests/unit_tests/test_sequence_packing.py +++ b/tests/unit_tests/test_sequence_packing.py @@ -9,9 +9,7 @@ from megatron.core import parallel_state from megatron.core.datasets.data_schedule import ( - PackingScheduler, get_batch_on_this_rank_for_sequence_packing, - scheduler_map, wrap_data_iterator, ) from megatron.core.rerun_state_machine import RerunDataIterator From fc75657faf82c464398bc48b6b47259c84064908 Mon Sep 17 00:00:00 2001 From: tailaim Date: Sat, 14 Feb 2026 05:25:44 -0800 Subject: [PATCH 4/9] minor fixes Signed-off-by: tailaim --- megatron/core/datasets/data_schedule.py | 15 ++------- megatron/core/datasets/gpt_dataset.py | 5 --- megatron/core/pipeline_parallel/schedules.py | 9 ------ megatron/training/arguments.py | 16 +--------- megatron/training/datasets/sft_dataset.py | 32 ++++++++++++++------ megatron/training/training.py | 31 ++++++++++++++++--- pretrain_gpt.py | 1 - 7 files changed, 52 insertions(+), 57 deletions(-) diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index deef93b8ebd..3039980e0aa 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -1,6 +1,5 @@ # Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. -import enum from typing import Any, Dict, List, Optional, Type import torch @@ -69,6 +68,7 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: Gathers the sequence lengths of all subsamples from all DP ranks. Each DP rank loads the same number of microbatches but each microbatch may have a different number of subsamples. + We find the number of subsamples each rank holds and then gather the sequence lengths of all subsamples from all ranks. """ @@ -620,15 +620,7 @@ def run( ) -class PackingSchedulerEnum(enum.Enum): - """Enum for supported sequence packing algorithms.""" - - DP_BALANCED = "dp_balanced" - - -scheduler_map: Dict[PackingSchedulerEnum, Type[BasePackingScheduler]] = { - PackingSchedulerEnum.DP_BALANCED: DpBalancedScheduler -} +scheduler_map: Dict[str, Type[BasePackingScheduler]] = {"dp_balanced": DpBalancedScheduler} def wrap_data_iterator( @@ -666,9 +658,8 @@ def wrap_data_iterator( dp_size = dp_group.size() cp_size = dp_cp_group.size() // dp_size - # Convert string to enum + # Look up the scheduler class by name scheduler_type = config.sequence_packing_scheduler - scheduler_type = PackingSchedulerEnum[scheduler_type.upper()] scheduler = scheduler_map[scheduler_type]( config.max_seqlen_per_dp_cp_rank, diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index bea7f4a4c18..04d2c279818 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -82,11 +82,6 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): sft_mock_dataset_config_json: Optional[str] = None """This config provides the necessary information for the mock dataset.""" - sequence_packing_scheduler: Optional[str] = None - """Scheduler for sequence packing and hybrid context parallel. - dp_balanced: DP-balanced scheduler for sequence packing. - """ - def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 23727f9c751..e903f392bf0 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -554,9 +554,6 @@ def forward_backward_no_pipelining( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) - pg_collection.dp = parallel_state.get_data_parallel_group( - with_context_parallel=False, partial_data_parallel=False - ) elif pg_collection is not None: assert hasattr(pg_collection, 'tp') @@ -1014,9 +1011,6 @@ def forward_backward_pipelining_with_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) - pg_collection.dp = parallel_state.get_data_parallel_group( - with_context_parallel=False, partial_data_parallel=False - ) elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model[0]) @@ -2162,9 +2156,6 @@ def forward_backward_pipelining_without_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) - pg_collection.dp = parallel_state.get_data_parallel_group( - with_context_parallel=False, partial_data_parallel=False - ) elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model) assert hasattr(p2p_communicator, 'config'), "p2p_communicator must have a config" diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 3326eebf833..1e7db76526e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1065,17 +1065,6 @@ def validate_args(args, defaults={}): assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ f'must be >= single sequence max length ({args.seq_length})' - if args.mock_data and args.sft_mock_dataset_config_json is None: - args.sft_mock_dataset_config_json = json.dumps( - { - "mode": "distribution", - "type": "lognormal", - "min_seq_len": args.seq_length // 2, - "max_seq_len": args.seq_length, - "mean_seq_len": args.seq_length // 4 * 3, - "lognormal_sigma": 1.1, - } - ) # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled @@ -1831,9 +1820,6 @@ def _add_network_size_args(parser): "persist_layer_norm", "bias_dropout_fusion", "apply_rope_fusion", - "max_seqlen_per_dp_cp_rank", - "hybrid_context_parallel", - "sequence_packing_scheduler", ] transformer_factory = ArgumentGroupFactory(TransformerConfig, exclude=exclude) transformer_group = transformer_factory.build_group(parser, "transformer configuration") @@ -2535,7 +2521,7 @@ def _add_distributed_args(parser): 'all layers will share the same communication type. Users can also ' 'specify separated types for each layer like ' '--cp-comm-type p2p p2p a2a a2a a2a+p2p a2a+p2p') - group.add_argument('--sequence-packing-scheduler', type=str, default='default_sequence_packing', choices=['default_sequence_packing']) + group.add_argument('--sequence-packing-scheduler', type=str, default=None, choices=['dp_balanced']) group.add_argument('--fake-process-group', action='store_true', default=False, help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \ This is quite useful for profiling memory usage of distributed training with just one GPU. \ diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index 818f3383377..60266f9a83f 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -50,6 +50,7 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> list: return self.dataset[idx]["messages"] + class SFTDataset(MegatronDataset): """The dataset used during SFT""" @@ -63,8 +64,6 @@ def __init__( config: GPTDatasetConfig, ) -> None: super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) - # Pre-calculate padding divisor to avoid redundant computation in get_item - self.padding_divisor = self._calculate_padding_divisor() @staticmethod def numel_low_level_dataset(low_level_dataset: LowLevelDataset) -> int: @@ -148,7 +147,7 @@ def extend_with_padding(tokens, targets, positions, pad_len): assert not self.config.reset_position_ids pack_positions.extend(range(len(tokens_list))) - pad_granularity = self.padding_divisor + pad_granularity = self._calculate_padding_divisor() mod_token_count = len(pack_tokens) % pad_granularity if mod_token_count != 0: pad_len = pad_granularity - mod_token_count @@ -214,6 +213,7 @@ def extend_with_padding(tokens, targets, positions, pad_len): 'max_seqlen': max_seqlen, } + class MockSFTLowLevelDataset: """The low-level mock dataset for SFT @@ -267,6 +267,8 @@ def __getitem__(self, idx: int) -> List[np.ndarray]: length = self.sequence_lengths[idx % self.size] sample = np.arange(1, length, dtype=np.int64) return sample + + class MockSFTDataset(SFTDataset): """The mock dataset used during SFT""" @@ -283,7 +285,17 @@ def __init__( @staticmethod def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowLevelDataset: - mock_config = json.loads(config.sft_mock_dataset_config_json) + if config.sft_mock_dataset_config_json is None: + mock_config = { + "mode": "distribution", + "type": "lognormal", + "min_seq_len": config.sequence_length // 2, + "max_seq_len": config.sequence_length, + "mean_seq_len": config.sequence_length // 4 * 3, + "lognormal_sigma": 1.1, + } + else: + mock_config = json.loads(config.sft_mock_dataset_config_json) return MockSFTLowLevelDataset(**mock_config) def __len__(self) -> int: @@ -315,11 +327,11 @@ def extend_with_padding(tokens, targets, positions, pad_len): cu_seqlens = [0] # Pad to padding_divisor alignment - if self.padding_divisor > 1: - mod_token_count = len(pack_tokens) % self.padding_divisor - if mod_token_count != 0: - pad_len = self.padding_divisor - mod_token_count - extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + pad_granularity = self._calculate_padding_divisor() + mod_token_count = len(pack_tokens) % pad_granularity + if mod_token_count != 0: + pad_len = pad_granularity - mod_token_count + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) # Record padded boundary after padding cu_seqlens.append(len(pack_tokens)) @@ -367,4 +379,4 @@ def extend_with_padding(tokens, targets, positions, pad_len): 'position_ids': position_ids, 'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, - } \ No newline at end of file + } diff --git a/megatron/training/training.py b/megatron/training/training.py index a9be5c40469..1d865eaa4f0 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -139,6 +139,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): from megatron.training.initialize import set_jit_fusion_options from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank from megatron.training.datasets.data_samplers import build_pretraining_data_loader +from megatron.core.datasets.data_schedule import HybridCPDataLoaderWrapper from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.transformer.moe import upcycling_utils from megatron.core.transformer.moe.moe_utils import track_moe_metrics, clear_aux_losses_tracker @@ -1749,8 +1750,8 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch else: # data_iterator unchanged num_microbatches = get_num_microbatches() - seqlen_sum_this_global_batch = args.seq_length * args.micro_batch_size * args.data_parallel_size * num_microbatches - seqlen_squared_sum_this_global_batch = args.seq_length ** 2 * args.micro_batch_size * args.data_parallel_size * num_microbatches + seqlen_sum_this_global_batch = args.seq_length * args.global_batch_size + seqlen_squared_sum_this_global_batch = args.seq_length ** 2 * args.global_batch_size # Forward pass. if save_dgrads_in_this_iteration: @@ -1790,10 +1791,9 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch # iteration is 0-indexed, move to 1-indexed for checkpoint name and logging. save_grads(args.save, state_dict, iteration + 1, "wgrads") - should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() if should_exit: - return {}, True, should_checkpoint, should_exit, exit_code, None, None, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, 0 + return {}, True, should_checkpoint, should_exit, exit_code, None, None, 0, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch # Empty unused memory. if args.empty_unused_memory_level >= 1: @@ -3249,9 +3249,30 @@ def evaluate( # Don't care about timing during evaluation config.timers = None ft_integration.on_eval_step_start() + if config.sequence_packing_scheduler is not None: + # This wrapper is designed to support DP-balanced THD and dynamic-CP. + # Before wrapping, the data_iterator returns either a single sequence per get_item call, or a list where each element is a sequence. + # The wrapper is responsible for: + # 1. scheduling the sequences across ranks + # 2. packing them into THD format + # 3. broadcast flops parametes and num_microbatches to TP ranks to support unfixed num_microbatches + # 4. broadcast metadata(cu_seqlens, cu_seqlens_padded, max_seqlen, etc.) to PP ranks to + # 5. returning the packed data iterator and the FLOPs parameters + try: + ( + packed_data_iterator, + eval_num_microbatches, + _, + _, + ) = wrap_data_iterator(data_iterator, config, eval_num_microbatches) + except StopIteration: + # Validation data iterator exhausted, stop evaluation early. + break + else: + packed_data_iterator = data_iterator loss_dicts = forward_backward_func( forward_step_func=forward_step_func, - data_iterator=data_iterator, + data_iterator=packed_data_iterator, model=model, num_microbatches=eval_num_microbatches, seq_length=args.seq_length, diff --git a/pretrain_gpt.py b/pretrain_gpt.py index b9222a8c001..083f97b0a2f 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -262,7 +262,6 @@ def core_gpt_dataset_config_from_args(args): "sequence_parallel_size": args.tensor_model_parallel_size*args.sequence_parallel, "hybrid_context_parallel": args.hybrid_context_parallel, "sft_mock_dataset_config_json":args.sft_mock_dataset_config_json, - "sequence_packing_scheduler": args.sequence_packing_scheduler, } # add FIM args to the config From 9b1ab25afc2eb4eb7571368473fa2582b6dbced6 Mon Sep 17 00:00:00 2001 From: tailaim Date: Tue, 24 Feb 2026 22:52:37 -0800 Subject: [PATCH 5/9] adjust readme Signed-off-by: tailaim --- megatron/core/datasets/readme.md | 58 +++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 9 deletions(-) diff --git a/megatron/core/datasets/readme.md b/megatron/core/datasets/readme.md index 64889feb481..a61c623d960 100644 --- a/megatron/core/datasets/readme.md +++ b/megatron/core/datasets/readme.md @@ -194,25 +194,65 @@ To save time during initialization, each index is built/cached sequentially on o ## Packing Scheduler -The packing scheduler re-schedules variable-length sequences across DP×CP ranks to improve GPU utilization. It is built around the following modules: +The packing scheduler re-schedules variable-length sequences across DP×CP ranks to improve GPU utilization. It is built around two modules: `data_schedule.py` (high-level logic and entry points) and `data_schedule_utils.py` (utility functions). -### `data_schedule` +### Call Hierarchy -This module contains the high-level scheduling logic and entry points: +The scheduling pipeline has two phases connected by the data iterator: `wrap_data_iterator` consumes the **original** data iterator, performs global-batch scheduling, and produces a **wrapped** (packed) data iterator; `get_batch_on_this_rank_for_sequence_packing` then consumes this **wrapped** data iterator to fetch individual packed microbatches during training. -- **`HybridCPDataLoaderWrapper`**: A wrapper class for hybrid context parallel (CP) scheduling. For every `__next__` call, it: (1) pulls a batch of packed samples from each DP rank, (2) gathers sequence lengths across the DP group, (3) schedules sub-samples using the `BalancedCPScheduler`, (4) reroutes sub-samples to the correct DPxCP ranks via all-to-all communication. +``` + original wrapped (packed) + data_iterator data_iterator + │ │ + ▼ ▼ + ┌────────────────────────┐ ┌────────────────────────────────────┐ + │ wrap_data_iterator() │ │ get_batch_on_this_rank_for_ │ +Phase 1 │ (once per global │ ────────► │ sequence_packing() │ Phase 2 +(scheduling) │ batch) │ returns │ (once per microbatch, │ (fetching) + │ │ wrapped │ called by training loop) │ + └───────────┬────────────┘ iterator └──────────────┬─────────────────────┘ + │ │ + ▼ ▼ + DpBalancedScheduler.run() next(wrapped_data_iterator) + │ ├─ get_thd_partitioned_indices() [TE] + ├─ get_batch_and_global_seqlens() [utils] ├─ broadcast_tensor() [utils] + ├─ get_groups_and_subsamples() └─ PackedSeqParams(...) + ├─ reroute_samples_to_dcp_ranks() [utils] + ├─ build_packed_microbatches() [utils] + ├─ broadcast_to_pp_group() [utils] + ├─ broadcast_scalars() [utils] + └─ create_data_iterator() [utils] +``` -- **`BasePackingScheduler`**: Abstract base class for packing schedulers. Defines the interface for `get_groups_and_subsamples()` (scheduling algorithm) and `run()` (full scheduling pipeline including fetch, schedule, reroute, pack, broadcast, and VPP handling). +### `data_schedule.py` -- **`DpBalancedScheduler`**: A concrete scheduler that packs sequences in their original order until reaching the max sequence length limit per DPxCP rank. Supports aligning the number of microbatches to DP size and VPP stage multiples. +#### Entry Points -- **`wrap_data_iterator()`**: Top-level entry point that wraps an existing `data_iterator`. It creates the appropriate scheduler, runs the scheduling pipeline, broadcast metadata and new num_microbatches, returns a new data iterator along with the updated number of microbatches and FLOPs statistics. +- **`wrap_data_iterator(original_data_iterator) → wrapped_data_iterator`** — Top-level entry point called once per global batch. Takes the **original** data iterator as input, resolves the scheduler class from `scheduler_map`, instantiates it, and delegates to `scheduler.run()` which consumes all microbatches from the original iterator, re-schedules them, and produces a **wrapped** (packed) data iterator along with the updated `num_microbatches` and FLOPs statistics. -- **`get_batch_on_this_rank_for_sequence_packing()`**: Fetches and broadcasts a single packed microbatch for the current rank. Handles TP/PP broadcasting, constructs `PackedSeqParams` (with `cu_seqlens`, `max_seqlen`, `qkv_format=thd`), and optionally partitions sequences across CP ranks using Transformer Engine's `thd_get_partitioned_indices`. +- **`get_batch_on_this_rank_for_sequence_packing(wrapped_data_iterator)`** — Per-microbatch entry point called by the training loop. Takes the **wrapped** data iterator returned by `wrap_data_iterator` as input. Fetches one packed microbatch via `next(wrapped_data_iterator)`, broadcasts batch fields across TP ranks, optionally partitions sequences across CP ranks using Transformer Engine's `thd_get_partitioned_indices`, and constructs `PackedSeqParams` (with `cu_seqlens`, `max_seqlen`, `qkv_format=thd`). + +#### Scheduler Classes + +- **`BasePackingScheduler`** — Abstract base class. Defines the interface: + - `get_groups_and_subsamples()` — pure scheduling algorithm (must be overridden). + - `run()` — full pipeline: fetch → schedule → reroute → pack → broadcast → VPP handling. + +- **`DpBalancedScheduler(BasePackingScheduler)`** — Concrete scheduler that packs sequences in their original order until reaching `max_seqlen_per_dp_cp_rank × cp_size`. Aligns the number of microbatches to `dp_size` (and VPP stage multiples when applicable). ### `data_schedule_utils.py` -This module contains the utility functions used by the schedulers. +Utility functions consumed by the schedulers above: + +| Function | Role | +|---|---| +| `get_batch_and_global_seqlens()` | Fetch `num_microbatches` batches from the data iterator and all-gather sequence lengths across DP ranks. | +| `reroute_samples_to_dcp_ranks()` | All-to-all communication to transfer sub-samples to their scheduled DP×CP rank. | +| `build_packed_microbatches()` | Concatenate sub-samples within each microbatch group and produce `cu_seqlens`. | +| `broadcast_to_pp_group()` | Broadcast packed samples and metadata from the first/last PP stage to middle stages. | +| `broadcast_scalars()` | Broadcast scalar values (e.g. `num_microbatches`, FLOPs stats) across a process group. | +| `broadcast_tensor()` | Broadcast a single tensor within a process group. | +| `create_data_iterator()` | Wrap packed sample lists into a data iterator; handles VPP stage splitting. | ## Fast DataLoader initialization From eb7cee881723226daf0ec3d0e5cbddacd0ea37b8 Mon Sep 17 00:00:00 2001 From: kunlunl Date: Mon, 2 Mar 2026 16:44:13 +0800 Subject: [PATCH 6/9] Modify mock dataset to do convergence test --- megatron/core/datasets/data_schedule.py | 2 +- megatron/training/datasets/sft_dataset.py | 138 +++++++++++++--------- 2 files changed, 84 insertions(+), 56 deletions(-) diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 3039980e0aa..7d556ac66f4 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -431,7 +431,7 @@ def get_groups_and_subsamples(self, sample_id_seqlens): num_to_move = multiple - remainder i = num_packed_sequence - 1 while num_to_move > 0: - assert i > 0, "Not enough samples to move" + assert i >= 0, "Not enough samples to move" if len(packed_id_groups[i]) > 1: seq_id = packed_id_groups[i].pop() packed_id_groups.append([seq_id]) diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index 60266f9a83f..a7dc5647ae5 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -11,6 +11,7 @@ import torch from megatron.core.datasets.gpt_dataset import GPTDatasetConfig +from megatron.core.datasets.indexed_dataset import IndexedDataset from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset from megatron.core.datasets.utils import Split @@ -218,11 +219,17 @@ class MockSFTLowLevelDataset: """The low-level mock dataset for SFT Args: - mode (str): Either 'file' or 'distribution'. + mode (str): One of 'file', 'distribution', or 'verification'. **kwargs: Additional arguments depending on mode. For mode='file': path (str) - path to a CSV file with sequence lengths. For mode='distribution': type (str), min_seq_len (int), max_seq_len (int), mean_seq_len (int), and distribution-specific params (e.g. lognormal_sigma). + For mode='verification': data_path (str) - prefix path to an IndexedDataset + (.bin/.idx files). Optional lognormal distribution params same as + 'distribution' mode (defaults: min_seq_len=100, max_seq_len=4096, + mean_seq_len=2048, lognormal_sigma=1.1). + format (str): Output format for MockSFTDataset. Either 'thd' (default, sequence + packing with cu_seqlens) or 'sbhd' (padded to seq_length, no cu_seqlens). """ seed: int = 0 @@ -233,6 +240,7 @@ class MockSFTLowLevelDataset: def __init__(self, mode: str, **kwargs) -> None: np.random.seed(self.seed) + self.format = kwargs.get("format", "thd") if mode == "file": self.sequence_lengths = np.array(pd.read_csv(kwargs["path"])).flatten() @@ -248,8 +256,20 @@ def __init__(self, mode: str, **kwargs) -> None: ) else: raise ValueError(f"Unsupported distribution type {kwargs['type']}") + elif mode == "verification": + # Load real tokens from an IndexedDataset for realistic loss curves. + # Sequence lengths are drawn from a lognormal distribution (same as + # "distribution" mode) to allow controlled comparison of THD vs SBHD. + self.indexed_dataset = IndexedDataset(kwargs["data_path"]) + min_seq_len = kwargs.get("min_seq_len", 100) + max_seq_len = kwargs.get("max_seq_len", 4096) + mean_seq_len = kwargs.get("mean_seq_len", 2048) + lognormal_sigma = kwargs.get("lognormal_sigma", 1.1) + self.sequence_lengths = self.generate_lognormal_samples( + self.size, mean_seq_len, lognormal_sigma, min_seq_len, max_seq_len + ) else: - raise ValueError(f"Unsupported mode '{mode}', must be 'file' or 'distribution'") + raise ValueError(f"Unsupported mode '{mode}', must be 'file', 'distribution', or 'verification'") def generate_lognormal_samples(self, size, mean, sigma, min_seq_len, max_seq_len): mu = np.log(mean) - sigma**2 / 2 @@ -260,13 +280,17 @@ def generate_lognormal_samples(self, size, mean, sigma, min_seq_len, max_seq_len def __len__(self) -> int: return self.size - def __getitem__(self, idx: int) -> List[np.ndarray]: - # the length of sample is 'length', but only length-1 elements are generated here, - # because an eod token will be appended at the end later in SFTDataset - - length = self.sequence_lengths[idx % self.size] - sample = np.arange(1, length, dtype=np.int64) - return sample + def __getitem__(self, idx: int) -> np.ndarray: + # The returned sample has 'length-1' tokens; an EOD token is appended + # later in MockSFTDataset.__getitem__, making the total 'length' tokens. + length = int(self.sequence_lengths[idx % self.size]) + if hasattr(self, 'indexed_dataset'): + doc_idx = idx % len(self.indexed_dataset) + raw = self.indexed_dataset[doc_idx] + sample = raw[:length - 1] + return sample.astype(np.int64) + else: + return np.arange(1, length, dtype=np.int64) class MockSFTDataset(SFTDataset): @@ -310,67 +334,71 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: tokens = self.dataset[int(self.indices[idx % len(self.indices)])] - def extend_with_padding(tokens, targets, positions, pad_len): + # Convert tokens to list and always append EOD to ensure length consistency. + # The low-level dataset returns length-1 tokens, and we add EOD to make it length tokens. + tokens_list = tokens.tolist() + tokens_list.append(eod) + + if self.dataset.format == "sbhd": + # SBHD format: single padded sequence without cu_seqlens. + # Long sequences are truncated to pack_length tokens (including EOD). + if len(tokens_list) >= pack_length + 1: + tokens_list = tokens_list[:pack_length - 1] + [eod] + # Pad to pack_length + 1 (offset by 1 for input/label split). + pad_len = pack_length + 1 - len(tokens_list) + if pad_len > 0: + tokens_list = tokens_list + [pad] * pad_len + assert len(tokens_list) == pack_length + 1 + input_ids = torch.tensor(tokens_list[:-1], dtype=torch.int64) + labels = torch.tensor(tokens_list[1:], dtype=torch.int64) + # Position IDs are sequential across the entire sequence including padding, + # matching GPTDataset behavior for standard (non-packed) training. + position_ids = torch.arange(pack_length, dtype=torch.int64) + loss_mask = torch.ones(pack_length, dtype=torch.float32) + loss_mask[labels == pad] = 0.0 + return { + 'tokens': input_ids, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + + # THD format (sequence packing) below. + def extend_with_padding(tokens, positions, pad_len): tokens.extend([pad] * pad_len) - targets.extend([pad] * pad_len) positions.extend(range(positions[-1] + 1, positions[-1] + 1 + pad_len)) - # Convert tokens to list and add EOD - tokens_list = tokens.tolist() - if tokens_list[-1] != eod: - tokens_list.append(eod) - targets_list = list(tokens_list) + pack_tokens = list(tokens_list) + [pad] + pack_positions = list(range(len(pack_tokens))) - pack_tokens = list(tokens_list) - pack_targets = list(targets_list) - pack_positions = list(range(len(tokens_list))) - cu_seqlens = [0] + # Truncate if sequence exceeds pack_length + 1 (need +1 for shift). + if len(pack_tokens) > pack_length + 1: + pack_tokens = pack_tokens[:pack_length - 1] + [eod, pad] + pack_positions = pack_positions[:pack_length + 1] - # Pad to padding_divisor alignment + # Pad to pad_granularity alignment (tp * cp * 2). + # We need final length (after shift) to be divisible by pad_granularity. pad_granularity = self._calculate_padding_divisor() - mod_token_count = len(pack_tokens) % pad_granularity + final_len = len(pack_tokens) - 1 + mod_token_count = final_len % pad_granularity if mod_token_count != 0: pad_len = pad_granularity - mod_token_count - extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) - - # Record padded boundary after padding - cu_seqlens.append(len(pack_tokens)) - - # Handle any necessary truncation - if len(pack_tokens) >= pack_length + 1: # +1 here to account for later alignment - max_body = pack_length - 1 - pack_tokens = pack_tokens[:max_body] - pack_targets = pack_targets[:max_body] - pack_tokens.extend([eod, pad]) - pack_targets.extend([eod, pad]) - pack_positions = pack_positions[:pack_length + 1] - cu_seqlens[-1] = len(pack_tokens) - 1 - - # Handle any necessary padding - if len(pack_tokens) < pack_length + 1: # +1 here to account for later alignment - pad_len = pack_length + 1 - len(pack_tokens) - extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) - cu_seqlens[-1] = len(pack_tokens) - 1 + extend_with_padding(pack_tokens, pack_positions, pad_len) - assert len(pack_tokens) == pack_length + 1 - assert len(pack_targets) == pack_length + 1 - assert len(pack_positions) == pack_length + 1 - - # Align and convert to tensors + # Apply shift for next-token prediction. input_ids = torch.tensor(pack_tokens[:-1], dtype=torch.int64) - labels = torch.tensor(pack_targets[1:], dtype=torch.int64) + labels = torch.tensor(pack_tokens[1:], dtype=torch.int64) position_ids = torch.tensor(pack_positions[:-1], dtype=torch.int64) - # Loss mask - loss_mask = torch.ones(pack_length, dtype=torch.float32) + seq_len = len(input_ids) + cu_seqlens = [0, seq_len] + + # Loss mask: mask padding tokens + loss_mask = torch.ones(seq_len, dtype=torch.float32) loss_mask[labels == pad] = 0.0 - loss_mask[labels == IGNORE_INDEX] = 0.0 - assert len(cu_seqlens) >= 2 cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) - # Calculating max_seqlen here because of possible effects of truncation and padding - adjacent_diffs = cu_seqlens[1:] - cu_seqlens[:-1] - max_seqlen = adjacent_diffs.max() # max_seqlen is a 0-D tensor + max_seqlen = torch.tensor(seq_len, dtype=torch.int32) return { 'tokens': input_ids, From f6ccf6e2265114968f4ec58696e522c8b4b51b57 Mon Sep 17 00:00:00 2001 From: tailaim Date: Mon, 2 Mar 2026 02:10:00 -0800 Subject: [PATCH 7/9] fixes according to comments Signed-off-by: tailaim --- megatron/core/datasets/data_schedule.py | 74 +++++++++---------- megatron/core/datasets/data_schedule_utils.py | 73 +++++++++++++----- megatron/training/arguments.py | 8 -- megatron/training/training.py | 2 +- 4 files changed, 89 insertions(+), 68 deletions(-) diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 7d556ac66f4..00591e4c24d 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -12,9 +12,9 @@ build_packed_microbatches, create_data_iterator, get_batch_and_global_seqlens, + get_cp_slice_for_thd, reroute_samples_to_dcp_ranks, ) -from megatron.core.extensions.transformer_engine import get_thd_partitioned_indices from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler from megatron.core.process_groups_config import ProcessGroupCollection @@ -491,29 +491,30 @@ def run( total_dcp_gpus = dp_cp_group.size() - # Handle VPP: extract the correct data_iterator for this PP stage + # Handle VPP: extract the correct data_iterator for this PP stage. + # When VPP is enabled, data_iterator is a list with one entry per VPP stage. + # We only need one data_iterator to run the schedule (all VPP stages on the + # same PP rank share the same underlying dataset), so pick the first non-None. + # Record which VPP stages had data so create_data_iterator knows which ones + # need full samples vs metadata only. + vpp_has_data = None if ( config.virtual_pipeline_model_parallel_size is not None and config.virtual_pipeline_model_parallel_size > 1 ): - # if enable VPP, data_iterator is a list of data_iterators for each VPP stage, - # and only the first and last stage rank will have data_iterator, - # other stages will have None. assert len(data_iterator) == config.virtual_pipeline_model_parallel_size - if pp_group.rank() == 0: - # the first stage - data_iterator = data_iterator[0] - elif pp_group.rank() == pp_group.size() - 1: - # the last stage - data_iterator = data_iterator[-1] - else: - data_iterator = None - - # data_iterator is not None when TP rank 0, with PP stage 0 or -1. + vpp_has_data = [di is not None for di in data_iterator] + extracted = None + for di in data_iterator: + if di is not None: + extracted = di + break + data_iterator = extracted + + # data_iterator is not None on TP rank 0 for PP stages that need data + # (first stage, last stage, or any stage with MTP). if data_iterator is not None: - assert tp_group.rank() == 0 and ( - pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1 - ), f"Only TP rank 0 and PP stage 0 or -1 should have data_iterator" + assert tp_group.rank() == 0, "Only TP rank 0 should have data_iterator" # Step 1: Fetch batches and gather global sequence lengths batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered = ( @@ -610,7 +611,7 @@ def run( num_micro_batches = int(num_micro_batches) # Step 9: create data_iterator and handle VPP if enabled - new_data_iterator = create_data_iterator(new_samples, pp_group, tp_group, config) + new_data_iterator = create_data_iterator(new_samples, tp_group, config, vpp_has_data) return ( new_data_iterator, @@ -729,10 +730,10 @@ def get_batch_on_this_rank_for_sequence_packing( # data_iterator should return a batch including the following keys. batch_keys = ['cu_seqlens', 'cu_seqlens_padded', 'max_seqlen'] - if is_first_stage: + if is_first_stage or mtp_on_this_rank: batch_keys.append('tokens') batch_keys.append('position_ids') - if is_last_stage: + if is_last_stage or mtp_on_this_rank: batch_keys.append('labels') batch_keys.append('loss_mask') @@ -746,21 +747,10 @@ def get_batch_on_this_rank_for_sequence_packing( assert data_iterator is None, "Non TP 0 rank should not have data_iterator" batch = {} - # Partition tokens, position_ids, labels, loss_mask for context parallel, currently only - # TP rank 0 and the first/last PP stage rank has these data. - if is_tp_rank_0 and is_first_or_last_stage: - cp_size = cp_group.size() - cp_rank = cp_group.rank() - # If cp_size == 1, no need to do further processing. - if cp_size > 1: - total_tokens = batch['tokens'].size(0) - # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as - # cu_seqlens to get the correct result. - # TODO: Revert this workaround once TE fixes the issue. - cu_seqlens = batch["cu_seqlens_padded"] - index = get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) - for key in ['tokens', 'position_ids', 'labels', 'loss_mask']: - batch[key] = batch[key].index_select(0, index) + # Partition tokens, position_ids, labels, loss_mask for context parallel. + # Only TP rank 0 on stages that have data (first/last PP stage or MTP stage) needs this. + if is_tp_rank_0 and (is_first_or_last_stage or mtp_on_this_rank): + get_cp_slice_for_thd(batch, cp_group) # Broadcast cu_seqlens_size because we need it to create placeholder for cu_seqlens and # cu_seqlens_padded for non TP 0 ranks. @@ -772,8 +762,10 @@ def get_batch_on_this_rank_for_sequence_packing( cu_seqlen_size = cu_seqlen_size.item() # Broadcast total_tokens because we need it to create placeholder for tokens, position_ids, - # labels, loss_mask for non TP 0 ranks. Only first or last stage need this. - if is_first_or_last_stage: + # labels, loss_mask for non TP 0 ranks. Only first stage, last stage, + # and stage with mtp need this. + + if is_first_or_last_stage or mtp_on_this_rank: if is_tp_rank_0: total_tokens = torch.tensor(batch['tokens'].size(0), dtype=torch.int32, device=dev) else: @@ -781,7 +773,7 @@ def get_batch_on_this_rank_for_sequence_packing( broadcast_tensor(total_tokens, tp_src_rank, tp_group) total_tokens = total_tokens.item() - # Step1: Prepare "tokens", "position_ids" on all ranks. + # Step1: Prepare "tokens", "position_ids" for first stage and stage with mtp on all TP ranks. if is_first_stage or mtp_on_this_rank: if is_tp_rank_0: assert batch['tokens'].dtype == torch.int64 @@ -796,8 +788,8 @@ def get_batch_on_this_rank_for_sequence_packing( batch['tokens'] = None batch['position_ids'] = None - # Step2: Prepare "labels", "loss_mask" on all ranks. - if is_last_stage: + # Step2: Prepare "labels", "loss_mask" for last stage and stage with mtp on all TP ranks. + if is_last_stage or mtp_on_this_rank: if is_tp_rank_0: assert batch['labels'].dtype == torch.int64 assert batch['loss_mask'].dtype == torch.float32 diff --git a/megatron/core/datasets/data_schedule_utils.py b/megatron/core/datasets/data_schedule_utils.py index 2e5635dfcd8..f3c637e4c79 100644 --- a/megatron/core/datasets/data_schedule_utils.py +++ b/megatron/core/datasets/data_schedule_utils.py @@ -5,9 +5,35 @@ import numpy as np import torch +from megatron.core.extensions.transformer_engine import get_thd_partitioned_indices from megatron.core.rerun_state_machine import RerunDataIterator +def get_cp_slice_for_thd(batch, cp_group): + """Partition sequence data for context parallelism in THD format. + + Uses TE's THD partitioned indices to split the packed sequence across CP ranks. + Only keys present in the batch are sliced. + + Args: + batch: Dict with packed sequence data. + cp_group: Context parallel process group. + """ + cp_size = cp_group.size() + if cp_size <= 1: + return + cp_rank = cp_group.rank() + total_tokens = batch['tokens'].size(0) + # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as + # cu_seqlens to get the correct result. + # TODO: Revert this workaround once TE fixes the issue. + cu_seqlens = batch["cu_seqlens_padded"] + index = get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) + for key in ['tokens', 'position_ids', 'labels', 'loss_mask']: + if key in batch: + batch[key] = batch[key].index_select(0, index) + + def _unpack_batch(batch: List[Dict[str, torch.Tensor]]) -> List[Dict[str, torch.Tensor]]: """ Unpacks the packed samples into a list of sub-samples. @@ -208,6 +234,11 @@ def broadcast_to_pp_group( max_seqlens = info_to_broadcast[3 : 3 + num_micro_batches] cu_seqlens_list = [] cu_seqlens_padded_list = [] + # cu_seqlens always starts with 0, and the other metadata values + # (num_micro_batches, seqlen_sum, seqlen_squared_sum, max_seqlens) + # are always positive, so we can use 0 as the delimiter to locate + # the start of each cu_seqlens / cu_seqlens_padded tensor. + # This avoids an extra broadcast for the lengths of cu_seqlens. indices = np.where(info_numpy == 0)[0] for i in range(num_micro_batches): cu_seqlens_list.append(info_to_broadcast[indices[i * 2] : indices[i * 2 + 1]]) @@ -266,31 +297,37 @@ def broadcast_scalars(values: List, group, dev, dtype=torch.float32) -> List: return values -def create_data_iterator(new_samples, pp_group, tp_group, config): - """Handle virtual pipeline parallelism.""" +def create_data_iterator(new_samples, tp_group, config, vpp_has_data=None): + """Handle virtual pipeline parallelism. + + For VPP, each PP rank needs a list of data iterators (one per VPP stage). + VPP stages that originally had a data_iterator (indicated by vpp_has_data) + get full samples; others get metadata only (cu_seqlens, cu_seqlens_padded, + max_seqlen). + + Args: + new_samples: The packed samples after scheduling. + tp_group: Tensor parallel process group. + config: Model parallel config. + vpp_has_data: A list of booleans (one per VPP stage) indicating which + VPP stages originally had a data_iterator. None if VPP is disabled. + """ if ( config.virtual_pipeline_model_parallel_size is not None and config.virtual_pipeline_model_parallel_size > 1 ): vpp_size = config.virtual_pipeline_model_parallel_size if tp_group.rank() == 0: - if pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1: - metadata = [ - {k: sample[k] for k in ["max_seqlen", "cu_seqlens", "cu_seqlens_padded"]} - for sample in new_samples - ] - if pp_group.rank() == 0: - new_data_iterator = [RerunDataIterator(iter(new_samples))] + [ - RerunDataIterator(iter(metadata)) for _ in range(vpp_size - 1) - ] + metadata = [ + {k: sample[k] for k in ["max_seqlen", "cu_seqlens", "cu_seqlens_padded"]} + for sample in new_samples + ] + new_data_iterator = [] + for i in range(vpp_size): + if vpp_has_data is not None and vpp_has_data[i]: + new_data_iterator.append(RerunDataIterator(iter(new_samples))) else: - new_data_iterator = [ - RerunDataIterator(iter(metadata)) for _ in range(vpp_size - 1) - ] + [RerunDataIterator(iter(new_samples))] - else: - # on middle PP stages, the new_samples are the metadata - metadata = new_samples - new_data_iterator = [RerunDataIterator(iter(metadata)) for _ in range(vpp_size)] + new_data_iterator.append(RerunDataIterator(iter(metadata))) else: new_data_iterator = [None for _ in range(vpp_size)] else: diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 1e7db76526e..25f0d0d06d0 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1054,14 +1054,7 @@ def validate_args(args, defaults={}): assert args.dataloader_type == 'single', 'Hybrid context parallelism only supported with single dataloader type' assert args.calculate_per_token_loss, 'Hybrid context parallelism must be used with --calculate-per-token-loss' - # Support for variable sequence lengths across batches/microbatches. - # set it if the dataloader supports generation of variable sequence lengths - # across batches/microbatches. Due to additional communication overhead - # during pipeline parallelism, it should not be set if sequence length - # is constant during training. - args.variable_seq_lengths = False if args.sequence_packing_scheduler is not None: - args.variable_seq_lengths = True assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ f'must be >= single sequence max length ({args.seq_length})' @@ -2521,7 +2514,6 @@ def _add_distributed_args(parser): 'all layers will share the same communication type. Users can also ' 'specify separated types for each layer like ' '--cp-comm-type p2p p2p a2a a2a a2a+p2p a2a+p2p') - group.add_argument('--sequence-packing-scheduler', type=str, default=None, choices=['dp_balanced']) group.add_argument('--fake-process-group', action='store_true', default=False, help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \ This is quite useful for profiling memory usage of distributed training with just one GPU. \ diff --git a/megatron/training/training.py b/megatron/training/training.py index 1d865eaa4f0..26769fabe96 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -539,7 +539,7 @@ def transformer_flops(): + args.hidden_size * v_dim ) - ) + ) * seqlen_sum_this_global_batch else: raise ValueError( "Invalid experimental_attention_variant: " From 816fca15aae9a09a0ee44116d58211de76f02962 Mon Sep 17 00:00:00 2001 From: kunlunl Date: Tue, 3 Mar 2026 00:17:46 +0800 Subject: [PATCH 8/9] Fix small seqlen of mock dataset --- megatron/training/datasets/sft_dataset.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index a7dc5647ae5..3f2e6e7362c 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -285,9 +285,25 @@ def __getitem__(self, idx: int) -> np.ndarray: # later in MockSFTDataset.__getitem__, making the total 'length' tokens. length = int(self.sequence_lengths[idx % self.size]) if hasattr(self, 'indexed_dataset'): - doc_idx = idx % len(self.indexed_dataset) + target = length - 1 + num_docs = len(self.indexed_dataset) + doc_idx = idx % num_docs raw = self.indexed_dataset[doc_idx] - sample = raw[:length - 1] + if len(raw) >= target: + sample = raw[:target] + else: + # Concatenate documents until we reach the target length. + chunks = [raw] + total = len(raw) + next_doc = doc_idx + 1 + while total < target: + raw_next = self.indexed_dataset[next_doc % num_docs] + need = target - total + chunks.append(raw_next[:need]) + total += min(len(raw_next), need) + next_doc += 1 + sample = np.concatenate(chunks)[:target] + assert len(sample) == target return sample.astype(np.int64) else: return np.arange(1, length, dtype=np.int64) From 2aae2d035dba66c5589610fb626f09366dcade56 Mon Sep 17 00:00:00 2001 From: kunlunl Date: Tue, 3 Mar 2026 00:27:41 +0800 Subject: [PATCH 9/9] Lint --- megatron/core/tokenizers/text/libraries/sft_tokenizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py index b9a4f7b0e4b..8a418f2dd7f 100644 --- a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py @@ -108,7 +108,6 @@ def __init__(self, tokenizer_path: str, prompt_format: str): self._prompt_format = prompt_format - def tokenize_conversation( self, conversation: List[Dict], return_target: bool, add_generation_prompt: bool ):