Skip to content

add --overlap-param-gather support for layer-wise optimizer. lots of unit tests.#3524

Queued
mchrzanowski wants to merge 34 commits intoNVIDIA:mainfrom
mchrzanowski:overlap-param-gather-muon-rebased
Queued

add --overlap-param-gather support for layer-wise optimizer. lots of unit tests.#3524
mchrzanowski wants to merge 34 commits intoNVIDIA:mainfrom
mchrzanowski:overlap-param-gather-muon-rebased

Conversation

@mchrzanowski
Copy link

@mchrzanowski mchrzanowski commented Feb 20, 2026

What does this PR do ?

  1. Adds --overlap-param-gather functionality to muon.
    Builds on top of @deyuf's PR

  2. Lots and lots of unit tests for synchronization and for the layer wise optimizer.

Passes functional tests here: https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/4555

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

mchrzanowski and others added 5 commits February 17, 2026 20:30
Integrate async param all-gather from upstream PR NVIDIA#2787 so that
dist_muon/dist_mop can overlap parameter all-gather with forward
compute via DDP's existing bucket and forward-pre-hook infrastructure.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
torch.distributed.all_gather with tensor-list output internally creates a
temporary contiguous buffer and copies chunks back to the individual output
tensors when wait() is called. When wrapped in _coalescing_manager, the
coalescing manager's handle only waits on the grouped NCCL operations but
does not trigger the per-op copy-back, leaving output tensors uninitialized
with garbage values that cause NaN at iteration 2.

Fix by calling all_gather directly per-bucket (without _coalescing_manager)
and storing individual work handles in a new _LayerWiseAllGatherHandle class
that properly triggers copy-back on wait(). Also pins the flattened source
tensor to prevent premature GC during async operations.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add 8 new test cases exercising the overlap-param-gather path
(use_layer_wise_optimizer=True + overlap_param_gather=True + async_allgather):
- test_overlap_param_gather_basic: end-to-end with bucket-based param sync
- test_overlap_param_gather_parameter_updates: vs standard optimizer
- test_overlap_param_gather_vs_sync_allgather: async vs sync produce identical results
- test_overlap_param_gather_bucket_lw_params: bucket.lw_params_list populated correctly
- test_overlap_param_gather_vs_standard_ddp: padded vs unpadded DDP produce same results
- test_overlap_param_gather_insufficient_parameters: TinyModel with overlap path
- test_overlap_param_gather_broadcast_vs_allgather: broadcast vs allgather equivalence
- test_overlap_param_gather_multi_iteration: 3-iteration convergence test

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace with overlap_param_gather and !use_distributed_optimizer checks,
which already capture the same semantics without a dedicated flag.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…_optimizer

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@mchrzanowski mchrzanowski requested review from a team as code owners February 20, 2026 20:10
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 20, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team February 20, 2026 20:10
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Contributor

@mkhona-nvidia mkhona-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

mchrzanowski and others added 9 commits February 20, 2026 22:38
… OOM

The overlap_param_gather feature allocates large temporary GPU buffers
during the forward pass for async all-gather operations. Although freed
after finish_param_sync, PyTorch's CUDA allocator caches the memory,
which is invisible to the async checkpoint worker's CUDA context. This
caused OOM during D2H tensor transfers, preventing checkpoints from
finalizing and trapping the run in a restart loop.

Add free_overlap_buffers() to _ParamAndGradBucketGroup and DDP that
explicitly releases these buffers, and call it + torch.cuda.empty_cache()
in save_checkpoint_and_time() before the save_checkpoint() call.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Tests that free_overlap_buffers():
- Clears lw_gather_tensor_list and nulls _lw_src_buffer on each bucket
- Waits on any pending param_gather_handle before freeing
- Is safe to call when no buffers are allocated (noop)
- DDP.free_overlap_buffers delegates to all bucket groups

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Each rank may own different numbers of params per bucket, creating variable
flat sizes. NCCL's all_gather requires uniform sizes. Replace with dp_size
broadcast calls per bucket: each rank broadcasts its actual-size flattened
params to all others. This eliminates padding and uses only collectives
(no P2P, which can deadlock with subsequent collectives on the same NCCL
communicator).

Memory cost: sum(lw_param_flat_sizes) per bucket — no padding waste.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…er with muon optimizer

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… deadlock

When a rank has no local params in a bucket, the empty src tensor was
created with bucket.grad_data.dtype (fp32 when grad_reduce_in_fp32=True)
instead of the actual param dtype (bf16). This caused receive buffers on
ranks without local params to be fp32 while the broadcasting rank sends
bf16 data, resulting in different buffer byte sizes across ranks in the
same broadcast collective — an NCCL deadlock.

The fix uses bucket.params_list[0].dtype to always get the correct param
dtype, matching the dtype produced by _flatten_dense_tensors on ranks
that do have local params.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ing only on last broadcast handle

All per-rank broadcasts within a single start_param_sync() call are issued on
the same NCCL communicator stream, so NCCL guarantees in-order completion.
Waiting sequentially on each of the ~64 individual handles caused intermediate
CUDA stream synchronizations that created a timing-dependent deadlock across
ranks. Waiting on only the last handle is sufficient.

Also removes debug logging added during investigation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Comment on lines -320 to 445
assert self.ddp_config.use_distributed_optimizer
# overlap_param_gather covers the layer-wise optimizer case, which sets
# overlap_param_gather=True without use_distributed_optimizer.
assert self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather
assert self.ddp_config.overlap_param_gather
Copy link
Contributor

@janEbert janEbert Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for self.ddp_config.use_distributed_optimizer is redundant here due to the existing assertion for self.ddp_config.overlap_param_gather below it. So the line added by this PR could be removed. (Or should the existing line checking just for self.ddp_config.overlap_param_gather be removed?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

muon doesnt use the (AdamW) distributed optimizer, so you still have to check for either here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you're doing assert self.overlap_param_gather right below the assert self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather.
So even if the first assertion passes with self.ddp_config.use_distributed_optimizer = True and self.ddp_config.overlap_param_gather = False, the second assertion fails anyway.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh im sorry, I misunderstood, thank you. fixed

@janEbert janEbert added the Expert Review Apply this label to indicate that your PR is ready for expert review. label Feb 23, 2026
@mchrzanowski mchrzanowski changed the title add --overlap-param-gather support for layer-wise optimizer add --overlap-param-gather support for layer-wise optimizer. improve memory usage with per-rank broadcasts. Feb 23, 2026
@mchrzanowski mchrzanowski requested a review from deepakn94 March 2, 2026 17:58
Copy link
Contributor

@deepakn94 deepakn94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mchrzanowski is the gradient reduction still an all-reduce?

- Rename all `lw` prefixed names to `layerwise` for clarity
- Improve set_layerwise_params_list docstring to clarify that each
  inner list contains only the params owned by that rank's layer-wise
  optimizer that also belong to this bucket
- Clear self.handles at the end of _LayerwiseAllGatherHandle.wait()
- Rename local variable `src` to `flat_local_params`

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@mchrzanowski mchrzanowski force-pushed the overlap-param-gather-muon-rebased branch from 4240979 to f30914e Compare March 3, 2026 04:45
The production code was renamed from lw_ to layerwise_ prefix but the
tests were not updated, causing attribute mismatches and test failures.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@deepakn94
Copy link
Contributor

Throughputs look good on a contrived case with high DP communication overhead (8B Transformer with TP=8).

### Logfile: 8b_adam_nooverlap_101106_date_26-03-04_time_13-16-15.log
---------------------------------------------------------------------------
| Metric               | Throughput (TFLOP/s)   | Time per Iter (ms)     |
---------------------------------------------------------------------------
| Iterations Recorded  | 57                     | 57                     |
| Min                  | 183.60                 | 282.00                 |
| Max                  | 187.20                 | 287.40                 |
| Mean                 | 185.64                 | 284.30                 |
| Median               | 185.70                 | 284.20                 |
---------------------------------------------------------------------------


### Logfile: 8b_adam_overlap_101105_date_26-03-04_time_13-16-04.log
---------------------------------------------------------------------------
| Metric               | Throughput (TFLOP/s)   | Time per Iter (ms)     |
---------------------------------------------------------------------------
| Iterations Recorded  | 63                     | 63                     |
| Min                  | 199.60                 | 255.80                 |
| Max                  | 206.30                 | 264.40                 |
| Mean                 | 203.87                 | 258.89                 |
| Median               | 204.10                 | 258.60                 |
---------------------------------------------------------------------------


### Logfile: 8b_muon_nooverlap_101124_date_26-03-04_time_13-23-41.log
---------------------------------------------------------------------------
| Metric               | Throughput (TFLOP/s)   | Time per Iter (ms)     |
---------------------------------------------------------------------------
| Iterations Recorded  | 44                     | 44                     |
| Min                  | 142.70                 | 365.10                 |
| Max                  | 144.50                 | 369.80                 |
| Mean                 | 143.83                 | 366.95                 |
| Median               | 143.90                 | 366.75                 |
---------------------------------------------------------------------------


### Logfile: 8b_muon_overlap_101123_date_26-03-04_time_13-23-35.log
---------------------------------------------------------------------------
| Metric               | Throughput (TFLOP/s)   | Time per Iter (ms)     |
---------------------------------------------------------------------------
| Iterations Recorded  | 49                     | 49                     |
| Min                  | 158.50                 | 324.70                 |
| Max                  | 162.50                 | 333.00                 |
| Mean                 | 160.59                 | 328.66                 |
| Median               | 160.50                 | 328.90                 |
---------------------------------------------------------------------------

mchrzanowski and others added 2 commits March 4, 2026 14:08
The assertion was previously skipped for layer-wise optimizers (dist_muon),
but it should be unconditional. Also spell out 'dist_muon' explicitly
instead of using substring matching.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@deepakn94
Copy link
Contributor

/ok to test 079e2a1

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Mar 4, 2026
@deepakn94 deepakn94 self-requested a review March 4, 2026 22:18
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@mchrzanowski mchrzanowski enabled auto-merge March 4, 2026 22:28
@mchrzanowski mchrzanowski disabled auto-merge March 4, 2026 22:29
@deepakn94 deepakn94 requested review from skyw and removed request for skyw March 4, 2026 23:30
@deepakn94
Copy link
Contributor

/ok to test 64dd7ea

@deepakn94 deepakn94 enabled auto-merge March 4, 2026 23:31
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@mchrzanowski
Copy link
Author

/ok to test 62b5ae2

@deepakn94 deepakn94 added this pull request to the merge queue Mar 5, 2026
@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22700476363

@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22703798183

@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22705408919

@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22709214862

@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22713754555

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Expert Review Apply this label to indicate that your PR is ready for expert review.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants