From e317e8d177a650f88d360bebef70d33fe0c30c39 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Fri, 13 Feb 2026 01:17:00 +0000 Subject: [PATCH 01/25] add --overlap-param-gather support for layer-wise optimizer (muon) Integrate async param all-gather from upstream PR #2787 so that dist_muon/dist_mop can overlap parameter all-gather with forward compute via DDP's existing bucket and forward-pre-hook infrastructure. Co-Authored-By: Claude Opus 4.6 --- .../distributed/distributed_data_parallel.py | 8 +- .../distributed_data_parallel_config.py | 3 + .../core/distributed/param_and_grad_buffer.py | 107 +++++++++++++++--- .../core/optimizer/layer_wise_optimizer.py | 56 ++++++++- megatron/core/optimizer/muon.py | 7 +- megatron/training/arguments.py | 5 +- megatron/training/training.py | 10 +- 7 files changed, 171 insertions(+), 25 deletions(-) diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 55179ff3024..8425d1548df 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -236,7 +236,10 @@ def _allocate_buffers_for_parameters( # Set `next_param_gather_bucket_group` for different bucket groups by iterating through # buckets in reverse order (since all-gathers happen in reverse order of buckets). - if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather: + if ( + self.ddp_config.use_distributed_optimizer + or self.ddp_config.use_layer_wise_optimizer + ) and self.ddp_config.overlap_param_gather: num_bucket_groups = len(bucket_groups) for i in range(1, num_bucket_groups): bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = ( @@ -345,7 +348,8 @@ def unmap_weight_tensor(m): self.grad_accs.append(grad_acc) self.use_forward_hook = ( - self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather + (self.ddp_config.use_distributed_optimizer or self.ddp_config.use_layer_wise_optimizer) + and self.ddp_config.overlap_param_gather ) self.remove_forward_pre_hook_handles = {} if self.use_forward_hook: diff --git a/megatron/core/distributed/distributed_data_parallel_config.py b/megatron/core/distributed/distributed_data_parallel_config.py index c4b25b9f85c..67d2534ee39 100644 --- a/megatron/core/distributed/distributed_data_parallel_config.py +++ b/megatron/core/distributed/distributed_data_parallel_config.py @@ -27,6 +27,9 @@ class DistributedDataParallelConfig: originally allocated model parameters, otherwise issue all-reduce collectives. """ + use_layer_wise_optimizer: bool = False + """If true, use layer-wise distributed optimizer for param all-gather overlap.""" + num_distributed_optimizer_instances: int = 1 """Sets the factor by which the DP domain is sharded to have the partial DistOpt enabled. Defaults to 1, which means DistOpt is across entire DP domain. diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 088374fbf13..a49cfcc955e 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -11,6 +11,7 @@ import torch from torch.distributed import _coalescing_manager +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors import megatron.core.nccl_allocator as nccl_allocator from megatron.core import parallel_state @@ -112,6 +113,22 @@ def __init__( global_start, global_end, _ = param_index_map[param] self.param_to_index[param] = (global_start - offset, global_end - offset) + # Layer-wise optimizer attributes for async param gather. + self.lw_params_list = None + self.lw_param_flat_sizes = None + self.lw_gather_tensor_list = None + + def set_lw_params_list(self, lw_params_list: List[List[torch.nn.Parameter]]): + """Set per-rank parameter lists for layer-wise async all-gather. + + Args: + lw_params_list: List of param lists, one per rank in the DP group. + """ + self.lw_params_list = lw_params_list + self.lw_param_flat_sizes = [ + sum([p.numel() for p in param_list]) for param_list in lw_params_list + ] + class _ParamAndGradBucketGroup: """ @@ -138,11 +155,11 @@ def __init__( self.buckets = buckets self.ddp_config = ddp_config - if self.ddp_config.use_distributed_optimizer: + if self.ddp_config.use_distributed_optimizer or self.ddp_config.use_layer_wise_optimizer: self.intra_distributed_optimizer_instance_group = collective_group self.intra_distributed_optimizer_instance_size = collective_group_size self.intra_distributed_optimizer_instance_rank = collective_group.rank() - else: + if not self.ddp_config.use_distributed_optimizer: self.data_parallel_group = collective_group # State for bookkeeping: params is the set of parameters this bucket group is @@ -262,7 +279,7 @@ def start_param_sync(self, force_sync: bool = False): force_sync (bool, optional): force synchronous collective regardless of other settings if true. """ - assert self.ddp_config.use_distributed_optimizer + assert self.ddp_config.use_distributed_optimizer or self.ddp_config.use_layer_wise_optimizer if force_sync: if self.param_gather_handle is not None: @@ -277,22 +294,62 @@ def start_param_sync(self, force_sync: bool = False): with _coalescing_manager( self.intra_distributed_optimizer_instance_group, async_ops=async_op ) as cm: - for idx, bucket in enumerate(self.buckets): - if self.cached_param_buffer_shard_list[idx] is None: - self.cached_param_buffer_shard_list[idx] = shard_buffer( - bucket.param_data, self.intra_distributed_optimizer_instance_size + if not self.ddp_config.use_layer_wise_optimizer: + for idx, bucket in enumerate(self.buckets): + if self.cached_param_buffer_shard_list[idx] is None: + self.cached_param_buffer_shard_list[idx] = shard_buffer( + bucket.param_data, self.intra_distributed_optimizer_instance_size + ) + local_data_view = self.cached_param_buffer_shard_list[idx][ + self.intra_distributed_optimizer_instance_rank + ] + dist_all_gather_func( + bucket.param_data, + local_data_view, + group=self.intra_distributed_optimizer_instance_group, + async_op=async_op, + ) + else: + for bucket in self.buckets: + local_rank = self.intra_distributed_optimizer_instance_rank + src = ( + _flatten_dense_tensors(bucket.lw_params_list[local_rank]) + if len(bucket.lw_params_list[local_rank]) > 0 + else torch.empty( + 0, + device=bucket.grad_data.device, + dtype=bucket.grad_data.dtype, + ) + ) + bucket.lw_gather_tensor_list = [ + torch.empty(size, device=src.device, dtype=src.dtype) + for size in bucket.lw_param_flat_sizes + ] + torch.distributed.all_gather( + bucket.lw_gather_tensor_list, + src, + group=self.intra_distributed_optimizer_instance_group, + async_op=async_op, ) - local_data_view = self.cached_param_buffer_shard_list[idx][ - self.intra_distributed_optimizer_instance_rank - ] - dist_all_gather_func( - bucket.param_data, - local_data_view, - group=self.intra_distributed_optimizer_instance_group, - async_op=async_op, - ) if async_op: self.param_gather_handle = cm + elif self.ddp_config.use_layer_wise_optimizer: + # Synchronous layer-wise case (e.g., force_sync=True for checkpointing): + # unflatten and copy gathered params immediately. + for bucket in self.buckets: + for idx, (flat_params, params) in enumerate( + zip(bucket.lw_gather_tensor_list, bucket.lw_params_list) + ): + if ( + len(params) == 0 + or idx == self.intra_distributed_optimizer_instance_rank + ): + continue + updated_params = _unflatten_dense_tensors(flat_params, params) + for updated_p, model_p in zip(updated_params, params): + model_p.data.copy_(updated_p) + bucket.lw_gather_tensor_list.clear() + self.param_gather_handle = None else: # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used, # `cm` is not None, which is different from when `_coalescing_manager` is not used in @@ -317,7 +374,7 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): skip_next_bucket_dispatch (bool, optional): if true, dispatch next bucket's communication if available. """ - assert self.ddp_config.use_distributed_optimizer + assert self.ddp_config.use_distributed_optimizer or self.ddp_config.use_layer_wise_optimizer assert self.ddp_config.overlap_param_gather # If current bucket's param AG has not been dispatched, dispatch it now (e.g., first @@ -355,6 +412,22 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): # correspond to multiple param buffers. If we zero out the entire grad buffer, # it would clear the data of those param buffers that have not yet completed AG. bucket.param_data.zero_() + elif self.ddp_config.use_layer_wise_optimizer: + for bucket in self.buckets: + # Unflatten and copy gathered params for each rank. + for idx, (flat_params, params) in enumerate( + zip(bucket.lw_gather_tensor_list, bucket.lw_params_list) + ): + # Skip local params and empty tensors. + if ( + len(params) == 0 + or idx == self.intra_distributed_optimizer_instance_rank + ): + continue + updated_params = _unflatten_dense_tensors(flat_params, params) + for updated_p, model_p in zip(updated_params, params): + model_p.data.copy_(updated_p) + bucket.lw_gather_tensor_list.clear() else: fp8_params = [] for bucket in self.buckets: diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index de4396a5b4f..84a37a4e439 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -45,6 +45,8 @@ def __init__( config: OptimizerConfig, pg_collection: Optional[ProcessGroupCollection] = None, init_state_fn_list: Optional[List[Callable]] = None, + model_chunks: Optional[List] = None, + async_allgather: bool = False, ) -> None: """ Initialize LayerWiseDistributedOptimizer. @@ -54,10 +56,21 @@ def __init__( config: OptimizerConfig. pg_collection: ProcessGroupCollection. init_state_fn_list: List of init state functions. + model_chunks: DDP-wrapped model chunks (needed for async_allgather). + async_allgather: If True, defer param all-gather to forward pre-hooks. """ self.pg_collection = pg_collection self.shard_params(optimizers) + + # Set up async all-gather using DDP bucket infrastructure. + self.async_allgather = async_allgather + if self.async_allgather: + assert model_chunks is not None, ( + "model_chunks must be provided if async_allgather is True" + ) + self.set_bucket_lw_params_list(model_chunks) + if init_state_fn_list: assert len(init_state_fn_list) == len( optimizers @@ -143,6 +156,43 @@ def shard_params(self, optimizers): if expt_dp_size == 1 or len(self.expt_dp_params_list[0]) == 0: self.expt_dp_params_list = None + def set_bucket_lw_params_list(self, model_chunks): + """Map sharded params to DDP buckets for async all-gather. + + For each bucket in each model chunk's bucket groups, build per-rank param lists + by cross-referencing the layer-wise sharded param lists with the bucket's params. + + Args: + model_chunks: DDP-wrapped model chunks with bucket_groups. + """ + for model_chunk in model_chunks: + for group in model_chunk.bucket_groups: + for bucket in group.buckets: + bucket_params_list = [ + [] for _ in range(get_pg_size(self.pg_collection.dp_cp)) + ] + for bucket_list, full_params_list in zip( + bucket_params_list, self.dp_cp_params_list + ): + for param in full_params_list: + if param in bucket.params: + bucket_list.append(param) + bucket.set_lw_params_list(bucket_params_list) + # Do the same for expert parallel bucket groups. + if self.expt_dp_params_list is not None: + for group in model_chunk.expert_parallel_bucket_groups: + for bucket in group.buckets: + bucket_params_list = [ + [] for _ in range(get_pg_size(self.pg_collection.expt_dp)) + ] + for bucket_list, full_params_list in zip( + bucket_params_list, self.expt_dp_params_list + ): + for param in full_params_list: + if param in bucket.params: + bucket_list.append(param) + bucket.set_lw_params_list(bucket_params_list) + @torch.no_grad() def allgather_params(self) -> None: """All-gather updated params from all ranks.""" @@ -223,8 +273,10 @@ def step(self): # type: ignore[no-untyped-def] """step function for layer-wise optimizer.""" update_successful, grad_norm, num_zeros_in_grad = super().step() - # All gather updated params. - self.allgather_params() + # All gather updated params. If async_allgather is True, the allgather + # is deferred to the forward pre-hooks via DDP bucket infrastructure. + if not self.async_allgather: + self.allgather_params() return update_successful, grad_norm, num_zeros_in_grad diff --git a/megatron/core/optimizer/muon.py b/megatron/core/optimizer/muon.py index 57eb1e94478..c637eaa5442 100644 --- a/megatron/core/optimizer/muon.py +++ b/megatron/core/optimizer/muon.py @@ -345,6 +345,11 @@ def adam_init_state_fn(opt, config=None): if reset_config_bf16: config.bf16 = True return LayerWiseDistributedOptimizer( - optimizers, config, pg_collection, init_state_fn_list=init_fns + optimizers, + config, + pg_collection, + init_state_fn_list=init_fns, + model_chunks=model_chunks, + async_allgather=config.overlap_param_gather, ) return ChainedOptimizer(optimizers) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 3fcc6437758..81ec555dfd4 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -673,8 +673,9 @@ def validate_args(args, defaults={}): ) if args.overlap_param_gather: - assert args.use_distributed_optimizer or args.use_megatron_fsdp, \ - '--overlap-param-gather only supported with distributed optimizer or megatron fsdp' + assert args.use_distributed_optimizer or args.use_megatron_fsdp \ + or ('dist' in args.optimizer), \ + '--overlap-param-gather only supported with distributed optimizer, megatron fsdp, or layer-wise optimizer' assert args.overlap_grad_reduce, \ 'Must use --overlap-param-gather with --overlap-grad-reduce' assert not args.use_legacy_models, \ diff --git a/megatron/training/training.py b/megatron/training/training.py index 2c68c70735d..c5a81722c96 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1374,6 +1374,10 @@ def build_model(): kwargs['average_in_collective'] = args.ddp_average_in_collective ddp_config = DistributedDataParallelConfig(**kwargs) + # Enable layer-wise optimizer flag for distributed muon/mop variants. + if 'dist' in args.optimizer: + ddp_config.use_layer_wise_optimizer = True + # In the Megatron FSDP and DDP use path, we need to initialize the bucket size. # If bucket_size is not provided as an input, use sane default. # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL @@ -3593,4 +3597,8 @@ def _get_iterator(dataloader_type, dataloader): def should_disable_forward_pre_hook(args): """Block forward pre-hook for certain configurations.""" - return not args.use_megatron_fsdp and args.use_distributed_optimizer and args.overlap_param_gather + return ( + not args.use_megatron_fsdp + and (args.use_distributed_optimizer or 'dist' in args.optimizer) + and args.overlap_param_gather + ) From d8189e4715769b1e37a8058fb3c11d29b97a9ef0 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Sat, 14 Feb 2026 17:41:50 -0800 Subject: [PATCH 02/25] Fix NaN in overlap-param-gather for layer-wise optimizer (Muon) torch.distributed.all_gather with tensor-list output internally creates a temporary contiguous buffer and copies chunks back to the individual output tensors when wait() is called. When wrapped in _coalescing_manager, the coalescing manager's handle only waits on the grouped NCCL operations but does not trigger the per-op copy-back, leaving output tensors uninitialized with garbage values that cause NaN at iteration 2. Fix by calling all_gather directly per-bucket (without _coalescing_manager) and storing individual work handles in a new _LayerWiseAllGatherHandle class that properly triggers copy-back on wait(). Also pins the flattened source tensor to prevent premature GC during async operations. Co-Authored-By: Claude Opus 4.6 --- .../core/distributed/param_and_grad_buffer.py | 144 +++++++++++------- 1 file changed, 92 insertions(+), 52 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index a49cfcc955e..95b2c17884b 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -117,6 +117,7 @@ def __init__( self.lw_params_list = None self.lw_param_flat_sizes = None self.lw_gather_tensor_list = None + self._lw_src_buffer = None def set_lw_params_list(self, lw_params_list: List[List[torch.nn.Parameter]]): """Set per-rank parameter lists for layer-wise async all-gather. @@ -130,6 +131,26 @@ def set_lw_params_list(self, lw_params_list: List[List[torch.nn.Parameter]]): ] +class _LayerWiseAllGatherHandle: + """Handle for multiple async all-gather operations used by the layer-wise optimizer. + + torch.distributed.all_gather with a tensor list output internally creates a temporary + contiguous buffer and copies chunks back to the individual output tensors when wait() + is called on the returned work handle. When wrapped in _coalescing_manager, this + copy-back step is lost because the coalescing manager's handle only waits on the + NCCL operations, not the individual copy-back callbacks. This class stores the + individual work handles so that wait() properly triggers the copy-back for each + all_gather call. + """ + + def __init__(self, handles): + self.handles = handles + + def wait(self): + for h in self.handles: + h.wait() + + class _ParamAndGradBucketGroup: """ Put multiple buckets into a group so that their communications can be aggregated together. @@ -290,11 +311,70 @@ def start_param_sync(self, force_sync: bool = False): assert self.param_gather_handle is None async_op = self.ddp_config.overlap_param_gather and not force_sync - # Coalesce communication kernels across buckets in the bucket group. - with _coalescing_manager( - self.intra_distributed_optimizer_instance_group, async_ops=async_op - ) as cm: - if not self.ddp_config.use_layer_wise_optimizer: + + if self.ddp_config.use_layer_wise_optimizer: + # Layer-wise optimizer path: do NOT use _coalescing_manager. + # + # torch.distributed.all_gather with a tensor-list output internally creates a + # temporary contiguous buffer, calls ncclAllGather into it, and then copies chunks + # back to the individual output tensors when wait() is called on the work handle. + # When wrapped in _coalescing_manager, the coalescing manager's handle only waits + # on the grouped NCCL operations but does NOT trigger the per-op copy-back step, + # leaving the output tensors uninitialized. We avoid this by calling all_gather + # directly and storing the individual work handles. + lw_work_handles = [] + for bucket in self.buckets: + local_rank = self.intra_distributed_optimizer_instance_rank + src = ( + _flatten_dense_tensors(bucket.lw_params_list[local_rank]) + if len(bucket.lw_params_list[local_rank]) > 0 + else torch.empty( + 0, + device=bucket.grad_data.device, + dtype=bucket.grad_data.dtype, + ) + ) + # Keep src alive until the async operation completes. + bucket._lw_src_buffer = src + bucket.lw_gather_tensor_list = [ + torch.empty(size, device=src.device, dtype=src.dtype) + for size in bucket.lw_param_flat_sizes + ] + work = torch.distributed.all_gather( + bucket.lw_gather_tensor_list, + src, + group=self.intra_distributed_optimizer_instance_group, + async_op=async_op, + ) + if async_op and work is not None: + lw_work_handles.append(work) + if async_op: + self.param_gather_handle = _LayerWiseAllGatherHandle(lw_work_handles) + else: + # Synchronous layer-wise case (e.g., force_sync=True for checkpointing): + # unflatten and copy gathered params immediately. + for bucket in self.buckets: + for idx, (flat_params, params) in enumerate( + zip(bucket.lw_gather_tensor_list, bucket.lw_params_list) + ): + if ( + len(params) == 0 + or idx == self.intra_distributed_optimizer_instance_rank + ): + continue + updated_params = _unflatten_dense_tensors(flat_params, params) + for updated_p, model_p in zip(updated_params, params): + model_p.data.copy_(updated_p) + bucket.lw_gather_tensor_list.clear() + bucket._lw_src_buffer = None + self.param_gather_handle = None + else: + # Standard distributed optimizer path: use _coalescing_manager. + # all_gather_into_tensor writes directly into a contiguous output buffer and + # does not need a copy-back step, so coalescing works correctly. + with _coalescing_manager( + self.intra_distributed_optimizer_instance_group, async_ops=async_op + ) as cm: for idx, bucket in enumerate(self.buckets): if self.cached_param_buffer_shard_list[idx] is None: self.cached_param_buffer_shard_list[idx] = shard_buffer( @@ -309,54 +389,13 @@ def start_param_sync(self, force_sync: bool = False): group=self.intra_distributed_optimizer_instance_group, async_op=async_op, ) + if async_op: + self.param_gather_handle = cm else: - for bucket in self.buckets: - local_rank = self.intra_distributed_optimizer_instance_rank - src = ( - _flatten_dense_tensors(bucket.lw_params_list[local_rank]) - if len(bucket.lw_params_list[local_rank]) > 0 - else torch.empty( - 0, - device=bucket.grad_data.device, - dtype=bucket.grad_data.dtype, - ) - ) - bucket.lw_gather_tensor_list = [ - torch.empty(size, device=src.device, dtype=src.dtype) - for size in bucket.lw_param_flat_sizes - ] - torch.distributed.all_gather( - bucket.lw_gather_tensor_list, - src, - group=self.intra_distributed_optimizer_instance_group, - async_op=async_op, - ) - if async_op: - self.param_gather_handle = cm - elif self.ddp_config.use_layer_wise_optimizer: - # Synchronous layer-wise case (e.g., force_sync=True for checkpointing): - # unflatten and copy gathered params immediately. - for bucket in self.buckets: - for idx, (flat_params, params) in enumerate( - zip(bucket.lw_gather_tensor_list, bucket.lw_params_list) - ): - if ( - len(params) == 0 - or idx == self.intra_distributed_optimizer_instance_rank - ): - continue - updated_params = _unflatten_dense_tensors(flat_params, params) - for updated_p, model_p in zip(updated_params, params): - model_p.data.copy_(updated_p) - bucket.lw_gather_tensor_list.clear() - self.param_gather_handle = None - else: - # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used, - # `cm` is not None, which is different from when `_coalescing_manager` is not used in - # which case the torch.distributed._all_gather_base() will return None. In order to - # maintain consistency with prior code, we need to manually set communication handle to - # None. - self.param_gather_handle = None + # When using `_coalescing_manager`, even if a synchronous op + # (async_op=False) is used, `cm` is not None. Manually set to None for + # consistency with prior code. + self.param_gather_handle = None self.param_gather_dispatched = True def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): @@ -428,6 +467,7 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) bucket.lw_gather_tensor_list.clear() + bucket._lw_src_buffer = None else: fp8_params = [] for bucket in self.buckets: From 2df43c9c5f08ab55c914195bf59086e1d98ccb4c Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Sun, 15 Feb 2026 21:53:08 -0800 Subject: [PATCH 03/25] Add unit tests for overlap-param-gather in layer-wise optimizer Add 8 new test cases exercising the overlap-param-gather path (use_layer_wise_optimizer=True + overlap_param_gather=True + async_allgather): - test_overlap_param_gather_basic: end-to-end with bucket-based param sync - test_overlap_param_gather_parameter_updates: vs standard optimizer - test_overlap_param_gather_vs_sync_allgather: async vs sync produce identical results - test_overlap_param_gather_bucket_lw_params: bucket.lw_params_list populated correctly - test_overlap_param_gather_vs_standard_ddp: padded vs unpadded DDP produce same results - test_overlap_param_gather_insufficient_parameters: TinyModel with overlap path - test_overlap_param_gather_broadcast_vs_allgather: broadcast vs allgather equivalence - test_overlap_param_gather_multi_iteration: 3-iteration convergence test Co-Authored-By: Claude Opus 4.6 --- tests/unit_tests/test_layer_wise_optimizer.py | 377 +++++++++++++++++- 1 file changed, 376 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index 05ce26bcfa0..50bc971f3ac 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -15,7 +15,7 @@ from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer import TransformerConfig -from megatron.core.utils import get_pg_size +from megatron.core.utils import get_pg_rank, get_pg_size from tests.unit_tests.test_utilities import Utils # Skip all tests in this file for LTS versions @@ -130,6 +130,71 @@ def create_model_and_optimizer( ) return model, optimizer, pg_collection + def create_model_and_optimizer_with_overlap_param_gather( + self, + model_class=SimpleModel, + clip_grad=1.0, + model_kwargs=None, + copy_from=None, + async_allgather=True, + ): + """Create model, DDP wrapper, and optimizer with overlap-param-gather enabled. + + This variant sets use_layer_wise_optimizer=True and overlap_param_gather=True + in DDP config and passes model_chunks=[model] + async_allgather to + LayerWiseDistributedOptimizer, enabling the bucket-based async param gather path. + + Args: + model_class: Model class to instantiate + clip_grad: Optional gradient clipping value + model_kwargs: Optional kwargs for model initialization + copy_from: Optional DDP model to copy weights from + async_allgather: If True, defer param all-gather to bucket infrastructure + + Returns: + tuple: (model, optimizer, pg_collection) + """ + if model_kwargs is None: + model_kwargs = {} + + model = model_class(**model_kwargs).bfloat16().cuda() + model.requires_grad_(True) + + ddp_config = DistributedDataParallelConfig( + use_distributed_optimizer=False, + use_layer_wise_optimizer=True, + overlap_param_gather=True, + ) + model = DistributedDataParallel( + TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model + ) + if copy_from: + model.module.load_state_dict(copy_from.module.state_dict()) + else: + model.broadcast_params() + + optimizer_config = OptimizerConfig( + optimizer='adam', + lr=0.01, + weight_decay=0.01, + bf16=False, + use_distributed_optimizer=False, + clip_grad=clip_grad, + ) + + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) + pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() + + optimizer = get_megatron_optimizer(optimizer_config, [model]) + optimizer_config.bf16 = True + optimizer = LayerWiseDistributedOptimizer( + optimizer.chained_optimizers, optimizer_config, pg_collection, + model_chunks=[model], + async_allgather=async_allgather, + ) + return model, optimizer, pg_collection + def create_reference_model(self, model): """Create a reference model by cloning the current model.""" reference_model = type(model.module)().bfloat16().cuda() @@ -438,3 +503,313 @@ def test_broadcast_vs_allgather(self): # Verify updated values match reference optimizer for param, ref_param in zip(model.parameters(), reference_model.parameters()): torch.testing.assert_close(param.data, ref_param.data, rtol=0, atol=0) + + # ---- Overlap-param-gather tests ---- + + def test_overlap_param_gather_basic(self): + """Test overlap-param-gather path: init, forward/backward/step, bucket-based param sync.""" + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather() + ) + + assert optimizer is not None, "Optimizer should not be None" + assert optimizer.async_allgather, "async_allgather should be True" + + reference_model = self.create_reference_model(model) + + input_tensor = torch.randn(16, 80, dtype=torch.bfloat16, device='cuda') + output = model(input_tensor) + loss = output.sum() + loss.backward() + + # step() updates local params but skips allgather (async_allgather=True) + update_successful, grad_norm, num_zeros = optimizer.step() + + assert update_successful, "Optimizer step should be successful" + + # Manually sync params through the bucket-based param sync path + # force_sync=True does synchronous allgather via bucket infrastructure + model.start_param_sync(force_sync=True) + + # Verify parameters were updated + params_updated = 0 + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + if not torch.equal(param.data, ref_param.data): + params_updated += 1 + + assert params_updated > 0, "At least some parameters should be updated" + + # Verify all ranks have the same updated parameters + dp_size = get_pg_size(pg_collection.dp_cp) + + if dp_size > 1: + for name, param in model.named_parameters(): + param_list = [torch.zeros_like(param.data) for _ in range(dp_size)] + torch.distributed.all_gather(param_list, param.data, group=pg_collection.dp_cp) + + for i in range(1, dp_size): + torch.testing.assert_close( + param_list[0], param_list[i], + msg=f"Parameter {name} differs between rank 0 and rank {i}", + ) + + def test_overlap_param_gather_parameter_updates(self): + """Test overlap-param-gather produces same parameter updates as standard optimizer.""" + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather() + ) + + # Create reference model with standard (non-layer-wise) optimizer + reference_model, reference_optimizer, _ = self.create_model_and_optimizer( + use_layer_wise=False, copy_from=model + ) + + # Set same gradients on both models + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + assert torch.equal(param.data, ref_param.data) + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # step() with async_allgather=True: updates but no allgather + optimizer.step() + # Manually sync params via bucket infrastructure + model.start_param_sync(force_sync=True) + + reference_optimizer.step() + + # Verify updated values match reference optimizer + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + torch.testing.assert_close(param.data, ref_param.data, rtol=1e-5, atol=1e-5) + + def test_overlap_param_gather_vs_sync_allgather(self): + """Key correctness test: overlap path and sync allgather produce identical updates. + + Compares: + - Overlap path: async_allgather=True, bucket-based param sync + - Sync path: async_allgather=False, optimizer.allgather_params() in step() + """ + # Create overlap model + overlap_model, overlap_optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) + ) + + # Create sync model with same weights (use_layer_wise_optimizer=True but sync allgather) + sync_model, sync_optimizer, _ = ( + self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=overlap_model + ) + ) + + # Verify initial parameters match + for op, sp in zip(overlap_model.parameters(), sync_model.parameters()): + assert torch.equal(op.data, sp.data) + + # Set identical gradients on both + for op, sp in zip(overlap_model.parameters(), sync_model.parameters()): + grad_value = torch.randn_like(op) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + op.main_grad = grad_value.clone().detach() + sp.main_grad = grad_value.clone().detach() + + # Overlap path: step + manual sync + overlap_optimizer.step() + overlap_model.start_param_sync(force_sync=True) + + # Sync path: step (includes allgather_params) + sync_optimizer.step() + + # Both paths should produce identical parameter values + for op, sp in zip(overlap_model.parameters(), sync_model.parameters()): + torch.testing.assert_close( + op.data, sp.data, rtol=0, atol=0, + msg="Overlap and sync allgather paths produced different parameter updates", + ) + + def test_overlap_param_gather_bucket_lw_params(self): + """Verify bucket.lw_params_list is populated when async_allgather is enabled.""" + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather() + ) + + dp_size = get_pg_size(pg_collection.dp_cp) + + for bucket_group in model.bucket_groups: + for bucket in bucket_group.buckets: + # lw_params_list should be populated by set_bucket_lw_params_list + assert bucket.lw_params_list is not None, ( + "bucket.lw_params_list should be populated" + ) + assert len(bucket.lw_params_list) == dp_size, ( + f"Expected {dp_size} per-rank lists, got {len(bucket.lw_params_list)}" + ) + + # The union of all per-rank param lists should cover all bucket params + all_lw_params = set() + for rank_params in bucket.lw_params_list: + for p in rank_params: + all_lw_params.add(p) + assert all_lw_params == bucket.params, ( + "Union of per-rank lw_params should equal bucket params" + ) + + # lw_param_flat_sizes should be populated and have correct length + assert bucket.lw_param_flat_sizes is not None + assert len(bucket.lw_param_flat_sizes) == dp_size + + # Each flat size should equal the sum of param numels for that rank + for rank_idx in range(dp_size): + expected_size = sum( + p.numel() for p in bucket.lw_params_list[rank_idx] + ) + assert bucket.lw_param_flat_sizes[rank_idx] == expected_size, ( + f"Rank {rank_idx}: expected flat_size {expected_size}, " + f"got {bucket.lw_param_flat_sizes[rank_idx]}" + ) + + def test_overlap_param_gather_vs_standard_ddp(self): + """Verify DDP with use_layer_wise_optimizer=True produces same results as standard DDP. + + Both use LayerWiseDistributedOptimizer but with different DDP configs: + - Overlap path: use_layer_wise_optimizer=True (padded buffers) + - Standard path: use_layer_wise_optimizer=False (unpadded buffers) + """ + # Create overlap-param-gather model (sync allgather for simpler comparison) + opg_model, opg_optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=False) + ) + + # Create standard model with same weights + std_model, std_optimizer, _ = self.create_model_and_optimizer(copy_from=opg_model) + + # Set identical gradients + for op, sp in zip(opg_model.parameters(), std_model.parameters()): + assert torch.equal(op.data, sp.data) + grad_value = torch.randn_like(op) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + op.main_grad = grad_value.clone().detach() + sp.main_grad = grad_value.clone().detach() + + opg_optimizer.step() + std_optimizer.step() + + # Both should produce identical parameter values + for op, sp in zip(opg_model.parameters(), std_model.parameters()): + torch.testing.assert_close( + op.data, sp.data, rtol=1e-5, atol=1e-5, + msg="Overlap-param-gather and standard paths produced different updates", + ) + + def test_overlap_param_gather_insufficient_parameters(self): + """Test overlap-param-gather with TinyModel (only 2 params). + + Many ranks will have no assigned params when world_size > 2. + """ + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather(model_class=TinyModel) + ) + + # Create reference model with standard (non-layer-wise) optimizer + reference_model, reference_optimizer, _ = self.create_model_and_optimizer( + model_class=TinyModel, use_layer_wise=False, copy_from=model + ) + + # Set same gradients on both models + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + assert torch.equal(param.data, ref_param.data) + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + optimizer.step() + model.start_param_sync(force_sync=True) + + reference_optimizer.step() + + # Verify updated values match reference optimizer + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + torch.testing.assert_close(param.data, ref_param.data, rtol=1e-5, atol=1e-5) + + def test_overlap_param_gather_broadcast_vs_allgather(self): + """Test overlap-param-gather: allgather vs broadcast produce same results.""" + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather( + model_class=SimpleModel, async_allgather=False + ) + ) + + # Create reference model with overlap-param-gather path too + reference_model, reference_optimizer, _ = ( + self.create_model_and_optimizer_with_overlap_param_gather( + model_class=SimpleModel, async_allgather=False, copy_from=model + ) + ) + + # Set same gradients on both models + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + assert torch.equal(param.data, ref_param.data) + torch.testing.assert_close(param.data, ref_param.data, rtol=0, atol=0) + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + optimizer.step() + + # Verify at least some parameters were updated + params_updated = 0 + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + if not torch.equal(param.data, ref_param.data): + params_updated += 1 + + assert params_updated > 0, "At least some parameters should be updated" + + # step() internally calls allgather_params. Replace reference with broadcast. + reference_optimizer.allgather_params = reference_optimizer.broadcast_params + reference_optimizer.step() + + # Verify updated values match reference optimizer + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + torch.testing.assert_close(param.data, ref_param.data, rtol=0, atol=0) + + def test_overlap_param_gather_multi_iteration(self): + """Test overlap-param-gather correctness over multiple training iterations. + + Runs multiple forward/backward/step iterations using the async allgather path. + After each iteration, manually syncs params and verifies they match a reference + model using the sync path. + """ + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) + ) + + # Create reference model with sync allgather for comparison + ref_model, ref_optimizer, _ = ( + self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model + ) + ) + + for iteration in range(3): + # Set identical gradients on both models + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # Async path: step (no allgather) + manual sync + optimizer.step() + model.start_param_sync(force_sync=True) + + # Sync path: step (includes allgather) + ref_optimizer.step() + + # Verify parameters match after each iteration + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, ref_param.data, rtol=0, atol=0, + msg=f"Parameters diverged at iteration {iteration}", + ) From fa0ceb97ed45efa78efd51aab027bdddf1759bda Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Tue, 17 Feb 2026 20:41:32 -0800 Subject: [PATCH 04/25] Remove use_layer_wise_optimizer from DDP config Replace with overlap_param_gather and !use_distributed_optimizer checks, which already capture the same semantics without a dedicated flag. Co-Authored-By: Claude Opus 4.6 --- .../core/distributed/distributed_data_parallel.py | 10 ++-------- .../distributed/distributed_data_parallel_config.py | 3 --- megatron/core/distributed/param_and_grad_buffer.py | 10 +++++----- megatron/training/training.py | 4 ---- tests/unit_tests/test_layer_wise_optimizer.py | 13 ++++++------- 5 files changed, 13 insertions(+), 27 deletions(-) diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 8425d1548df..e2a19832acc 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -236,10 +236,7 @@ def _allocate_buffers_for_parameters( # Set `next_param_gather_bucket_group` for different bucket groups by iterating through # buckets in reverse order (since all-gathers happen in reverse order of buckets). - if ( - self.ddp_config.use_distributed_optimizer - or self.ddp_config.use_layer_wise_optimizer - ) and self.ddp_config.overlap_param_gather: + if self.ddp_config.overlap_param_gather: num_bucket_groups = len(bucket_groups) for i in range(1, num_bucket_groups): bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = ( @@ -347,10 +344,7 @@ def unmap_weight_tensor(m): grad_acc.register_hook(self._make_backward_post_hook(param)) self.grad_accs.append(grad_acc) - self.use_forward_hook = ( - (self.ddp_config.use_distributed_optimizer or self.ddp_config.use_layer_wise_optimizer) - and self.ddp_config.overlap_param_gather - ) + self.use_forward_hook = self.ddp_config.overlap_param_gather self.remove_forward_pre_hook_handles = {} if self.use_forward_hook: self.enable_forward_pre_hook() diff --git a/megatron/core/distributed/distributed_data_parallel_config.py b/megatron/core/distributed/distributed_data_parallel_config.py index 67d2534ee39..c4b25b9f85c 100644 --- a/megatron/core/distributed/distributed_data_parallel_config.py +++ b/megatron/core/distributed/distributed_data_parallel_config.py @@ -27,9 +27,6 @@ class DistributedDataParallelConfig: originally allocated model parameters, otherwise issue all-reduce collectives. """ - use_layer_wise_optimizer: bool = False - """If true, use layer-wise distributed optimizer for param all-gather overlap.""" - num_distributed_optimizer_instances: int = 1 """Sets the factor by which the DP domain is sharded to have the partial DistOpt enabled. Defaults to 1, which means DistOpt is across entire DP domain. diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 95b2c17884b..8b8ff3291b4 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -176,7 +176,7 @@ def __init__( self.buckets = buckets self.ddp_config = ddp_config - if self.ddp_config.use_distributed_optimizer or self.ddp_config.use_layer_wise_optimizer: + if self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather: self.intra_distributed_optimizer_instance_group = collective_group self.intra_distributed_optimizer_instance_size = collective_group_size self.intra_distributed_optimizer_instance_rank = collective_group.rank() @@ -300,7 +300,7 @@ def start_param_sync(self, force_sync: bool = False): force_sync (bool, optional): force synchronous collective regardless of other settings if true. """ - assert self.ddp_config.use_distributed_optimizer or self.ddp_config.use_layer_wise_optimizer + assert self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather if force_sync: if self.param_gather_handle is not None: @@ -312,7 +312,7 @@ def start_param_sync(self, force_sync: bool = False): async_op = self.ddp_config.overlap_param_gather and not force_sync - if self.ddp_config.use_layer_wise_optimizer: + if not self.ddp_config.use_distributed_optimizer: # Layer-wise optimizer path: do NOT use _coalescing_manager. # # torch.distributed.all_gather with a tensor-list output internally creates a @@ -413,7 +413,7 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): skip_next_bucket_dispatch (bool, optional): if true, dispatch next bucket's communication if available. """ - assert self.ddp_config.use_distributed_optimizer or self.ddp_config.use_layer_wise_optimizer + assert self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather assert self.ddp_config.overlap_param_gather # If current bucket's param AG has not been dispatched, dispatch it now (e.g., first @@ -451,7 +451,7 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): # correspond to multiple param buffers. If we zero out the entire grad buffer, # it would clear the data of those param buffers that have not yet completed AG. bucket.param_data.zero_() - elif self.ddp_config.use_layer_wise_optimizer: + elif not self.ddp_config.use_distributed_optimizer: for bucket in self.buckets: # Unflatten and copy gathered params for each rank. for idx, (flat_params, params) in enumerate( diff --git a/megatron/training/training.py b/megatron/training/training.py index c5a81722c96..3e8b8c14ee0 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1374,10 +1374,6 @@ def build_model(): kwargs['average_in_collective'] = args.ddp_average_in_collective ddp_config = DistributedDataParallelConfig(**kwargs) - # Enable layer-wise optimizer flag for distributed muon/mop variants. - if 'dist' in args.optimizer: - ddp_config.use_layer_wise_optimizer = True - # In the Megatron FSDP and DDP use path, we need to initialize the bucket size. # If bucket_size is not provided as an input, use sane default. # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index 50bc971f3ac..d67703e1ad3 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -140,8 +140,8 @@ def create_model_and_optimizer_with_overlap_param_gather( ): """Create model, DDP wrapper, and optimizer with overlap-param-gather enabled. - This variant sets use_layer_wise_optimizer=True and overlap_param_gather=True - in DDP config and passes model_chunks=[model] + async_allgather to + This variant sets overlap_param_gather=True in DDP config and passes + model_chunks=[model] + async_allgather to LayerWiseDistributedOptimizer, enabling the bucket-based async param gather path. Args: @@ -162,7 +162,6 @@ def create_model_and_optimizer_with_overlap_param_gather( ddp_config = DistributedDataParallelConfig( use_distributed_optimizer=False, - use_layer_wise_optimizer=True, overlap_param_gather=True, ) model = DistributedDataParallel( @@ -595,7 +594,7 @@ def test_overlap_param_gather_vs_sync_allgather(self): self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) ) - # Create sync model with same weights (use_layer_wise_optimizer=True but sync allgather) + # Create sync model with same weights (overlap_param_gather=True but sync allgather) sync_model, sync_optimizer, _ = ( self.create_model_and_optimizer_with_overlap_param_gather( async_allgather=False, copy_from=overlap_model @@ -669,11 +668,11 @@ def test_overlap_param_gather_bucket_lw_params(self): ) def test_overlap_param_gather_vs_standard_ddp(self): - """Verify DDP with use_layer_wise_optimizer=True produces same results as standard DDP. + """Verify DDP with overlap_param_gather=True produces same results as standard DDP. Both use LayerWiseDistributedOptimizer but with different DDP configs: - - Overlap path: use_layer_wise_optimizer=True (padded buffers) - - Standard path: use_layer_wise_optimizer=False (unpadded buffers) + - Overlap path: overlap_param_gather=True (padded buffers) + - Standard path: overlap_param_gather=False (unpadded buffers) """ # Create overlap-param-gather model (sync allgather for simpler comparison) opg_model, opg_optimizer, pg_collection = ( From bbed683fabc4807a5f7b22e7cc5df8f506aeb563 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Tue, 17 Feb 2026 20:48:59 -0800 Subject: [PATCH 05/25] Add comments explaining overlap_param_gather replacing use_layer_wise_optimizer Co-Authored-By: Claude Opus 4.6 --- megatron/core/distributed/distributed_data_parallel.py | 6 ++++++ megatron/core/distributed/param_and_grad_buffer.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index e2a19832acc..32067347d44 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -236,6 +236,9 @@ def _allocate_buffers_for_parameters( # Set `next_param_gather_bucket_group` for different bucket groups by iterating through # buckets in reverse order (since all-gathers happen in reverse order of buckets). + # Note: overlap_param_gather covers both the distributed optimizer and the + # layer-wise optimizer cases; the latter sets overlap_param_gather=True + # without use_distributed_optimizer. if self.ddp_config.overlap_param_gather: num_bucket_groups = len(bucket_groups) for i in range(1, num_bucket_groups): @@ -344,6 +347,9 @@ def unmap_weight_tensor(m): grad_acc.register_hook(self._make_backward_post_hook(param)) self.grad_accs.append(grad_acc) + # Note: overlap_param_gather covers both the distributed optimizer and the + # layer-wise optimizer cases; the latter sets overlap_param_gather=True + # without use_distributed_optimizer. self.use_forward_hook = self.ddp_config.overlap_param_gather self.remove_forward_pre_hook_handles = {} if self.use_forward_hook: diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 8b8ff3291b4..cc9454dccd0 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -176,6 +176,8 @@ def __init__( self.buckets = buckets self.ddp_config = ddp_config + # overlap_param_gather covers the layer-wise optimizer case, which sets + # overlap_param_gather=True without use_distributed_optimizer. if self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather: self.intra_distributed_optimizer_instance_group = collective_group self.intra_distributed_optimizer_instance_size = collective_group_size @@ -300,6 +302,8 @@ def start_param_sync(self, force_sync: bool = False): force_sync (bool, optional): force synchronous collective regardless of other settings if true. """ + # overlap_param_gather covers the layer-wise optimizer case, which sets + # overlap_param_gather=True without use_distributed_optimizer. assert self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather if force_sync: @@ -413,6 +417,8 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): skip_next_bucket_dispatch (bool, optional): if true, dispatch next bucket's communication if available. """ + # overlap_param_gather covers the layer-wise optimizer case, which sets + # overlap_param_gather=True without use_distributed_optimizer. assert self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather assert self.ddp_config.overlap_param_gather From 46bc317c6f89eb3c375a18f62469f18a060eeb3a Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Fri, 20 Feb 2026 12:19:39 -0800 Subject: [PATCH 06/25] Run autoformat (black, isort) on changed files Co-Authored-By: Claude Opus 4.6 --- .../core/distributed/param_and_grad_buffer.py | 6 +- .../core/optimizer/layer_wise_optimizer.py | 10 +-- tests/unit_tests/test_layer_wise_optimizer.py | 73 ++++++++++--------- 3 files changed, 44 insertions(+), 45 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index cc9454dccd0..e8cfccf8afe 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -10,8 +10,8 @@ from typing import Dict, List, Optional import torch -from torch.distributed import _coalescing_manager from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.distributed import _coalescing_manager import megatron.core.nccl_allocator as nccl_allocator from megatron.core import parallel_state @@ -333,9 +333,7 @@ def start_param_sync(self, force_sync: bool = False): _flatten_dense_tensors(bucket.lw_params_list[local_rank]) if len(bucket.lw_params_list[local_rank]) > 0 else torch.empty( - 0, - device=bucket.grad_data.device, - dtype=bucket.grad_data.dtype, + 0, device=bucket.grad_data.device, dtype=bucket.grad_data.dtype ) ) # Keep src alive until the async operation completes. diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index 84a37a4e439..42fc6e2b8a0 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -66,9 +66,9 @@ def __init__( # Set up async all-gather using DDP bucket infrastructure. self.async_allgather = async_allgather if self.async_allgather: - assert model_chunks is not None, ( - "model_chunks must be provided if async_allgather is True" - ) + assert ( + model_chunks is not None + ), "model_chunks must be provided if async_allgather is True" self.set_bucket_lw_params_list(model_chunks) if init_state_fn_list: @@ -168,9 +168,7 @@ def set_bucket_lw_params_list(self, model_chunks): for model_chunk in model_chunks: for group in model_chunk.bucket_groups: for bucket in group.buckets: - bucket_params_list = [ - [] for _ in range(get_pg_size(self.pg_collection.dp_cp)) - ] + bucket_params_list = [[] for _ in range(get_pg_size(self.pg_collection.dp_cp))] for bucket_list, full_params_list in zip( bucket_params_list, self.dp_cp_params_list ): diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index d67703e1ad3..043b7f12588 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -161,8 +161,7 @@ def create_model_and_optimizer_with_overlap_param_gather( model.requires_grad_(True) ddp_config = DistributedDataParallelConfig( - use_distributed_optimizer=False, - overlap_param_gather=True, + use_distributed_optimizer=False, overlap_param_gather=True ) model = DistributedDataParallel( TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model @@ -188,7 +187,9 @@ def create_model_and_optimizer_with_overlap_param_gather( optimizer = get_megatron_optimizer(optimizer_config, [model]) optimizer_config.bf16 = True optimizer = LayerWiseDistributedOptimizer( - optimizer.chained_optimizers, optimizer_config, pg_collection, + optimizer.chained_optimizers, + optimizer_config, + pg_collection, model_chunks=[model], async_allgather=async_allgather, ) @@ -548,7 +549,8 @@ def test_overlap_param_gather_basic(self): for i in range(1, dp_size): torch.testing.assert_close( - param_list[0], param_list[i], + param_list[0], + param_list[i], msg=f"Parameter {name} differs between rank 0 and rank {i}", ) @@ -595,10 +597,8 @@ def test_overlap_param_gather_vs_sync_allgather(self): ) # Create sync model with same weights (overlap_param_gather=True but sync allgather) - sync_model, sync_optimizer, _ = ( - self.create_model_and_optimizer_with_overlap_param_gather( - async_allgather=False, copy_from=overlap_model - ) + sync_model, sync_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=overlap_model ) # Verify initial parameters match @@ -622,7 +622,10 @@ def test_overlap_param_gather_vs_sync_allgather(self): # Both paths should produce identical parameter values for op, sp in zip(overlap_model.parameters(), sync_model.parameters()): torch.testing.assert_close( - op.data, sp.data, rtol=0, atol=0, + op.data, + sp.data, + rtol=0, + atol=0, msg="Overlap and sync allgather paths produced different parameter updates", ) @@ -637,21 +640,21 @@ def test_overlap_param_gather_bucket_lw_params(self): for bucket_group in model.bucket_groups: for bucket in bucket_group.buckets: # lw_params_list should be populated by set_bucket_lw_params_list - assert bucket.lw_params_list is not None, ( - "bucket.lw_params_list should be populated" - ) - assert len(bucket.lw_params_list) == dp_size, ( - f"Expected {dp_size} per-rank lists, got {len(bucket.lw_params_list)}" - ) + assert ( + bucket.lw_params_list is not None + ), "bucket.lw_params_list should be populated" + assert ( + len(bucket.lw_params_list) == dp_size + ), f"Expected {dp_size} per-rank lists, got {len(bucket.lw_params_list)}" # The union of all per-rank param lists should cover all bucket params all_lw_params = set() for rank_params in bucket.lw_params_list: for p in rank_params: all_lw_params.add(p) - assert all_lw_params == bucket.params, ( - "Union of per-rank lw_params should equal bucket params" - ) + assert ( + all_lw_params == bucket.params + ), "Union of per-rank lw_params should equal bucket params" # lw_param_flat_sizes should be populated and have correct length assert bucket.lw_param_flat_sizes is not None @@ -659,9 +662,7 @@ def test_overlap_param_gather_bucket_lw_params(self): # Each flat size should equal the sum of param numels for that rank for rank_idx in range(dp_size): - expected_size = sum( - p.numel() for p in bucket.lw_params_list[rank_idx] - ) + expected_size = sum(p.numel() for p in bucket.lw_params_list[rank_idx]) assert bucket.lw_param_flat_sizes[rank_idx] == expected_size, ( f"Rank {rank_idx}: expected flat_size {expected_size}, " f"got {bucket.lw_param_flat_sizes[rank_idx]}" @@ -696,7 +697,10 @@ def test_overlap_param_gather_vs_standard_ddp(self): # Both should produce identical parameter values for op, sp in zip(opg_model.parameters(), std_model.parameters()): torch.testing.assert_close( - op.data, sp.data, rtol=1e-5, atol=1e-5, + op.data, + sp.data, + rtol=1e-5, + atol=1e-5, msg="Overlap-param-gather and standard paths produced different updates", ) @@ -705,8 +709,8 @@ def test_overlap_param_gather_insufficient_parameters(self): Many ranks will have no assigned params when world_size > 2. """ - model, optimizer, pg_collection = ( - self.create_model_and_optimizer_with_overlap_param_gather(model_class=TinyModel) + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + model_class=TinyModel ) # Create reference model with standard (non-layer-wise) optimizer @@ -733,10 +737,8 @@ def test_overlap_param_gather_insufficient_parameters(self): def test_overlap_param_gather_broadcast_vs_allgather(self): """Test overlap-param-gather: allgather vs broadcast produce same results.""" - model, optimizer, pg_collection = ( - self.create_model_and_optimizer_with_overlap_param_gather( - model_class=SimpleModel, async_allgather=False - ) + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + model_class=SimpleModel, async_allgather=False ) # Create reference model with overlap-param-gather path too @@ -780,15 +782,13 @@ def test_overlap_param_gather_multi_iteration(self): After each iteration, manually syncs params and verifies they match a reference model using the sync path. """ - model, optimizer, pg_collection = ( - self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True ) # Create reference model with sync allgather for comparison - ref_model, ref_optimizer, _ = ( - self.create_model_and_optimizer_with_overlap_param_gather( - async_allgather=False, copy_from=model - ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model ) for iteration in range(3): @@ -809,6 +809,9 @@ def test_overlap_param_gather_multi_iteration(self): # Verify parameters match after each iteration for param, ref_param in zip(model.parameters(), ref_model.parameters()): torch.testing.assert_close( - param.data, ref_param.data, rtol=0, atol=0, + param.data, + ref_param.data, + rtol=0, + atol=0, msg=f"Parameters diverged at iteration {iteration}", ) From 3e88f38393a5e60fb5a9cf8ce3a74b15f80bbe42 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Sun, 22 Feb 2026 09:48:43 -0800 Subject: [PATCH 07/25] Free overlap param-gather buffers before async checkpoint save to fix OOM The overlap_param_gather feature allocates large temporary GPU buffers during the forward pass for async all-gather operations. Although freed after finish_param_sync, PyTorch's CUDA allocator caches the memory, which is invisible to the async checkpoint worker's CUDA context. This caused OOM during D2H tensor transfers, preventing checkpoints from finalizing and trapping the run in a restart loop. Add free_overlap_buffers() to _ParamAndGradBucketGroup and DDP that explicitly releases these buffers, and call it + torch.cuda.empty_cache() in save_checkpoint_and_time() before the save_checkpoint() call. Co-Authored-By: Claude Opus 4.6 --- .../distributed/distributed_data_parallel.py | 5 +++++ .../core/distributed/param_and_grad_buffer.py | 16 ++++++++++++++++ megatron/training/training.py | 7 +++++++ 3 files changed, 28 insertions(+) diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 32067347d44..f4647d764aa 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -538,6 +538,11 @@ def finish_grad_sync(self, force_all_reduce: Optional[bool] = False): for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: bucket_group.finish_grad_sync(force_all_reduce=force_all_reduce) + def free_overlap_buffers(self): + """Free overlap param-gather GPU buffers across all bucket groups.""" + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.free_overlap_buffers() + def scale_gradients(self, scaling_factor: float): """Scale all gradients inside the buffers by `scaling_factor`.""" for buffer in self.buffers + self.expert_parallel_buffers: diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index e8cfccf8afe..9535a669fd0 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -656,6 +656,22 @@ def finish_grad_sync(self, force_all_reduce: Optional[bool] = False): self.grad_reduce_handle.wait() self.grad_reduce_handle = None + def free_overlap_buffers(self): + """Free GPU buffers used by overlap param gather. + + Waits on any pending param all-gather handle, then releases the + per-bucket temporary buffers so that the CUDA memory allocator can + reclaim them. Called before async checkpoint saves to avoid OOM in + the persistent checkpoint worker process. + """ + if self.param_gather_handle is not None: + self.param_gather_handle.wait() + self.param_gather_handle = None + for bucket in self.buckets: + if bucket.lw_gather_tensor_list is not None: + bucket.lw_gather_tensor_list.clear() + bucket._lw_src_buffer = None + def register_grad_ready( self, param: torch.nn.Parameter, force_all_reduce: Optional[bool] = False ): diff --git a/megatron/training/training.py b/megatron/training/training.py index 3e8b8c14ee0..1d0b15df13f 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -2257,6 +2257,13 @@ def save_checkpoint_and_time( one_logger_utils.track_e2e_metrics() if should_disable_forward_pre_hook(args): force_param_sync(model) + # Free overlap param-gather buffers and release cached GPU memory so + # that the async checkpoint worker process has enough GPU headroom for + # D2H tensor transfers. + for model_chunk in model: + if hasattr(model_chunk, 'free_overlap_buffers'): + model_chunk.free_overlap_buffers() + torch.cuda.empty_cache() global num_checkpoints_memory_reported, MAX_NUM_CHECKPOINTS_MEMORY_REPORTED should_report_memory = num_checkpoints_memory_reported < MAX_NUM_CHECKPOINTS_MEMORY_REPORTED From 3bf616c018645a407fab7770b6e0ef79c44fc7de Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Sun, 22 Feb 2026 09:54:46 -0800 Subject: [PATCH 08/25] Add unit tests for free_overlap_buffers Tests that free_overlap_buffers(): - Clears lw_gather_tensor_list and nulls _lw_src_buffer on each bucket - Waits on any pending param_gather_handle before freeing - Is safe to call when no buffers are allocated (noop) - DDP.free_overlap_buffers delegates to all bucket groups Co-Authored-By: Claude Opus 4.6 --- .../distributed/test_param_and_grad_buffer.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/tests/unit_tests/distributed/test_param_and_grad_buffer.py b/tests/unit_tests/distributed/test_param_and_grad_buffer.py index b60dfb1791b..cab2f06c3a0 100644 --- a/tests/unit_tests/distributed/test_param_and_grad_buffer.py +++ b/tests/unit_tests/distributed/test_param_and_grad_buffer.py @@ -373,3 +373,97 @@ def test_force_all_reduce_uses_correct_collective(force_all_reduce: bool): ), "Expected all_reduce NOT to be called when force_all_reduce=False" Utils.destroy_model_parallel() + + +class TestFreeOverlapBuffers: + """Tests for free_overlap_buffers() which releases GPU memory before async checkpoint saves.""" + + @staticmethod + def _make_model(): + """Create a DDP-wrapped model with overlap_param_gather enabled.""" + Utils.initialize_model_parallel() + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=True, + use_distributed_optimizer=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + bucket_size=None, + ) + module = TestModel( + input_dim=32, output_dim=32, num_layers=2, bias=False, shared_embedding=False, + ).bfloat16() + model = DistributedDataParallel( + TransformerConfig(num_attention_heads=1, num_layers=1), + ddp_config=ddp_config, + module=module, + ) + return model + + def test_bucket_group_clears_buffers(self): + """free_overlap_buffers on a bucket group should None-out per-bucket lw buffers.""" + model = self._make_model() + + for bg in model.bucket_groups: + # Simulate buffers that would be allocated by start_param_sync. + for bucket in bg.buckets: + bucket.lw_gather_tensor_list = [torch.empty(8), torch.empty(8)] + bucket._lw_src_buffer = torch.empty(16) + + bg.free_overlap_buffers() + + for bucket in bg.buckets: + assert ( + bucket.lw_gather_tensor_list is not None + and len(bucket.lw_gather_tensor_list) == 0 + ), "lw_gather_tensor_list should be empty after free_overlap_buffers" + assert ( + bucket._lw_src_buffer is None + ), "_lw_src_buffer should be None after free_overlap_buffers" + + Utils.destroy_model_parallel() + + def test_bucket_group_waits_on_pending_handle(self): + """free_overlap_buffers should wait() on any pending param_gather_handle.""" + model = self._make_model() + + for bg in model.bucket_groups: + mock_handle = mock.MagicMock() + bg.param_gather_handle = mock_handle + + bg.free_overlap_buffers() + + mock_handle.wait.assert_called_once() + assert bg.param_gather_handle is None, ( + "param_gather_handle should be None after free_overlap_buffers" + ) + + Utils.destroy_model_parallel() + + def test_bucket_group_noop_when_no_buffers(self): + """free_overlap_buffers should be safe to call when no buffers are allocated.""" + model = self._make_model() + + for bg in model.bucket_groups: + assert bg.param_gather_handle is None + for bucket in bg.buckets: + assert bucket.lw_gather_tensor_list is None + assert bucket._lw_src_buffer is None + + # Should not raise. + bg.free_overlap_buffers() + + Utils.destroy_model_parallel() + + def test_ddp_free_overlap_buffers_delegates(self): + """DDP.free_overlap_buffers should call free_overlap_buffers on all bucket groups.""" + model = self._make_model() + + with mock.patch.object( + type(model.bucket_groups[0]), 'free_overlap_buffers' + ) as mock_free: + model.free_overlap_buffers() + assert mock_free.call_count == len( + model.bucket_groups + model.expert_parallel_bucket_groups + ), "free_overlap_buffers should be called on every bucket group" + + Utils.destroy_model_parallel() From f111804d83f81757251135702dc8b778ace5ca88 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Sun, 22 Feb 2026 15:57:36 -0800 Subject: [PATCH 09/25] Replace all_gather with per-rank broadcasts for layer-wise param gather MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each rank may own different numbers of params per bucket, creating variable flat sizes. NCCL's all_gather requires uniform sizes. Replace with dp_size broadcast calls per bucket: each rank broadcasts its actual-size flattened params to all others. This eliminates padding and uses only collectives (no P2P, which can deadlock with subsequent collectives on the same NCCL communicator). Memory cost: sum(lw_param_flat_sizes) per bucket — no padding waste. Co-Authored-By: Claude Opus 4.6 --- .../core/distributed/param_and_grad_buffer.py | 98 ++++++++++++------- .../core/optimizer/layer_wise_optimizer.py | 46 ++++++--- 2 files changed, 99 insertions(+), 45 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 9535a669fd0..afc50c56de2 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -132,15 +132,10 @@ def set_lw_params_list(self, lw_params_list: List[List[torch.nn.Parameter]]): class _LayerWiseAllGatherHandle: - """Handle for multiple async all-gather operations used by the layer-wise optimizer. - - torch.distributed.all_gather with a tensor list output internally creates a temporary - contiguous buffer and copies chunks back to the individual output tensors when wait() - is called on the returned work handle. When wrapped in _coalescing_manager, this - copy-back step is lost because the coalescing manager's handle only waits on the - NCCL operations, not the individual copy-back callbacks. This class stores the - individual work handles so that wait() properly triggers the copy-back for each - all_gather call. + """Handle for multiple async broadcast operations used by the layer-wise optimizer. + + Wraps the list of work handles returned by per-rank broadcast calls so that + a single wait() call waits on all outstanding operations. """ def __init__(self, handles): @@ -317,18 +312,27 @@ def start_param_sync(self, force_sync: bool = False): async_op = self.ddp_config.overlap_param_gather and not force_sync if not self.ddp_config.use_distributed_optimizer: - # Layer-wise optimizer path: do NOT use _coalescing_manager. + # Layer-wise optimizer path: use per-rank broadcasts for + # variable-size param gather. + # + # Each rank may own a different number of params per bucket, so + # lw_param_flat_sizes can vary across ranks. NCCL's ncclAllGather + # requires uniform send sizes, so we cannot use all_gather directly. # - # torch.distributed.all_gather with a tensor-list output internally creates a - # temporary contiguous buffer, calls ncclAllGather into it, and then copies chunks - # back to the individual output tensors when wait() is called on the work handle. - # When wrapped in _coalescing_manager, the coalescing manager's handle only waits - # on the grouped NCCL operations but does NOT trigger the per-op copy-back step, - # leaving the output tensors uninitialized. We avoid this by calling all_gather - # directly and storing the individual work handles. + # Instead, we issue dp_size broadcasts per bucket: for each rank i, + # all ranks call broadcast with rank i as the source. On the source + # rank the buffer is the flattened local params; on other ranks the + # buffer is a correctly-sized receive allocation. This avoids padding + # and uses only collectives (no P2P send/recv, which can deadlock + # with subsequent collectives on the same NCCL communicator). + # + # Memory cost: sum(lw_param_flat_sizes) elements per bucket (the + # receive buffers for each remote rank, plus the local flattened src). + dp_size = self.intra_distributed_optimizer_instance_size + local_rank = self.intra_distributed_optimizer_instance_rank + group = self.intra_distributed_optimizer_instance_group lw_work_handles = [] for bucket in self.buckets: - local_rank = self.intra_distributed_optimizer_instance_rank src = ( _flatten_dense_tensors(bucket.lw_params_list[local_rank]) if len(bucket.lw_params_list[local_rank]) > 0 @@ -338,30 +342,58 @@ def start_param_sync(self, force_sync: bool = False): ) # Keep src alive until the async operation completes. bucket._lw_src_buffer = src - bucket.lw_gather_tensor_list = [ - torch.empty(size, device=src.device, dtype=src.dtype) - for size in bucket.lw_param_flat_sizes - ] - work = torch.distributed.all_gather( - bucket.lw_gather_tensor_list, - src, - group=self.intra_distributed_optimizer_instance_group, - async_op=async_op, - ) - if async_op and work is not None: - lw_work_handles.append(work) + + if max(bucket.lw_param_flat_sizes) == 0: + # All ranks have empty params for this bucket — skip. + bucket.lw_gather_tensor_list = [ + torch.empty(0, device=src.device, dtype=src.dtype) + for _ in range(dp_size) + ] + continue + + # Allocate per-rank receive buffers (actual sizes, NO padding). + gather_list = [] + for i in range(dp_size): + if i == local_rank: + gather_list.append( + torch.empty(0, device=src.device, dtype=src.dtype) + ) + else: + gather_list.append( + torch.empty( + bucket.lw_param_flat_sizes[i], + device=src.device, + dtype=src.dtype, + ) + ) + bucket.lw_gather_tensor_list = gather_list + + # Broadcast each rank's params to all other ranks. + for i in range(dp_size): + if bucket.lw_param_flat_sizes[i] == 0: + continue + src_global = torch.distributed.get_global_rank(group, i) + if i == local_rank: + buf = src + else: + buf = gather_list[i] + work = torch.distributed.broadcast( + buf, src_global, group=group, async_op=async_op, + ) + if async_op and work is not None: + lw_work_handles.append(work) + if async_op: self.param_gather_handle = _LayerWiseAllGatherHandle(lw_work_handles) else: - # Synchronous layer-wise case (e.g., force_sync=True for checkpointing): - # unflatten and copy gathered params immediately. + # Synchronous: unflatten and copy gathered params immediately. for bucket in self.buckets: for idx, (flat_params, params) in enumerate( zip(bucket.lw_gather_tensor_list, bucket.lw_params_list) ): if ( len(params) == 0 - or idx == self.intra_distributed_optimizer_instance_rank + or idx == local_rank ): continue updated_params = _unflatten_dense_tensors(flat_params, params) diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index 42fc6e2b8a0..9a756c7a119 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -195,30 +195,52 @@ def set_bucket_lw_params_list(self, model_chunks): def allgather_params(self) -> None: """All-gather updated params from all ranks.""" - # helper function to flatten local params, allgather, unflatten and copy to model params + # helper function to flatten local params, broadcast to gather, + # unflatten and copy to model params def _allgather_helper(params_list, group): - # flatten this rank's params and create empty tensor output list device = params_list[0][0].device dtype = params_list[0][0].dtype rank = get_pg_rank(group) - # for rank without params create empty tensor and participate in allgather + dp_size = get_pg_size(group) + # Flatten this rank's params. src = ( _flatten_dense_tensors(params_list[rank]) if len(params_list[rank]) > 0 else torch.empty(0, device=device, dtype=dtype) ) - output_list = [ - torch.empty(sum([p.numel() for p in params]), device=device, dtype=dtype) - for params in params_list + flat_sizes = [ + sum(p.numel() for p in params) for params in params_list ] - # single all_gather_v to collect all updated params - torch.distributed.all_gather(output_list, src, group=group) - # unflatten and copy gathered params for each rank i - for idx, (flat_params, params) in enumerate(zip(output_list, params_list)): - # skip local params and empty tensors + if max(flat_sizes) == 0: + return + + # Allocate per-rank receive buffers (actual sizes, NO padding). + recv_buffers = [] + for i in range(dp_size): + if i == rank: + recv_buffers.append(None) + else: + recv_buffers.append( + torch.empty(flat_sizes[i], device=device, dtype=dtype) + ) + + # Broadcast each rank's params to all other ranks. + # All ranks must participate in every broadcast call (collective). + for i in range(dp_size): + if flat_sizes[i] == 0: + continue + src_global = torch.distributed.get_global_rank(group, i) + if i == rank: + buf = src + else: + buf = recv_buffers[i] + torch.distributed.broadcast(buf, src_global, group=group) + + # Unflatten and copy gathered params for each rank. + for idx, params in enumerate(params_list): if len(params) == 0 or idx == rank: continue - updated_params = _unflatten_dense_tensors(flat_params, params) + updated_params = _unflatten_dense_tensors(recv_buffers[idx], params) for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) From 211a0cada67b524770d7d54336d05168781763c9 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Sun, 22 Feb 2026 16:01:16 -0800 Subject: [PATCH 10/25] Autoformat: black formatting fixes Co-Authored-By: Claude Opus 4.6 --- .../core/distributed/param_and_grad_buffer.py | 18 +++++------------- .../core/optimizer/layer_wise_optimizer.py | 8 ++------ .../distributed/test_param_and_grad_buffer.py | 12 +++++------- 3 files changed, 12 insertions(+), 26 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index afc50c56de2..1f7d9c2420c 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -346,8 +346,7 @@ def start_param_sync(self, force_sync: bool = False): if max(bucket.lw_param_flat_sizes) == 0: # All ranks have empty params for this bucket — skip. bucket.lw_gather_tensor_list = [ - torch.empty(0, device=src.device, dtype=src.dtype) - for _ in range(dp_size) + torch.empty(0, device=src.device, dtype=src.dtype) for _ in range(dp_size) ] continue @@ -355,15 +354,11 @@ def start_param_sync(self, force_sync: bool = False): gather_list = [] for i in range(dp_size): if i == local_rank: - gather_list.append( - torch.empty(0, device=src.device, dtype=src.dtype) - ) + gather_list.append(torch.empty(0, device=src.device, dtype=src.dtype)) else: gather_list.append( torch.empty( - bucket.lw_param_flat_sizes[i], - device=src.device, - dtype=src.dtype, + bucket.lw_param_flat_sizes[i], device=src.device, dtype=src.dtype ) ) bucket.lw_gather_tensor_list = gather_list @@ -378,7 +373,7 @@ def start_param_sync(self, force_sync: bool = False): else: buf = gather_list[i] work = torch.distributed.broadcast( - buf, src_global, group=group, async_op=async_op, + buf, src_global, group=group, async_op=async_op ) if async_op and work is not None: lw_work_handles.append(work) @@ -391,10 +386,7 @@ def start_param_sync(self, force_sync: bool = False): for idx, (flat_params, params) in enumerate( zip(bucket.lw_gather_tensor_list, bucket.lw_params_list) ): - if ( - len(params) == 0 - or idx == local_rank - ): + if len(params) == 0 or idx == local_rank: continue updated_params = _unflatten_dense_tensors(flat_params, params) for updated_p, model_p in zip(updated_params, params): diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index 9a756c7a119..53bffb24abc 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -208,9 +208,7 @@ def _allgather_helper(params_list, group): if len(params_list[rank]) > 0 else torch.empty(0, device=device, dtype=dtype) ) - flat_sizes = [ - sum(p.numel() for p in params) for params in params_list - ] + flat_sizes = [sum(p.numel() for p in params) for params in params_list] if max(flat_sizes) == 0: return @@ -220,9 +218,7 @@ def _allgather_helper(params_list, group): if i == rank: recv_buffers.append(None) else: - recv_buffers.append( - torch.empty(flat_sizes[i], device=device, dtype=dtype) - ) + recv_buffers.append(torch.empty(flat_sizes[i], device=device, dtype=dtype)) # Broadcast each rank's params to all other ranks. # All ranks must participate in every broadcast call (collective). diff --git a/tests/unit_tests/distributed/test_param_and_grad_buffer.py b/tests/unit_tests/distributed/test_param_and_grad_buffer.py index cab2f06c3a0..aa83da28553 100644 --- a/tests/unit_tests/distributed/test_param_and_grad_buffer.py +++ b/tests/unit_tests/distributed/test_param_and_grad_buffer.py @@ -390,7 +390,7 @@ def _make_model(): bucket_size=None, ) module = TestModel( - input_dim=32, output_dim=32, num_layers=2, bias=False, shared_embedding=False, + input_dim=32, output_dim=32, num_layers=2, bias=False, shared_embedding=False ).bfloat16() model = DistributedDataParallel( TransformerConfig(num_attention_heads=1, num_layers=1), @@ -433,9 +433,9 @@ def test_bucket_group_waits_on_pending_handle(self): bg.free_overlap_buffers() mock_handle.wait.assert_called_once() - assert bg.param_gather_handle is None, ( - "param_gather_handle should be None after free_overlap_buffers" - ) + assert ( + bg.param_gather_handle is None + ), "param_gather_handle should be None after free_overlap_buffers" Utils.destroy_model_parallel() @@ -458,9 +458,7 @@ def test_ddp_free_overlap_buffers_delegates(self): """DDP.free_overlap_buffers should call free_overlap_buffers on all bucket groups.""" model = self._make_model() - with mock.patch.object( - type(model.bucket_groups[0]), 'free_overlap_buffers' - ) as mock_free: + with mock.patch.object(type(model.bucket_groups[0]), 'free_overlap_buffers') as mock_free: model.free_overlap_buffers() assert mock_free.call_count == len( model.bucket_groups + model.expert_parallel_bucket_groups From c7ee958bb2c84d7f6cd59e152f8f82e5b45a6cca Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Sun, 22 Feb 2026 17:53:04 -0800 Subject: [PATCH 11/25] Remove assertions blocking overlap_grad_reduce and overlap_param_gather with muon optimizer Co-Authored-By: Claude Opus 4.6 --- megatron/training/arguments.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 45df3b6b2b0..93dc68906f3 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1304,10 +1304,6 @@ def validate_args(args, defaults={}): # Muon optimizer check if 'muon' in args.optimizer: - # TODO: remove these checks once we support them - assert not args.overlap_grad_reduce, "Muon optimizer does not support overlap grad reduce for now." - assert not args.overlap_param_gather, "Muon optimizer does not support overlap param gather for now." - assert not args.use_distributed_optimizer, "Muon optimizer does not support distributed optimizer for now." assert not args.use_torch_fsdp2, "Muon optimizer does not support Torch-FSDP2 for now." assert not args.use_megatron_fsdp, "Muon optimizer does not support Megatron-FSDP for now." From cbed167fc60ede4fd28dc29b571226d7ba971545 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Sun, 22 Feb 2026 18:57:12 -0800 Subject: [PATCH 12/25] Fix dtype mismatch in layer-wise param gather broadcasts causing NCCL deadlock MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a rank has no local params in a bucket, the empty src tensor was created with bucket.grad_data.dtype (fp32 when grad_reduce_in_fp32=True) instead of the actual param dtype (bf16). This caused receive buffers on ranks without local params to be fp32 while the broadcasting rank sends bf16 data, resulting in different buffer byte sizes across ranks in the same broadcast collective — an NCCL deadlock. The fix uses bucket.params_list[0].dtype to always get the correct param dtype, matching the dtype produced by _flatten_dense_tensors on ranks that do have local params. Co-Authored-By: Claude Opus 4.6 --- megatron/core/distributed/param_and_grad_buffer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 1f7d9c2420c..c8ec0bdc913 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -333,11 +333,16 @@ def start_param_sync(self, force_sync: bool = False): group = self.intra_distributed_optimizer_instance_group lw_work_handles = [] for bucket in self.buckets: + # Use param dtype (e.g., bf16), NOT grad dtype (which may be fp32 + # when grad_reduce_in_fp32 is enabled). All ranks must use the same + # dtype for broadcast buffers, and ranks with local params will have + # bf16 src tensors from _flatten_dense_tensors. + param_dtype = bucket.params_list[0].dtype src = ( _flatten_dense_tensors(bucket.lw_params_list[local_rank]) if len(bucket.lw_params_list[local_rank]) > 0 else torch.empty( - 0, device=bucket.grad_data.device, dtype=bucket.grad_data.dtype + 0, device=bucket.grad_data.device, dtype=param_dtype ) ) # Keep src alive until the async operation completes. From 7726a355db4d393d8055b2e3c8ccae7b8a64f800 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Sun, 22 Feb 2026 23:49:29 -0800 Subject: [PATCH 13/25] Fix timing-dependent NCCL deadlock in layer-wise param gather by waiting only on last broadcast handle All per-rank broadcasts within a single start_param_sync() call are issued on the same NCCL communicator stream, so NCCL guarantees in-order completion. Waiting sequentially on each of the ~64 individual handles caused intermediate CUDA stream synchronizations that created a timing-dependent deadlock across ranks. Waiting on only the last handle is sufficient. Also removes debug logging added during investigation. Co-Authored-By: Claude Opus 4.6 --- megatron/core/distributed/param_and_grad_buffer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index c8ec0bdc913..57c23207823 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -142,8 +142,12 @@ def __init__(self, handles): self.handles = handles def wait(self): - for h in self.handles: - h.wait() + # All broadcasts are on the same NCCL communicator stream, so NCCL + # guarantees in-order completion. Waiting on only the last handle is + # sufficient and avoids intermediate CUDA stream synchronizations that + # can cause timing-dependent deadlocks across ranks. + if self.handles: + self.handles[-1].wait() class _ParamAndGradBucketGroup: From 5253153bd90c4189ebd7a7a44874406d8c163b75 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Mon, 23 Feb 2026 09:08:29 -0800 Subject: [PATCH 14/25] Add missing regression tests for layer-wise optimizer overlap param gather Cover gaps in test coverage: async dispatch + finish_param_sync cycle, finish_param_sync chaining to next bucket group, forward pre-hook integration, grad_reduce_in_fp32 dtype mismatch regression, hook enable/disable lifecycle, and multi-iteration with hooks. Co-Authored-By: Claude Opus 4.6 --- tests/unit_tests/test_layer_wise_optimizer.py | 322 +++++++++++++++++- 1 file changed, 321 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index 043b7f12588..eef979daadf 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -137,6 +137,8 @@ def create_model_and_optimizer_with_overlap_param_gather( model_kwargs=None, copy_from=None, async_allgather=True, + grad_reduce_in_fp32=False, + bucket_size=None, ): """Create model, DDP wrapper, and optimizer with overlap-param-gather enabled. @@ -150,6 +152,8 @@ def create_model_and_optimizer_with_overlap_param_gather( model_kwargs: Optional kwargs for model initialization copy_from: Optional DDP model to copy weights from async_allgather: If True, defer param all-gather to bucket infrastructure + grad_reduce_in_fp32: If True, reduce grads in fp32 (regression test for dtype fix) + bucket_size: Maximum number of parameters per bucket (None = single bucket) Returns: tuple: (model, optimizer, pg_collection) @@ -161,7 +165,10 @@ def create_model_and_optimizer_with_overlap_param_gather( model.requires_grad_(True) ddp_config = DistributedDataParallelConfig( - use_distributed_optimizer=False, overlap_param_gather=True + use_distributed_optimizer=False, + overlap_param_gather=True, + grad_reduce_in_fp32=grad_reduce_in_fp32, + bucket_size=bucket_size, ) model = DistributedDataParallel( TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model @@ -815,3 +822,316 @@ def test_overlap_param_gather_multi_iteration(self): atol=0, msg=f"Parameters diverged at iteration {iteration}", ) + + def test_overlap_param_gather_async_dispatch_and_finish(self): + """Test async dispatch + finish_param_sync cycle (the actual runtime path). + + Exercises _LayerWiseAllGatherHandle.wait() through the async dispatch path: + start_param_sync() (no force_sync) dispatches async broadcasts, then + finish_param_sync() waits on the handle and unflattens gathered params. + """ + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) + ) + ref_model, ref_optimizer, _ = ( + self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model + ) + ) + + # Set identical gradients on both models + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # Async path: step (no allgather) + async dispatch + explicit finish + optimizer.step() + model.start_param_sync() # async dispatch to all bucket groups + for bucket_group in model.bucket_groups: + bucket_group.finish_param_sync(skip_next_bucket_dispatch=True) + + # Sync path: step (includes allgather) + ref_optimizer.step() + + # Verify params match sync path + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg="Async dispatch + finish path produced different params than sync path", + ) + + # Verify all ranks have identical parameters + dp_size = get_pg_size(pg_collection.dp_cp) + if dp_size > 1: + for name, param in model.named_parameters(): + param_list = [torch.zeros_like(param.data) for _ in range(dp_size)] + torch.distributed.all_gather(param_list, param.data, group=pg_collection.dp_cp) + for i in range(1, dp_size): + torch.testing.assert_close( + param_list[0], + param_list[i], + msg=f"Parameter {name} differs between rank 0 and rank {i}", + ) + + def test_overlap_param_gather_finish_chains_next_bucket(self): + """Test that finish_param_sync() dispatches next_param_gather_bucket_group. + + Uses a small bucket_size to force multiple bucket groups, then dispatches + only the last bucket group and verifies that finishing it chains to the next. + """ + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True, bucket_size=2000 + ) + ) + + bucket_groups = model.bucket_groups + if len(bucket_groups) <= 1: + pytest.skip("Need multiple bucket groups to test chaining") + + ref_model, ref_optimizer, _ = ( + self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model, bucket_size=2000 + ) + ) + + # Set identical gradients on both models + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + optimizer.step() + + # Dispatch ONLY the last bucket group (which has next_param_gather_bucket_group set) + last_bg = bucket_groups[-1] + last_bg.start_param_sync() + + # Verify: next bucket group has NOT been dispatched yet + next_bg = last_bg.next_param_gather_bucket_group + assert next_bg is not None, "Last bucket group should have a next" + assert not next_bg.param_gather_dispatched, "Next bucket should not be dispatched yet" + + # Finish the last bucket group — should chain-dispatch the next one + last_bg.finish_param_sync() + + # Verify: next bucket group IS now dispatched via chaining + assert next_bg.param_gather_dispatched, ( + "finish_param_sync should have dispatched next bucket group" + ) + + # Finish remaining bucket groups through the chain + for bg in reversed(bucket_groups[:-1]): + bg.finish_param_sync() + + # Reference: sync step + ref_optimizer.step() + + # Verify params match + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg="Chained bucket finish produced different params than sync path", + ) + + def test_overlap_param_gather_forward_pre_hook(self): + """Test forward pre-hooks trigger finish_param_sync during model(input). + + After async dispatch, running model(input) fires forward pre-hooks that + call finish_param_sync() on each bucket group, completing the param sync. + """ + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) + ) + ref_model, ref_optimizer, _ = ( + self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model + ) + ) + + # Set identical gradients on both models + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # Async path: step (no allgather) + async dispatch + optimizer.step() + model.start_param_sync() # dispatch async broadcasts + + # Forward pass triggers hooks that call finish_param_sync() + input_tensor = torch.randn(16, 80, dtype=torch.bfloat16, device='cuda') + output = model(input_tensor) + + # Sync path: step (includes allgather) + ref_optimizer.step() + + # Verify params match + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg="Forward pre-hook path produced different params than sync path", + ) + + def test_overlap_param_gather_grad_reduce_in_fp32(self): + """Regression test: grad_reduce_in_fp32 must not cause dtype mismatch in broadcasts. + + When grad_reduce_in_fp32=True, the grad buffer dtype is fp32 but broadcast + buffers must use param dtype (bf16). Without the fix (commit cbed167fc), this + would cause a dtype mismatch error in the per-rank broadcast calls. + """ + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True, grad_reduce_in_fp32=True + ) + ) + ref_model, ref_optimizer, _ = ( + self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model, grad_reduce_in_fp32=True + ) + ) + + # Set identical gradients on both models + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # Async path: step + force_sync + optimizer.step() + model.start_param_sync(force_sync=True) + + # Sync path: step (includes allgather) + ref_optimizer.step() + + # Verify params match + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=1e-5, + atol=1e-5, + msg="grad_reduce_in_fp32 path produced different params than reference", + ) + + def test_overlap_param_gather_hook_enable_disable_cycle(self): + """Test the training loop's hook lifecycle: disable → manual sync → enable → forward. + + The training loop disables hooks before iteration 1 (for initialization), + then enables them for subsequent iterations. This test exercises that cycle. + """ + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) + ) + ref_model, ref_optimizer, _ = ( + self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model + ) + ) + + input_tensor = torch.randn(16, 80, dtype=torch.bfloat16, device='cuda') + + # Iteration 1: hooks disabled, manual sync + model.disable_forward_pre_hook(param_sync=False) + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + optimizer.step() + model.start_param_sync(force_sync=True) # manual sync + + ref_optimizer.step() + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg="Params diverged after iteration 1 (hooks disabled)", + ) + + # Iteration 2: hooks re-enabled, forward pass triggers sync + model.enable_forward_pre_hook() + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + optimizer.step() + model.start_param_sync() # async dispatch + output = model(input_tensor) # hooks finish sync + + ref_optimizer.step() + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg="Params diverged after iteration 2 (hooks re-enabled)", + ) + + def test_overlap_param_gather_multi_iteration_with_hooks(self): + """Test multiple iterations using forward pre-hooks (not manual force_sync). + + Runs 3 iterations where each iteration uses: set grads → step → async dispatch → + forward pass (hooks wait+unflatten). Compares against reference model using sync + allgather after each iteration. + """ + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) + ) + ref_model, ref_optimizer, _ = ( + self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model + ) + ) + + input_tensor = torch.randn(16, 80, dtype=torch.bfloat16, device='cuda') + + for iteration in range(3): + # Set identical gradients on both models + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # Async path: step + dispatch + forward (hooks wait+unflatten) + optimizer.step() + model.start_param_sync() # async dispatch + output = model(input_tensor) # hooks trigger finish_param_sync + + # Sync path: step (includes allgather) + ref_optimizer.step() + + # Verify parameters match after each iteration + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg=f"Parameters diverged at iteration {iteration}", + ) From 189428b15e988cee7c15a1ba756d5862b82787b0 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Mon, 23 Feb 2026 09:08:58 -0800 Subject: [PATCH 15/25] Autoformat: black formatting fixes Co-Authored-By: Claude Opus 4.6 --- tests/unit_tests/test_layer_wise_optimizer.py | 70 +++++++------------ 1 file changed, 27 insertions(+), 43 deletions(-) diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index eef979daadf..baa0a57df6d 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -830,13 +830,11 @@ def test_overlap_param_gather_async_dispatch_and_finish(self): start_param_sync() (no force_sync) dispatches async broadcasts, then finish_param_sync() waits on the handle and unflattens gathered params. """ - model, optimizer, pg_collection = ( - self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True ) - ref_model, ref_optimizer, _ = ( - self.create_model_and_optimizer_with_overlap_param_gather( - async_allgather=False, copy_from=model - ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model ) # Set identical gradients on both models @@ -884,20 +882,16 @@ def test_overlap_param_gather_finish_chains_next_bucket(self): Uses a small bucket_size to force multiple bucket groups, then dispatches only the last bucket group and verifies that finishing it chains to the next. """ - model, optimizer, pg_collection = ( - self.create_model_and_optimizer_with_overlap_param_gather( - async_allgather=True, bucket_size=2000 - ) + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True, bucket_size=2000 ) bucket_groups = model.bucket_groups if len(bucket_groups) <= 1: pytest.skip("Need multiple bucket groups to test chaining") - ref_model, ref_optimizer, _ = ( - self.create_model_and_optimizer_with_overlap_param_gather( - async_allgather=False, copy_from=model, bucket_size=2000 - ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model, bucket_size=2000 ) # Set identical gradients on both models @@ -922,9 +916,9 @@ def test_overlap_param_gather_finish_chains_next_bucket(self): last_bg.finish_param_sync() # Verify: next bucket group IS now dispatched via chaining - assert next_bg.param_gather_dispatched, ( - "finish_param_sync should have dispatched next bucket group" - ) + assert ( + next_bg.param_gather_dispatched + ), "finish_param_sync should have dispatched next bucket group" # Finish remaining bucket groups through the chain for bg in reversed(bucket_groups[:-1]): @@ -949,13 +943,11 @@ def test_overlap_param_gather_forward_pre_hook(self): After async dispatch, running model(input) fires forward pre-hooks that call finish_param_sync() on each bucket group, completing the param sync. """ - model, optimizer, pg_collection = ( - self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True ) - ref_model, ref_optimizer, _ = ( - self.create_model_and_optimizer_with_overlap_param_gather( - async_allgather=False, copy_from=model - ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model ) # Set identical gradients on both models @@ -993,15 +985,11 @@ def test_overlap_param_gather_grad_reduce_in_fp32(self): buffers must use param dtype (bf16). Without the fix (commit cbed167fc), this would cause a dtype mismatch error in the per-rank broadcast calls. """ - model, optimizer, pg_collection = ( - self.create_model_and_optimizer_with_overlap_param_gather( - async_allgather=True, grad_reduce_in_fp32=True - ) + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True, grad_reduce_in_fp32=True ) - ref_model, ref_optimizer, _ = ( - self.create_model_and_optimizer_with_overlap_param_gather( - async_allgather=False, copy_from=model, grad_reduce_in_fp32=True - ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model, grad_reduce_in_fp32=True ) # Set identical gradients on both models @@ -1034,13 +1022,11 @@ def test_overlap_param_gather_hook_enable_disable_cycle(self): The training loop disables hooks before iteration 1 (for initialization), then enables them for subsequent iterations. This test exercises that cycle. """ - model, optimizer, pg_collection = ( - self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True ) - ref_model, ref_optimizer, _ = ( - self.create_model_and_optimizer_with_overlap_param_gather( - async_allgather=False, copy_from=model - ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model ) input_tensor = torch.randn(16, 80, dtype=torch.bfloat16, device='cuda') @@ -1099,13 +1085,11 @@ def test_overlap_param_gather_multi_iteration_with_hooks(self): forward pass (hooks wait+unflatten). Compares against reference model using sync allgather after each iteration. """ - model, optimizer, pg_collection = ( - self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True ) - ref_model, ref_optimizer, _ = ( - self.create_model_and_optimizer_with_overlap_param_gather( - async_allgather=False, copy_from=model - ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model ) input_tensor = torch.randn(16, 80, dtype=torch.bfloat16, device='cuda') From d07bb5799e0105ab92b5d789dd0eb511dbcab27b Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Mon, 23 Feb 2026 09:26:07 -0800 Subject: [PATCH 16/25] Switch layer-wise optimizer tests from adam to dist_muon Replace all adam optimizer usage with get_megatron_muon_optimizer (muon for 2D weights + adam for 1D biases), which is the actual production use case for LayerWiseDistributedOptimizer. Co-Authored-By: Claude Opus 4.6 --- tests/unit_tests/test_layer_wise_optimizer.py | 125 ++++++++---------- 1 file changed, 56 insertions(+), 69 deletions(-) diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index baa0a57df6d..1a1f2622105 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -10,9 +10,10 @@ from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig -from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer +from megatron.core.optimizer import OptimizerConfig from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer -from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer +from megatron.core.optimizer.muon import get_megatron_muon_optimizer +from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer import TransformerConfig from megatron.core.utils import get_pg_rank, get_pg_size @@ -88,8 +89,8 @@ def create_model_and_optimizer( model_class: Model class to instantiate clip_grad: Optional gradient clipping value model_kwargs: Optional kwargs for model initialization - use_layer_wise: If True, wrap optimizer in LayerWiseDistributedOptimizer; - if False, use get_megatron_optimizer instead (for reference) + use_layer_wise: If True, use LayerWiseDistributedOptimizer via dist_muon; + if False, use standard muon ChainedOptimizer (for reference) Returns: tuple: (model, optimizer, pg_collection) @@ -110,24 +111,26 @@ def create_model_and_optimizer( model.broadcast_params() optimizer_config = OptimizerConfig( - optimizer='adam', + optimizer='muon', lr=0.01, weight_decay=0.01, - bf16=not use_layer_wise, + bf16=True, use_distributed_optimizer=False, clip_grad=clip_grad, + muon_tp_mode="duplicated", ) pg_collection = ProcessGroupCollection.use_mpu_process_groups() pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() - optimizer = get_megatron_optimizer(optimizer_config, [model]) - if use_layer_wise: - optimizer_config.bf16 = True - optimizer = LayerWiseDistributedOptimizer( - optimizer.chained_optimizers, optimizer_config, pg_collection - ) + optimizer = get_megatron_muon_optimizer( + config=optimizer_config, + model_chunks=[model], + use_gloo_process_groups=True, + layer_wise_distributed_optimizer=use_layer_wise, + pg_collection=pg_collection, + ) return model, optimizer, pg_collection def create_model_and_optimizer_with_overlap_param_gather( @@ -142,9 +145,9 @@ def create_model_and_optimizer_with_overlap_param_gather( ): """Create model, DDP wrapper, and optimizer with overlap-param-gather enabled. - This variant sets overlap_param_gather=True in DDP config and passes - model_chunks=[model] + async_allgather to - LayerWiseDistributedOptimizer, enabling the bucket-based async param gather path. + This variant sets overlap_param_gather=True in DDP config and uses + get_megatron_muon_optimizer with layer_wise_distributed_optimizer=True, + enabling the bucket-based async param gather path. Args: model_class: Model class to instantiate @@ -179,26 +182,26 @@ def create_model_and_optimizer_with_overlap_param_gather( model.broadcast_params() optimizer_config = OptimizerConfig( - optimizer='adam', + optimizer='muon', lr=0.01, weight_decay=0.01, - bf16=False, + bf16=True, use_distributed_optimizer=False, clip_grad=clip_grad, + overlap_param_gather=async_allgather, + muon_tp_mode="duplicated", ) pg_collection = ProcessGroupCollection.use_mpu_process_groups() pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() - optimizer = get_megatron_optimizer(optimizer_config, [model]) - optimizer_config.bf16 = True - optimizer = LayerWiseDistributedOptimizer( - optimizer.chained_optimizers, - optimizer_config, - pg_collection, + optimizer = get_megatron_muon_optimizer( + config=optimizer_config, model_chunks=[model], - async_allgather=async_allgather, + use_gloo_process_groups=True, + layer_wise_distributed_optimizer=True, + pg_collection=pg_collection, ) return model, optimizer, pg_collection @@ -333,42 +336,13 @@ def test_sharded_state_dict(self): def test_multiple_optimizers(self): """Test LayerWiseDistributedOptimizer with multiple chained optimizers. - This test properly tests allgather functionality with multiple ranks. + Uses get_megatron_muon_optimizer which produces multiple chained optimizers + (muon for 2D weights + adam for 1D biases). Tests allgather with multiple ranks. """ - model = SimpleModel().bfloat16().cuda() - model.requires_grad_(True) - - ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=False) - model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model - ) - - optimizer_config = OptimizerConfig( - optimizer='adam', lr=0.01, bf16=True, use_distributed_optimizer=False - ) - - # Split parameters into two groups for testing multiple optimizers - params = list(model.parameters()) - mid_point = len(params) // 2 - param_groups_1 = [{'params': params[:mid_point]}] - param_groups_2 = [{'params': params[mid_point:]}] - - # Create two separate base optimizers - base_optimizer_1 = torch.optim.Adam(param_groups_1, lr=optimizer_config.lr) - base_optimizer_2 = torch.optim.Adam(param_groups_2, lr=optimizer_config.lr) - - wrapped_optimizer_1 = FP32Optimizer(base_optimizer_1, optimizer_config, None) - wrapped_optimizer_2 = FP32Optimizer(base_optimizer_2, optimizer_config, None) - - pg_collection = ProcessGroupCollection.use_mpu_process_groups() - pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) - pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() - - optimizer = LayerWiseDistributedOptimizer( - [wrapped_optimizer_1, wrapped_optimizer_2], optimizer_config, pg_collection - ) + model, optimizer, pg_collection = self.create_model_and_optimizer() - assert len(optimizer.chained_optimizers) == 2, "Should have two chained optimizers" + # get_megatron_muon_optimizer produces muon + adam chained optimizers + assert len(optimizer.chained_optimizers) >= 2, "Should have multiple chained optimizers" # Set gradients and test optimizer step - this will trigger allgather for param in model.parameters(): @@ -404,26 +378,39 @@ def test_bf16_error(self): TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model ) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) + pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() + + # Create muon optimizer (non-layer-wise) — produces Float16-wrapped chained optimizers optimizer_config = OptimizerConfig( - optimizer='adam', lr=0.01, bf16=True, use_distributed_optimizer=False + optimizer='muon', + lr=0.01, + bf16=True, + use_distributed_optimizer=False, + muon_tp_mode="duplicated", ) - - # Create base optimizer and manually wrap in Float16 optimizer - param_groups = [{'params': list(model.parameters())}] - base_optimizer = torch.optim.Adam(param_groups, lr=optimizer_config.lr) - wrapped_optimizer = Float16OptimizerWithFloat16Params( - base_optimizer, optimizer_config, None, None + muon_optimizer = get_megatron_muon_optimizer( + config=optimizer_config, + model_chunks=[model], + use_gloo_process_groups=True, + layer_wise_distributed_optimizer=False, + pg_collection=pg_collection, ) - pg_collection = ProcessGroupCollection.use_mpu_process_groups() - pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) - pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() + # Extract a Float16-wrapped chained optimizer + wrapped_optimizer = muon_optimizer.chained_optimizers[0] + assert isinstance(wrapped_optimizer, Float16OptimizerWithFloat16Params) # Should raise TypeError when receiving already-wrapped Float16 optimizer + # Use a fresh config since get_megatron_muon_optimizer mutates config.optimizer + lw_config = OptimizerConfig( + optimizer='muon', lr=0.01, bf16=True, use_distributed_optimizer=False + ) with pytest.raises( TypeError, match='LayerWiseDistributedOptimizer received Float16 optimizer already' ): - LayerWiseDistributedOptimizer([wrapped_optimizer], optimizer_config, pg_collection) + LayerWiseDistributedOptimizer([wrapped_optimizer], lw_config, pg_collection) def _run_parameter_update_test(self, model_class=SimpleModel): """Helper method to test parameter updates with a given model class. From 3a0afd19c23feb5ca01fa3b90e656d752449866b Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Mon, 23 Feb 2026 09:47:42 -0800 Subject: [PATCH 17/25] Add back overlap assertions for muon (but not dist_muon) Block --overlap-grad-reduce and --overlap-param-gather when using plain muon optimizer, with message directing users to dist_muon. These flags are only supported with the layer-wise distributed optimizer path (dist_muon). Co-Authored-By: Claude Opus 4.6 --- megatron/training/arguments.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 93dc68906f3..4044f4dd493 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1304,6 +1304,10 @@ def validate_args(args, defaults={}): # Muon optimizer check if 'muon' in args.optimizer: + if args.optimizer == 'muon': + assert not args.overlap_grad_reduce, "Muon optimizer does not support overlap grad reduce. Use dist_muon instead." + assert not args.overlap_param_gather, "Muon optimizer does not support overlap param gather. Use dist_muon instead." + assert not args.use_distributed_optimizer, "Muon optimizer does not support distributed optimizer for now." assert not args.use_torch_fsdp2, "Muon optimizer does not support Torch-FSDP2 for now." assert not args.use_megatron_fsdp, "Muon optimizer does not support Megatron-FSDP for now." From 97b9916af57cc80e5785bbdd66feb577bbd459e1 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Mon, 23 Feb 2026 10:12:28 -0800 Subject: [PATCH 18/25] Remove redundant assert in finish_param_sync The first assert (use_distributed_optimizer or overlap_param_gather) is subsumed by the second (overlap_param_gather), so drop it. Co-Authored-By: Claude Opus 4.6 --- megatron/core/distributed/param_and_grad_buffer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 57c23207823..a0ae206704b 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -448,9 +448,6 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): skip_next_bucket_dispatch (bool, optional): if true, dispatch next bucket's communication if available. """ - # overlap_param_gather covers the layer-wise optimizer case, which sets - # overlap_param_gather=True without use_distributed_optimizer. - assert self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather assert self.ddp_config.overlap_param_gather # If current bucket's param AG has not been dispatched, dispatch it now (e.g., first From e5a53a290280f4c210ba2f60b036e8908504617a Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Tue, 24 Feb 2026 10:12:38 -0800 Subject: [PATCH 19/25] Replace per-rank broadcasts with all_gather for layer-wise param gather Switch from dp_size broadcast calls per bucket to a single torch.distributed.all_gather per bucket. PyTorch's NCCL backend handles uneven tensor sizes internally (via grouped send/recv), so no manual padding is needed. This simplifies the code and addresses reviewer feedback to use all_gather instead of broadcasts. Detach the flattened src tensor from the autograd graph because start_param_sync can be called during the forward pass (where autograd is active) and all_gather writes into gather_list entries in-place. Co-Authored-By: Claude Opus 4.6 --- .../core/distributed/param_and_grad_buffer.py | 126 ++++++++---------- .../core/optimizer/layer_wise_optimizer.py | 26 ++-- .../distributed/test_param_and_grad_buffer.py | 9 +- tests/unit_tests/test_layer_wise_optimizer.py | 72 +++++++++- 4 files changed, 135 insertions(+), 98 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index a0ae206704b..2d4d9d7fc0d 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -116,7 +116,7 @@ def __init__( # Layer-wise optimizer attributes for async param gather. self.lw_params_list = None self.lw_param_flat_sizes = None - self.lw_gather_tensor_list = None + self.lw_gather_list = None self._lw_src_buffer = None def set_lw_params_list(self, lw_params_list: List[List[torch.nn.Parameter]]): @@ -131,21 +131,17 @@ def set_lw_params_list(self, lw_params_list: List[List[torch.nn.Parameter]]): ] -class _LayerWiseAllGatherHandle: - """Handle for multiple async broadcast operations used by the layer-wise optimizer. +class _LWAllGatherHandle: + """Handle wrapping multiple async all-gather work objects. - Wraps the list of work handles returned by per-rank broadcast calls so that - a single wait() call waits on all outstanding operations. + NCCL guarantees in-order completion on the same communicator, so waiting + on only the last handle is sufficient. """ def __init__(self, handles): self.handles = handles def wait(self): - # All broadcasts are on the same NCCL communicator stream, so NCCL - # guarantees in-order completion. Waiting on only the last handle is - # sufficient and avoids intermediate CUDA stream synchronizations that - # can cause timing-dependent deadlocks across ranks. if self.handles: self.handles[-1].wait() @@ -316,91 +312,80 @@ def start_param_sync(self, force_sync: bool = False): async_op = self.ddp_config.overlap_param_gather and not force_sync if not self.ddp_config.use_distributed_optimizer: - # Layer-wise optimizer path: use per-rank broadcasts for - # variable-size param gather. + # Layer-wise optimizer path: use all_gather for variable-size + # param gather. # # Each rank may own a different number of params per bucket, so - # lw_param_flat_sizes can vary across ranks. NCCL's ncclAllGather - # requires uniform send sizes, so we cannot use all_gather directly. - # - # Instead, we issue dp_size broadcasts per bucket: for each rank i, - # all ranks call broadcast with rank i as the source. On the source - # rank the buffer is the flattened local params; on other ranks the - # buffer is a correctly-sized receive allocation. This avoids padding - # and uses only collectives (no P2P send/recv, which can deadlock - # with subsequent collectives on the same NCCL communicator). - # - # Memory cost: sum(lw_param_flat_sizes) elements per bucket (the - # receive buffers for each remote rank, plus the local flattened src). + # lw_param_flat_sizes can vary across ranks. PyTorch's NCCL + # backend handles uneven tensor sizes in torch.distributed.all_gather + # (falling back to grouped send/recv internally when sizes differ), + # so no manual padding is needed. dp_size = self.intra_distributed_optimizer_instance_size local_rank = self.intra_distributed_optimizer_instance_rank group = self.intra_distributed_optimizer_instance_group lw_work_handles = [] for bucket in self.buckets: - # Use param dtype (e.g., bf16), NOT grad dtype (which may be fp32 - # when grad_reduce_in_fp32 is enabled). All ranks must use the same - # dtype for broadcast buffers, and ranks with local params will have - # bf16 src tensors from _flatten_dense_tensors. + # Use param dtype (e.g., bf16), NOT grad dtype (which may be + # fp32 when grad_reduce_in_fp32 is enabled). param_dtype = bucket.params_list[0].dtype - src = ( - _flatten_dense_tensors(bucket.lw_params_list[local_rank]) - if len(bucket.lw_params_list[local_rank]) > 0 - else torch.empty( - 0, device=bucket.grad_data.device, dtype=param_dtype - ) - ) - # Keep src alive until the async operation completes. - bucket._lw_src_buffer = src if max(bucket.lw_param_flat_sizes) == 0: # All ranks have empty params for this bucket — skip. - bucket.lw_gather_tensor_list = [ - torch.empty(0, device=src.device, dtype=src.dtype) for _ in range(dp_size) - ] + bucket.lw_gather_list = None continue - # Allocate per-rank receive buffers (actual sizes, NO padding). + # Flatten local params. Detach from the autograd graph because + # start_param_sync can be called during the forward pass (where + # autograd is active) and all_gather will write into gather_list + # entries in-place. + local_size = bucket.lw_param_flat_sizes[local_rank] + if local_size > 0: + src = _flatten_dense_tensors(bucket.lw_params_list[local_rank]).detach() + else: + src = torch.empty( + 0, device=bucket.grad_data.device, dtype=param_dtype + ) + # Keep src alive until the async operation completes. + bucket._lw_src_buffer = src + + # Allocate per-rank receive buffers with actual sizes (no padding). + # Reuse src for local_rank's slot to avoid an extra allocation. gather_list = [] for i in range(dp_size): if i == local_rank: - gather_list.append(torch.empty(0, device=src.device, dtype=src.dtype)) + gather_list.append(src) else: gather_list.append( torch.empty( - bucket.lw_param_flat_sizes[i], device=src.device, dtype=src.dtype + bucket.lw_param_flat_sizes[i], + device=src.device, + dtype=src.dtype, ) ) - bucket.lw_gather_tensor_list = gather_list + bucket.lw_gather_list = gather_list - # Broadcast each rank's params to all other ranks. - for i in range(dp_size): - if bucket.lw_param_flat_sizes[i] == 0: - continue - src_global = torch.distributed.get_global_rank(group, i) - if i == local_rank: - buf = src - else: - buf = gather_list[i] - work = torch.distributed.broadcast( - buf, src_global, group=group, async_op=async_op - ) - if async_op and work is not None: - lw_work_handles.append(work) + work = torch.distributed.all_gather( + gather_list, src, group=group, async_op=async_op + ) + if async_op and work is not None: + lw_work_handles.append(work) if async_op: - self.param_gather_handle = _LayerWiseAllGatherHandle(lw_work_handles) + self.param_gather_handle = _LWAllGatherHandle(lw_work_handles) else: # Synchronous: unflatten and copy gathered params immediately. for bucket in self.buckets: - for idx, (flat_params, params) in enumerate( - zip(bucket.lw_gather_tensor_list, bucket.lw_params_list) - ): + if bucket.lw_gather_list is None: + continue + for idx, params in enumerate(bucket.lw_params_list): if len(params) == 0 or idx == local_rank: continue - updated_params = _unflatten_dense_tensors(flat_params, params) + updated_params = _unflatten_dense_tensors( + bucket.lw_gather_list[idx], params + ) for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) - bucket.lw_gather_tensor_list.clear() + bucket.lw_gather_list = None bucket._lw_src_buffer = None self.param_gather_handle = None else: @@ -487,20 +472,22 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): bucket.param_data.zero_() elif not self.ddp_config.use_distributed_optimizer: for bucket in self.buckets: + if bucket.lw_gather_list is None: + continue # Unflatten and copy gathered params for each rank. - for idx, (flat_params, params) in enumerate( - zip(bucket.lw_gather_tensor_list, bucket.lw_params_list) - ): + for idx, params in enumerate(bucket.lw_params_list): # Skip local params and empty tensors. if ( len(params) == 0 or idx == self.intra_distributed_optimizer_instance_rank ): continue - updated_params = _unflatten_dense_tensors(flat_params, params) + updated_params = _unflatten_dense_tensors( + bucket.lw_gather_list[idx], params + ) for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) - bucket.lw_gather_tensor_list.clear() + bucket.lw_gather_list = None bucket._lw_src_buffer = None else: fp8_params = [] @@ -698,8 +685,7 @@ def free_overlap_buffers(self): self.param_gather_handle.wait() self.param_gather_handle = None for bucket in self.buckets: - if bucket.lw_gather_tensor_list is not None: - bucket.lw_gather_tensor_list.clear() + bucket.lw_gather_list = None bucket._lw_src_buffer = None def register_grad_ready( diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index 53bffb24abc..b95a7165fae 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -195,7 +195,7 @@ def set_bucket_lw_params_list(self, model_chunks): def allgather_params(self) -> None: """All-gather updated params from all ranks.""" - # helper function to flatten local params, broadcast to gather, + # helper function to flatten local params, all-gather, # unflatten and copy to model params def _allgather_helper(params_list, group): device = params_list[0][0].device @@ -212,31 +212,23 @@ def _allgather_helper(params_list, group): if max(flat_sizes) == 0: return - # Allocate per-rank receive buffers (actual sizes, NO padding). - recv_buffers = [] + # Allocate per-rank receive buffers with actual sizes (no padding). + # PyTorch's NCCL backend handles uneven sizes in all_gather via + # grouped send/recv internally. Reuse src for local rank's slot. + gather_list = [] for i in range(dp_size): if i == rank: - recv_buffers.append(None) + gather_list.append(src) else: - recv_buffers.append(torch.empty(flat_sizes[i], device=device, dtype=dtype)) + gather_list.append(torch.empty(flat_sizes[i], device=device, dtype=dtype)) - # Broadcast each rank's params to all other ranks. - # All ranks must participate in every broadcast call (collective). - for i in range(dp_size): - if flat_sizes[i] == 0: - continue - src_global = torch.distributed.get_global_rank(group, i) - if i == rank: - buf = src - else: - buf = recv_buffers[i] - torch.distributed.broadcast(buf, src_global, group=group) + torch.distributed.all_gather(gather_list, src, group=group) # Unflatten and copy gathered params for each rank. for idx, params in enumerate(params_list): if len(params) == 0 or idx == rank: continue - updated_params = _unflatten_dense_tensors(recv_buffers[idx], params) + updated_params = _unflatten_dense_tensors(gather_list[idx], params) for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) diff --git a/tests/unit_tests/distributed/test_param_and_grad_buffer.py b/tests/unit_tests/distributed/test_param_and_grad_buffer.py index aa83da28553..62cec92b9f0 100644 --- a/tests/unit_tests/distributed/test_param_and_grad_buffer.py +++ b/tests/unit_tests/distributed/test_param_and_grad_buffer.py @@ -406,16 +406,15 @@ def test_bucket_group_clears_buffers(self): for bg in model.bucket_groups: # Simulate buffers that would be allocated by start_param_sync. for bucket in bg.buckets: - bucket.lw_gather_tensor_list = [torch.empty(8), torch.empty(8)] + bucket.lw_gather_list = [torch.empty(8), torch.empty(8)] bucket._lw_src_buffer = torch.empty(16) bg.free_overlap_buffers() for bucket in bg.buckets: assert ( - bucket.lw_gather_tensor_list is not None - and len(bucket.lw_gather_tensor_list) == 0 - ), "lw_gather_tensor_list should be empty after free_overlap_buffers" + bucket.lw_gather_list is None + ), "lw_gather_list should be None after free_overlap_buffers" assert ( bucket._lw_src_buffer is None ), "_lw_src_buffer should be None after free_overlap_buffers" @@ -446,7 +445,7 @@ def test_bucket_group_noop_when_no_buffers(self): for bg in model.bucket_groups: assert bg.param_gather_handle is None for bucket in bg.buckets: - assert bucket.lw_gather_tensor_list is None + assert bucket.lw_gather_list is None assert bucket._lw_src_buffer is None # Should not raise. diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index 1a1f2622105..75c479da555 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -314,11 +314,21 @@ def test_sharded_state_dict(self): # Test sharded_state_dict sharded_state_dict = optimizer.sharded_state_dict(model_sharded_state_dict) - # Verify the sharded_state_dict is not None and has expected structure + # Verify the sharded_state_dict is not None and has expected structure. + # With multiple chained optimizers (muon + adam), the top-level keys are + # integer indices; each sub-dict should contain an 'optimizer' key. assert sharded_state_dict is not None, "Sharded state dict should not be None" - assert ( - 'optimizer' in sharded_state_dict - ), "Sharded state dict should contain 'optimizer' key" + if isinstance(sharded_state_dict, dict) and all( + isinstance(k, int) for k in sharded_state_dict.keys() + ): + for idx, sub_dict in sharded_state_dict.items(): + assert ( + 'optimizer' in sub_dict + ), f"Sub-dict {idx} should contain 'optimizer' key" + else: + assert ( + 'optimizer' in sharded_state_dict + ), "Sharded state dict should contain 'optimizer' key" # Verify that replica_id is set correctly (should be 0 for DP dimension) from megatron.core.dist_checkpointing import ShardedTensor @@ -813,8 +823,7 @@ def test_overlap_param_gather_multi_iteration(self): def test_overlap_param_gather_async_dispatch_and_finish(self): """Test async dispatch + finish_param_sync cycle (the actual runtime path). - Exercises _LayerWiseAllGatherHandle.wait() through the async dispatch path: - start_param_sync() (no force_sync) dispatches async broadcasts, then + start_param_sync() (no force_sync) dispatches async all-gathers, then finish_param_sync() waits on the handle and unflattens gathered params. """ model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( @@ -1106,3 +1115,54 @@ def test_overlap_param_gather_multi_iteration_with_hooks(self): atol=0, msg=f"Parameters diverged at iteration {iteration}", ) + + def test_overlap_param_gather_start_sync_with_autograd(self): + """Regression test: start_param_sync must work when autograd is active. + + _flatten_dense_tensors on params with requires_grad=True produces a tensor + that also requires grad. Since all_gather writes into gather_list entries + in-place and the local rank's slot reuses src, this triggers: + RuntimeError: a view of a leaf Variable that requires grad is being + used in an in-place operation. + The fix is to .detach() the flattened tensor before using it as src. + + This test calls start_param_sync (synchronous via force_sync) WITHOUT + torch.no_grad() to reproduce the exact scenario that occurs during the + forward pass when finish_param_sync chains to start_param_sync for the + next bucket group. + """ + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True + ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model + ) + + # Confirm params require grad (the precondition for this bug). + for param in model.parameters(): + assert param.requires_grad, "Test requires params with requires_grad=True" + + # Set identical gradients on both models. + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # Step both optimizers (async path skips allgather, ref path includes it). + optimizer.step() + ref_optimizer.step() + + # Call start_param_sync with autograd ENABLED (no torch.no_grad()). + # Before the .detach() fix, this would raise RuntimeError. + model.start_param_sync(force_sync=True) + + # Verify gathered params match the reference. + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg="Params incorrect after start_param_sync with autograd enabled", + ) From a566a7b6f5eb4dbd7c0369c528f246be623549ed Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Fri, 27 Feb 2026 22:07:53 -0800 Subject: [PATCH 20/25] Skip --overlap-grad-reduce requirement for layer-wise optimizers Layer-wise optimizers like dist_muon handle gradient reduction internally and do not use the DDP grad-reduce overlap path, so requiring --overlap-grad-reduce with --overlap-param-gather is unnecessary. Co-Authored-By: Claude Opus 4.6 (1M context) --- megatron/training/arguments.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index ee737315932..686443bbab0 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -726,8 +726,9 @@ def validate_args(args, defaults={}): assert args.use_distributed_optimizer or args.use_megatron_fsdp \ or ('dist' in args.optimizer), \ '--overlap-param-gather only supported with distributed optimizer, megatron fsdp, or layer-wise optimizer' - assert args.overlap_grad_reduce, \ - 'Must use --overlap-param-gather with --overlap-grad-reduce' + if 'dist' not in args.optimizer: + assert args.overlap_grad_reduce, \ + 'Must use --overlap-param-gather with --overlap-grad-reduce' assert not args.use_legacy_models, \ '--overlap-param-gather only supported with MCore models' From f30914ec5b55dd8f57ed90c1b64b3ad7435813aa Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Mon, 2 Mar 2026 15:38:47 -0800 Subject: [PATCH 21/25] Rename lw -> layerwise, improve docstring, clear handles in wait() - Rename all `lw` prefixed names to `layerwise` for clarity - Improve set_layerwise_params_list docstring to clarify that each inner list contains only the params owned by that rank's layer-wise optimizer that also belong to this bucket - Clear self.handles at the end of _LayerwiseAllGatherHandle.wait() - Rename local variable `src` to `flat_local_params` Co-Authored-By: Claude Opus 4.6 (1M context) --- .../core/distributed/param_and_grad_buffer.py | 87 ++++++++++--------- .../core/optimizer/layer_wise_optimizer.py | 8 +- 2 files changed, 51 insertions(+), 44 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 2d4d9d7fc0d..4f779b33000 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -114,24 +114,28 @@ def __init__( self.param_to_index[param] = (global_start - offset, global_end - offset) # Layer-wise optimizer attributes for async param gather. - self.lw_params_list = None - self.lw_param_flat_sizes = None - self.lw_gather_list = None - self._lw_src_buffer = None + self.layerwise_params_list = None + self.layerwise_param_flat_sizes = None + self.layerwise_gather_list = None + self._layerwise_src_buffer = None - def set_lw_params_list(self, lw_params_list: List[List[torch.nn.Parameter]]): + def set_layerwise_params_list( + self, layerwise_params_list: List[List[torch.nn.Parameter]] + ): """Set per-rank parameter lists for layer-wise async all-gather. Args: - lw_params_list: List of param lists, one per rank in the DP group. + layerwise_params_list: List of param lists, one per rank in the DP group. + Each inner list contains the parameters owned by that rank's + layer-wise optimizer that also belong to this bucket. """ - self.lw_params_list = lw_params_list - self.lw_param_flat_sizes = [ - sum([p.numel() for p in param_list]) for param_list in lw_params_list + self.layerwise_params_list = layerwise_params_list + self.layerwise_param_flat_sizes = [ + sum([p.numel() for p in param_list]) for param_list in layerwise_params_list ] -class _LWAllGatherHandle: +class _LayerwiseAllGatherHandle: """Handle wrapping multiple async all-gather work objects. NCCL guarantees in-order completion on the same communicator, so waiting @@ -144,6 +148,7 @@ def __init__(self, handles): def wait(self): if self.handles: self.handles[-1].wait() + self.handles = None class _ParamAndGradBucketGroup: @@ -316,77 +321,79 @@ def start_param_sync(self, force_sync: bool = False): # param gather. # # Each rank may own a different number of params per bucket, so - # lw_param_flat_sizes can vary across ranks. PyTorch's NCCL + # layerwise_param_flat_sizes can vary across ranks. PyTorch's NCCL # backend handles uneven tensor sizes in torch.distributed.all_gather # (falling back to grouped send/recv internally when sizes differ), # so no manual padding is needed. dp_size = self.intra_distributed_optimizer_instance_size local_rank = self.intra_distributed_optimizer_instance_rank group = self.intra_distributed_optimizer_instance_group - lw_work_handles = [] + layerwise_work_handles = [] for bucket in self.buckets: # Use param dtype (e.g., bf16), NOT grad dtype (which may be # fp32 when grad_reduce_in_fp32 is enabled). param_dtype = bucket.params_list[0].dtype - if max(bucket.lw_param_flat_sizes) == 0: + if max(bucket.layerwise_param_flat_sizes) == 0: # All ranks have empty params for this bucket — skip. - bucket.lw_gather_list = None + bucket.layerwise_gather_list = None continue # Flatten local params. Detach from the autograd graph because # start_param_sync can be called during the forward pass (where # autograd is active) and all_gather will write into gather_list # entries in-place. - local_size = bucket.lw_param_flat_sizes[local_rank] + local_size = bucket.layerwise_param_flat_sizes[local_rank] if local_size > 0: - src = _flatten_dense_tensors(bucket.lw_params_list[local_rank]).detach() + flat_local_params = _flatten_dense_tensors( + bucket.layerwise_params_list[local_rank] + ).detach() else: - src = torch.empty( + flat_local_params = torch.empty( 0, device=bucket.grad_data.device, dtype=param_dtype ) - # Keep src alive until the async operation completes. - bucket._lw_src_buffer = src + # Keep flat_local_params alive until the async operation completes. + bucket._layerwise_src_buffer = flat_local_params # Allocate per-rank receive buffers with actual sizes (no padding). - # Reuse src for local_rank's slot to avoid an extra allocation. + # Reuse flat_local_params for local_rank's slot to avoid an extra allocation. gather_list = [] for i in range(dp_size): if i == local_rank: - gather_list.append(src) + gather_list.append(flat_local_params) else: gather_list.append( torch.empty( - bucket.lw_param_flat_sizes[i], - device=src.device, - dtype=src.dtype, + bucket.layerwise_param_flat_sizes[i], + device=flat_local_params.device, + dtype=flat_local_params.dtype, ) ) - bucket.lw_gather_list = gather_list + bucket.layerwise_gather_list = gather_list work = torch.distributed.all_gather( - gather_list, src, group=group, async_op=async_op + gather_list, flat_local_params, group=group, async_op=async_op ) if async_op and work is not None: - lw_work_handles.append(work) + layerwise_work_handles.append(work) if async_op: - self.param_gather_handle = _LWAllGatherHandle(lw_work_handles) + self.param_gather_handle = _LayerwiseAllGatherHandle(layerwise_work_handles) else: # Synchronous: unflatten and copy gathered params immediately. for bucket in self.buckets: - if bucket.lw_gather_list is None: + if bucket.layerwise_gather_list is None: continue - for idx, params in enumerate(bucket.lw_params_list): + for idx, params in enumerate(bucket.layerwise_params_list): if len(params) == 0 or idx == local_rank: continue updated_params = _unflatten_dense_tensors( - bucket.lw_gather_list[idx], params + bucket.layerwise_gather_list[idx], params ) for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) - bucket.lw_gather_list = None - bucket._lw_src_buffer = None + bucket.layerwise_gather_list = None + bucket._layerwise_src_buffer = None self.param_gather_handle = None else: # Standard distributed optimizer path: use _coalescing_manager. @@ -472,10 +479,10 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): bucket.param_data.zero_() elif not self.ddp_config.use_distributed_optimizer: for bucket in self.buckets: - if bucket.lw_gather_list is None: + if bucket.layerwise_gather_list is None: continue # Unflatten and copy gathered params for each rank. - for idx, params in enumerate(bucket.lw_params_list): + for idx, params in enumerate(bucket.layerwise_params_list): # Skip local params and empty tensors. if ( len(params) == 0 @@ -483,12 +490,12 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): ): continue updated_params = _unflatten_dense_tensors( - bucket.lw_gather_list[idx], params + bucket.layerwise_gather_list[idx], params ) for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) - bucket.lw_gather_list = None - bucket._lw_src_buffer = None + bucket.layerwise_gather_list = None + bucket._layerwise_src_buffer = None else: fp8_params = [] for bucket in self.buckets: @@ -685,8 +692,8 @@ def free_overlap_buffers(self): self.param_gather_handle.wait() self.param_gather_handle = None for bucket in self.buckets: - bucket.lw_gather_list = None - bucket._lw_src_buffer = None + bucket.layerwise_gather_list = None + bucket._layerwise_src_buffer = None def register_grad_ready( self, param: torch.nn.Parameter, force_all_reduce: Optional[bool] = False diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index b95a7165fae..a9fdc7ba72f 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -69,7 +69,7 @@ def __init__( assert ( model_chunks is not None ), "model_chunks must be provided if async_allgather is True" - self.set_bucket_lw_params_list(model_chunks) + self.set_bucket_layerwise_params_list(model_chunks) if init_state_fn_list: assert len(init_state_fn_list) == len( @@ -156,7 +156,7 @@ def shard_params(self, optimizers): if expt_dp_size == 1 or len(self.expt_dp_params_list[0]) == 0: self.expt_dp_params_list = None - def set_bucket_lw_params_list(self, model_chunks): + def set_bucket_layerwise_params_list(self, model_chunks): """Map sharded params to DDP buckets for async all-gather. For each bucket in each model chunk's bucket groups, build per-rank param lists @@ -175,7 +175,7 @@ def set_bucket_lw_params_list(self, model_chunks): for param in full_params_list: if param in bucket.params: bucket_list.append(param) - bucket.set_lw_params_list(bucket_params_list) + bucket.set_layerwise_params_list(bucket_params_list) # Do the same for expert parallel bucket groups. if self.expt_dp_params_list is not None: for group in model_chunk.expert_parallel_bucket_groups: @@ -189,7 +189,7 @@ def set_bucket_lw_params_list(self, model_chunks): for param in full_params_list: if param in bucket.params: bucket_list.append(param) - bucket.set_lw_params_list(bucket_params_list) + bucket.set_layerwise_params_list(bucket_params_list) @torch.no_grad() def allgather_params(self) -> None: From 23c4516b7ba77821f6762fc19ac22740082fd4f6 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Mon, 2 Mar 2026 21:42:57 -0800 Subject: [PATCH 22/25] Fix tests to use renamed layerwise_ attribute prefix The production code was renamed from lw_ to layerwise_ prefix but the tests were not updated, causing attribute mismatches and test failures. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../distributed/test_param_and_grad_buffer.py | 18 ++++++------ tests/unit_tests/test_layer_wise_optimizer.py | 28 +++++++++---------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/unit_tests/distributed/test_param_and_grad_buffer.py b/tests/unit_tests/distributed/test_param_and_grad_buffer.py index 62cec92b9f0..48f815531c5 100644 --- a/tests/unit_tests/distributed/test_param_and_grad_buffer.py +++ b/tests/unit_tests/distributed/test_param_and_grad_buffer.py @@ -400,24 +400,24 @@ def _make_model(): return model def test_bucket_group_clears_buffers(self): - """free_overlap_buffers on a bucket group should None-out per-bucket lw buffers.""" + """free_overlap_buffers on a bucket group should None-out per-bucket layerwise buffers.""" model = self._make_model() for bg in model.bucket_groups: # Simulate buffers that would be allocated by start_param_sync. for bucket in bg.buckets: - bucket.lw_gather_list = [torch.empty(8), torch.empty(8)] - bucket._lw_src_buffer = torch.empty(16) + bucket.layerwise_gather_list = [torch.empty(8), torch.empty(8)] + bucket._layerwise_src_buffer = torch.empty(16) bg.free_overlap_buffers() for bucket in bg.buckets: assert ( - bucket.lw_gather_list is None - ), "lw_gather_list should be None after free_overlap_buffers" + bucket.layerwise_gather_list is None + ), "layerwise_gather_list should be None after free_overlap_buffers" assert ( - bucket._lw_src_buffer is None - ), "_lw_src_buffer should be None after free_overlap_buffers" + bucket._layerwise_src_buffer is None + ), "_layerwise_src_buffer should be None after free_overlap_buffers" Utils.destroy_model_parallel() @@ -445,8 +445,8 @@ def test_bucket_group_noop_when_no_buffers(self): for bg in model.bucket_groups: assert bg.param_gather_handle is None for bucket in bg.buckets: - assert bucket.lw_gather_list is None - assert bucket._lw_src_buffer is None + assert bucket.layerwise_gather_list is None + assert bucket._layerwise_src_buffer is None # Should not raise. bg.free_overlap_buffers() diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index 75c479da555..6cceed64ea7 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -634,7 +634,7 @@ def test_overlap_param_gather_vs_sync_allgather(self): ) def test_overlap_param_gather_bucket_lw_params(self): - """Verify bucket.lw_params_list is populated when async_allgather is enabled.""" + """Verify bucket.layerwise_params_list is populated when async_allgather is enabled.""" model, optimizer, pg_collection = ( self.create_model_and_optimizer_with_overlap_param_gather() ) @@ -643,33 +643,33 @@ def test_overlap_param_gather_bucket_lw_params(self): for bucket_group in model.bucket_groups: for bucket in bucket_group.buckets: - # lw_params_list should be populated by set_bucket_lw_params_list + # layerwise_params_list should be populated by set_bucket_layerwise_params_list assert ( - bucket.lw_params_list is not None - ), "bucket.lw_params_list should be populated" + bucket.layerwise_params_list is not None + ), "bucket.layerwise_params_list should be populated" assert ( - len(bucket.lw_params_list) == dp_size - ), f"Expected {dp_size} per-rank lists, got {len(bucket.lw_params_list)}" + len(bucket.layerwise_params_list) == dp_size + ), f"Expected {dp_size} per-rank lists, got {len(bucket.layerwise_params_list)}" # The union of all per-rank param lists should cover all bucket params all_lw_params = set() - for rank_params in bucket.lw_params_list: + for rank_params in bucket.layerwise_params_list: for p in rank_params: all_lw_params.add(p) assert ( all_lw_params == bucket.params - ), "Union of per-rank lw_params should equal bucket params" + ), "Union of per-rank layerwise_params should equal bucket params" - # lw_param_flat_sizes should be populated and have correct length - assert bucket.lw_param_flat_sizes is not None - assert len(bucket.lw_param_flat_sizes) == dp_size + # layerwise_param_flat_sizes should be populated and have correct length + assert bucket.layerwise_param_flat_sizes is not None + assert len(bucket.layerwise_param_flat_sizes) == dp_size # Each flat size should equal the sum of param numels for that rank for rank_idx in range(dp_size): - expected_size = sum(p.numel() for p in bucket.lw_params_list[rank_idx]) - assert bucket.lw_param_flat_sizes[rank_idx] == expected_size, ( + expected_size = sum(p.numel() for p in bucket.layerwise_params_list[rank_idx]) + assert bucket.layerwise_param_flat_sizes[rank_idx] == expected_size, ( f"Rank {rank_idx}: expected flat_size {expected_size}, " - f"got {bucket.lw_param_flat_sizes[rank_idx]}" + f"got {bucket.layerwise_param_flat_sizes[rank_idx]}" ) def test_overlap_param_gather_vs_standard_ddp(self): From d77abd5604610e04d31b7804b23add75e055eada Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Wed, 4 Mar 2026 14:08:01 -0800 Subject: [PATCH 23/25] Re-enable --overlap-grad-reduce assertion for --overlap-param-gather The assertion was previously skipped for layer-wise optimizers (dist_muon), but it should be unconditional. Also spell out 'dist_muon' explicitly instead of using substring matching. Co-Authored-By: Claude Opus 4.6 (1M context) --- megatron/training/arguments.py | 9 ++++----- tests/unit_tests/test_layer_wise_optimizer.py | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 686443bbab0..99507ccd443 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -724,11 +724,10 @@ def validate_args(args, defaults={}): if args.overlap_param_gather: assert args.use_distributed_optimizer or args.use_megatron_fsdp \ - or ('dist' in args.optimizer), \ - '--overlap-param-gather only supported with distributed optimizer, megatron fsdp, or layer-wise optimizer' - if 'dist' not in args.optimizer: - assert args.overlap_grad_reduce, \ - 'Must use --overlap-param-gather with --overlap-grad-reduce' + or args.optimizer == 'dist_muon', \ + '--overlap-param-gather only supported with distributed optimizer, megatron fsdp, or dist_muon' + assert args.overlap_grad_reduce, \ + 'Must use --overlap-param-gather with --overlap-grad-reduce' assert not args.use_legacy_models, \ '--overlap-param-gather only supported with MCore models' diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index 6cceed64ea7..161d9e8a6d3 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -170,6 +170,7 @@ def create_model_and_optimizer_with_overlap_param_gather( ddp_config = DistributedDataParallelConfig( use_distributed_optimizer=False, overlap_param_gather=True, + overlap_grad_reduce=True, grad_reduce_in_fp32=grad_reduce_in_fp32, bucket_size=bucket_size, ) From b22ecd375596c5fe0b3e5606e9101508bad837f3 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Wed, 4 Mar 2026 14:22:51 -0800 Subject: [PATCH 24/25] Run autoformat.sh Co-Authored-By: Claude Opus 4.6 (1M context) --- megatron/core/distributed/param_and_grad_buffer.py | 4 +--- tests/unit_tests/test_layer_wise_optimizer.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 4f779b33000..968eb4eb866 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -119,9 +119,7 @@ def __init__( self.layerwise_gather_list = None self._layerwise_src_buffer = None - def set_layerwise_params_list( - self, layerwise_params_list: List[List[torch.nn.Parameter]] - ): + def set_layerwise_params_list(self, layerwise_params_list: List[List[torch.nn.Parameter]]): """Set per-rank parameter lists for layer-wise async all-gather. Args: diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index 161d9e8a6d3..c484ca104ee 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -323,9 +323,7 @@ def test_sharded_state_dict(self): isinstance(k, int) for k in sharded_state_dict.keys() ): for idx, sub_dict in sharded_state_dict.items(): - assert ( - 'optimizer' in sub_dict - ), f"Sub-dict {idx} should contain 'optimizer' key" + assert 'optimizer' in sub_dict, f"Sub-dict {idx} should contain 'optimizer' key" else: assert ( 'optimizer' in sharded_state_dict From 62b5ae24071f4048f399a390a41e1e271c275a42 Mon Sep 17 00:00:00 2001 From: Mike Chrzanowski Date: Wed, 4 Mar 2026 15:37:11 -0800 Subject: [PATCH 25/25] Add missing docstring to fix lint error Co-Authored-By: Claude Opus 4.6 (1M context) --- megatron/core/distributed/param_and_grad_buffer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 968eb4eb866..112b0f20697 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -144,6 +144,7 @@ def __init__(self, handles): self.handles = handles def wait(self): + """Wait on the last handle and clear all handles.""" if self.handles: self.handles[-1].wait() self.handles = None