Skip to content

[Dev] Add E2E support for THD format#2924

Merged
yaox12 merged 9 commits intoNVIDIA:devfrom
xiaoyao0115:thd_e2e
Mar 3, 2026
Merged

[Dev] Add E2E support for THD format#2924
yaox12 merged 9 commits intoNVIDIA:devfrom
xiaoyao0115:thd_e2e

Conversation

@xiaoyao0115
Copy link
Contributor

@xiaoyao0115 xiaoyao0115 commented Jan 13, 2026

Description

This PR adds Sequence Packing (THD format) E2E support to MCore. Main branch PR:#3386

The core missing functionalities of THD in MCore are:

  • data iterator cannot handle THD meta data, like cu_seqlens, max_seqlens.
  • num_microbatches is fixed.
  • PackParams are not passing between PP ranks.

Key Changes

1. Add a data_iterator wrapper (megatron/core/datasets/data_schedule.py::wrap_dataloader)

A wrapper function that intercepts the data iterator to perform scheduling and packing:

  • Schedule & Pack: Extracts data from the data iterator, schedules sequences across DP×CP ranks, and packs them into microbatches with cu_seqlens metadata.
  • Returns packing results: Returns the packed num_microbatches along with two parameters for FLOPs calculation: num_total_tokens_this_global_batch and sequence_square_sum_this_global_batch.
  • TP broadcast: Broadcasts num_microbatches and FLOPs parameters across TP ranks since only TP rank 0 has access to the data iterator.
  • PP broadcast: When using PP, middle PP stages (not first or last) require metadata (cu_seqlens, cu_seqlens_padded, max_seqlen, etc.) to be broadcast from PP rank 0 for correct computation.

2. Mock SFT Dataset Support

Supports mock datasets for testing and benchmarking with configurable sequence length distributions.
There are two modes of mock sft dataset:

  • File mode: Load sequence lengths from an external CSV, example json:
    {"mode": "file", "path": "/path/to/seqlens.csv"}
  • Distribution mode: Generate sequence lengths from a distribution (currently supports lognormal), example json:
    {"mode": "distribution", "type": "lognormal", "min_seq_len": 1024, "max_seq_len": 8192, "mean_seq_len": 4096, "lognormal_sigma": 1.1}

Architecture

Before vs After

graph LR
    subgraph Before
        A1[DataIterator] --> B1[get_batch]
        B1 --> C1[forward_backward]
        C1 --> D1[Fixed seq_len FLOPs]
    end
    subgraph After
        A2[DataIterator] --> W[wrap_dataloader]
        W -->|schedule + pack| B2[PackedDataIterator]
        W -->|broadcast| M[num_microbatches + flops_params]
        B2 --> C2[get_batch_for_sequence_packing]
        C2 --> D2[forward_backward]
        D2 --> E2[Dynamic FLOPs]
        M   --> E2
    end
Loading

Execution Flow

sequenceDiagram
    participant Train as training.py
    participant Schedule as schedules.py
    participant Wrap as wrap_iterator_helper
    participant DataSched as data_schedule.py
    participant GetBatch as get_batch_for_seq_packing

    Train->>Schedule: forward_backward_*(data_iterator)
    Schedule->>Wrap: wrap_iterator_helper(config, data_iterator)
    Wrap->>DataSched: wrap_dataloader(data_iterator, scheduler_type)
    
    Note over DataSched: 1. Gather global seqlens across DP
    Note over DataSched: 2. Scheduler assigns sequences to microbatches
    Note over DataSched: 3. All-to-all redistribute samples
    Note over DataSched: 4. Pack into microbatches
    Note over DataSched: 5. Broadcast to TP/PP ranks
    
    DataSched-->>Schedule: (packed_iter, num_mbs, total_tokens, seq_sq_sum)
    
    loop for each microbatch
        Schedule->>GetBatch: get_batch_on_this_rank_for_sequence_packing
        Note over GetBatch: Broadcast tokens/labels to TP group
        Note over GetBatch: Partition for CP if needed
        GetBatch-->>Schedule: (tokens, labels, loss_mask, pos_ids, packed_seq_params)
    end
    
    Schedule-->>Train: forward_data_store + [total_tokens, seq_sq_sum]
