Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
576 changes: 574 additions & 2 deletions megatron/core/datasets/data_schedule.py

Large diffs are not rendered by default.

492 changes: 492 additions & 0 deletions megatron/core/datasets/data_schedule_utils.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions megatron/core/datasets/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig):
context_parallel_size: Optional[int] = None
"""The size of the context parallel group. Needed for padding in packed sequences."""

sft_mock_dataset_config_json: Optional[str] = None
"""This config provides the necessary information for the mock dataset."""

sequence_packing_scheduler: Optional[str] = None
"""Scheduler for sequence packing and hybrid context parallel.
dp_balanced: DP-balanced scheduler for sequence packing.
"""

def __post_init__(self) -> None:
"""Do asserts and set fields post init"""
super().__post_init__()
Expand Down
22 changes: 22 additions & 0 deletions megatron/core/datasets/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,28 @@ To query the `BlendedDataset` for the _k_-th sample we do the following

To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `BlendedDataset.__init__` function.

## Packing Scheduler

The packing scheduler re-schedules variable-length sequences across DP×CP ranks to improve GPU utilization. It is built around the following modules:

### `data_schedule`

This module contains the high-level scheduling logic and entry points:

- **`HybridCPDataLoaderWrapper`**: A wrapper class for hybrid context parallel (CP) scheduling. For every `__next__` call, it: (1) pulls a batch of packed samples from each DP rank, (2) gathers sequence lengths across the DP group, (3) schedules sub-samples using the `BalancedCPScheduler`, (4) reroutes sub-samples to the correct DPxCP ranks via all-to-all communication.

- **`BasePackingScheduler`**: Abstract base class for packing schedulers. Defines the interface for `get_groups_and_subsamples()` (scheduling algorithm) and `run()` (full scheduling pipeline including fetch, schedule, reroute, pack, broadcast, and VPP handling).

- **`DpBalancedScheduler`**: A concrete scheduler that packs sequences in their original order until reaching the max sequence length limit per DPxCP rank. Supports aligning the number of microbatches to DP size and VPP stage multiples.

- **`wrap_data_iterator()`**: Top-level entry point that wraps an existing `data_iterator`. It creates the appropriate scheduler, runs the scheduling pipeline, broadcast metadata and new num_microbatches, returns a new data iterator along with the updated number of microbatches and FLOPs statistics.

- **`get_batch_on_this_rank_for_sequence_packing()`**: Fetches and broadcasts a single packed microbatch for the current rank. Handles TP/PP broadcasting, constructs `PackedSeqParams` (with `cu_seqlens`, `max_seqlen`, `qkv_format=thd`), and optionally partitions sequences across CP ranks using Transformer Engine's `thd_get_partitioned_indices`.

### `data_schedule_utils.py`

This module contains the utility functions used by the schedulers.

## Fast DataLoader initialization

Especially for large-scale runs, DataLoader initialization can take several minutes, since it involves opening and memory-mapping multiple files and can significantly stress the filesystem. To speed up this process, we have developed the following three optimizations, controlled by configuration flags":
Expand Down
21 changes: 21 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,3 +2557,24 @@ def set_save_original_input(module):
from transformer_engine.pytorch.float8_tensor import Float8Tensor
except ImportError:
Float8Tensor = None


def get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank):
"""Get partitioned indices for THD format data in context parallel.

Args:
cu_seqlens: Cumulative sequence lengths tensor.
total_tokens: Total number of tokens.
cp_size: Context parallel world size.
cp_rank: Context parallel rank.

Returns:
Partitioned indices tensor.
"""
assert is_te_min_version("1.10.0"), (
"Please update Transformer Engine to >= 1.10 to use "
"Context Parallel with THD format data"
)
import transformer_engine_torch as tex

return tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank)
8 changes: 7 additions & 1 deletion megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ModelParallelConfig:
can handle without overflowing the memory. Typically, a good starting point is to set this
to maximum sequence length / context parallel size.
This is used to calculate the number and length of sub-samples assigned to
each rank when using hybrid_context_parallel.
each rank when sequence_packing_scheduler is not None.
"""

hybrid_context_parallel: bool = False
Expand All @@ -69,6 +69,12 @@ class ModelParallelConfig:
Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel.
"""

sequence_packing_scheduler: Optional[Literal['dp_balanced']] = None
"""
Scheduler for sequence packing and hybrid context parallel.
dp_balanced: DP-balanced scheduler for sequence packing.
"""

expert_model_parallel_size: int = 1
"""Distributes Moe Experts across sub data parallel dimension."""

Expand Down
9 changes: 9 additions & 0 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,9 @@ def forward_backward_no_pipelining(
pg_collection.dp_cp = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=False
)
pg_collection.dp = parallel_state.get_data_parallel_group(
with_context_parallel=False, partial_data_parallel=False
)

elif pg_collection is not None:
assert hasattr(pg_collection, 'tp')
Expand Down Expand Up @@ -879,6 +882,9 @@ def forward_backward_pipelining_with_interleaving(
pg_collection.dp_cp = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=False
)
pg_collection.dp = parallel_state.get_data_parallel_group(
with_context_parallel=False, partial_data_parallel=False
)

