From 59c26d3972183050961cb7636ebbd7f7855a936a Mon Sep 17 00:00:00 2001 From: xiaoyao0115 <1804647152@qq.com> Date: Tue, 3 Feb 2026 01:46:21 -0800 Subject: [PATCH 1/3] add thd e2e support and mock dataset Signed-off-by: xiaoyao0115 <1804647152@qq.com> --- megatron/core/datasets/data_schedule.py | 983 +++++++++++++++++- megatron/core/datasets/gpt_dataset.py | 6 + megatron/core/model_parallel_config.py | 13 +- megatron/core/pipeline_parallel/schedules.py | 9 + .../text/libraries/sft_tokenizer.py | 5 + .../core/transformer/transformer_config.py | 32 + megatron/training/arguments.py | 45 +- megatron/training/datasets/sft_dataset.py | 174 +++- megatron/training/training.py | 129 ++- pretrain_gpt.py | 17 +- tests/unit_tests/test_sequence_packing.py | 322 ++++++ 11 files changed, 1672 insertions(+), 63 deletions(-) create mode 100644 tests/unit_tests/test_sequence_packing.py diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 0f016473b6a..7a1dd9e587d 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -1,12 +1,27 @@ # Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. -from typing import Any, List, Optional +import enum +from typing import Any, Dict, List, Optional, Type, Union +import numpy as np import torch from megatron.core import parallel_state +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.rerun_state_machine import RerunDataIterator +from megatron.core.utils import is_te_min_version + +try: + # Register the TE CUDA kernels + import transformer_engine # pylint: disable=unused-import + + # Alias the PyTorch wrapper so we can call tex.* APIs + import transformer_engine_torch as tex +except ImportError: + # TE isn’t installed or the torch wrapper is missing + tex = None class HybridCPDataLoaderWrapper: @@ -299,3 +314,969 @@ def __next__(self) -> Any: batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets ) return samples_this_rank_with_id, sample_id_groups + + +class PackingScheduler(enum.Enum): + """Enum for supported sequence packing algorithms.""" + + DEFAULT_SEQUENCE_PACKING = "default_sequence_packing" + + +def _broadcast_tensor(item, src_rank, group) -> None: + """Broadcast a tensor from src_rank to all ranks in the group.""" + if item is not None: + torch.distributed.broadcast(item, src_rank, group=group) + + +class BaseScheduler: + """Base class for sequence packing schedulers.""" + + def __init__( + self, + max_seqlen_per_dp_cp_rank: int, + cp_size: int, + dp_size: int, + microbatch_group_size_per_vp_stage: Optional[int], + ): + self.max_seqlen_per_dp_cp_rank = max_seqlen_per_dp_cp_rank + self.cp_size = cp_size + self.dp_size = dp_size + self.microbatch_group_size_per_vp_stage = microbatch_group_size_per_vp_stage + + def get_require_sample_keys(self): + """Return the required key of each batch.""" + raise NotImplementedError + + def get_groups_and_subsamples(self, sample_id_seqlens): + """schedule the samples into groups""" + raise NotImplementedError + + def run( + self, + data_iterator, + num_microbatches, + dp_group, + tp_group, + pp_group, + dp_cp_group, + dev, + config, + ): + """Run the scheduler and return the new data_iterator.""" + raise NotImplementedError + + @staticmethod + def _get_global_seqlens(subsample_seqlens: torch.Tensor, dp_group) -> List[int]: + """ + Gathers the sequence lengths of all subsamples from all DP ranks. + + Each DP rank has the same number of subsamples (num_microbatches), + so we can directly all_gather without padding. + """ + dp_size = dp_group.size() + num_local_subsamples = subsample_seqlens.shape[0] + + # Gather the subsample_seqlens from all ranks + seqlens_gathered = [torch.empty_like(subsample_seqlens) for _ in range(dp_size)] + torch.distributed.all_gather(seqlens_gathered, subsample_seqlens, group=dp_group) + + seqlens_gathered = torch.cat(seqlens_gathered, dim=0) + seqlens_gathered = seqlens_gathered.cpu().tolist() + + # Calculate the offsets to assign unique global ID to each subsample. + # Since each rank has the same number of subsamples, offsets are evenly spaced. + offsets = torch.arange( + 0, dp_size * num_local_subsamples, num_local_subsamples, dtype=torch.int32 + ) + + return seqlens_gathered, offsets + + @staticmethod + def _get_global_id_seqlens(num_local_subsamples, offsets, seqlens_gathered, dp_group): + """ + Calculates the global ID for each subsample. + + We assign a unique global ID to each subsample. + + Returns: + global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. + global_ids_this_rank: list of global IDs locally present on this rank. + """ + dp_rank = dp_group.rank() + global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() + # Create a list of (global_id, seqlen) tuples for scheduling + global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] + # Get the global IDs locally present on this rank + global_ids_this_rank = global_ids[ + offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples + ] + + return global_id_seqlens, global_ids_this_rank + + @staticmethod + def _broadcast_to_pp_group( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + pp_group, + dev, + ): + """ + Broadcast num_micro_batches, seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch and metadata to middle PP stages. + """ + + pp_src_rank = torch.distributed.get_process_group_ranks(pp_group)[0] + + if pp_group.size() > 2: + if pp_group.rank() == 0: + tensor_list = [ + torch.tensor( + [ + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ], + dtype=torch.float32, + ).cuda() + ] + for sample in new_samples: + tensor_list.append(sample["max_seqlen"].unsqueeze(0)) + for sample in new_samples: + tensor_list.append(sample["cu_seqlens"]) + tensor_list.append(sample["cu_seqlens_padded"]) + info_to_broadcast = torch.cat(tensor_list, dim=0).to( + device=dev, dtype=torch.float32 + ) + info_length_tensor = torch.tensor( + info_to_broadcast.shape[0], dtype=torch.int32 + ).cuda() + _broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) + _broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) + else: + info_length_tensor = torch.tensor(0, dtype=torch.int32).cuda() + _broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) + info_to_broadcast = torch.empty( + info_length_tensor.item(), dtype=torch.float32 + ).cuda() + _broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) + if pp_group.rank() != pp_group.size() - 1: + # middle PP stages receive the broadcasted info and unpack it + info_numpy = info_to_broadcast.cpu().numpy() + num_micro_batches = int(info_numpy[0]) + seqlen_sum_this_global_batch = info_numpy[1] + seqlen_squared_sum_this_global_batch = info_numpy[2] + max_seqlens = info_to_broadcast[3 : 3 + num_micro_batches] + cu_seqlens_list = [] + cu_seqlens_padded_list = [] + indices = np.where(info_numpy == 0)[0] + for i in range(num_micro_batches): + cu_seqlens_list.append( + info_to_broadcast[indices[i * 2] : indices[i * 2 + 1]] + ) + if i == num_micro_batches - 1: + cu_seqlens_padded_list.append(info_to_broadcast[indices[i * 2 + 1] :]) + else: + cu_seqlens_padded_list.append( + info_to_broadcast[indices[i * 2 + 1] : indices[i * 2 + 2]] + ) + + new_samples = [] + for i in range(num_micro_batches): + new_sample = {} + new_sample["max_seqlen"] = max_seqlens[i].to(torch.int32) + new_sample["cu_seqlens"] = cu_seqlens_list[i].to(torch.int32) + new_sample["cu_seqlens_padded"] = cu_seqlens_padded_list[i].to(torch.int32) + new_samples.append(new_sample) + + return ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + @staticmethod + def _broadcast_scalars(values: List, group, dev, dtype=torch.float32) -> List: + """ + Broadcast scalar values from rank 0 to all ranks in the group. + + Args: + values: List of scalar values to broadcast (only used on rank 0). + group: The process group to broadcast within. + dev: The device to use for the tensor. + dtype: The data type for the tensor. + + Returns: + List of broadcasted values. + """ + if group.size() <= 1: + return values + + src_rank = torch.distributed.get_process_group_ranks(group)[0] + num_values = len(values) + + if group.rank() == 0: + info_to_broadcast = torch.tensor(values, dtype=dtype, device=dev) + else: + info_to_broadcast = torch.zeros(num_values, dtype=dtype, device=dev) + + _broadcast_tensor(info_to_broadcast, src_rank, group) + + if group.rank() != 0: + values = info_to_broadcast.cpu().tolist() + + return values + + @staticmethod + def _create_data_iterator(new_samples, pp_group, tp_group, config): + """Handle virtual pipeline parallelism.""" + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + vpp_size = config.virtual_pipeline_model_parallel_size + if tp_group.rank() == 0: + if pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1: + new_samples_for_other_ppstage = [] + for sample in new_samples: + new_sample_for_other_ppstage = {} + new_sample_for_other_ppstage["max_seqlen"] = sample["max_seqlen"] + new_sample_for_other_ppstage["cu_seqlens"] = sample["cu_seqlens"] + new_sample_for_other_ppstage["cu_seqlens_padded"] = sample[ + "cu_seqlens_padded" + ] + new_samples_for_other_ppstage.append(new_sample_for_other_ppstage) + if pp_group.rank() == 0: + new_data_iterator = [RerunDataIterator(iter(new_samples))] + [ + RerunDataIterator(iter(new_samples_for_other_ppstage)) + for _ in range(vpp_size - 1) + ] + else: + new_data_iterator = [ + RerunDataIterator(iter(new_samples_for_other_ppstage)) + for _ in range(vpp_size - 1) + ] + [RerunDataIterator(iter(new_samples))] + else: + new_data_iterator = [ + RerunDataIterator(iter(new_samples)) for _ in range(vpp_size) + ] + else: + new_data_iterator = [None for _ in range(vpp_size)] + else: + new_data_iterator = ( + RerunDataIterator(iter(new_samples)) if tp_group.rank() == 0 else None + ) + + return new_data_iterator + + @staticmethod + def _reroute_samples_to_dcp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_dcp_gpus, + ): + """ + Reroutes the sub-samples to the correct rank after scheduling. + + For each key in the batch dict, we perform an all-to-all communication + to transfer the data to the correct ranks. + """ + + def _gid_to_src_rank(gid: int) -> int: + dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) + dcp_rank = ( + torch.distributed.get_process_group_ranks(dp_group)[dp_src_rank] // tp_group.size() + ) % dp_cp_group.size() + return dcp_rank + + gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} + dcp_rank = dp_cp_group.rank() + dp_ranks = torch.distributed.get_process_group_ranks(dp_group) + dp_ranks = [(r // tp_group.size()) % dp_cp_group.size() for r in dp_ranks] + + data_keys = batch[0].keys() + + # Create the send plan + combined_sample_id_groups: List[List[int]] = [[] for _ in range(total_dcp_gpus)] + for d in range(total_dcp_gpus): + for sample_id_group in sample_id_groups: + combined_sample_id_groups[d].extend(sample_id_group[d]) + for dest_rank in range(total_dcp_gpus): + combined_sample_id_groups[dest_rank].sort() + + send_ids_sorted = [ + gid + for d in dp_ranks + for gid in combined_sample_id_groups[d] + if gid in global_ids_this_rank + ] + + send_num_split = [0] * total_dcp_gpus + send_lens_split = [0] * total_dcp_gpus + for dest_rank in range(total_dcp_gpus): + if dest_rank in dp_ranks: + send_seq_lens = [ + global_id_seqlens[gid][1] + for gid in combined_sample_id_groups[dest_rank] + if gid in global_ids_this_rank + ] + send_num_split[dest_rank] = len(send_seq_lens) + send_lens_split[dest_rank] = sum(send_seq_lens) + else: + send_lens_split[dest_rank] = 0 + + # Create the recv plan + recv_sample_id_groups = [[] for _ in range(total_dcp_gpus)] + for gid in combined_sample_id_groups[dcp_rank]: + src_rank = _gid_to_src_rank(gid) + recv_sample_id_groups[src_rank].append(gid) + + recv_lens_split = [0] * total_dcp_gpus + for src_rank in range(total_dcp_gpus): + recv_lens_split[src_rank] = sum( + [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] + ) + + recv_ids_sorted = [gid for d in range(total_dcp_gpus) for gid in recv_sample_id_groups[d]] + recv_counts = [len(recv_sample_id_groups[d]) for d in range(total_dcp_gpus)] + + recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] + + def _pack_sample_by_key(key: str) -> torch.Tensor: + flattened_tensors = [] + for gid in send_ids_sorted: + t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) + flattened_tensors.append(t.reshape(-1)) + return ( + torch.cat(flattened_tensors, dim=0) + if flattened_tensors + else torch.empty(1, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + ) + + def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): + cursor = 0 + for i, gid in enumerate(recv_ids_sorted): + sample_len = ( + 1 + if key in ["original_seq_len", "padded_seq_len"] + else global_id_seqlens[gid][1] + ) + recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] + cursor += sample_len + + for key in data_keys: + output_split_sizes, input_split_sizes = ( + (recv_counts, send_num_split) + if key in ["original_seq_len", "padded_seq_len"] + else (recv_lens_split, send_lens_split) + ) + send_tensor = _pack_sample_by_key(key) + recv_tensor_size = sum(output_split_sizes) + recv_tensor = torch.empty( + recv_tensor_size, device=torch.cuda.current_device(), dtype=send_tensor.dtype + ) + torch.distributed.all_to_all_single( + output=recv_tensor, + input=send_tensor, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=dp_cp_group, + ) + _unpack_sample_by_key(key, recv_tensor) + + recv_sample_with_id = { + recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted) + } + return recv_sample_with_id + + @staticmethod + def _pack_sequences( + samples: List, + padded_lengths: torch.Tensor, + original_lengths: torch.Tensor, + dev: torch.device, + ) -> Dict[str, torch.Tensor]: + """Pack multiple samples into a single packed sample.""" + + def _pack_tensors(tensors): + return torch.cat([t.reshape(-1) for t in tensors], dim=0) + + tokens = _pack_tensors([sample["tokens"] for sample in samples]) + labels = _pack_tensors([sample["labels"] for sample in samples]) + loss_mask = _pack_tensors([sample["loss_mask"] for sample in samples]) + position_ids = _pack_tensors([sample["position_ids"] for sample in samples]) + + new_sample = {} + new_sample["tokens"] = tokens + new_sample["labels"] = labels + new_sample["loss_mask"] = loss_mask + new_sample["position_ids"] = position_ids + + padded_lengths = padded_lengths.to( + device=dev, dtype=torch.int32, non_blocking=True + ).reshape(-1) + cu_seqlens_padded = torch.empty(padded_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens_padded[0] = 0 + cu_seqlens_padded[1:] = torch.cumsum(padded_lengths, dim=0) + max_seqlen = torch.max(padded_lengths).to(dtype=torch.int32) + + new_sample["cu_seqlens_padded"] = cu_seqlens_padded + new_sample["max_seqlen"] = max_seqlen + + original_lengths = original_lengths.to( + device=dev, dtype=torch.int32, non_blocking=True + ).reshape(-1) + cu_seqlens = torch.empty(original_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens[0] = 0 + cu_seqlens[1:] = torch.cumsum(original_lengths, dim=0).reshape(-1) + new_sample["cu_seqlens"] = cu_seqlens + + return new_sample + + @staticmethod + def _build_packed_microbatches( + grouped_samples: List[List[Dict[str, torch.Tensor]]], dev: torch.device + ) -> List[Dict[str, torch.Tensor]]: + """Build packed samples for each microbatch.""" + num_micro_batches = len(grouped_samples) + seg_starts: List[int] = [0] + original_lens_tensors = [] + padded_lens_tensors = [] + + for i in range(num_micro_batches): + samples = grouped_samples[i] + seg_starts.append(seg_starts[-1] + len(samples)) + original_lens_tensors.extend([s["original_seq_len"].reshape(-1) for s in samples]) + padded_lens_tensors.extend([s["padded_seq_len"].reshape(-1) for s in samples]) + + padded_lens_all_gpu = torch.cat(padded_lens_tensors, dim=0).to(dtype=torch.int32) + original_lens_all_gpu = torch.cat(original_lens_tensors, dim=0).to(dtype=torch.int32) + + new_samples: List[Dict[str, torch.Tensor]] = [] + for i in range(num_micro_batches): + samples = grouped_samples[i] + lp = padded_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + lo = original_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + new_sample = BaseScheduler._pack_sequences(samples, lp, lo, dev) + new_samples.append(new_sample) + + return new_samples + + @staticmethod + def _get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group): + """ + Get the batch and global sequence lengths. + Each DP rank loads the same number of sequences, so we need to gather the sequence + lengths from all ranks then we can schedule the sequences into groups. + Args: + data_iterator: The data iterator. + num_microbatches: The number of microbatches. + dp_group: The data parallel group. + + Returns: + batch: The batch. + global_id_seqlens: The global sequence lengths. + global_ids_this_rank: The global IDs locally present on this rank. + """ + batch = [next(data_iterator) for _ in range(num_microbatches)] + subsample_seqlens = [] + for sample in batch: + subsample_seqlens.extend([sample["tokens"].numel()]) + subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() + + seqlens_gathered, offsets = BaseScheduler._get_global_seqlens(subsample_seqlens, dp_group) + + global_id_seqlens, global_ids_this_rank = BaseScheduler._get_global_id_seqlens( + subsample_seqlens.shape[0], offsets, seqlens_gathered, dp_group + ) + + return batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered + + +class DefaultSequencePackingScheduler(BaseScheduler): + """Packs sequences in their original order until reaching the max limit of sequence length.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_seq_len_all_ranks = self.max_seqlen_per_dp_cp_rank * self.cp_size + + def get_require_sample_keys(self): + """Return the required key of each batch.""" + return [ + "tokens", + "labels", + "loss_mask", + "position_ids", + "original_seq_len", # Length of the original sequence length, should be a gpu tensor. + "padded_seq_len", # Length of the padded sequence length, should be a gpu tensor. + ] + + def get_groups_and_subsamples(self, sample_id_seqlens): + """ + Packs sequences in their original order until reaching the max limit of sequence length. + """ + sample_id_groups = [] + packed_id_groups = [] + sum_seqlen = 0 + single_microbatch = [] + + for i in range(len(sample_id_seqlens)): + if sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks: + single_microbatch.append(i) + sum_seqlen += sample_id_seqlens[i][1] + else: + packed_id_groups.append(single_microbatch) + single_microbatch = [i] + sum_seqlen = sample_id_seqlens[i][1] + if len(single_microbatch) > 0: + packed_id_groups.append(single_microbatch) + + # we want the number of packed sequences to be multiple of dp_size + # so we move few samples from previous microbatch + # to the end of the microbatches if needed + num_packed_sequence = len(packed_id_groups) + + # when enabling vpp, we want the number of packed sequences to be + # multiple of dp_size * microbatch_group_size_per_vp_stage + multiple = self.dp_size * ( + self.microbatch_group_size_per_vp_stage + if self.microbatch_group_size_per_vp_stage is not None + else 1 + ) + if num_packed_sequence % multiple != 0: + remainder = num_packed_sequence % multiple + num_to_move = multiple - remainder + i = num_packed_sequence - 1 + while num_to_move > 0: + assert i > 0, "Not enough samples to move" + if len(packed_id_groups[i]) > 1: + seq_id = packed_id_groups[i].pop() + packed_id_groups.append([seq_id]) + num_to_move -= 1 + else: + i -= 1 + + num_micro_batches = int(len(packed_id_groups) / self.dp_size) + for i in range(num_micro_batches): + sample_id_groups.append([]) + for j in range(self.cp_size * self.dp_size): + seq_id = int(i * self.dp_size + j / self.cp_size) + sample_id_groups[i].append(packed_id_groups[seq_id]) + return sample_id_groups + + def run( + self, + data_iterator, + num_microbatches: int, + dp_group, + tp_group, + pp_group, + dp_cp_group, + dev: torch.device, + config, + ): + """ + Run the complete scheduling pipeline. + + Steps: + 1. Fetch batches and gather global sequence lengths + 2. Check required sample keys + 3. Schedule samples into groups + 4. Reroute samples to DCP ranks + 5. Build packed microbatches + 6. Calculate FLOPs info + 7. Broadcast to PP group (for middle PP stages) + 8. Broadcast to TP group (for non-TP-0 ranks) + 9. Handle VPP if enabled + + Args: + data_iterator: The data iterator. + num_microbatches: The number of microbatches to fetch. + dp_group: Data parallel process group. + tp_group: Tensor parallel process group. + pp_group: Pipeline parallel process group. + dp_cp_group: Data parallel + context parallel process group. + dev: CUDA device. + config: Model parallel config. + + Returns: + new_data_iterator: The new data iterator (or list for VPP). + num_micro_batches: Number of micro batches after scheduling. + seqlen_sum_this_global_batch: Total tokens for FLOPs calculation. + seqlen_squared_sum_this_global_batch: Sum of squared seqlens for FLOPs. + """ + + total_dcp_gpus = dp_cp_group.size() + + # Handle VPP: extract the correct data_iterator for this PP stage + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + # if enable VPP, data_iterator is a list of data_iterators for each VPP stage, + # and only the first and last stage rank will have data_iterator, + # other stages will have None. + assert len(data_iterator) == config.virtual_pipeline_model_parallel_size + if pp_group.rank() == 0: + # the first stage + data_iterator = data_iterator[0] + elif pp_group.rank() == pp_group.size() - 1: + # the last stage + data_iterator = data_iterator[-1] + else: + data_iterator = None + + # data_iterator is not None when TP rank 0, with PP stage 0 or -1. + if data_iterator is not None: + assert tp_group.rank() == 0 and ( + pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1 + ), f"Only TP rank 0 and PP stage 0 or -1 should have data_iterator" + + # Step 1: Fetch batches and gather global sequence lengths + batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered = ( + self._get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group) + ) + + # Step 2: Check required sample keys + for key in self.get_require_sample_keys(): + assert key in batch[0], f"Batch missing required key {key}" + + # Step 3: Schedule samples into groups + sample_id_groups = self.get_groups_and_subsamples(global_id_seqlens) + + # Validate scheduling result + set_gbs = set() + for group in sample_id_groups: + for sub in group: + set_gbs.update(sub) + assert len(set_gbs) == len(global_id_seqlens), ( + f"set_gbs length: {len(set_gbs)} != " + f"global_id_seqlens length: {len(global_id_seqlens)}" + ) + + # Step 4: Reroute samples to DCP ranks + samples_this_rank_with_id = self._reroute_samples_to_dcp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_dcp_gpus, + ) + + dcp_rank = dp_cp_group.rank() + num_micro_batches = len(sample_id_groups) + + grouped_samples = [ + [ + samples_this_rank_with_id[sub_sample_id] + for sub_sample_id in sample_id_groups[i][dcp_rank] + ] + for i in range(num_micro_batches) + ] + + # Step 5: Build packed microbatches + new_samples = self._build_packed_microbatches(grouped_samples, dev) + + # Step 6: Calculate FLOPs info + seqlen_sum_this_global_batch = float(sum(seqlens_gathered)) + seqlen_squared_sum_this_global_batch = float( + sum(seqlen**2 for seqlen in seqlens_gathered) + ) + else: + ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = (None, None, None, None) + + # Step 7: Broadcast to PP group (for middle PP stages) + if tp_group.rank() == 0: + ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = self._broadcast_to_pp_group( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + pp_group, + dev, + ) + + # Step 8: Broadcast to TP group (for non-TP-0 ranks) + (num_micro_batches, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) = ( + self._broadcast_scalars( + [ + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ], + tp_group, + dev, + ) + ) + num_micro_batches = int(num_micro_batches) + + # Step 9: create data_iterator and handle VPP if enabled + new_data_iterator = self._create_data_iterator(new_samples, pp_group, tp_group, config) + + return ( + new_data_iterator, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + +def wrap_dataloader( + data_iterator, config, num_microbatches, pg_collection: Optional[ProcessGroupCollection] = None +): + """ + A wrapper function that wraps around an existing data_iterator + and return the num_micro_batches for sequence packing. + + Args: + data_iterator: The original data_iterator to wrap around + config: The config object containing the max_seqlen_per_dp_cp_rank + dp_cp_group: Data parallel context parallel group. + pg_collection: The process group collection. + """ + + scheduler_map: Dict[PackingScheduler, Type[BaseScheduler]] = { + PackingScheduler.DEFAULT_SEQUENCE_PACKING: DefaultSequencePackingScheduler + } + + if pg_collection is None: + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + dp_group = parallel_state.get_data_parallel_group() + tp_group = parallel_state.get_tensor_model_parallel_group() + pp_group = parallel_state.get_pipeline_model_parallel_group() + else: + dp_cp_group = pg_collection.dp_cp + dp_group = pg_collection.dp + tp_group = pg_collection.tp + pp_group = pg_collection.pp + assert ( + dp_cp_group is not None + and dp_group is not None + and tp_group is not None + and pp_group is not None + ), "dp_cp_group, dp_group, tp_group must not be None when using sequence packing" + + dev = torch.cuda.current_device() + dp_size = dp_group.size() + cp_size = dp_cp_group.size() // dp_size + + # Convert string to enum + scheduler_type = config.sequence_packing_scheduler + scheduler_type = PackingScheduler[scheduler_type.upper()] + + scheduler = scheduler_map[scheduler_type]( + config.max_seqlen_per_dp_cp_rank, + cp_size, + dp_size, + # When VPP is enabled, align num_micro_batches to this multiple. + ( + None + if config.virtual_pipeline_model_parallel_size is None + else config.microbatch_group_size_per_vp_stage + ), + ) + + ( + new_data_iterator, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = scheduler.run( + data_iterator, num_microbatches, dp_group, tp_group, pp_group, dp_cp_group, dev, config + ) + + return ( + new_data_iterator, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + +def get_batch_on_this_rank_for_sequence_packing( + data_iterator, mtp_on_this_rank: bool = False, vp_stage: Optional[int] = None +): + """ + Get a batch of data for sequence packing. + Args: + data_iterator (Iterator): The data iterator to get the batch from. + mtp_on_this_rank (bool): Whether to use multi-token prediction. + vp_stage (Optional[int]): The stage of the pipeline. + Returns: + tuple of (tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params) + """ + + tp_src_rank = parallel_state.get_tensor_model_parallel_src_rank() + tp_group = parallel_state.get_tensor_model_parallel_group() + + is_tp_rank_0 = parallel_state.get_tensor_model_parallel_rank() == 0 + is_first_stage = parallel_state.is_pipeline_first_stage( + ignore_virtual=vp_stage is None, vp_stage=vp_stage + ) + is_last_stage = parallel_state.is_pipeline_last_stage( + ignore_virtual=vp_stage is None, vp_stage=vp_stage + ) + is_first_or_last_stage = is_first_stage or is_last_stage + dev = torch.cuda.current_device() + + # data_iterator should return a batch including the following keys. + batch_keys = ['cu_seqlens', 'cu_seqlens_padded', 'max_seqlen'] + if is_first_stage: + batch_keys.append('tokens') + batch_keys.append('position_ids') + if is_last_stage: + batch_keys.append('labels') + batch_keys.append('loss_mask') + + # Get a batch from data_iterator or create an emtpy batch. + if is_tp_rank_0: + assert data_iterator is not None + batch = next(data_iterator) + for key in batch_keys: + assert key in batch, f"{key} is missing in current batch." + else: + assert data_iterator is None, "Non TP 0 rank should not have data_iterator" + batch = {} + + # Partition tokens, position_ids, labels, loss_mask for context parallel, currently only + # TP rank 0 and the first/last PP stage rank has these data. + if is_tp_rank_0 and is_first_or_last_stage: + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + # If cp_size == 1, no need to do further processing. + if cp_size > 1: + assert tex is not None and is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + total_tokens = batch['tokens'].size(0) + # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as + # cu_seqlens to get the correct result. + # TODO: Revert this workaround once TE fixes the issue. + cu_seqlens = batch["cu_seqlens_padded"] + index = tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) + for key in ['tokens', 'position_ids', 'labels', 'loss_mask']: + batch[key] = batch[key].index_select(0, index) + + # Broadcast cu_seqlens_size because we need it to create placeholder for cu_seqlens and + # cu_seqlens_padded for non TP 0 ranks. + if is_tp_rank_0: + cu_seqlen_size = torch.tensor(batch['cu_seqlens'].size(0), dtype=torch.int32, device=dev) + else: + cu_seqlen_size = torch.empty(1, dtype=torch.int32, device=dev) + _broadcast_tensor(cu_seqlen_size, tp_src_rank, tp_group) + cu_seqlen_size = cu_seqlen_size.item() + + # Broadcast total_tokens because we need it to create placeholder for tokens, position_ids, + # labels, loss_mask for non TP 0 ranks. Only first or last stage need this. + if is_first_or_last_stage: + if is_tp_rank_0: + total_tokens = torch.tensor(batch['tokens'].size(0), dtype=torch.int32, device=dev) + else: + total_tokens = torch.empty(1, dtype=torch.int32, device=dev) + _broadcast_tensor(total_tokens, tp_src_rank, tp_group) + total_tokens = total_tokens.item() + + # Step1: Prepare "tokens", "position_ids" on all ranks. + if is_first_stage or mtp_on_this_rank: + if is_tp_rank_0: + assert batch['tokens'].dtype == torch.int64 + assert batch['position_ids'].dtype == torch.int64 + batch['tokens'] = batch['tokens'].view(1, total_tokens) + batch['position_ids'] = batch['position_ids'].view(1, total_tokens) + else: + batch['tokens'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + batch['position_ids'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + else: + # Non first stage rank doesn't need tokens and position_ids. + batch['tokens'] = None + batch['position_ids'] = None + + # Step2: Prepare "labels", "loss_mask" on all ranks. + if is_last_stage: + if is_tp_rank_0: + assert batch['labels'].dtype == torch.int64 + assert batch['loss_mask'].dtype == torch.float32 + batch['labels'] = batch['labels'].view(1, total_tokens) + batch['loss_mask'] = batch['loss_mask'].view(1, total_tokens) + else: + batch['labels'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev) + batch['loss_mask'] = torch.empty([1, total_tokens], dtype=torch.float32, device=dev) + else: + # Non last stage rank doesn't need labels and loss_mask. + batch['labels'] = None + batch['loss_mask'] = None + + # Step3: Prepare "cu_seqlens", "cu_seqlens_padded", "max_seqlen" on all ranks. + if is_tp_rank_0: + assert batch['cu_seqlens'].dtype == torch.int32 + assert batch['cu_seqlens_padded'].dtype == torch.int32 + assert batch['cu_seqlens'].dim() == 1 + assert batch['cu_seqlens_padded'].dim() == 1 + if type(batch['max_seqlen']) == int: + batch['max_seqlen'] = torch.tensor(batch['max_seqlen'], dtype=torch.int32, device=dev) + else: + assert batch['max_seqlen'].dtype == torch.int32 + assert batch['max_seqlen'].numel() == 1 + else: + batch['cu_seqlens'] = torch.empty([cu_seqlen_size], dtype=torch.int32, device=dev) + batch['cu_seqlens_padded'] = torch.empty([cu_seqlen_size], dtype=torch.int32, device=dev) + batch['max_seqlen'] = torch.empty(1, dtype=torch.int32, device=dev) + + # Broadcast batch inside TP group. + _broadcast_tensor(batch['tokens'], tp_src_rank, tp_group) + _broadcast_tensor(batch['position_ids'], tp_src_rank, tp_group) + _broadcast_tensor(batch['labels'], tp_src_rank, tp_group) + _broadcast_tensor(batch['loss_mask'], tp_src_rank, tp_group) + _broadcast_tensor(batch['cu_seqlens'], tp_src_rank, tp_group) + _broadcast_tensor(batch['cu_seqlens_padded'], tp_src_rank, tp_group) + _broadcast_tensor(batch['max_seqlen'], tp_src_rank, tp_group) + + # Extract the data from batch after broadcasting. + tokens = batch['tokens'] + position_ids = batch['position_ids'] + labels = batch['labels'] + loss_mask = batch['loss_mask'] + cu_seqlens = batch['cu_seqlens'] + cu_seqlens_padded = batch['cu_seqlens_padded'] + max_seqlen = batch['max_seqlen'].item() + + # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as cu_seqlens to + # get the correct result. + # TODO: Revert this workaround once TE fixes the issue. + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + local_cp_size=None, + cp_group=None, + ) + + # "attention_mask" is not valid for sequence packing, so set it to None. + return tokens, labels, loss_mask, None, position_ids, packed_seq_params diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index cbe0652402d..fa421641db5 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -79,6 +79,12 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): context_parallel_size: Optional[int] = None """The size of the context parallel group. Needed for padding in packed sequences.""" + sft_mock_dataset_config_json: Optional[str] = None + """This config provides the necessary information for the mock dataset.""" + + sequence_packing: bool = False + """Option to enable sequence packing for training.""" + def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 3c6ff04d3b0..18b0144041b 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -59,7 +59,7 @@ class ModelParallelConfig: can handle without overflowing the memory. Typically, a good starting point is to set this to maximum sequence length / context parallel size. This is used to calculate the number and length of sub-samples assigned to - each rank when using hybrid_context_parallel. + each rank when using sequence_packing. """ hybrid_context_parallel: bool = False @@ -69,6 +69,17 @@ class ModelParallelConfig: Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel. """ + sequence_packing_scheduler: Optional[str] = None + """ + Scheduler for sequence packing and hybrid context parallel. + default_sequence_packing: default sequence packing scheduler for sequence packing. + """ + + sequence_packing: bool = False + """ + If true, enables sft sequence packing. + """ + expert_model_parallel_size: int = 1 """Distributes Moe Experts across sub data parallel dimension.""" diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 15c5adfc7a2..6ed6317d61c 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -552,6 +552,9 @@ def forward_backward_no_pipelining( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + pg_collection.dp = parallel_state.get_data_parallel_group( + with_context_parallel=False, partial_data_parallel=False + ) elif pg_collection is not None: assert hasattr(pg_collection, 'tp') @@ -879,6 +882,9 @@ def forward_backward_pipelining_with_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + pg_collection.dp = parallel_state.get_data_parallel_group( + with_context_parallel=False, partial_data_parallel=False + ) elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model[0]) @@ -2028,6 +2034,9 @@ def forward_backward_pipelining_without_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + pg_collection.dp = parallel_state.get_data_parallel_group( + with_context_parallel=False, partial_data_parallel=False + ) elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model) assert model_type != ModelType.encoder_and_decoder, ( diff --git a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py index 8a418f2dd7f..e06e51e5ee3 100644 --- a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py @@ -237,6 +237,11 @@ def bos_id(self): def eod(self): """End of sentence token ID.""" return self._tokenizer.eos_token_id + + @property + def eos(self): + """End of sentence token ID.""" + return self._tokenizer.eos_token_id @property def vocab(self): diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 1aad3c4b89f..736f3573c81 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2019,6 +2019,38 @@ def __post_init__(self): self.attention_backend == AttnBackend.flash ), "Batch invariant mode only supports FlashAttention" + if self.sequence_packing: + # Check TE version. + if not HAVE_PACKAGING: + raise ImportError( + "packaging is not installed. Please install it with `pip install packaging`." + ) + # TODO: remove this after we fix the convergence issue with TE < 2.9. + if not ( + is_te_min_version("2.9.0") or get_te_version() == PkgVersion("2.9.0.dev0+5b3092a") + ): + raise ValueError( + "SFT sequence packing requires Transformer Engine >= 2.9.0 " + f"but got {get_te_version()} (TE < 2.9.0 may have convergence issues)." + ) + + # Needed for passing variable sequences between pp stages. + self.variable_seq_lengths = True + + # TODO(tailaim): add support for other dispatcher types + warnings.warn("Setting moe_token_dispatcher_type to alltoall for sft sequence packing.") + self.moe_token_dispatcher_type = "alltoall" + + if self.sequence_packing_scheduler is None: + self.sequence_packing_scheduler = 'default_sequence_packing' + + supported_schedulers = ['default_sequence_packing'] + if self.sequence_packing_scheduler not in supported_schedulers: + raise ValueError( + f"Unknown scheduler: {self.sequence_packing_scheduler}. " + f"Available schedulers: {supported_schedulers}" + ) + @dataclass class MLATransformerConfig(TransformerConfig): diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index eed05652e09..a3a9cce9922 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -823,13 +823,6 @@ def validate_args(args, defaults={}): if args.rl_use_sequence_packing: args.consumed_train_bins = 0 - # Support for variable sequence lengths across batches/microbatches. - # set it if the dataloader supports generation of variable sequence lengths - # across batches/microbatches. Due to additional communication overhead - # during pipeline parallelism, it should not be set if sequence length - # is constant during training. - args.variable_seq_lengths = False - # Iteration-based training. if args.train_iters: # If we use iteration-based training, make sure the @@ -1000,6 +993,32 @@ def validate_args(args, defaults={}): assert args.dataloader_type == 'single', 'Hybrid context parallelism only supported with single dataloader type' assert args.calculate_per_token_loss, 'Hybrid context parallelism must be used with --calculate-per-token-loss' + # Support for variable sequence lengths across batches/microbatches. + # set it if the dataloader supports generation of variable sequence lengths + # across batches/microbatches. Due to additional communication overhead + # during pipeline parallelism, it should not be set if sequence length + # is constant during training. + args.variable_seq_lengths = False + if args.sequence_packing: + args.variable_seq_lengths = True + assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ + f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ + f'must be >= single sequence max length ({args.seq_length})' + # TODO(tailaim): add support for other dispatcher types + print(f"Setting moe_token_dispatcher_type to alltoall for sft sequence packing with pipeline parallelism") + args.moe_token_dispatcher_type = "alltoall" + if args.mock_data and args.sft_mock_dataset_config_json is None: + args.sft_mock_dataset_config_json = json.dumps( + { + "mode": "distribution", + "type": "lognormal", + "min_seq_len": args.seq_length // 2, + "max_seq_len": args.seq_length, + "mean_seq_len": args.seq_length // 4 * 3, + "lognormal_sigma": 1.1, + } + ) + # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled if (args.tensor_model_parallel_size > 1 or args.context_parallel_size > 1) \ @@ -2399,6 +2418,14 @@ def _add_distributed_args(parser): 'all layers will share the same communication type. Users can also ' 'specify separated types for each layer like ' '--cp-comm-type p2p p2p a2a a2a a2a+p2p a2a+p2p') + group.add_argument('--max-seqlen-per-dp-cp-rank', type=int, default=None, + help='Maximum sequence length per CP rank. This is used to calculate the ' + 'number of sub-samples assigned to each CP rank when using heterogeneous context parallel.') + group.add_argument('--hybrid-context-parallel', action='store_true', default=False, + help='Enables hybrid context parallel. This is used to balance the workload ' + 'of each CP rank when we use packed samples with variable sequence lengths. ' + 'Requires --max-seqlen-per-dp-cp-rank to be set.') + group.add_argument('--sequence-packing-scheduler', type=str, default='default_sequence_packing', choices=['default_sequence_packing']) group.add_argument('--fake-process-group', action='store_true', default=False, help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \ This is quite useful for profiling memory usage of distributed training with just one GPU. \ @@ -2940,4 +2967,8 @@ def _add_sft_args(parser): group.add_argument('--sft', action="store_true", help='Megatron SFT training') group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", help='SFT prompt format.') + group.add_argument('--sequence-packing', action='store_true', + help='use sequence packing(thd format) for training') + group.add_argument('--sft-mock-dataset-config-json', type=str, default=None, + help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution.') return parser diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index fd9d1fe7c14..60ed2b45975 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -2,9 +2,12 @@ import atexit, json from collections import Counter -from typing import Any, Dict, Optional +import json +import math +from typing import Any, Dict, Optional, List import numpy as np +import pandas as pd import torch from megatron.core.datasets.gpt_dataset import GPTDatasetConfig @@ -61,6 +64,8 @@ def __init__( config: GPTDatasetConfig, ) -> None: super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) + # Pre-calculate padding divisor to avoid redundant computation in get_padding_size + self.padding_divisor = self._calculate_padding_divisor() @staticmethod def numel_low_level_dataset(low_level_dataset: LowLevelDataset) -> int: @@ -88,8 +93,38 @@ def _split_conversations(self, merged_conversations): split_conversations.append(current) return split_conversations - def __getitem__(self, idx: int) -> Dict[str, Any]: + def _calculate_padding_divisor(self) -> int: + """ + Calculate the divisor used for sequence padding. + tp_pad = tp_size * 2 if tp_size > 1 else 1 + cp_pad = cp_size * 2 if cp_size > 1 else 1 + cp_pad = cp_pad * dp_size if hybrid_cp else cp_pad + divisor = cp_pad * tp_pad + """ + if self.config.hybrid_context_parallel: + # Hybrid CP: consider both CP and DP + cp_pad = self.config.data_parallel_size * self.config.context_parallel_size * 2 + else: + # Standard CP: only consider CP + cp_pad = self.config.context_parallel_size * 2 if self.config.context_parallel_size > 1 else 1 + tp_pad = self.config.sequence_parallel_size if self.config.sequence_parallel_size > 0 else 1 + divisor = cp_pad * tp_pad + # TODO(tailaim): do we need to pad for FP8 execution? + # divisor = ((divisor + 15) // 16) * 16 + return divisor + + def get_padding_size( + self, + seq_len: int, + ) -> int: + seq_len_padded = math.ceil(seq_len / self.padding_divisor) * self.padding_divisor + assert seq_len > seq_len_padded / 2 / self.config.context_parallel_size * (self.config.context_parallel_size - 1), \ + f"sequence length {seq_len} is too short, the divisor is {self.padding_divisor}, that means cp_rank \ + {self.config.context_parallel_size-1} will have no valid tokens" + return seq_len_padded + def __getitem__(self, idx: int) -> Dict[str, Any]: + sequence_packing = self.config.sequence_packing tokenizer = self.config.tokenizer pack_length = self.config.sequence_length @@ -194,3 +229,138 @@ def extend_with_padding(tokens, targets, positions, pad_len): 'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, } + +class MockSFTLowLevelDataset: + """The low-level mock dataset for SFT + + Args: + mock_config (dict): The config for mock dataset. + """ + + seed: int = 0 + """The hard-coded random seed to use to set the NumPy RNG""" + + size: int = 1000000 + """The hard-coded number of sequence to generate""" + + # This is to maintain consistency with the SFT dataset that uses real data. In the real dataset, an element in the low-level dataset often contains multiple sequences. So here, each element in the mock low-level dataset also contains num_sequence_per_sample sequences. This will be made more reasonable in the future. + + + def __init__(self, config: Dict) -> None: + np.random.seed(self.seed) + # either choose to load sequence lengths from external file, or generate random sequence lengths + + assert "mode" in config, f"mode must be set, either 'file' or 'distribution'" + + if config["mode"] == "file": + self.sequence_lengths = np.array(pd.read_csv(config["path"])).flatten() + self.size = len(self.sequence_lengths) + elif config["mode"] == "distribution": + min_seq_len = config["min_seq_len"] + max_seq_len = config["max_seq_len"] + mean_seq_len = config["mean_seq_len"] + if config["type"] == "lognormal": + lognormal_sigma = config["lognormal_sigma"] + self.sequence_lengths = self.generate_lognormal_samples(self.size, mean_seq_len,lognormal_sigma, min_seq_len, max_seq_len) + else: + raise ValueError(f"Unsupported sequence length distribution type {config['type']}") + + def generate_lognormal_samples(self, size, mean, sigma, min_seq_len, max_seq_len): + mu = np.log(mean) - sigma**2 / 2 + samples = np.random.lognormal(mu, sigma, size) + samples = np.clip(samples, min_seq_len, max_seq_len) + return samples.astype(int) + + def __len__(self) -> int: + return self.size + + def __getitem__(self, idx: int) -> List[np.ndarray]: + length = self.sequence_lengths[idx % self.size] + # the length of sample is 'length', but only length-1 elements are generated here, + # because an eod token will be appended at the end later in SFTDataset + sample = np.arange(1, length, dtype=np.int64) + return sample +class MockSFTDataset(SFTDataset): + """The mock dataset used during SFT""" + + def __init__( + self, + dataset: LowLevelDataset, + dataset_path: Optional[str], + indices: np.ndarray, + num_samples: Optional[int], + index_split: Split, + config: GPTDatasetConfig, + ) -> None: + super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) + + @staticmethod + def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowLevelDataset: + mock_config = json.loads(config.sft_mock_dataset_config_json) + return MockSFTLowLevelDataset(mock_config) + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> Dict[str, Any]: + sequence_packing = self.config.sequence_packing + tokenizer = self.config.tokenizer + max_seq_len = self.config.sequence_length + + tokens = self.dataset[int(self.indices[idx % len(self.indices)])] + target = np.array(tokens, dtype=np.int64) + + force_eod_length = int(tokenizer.force_eod) + + if len(tokens) > max_seq_len - force_eod_length: + # cut the right side + tokens = tokens[: max_seq_len - force_eod_length] + target = target[: max_seq_len - force_eod_length] + # tokens = tokens[(-max_seq_len + force_eod_length):] + # target = target[(-max_seq_len + force_eod_length):] + + # padding + num_tokens = len(tokens) + force_eod_length + if sequence_packing: + padding_len = self.get_padding_size(num_tokens) - num_tokens + else: + padding_len = max_seq_len - num_tokens + assert padding_len >= 0 + filler = [tokenizer.eod] * force_eod_length + [tokenizer.pad] * (padding_len + 1) + + tokens = np.array(tokens.tolist() + filler, dtype=np.int64) + target = np.array(target.tolist() + filler, dtype=np.int64) + + tokens = torch.tensor(tokens) + target = torch.tensor(target) + + tokens = tokens[:-1].contiguous() + target = target[1:].contiguous() + seq_len = tokens.numel() + + loss_mask, position_ids, attention_mask = self._get_ltor_masks_and_position_ids( + seq_len, target, tokenizer.pad + ) + + if self.config.create_attention_mask: + ret = { + 'tokens': tokens, + 'labels': target, + 'attention_mask': attention_mask, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + else: + ret = { + 'tokens': tokens, + 'labels': target, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + } + + if sequence_packing: + # sequence packing need both original sequence length and padded length + ret['original_seq_len'] = torch.tensor(num_tokens, dtype=torch.int32, device=tokens.device) + ret['padded_seq_len'] = torch.tensor(seq_len, dtype=torch.int32, device=tokens.device) + + return ret diff --git a/megatron/training/training.py b/megatron/training/training.py index 2c68c70735d..a5b225b2c82 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -139,7 +139,6 @@ def set_startup_timestamps(program_start=None, main_entry=None): from megatron.training.initialize import set_jit_fusion_options from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank from megatron.training.datasets.data_samplers import build_pretraining_data_loader -from megatron.core.datasets.data_schedule import HybridCPDataLoaderWrapper from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.transformer.moe import upcycling_utils from megatron.core.transformer.moe.moe_utils import track_moe_metrics, clear_aux_losses_tracker @@ -169,6 +168,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): get_num_microbatches, update_num_microbatches ) +from megatron.core.datasets.data_schedule import wrap_dataloader from .async_utils import maybe_finalize_async_save from .utils import ( @@ -225,7 +225,7 @@ def print_datetime(string, override_timestamp=None): time_str = datetime.fromtimestamp(override_timestamp).strftime('%Y-%m-%d %H:%M:%S.%f') print_rank_0(f'[{string}] datetime: {time_str} ') -def num_floating_point_operations(args, batch_size): +def num_floating_point_operations(args, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch): def calculate_layer_counts(): """Calculate the number of attention, Mamba, and MLP layers.""" if args.hybrid_override_pattern: @@ -251,44 +251,42 @@ def calculate_layer_counts(): num_moe_layers = 0 return num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers - def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False): + def mlp_layer_flops(seqlen_sum_this_global_batch, hidden_size, expansion=4.0, swiglu=False): """Calculate FLOPs for an MLP layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 - return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 + return 4 * expansion * scale_factor * seqlen_sum_this_global_batch * hidden_size**2 - def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, + def moe_layer_flops(seqlen_sum_this_global_batch, hidden_size, moe_ffn_hidden_size, shared_expert_ffn_hidden_size, num_experts_routed_to, moe_latent_size=None, swiglu=False): """Calculate FLOPs for an MoE layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 if moe_latent_size is None: - routed_flops = (4 * batch_size * seq_len * hidden_size * + routed_flops = (4 * seqlen_sum_this_global_batch * hidden_size * moe_ffn_hidden_size * num_experts_routed_to * scale_factor) else: # Routed experts run on moe_latent_size. - routed_flops = (4 * batch_size * seq_len * moe_latent_size * + routed_flops = (4 * seqlen_sum_this_global_batch * moe_latent_size * moe_ffn_hidden_size * num_experts_routed_to * scale_factor) # Up proj and down proj. - routed_flops += (4 * batch_size * seq_len * hidden_size * moe_latent_size) - shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor + routed_flops += (4 * seqlen_sum_this_global_batch * hidden_size * moe_latent_size) + shared_flops = 4 * seqlen_sum_this_global_batch * hidden_size * shared_expert_ffn_hidden_size * scale_factor return routed_flops + shared_flops def attn_layer_flops( - batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None + seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None ): """Calculate FLOPs for an attention layer.""" p = (kv_channels * num_heads / hidden_size) if kv_channels else 1 g = gqa_groups if gqa else num_heads return ( 4 - * batch_size - * seq_len * hidden_size * p - * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2)) + * (hidden_size * seqlen_sum_this_global_batch + (hidden_size * (g / num_heads)) * seqlen_sum_this_global_batch + (seqlen_squared_sum_this_global_batch / 2)) ) - def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, + def mamba_layer_flops(seqlen_sum_this_global_batch, hidden_size, state_dim=16, head_dim=64, num_groups=1, num_heads=128): """Calculate FLOPs for a Mamba layer.""" # Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels, @@ -301,16 +299,15 @@ def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, return ( ( 2 - * batch_size - * seq_len + * seqlen_sum_this_global_batch * hidden_size * (2 * d_in + 2 * num_groups * state_dim + nheads) ) # in_proj - + (7 * batch_size * seq_len * d_in * state_dim) # scan - + (2 * batch_size * seq_len * d_in * hidden_size) # out_proj + + (7 * seqlen_sum_this_global_batch * d_in * state_dim) # scan + + (2 * seqlen_sum_this_global_batch * d_in * hidden_size) # out_proj ) - def hybrid_flops(batch_size, seq_len, hidden_size, + def hybrid_flops(seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, hidden_size, num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers, mamba_state_dim=128, mamba_head_dim=64, mamba_num_groups=8, mamba_num_heads=128, @@ -322,17 +319,17 @@ def hybrid_flops(batch_size, seq_len, hidden_size, vocab_size=256000, mtp_num_layers=0): """Calculate total FLOPs for the hybrid model.""" flops_fwd = ( - num_attn_layers * attn_layer_flops(batch_size, seq_len, hidden_size, + num_attn_layers * attn_layer_flops(seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, hidden_size, num_attn_heads, gqa, gqa_groups, kv_channels) + - num_mlp_layers * mlp_layer_flops(batch_size, seq_len, hidden_size, + num_mlp_layers * mlp_layer_flops(seqlen_sum_this_global_batch, hidden_size, mlp_expansion, swiglu) + - num_mamba_layers * mamba_layer_flops(batch_size, seq_len, hidden_size, + num_mamba_layers * mamba_layer_flops(seqlen_sum_this_global_batch, hidden_size, mamba_state_dim, mamba_head_dim, mamba_num_groups, mamba_num_heads) + - num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, + num_moe_layers * moe_layer_flops(seqlen_sum_this_global_batch, hidden_size, moe_ffn_hidden_size, shared_expert_ffn_hidden_size, num_experts_routed_to, moe_latent_size, swiglu) + - (2 * batch_size * seq_len * hidden_size * vocab_size * (1 + mtp_num_layers)) # logits computation + (2 * seqlen_sum_this_global_batch * hidden_size * vocab_size * (1 + mtp_num_layers)) # logits computation ) return flops_fwd * 3 @@ -403,13 +400,18 @@ def transformer_flops(): assert not args.group_query_attention ''' Basic arithmetic - let B is batch size, s is seq_len, h is embedding dim, - for one self_attnetion block (prenorm is not included) - qkv projection: 6Bsh^2 - attn: 2Bs^2h - attn over value: 2Bs^2h - oproj: 2Bsh^2 - + + Let h be the embedding dim. + We use two statistics to unify BSHD and THD cases: + seqlen_sum_this_global_batch: total number of tokens in this global batch + seqlen_squared_sum_this_global_batch: sum of squared sequence lengths in this global batch + + For one self-attention block (prenorm not included): + qkv projection: 6 * seqlen_sum_this_global_batch * h^2 + attn: 2 * seqlen_squared_sum_this_global_batch * h + attn over value: 2 * seqlen_squared_sum_this_global_batch * h + oproj: 2 * seqlen_sum_this_global_batch * h^2 + references https://arxiv.org/abs/2305.10403 https://arxiv.org/abs/2205.05198 @@ -430,7 +432,7 @@ def transformer_flops(): standard_self_attn_term = ( forward_backward_expansion_factor * fma_expansion_factor - * ( + * ( seqlen_sum_this_global_batch * ( ## q lora + rope + q norm q_term ## kv lora + rope + kv norm @@ -442,12 +444,12 @@ def transformer_flops(): ) + args.hidden_size * args.qk_pos_emb_head_dim ## o proj - + (args.num_attention_heads * args.v_head_dim) * args.hidden_size + + (args.num_attention_heads * args.v_head_dim) * args.hidden_size) ## core attn - + args.seq_length + + seqlen_squared_sum_this_global_batch * (args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)) - / 2 # causal mask (only half of the mask is non-zero) - + args.seq_length * args.num_attention_heads * args.v_head_dim / 2 + / 2 # causal mask (only half of the mask is non-zero) + + seqlen_squared_sum_this_global_batch * args.num_attention_heads * args.v_head_dim / 2 ) ) @@ -460,7 +462,7 @@ def transformer_flops(): standard_self_attn_term = ( forward_backward_expansion_factor * fma_expansion_factor - * ( + * ( seqlen_sum_this_global_batch *( ## qkv proj args.hidden_size * ( @@ -468,14 +470,14 @@ def transformer_flops(): + key_projection_size + value_projection_size + gate_projection_size - ) + )) ## core attention + query_projection_size - * args.seq_length + * seqlen_squared_sum_this_global_batch / 2 # causal mask (only half of the mask is non-zero) * 2 # QK^T and (QK^T)V ## out proj - + query_projection_size + + seqlen_sum_this_global_batch * query_projection_size * args.hidden_size ) ) @@ -553,8 +555,7 @@ def transformer_flops(): ) total_floating_point_operations = ( - batch_size - * args.seq_length + seqlen_sum_this_global_batch * ( # MLP forward_backward_expansion_factor @@ -584,8 +585,6 @@ def transformer_flops(): + (shared_expert_ffn_hidden_size * ffn_expansion_factor) * num_moe_layers ) - # Self Attention - + self_attn_term # MTP norms and proj + forward_backward_expansion_factor * fma_expansion_factor @@ -603,6 +602,10 @@ def transformer_flops(): * args.padded_vocab_size * (mtp_num_layers + 1) # MTP + final logit ) + + + # Self Attention + self_attn_term + ) return total_floating_point_operations @@ -616,8 +619,8 @@ def transformer_flops(): mtp_num_layers = 0 # Compute hybrid model FLOPs. return hybrid_flops( - batch_size=batch_size, - seq_len=args.seq_length, + seqlen_sum_this_global_batch=seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch=seqlen_squared_sum_this_global_batch, hidden_size=args.hidden_size, num_attn_layers=num_attn_layers, num_mamba_layers=num_mamba_layers, @@ -1712,6 +1715,19 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch if isinstance(optim_instance, DistributedOptimizer): optim_instance._copy_main_params_to_param_buffer() + if config.sequence_packing: + ( + data_iterator, + num_microbatches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) = wrap_dataloader(data_iterator, config, num_microbatches) + else: + # data_iterator unchanged + num_microbatches = get_num_microbatches() + seqlen_sum_this_global_batch = args.seq_length * args.micro_batch_size * args.data_parallel_size * num_microbatches + seqlen_squared_sum_this_global_batch = args.seq_length ** 2 * args.micro_batch_size * args.data_parallel_size * num_microbatches + # Forward pass. if save_dgrads_in_this_iteration: enable_dgrad_logging(model, args.save) @@ -1719,7 +1735,7 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch forward_step_func=forward_step_func, data_iterator=data_iterator, model=model, - num_microbatches=get_num_microbatches(), + num_microbatches=num_microbatches, seq_length=args.seq_length, micro_batch_size=args.micro_batch_size, decoder_seq_length=args.decoder_seq_length, @@ -1750,9 +1766,10 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch # iteration is 0-indexed, move to 1-indexed for checkpoint name and logging. save_grads(args.save, state_dict, iteration + 1, "wgrads") + should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() if should_exit: - return {}, True, should_checkpoint, should_exit, exit_code, None, None, 0 + return {}, True, should_checkpoint, should_exit, exit_code, None, None, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, 0 # Empty unused memory. if args.empty_unused_memory_level >= 1: @@ -1832,8 +1849,10 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch grad_norm, num_zeros_in_grad, log_max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, ) - return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit + return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch def training_log( @@ -1848,6 +1867,8 @@ def training_log( params_norm, num_zeros_in_grad, max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, pg_collection=None, is_first_iteration=False, ): @@ -2082,7 +2103,7 @@ def training_log( elapsed_time = timers('interval-time').elapsed(barrier=True, reset=should_reset) elapsed_time_per_iteration = elapsed_time / total_iterations - throughput = num_floating_point_operations(args, batch_size) / ( + throughput = num_floating_point_operations(args,seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) / ( elapsed_time_per_iteration * 10**12 * args.world_size ) @@ -2833,6 +2854,8 @@ def trace_handler(p): # Completely skip iteration if needed. if iteration in args.iterations_to_skip: + # TODO(tailaim): this need to be modified + assert not args.sequence_packing, "Sequence packing is not supported in skip iteration mode" # Dummy train_step to fast forward train_data_iterator. dummy_train_step(train_data_iterator) if iteration == start_iteration: @@ -2875,6 +2898,8 @@ def trace_handler(p): grad_norm, num_zeros_in_grad, max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, ) = train_step( forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func, iteration=iteration ) @@ -2962,7 +2987,7 @@ def trace_handler(p): else: assert num_skipped_samples_in_batch == 0 args.skipped_train_samples += num_skipped_samples_in_batch - num_floating_point_operations_in_batch = num_floating_point_operations(args, batch_size) + num_floating_point_operations_in_batch = num_floating_point_operations(args, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) num_floating_point_operations_so_far += num_floating_point_operations_in_batch num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch @@ -2988,6 +3013,8 @@ def trace_handler(p): params_norm, num_zeros_in_grad, max_attention_logit, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, pg_collection=model_pg_collection, is_first_iteration=is_first_iteration, ) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 81768944623..1ce598d6868 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -25,6 +25,7 @@ from megatron.core import parallel_state from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset +from megatron.core.datasets.data_schedule import get_batch_on_this_rank_for_sequence_packing from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel from megatron.core.rerun_state_machine import get_rerun_state_machine @@ -48,6 +49,7 @@ get_blend_and_blend_per_split, is_first_or_last_pipeline_stage, ) +from megatron.training.datasets.sft_dataset import SFTDataset, MockSFTDataset from model_provider import model_provider try: @@ -65,6 +67,14 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): """Generate a batch.""" args = get_args() config = core_transformer_config_from_args(args) + + if args.sequence_packing: + return get_batch_on_this_rank_for_sequence_packing( + data_iterator, + mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage), + vp_stage=vp_stage, + ) + # TODO: this is pretty hacky, find a better way if not is_first_or_last_pipeline_stage(vp_stage) and ( (not mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage))): @@ -249,6 +259,8 @@ def core_gpt_dataset_config_from_args(args): "data_parallel_size": args.data_parallel_size, "sequence_parallel_size": args.tensor_model_parallel_size*args.sequence_parallel, "hybrid_context_parallel": args.hybrid_context_parallel, + "sft_mock_dataset_config_json":args.sft_mock_dataset_config_json, + "sequence_packing": args.sequence_packing, } # add FIM args to the config @@ -286,7 +298,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None config = core_gpt_dataset_config_from_args(args) if args.sft: - dataset_type = SFTDataset + if args.mock_data: + dataset_type = MockSFTDataset + else: + dataset_type = SFTDataset else: if args.mock_data: dataset_type = MockGPTDataset diff --git a/tests/unit_tests/test_sequence_packing.py b/tests/unit_tests/test_sequence_packing.py new file mode 100644 index 00000000000..bac7ea79db1 --- /dev/null +++ b/tests/unit_tests/test_sequence_packing.py @@ -0,0 +1,322 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from types import SimpleNamespace + +import pytest +import torch + +from megatron.core import parallel_state +from megatron.core.datasets.data_schedule import get_batch_on_this_rank_for_sequence_packing +from megatron.training.global_vars import unset_global_variables +from tests.unit_tests.test_utilities import Utils + + +class MockVariableLengthSequencePackingDataIterator: + """ + Mock data iterator for testing get_batch_on_this_rank_for_sequence_packing. + + Generates variable-length (THD format) packed sequences with deterministic + data for verification across parallel ranks. + """ + + def __init__( + self, + total_seq_length: int, + sequence_lengths: list, + local_cp_size: int = None, + device: str = "cuda", + seed: int = 42, + ): + """ + Args: + total_seq_length: Total length of packed sequences + sequence_lengths: List of individual sequence lengths (variable-length). + If None, generates random variable lengths. + local_cp_size: Local CP size for hybrid context parallel + device: Device to create tensors on + seed: Random seed for reproducibility + """ + self.total_seq_length = total_seq_length + self.sequence_lengths = sequence_lengths + self.local_cp_size = local_cp_size + self.device = device + self.seed = seed + assert ( + sum(self.sequence_lengths) == total_seq_length + ), f"Sequence lengths sum {sum(self.sequence_lengths)} != total {total_seq_length}" + + def __iter__(self): + """Interface for the data iterator.""" + return self + + def __next__(self): + """Generate a mock batch with variable-length THD format.""" + dev = self.device + torch.manual_seed(self.seed) + torch.cuda.manual_seed(self.seed) + + tokens = torch.randint(0, 16384, (self.total_seq_length,), dtype=torch.int64, device=dev) + + # Create position_ids that reset for each sequence (THD format) + position_ids = [] + for seq_len in self.sequence_lengths: + position_ids.extend(range(seq_len)) + position_ids = torch.tensor(position_ids, dtype=torch.int64, device=dev) + + # Labels are tokens shifted by 1 for easy verification + labels = tokens + 1 + + # Loss mask: 1.0 for all positions except padding (none here) + loss_mask = torch.ones(self.total_seq_length, dtype=torch.float32, device=dev) + + # Create cu_seqlens for variable-length packed sequences + cu_seqlens = [0] + for seq_len in self.sequence_lengths: + cu_seqlens.append(cu_seqlens[-1] + seq_len) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=dev) + cu_seqlens_padded = cu_seqlens.clone() + + max_seqlen = torch.tensor([max(self.sequence_lengths)], dtype=torch.int32, device=dev) + + batch = { + "tokens": tokens, + "position_ids": position_ids, + "labels": labels, + "loss_mask": loss_mask, + "cu_seqlens": cu_seqlens, + "cu_seqlens_padded": cu_seqlens_padded, + "max_seqlen": max_seqlen, + } + + if not ( + parallel_state.is_pipeline_first_stage(ignore_virtual=True) + or parallel_state.is_pipeline_last_stage(ignore_virtual=True) + ): + batch["tokens"] = None + batch["position_ids"] = None + batch["labels"] = None + batch["loss_mask"] = None + + if self.local_cp_size is not None: + batch["local_cp_size"] = torch.tensor( + [self.local_cp_size], dtype=torch.int32, device=dev + ) + + return batch + + +def _gather_tensor_from_tp_group(tensor): + """Gather tensors from all TP ranks for comparison.""" + assert tensor is not None, "Tensor should not be None" + tp_size = parallel_state.get_tensor_model_parallel_world_size() + gathered = [torch.zeros_like(tensor) for _ in range(tp_size)] + torch.distributed.all_gather( + gathered, tensor, group=parallel_state.get_tensor_model_parallel_group() + ) + return gathered + + +def _gather_tensor_from_all_ranks(tensor): + """Gather tensors from all PP ranks for comparison.""" + assert tensor is not None, "Tensor should not be None" + if type(tensor) is int: + tensor = torch.tensor(tensor, dtype=torch.int32, device=torch.cuda.current_device()) + gathered = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(gathered, tensor) + return gathered + + +@pytest.mark.parametrize( + ("tp", "pp", "cp", "hybrid_cp"), + [ + (1, 1, 1, False), # Basic case: no parallelism + (2, 1, 1, False), # Tensor parallel only + (1, 2, 1, False), # Pipeline parallel only + (2, 2, 1, False), # TP + PP + (1, 1, 2, False), # CP only + (2, 1, 2, False), # TP + CP + (1, 2, 2, False), # PP + CP + (1, 4, 1, False), # Has middle pp stage + (1, 1, 1, True), # Hybrid CP enabled (CP=1 with hybrid groups) + (2, 1, 1, True), # TP + Hybrid CP + ], +) +def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): + """ + Test get_batch_on_this_rank_for_sequence_packing function with variable-length THD format. + + This test verifies: + 1. TP ranks: All ranks within a TP group receive identical data after broadcast + 2. PP ranks: Middle PP ranks have the same packed_seq_params as first/last stages + 3. CP ranks: Data is correctly partitioned with proper shape and values + 4. Variable-length (THD) format: Different sequence lengths are handled correctly + """ + args = SimpleNamespace() + args.tensor_model_parallel_size = tp + args.pipeline_model_parallel_size = pp + args.context_parallel_size = cp + args.hybrid_context_parallel = hybrid_cp + args.virtual_pipeline_model_parallel_size = None + args.data_parallel_size = 8 // (tp * pp * cp) + args.seq_length = 8192 + + # Skip invalid configurations + if args.data_parallel_size < 1: + raise ValueError(f"Invalid config: tp={tp}, pp={pp}, cp={cp} exceeds world size 8") + + # Initialize model parallel + Utils.initialize_model_parallel( + tp, + pp, + None, + context_parallel_size=cp, + hybrid_context_parallel=hybrid_cp, + min_hybrid_context_parallel_size=1, + ) + + try: + # Create mock data iterator with variable-length sequences + # Only TP rank 0 needs the iterator; other TP ranks pass None + tp_rank = parallel_state.get_tensor_model_parallel_rank() + local_cp_size = 8 // (tp * pp) if hybrid_cp else None + + if tp_rank == 0: + # Use deterministic seed based on DP rank so same data within TP/PP/CP group + dp_rank = parallel_state.get_data_parallel_rank() + sequence_lengths = [1024, 2048, 512, 1536, 3072] + assert ( + sum(sequence_lengths) == args.seq_length + ), f"Sequence lengths sum {sum(sequence_lengths)} != total {args.seq_length}" + data_iterator = iter( + MockVariableLengthSequencePackingDataIterator( + total_seq_length=args.seq_length, + sequence_lengths=sequence_lengths, # Variable lengths, sum=8192 + local_cp_size=local_cp_size, + seed=42 + dp_rank, # Same seed within PP/CP group + ) + ) + else: + # Non-TP-rank-0 ranks don't need the iterator + data_iterator = None + + # Call the function under test + result = get_batch_on_this_rank_for_sequence_packing( + data_iterator=data_iterator, + mtp_on_this_rank=False, + vp_stage=None, + hybrid_context_parallel=hybrid_cp, + ) + + # Unpack the result + tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = result + + # Get parallel state info + tp_rank = parallel_state.get_tensor_model_parallel_rank() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + cp_rank = parallel_state.get_context_parallel_rank() + is_first_stage = parallel_state.is_pipeline_first_stage(ignore_virtual=True) + is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + is_first_or_last = is_first_stage or is_last_stage + + # ===================================================================== + # TEST 1: Verify data based on pipeline stage + # ===================================================================== + if is_first_stage: + assert tokens is not None, "First stage should have tokens" + assert position_ids is not None, "First stage should have position_ids" + assert tokens.dim() == 2, "Tokens should be 2D (batch, seq)" + assert position_ids.dim() == 2, "Position IDs should be 2D (batch, seq)" + assert tokens.size(0) == 1, "batch should be 1 in THD format" + assert position_ids.size(0) == 1, "batch should be 1 in THD format" + else: + assert tokens is None, "Non-first stage should not have tokens" + assert position_ids is None, "Non-first stage should not have position_ids" + + if is_last_stage: + assert labels is not None, "Last stage should have labels" + assert loss_mask is not None, "Last stage should have loss_mask" + assert labels.dim() == 2, "Labels should be 2D (batch, seq)" + assert loss_mask.dim() == 2, "Loss mask should be 2D (batch, seq)" + assert labels.size(0) == 1, "batch should be 1 in THD format" + assert loss_mask.size(0) == 1, "batch should be 1 in THD format" + else: + assert labels is None, "Non-last stage should not have labels" + assert loss_mask is None, "Non-last stage should not have loss_mask" + + # ===================================================================== + # TEST 2: Verify all ranks have consistent packed_seq_params + # ===================================================================== + assert packed_seq_params is not None + assert packed_seq_params.qkv_format == "thd" + if hybrid_cp: + assert packed_seq_params.local_cp_size is not None + assert packed_seq_params.cp_group is not None + + test_keys = [ + "cu_seqlens_q", + "cu_seqlens_q_padded", + "max_seqlen_q", + "cu_seqlens_kv", + "cu_seqlens_kv_padded", + "max_seqlen_kv", + ] + if hybrid_cp: + test_keys.append("local_cp_size") + for key in test_keys: + tensor = getattr(packed_seq_params, key) + assert tensor is not None + gathered_tensor = _gather_tensor_from_all_ranks(tensor) + for i in range(1, len(gathered_tensor)): + assert torch.equal( + gathered_tensor[0], gathered_tensor[i] + ), f"Rank 0 and rank {i} have different {key}" + + # ===================================================================== + # TEST 3: Verify TP ranks receive identical data after broadcast + # ===================================================================== + if tp > 1: + test_tensors = [] + if is_first_stage: + test_tensors.extend([tokens, position_ids]) + if is_last_stage: + test_tensors.extend([labels, loss_mask]) + + for tensor in test_tensors: + gathered_tensors = _gather_tensor_from_tp_group(tensor) + for i in range(1, tp): + assert torch.equal( + gathered_tensors[0], gathered_tensors[i] + ), f"TP rank 0 and rank {i} have different data" + + # ===================================================================== + # TEST 4: Verify CP partitioning + # ===================================================================== + if cp > 1 or hybrid_cp: + if hybrid_cp: + assert packed_seq_params.local_cp_size is not None + cp_size = packed_seq_params.local_cp_size + assert packed_seq_params.cp_group == ( + parallel_state.get_hybrid_data_context_parallel_groups(group_size=cp_size) + ) + else: + cp_size = cp + + # With CP, the sequence should be partitioned + expected_seq_len = args.seq_length // cp_size + + if is_first_stage: + actual_seq_len = tokens.shape[1] + assert ( + actual_seq_len == expected_seq_len + ), f"CP partitioned tokens have wrong shape: {actual_seq_len} != {expected_seq_len}" + + # Verify labels only if all CP ranks are at last stage + if is_last_stage: + actual_seq_len = labels.shape[1] + assert ( + actual_seq_len == expected_seq_len + ), f"CP partitioned labels have wrong shape: {actual_seq_len} != {expected_seq_len}" + + finally: + Utils.destroy_model_parallel() + unset_global_variables() From cd7d04542d11b579d99157d51981bb0293c333dc Mon Sep 17 00:00:00 2001 From: tailaim Date: Mon, 9 Feb 2026 21:36:01 -0800 Subject: [PATCH 2/3] refactor according to comments and the new sftdataset Signed-off-by: tailaim --- megatron/core/datasets/data_schedule.py | 581 +++--------------- .../core/datasets/data_scheduler_utils.py | 492 +++++++++++++++ megatron/core/datasets/gpt_dataset.py | 6 +- .../core/extensions/transformer_engine.py | 21 + megatron/core/model_parallel_config.py | 9 +- .../text/libraries/sft_tokenizer.py | 5 - .../core/transformer/transformer_config.py | 18 +- megatron/training/arguments.py | 9 +- megatron/training/datasets/sft_dataset.py | 161 ++--- megatron/training/training.py | 6 +- pretrain_gpt.py | 5 +- tests/unit_tests/test_sequence_packing.py | 251 ++++++-- 12 files changed, 916 insertions(+), 648 deletions(-) create mode 100644 megatron/core/datasets/data_scheduler_utils.py diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 7a1dd9e587d..76e53c6fe37 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -1,27 +1,24 @@ # Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. import enum -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type -import numpy as np import torch from megatron.core import parallel_state +from megatron.core.datasets.data_scheduler_utils import ( + broadcast_scalars, + broadcast_tensor, + broadcast_to_pp_group, + build_packed_microbatches, + create_data_iterator, + get_batch_and_global_seqlens, + reroute_samples_to_dcp_ranks, +) +from megatron.core.extensions.transformer_engine import get_thd_partitioned_indices from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.rerun_state_machine import RerunDataIterator -from megatron.core.utils import is_te_min_version - -try: - # Register the TE CUDA kernels - import transformer_engine # pylint: disable=unused-import - - # Alias the PyTorch wrapper so we can call tex.* APIs - import transformer_engine_torch as tex -except ImportError: - # TE isn’t installed or the torch wrapper is missing - tex = None class HybridCPDataLoaderWrapper: @@ -72,7 +69,6 @@ def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: Gathers the sequence lengths of all subsamples from all DP ranks. Each DP rank loads the same number of microbatches but each microbatch may have a different number of subsamples. - We find the number of subsamples each rank holds and then gather the sequence lengths of all subsamples from all ranks. """ @@ -316,18 +312,6 @@ def __next__(self) -> Any: return samples_this_rank_with_id, sample_id_groups -class PackingScheduler(enum.Enum): - """Enum for supported sequence packing algorithms.""" - - DEFAULT_SEQUENCE_PACKING = "default_sequence_packing" - - -def _broadcast_tensor(item, src_rank, group) -> None: - """Broadcast a tensor from src_rank to all ranks in the group.""" - if item is not None: - torch.distributed.broadcast(item, src_rank, group=group) - - class BaseScheduler: """Base class for sequence packing schedulers.""" @@ -338,6 +322,14 @@ def __init__( dp_size: int, microbatch_group_size_per_vp_stage: Optional[int], ): + """ + Args: + max_seqlen_per_dp_cp_rank: The maximum sequence length per DPxCP rank. + cp_size: The context parallel size. + dp_size: The data parallel size. + microbatch_group_size_per_vp_stage: The microbatch group size per virtual + pipeline stage, only used when enabling VPP, otherwise None. + """ self.max_seqlen_per_dp_cp_rank = max_seqlen_per_dp_cp_rank self.cp_size = cp_size self.dp_size = dp_size @@ -362,446 +354,29 @@ def run( dev, config, ): - """Run the scheduler and return the new data_iterator.""" - raise NotImplementedError - - @staticmethod - def _get_global_seqlens(subsample_seqlens: torch.Tensor, dp_group) -> List[int]: - """ - Gathers the sequence lengths of all subsamples from all DP ranks. - - Each DP rank has the same number of subsamples (num_microbatches), - so we can directly all_gather without padding. - """ - dp_size = dp_group.size() - num_local_subsamples = subsample_seqlens.shape[0] - - # Gather the subsample_seqlens from all ranks - seqlens_gathered = [torch.empty_like(subsample_seqlens) for _ in range(dp_size)] - torch.distributed.all_gather(seqlens_gathered, subsample_seqlens, group=dp_group) - - seqlens_gathered = torch.cat(seqlens_gathered, dim=0) - seqlens_gathered = seqlens_gathered.cpu().tolist() - - # Calculate the offsets to assign unique global ID to each subsample. - # Since each rank has the same number of subsamples, offsets are evenly spaced. - offsets = torch.arange( - 0, dp_size * num_local_subsamples, num_local_subsamples, dtype=torch.int32 - ) - - return seqlens_gathered, offsets - - @staticmethod - def _get_global_id_seqlens(num_local_subsamples, offsets, seqlens_gathered, dp_group): """ - Calculates the global ID for each subsample. - - We assign a unique global ID to each subsample. - - Returns: - global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. - global_ids_this_rank: list of global IDs locally present on this rank. - """ - dp_rank = dp_group.rank() - global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() - # Create a list of (global_id, seqlen) tuples for scheduling - global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] - # Get the global IDs locally present on this rank - global_ids_this_rank = global_ids[ - offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples - ] - - return global_id_seqlens, global_ids_this_rank - - @staticmethod - def _broadcast_to_pp_group( - new_samples, - num_micro_batches, - seqlen_sum_this_global_batch, - seqlen_squared_sum_this_global_batch, - pp_group, - dev, - ): - """ - Broadcast num_micro_batches, seqlen_sum_this_global_batch, - seqlen_squared_sum_this_global_batch and metadata to middle PP stages. - """ - - pp_src_rank = torch.distributed.get_process_group_ranks(pp_group)[0] - - if pp_group.size() > 2: - if pp_group.rank() == 0: - tensor_list = [ - torch.tensor( - [ - num_micro_batches, - seqlen_sum_this_global_batch, - seqlen_squared_sum_this_global_batch, - ], - dtype=torch.float32, - ).cuda() - ] - for sample in new_samples: - tensor_list.append(sample["max_seqlen"].unsqueeze(0)) - for sample in new_samples: - tensor_list.append(sample["cu_seqlens"]) - tensor_list.append(sample["cu_seqlens_padded"]) - info_to_broadcast = torch.cat(tensor_list, dim=0).to( - device=dev, dtype=torch.float32 - ) - info_length_tensor = torch.tensor( - info_to_broadcast.shape[0], dtype=torch.int32 - ).cuda() - _broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) - _broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) - else: - info_length_tensor = torch.tensor(0, dtype=torch.int32).cuda() - _broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) - info_to_broadcast = torch.empty( - info_length_tensor.item(), dtype=torch.float32 - ).cuda() - _broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) - if pp_group.rank() != pp_group.size() - 1: - # middle PP stages receive the broadcasted info and unpack it - info_numpy = info_to_broadcast.cpu().numpy() - num_micro_batches = int(info_numpy[0]) - seqlen_sum_this_global_batch = info_numpy[1] - seqlen_squared_sum_this_global_batch = info_numpy[2] - max_seqlens = info_to_broadcast[3 : 3 + num_micro_batches] - cu_seqlens_list = [] - cu_seqlens_padded_list = [] - indices = np.where(info_numpy == 0)[0] - for i in range(num_micro_batches): - cu_seqlens_list.append( - info_to_broadcast[indices[i * 2] : indices[i * 2 + 1]] - ) - if i == num_micro_batches - 1: - cu_seqlens_padded_list.append(info_to_broadcast[indices[i * 2 + 1] :]) - else: - cu_seqlens_padded_list.append( - info_to_broadcast[indices[i * 2 + 1] : indices[i * 2 + 2]] - ) - - new_samples = [] - for i in range(num_micro_batches): - new_sample = {} - new_sample["max_seqlen"] = max_seqlens[i].to(torch.int32) - new_sample["cu_seqlens"] = cu_seqlens_list[i].to(torch.int32) - new_sample["cu_seqlens_padded"] = cu_seqlens_padded_list[i].to(torch.int32) - new_samples.append(new_sample) - - return ( - new_samples, - num_micro_batches, - seqlen_sum_this_global_batch, - seqlen_squared_sum_this_global_batch, - ) - - @staticmethod - def _broadcast_scalars(values: List, group, dev, dtype=torch.float32) -> List: - """ - Broadcast scalar values from rank 0 to all ranks in the group. + Run the scheduler and return the new data_iterator. - Args: - values: List of scalar values to broadcast (only used on rank 0). - group: The process group to broadcast within. - dev: The device to use for the tensor. - dtype: The data type for the tensor. - - Returns: - List of broadcasted values. - """ - if group.size() <= 1: - return values - - src_rank = torch.distributed.get_process_group_ranks(group)[0] - num_values = len(values) - - if group.rank() == 0: - info_to_broadcast = torch.tensor(values, dtype=dtype, device=dev) - else: - info_to_broadcast = torch.zeros(num_values, dtype=dtype, device=dev) - - _broadcast_tensor(info_to_broadcast, src_rank, group) - - if group.rank() != 0: - values = info_to_broadcast.cpu().tolist() - - return values - - @staticmethod - def _create_data_iterator(new_samples, pp_group, tp_group, config): - """Handle virtual pipeline parallelism.""" - if ( - config.virtual_pipeline_model_parallel_size is not None - and config.virtual_pipeline_model_parallel_size > 1 - ): - vpp_size = config.virtual_pipeline_model_parallel_size - if tp_group.rank() == 0: - if pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1: - new_samples_for_other_ppstage = [] - for sample in new_samples: - new_sample_for_other_ppstage = {} - new_sample_for_other_ppstage["max_seqlen"] = sample["max_seqlen"] - new_sample_for_other_ppstage["cu_seqlens"] = sample["cu_seqlens"] - new_sample_for_other_ppstage["cu_seqlens_padded"] = sample[ - "cu_seqlens_padded" - ] - new_samples_for_other_ppstage.append(new_sample_for_other_ppstage) - if pp_group.rank() == 0: - new_data_iterator = [RerunDataIterator(iter(new_samples))] + [ - RerunDataIterator(iter(new_samples_for_other_ppstage)) - for _ in range(vpp_size - 1) - ] - else: - new_data_iterator = [ - RerunDataIterator(iter(new_samples_for_other_ppstage)) - for _ in range(vpp_size - 1) - ] + [RerunDataIterator(iter(new_samples))] - else: - new_data_iterator = [ - RerunDataIterator(iter(new_samples)) for _ in range(vpp_size) - ] - else: - new_data_iterator = [None for _ in range(vpp_size)] - else: - new_data_iterator = ( - RerunDataIterator(iter(new_samples)) if tp_group.rank() == 0 else None - ) - - return new_data_iterator - - @staticmethod - def _reroute_samples_to_dcp_ranks( - batch, - global_ids_this_rank, - global_id_seqlens, - sample_id_groups, - offsets, - dp_group, - tp_group, - dp_cp_group, - total_dcp_gpus, - ): - """ - Reroutes the sub-samples to the correct rank after scheduling. - - For each key in the batch dict, we perform an all-to-all communication - to transfer the data to the correct ranks. - """ - - def _gid_to_src_rank(gid: int) -> int: - dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) - dcp_rank = ( - torch.distributed.get_process_group_ranks(dp_group)[dp_src_rank] // tp_group.size() - ) % dp_cp_group.size() - return dcp_rank - - gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} - dcp_rank = dp_cp_group.rank() - dp_ranks = torch.distributed.get_process_group_ranks(dp_group) - dp_ranks = [(r // tp_group.size()) % dp_cp_group.size() for r in dp_ranks] - - data_keys = batch[0].keys() - - # Create the send plan - combined_sample_id_groups: List[List[int]] = [[] for _ in range(total_dcp_gpus)] - for d in range(total_dcp_gpus): - for sample_id_group in sample_id_groups: - combined_sample_id_groups[d].extend(sample_id_group[d]) - for dest_rank in range(total_dcp_gpus): - combined_sample_id_groups[dest_rank].sort() - - send_ids_sorted = [ - gid - for d in dp_ranks - for gid in combined_sample_id_groups[d] - if gid in global_ids_this_rank - ] - - send_num_split = [0] * total_dcp_gpus - send_lens_split = [0] * total_dcp_gpus - for dest_rank in range(total_dcp_gpus): - if dest_rank in dp_ranks: - send_seq_lens = [ - global_id_seqlens[gid][1] - for gid in combined_sample_id_groups[dest_rank] - if gid in global_ids_this_rank - ] - send_num_split[dest_rank] = len(send_seq_lens) - send_lens_split[dest_rank] = sum(send_seq_lens) - else: - send_lens_split[dest_rank] = 0 - - # Create the recv plan - recv_sample_id_groups = [[] for _ in range(total_dcp_gpus)] - for gid in combined_sample_id_groups[dcp_rank]: - src_rank = _gid_to_src_rank(gid) - recv_sample_id_groups[src_rank].append(gid) - - recv_lens_split = [0] * total_dcp_gpus - for src_rank in range(total_dcp_gpus): - recv_lens_split[src_rank] = sum( - [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] - ) - - recv_ids_sorted = [gid for d in range(total_dcp_gpus) for gid in recv_sample_id_groups[d]] - recv_counts = [len(recv_sample_id_groups[d]) for d in range(total_dcp_gpus)] - - recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] - - def _pack_sample_by_key(key: str) -> torch.Tensor: - flattened_tensors = [] - for gid in send_ids_sorted: - t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) - flattened_tensors.append(t.reshape(-1)) - return ( - torch.cat(flattened_tensors, dim=0) - if flattened_tensors - else torch.empty(1, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) - ) - - def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): - cursor = 0 - for i, gid in enumerate(recv_ids_sorted): - sample_len = ( - 1 - if key in ["original_seq_len", "padded_seq_len"] - else global_id_seqlens[gid][1] - ) - recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] - cursor += sample_len - - for key in data_keys: - output_split_sizes, input_split_sizes = ( - (recv_counts, send_num_split) - if key in ["original_seq_len", "padded_seq_len"] - else (recv_lens_split, send_lens_split) - ) - send_tensor = _pack_sample_by_key(key) - recv_tensor_size = sum(output_split_sizes) - recv_tensor = torch.empty( - recv_tensor_size, device=torch.cuda.current_device(), dtype=send_tensor.dtype - ) - torch.distributed.all_to_all_single( - output=recv_tensor, - input=send_tensor, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=dp_cp_group, - ) - _unpack_sample_by_key(key, recv_tensor) - - recv_sample_with_id = { - recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted) - } - return recv_sample_with_id - - @staticmethod - def _pack_sequences( - samples: List, - padded_lengths: torch.Tensor, - original_lengths: torch.Tensor, - dev: torch.device, - ) -> Dict[str, torch.Tensor]: - """Pack multiple samples into a single packed sample.""" - - def _pack_tensors(tensors): - return torch.cat([t.reshape(-1) for t in tensors], dim=0) - - tokens = _pack_tensors([sample["tokens"] for sample in samples]) - labels = _pack_tensors([sample["labels"] for sample in samples]) - loss_mask = _pack_tensors([sample["loss_mask"] for sample in samples]) - position_ids = _pack_tensors([sample["position_ids"] for sample in samples]) - - new_sample = {} - new_sample["tokens"] = tokens - new_sample["labels"] = labels - new_sample["loss_mask"] = loss_mask - new_sample["position_ids"] = position_ids - - padded_lengths = padded_lengths.to( - device=dev, dtype=torch.int32, non_blocking=True - ).reshape(-1) - cu_seqlens_padded = torch.empty(padded_lengths.numel() + 1, device=dev, dtype=torch.int32) - cu_seqlens_padded[0] = 0 - cu_seqlens_padded[1:] = torch.cumsum(padded_lengths, dim=0) - max_seqlen = torch.max(padded_lengths).to(dtype=torch.int32) - - new_sample["cu_seqlens_padded"] = cu_seqlens_padded - new_sample["max_seqlen"] = max_seqlen - - original_lengths = original_lengths.to( - device=dev, dtype=torch.int32, non_blocking=True - ).reshape(-1) - cu_seqlens = torch.empty(original_lengths.numel() + 1, device=dev, dtype=torch.int32) - cu_seqlens[0] = 0 - cu_seqlens[1:] = torch.cumsum(original_lengths, dim=0).reshape(-1) - new_sample["cu_seqlens"] = cu_seqlens - - return new_sample - - @staticmethod - def _build_packed_microbatches( - grouped_samples: List[List[Dict[str, torch.Tensor]]], dev: torch.device - ) -> List[Dict[str, torch.Tensor]]: - """Build packed samples for each microbatch.""" - num_micro_batches = len(grouped_samples) - seg_starts: List[int] = [0] - original_lens_tensors = [] - padded_lens_tensors = [] - - for i in range(num_micro_batches): - samples = grouped_samples[i] - seg_starts.append(seg_starts[-1] + len(samples)) - original_lens_tensors.extend([s["original_seq_len"].reshape(-1) for s in samples]) - padded_lens_tensors.extend([s["padded_seq_len"].reshape(-1) for s in samples]) - - padded_lens_all_gpu = torch.cat(padded_lens_tensors, dim=0).to(dtype=torch.int32) - original_lens_all_gpu = torch.cat(original_lens_tensors, dim=0).to(dtype=torch.int32) - - new_samples: List[Dict[str, torch.Tensor]] = [] - for i in range(num_micro_batches): - samples = grouped_samples[i] - lp = padded_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] - lo = original_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] - new_sample = BaseScheduler._pack_sequences(samples, lp, lo, dev) - new_samples.append(new_sample) - - return new_samples - - @staticmethod - def _get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group): - """ - Get the batch and global sequence lengths. - Each DP rank loads the same number of sequences, so we need to gather the sequence - lengths from all ranks then we can schedule the sequences into groups. Args: data_iterator: The data iterator. - num_microbatches: The number of microbatches. - dp_group: The data parallel group. + num_microbatches: The number of microbatches to fetch. + dp_group: Data parallel process group. + tp_group: Tensor parallel process group. + pp_group: Pipeline parallel process group. + dp_cp_group: Data parallel + context parallel process group. + dev: CUDA device. + config: Model parallel config. Returns: - batch: The batch. - global_id_seqlens: The global sequence lengths. - global_ids_this_rank: The global IDs locally present on this rank. + new_data_iterator: The new data iterator (or list for VPP). + num_micro_batches: Number of micro batches after scheduling. + seqlen_sum_this_global_batch: Total tokens for FLOPs calculation. + seqlen_squared_sum_this_global_batch: Sum of squared seqlens for FLOPs. """ - batch = [next(data_iterator) for _ in range(num_microbatches)] - subsample_seqlens = [] - for sample in batch: - subsample_seqlens.extend([sample["tokens"].numel()]) - subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() - - seqlens_gathered, offsets = BaseScheduler._get_global_seqlens(subsample_seqlens, dp_group) - - global_id_seqlens, global_ids_this_rank = BaseScheduler._get_global_id_seqlens( - subsample_seqlens.shape[0], offsets, seqlens_gathered, dp_group - ) - - return batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered + raise NotImplementedError -class DefaultSequencePackingScheduler(BaseScheduler): +class DpBalancedScheduler(BaseScheduler): """Packs sequences in their original order until reaching the max limit of sequence length.""" def __init__(self, *args, **kwargs): @@ -942,12 +517,14 @@ def run( # Step 1: Fetch batches and gather global sequence lengths batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered = ( - self._get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group) + get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group) ) # Step 2: Check required sample keys for key in self.get_require_sample_keys(): - assert key in batch[0], f"Batch missing required key {key}" + assert ( + key in batch[0] + ), f"Batch missing required key {key}, provided keys: {batch[0].keys()}" # Step 3: Schedule samples into groups sample_id_groups = self.get_groups_and_subsamples(global_id_seqlens) @@ -963,7 +540,7 @@ def run( ) # Step 4: Reroute samples to DCP ranks - samples_this_rank_with_id = self._reroute_samples_to_dcp_ranks( + samples_this_rank_with_id = reroute_samples_to_dcp_ranks( batch, global_ids_this_rank, global_id_seqlens, @@ -987,7 +564,7 @@ def run( ] # Step 5: Build packed microbatches - new_samples = self._build_packed_microbatches(grouped_samples, dev) + new_samples = build_packed_microbatches(grouped_samples, dev) # Step 6: Calculate FLOPs info seqlen_sum_this_global_batch = float(sum(seqlens_gathered)) @@ -1009,7 +586,7 @@ def run( num_micro_batches, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, - ) = self._broadcast_to_pp_group( + ) = broadcast_to_pp_group( new_samples, num_micro_batches, seqlen_sum_this_global_batch, @@ -1020,7 +597,7 @@ def run( # Step 8: Broadcast to TP group (for non-TP-0 ranks) (num_micro_batches, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch) = ( - self._broadcast_scalars( + broadcast_scalars( [ num_micro_batches, seqlen_sum_this_global_batch, @@ -1033,7 +610,7 @@ def run( num_micro_batches = int(num_micro_batches) # Step 9: create data_iterator and handle VPP if enabled - new_data_iterator = self._create_data_iterator(new_samples, pp_group, tp_group, config) + new_data_iterator = create_data_iterator(new_samples, pp_group, tp_group, config) return ( new_data_iterator, @@ -1043,7 +620,18 @@ def run( ) -def wrap_dataloader( +class PackingScheduler(enum.Enum): + """Enum for supported sequence packing algorithms.""" + + DP_BALANCED = "dp_balanced" + + +scheduler_map: Dict[PackingScheduler, Type[BaseScheduler]] = { + PackingScheduler.DP_BALANCED: DpBalancedScheduler +} + + +def wrap_data_iterator( data_iterator, config, num_microbatches, pg_collection: Optional[ProcessGroupCollection] = None ): """ @@ -1057,10 +645,6 @@ def wrap_dataloader( pg_collection: The process group collection. """ - scheduler_map: Dict[PackingScheduler, Type[BaseScheduler]] = { - PackingScheduler.DEFAULT_SEQUENCE_PACKING: DefaultSequencePackingScheduler - } - if pg_collection is None: dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) dp_group = parallel_state.get_data_parallel_group() @@ -1116,7 +700,11 @@ def wrap_dataloader( def get_batch_on_this_rank_for_sequence_packing( - data_iterator, mtp_on_this_rank: bool = False, vp_stage: Optional[int] = None + data_iterator, + vpp_size: Optional[int] = None, + mtp_on_this_rank: bool = False, + vp_stage: Optional[int] = None, + pg_collection: Optional[ProcessGroupCollection] = None, ): """ Get a batch of data for sequence packing. @@ -1128,16 +716,23 @@ def get_batch_on_this_rank_for_sequence_packing( tuple of (tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params) """ - tp_src_rank = parallel_state.get_tensor_model_parallel_src_rank() - tp_group = parallel_state.get_tensor_model_parallel_group() + if pg_collection is None: + tp_group = parallel_state.get_tensor_model_parallel_group() + pp_group = parallel_state.get_pipeline_model_parallel_group() + cp_group = parallel_state.get_context_parallel_group() + else: + tp_group = pg_collection.tp + pp_group = pg_collection.pp + cp_group = pg_collection.cp - is_tp_rank_0 = parallel_state.get_tensor_model_parallel_rank() == 0 - is_first_stage = parallel_state.is_pipeline_first_stage( - ignore_virtual=vp_stage is None, vp_stage=vp_stage - ) - is_last_stage = parallel_state.is_pipeline_last_stage( - ignore_virtual=vp_stage is None, vp_stage=vp_stage + tp_src_rank = torch.distributed.get_process_group_ranks(tp_group)[0] + + is_tp_rank_0 = tp_group.rank() == 0 + is_first_stage = pp_group.rank() == 0 and (vp_stage is None or vp_stage == 0) + is_last_stage = pp_group.rank() == pp_group.size() - 1 and ( + vp_stage is None or vp_stage == vpp_size - 1 ) + is_first_or_last_stage = is_first_stage or is_last_stage dev = torch.cuda.current_device() @@ -1163,20 +758,16 @@ def get_batch_on_this_rank_for_sequence_packing( # Partition tokens, position_ids, labels, loss_mask for context parallel, currently only # TP rank 0 and the first/last PP stage rank has these data. if is_tp_rank_0 and is_first_or_last_stage: - cp_size = parallel_state.get_context_parallel_world_size() - cp_rank = parallel_state.get_context_parallel_rank() + cp_size = cp_group.size() + cp_rank = cp_group.rank() # If cp_size == 1, no need to do further processing. if cp_size > 1: - assert tex is not None and is_te_min_version("1.10.0"), ( - "Please update Transformer Engine to >= 1.10 to use " - "Context Parallel with THD format data" - ) total_tokens = batch['tokens'].size(0) # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as # cu_seqlens to get the correct result. # TODO: Revert this workaround once TE fixes the issue. cu_seqlens = batch["cu_seqlens_padded"] - index = tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) + index = get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) for key in ['tokens', 'position_ids', 'labels', 'loss_mask']: batch[key] = batch[key].index_select(0, index) @@ -1186,7 +777,7 @@ def get_batch_on_this_rank_for_sequence_packing( cu_seqlen_size = torch.tensor(batch['cu_seqlens'].size(0), dtype=torch.int32, device=dev) else: cu_seqlen_size = torch.empty(1, dtype=torch.int32, device=dev) - _broadcast_tensor(cu_seqlen_size, tp_src_rank, tp_group) + broadcast_tensor(cu_seqlen_size, tp_src_rank, tp_group) cu_seqlen_size = cu_seqlen_size.item() # Broadcast total_tokens because we need it to create placeholder for tokens, position_ids, @@ -1196,7 +787,7 @@ def get_batch_on_this_rank_for_sequence_packing( total_tokens = torch.tensor(batch['tokens'].size(0), dtype=torch.int32, device=dev) else: total_tokens = torch.empty(1, dtype=torch.int32, device=dev) - _broadcast_tensor(total_tokens, tp_src_rank, tp_group) + broadcast_tensor(total_tokens, tp_src_rank, tp_group) total_tokens = total_tokens.item() # Step1: Prepare "tokens", "position_ids" on all ranks. @@ -1246,13 +837,13 @@ def get_batch_on_this_rank_for_sequence_packing( batch['max_seqlen'] = torch.empty(1, dtype=torch.int32, device=dev) # Broadcast batch inside TP group. - _broadcast_tensor(batch['tokens'], tp_src_rank, tp_group) - _broadcast_tensor(batch['position_ids'], tp_src_rank, tp_group) - _broadcast_tensor(batch['labels'], tp_src_rank, tp_group) - _broadcast_tensor(batch['loss_mask'], tp_src_rank, tp_group) - _broadcast_tensor(batch['cu_seqlens'], tp_src_rank, tp_group) - _broadcast_tensor(batch['cu_seqlens_padded'], tp_src_rank, tp_group) - _broadcast_tensor(batch['max_seqlen'], tp_src_rank, tp_group) + broadcast_tensor(batch['tokens'], tp_src_rank, tp_group) + broadcast_tensor(batch['position_ids'], tp_src_rank, tp_group) + broadcast_tensor(batch['labels'], tp_src_rank, tp_group) + broadcast_tensor(batch['loss_mask'], tp_src_rank, tp_group) + broadcast_tensor(batch['cu_seqlens'], tp_src_rank, tp_group) + broadcast_tensor(batch['cu_seqlens_padded'], tp_src_rank, tp_group) + broadcast_tensor(batch['max_seqlen'], tp_src_rank, tp_group) # Extract the data from batch after broadcasting. tokens = batch['tokens'] diff --git a/megatron/core/datasets/data_scheduler_utils.py b/megatron/core/datasets/data_scheduler_utils.py new file mode 100644 index 00000000000..2a1e2a6528d --- /dev/null +++ b/megatron/core/datasets/data_scheduler_utils.py @@ -0,0 +1,492 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. + +from typing import Dict, List + +import numpy as np +import torch + +from megatron.core.rerun_state_machine import RerunDataIterator + + +def _unpack_batch(batch): + """ + Unpacks the packed samples into a list of sub-samples. + Since each sub-sample may be routed to different DPxCP ranks, + we unpack the sample here to avoid unnecessarily transferring + the entire packed sample. + """ + batch_unpacked = [] + dev = batch[0]["tokens"].device + original_seq_lens = [] + padded_seq_lens = [] + for sample in batch: + for key in sample.keys(): + if len(sample[key].shape) == 2: + # squeeze the redundant batch dimension added by + # default collate_fn in pytorch dataloader + # we need a custom collate_fn for THD to avoid this + # current THD does not support micro_batch_size > 1 due to sft_dataset.py and + # data_loader in data_samples.py + sample[key] = sample[key].squeeze(0) + for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): + sub_sample_dict = {} + start_idx = sample["cu_seqlens"][sub_sample] + end_idx = sample["cu_seqlens"][sub_sample + 1] + if end_idx - start_idx == 0: + continue + for key in ["tokens", "labels", "loss_mask", "position_ids"]: + sub_sample_dict[key] = sample[key][start_idx:end_idx] + # Since sft_dataset.py does not provide cu_seqlens_original, + # we assume original_seq_len equals padded_seq_len here. + # Ideally the dataset should define the pre-padding seq_len. + seq_len = (end_idx - start_idx).item() + original_seq_lens.append(seq_len) + padded_seq_lens.append(seq_len) + batch_unpacked.append(sub_sample_dict) + + # Single H2D transfer for all seq lens + original_seq_lens_cuda = torch.tensor(original_seq_lens, device=dev) + padded_seq_lens_cuda = torch.tensor(padded_seq_lens, device=dev) + for i, sub_sample_dict in enumerate(batch_unpacked): + sub_sample_dict["original_seq_len"] = original_seq_lens_cuda[i : i + 1] + sub_sample_dict["padded_seq_len"] = padded_seq_lens_cuda[i : i + 1] + + return batch_unpacked + + +def _get_global_seqlens_and_ids(subsample_seqlens: torch.Tensor, dp_group): + """ + Gathers the sequence lengths of all subsamples from all DP ranks and calculates global IDs. + """ + # Collect the number of subsamples from all ranks + num_local_subsamples = subsample_seqlens.shape[0] + local_len = torch.tensor([num_local_subsamples], dtype=torch.int32).cuda() + dp_subsample_count = [torch.zeros_like(local_len) for _ in range(dp_group.size())] + torch.distributed.all_gather(dp_subsample_count, local_len, group=dp_group) + + # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length + dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) + max_sub_samples = int(dp_subsample_counts.max().item()) + + if num_local_subsamples < max_sub_samples: + subsample_seqlens_padded = torch.cat( + [ + subsample_seqlens, + torch.zeros(max_sub_samples - num_local_subsamples, dtype=torch.int32).cuda(), + ], + dim=0, + ) + else: + subsample_seqlens_padded = subsample_seqlens + + # Gather the subsample_seqlens from all ranks + seqlens_gathered = [torch.empty_like(subsample_seqlens_padded) for _ in range(dp_group.size())] + torch.distributed.all_gather(seqlens_gathered, subsample_seqlens_padded, group=dp_group) + + # Trim each seqlens_gathered to the length of the correct sample + for dp_rank, seqlen in enumerate(seqlens_gathered): + seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]] + + seqlens_gathered = torch.cat(seqlens_gathered, dim=0) + seqlens_gathered = seqlens_gathered.cpu().tolist() + + # Calculate the offsets to assign unique global ID to each subsample. + csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32) + offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum], dim=0) + + # Calculate global ID for each subsample + dp_rank = dp_group.rank() + global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() + + # Create a list of (global_id, seqlen) tuples for scheduling + global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] + + # Get the global IDs locally present on this rank + start_idx = offsets[dp_rank] + end_idx = offsets[dp_rank + 1] + + global_ids_this_rank = global_ids[start_idx:end_idx] + + return global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered + + +def _pack_sequences( + samples: List, padded_lengths: torch.Tensor, original_lengths: torch.Tensor, dev: torch.device +) -> Dict[str, torch.Tensor]: + """Pack multiple samples into a single packed sample.""" + + def _pack_tensors(tensors): + return torch.cat([t.reshape(-1) for t in tensors], dim=0) + + tokens = _pack_tensors([sample["tokens"] for sample in samples]) + labels = _pack_tensors([sample["labels"] for sample in samples]) + loss_mask = _pack_tensors([sample["loss_mask"] for sample in samples]) + position_ids = _pack_tensors([sample["position_ids"] for sample in samples]) + + new_sample = {} + new_sample["tokens"] = tokens + new_sample["labels"] = labels + new_sample["loss_mask"] = loss_mask + new_sample["position_ids"] = position_ids + + padded_lengths = padded_lengths.to(device=dev, dtype=torch.int32, non_blocking=True).reshape(-1) + cu_seqlens_padded = torch.empty(padded_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens_padded[0] = 0 + cu_seqlens_padded[1:] = torch.cumsum(padded_lengths, dim=0) + max_seqlen = torch.max(padded_lengths).to(dtype=torch.int32) + + new_sample["cu_seqlens_padded"] = cu_seqlens_padded + new_sample["max_seqlen"] = max_seqlen + + original_lengths = original_lengths.to( + device=dev, dtype=torch.int32, non_blocking=True + ).reshape(-1) + cu_seqlens = torch.empty(original_lengths.numel() + 1, device=dev, dtype=torch.int32) + cu_seqlens[0] = 0 + cu_seqlens[1:] = torch.cumsum(original_lengths, dim=0).reshape(-1) + new_sample["cu_seqlens"] = cu_seqlens + + return new_sample + + +def broadcast_tensor(item, src_rank, group) -> None: + """Broadcast a tensor from src_rank to all ranks in the group.""" + if item is not None: + torch.distributed.broadcast(item, src_rank, group=group) + + +def broadcast_to_pp_group( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + pp_group, + dev, +): + """ + Broadcast num_micro_batches, seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch and metadata to middle PP stages. + Before this broadcast, the new_samples on middle PP stages are None, + after this broadcast, the new_samples on middle PP stages contain the metadata but + without tokens, labels, loss_mask, position_ids. + """ + + pp_src_rank = torch.distributed.get_process_group_ranks(pp_group)[0] + + if pp_group.size() > 2: + if pp_group.rank() == 0: + tensor_list = [ + torch.tensor( + [ + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ], + dtype=torch.float32, + ).cuda() + ] + for sample in new_samples: + tensor_list.append(sample["max_seqlen"].unsqueeze(0)) + for sample in new_samples: + tensor_list.append(sample["cu_seqlens"]) + tensor_list.append(sample["cu_seqlens_padded"]) + info_to_broadcast = torch.cat(tensor_list, dim=0).to(device=dev, dtype=torch.float32) + info_length_tensor = torch.tensor(info_to_broadcast.shape[0], dtype=torch.int32).cuda() + broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) + broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) + else: + info_length_tensor = torch.tensor(0, dtype=torch.int32).cuda() + broadcast_tensor(info_length_tensor, pp_src_rank, pp_group) + info_to_broadcast = torch.empty(info_length_tensor.item(), dtype=torch.float32).cuda() + broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group) + if pp_group.rank() != pp_group.size() - 1: + # middle PP stages receive the broadcasted info and unpack it + info_numpy = info_to_broadcast.cpu().numpy() + num_micro_batches = int(info_numpy[0]) + seqlen_sum_this_global_batch = info_numpy[1] + seqlen_squared_sum_this_global_batch = info_numpy[2] + max_seqlens = info_to_broadcast[3 : 3 + num_micro_batches] + cu_seqlens_list = [] + cu_seqlens_padded_list = [] + indices = np.where(info_numpy == 0)[0] + for i in range(num_micro_batches): + cu_seqlens_list.append(info_to_broadcast[indices[i * 2] : indices[i * 2 + 1]]) + if i == num_micro_batches - 1: + cu_seqlens_padded_list.append(info_to_broadcast[indices[i * 2 + 1] :]) + else: + cu_seqlens_padded_list.append( + info_to_broadcast[indices[i * 2 + 1] : indices[i * 2 + 2]] + ) + + new_samples = [] + for i in range(num_micro_batches): + new_sample = {} + new_sample["max_seqlen"] = max_seqlens[i].to(torch.int32) + new_sample["cu_seqlens"] = cu_seqlens_list[i].to(torch.int32) + new_sample["cu_seqlens_padded"] = cu_seqlens_padded_list[i].to(torch.int32) + new_samples.append(new_sample) + + return ( + new_samples, + num_micro_batches, + seqlen_sum_this_global_batch, + seqlen_squared_sum_this_global_batch, + ) + + +def broadcast_scalars(values: List, group, dev, dtype=torch.float32) -> List: + """ + Broadcast scalar values from rank 0 to all ranks in the group. + + Args: + values: List of scalar values to broadcast (only used on rank 0). + group: The process group to broadcast within. + dev: The device to use for the tensor. + dtype: The data type for the tensor. + + Returns: + List of broadcasted values. + """ + if group.size() <= 1: + return values + + src_rank = torch.distributed.get_process_group_ranks(group)[0] + num_values = len(values) + + if group.rank() == 0: + info_to_broadcast = torch.tensor(values, dtype=dtype, device=dev) + else: + info_to_broadcast = torch.zeros(num_values, dtype=dtype, device=dev) + + broadcast_tensor(info_to_broadcast, src_rank, group) + + if group.rank() != 0: + values = info_to_broadcast.cpu().tolist() + + return values + + +def create_data_iterator(new_samples, pp_group, tp_group, config): + """Handle virtual pipeline parallelism.""" + if ( + config.virtual_pipeline_model_parallel_size is not None + and config.virtual_pipeline_model_parallel_size > 1 + ): + vpp_size = config.virtual_pipeline_model_parallel_size + if tp_group.rank() == 0: + if pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1: + metadata = [ + {k: sample[k] for k in ["max_seqlen", "cu_seqlens", "cu_seqlens_padded"]} + for sample in new_samples + ] + if pp_group.rank() == 0: + new_data_iterator = [RerunDataIterator(iter(new_samples))] + [ + RerunDataIterator(iter(metadata)) for _ in range(vpp_size - 1) + ] + else: + new_data_iterator = [ + RerunDataIterator(iter(metadata)) for _ in range(vpp_size - 1) + ] + [RerunDataIterator(iter(new_samples))] + else: + # on middle PP stages, the new_samples are the metadata + metadata = new_samples + new_data_iterator = [RerunDataIterator(iter(metadata)) for _ in range(vpp_size)] + else: + new_data_iterator = [None for _ in range(vpp_size)] + else: + new_data_iterator = RerunDataIterator(iter(new_samples)) if tp_group.rank() == 0 else None + + return new_data_iterator + + +def reroute_samples_to_dcp_ranks( + batch, + global_ids_this_rank, + global_id_seqlens, + sample_id_groups, + offsets, + dp_group, + tp_group, + dp_cp_group, + total_dcp_gpus, +): + """ + Reroutes the sub-samples to the correct rank after scheduling. + + For each key in the batch dict, we perform an all-to-all communication + to transfer the data to the correct ranks. + """ + + def _gid_to_src_rank(gid: int) -> int: + dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) + dcp_rank = ( + torch.distributed.get_process_group_ranks(dp_group)[dp_src_rank] // tp_group.size() + ) % dp_cp_group.size() + return dcp_rank + + gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} + dcp_rank = dp_cp_group.rank() + dp_ranks = torch.distributed.get_process_group_ranks(dp_group) + dp_ranks = [(r // tp_group.size()) % dp_cp_group.size() for r in dp_ranks] + + data_keys = batch[0].keys() + + # Create the send plan + combined_sample_id_groups: List[List[int]] = [[] for _ in range(total_dcp_gpus)] + for d in range(total_dcp_gpus): + for sample_id_group in sample_id_groups: + combined_sample_id_groups[d].extend(sample_id_group[d]) + for dest_rank in range(total_dcp_gpus): + combined_sample_id_groups[dest_rank].sort() + + send_ids_sorted = [ + gid for d in dp_ranks for gid in combined_sample_id_groups[d] if gid in global_ids_this_rank + ] + + send_num_split = [0] * total_dcp_gpus + send_lens_split = [0] * total_dcp_gpus + for dest_rank in range(total_dcp_gpus): + if dest_rank in dp_ranks: + send_seq_lens = [ + global_id_seqlens[gid][1] + for gid in combined_sample_id_groups[dest_rank] + if gid in global_ids_this_rank + ] + send_num_split[dest_rank] = len(send_seq_lens) + send_lens_split[dest_rank] = sum(send_seq_lens) + else: + send_lens_split[dest_rank] = 0 + + # Create the recv plan + recv_sample_id_groups = [[] for _ in range(total_dcp_gpus)] + for gid in combined_sample_id_groups[dcp_rank]: + src_rank = _gid_to_src_rank(gid) + recv_sample_id_groups[src_rank].append(gid) + + recv_lens_split = [0] * total_dcp_gpus + for src_rank in range(total_dcp_gpus): + recv_lens_split[src_rank] = sum( + [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] + ) + + recv_ids_sorted = [gid for d in range(total_dcp_gpus) for gid in recv_sample_id_groups[d]] + recv_counts = [len(recv_sample_id_groups[d]) for d in range(total_dcp_gpus)] + + recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] + + def _pack_sample_by_key(key: str) -> torch.Tensor: + flattened_tensors = [] + for gid in send_ids_sorted: + t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) + flattened_tensors.append(t.reshape(-1)) + return ( + torch.cat(flattened_tensors, dim=0) + if flattened_tensors + else torch.empty(1, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + ) + + def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): + cursor = 0 + for i, gid in enumerate(recv_ids_sorted): + sample_len = ( + 1 if key in ["original_seq_len", "padded_seq_len"] else global_id_seqlens[gid][1] + ) + recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] + cursor += sample_len + + for key in data_keys: + output_split_sizes, input_split_sizes = ( + (recv_counts, send_num_split) + if key in ["original_seq_len", "padded_seq_len"] + else (recv_lens_split, send_lens_split) + ) + send_tensor = _pack_sample_by_key(key) + recv_tensor_size = sum(output_split_sizes) + recv_tensor = torch.empty( + recv_tensor_size, device=torch.cuda.current_device(), dtype=send_tensor.dtype + ) + torch.distributed.all_to_all_single( + output=recv_tensor, + input=send_tensor, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=dp_cp_group, + ) + _unpack_sample_by_key(key, recv_tensor) + + recv_sample_with_id = {recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted)} + return recv_sample_with_id + + +def build_packed_microbatches( + grouped_samples: List[List[Dict[str, torch.Tensor]]], dev: torch.device +) -> List[Dict[str, torch.Tensor]]: + """Build packed samples for each microbatch.""" + num_micro_batches = len(grouped_samples) + seg_starts: List[int] = [0] + original_lens_tensors = [] + padded_lens_tensors = [] + + for i in range(num_micro_batches): + samples = grouped_samples[i] + seg_starts.append(seg_starts[-1] + len(samples)) + original_lens_tensors.extend([s["original_seq_len"].reshape(-1) for s in samples]) + padded_lens_tensors.extend([s["padded_seq_len"].reshape(-1) for s in samples]) + + padded_lens_all_gpu = torch.cat(padded_lens_tensors, dim=0).to(dtype=torch.int32) + original_lens_all_gpu = torch.cat(original_lens_tensors, dim=0).to(dtype=torch.int32) + + new_samples: List[Dict[str, torch.Tensor]] = [] + for i in range(num_micro_batches): + samples = grouped_samples[i] + lens_padded = padded_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + lens_original = original_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] + new_sample = _pack_sequences(samples, lens_padded, lens_original, dev) + new_samples.append(new_sample) + + return new_samples + + +def get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group): + """ + Get the batch and global sequence lengths. + Each DP rank loads the same number of sequences, so we need to gather the sequence + lengths from all ranks then we can schedule the sequences into groups. + Args: + data_iterator: The data iterator. + num_microbatches: The number of microbatches. + dp_group: The data parallel group. + + Returns: + batch: The batch. + global_id_seqlens: The global sequence lengths. + global_ids_this_rank: The global IDs locally present on this rank. + """ + + batch_list = [next(data_iterator) for _ in range(num_microbatches)] + + batch = [] + for item in batch_list: + if isinstance(item, dict): + batch.append(item) + elif isinstance(item, list): + batch.extend(item) + else: + raise ValueError(f"Invalid item type: {type(item)}") + + # This unpack step is redundant: in sft_dataset.py, sequences are already packed before + # rescheduling, so we need to unpack them here and repack after rescheduling. This is only + # to adapt to the current megatron-lm sft_dataset. + # If you implement your own dataset, just have __getitem__ return List[Dict] + # and this step can be skipped. + batch = _unpack_batch(batch) + + subsample_seqlens = torch.cat([sample["padded_seq_len"] for sample in batch]).to( + dtype=torch.int32, device=torch.cuda.current_device() + ) + + global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered = ( + _get_global_seqlens_and_ids(subsample_seqlens, dp_group) + ) + + return batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index fa421641db5..bea7f4a4c18 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -82,8 +82,10 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): sft_mock_dataset_config_json: Optional[str] = None """This config provides the necessary information for the mock dataset.""" - sequence_packing: bool = False - """Option to enable sequence packing for training.""" + sequence_packing_scheduler: Optional[str] = None + """Scheduler for sequence packing and hybrid context parallel. + dp_balanced: DP-balanced scheduler for sequence packing. + """ def __post_init__(self) -> None: """Do asserts and set fields post init""" diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 996330f5674..537e9ad216a 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -2557,3 +2557,24 @@ def set_save_original_input(module): from transformer_engine.pytorch.float8_tensor import Float8Tensor except ImportError: Float8Tensor = None + + +def get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank): + """Get partitioned indices for THD format data in context parallel. + + Args: + cu_seqlens: Cumulative sequence lengths tensor. + total_tokens: Total number of tokens. + cp_size: Context parallel world size. + cp_rank: Context parallel rank. + + Returns: + Partitioned indices tensor. + """ + assert is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + import transformer_engine_torch as tex + + return tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 18b0144041b..ca3dbbcfa05 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -59,7 +59,7 @@ class ModelParallelConfig: can handle without overflowing the memory. Typically, a good starting point is to set this to maximum sequence length / context parallel size. This is used to calculate the number and length of sub-samples assigned to - each rank when using sequence_packing. + each rank when sequence_packing_scheduler is not None. """ hybrid_context_parallel: bool = False @@ -72,12 +72,7 @@ class ModelParallelConfig: sequence_packing_scheduler: Optional[str] = None """ Scheduler for sequence packing and hybrid context parallel. - default_sequence_packing: default sequence packing scheduler for sequence packing. - """ - - sequence_packing: bool = False - """ - If true, enables sft sequence packing. + dp_balanced: DP-balanced scheduler for sequence packing. """ expert_model_parallel_size: int = 1 diff --git a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py index e06e51e5ee3..8a418f2dd7f 100644 --- a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py @@ -237,11 +237,6 @@ def bos_id(self): def eod(self): """End of sentence token ID.""" return self._tokenizer.eos_token_id - - @property - def eos(self): - """End of sentence token ID.""" - return self._tokenizer.eos_token_id @property def vocab(self): diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 736f3573c81..6ec252f82fa 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2019,7 +2019,7 @@ def __post_init__(self): self.attention_backend == AttnBackend.flash ), "Batch invariant mode only supports FlashAttention" - if self.sequence_packing: + if self.sequence_packing_scheduler is not None: # Check TE version. if not HAVE_PACKAGING: raise ImportError( @@ -2038,14 +2038,16 @@ def __post_init__(self): self.variable_seq_lengths = True # TODO(tailaim): add support for other dispatcher types - warnings.warn("Setting moe_token_dispatcher_type to alltoall for sft sequence packing.") - self.moe_token_dispatcher_type = "alltoall" - - if self.sequence_packing_scheduler is None: - self.sequence_packing_scheduler = 'default_sequence_packing' + assert self.moe_token_dispatcher_type == "alltoall", ( + f"sequence_packing only supports moe_token_dispatcher_type='alltoall', " + f"got '{self.moe_token_dispatcher_type}'" + ) - supported_schedulers = ['default_sequence_packing'] - if self.sequence_packing_scheduler not in supported_schedulers: + supported_schedulers = ['dp_balanced'] + if ( + self.sequence_packing_scheduler is not None + and self.sequence_packing_scheduler not in supported_schedulers + ): raise ValueError( f"Unknown scheduler: {self.sequence_packing_scheduler}. " f"Available schedulers: {supported_schedulers}" diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index a3a9cce9922..b6eebe6d41f 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -999,7 +999,7 @@ def validate_args(args, defaults={}): # during pipeline parallelism, it should not be set if sequence length # is constant during training. args.variable_seq_lengths = False - if args.sequence_packing: + if args.sequence_packing_scheduler is not None: args.variable_seq_lengths = True assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ @@ -1726,6 +1726,9 @@ def _add_network_size_args(parser): "persist_layer_norm", "bias_dropout_fusion", "apply_rope_fusion", + "max_seqlen_per_dp_cp_rank", + "hybrid_context_parallel", + "sequence_packing_scheduler", ] transformer_factory = ArgumentGroupFactory(TransformerConfig, exclude=exclude) transformer_group = transformer_factory.build_group(parser, "transformer configuration") @@ -2970,5 +2973,7 @@ def _add_sft_args(parser): group.add_argument('--sequence-packing', action='store_true', help='use sequence packing(thd format) for training') group.add_argument('--sft-mock-dataset-config-json', type=str, default=None, - help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution.') + help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution. ' + 'If not specified and --mock-data is set, defaults to a lognormal distribution with ' + 'min_seq_len=seq_length//2, max_seq_len=seq_length, mean_seq_len=seq_length*3//4, lognormal_sigma=1.1.') return parser diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index 60ed2b45975..27de8277630 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -4,7 +4,7 @@ from collections import Counter import json import math -from typing import Any, Dict, Optional, List +from typing import Any, Dict, Optional, List, Union import numpy as np import pandas as pd @@ -50,7 +50,6 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> list: return self.dataset[idx]["messages"] - class SFTDataset(MegatronDataset): """The dataset used during SFT""" @@ -64,7 +63,7 @@ def __init__( config: GPTDatasetConfig, ) -> None: super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) - # Pre-calculate padding divisor to avoid redundant computation in get_padding_size + # Pre-calculate padding divisor to avoid redundant computation in get_item self.padding_divisor = self._calculate_padding_divisor() @staticmethod @@ -112,19 +111,9 @@ def _calculate_padding_divisor(self) -> int: # TODO(tailaim): do we need to pad for FP8 execution? # divisor = ((divisor + 15) // 16) * 16 return divisor - - def get_padding_size( - self, - seq_len: int, - ) -> int: - seq_len_padded = math.ceil(seq_len / self.padding_divisor) * self.padding_divisor - assert seq_len > seq_len_padded / 2 / self.config.context_parallel_size * (self.config.context_parallel_size - 1), \ - f"sequence length {seq_len} is too short, the divisor is {self.padding_divisor}, that means cp_rank \ - {self.config.context_parallel_size-1} will have no valid tokens" - return seq_len_padded def __getitem__(self, idx: int) -> Dict[str, Any]: - sequence_packing = self.config.sequence_packing + tokenizer = self.config.tokenizer pack_length = self.config.sequence_length @@ -163,12 +152,11 @@ def extend_with_padding(tokens, targets, positions, pad_len): assert not self.config.reset_position_ids pack_positions.extend(range(len(tokens_list))) - if self.config.context_parallel_size > 1: - pad_granularity = self.config.context_parallel_size * 2 - mod_token_count = len(pack_tokens) % pad_granularity - if mod_token_count != 0: - pad_len = pad_granularity - mod_token_count - extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + pad_granularity = self.padding_divisor + mod_token_count = len(pack_tokens) % pad_granularity + if mod_token_count != 0: + pad_len = pad_granularity - mod_token_count + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) # TODO(duncan): Consider also padding to multiple of number of tokens here. This might # be needed for efficiency (and potentially set via command-line argument). @@ -303,64 +291,81 @@ def __len__(self) -> int: return self.num_samples def __getitem__(self, idx: int) -> Dict[str, Any]: - sequence_packing = self.config.sequence_packing + tokenizer = self.config.tokenizer - max_seq_len = self.config.sequence_length + pack_length = self.config.sequence_length + eod = tokenizer.eod + pad = tokenizer.pad tokens = self.dataset[int(self.indices[idx % len(self.indices)])] - target = np.array(tokens, dtype=np.int64) - - force_eod_length = int(tokenizer.force_eod) - - if len(tokens) > max_seq_len - force_eod_length: - # cut the right side - tokens = tokens[: max_seq_len - force_eod_length] - target = target[: max_seq_len - force_eod_length] - # tokens = tokens[(-max_seq_len + force_eod_length):] - # target = target[(-max_seq_len + force_eod_length):] - - # padding - num_tokens = len(tokens) + force_eod_length - if sequence_packing: - padding_len = self.get_padding_size(num_tokens) - num_tokens - else: - padding_len = max_seq_len - num_tokens - assert padding_len >= 0 - filler = [tokenizer.eod] * force_eod_length + [tokenizer.pad] * (padding_len + 1) - - tokens = np.array(tokens.tolist() + filler, dtype=np.int64) - target = np.array(target.tolist() + filler, dtype=np.int64) - - tokens = torch.tensor(tokens) - target = torch.tensor(target) - - tokens = tokens[:-1].contiguous() - target = target[1:].contiguous() - seq_len = tokens.numel() - - loss_mask, position_ids, attention_mask = self._get_ltor_masks_and_position_ids( - seq_len, target, tokenizer.pad - ) - - if self.config.create_attention_mask: - ret = { - 'tokens': tokens, - 'labels': target, - 'attention_mask': attention_mask, - 'loss_mask': loss_mask, - 'position_ids': position_ids, - } - else: - ret = { - 'tokens': tokens, - 'labels': target, - 'loss_mask': loss_mask, - 'position_ids': position_ids, - } - - if sequence_packing: - # sequence packing need both original sequence length and padded length - ret['original_seq_len'] = torch.tensor(num_tokens, dtype=torch.int32, device=tokens.device) - ret['padded_seq_len'] = torch.tensor(seq_len, dtype=torch.int32, device=tokens.device) - - return ret + + def extend_with_padding(tokens, targets, positions, pad_len): + tokens.extend([pad] * pad_len) + targets.extend([pad] * pad_len) + positions.extend(range(positions[-1] + 1, positions[-1] + 1 + pad_len)) + + # Convert tokens to list and add EOD + tokens_list = tokens.tolist() + if tokens_list[-1] != eod: + tokens_list.append(eod) + targets_list = list(tokens_list) + + pack_tokens = list(tokens_list) + pack_targets = list(targets_list) + pack_positions = list(range(len(tokens_list))) + cu_seqlens = [0] + + # Pad to padding_divisor alignment + if self.padding_divisor > 1: + mod_token_count = len(pack_tokens) % self.padding_divisor + if mod_token_count != 0: + pad_len = self.padding_divisor - mod_token_count + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + + # Record padded boundary after padding + cu_seqlens.append(len(pack_tokens)) + + # Handle any necessary truncation + if len(pack_tokens) >= pack_length + 1: # +1 here to account for later alignment + max_body = pack_length - 1 + pack_tokens = pack_tokens[:max_body] + pack_targets = pack_targets[:max_body] + pack_tokens.extend([eod, pad]) + pack_targets.extend([eod, pad]) + pack_positions = pack_positions[:pack_length + 1] + cu_seqlens[-1] = len(pack_tokens) - 1 + + # Handle any necessary padding + if len(pack_tokens) < pack_length + 1: # +1 here to account for later alignment + pad_len = pack_length + 1 - len(pack_tokens) + extend_with_padding(pack_tokens, pack_targets, pack_positions, pad_len) + cu_seqlens[-1] = len(pack_tokens) - 1 + + assert len(pack_tokens) == pack_length + 1 + assert len(pack_targets) == pack_length + 1 + assert len(pack_positions) == pack_length + 1 + + # Align and convert to tensors + input_ids = torch.tensor(pack_tokens[:-1], dtype=torch.int64) + labels = torch.tensor(pack_targets[1:], dtype=torch.int64) + position_ids = torch.tensor(pack_positions[:-1], dtype=torch.int64) + + # Loss mask + loss_mask = torch.ones(pack_length, dtype=torch.float32) + loss_mask[labels == pad] = 0.0 + loss_mask[labels == IGNORE_INDEX] = 0.0 + + assert len(cu_seqlens) >= 2 + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) + # Calculating max_seqlen here because of possible effects of truncation and padding + adjacent_diffs = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = adjacent_diffs.max() # max_seqlen is a 0-D tensor + + return { + 'tokens': input_ids, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + 'cu_seqlens': cu_seqlens, + 'max_seqlen': max_seqlen, + } \ No newline at end of file diff --git a/megatron/training/training.py b/megatron/training/training.py index a5b225b2c82..1326dc39634 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -168,7 +168,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): get_num_microbatches, update_num_microbatches ) -from megatron.core.datasets.data_schedule import wrap_dataloader +from megatron.core.datasets.data_schedule import wrap_data_iterator from .async_utils import maybe_finalize_async_save from .utils import ( @@ -1721,7 +1721,7 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch num_microbatches, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch, - ) = wrap_dataloader(data_iterator, config, num_microbatches) + ) = wrap_data_iterator(data_iterator, config, get_num_microbatches()) else: # data_iterator unchanged num_microbatches = get_num_microbatches() @@ -2855,7 +2855,7 @@ def trace_handler(p): # Completely skip iteration if needed. if iteration in args.iterations_to_skip: # TODO(tailaim): this need to be modified - assert not args.sequence_packing, "Sequence packing is not supported in skip iteration mode" + assert config.sequence_packing_scheduler is None, "Sequence packing scheduler is not supported in skip iteration mode" # Dummy train_step to fast forward train_data_iterator. dummy_train_step(train_data_iterator) if iteration == start_iteration: diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 1ce598d6868..9beadc1ec37 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -68,9 +68,10 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): args = get_args() config = core_transformer_config_from_args(args) - if args.sequence_packing: + if args.sequence_packing_scheduler is not None: return get_batch_on_this_rank_for_sequence_packing( data_iterator, + vpp_size=config.virtual_pipeline_model_parallel_size, mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage), vp_stage=vp_stage, ) @@ -260,7 +261,7 @@ def core_gpt_dataset_config_from_args(args): "sequence_parallel_size": args.tensor_model_parallel_size*args.sequence_parallel, "hybrid_context_parallel": args.hybrid_context_parallel, "sft_mock_dataset_config_json":args.sft_mock_dataset_config_json, - "sequence_packing": args.sequence_packing, + "sequence_packing_scheduler": args.sequence_packing_scheduler, } # add FIM args to the config diff --git a/tests/unit_tests/test_sequence_packing.py b/tests/unit_tests/test_sequence_packing.py index bac7ea79db1..288e66f2f93 100644 --- a/tests/unit_tests/test_sequence_packing.py +++ b/tests/unit_tests/test_sequence_packing.py @@ -1,12 +1,20 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +import random from types import SimpleNamespace +import numpy as np import pytest import torch from megatron.core import parallel_state -from megatron.core.datasets.data_schedule import get_batch_on_this_rank_for_sequence_packing +from megatron.core.datasets.data_schedule import ( + PackingScheduler, + get_batch_on_this_rank_for_sequence_packing, + scheduler_map, + wrap_data_iterator, +) +from megatron.core.rerun_state_machine import RerunDataIterator from megatron.training.global_vars import unset_global_variables from tests.unit_tests.test_utilities import Utils @@ -32,7 +40,6 @@ def __init__( total_seq_length: Total length of packed sequences sequence_lengths: List of individual sequence lengths (variable-length). If None, generates random variable lengths. - local_cp_size: Local CP size for hybrid context parallel device: Device to create tensors on seed: Random seed for reproducibility """ @@ -127,21 +134,19 @@ def _gather_tensor_from_all_ranks(tensor): @pytest.mark.parametrize( - ("tp", "pp", "cp", "hybrid_cp"), + ("tp", "pp", "cp"), [ - (1, 1, 1, False), # Basic case: no parallelism - (2, 1, 1, False), # Tensor parallel only - (1, 2, 1, False), # Pipeline parallel only - (2, 2, 1, False), # TP + PP - (1, 1, 2, False), # CP only - (2, 1, 2, False), # TP + CP - (1, 2, 2, False), # PP + CP - (1, 4, 1, False), # Has middle pp stage - (1, 1, 1, True), # Hybrid CP enabled (CP=1 with hybrid groups) - (2, 1, 1, True), # TP + Hybrid CP + (1, 1, 1), # Basic case: no parallelism + (2, 1, 1), # Tensor parallel only + (1, 2, 1), # Pipeline parallel only + (2, 2, 1), # TP + PP + (1, 1, 2), # CP only + (2, 1, 2), # TP + CP + (1, 2, 2), # PP + CP + (1, 4, 1), # Has middle pp stage ], ) -def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): +def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp): """ Test get_batch_on_this_rank_for_sequence_packing function with variable-length THD format. @@ -155,7 +160,6 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): args.tensor_model_parallel_size = tp args.pipeline_model_parallel_size = pp args.context_parallel_size = cp - args.hybrid_context_parallel = hybrid_cp args.virtual_pipeline_model_parallel_size = None args.data_parallel_size = 8 // (tp * pp * cp) args.seq_length = 8192 @@ -165,21 +169,12 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): raise ValueError(f"Invalid config: tp={tp}, pp={pp}, cp={cp} exceeds world size 8") # Initialize model parallel - Utils.initialize_model_parallel( - tp, - pp, - None, - context_parallel_size=cp, - hybrid_context_parallel=hybrid_cp, - min_hybrid_context_parallel_size=1, - ) + Utils.initialize_model_parallel(tp, pp, None, context_parallel_size=cp) try: # Create mock data iterator with variable-length sequences # Only TP rank 0 needs the iterator; other TP ranks pass None tp_rank = parallel_state.get_tensor_model_parallel_rank() - local_cp_size = 8 // (tp * pp) if hybrid_cp else None - if tp_rank == 0: # Use deterministic seed based on DP rank so same data within TP/PP/CP group dp_rank = parallel_state.get_data_parallel_rank() @@ -191,7 +186,6 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): MockVariableLengthSequencePackingDataIterator( total_seq_length=args.seq_length, sequence_lengths=sequence_lengths, # Variable lengths, sum=8192 - local_cp_size=local_cp_size, seed=42 + dp_rank, # Same seed within PP/CP group ) ) @@ -201,10 +195,7 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): # Call the function under test result = get_batch_on_this_rank_for_sequence_packing( - data_iterator=data_iterator, - mtp_on_this_rank=False, - vp_stage=None, - hybrid_context_parallel=hybrid_cp, + data_iterator=data_iterator, mtp_on_this_rank=False, vp_stage=None ) # Unpack the result @@ -248,9 +239,6 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): # ===================================================================== assert packed_seq_params is not None assert packed_seq_params.qkv_format == "thd" - if hybrid_cp: - assert packed_seq_params.local_cp_size is not None - assert packed_seq_params.cp_group is not None test_keys = [ "cu_seqlens_q", @@ -260,8 +248,6 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): "cu_seqlens_kv_padded", "max_seqlen_kv", ] - if hybrid_cp: - test_keys.append("local_cp_size") for key in test_keys: tensor = getattr(packed_seq_params, key) assert tensor is not None @@ -291,18 +277,9 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): # ===================================================================== # TEST 4: Verify CP partitioning # ===================================================================== - if cp > 1 or hybrid_cp: - if hybrid_cp: - assert packed_seq_params.local_cp_size is not None - cp_size = packed_seq_params.local_cp_size - assert packed_seq_params.cp_group == ( - parallel_state.get_hybrid_data_context_parallel_groups(group_size=cp_size) - ) - else: - cp_size = cp - + if cp > 1: # With CP, the sequence should be partitioned - expected_seq_len = args.seq_length // cp_size + expected_seq_len = args.seq_length // cp if is_first_stage: actual_seq_len = tokens.shape[1] @@ -320,3 +297,185 @@ def test_get_batch_on_this_rank_for_sequence_packing(tp, pp, cp, hybrid_cp): finally: Utils.destroy_model_parallel() unset_global_variables() + + +@pytest.mark.parametrize( + ("tp", "pp", "cp", "vpp", "scheduler_type"), + [ + (1, 1, 8, None, "dp_balanced"), + (2, 1, 4, None, "dp_balanced"), + (2, 4, 1, None, "dp_balanced"), + (2, 2, 1, None, "dp_balanced"), + (1, 4, 1, 4, "dp_balanced"), + ], +) +def test_wrap_dataloader(tp, pp, cp, vpp, scheduler_type): + ''' + Test wrap_dataloader function with different scheduler types. + ''' + args = SimpleNamespace() + args.tensor_model_parallel_size = tp + args.pipeline_model_parallel_size = pp + args.context_parallel_size = cp + args.virtual_pipeline_model_parallel_size = None + args.data_parallel_size = 8 // (tp * pp * cp) + args.seq_length = 8192 + args.max_seqlen_per_dp_cp_rank = 8192 + + # Skip invalid configurations + if args.data_parallel_size < 1: + raise ValueError(f"Invalid config: tp={tp}, pp={pp}, cp={cp} exceeds world size 8") + + def _create_single_sample(seq_len): + # hard code the padding size to 16 + pad_size = 16 + seq_len_padded = ((seq_len + pad_size - 1) // pad_size) * pad_size + device = torch.device("cuda", torch.cuda.current_device()) + tokens = torch.randint(0, 128, (seq_len_padded,), dtype=torch.int64, device=device) + labels = tokens + 1 + position_ids = torch.arange(seq_len_padded, dtype=torch.int64, device=device) + loss_mask = torch.ones(seq_len_padded, dtype=torch.float32, device=device) + loss_mask[0:seq_len] = 1 + loss_mask[seq_len:] = 0 + cu_seqlens = torch.tensor([0, seq_len_padded], dtype=torch.int32, device=device) + + return { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + 'cu_seqlens': cu_seqlens, + } + + # Initialize model parallel + Utils.initialize_model_parallel(tp, pp, vpp, context_parallel_size=cp) + + global_batch_size = 64 + micro_batch_size = 1 + nums = [random.randint(2048, args.seq_length) for _ in range(global_batch_size)] # 64 sequences + + config = SimpleNamespace() + config.max_seqlen_per_dp_cp_rank = args.max_seqlen_per_dp_cp_rank + config.microbatch_group_size_per_vp_stage = pp + config.virtual_pipeline_model_parallel_size = vpp + config.sequence_packing_scheduler = scheduler_type + + dp_rank = parallel_state.get_data_parallel_rank() + dp_size = parallel_state.get_data_parallel_world_size() + + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + is_pp_first = pp_rank == 0 + is_pp_last = pp_rank == pp - 1 + is_pp_first_or_last = is_pp_first or is_pp_last + is_tp_first = tp_rank == 0 + + num_micro_batches_old = global_batch_size // micro_batch_size // dp_size + + if is_tp_first and (is_pp_first or is_pp_last): + samples = [ + _create_single_sample(num) + for num in nums[dp_rank * num_micro_batches_old : (dp_rank + 1) * num_micro_batches_old] + ] + data_iterator = RerunDataIterator(iter(samples)) + else: + data_iterator = None + + if is_tp_first: + if vpp is not None and vpp > 1: + if is_pp_first: + data_iterator = [data_iterator] + [None for _ in range(vpp - 1)] + elif is_pp_last: + data_iterator = [None for _ in range(vpp - 1)] + [data_iterator] + else: + data_iterator = [None for _ in range(vpp)] + try: + # Call the function under test + ( + new_data_iterator, + num_micro_batches, + num_total_tokens_this_global_batch, + sequence_square_sum_this_global_batch, + ) = wrap_data_iterator(data_iterator, config, num_micro_batches_old) + + # check the result + assert type(num_micro_batches) is int + assert ( + type(num_total_tokens_this_global_batch) is float + or type(num_total_tokens_this_global_batch) is np.float32 + ) + assert ( + type(sequence_square_sum_this_global_batch) is float + or type(sequence_square_sum_this_global_batch) is np.float32 + ) + + def _check_batch(batch_all, batch_keys): + for batch in batch_all: + assert set(batch_keys) <= set( + batch.keys() + ), f"batch keys: {set(batch.keys())} missing {set(batch_keys) - set(batch.keys())}" + for key in batch_keys: + assert batch[key] is not None + + if is_tp_first: + # CHECK KEYS + batch_keys = ["cu_seqlens", "max_seqlen", "cu_seqlens_padded"] + if vpp is not None and vpp > 1: + # check metadata for all stages (save batches to avoid re-consuming iterators) + all_stage_batches = [] + for temp_data_iterator in new_data_iterator: + stage_batch = [next(temp_data_iterator) for _ in range(num_micro_batches)] + all_stage_batches.append(stage_batch) + _check_batch(stage_batch, batch_keys) + + # check for first or last stage on first or last pp rank + if is_pp_first_or_last: + batch_all = all_stage_batches[0] if is_pp_first else all_stage_batches[-1] + batch_keys += ["tokens", "position_ids", "labels", "loss_mask"] + _check_batch(batch_all, batch_keys) + else: + # non-VPP: single iterator + batch_all = [next(new_data_iterator) for _ in range(num_micro_batches)] + if is_pp_first_or_last: + batch_keys += ["tokens", "position_ids", "labels", "loss_mask"] + _check_batch(batch_all, batch_keys) + + # CHECK TOKEN SUM ON FIRST OR LAST PP RANK + # Note: data_iterator is consumed by wrap_data_iterator, new_data_iterator is consumed above. + # Use `samples` for before-wrap, reuse `batch_all` from the check above for after-wrap. + if is_pp_first_or_last: + # Compute token sum before wrap + token_sum_before = torch.tensor(0, dtype=torch.int64, device='cuda') + for sample in samples: + token_sum_before += sample['tokens'].long().sum() + + # Compute token sum after wrap (batch_all already collected above with tokens) + token_sum_after = torch.tensor(0, dtype=torch.int64, device='cuda') + for batch in batch_all: + token_sum_after += batch['tokens'].long().sum() + + # Reduce sum across dp_cp group and verify equality + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=False) + torch.distributed.all_reduce( + token_sum_before, op=torch.distributed.ReduceOp.SUM, group=dp_cp_group + ) + torch.distributed.all_reduce( + token_sum_after, op=torch.distributed.ReduceOp.SUM, group=dp_cp_group + ) + + assert ( + token_sum_before == token_sum_after + ), f"Token sum mismatch: before={token_sum_before.item()}, after={token_sum_after.item()}" + + else: + if vpp is not None and vpp > 1: + assert type(new_data_iterator) is list and len(new_data_iterator) == vpp + for data_iterator in new_data_iterator: + assert data_iterator is None + else: + assert new_data_iterator is None + + finally: + Utils.destroy_model_parallel() + unset_global_variables() From c226649307f4a31f83064f30a5cce5afd6d58d36 Mon Sep 17 00:00:00 2001 From: xiaoyao0115 <1804647152@qq.com> Date: Thu, 12 Feb 2026 05:42:09 -0800 Subject: [PATCH 3/3] small fixes according to comments Signed-off-by: xiaoyao0115 <1804647152@qq.com> --- megatron/core/datasets/data_schedule.py | 20 ++++----- ...eduler_utils.py => data_schedule_utils.py} | 8 ++-- megatron/core/datasets/readme.md | 22 +++++++++ megatron/core/model_parallel_config.py | 2 +- .../text/libraries/sft_tokenizer.py | 1 + .../core/transformer/transformer_config.py | 2 +- megatron/training/arguments.py | 5 --- megatron/training/datasets/sft_dataset.py | 45 ++++++++++--------- .../unit_tests/models/test_mamba_moe_model.py | 1 + tests/unit_tests/test_sequence_packing.py | 2 - 10 files changed, 64 insertions(+), 44 deletions(-) rename megatron/core/datasets/{data_scheduler_utils.py => data_schedule_utils.py} (98%) diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 76e53c6fe37..deef93b8ebd 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -6,7 +6,7 @@ import torch from megatron.core import parallel_state -from megatron.core.datasets.data_scheduler_utils import ( +from megatron.core.datasets.data_schedule_utils import ( broadcast_scalars, broadcast_tensor, broadcast_to_pp_group, @@ -312,7 +312,7 @@ def __next__(self) -> Any: return samples_this_rank_with_id, sample_id_groups -class BaseScheduler: +class BasePackingScheduler: """Base class for sequence packing schedulers.""" def __init__( @@ -335,7 +335,7 @@ def __init__( self.dp_size = dp_size self.microbatch_group_size_per_vp_stage = microbatch_group_size_per_vp_stage - def get_require_sample_keys(self): + def get_required_sample_keys(self): """Return the required key of each batch.""" raise NotImplementedError @@ -376,14 +376,14 @@ def run( raise NotImplementedError -class DpBalancedScheduler(BaseScheduler): +class DpBalancedScheduler(BasePackingScheduler): """Packs sequences in their original order until reaching the max limit of sequence length.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_seq_len_all_ranks = self.max_seqlen_per_dp_cp_rank * self.cp_size - def get_require_sample_keys(self): + def get_required_sample_keys(self): """Return the required key of each batch.""" return [ "tokens", @@ -521,7 +521,7 @@ def run( ) # Step 2: Check required sample keys - for key in self.get_require_sample_keys(): + for key in self.get_required_sample_keys(): assert ( key in batch[0] ), f"Batch missing required key {key}, provided keys: {batch[0].keys()}" @@ -620,14 +620,14 @@ def run( ) -class PackingScheduler(enum.Enum): +class PackingSchedulerEnum(enum.Enum): """Enum for supported sequence packing algorithms.""" DP_BALANCED = "dp_balanced" -scheduler_map: Dict[PackingScheduler, Type[BaseScheduler]] = { - PackingScheduler.DP_BALANCED: DpBalancedScheduler +scheduler_map: Dict[PackingSchedulerEnum, Type[BasePackingScheduler]] = { + PackingSchedulerEnum.DP_BALANCED: DpBalancedScheduler } @@ -668,7 +668,7 @@ def wrap_data_iterator( # Convert string to enum scheduler_type = config.sequence_packing_scheduler - scheduler_type = PackingScheduler[scheduler_type.upper()] + scheduler_type = PackingSchedulerEnum[scheduler_type.upper()] scheduler = scheduler_map[scheduler_type]( config.max_seqlen_per_dp_cp_rank, diff --git a/megatron/core/datasets/data_scheduler_utils.py b/megatron/core/datasets/data_schedule_utils.py similarity index 98% rename from megatron/core/datasets/data_scheduler_utils.py rename to megatron/core/datasets/data_schedule_utils.py index 2a1e2a6528d..2e5635dfcd8 100644 --- a/megatron/core/datasets/data_scheduler_utils.py +++ b/megatron/core/datasets/data_schedule_utils.py @@ -8,7 +8,7 @@ from megatron.core.rerun_state_machine import RerunDataIterator -def _unpack_batch(batch): +def _unpack_batch(batch: List[Dict[str, torch.Tensor]]) -> List[Dict[str, torch.Tensor]]: """ Unpacks the packed samples into a list of sub-samples. Since each sub-sample may be routed to different DPxCP ranks, @@ -474,9 +474,9 @@ def get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group): else: raise ValueError(f"Invalid item type: {type(item)}") - # This unpack step is redundant: in sft_dataset.py, sequences are already packed before - # rescheduling, so we need to unpack them here and repack after rescheduling. This is only - # to adapt to the current megatron-lm sft_dataset. + # in sft_dataset.py, sequences are already packed before rescheduling, + # so we need to unpack them here and repack after rescheduling. + # This is only to adapt to the current megatron-lm sft_dataset. # If you implement your own dataset, just have __getitem__ return List[Dict] # and this step can be skipped. batch = _unpack_batch(batch) diff --git a/megatron/core/datasets/readme.md b/megatron/core/datasets/readme.md index 452bf24e4a2..64889feb481 100644 --- a/megatron/core/datasets/readme.md +++ b/megatron/core/datasets/readme.md @@ -192,6 +192,28 @@ To query the `BlendedDataset` for the _k_-th sample we do the following To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `BlendedDataset.__init__` function. +## Packing Scheduler + +The packing scheduler re-schedules variable-length sequences across DP×CP ranks to improve GPU utilization. It is built around the following modules: + +### `data_schedule` + +This module contains the high-level scheduling logic and entry points: + +- **`HybridCPDataLoaderWrapper`**: A wrapper class for hybrid context parallel (CP) scheduling. For every `__next__` call, it: (1) pulls a batch of packed samples from each DP rank, (2) gathers sequence lengths across the DP group, (3) schedules sub-samples using the `BalancedCPScheduler`, (4) reroutes sub-samples to the correct DPxCP ranks via all-to-all communication. + +- **`BasePackingScheduler`**: Abstract base class for packing schedulers. Defines the interface for `get_groups_and_subsamples()` (scheduling algorithm) and `run()` (full scheduling pipeline including fetch, schedule, reroute, pack, broadcast, and VPP handling). + +- **`DpBalancedScheduler`**: A concrete scheduler that packs sequences in their original order until reaching the max sequence length limit per DPxCP rank. Supports aligning the number of microbatches to DP size and VPP stage multiples. + +- **`wrap_data_iterator()`**: Top-level entry point that wraps an existing `data_iterator`. It creates the appropriate scheduler, runs the scheduling pipeline, broadcast metadata and new num_microbatches, returns a new data iterator along with the updated number of microbatches and FLOPs statistics. + +- **`get_batch_on_this_rank_for_sequence_packing()`**: Fetches and broadcasts a single packed microbatch for the current rank. Handles TP/PP broadcasting, constructs `PackedSeqParams` (with `cu_seqlens`, `max_seqlen`, `qkv_format=thd`), and optionally partitions sequences across CP ranks using Transformer Engine's `thd_get_partitioned_indices`. + +### `data_schedule_utils.py` + +This module contains the utility functions used by the schedulers. + ## Fast DataLoader initialization Especially for large-scale runs, DataLoader initialization can take several minutes, since it involves opening and memory-mapping multiple files and can significantly stress the filesystem. To speed up this process, we have developed the following three optimizations, controlled by configuration flags": diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index ca3dbbcfa05..b8d23a4735c 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -69,7 +69,7 @@ class ModelParallelConfig: Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel. """ - sequence_packing_scheduler: Optional[str] = None + sequence_packing_scheduler: Optional[Literal['dp_balanced']] = None """ Scheduler for sequence packing and hybrid context parallel. dp_balanced: DP-balanced scheduler for sequence packing. diff --git a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py index 8a418f2dd7f..b9a4f7b0e4b 100644 --- a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py @@ -108,6 +108,7 @@ def __init__(self, tokenizer_path: str, prompt_format: str): self._prompt_format = prompt_format + def tokenize_conversation( self, conversation: List[Dict], return_target: bool, add_generation_prompt: bool ): diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 6ec252f82fa..20327d58112 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2049,7 +2049,7 @@ def __post_init__(self): and self.sequence_packing_scheduler not in supported_schedulers ): raise ValueError( - f"Unknown scheduler: {self.sequence_packing_scheduler}. " + f"Unsupported scheduler: {self.sequence_packing_scheduler}. " f"Available schedulers: {supported_schedulers}" ) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index b6eebe6d41f..a4f4b982031 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1004,9 +1004,6 @@ def validate_args(args, defaults={}): assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ f'must be >= single sequence max length ({args.seq_length})' - # TODO(tailaim): add support for other dispatcher types - print(f"Setting moe_token_dispatcher_type to alltoall for sft sequence packing with pipeline parallelism") - args.moe_token_dispatcher_type = "alltoall" if args.mock_data and args.sft_mock_dataset_config_json is None: args.sft_mock_dataset_config_json = json.dumps( { @@ -2970,8 +2967,6 @@ def _add_sft_args(parser): group.add_argument('--sft', action="store_true", help='Megatron SFT training') group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", help='SFT prompt format.') - group.add_argument('--sequence-packing', action='store_true', - help='use sequence packing(thd format) for training') group.add_argument('--sft-mock-dataset-config-json', type=str, default=None, help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution. ' 'If not specified and --mock-data is set, defaults to a lognormal distribution with ' diff --git a/megatron/training/datasets/sft_dataset.py b/megatron/training/datasets/sft_dataset.py index 27de8277630..d7044247e6c 100644 --- a/megatron/training/datasets/sft_dataset.py +++ b/megatron/training/datasets/sft_dataset.py @@ -222,7 +222,11 @@ class MockSFTLowLevelDataset: """The low-level mock dataset for SFT Args: - mock_config (dict): The config for mock dataset. + mode (str): Either 'file' or 'distribution'. + **kwargs: Additional arguments depending on mode. + For mode='file': path (str) - path to a CSV file with sequence lengths. + For mode='distribution': type (str), min_seq_len (int), max_seq_len (int), + mean_seq_len (int), and distribution-specific params (e.g. lognormal_sigma). """ seed: int = 0 @@ -230,28 +234,26 @@ class MockSFTLowLevelDataset: size: int = 1000000 """The hard-coded number of sequence to generate""" - - # This is to maintain consistency with the SFT dataset that uses real data. In the real dataset, an element in the low-level dataset often contains multiple sequences. So here, each element in the mock low-level dataset also contains num_sequence_per_sample sequences. This will be made more reasonable in the future. - - def __init__(self, config: Dict) -> None: + def __init__(self, mode: str, **kwargs) -> None: np.random.seed(self.seed) - # either choose to load sequence lengths from external file, or generate random sequence lengths - - assert "mode" in config, f"mode must be set, either 'file' or 'distribution'" - - if config["mode"] == "file": - self.sequence_lengths = np.array(pd.read_csv(config["path"])).flatten() + + if mode == "file": + self.sequence_lengths = np.array(pd.read_csv(kwargs["path"])).flatten() self.size = len(self.sequence_lengths) - elif config["mode"] == "distribution": - min_seq_len = config["min_seq_len"] - max_seq_len = config["max_seq_len"] - mean_seq_len = config["mean_seq_len"] - if config["type"] == "lognormal": - lognormal_sigma = config["lognormal_sigma"] - self.sequence_lengths = self.generate_lognormal_samples(self.size, mean_seq_len,lognormal_sigma, min_seq_len, max_seq_len) + elif mode == "distribution": + min_seq_len = kwargs["min_seq_len"] + max_seq_len = kwargs["max_seq_len"] + mean_seq_len = kwargs["mean_seq_len"] + if kwargs["type"] == "lognormal": + lognormal_sigma = kwargs["lognormal_sigma"] + self.sequence_lengths = self.generate_lognormal_samples( + self.size, mean_seq_len, lognormal_sigma, min_seq_len, max_seq_len + ) else: - raise ValueError(f"Unsupported sequence length distribution type {config['type']}") + raise ValueError(f"Unsupported distribution type {kwargs['type']}") + else: + raise ValueError(f"Unsupported mode '{mode}', must be 'file' or 'distribution'") def generate_lognormal_samples(self, size, mean, sigma, min_seq_len, max_seq_len): mu = np.log(mean) - sigma**2 / 2 @@ -263,9 +265,10 @@ def __len__(self) -> int: return self.size def __getitem__(self, idx: int) -> List[np.ndarray]: - length = self.sequence_lengths[idx % self.size] # the length of sample is 'length', but only length-1 elements are generated here, # because an eod token will be appended at the end later in SFTDataset + + length = self.sequence_lengths[idx % self.size] sample = np.arange(1, length, dtype=np.int64) return sample class MockSFTDataset(SFTDataset): @@ -285,7 +288,7 @@ def __init__( @staticmethod def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> LowLevelDataset: mock_config = json.loads(config.sft_mock_dataset_config_json) - return MockSFTLowLevelDataset(mock_config) + return MockSFTLowLevelDataset(**mock_config) def __len__(self) -> int: return self.num_samples diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index f933d811779..912d97528e7 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -277,6 +277,7 @@ "offload_modules": [], "hybrid_context_parallel": False, "max_seqlen_per_dp_cp_rank": None, + "sequence_packing_scheduler": None, } # Fields to ignore entirely (ephemeral, environment-specific, very large). SKIP_FIELDS = set() diff --git a/tests/unit_tests/test_sequence_packing.py b/tests/unit_tests/test_sequence_packing.py index 288e66f2f93..60316b0236e 100644 --- a/tests/unit_tests/test_sequence_packing.py +++ b/tests/unit_tests/test_sequence_packing.py @@ -9,9 +9,7 @@ from megatron.core import parallel_state from megatron.core.datasets.data_schedule import ( - PackingScheduler, get_batch_on_this_rank_for_sequence_packing, - scheduler_map, wrap_data_iterator, ) from megatron.core.rerun_state_machine import RerunDataIterator