Loading

New Arguments

Argument Type Description
--sequence-packing flag Enable sequence packing (THD format) for training
--sequence-packing-scheduler str Scheduler type: default or empty
--sft-mock-dataset-config-json str JSON config for mock dataset

Changes

File Description
megatron/core/datasets/data_schedule.py Core scheduling and packing logic
megatron/core/pipeline_parallel/schedules.py Integration with forward/backward schedules
megatron/training/training.py Updated FLOPs calculation for variable-length sequences
megatron/training/datasets/sft_dataset.py Mock dataset support
megatron/training/arguments.py New CLI arguments
megatron/core/model_parallel_config.py Configuration options
tests/unit_tests/test_sequence_packing.py Unit tests

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@xiaoyao0115 xiaoyao0115 requested review from a team as code owners January 13, 2026 13:37
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 13, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yanring yanring marked this pull request as draft January 14, 2026 00:44
@yanring yanring requested a review from ISEEKYAN January 15, 2026 06:29
@yanring yanring changed the title Add support for thd e2e Add E2E support for THD format Jan 21, 2026
@kunlunl kunlunl requested a review from lhb8125 January 27, 2026 06:42
@xiaoyao0115 xiaoyao0115 force-pushed the thd_e2e branch 2 times, most recently from 084682b to e10e050 Compare January 30, 2026 03:11
@xiaoyao0115
Copy link
Contributor Author

/ok to test 92fefca


# data_iterator is not None when TP rank 0, with PP stage 0 or -1.
if data_iterator is not None:
assert tp_group.rank() == 0 and (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaoyao0115 One concern I have come to realize is the assumption that only TP rank 0 should have the data iterator and it always broadcasts information to other TP ranks.
This is true in Megatron-LM but does not hold true for Nemo-RL for example. In Nemo-RL, the head node distributes data to all the workers and each rank has the data_iterator.

It might be difficult to identify which regime we are in but maybe we need to support a mode where we don't perform broadcasts between TP ranks etc. This applies to other places in the code where we make this assumption as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah — different frameworks make different assumptions about data_iterator. Megatron-LM typically assumes only TP rank 0 owns the iterator and broadcasts to the other TP ranks, while in frameworks like NeMo-RL and verl are different.
For this change, I think it’s best to keep Megatron-LM’s behavior, rather than trying to cover every external data-loading regime. Supporting “no TP broadcast” is very framework-specific and is better implemented in the integration layer (e.g., NeMo-RL / verl), by adapting their dataloader to produce the same final inputs that Megatron expects.

Copy link
Contributor

@parthmannan parthmannan Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that the default codepath here should be what Megatron-LM expects but I am thinking if we can reduce the friction while maintaining the expectation. We can keep Scheduler and MCore training logic separate. I am just brainstorming here to ease integration into other frameworks -

For example, for the DpBalancedScheduler, this change has

if data_iterator is not None:
            assert tp_group.rank() == 0 and (
                pp_group.rank() == 0 or pp_group.rank() == pp_group.size() - 1
            ), f"Only TP rank 0 and PP stage 0 or -1 should have data_iterator"

This assert would mean Nemo-RL etc. have to write an entirely new DpBalancedScheduler or add complicated logic to schedule only on TP rank0 and then exchange information. This exchange would be unnecessary as in Nemo-RL, each rank has the data_iterator already.
Can we not separate these assumptions and keep the scheduling logic separate from the data processing logic?
As an idea, we could separate these and inside wrap_data_iterator, call these separately

scheduler.process_data #Run the checks needed about data_iterator and TP/PP ranks
if tp_rank == 0:
    scheduler.run #Takes data_iterator and schedules
scheduler.broadcast_pp_tp #Perform broadcasts
scheduler.create_new_iterator #Create new iterator

Nemo-RL could implement it's own wrap_data_iterator where

scheduler.run #Takes data_iterator and schedules
scheduler.create_new_iterator #Create new iterator

It prevents creating a duplicate scheduler in other frameworks and would help prevent code drifting apart as we update/improve logic.

)


