add --overlap-param-gather support for layer-wise optimizer. lots of unit tests.#3524
add --overlap-param-gather support for layer-wise optimizer. lots of unit tests.#3524mchrzanowski wants to merge 34 commits intoNVIDIA:mainfrom
Conversation
Integrate async param all-gather from upstream PR NVIDIA#2787 so that dist_muon/dist_mop can overlap parameter all-gather with forward compute via DDP's existing bucket and forward-pre-hook infrastructure. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
torch.distributed.all_gather with tensor-list output internally creates a temporary contiguous buffer and copies chunks back to the individual output tensors when wait() is called. When wrapped in _coalescing_manager, the coalescing manager's handle only waits on the grouped NCCL operations but does not trigger the per-op copy-back, leaving output tensors uninitialized with garbage values that cause NaN at iteration 2. Fix by calling all_gather directly per-bucket (without _coalescing_manager) and storing individual work handles in a new _LayerWiseAllGatherHandle class that properly triggers copy-back on wait(). Also pins the flattened source tensor to prevent premature GC during async operations. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add 8 new test cases exercising the overlap-param-gather path (use_layer_wise_optimizer=True + overlap_param_gather=True + async_allgather): - test_overlap_param_gather_basic: end-to-end with bucket-based param sync - test_overlap_param_gather_parameter_updates: vs standard optimizer - test_overlap_param_gather_vs_sync_allgather: async vs sync produce identical results - test_overlap_param_gather_bucket_lw_params: bucket.lw_params_list populated correctly - test_overlap_param_gather_vs_standard_ddp: padded vs unpadded DDP produce same results - test_overlap_param_gather_insufficient_parameters: TinyModel with overlap path - test_overlap_param_gather_broadcast_vs_allgather: broadcast vs allgather equivalence - test_overlap_param_gather_multi_iteration: 3-iteration convergence test Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace with overlap_param_gather and !use_distributed_optimizer checks, which already capture the same semantics without a dedicated flag. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…_optimizer Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… OOM The overlap_param_gather feature allocates large temporary GPU buffers during the forward pass for async all-gather operations. Although freed after finish_param_sync, PyTorch's CUDA allocator caches the memory, which is invisible to the async checkpoint worker's CUDA context. This caused OOM during D2H tensor transfers, preventing checkpoints from finalizing and trapping the run in a restart loop. Add free_overlap_buffers() to _ParamAndGradBucketGroup and DDP that explicitly releases these buffers, and call it + torch.cuda.empty_cache() in save_checkpoint_and_time() before the save_checkpoint() call. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Tests that free_overlap_buffers(): - Clears lw_gather_tensor_list and nulls _lw_src_buffer on each bucket - Waits on any pending param_gather_handle before freeing - Is safe to call when no buffers are allocated (noop) - DDP.free_overlap_buffers delegates to all bucket groups Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Each rank may own different numbers of params per bucket, creating variable flat sizes. NCCL's all_gather requires uniform sizes. Replace with dp_size broadcast calls per bucket: each rank broadcasts its actual-size flattened params to all others. This eliminates padding and uses only collectives (no P2P, which can deadlock with subsequent collectives on the same NCCL communicator). Memory cost: sum(lw_param_flat_sizes) per bucket — no padding waste. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…er with muon optimizer Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… deadlock When a rank has no local params in a bucket, the empty src tensor was created with bucket.grad_data.dtype (fp32 when grad_reduce_in_fp32=True) instead of the actual param dtype (bf16). This caused receive buffers on ranks without local params to be fp32 while the broadcasting rank sends bf16 data, resulting in different buffer byte sizes across ranks in the same broadcast collective — an NCCL deadlock. The fix uses bucket.params_list[0].dtype to always get the correct param dtype, matching the dtype produced by _flatten_dense_tensors on ranks that do have local params. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ing only on last broadcast handle All per-rank broadcasts within a single start_param_sync() call are issued on the same NCCL communicator stream, so NCCL guarantees in-order completion. Waiting sequentially on each of the ~64 individual handles caused intermediate CUDA stream synchronizations that created a timing-dependent deadlock across ranks. Waiting on only the last handle is sufficient. Also removes debug logging added during investigation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| assert self.ddp_config.use_distributed_optimizer | ||
| # overlap_param_gather covers the layer-wise optimizer case, which sets | ||
| # overlap_param_gather=True without use_distributed_optimizer. | ||
| assert self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather | ||
| assert self.ddp_config.overlap_param_gather |
There was a problem hiding this comment.
The check for self.ddp_config.use_distributed_optimizer is redundant here due to the existing assertion for self.ddp_config.overlap_param_gather below it. So the line added by this PR could be removed. (Or should the existing line checking just for self.ddp_config.overlap_param_gather be removed?)
There was a problem hiding this comment.
muon doesnt use the (AdamW) distributed optimizer, so you still have to check for either here
There was a problem hiding this comment.
But you're doing assert self.overlap_param_gather right below the assert self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather.
So even if the first assertion passes with self.ddp_config.use_distributed_optimizer = True and self.ddp_config.overlap_param_gather = False, the second assertion fails anyway.
There was a problem hiding this comment.
oh im sorry, I misunderstood, thank you. fixed
deepakn94
left a comment
There was a problem hiding this comment.
@mchrzanowski is the gradient reduction still an all-reduce?
- Rename all `lw` prefixed names to `layerwise` for clarity - Improve set_layerwise_params_list docstring to clarify that each inner list contains only the params owned by that rank's layer-wise optimizer that also belong to this bucket - Clear self.handles at the end of _LayerwiseAllGatherHandle.wait() - Rename local variable `src` to `flat_local_params` Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
4240979 to
f30914e
Compare
The production code was renamed from lw_ to layerwise_ prefix but the tests were not updated, causing attribute mismatches and test failures. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Throughputs look good on a contrived case with high DP communication overhead (8B Transformer with TP=8). |
The assertion was previously skipped for layer-wise optimizers (dist_muon), but it should be unconditional. Also spell out 'dist_muon' explicitly instead of using substring matching. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
/ok to test 079e2a1 |
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
/ok to test 64dd7ea |
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
/ok to test 62b5ae2 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22700476363 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22703798183 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22705408919 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22709214862 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22713754555 |
What does this PR do ?
Adds --overlap-param-gather functionality to muon.
Builds on top of @deyuf's PR
Lots and lots of unit tests for synchronization and for the layer wise optimizer.
Passes functional tests here: https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/4555
Pre-checks
Core 0.8)