elif p2p_communicator is not None and pg_collection is not None:
model_type = get_model_type(model[0])
Expand Down Expand Up @@ -2028,6 +2034,9 @@ def forward_backward_pipelining_without_interleaving(
pg_collection.dp_cp = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=False
)
pg_collection.dp = parallel_state.get_data_parallel_group(
with_context_parallel=False, partial_data_parallel=False
)
elif p2p_communicator is not None and pg_collection is not None:
model_type = get_model_type(model)
assert model_type != ModelType.encoder_and_decoder, (
Expand Down
1 change: 1 addition & 0 deletions megatron/core/tokenizers/text/libraries/sft_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(self, tokenizer_path: str, prompt_format: str):

self._prompt_format = prompt_format


def tokenize_conversation(
self, conversation: List[Dict], return_target: bool, add_generation_prompt: bool
):
Expand Down
34 changes: 34 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2019,6 +2019,40 @@ def __post_init__(self):
self.attention_backend == AttnBackend.flash
), "Batch invariant mode only supports FlashAttention"

if self.sequence_packing_scheduler is not None:
# Check TE version.
if not HAVE_PACKAGING:
raise ImportError(
"packaging is not installed. Please install it with `pip install packaging`."
)
# TODO: remove this after we fix the convergence issue with TE < 2.9.
if not (
is_te_min_version("2.9.0") or get_te_version() == PkgVersion("2.9.0.dev0+5b3092a")
):
raise ValueError(
"SFT sequence packing requires Transformer Engine >= 2.9.0 "
f"but got {get_te_version()} (TE < 2.9.0 may have convergence issues)."
)

# Needed for passing variable sequences between pp stages.
self.variable_seq_lengths = True

# TODO(tailaim): add support for other dispatcher types
assert self.moe_token_dispatcher_type == "alltoall", (
f"sequence_packing only supports moe_token_dispatcher_type='alltoall', "
f"got '{self.moe_token_dispatcher_type}'"
)

supported_schedulers = ['dp_balanced']
if (
self.sequence_packing_scheduler is not None
and self.sequence_packing_scheduler not in supported_schedulers
):
raise ValueError(
f"Unsupported scheduler: {self.sequence_packing_scheduler}. "
f"Available schedulers: {supported_schedulers}"
)


@dataclass
class MLATransformerConfig(TransformerConfig):
Expand Down
45 changes: 38 additions & 7 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,13 +823,6 @@ def validate_args(args, defaults={}):
if args.rl_use_sequence_packing:
args.consumed_train_bins = 0

# Support for variable sequence lengths across batches/microbatches.
# set it if the dataloader supports generation of variable sequence lengths
# across batches/microbatches. Due to additional communication overhead
# during pipeline parallelism, it should not be set if sequence length
# is constant during training.
args.variable_seq_lengths = False

# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
Expand Down Expand Up @@ -1000,6 +993,29 @@ def validate_args(args, defaults={}):
assert args.dataloader_type == 'single', 'Hybrid context parallelism only supported with single dataloader type'
assert args.calculate_per_token_loss, 'Hybrid context parallelism must be used with --calculate-per-token-loss'

# Support for variable sequence lengths across batches/microbatches.
# set it if the dataloader supports generation of variable sequence lengths
# across batches/microbatches. Due to additional communication overhead
# during pipeline parallelism, it should not be set if sequence length
# is constant during training.
args.variable_seq_lengths = False
if args.sequence_packing_scheduler is not None:
args.variable_seq_lengths = True
assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \
f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \
f'must be >= single sequence max length ({args.seq_length})'
if args.mock_data and args.sft_mock_dataset_config_json is None:
args.sft_mock_dataset_config_json = json.dumps(
{
"mode": "distribution",
"type": "lognormal",
"min_seq_len": args.seq_length // 2,
"max_seq_len": args.seq_length,
"mean_seq_len": args.seq_length // 4 * 3,
"lognormal_sigma": 1.1,
}
)

# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if (args.tensor_model_parallel_size > 1 or args.context_parallel_size > 1) \
Expand Down Expand Up @@ -1707,6 +1723,9 @@ def _add_network_size_args(parser):
"persist_layer_norm",
"bias_dropout_fusion",
"apply_rope_fusion",
"max_seqlen_per_dp_cp_rank",
"hybrid_context_parallel",
"sequence_packing_scheduler",
]
transformer_factory = ArgumentGroupFactory(TransformerConfig, exclude=exclude)
transformer_group = transformer_factory.build_group(parser, "transformer configuration")
Expand Down Expand Up @@ -2399,6 +2418,14 @@ def _add_distributed_args(parser):
'all layers will share the same communication type. Users can also '
'specify separated types for each layer like '
'--cp-comm-type p2p p2p a2a a2a a2a+p2p a2a+p2p')
group.add_argument('--max-seqlen-per-dp-cp-rank', type=int, default=None,
help='Maximum sequence length per CP rank. This is used to calculate the '
'number of sub-samples assigned to each CP rank when using heterogeneous context parallel.')
group.add_argument('--hybrid-context-parallel', action='store_true', default=False,
help='Enables hybrid context parallel. This is used to balance the workload '
'of each CP rank when we use packed samples with variable sequence lengths. '
'Requires --max-seqlen-per-dp-cp-rank to be set.')
group.add_argument('--sequence-packing-scheduler', type=str, default='default_sequence_packing', choices=['default_sequence_packing'])
group.add_argument('--fake-process-group', action='store_true', default=False,
help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \
This is quite useful for profiling memory usage of distributed training with just one GPU. \
Expand Down Expand Up @@ -2940,4 +2967,8 @@ def _add_sft_args(parser):
group.add_argument('--sft', action="store_true", help='Megatron SFT training')
group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned",
help='SFT prompt format.')
group.add_argument('--sft-mock-dataset-config-json', type=str, default=None,
help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution. '
'If not specified and --mock-data is set, defaults to a lognormal distribution with '
'min_seq_len=seq_length//2, max_seq_len=seq_length, mean_seq_len=seq_length*3//4, lognormal_sigma=1.1.')
return parser
Loading