def get_batch_on_this_rank_for_sequence_packing(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar thing here. We should keep data sharding (i.e. get_thd_partitioned_indices etc.) separate from broadcasting logic instead of combining into 1 large function.

For example, in pretrain_gpt.py, we call get_batch_on_this_tp_rank if needed and then we can use get_thd_batch_on_this_cp_rank for THD format or get_batch_on_this_cp_rank for SBHD format.

Calling it out here in case you have plans to remove those independent functions in favor of this single utility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Folding THD into get_batch_on_this_tp_rank would add a lot of format-specific branching since THD carries extra keys like max_seqlen, cu_seqlens, etc.
Also consistent with my earlier point: other frameworks (NeMo-RL / verl) should implement their own data processing to match Megatron’s expected inputs rather than having Megatron cover every data-iterator regime.
For CP sharding, get_batch_on_this_cp_rank in data_schedule.py (around L749–L763) already makes the sharding logic explicit. In the current flow I also shard first then broadcast on TP, which reduces broadcast volume and makes the broadcast logic clearer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I didn't mean to suggest that we should fold THD into get_batch_on_this_tp_rank but the opposite. We are folding CP sharding and TP broadcast into 1 function here but I was wondering if we could keep them separate so that other frameworks can use the functions they need instead of re-implementing everything.
Today, we do

batch = get_batch_on_this_tp_rank
if thd:
    batch = get_thd_batch_on_this_cp_rank
else:
    batch = get_batch_on_this_cp_rank

This allows us to use them separately based on the need.

@xiaoyao0115
Copy link
Contributor Author

/ok to test 0614bc4

@xiaoyao0115
Copy link
Contributor Author

/ok to test 8180b57

Copy link
Contributor

@lhb8125 lhb8125 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM except for the docs.

@xiaoyao0115
Copy link
Contributor Author

/ok to test 0944609

@xiaoyao0115
Copy link
Contributor Author

/ok to test 16157f2

batch['position_ids'] = batch['position_ids'].view(1, total_tokens)
else:
batch['tokens'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev)
batch['position_ids'] = torch.empty([1, total_tokens], dtype=torch.int64, device=dev)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total_tokens undefined when mtp_on_this_rank=True on a middle PP stage
In get_batch_on_this_rank_for_sequence_packing, when mtp_on_this_rank is True but is_first_or_last_stage is False (possible with custom pipeline layouts that place MTP layers on middle stages), the variable total_tokens is never defined.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, i will fix this

max_seqlens = info_to_broadcast[3 : 3 + num_micro_batches]
cu_seqlens_list = []
cu_seqlens_padded_list = []
indices = np.where(info_numpy == 0)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast_to_pp_group uses np.where(info_numpy == 0) to find cu_seqlens boundaries, which is fragile and incorrect if any metadata value is zero
The receiver side uses np.where(info_numpy == 0) to locate cu_seqlens boundaries (since cu_seqlens always start with 0). Will this silently produce incorrect results if any other value in the tensor happens to be 0?

Copy link
Contributor Author

@xiaoyao0115 xiaoyao0115 Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is pretty hacky. However, in practice the other metadata values should never be 0 — num_micro_batches, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch and max_seqlen are always positive, and cu_seqlens always increases after the initial 0. So using np.where(info == 0) to locate cu_seqlens boundaries works correctly.
I can either add a comment clarifying the assumption, or add an explicit broadcast for the cu_seqlens length to make it more robust. Which would you prefer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add some comments

@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22435655277

xiaoyao0115 and others added 8 commits March 2, 2026 01:38
Signed-off-by: xiaoyao0115 <1804647152@qq.com>
Signed-off-by: tailaim <tailaim@nvidia.com>
Signed-off-by: xiaoyao0115 <1804647152@qq.com>
Signed-off-by: tailaim <tailaim@nvidia.com>
Signed-off-by: tailaim <tailaim@nvidia.com>
Signed-off-by: tailaim <tailaim@nvidia.com>
@kunlunl
Copy link
Contributor

kunlunl commented Mar 2, 2026

/ok to test 816fca1

@kunlunl
Copy link
Contributor

kunlunl commented Mar 2, 2026

/ok to test 2aae2d0

@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22604966008

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Expert Review Apply this label to indicate that your PR is ready for expert review.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants