Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
557 changes: 556 additions & 1 deletion megatron/core/datasets/data_schedule.py

Large diffs are not rendered by default.

529 changes: 529 additions & 0 deletions megatron/core/datasets/data_schedule_utils.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions megatron/core/datasets/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig):
context_parallel_size: Optional[int] = None
"""The size of the context parallel group. Needed for padding in packed sequences."""

sft_mock_dataset_config_json: Optional[str] = None
"""This config provides the necessary information for the mock dataset."""
Comment on lines +82 to +83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anti-pattern? Optional dataset config within a dataset config? Why aren't we using inheritance here?

Copy link
Contributor Author

@xiaoyao0115 xiaoyao0115 Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mock/SFT dataset config has a large number of parameters, and threading them one-by-one into the gpt_dataset would be noisy.


def __post_init__(self) -> None:
"""Do asserts and set fields post init"""
super().__post_init__()
Expand Down
62 changes: 62 additions & 0 deletions megatron/core/datasets/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,68 @@ To query the `BlendedDataset` for the _k_-th sample we do the following
To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `BlendedDataset.__init__` function.
## Packing Scheduler
The packing scheduler re-schedules variable-length sequences across DP×CP ranks to improve GPU utilization. It is built around two modules: `data_schedule.py` (high-level logic and entry points) and `data_schedule_utils.py` (utility functions).
### Call Hierarchy
The scheduling pipeline has two phases connected by the data iterator: `wrap_data_iterator` consumes the **original** data iterator, performs global-batch scheduling, and produces a **wrapped** (packed) data iterator; `get_batch_on_this_rank_for_sequence_packing` then consumes this **wrapped** data iterator to fetch individual packed microbatches during training.
```
original wrapped (packed)
data_iterator data_iterator
│ │
▼ ▼
┌────────────────────────┐ ┌────────────────────────────────────┐
│ wrap_data_iterator() │ │ get_batch_on_this_rank_for_ │
Phase 1 │ (once per global │ ────────► │ sequence_packing() │ Phase 2
(scheduling) │ batch) │ returns │ (once per microbatch, │ (fetching)
│ │ wrapped │ called by training loop) │
└───────────┬────────────┘ iterator └──────────────┬─────────────────────┘
│ │
▼ ▼
DpBalancedScheduler.run() next(wrapped_data_iterator)
│ ├─ get_thd_partitioned_indices() [TE]
├─ get_batch_and_global_seqlens() [utils] ├─ broadcast_tensor() [utils]
├─ get_groups_and_subsamples() └─ PackedSeqParams(...)
├─ reroute_samples_to_dcp_ranks() [utils]
├─ build_packed_microbatches() [utils]
├─ broadcast_to_pp_group() [utils]
├─ broadcast_scalars() [utils]
└─ create_data_iterator() [utils]
```
### `data_schedule.py`
#### Entry Points
- **`wrap_data_iterator(original_data_iterator) → wrapped_data_iterator`** — Top-level entry point called once per global batch. Takes the **original** data iterator as input, resolves the scheduler class from `scheduler_map`, instantiates it, and delegates to `scheduler.run()` which consumes all microbatches from the original iterator, re-schedules them, and produces a **wrapped** (packed) data iterator along with the updated `num_microbatches` and FLOPs statistics.
- **`get_batch_on_this_rank_for_sequence_packing(wrapped_data_iterator)`** — Per-microbatch entry point called by the training loop. Takes the **wrapped** data iterator returned by `wrap_data_iterator` as input. Fetches one packed microbatch via `next(wrapped_data_iterator)`, broadcasts batch fields across TP ranks, optionally partitions sequences across CP ranks using Transformer Engine's `thd_get_partitioned_indices`, and constructs `PackedSeqParams` (with `cu_seqlens`, `max_seqlen`, `qkv_format=thd`).
#### Scheduler Classes
- **`BasePackingScheduler`** — Abstract base class. Defines the interface:
- `get_groups_and_subsamples()` — pure scheduling algorithm (must be overridden).
- `run()` — full pipeline: fetch → schedule → reroute → pack → broadcast → VPP handling.
- **`DpBalancedScheduler(BasePackingScheduler)`** — Concrete scheduler that packs sequences in their original order until reaching `max_seqlen_per_dp_cp_rank × cp_size`. Aligns the number of microbatches to `dp_size` (and VPP stage multiples when applicable).
### `data_schedule_utils.py`
Utility functions consumed by the schedulers above:
| Function | Role |
|---|---|
| `get_batch_and_global_seqlens()` | Fetch `num_microbatches` batches from the data iterator and all-gather sequence lengths across DP ranks. |
| `reroute_samples_to_dcp_ranks()` | All-to-all communication to transfer sub-samples to their scheduled DP×CP rank. |
| `build_packed_microbatches()` | Concatenate sub-samples within each microbatch group and produce `cu_seqlens`. |
| `broadcast_to_pp_group()` | Broadcast packed samples and metadata from the first/last PP stage to middle stages. |
| `broadcast_scalars()` | Broadcast scalar values (e.g. `num_microbatches`, FLOPs stats) across a process group. |
| `broadcast_tensor()` | Broadcast a single tensor within a process group. |
| `create_data_iterator()` | Wrap packed sample lists into a data iterator; handles VPP stage splitting. |
## Fast DataLoader initialization
Especially for large-scale runs, DataLoader initialization can take several minutes, since it involves opening and memory-mapping multiple files and can significantly stress the filesystem. To speed up this process, we have developed the following three optimizations, controlled by configuration flags":
Expand Down
21 changes: 21 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2559,3 +2559,24 @@ def set_save_original_input(module):
from transformer_engine.pytorch.float8_tensor import Float8Tensor
except ImportError:
Float8Tensor = None


def get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank):
"""Get partitioned indices for THD format data in context parallel.

Args:
cu_seqlens: Cumulative sequence lengths tensor.
total_tokens: Total number of tokens.
cp_size: Context parallel world size.
cp_rank: Context parallel rank.

Returns:
Partitioned indices tensor.
"""
assert is_te_min_version("1.10.0"), (
"Please update Transformer Engine to >= 1.10 to use "
"Context Parallel with THD format data"
)
import transformer_engine_torch as tex

return tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank)
8 changes: 7 additions & 1 deletion megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ModelParallelConfig:
can handle without overflowing the memory. Typically, a good starting point is to set this
to maximum sequence length / context parallel size.
This is used to calculate the number and length of sub-samples assigned to
each rank when using hybrid_context_parallel.
each rank when sequence_packing_scheduler is not None.
"""

hybrid_context_parallel: bool = False
Expand All @@ -72,6 +72,12 @@ class ModelParallelConfig:
Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel.
"""

sequence_packing_scheduler: Optional[Literal['dp_balanced']] = None
"""
Scheduler for sequence packing and hybrid context parallel.
dp_balanced: DP-balanced scheduler for sequence packing.
"""

expert_model_parallel_size: int = 1
"""Distributes Moe Experts across sub data parallel dimension."""

Expand Down
34 changes: 34 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,6 +2076,40 @@ def __post_init__(self):
self.attention_backend == AttnBackend.flash
), "Batch invariant mode only supports FlashAttention"

if self.sequence_packing_scheduler is not None:
# Check TE version.
if not HAVE_PACKAGING:
raise ImportError(
"packaging is not installed. Please install it with `pip install packaging`."
)
# TODO: remove this after we fix the convergence issue with TE < 2.9.
if not (
is_te_min_version("2.9.0") or get_te_version() == PkgVersion("2.9.0.dev0+5b3092a")
):
raise ValueError(
"SFT sequence packing requires Transformer Engine >= 2.9.0 "
f"but got {get_te_version()} (TE < 2.9.0 may have convergence issues)."
)

# Needed for passing variable sequences between pp stages.
self.variable_seq_lengths = True

# TODO(tailaim): add support for other dispatcher types
assert self.moe_token_dispatcher_type == "alltoall", (
f"sequence_packing only supports moe_token_dispatcher_type='alltoall', "
f"got '{self.moe_token_dispatcher_type}'"
)

supported_schedulers = ['dp_balanced']
if (
self.sequence_packing_scheduler is not None
and self.sequence_packing_scheduler not in supported_schedulers
):
raise ValueError(
f"Unsupported scheduler: {self.sequence_packing_scheduler}. "
f"Available schedulers: {supported_schedulers}"
)
Comment on lines +2103 to +2111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to keep track of the supported schedulers here? Can we make it a part of the data class?

nit: and in the error message, it's more that the scheduler isn't supported as opposed to being unknown

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, i'll make the changes.



@dataclass
@experimental_api
Expand Down
16 changes: 9 additions & 7 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,13 +884,6 @@ def validate_args(args, defaults={}):
if args.rl_use_sequence_packing:
args.consumed_train_bins = 0

# Support for variable sequence lengths across batches/microbatches.
# set it if the dataloader supports generation of variable sequence lengths
# across batches/microbatches. Due to additional communication overhead
# during pipeline parallelism, it should not be set if sequence length
# is constant during training.
args.variable_seq_lengths = False

# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
Expand Down Expand Up @@ -1061,6 +1054,11 @@ def validate_args(args, defaults={}):
assert args.dataloader_type == 'single', 'Hybrid context parallelism only supported with single dataloader type'
assert args.calculate_per_token_loss, 'Hybrid context parallelism must be used with --calculate-per-token-loss'

if args.sequence_packing_scheduler is not None:
assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \
f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \
f'must be >= single sequence max length ({args.seq_length})'

# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if (args.tensor_model_parallel_size > 1 or args.context_parallel_size > 1) \
Expand Down Expand Up @@ -3061,4 +3059,8 @@ def _add_sft_args(parser):
group.add_argument('--sft', action="store_true", help='Megatron SFT training')
group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned",
help='SFT prompt format.')
group.add_argument('--sft-mock-dataset-config-json', type=str, default=None,
help='This config provides the necessary information for the mock dataset. You can either specify a CSV file that contains sequence lengths, where each line stores the length of a sequence, for example: {"mode":"file","path":"/path/to/file"}. Alternatively, you can specify a distribution (currently only supporting lognormal distribution) along with the required parameters, for example, {"mode":"distribution","type":"lognormal","min_seq_len":1024,"max_seq_len":2048,"mean_seq_len":1536,"lognormal_sigma":1.1}, where sigma controls the variability of the lognormal distribution. '
'If not specified and --mock-data is set, defaults to a lognormal distribution with '
'min_seq_len=seq_length//2, max_seq_len=seq_length, mean_seq_len=seq_length*3//4, lognormal_sigma=1.1.')
return parser
Loading
Loading