diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 55179ff3024..f4647d764aa 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -236,7 +236,10 @@ def _allocate_buffers_for_parameters( # Set `next_param_gather_bucket_group` for different bucket groups by iterating through # buckets in reverse order (since all-gathers happen in reverse order of buckets). - if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather: + # Note: overlap_param_gather covers both the distributed optimizer and the + # layer-wise optimizer cases; the latter sets overlap_param_gather=True + # without use_distributed_optimizer. + if self.ddp_config.overlap_param_gather: num_bucket_groups = len(bucket_groups) for i in range(1, num_bucket_groups): bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = ( @@ -344,9 +347,10 @@ def unmap_weight_tensor(m): grad_acc.register_hook(self._make_backward_post_hook(param)) self.grad_accs.append(grad_acc) - self.use_forward_hook = ( - self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather - ) + # Note: overlap_param_gather covers both the distributed optimizer and the + # layer-wise optimizer cases; the latter sets overlap_param_gather=True + # without use_distributed_optimizer. + self.use_forward_hook = self.ddp_config.overlap_param_gather self.remove_forward_pre_hook_handles = {} if self.use_forward_hook: self.enable_forward_pre_hook() @@ -534,6 +538,11 @@ def finish_grad_sync(self, force_all_reduce: Optional[bool] = False): for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: bucket_group.finish_grad_sync(force_all_reduce=force_all_reduce) + def free_overlap_buffers(self): + """Free overlap param-gather GPU buffers across all bucket groups.""" + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.free_overlap_buffers() + def scale_gradients(self, scaling_factor: float): """Scale all gradients inside the buffers by `scaling_factor`.""" for buffer in self.buffers + self.expert_parallel_buffers: diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 088374fbf13..112b0f20697 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -10,6 +10,7 @@ from typing import Dict, List, Optional import torch +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import _coalescing_manager import megatron.core.nccl_allocator as nccl_allocator @@ -112,6 +113,42 @@ def __init__( global_start, global_end, _ = param_index_map[param] self.param_to_index[param] = (global_start - offset, global_end - offset) + # Layer-wise optimizer attributes for async param gather. + self.layerwise_params_list = None + self.layerwise_param_flat_sizes = None + self.layerwise_gather_list = None + self._layerwise_src_buffer = None + + def set_layerwise_params_list(self, layerwise_params_list: List[List[torch.nn.Parameter]]): + """Set per-rank parameter lists for layer-wise async all-gather. + + Args: + layerwise_params_list: List of param lists, one per rank in the DP group. + Each inner list contains the parameters owned by that rank's + layer-wise optimizer that also belong to this bucket. + """ + self.layerwise_params_list = layerwise_params_list + self.layerwise_param_flat_sizes = [ + sum([p.numel() for p in param_list]) for param_list in layerwise_params_list + ] + + +class _LayerwiseAllGatherHandle: + """Handle wrapping multiple async all-gather work objects. + + NCCL guarantees in-order completion on the same communicator, so waiting + on only the last handle is sufficient. + """ + + def __init__(self, handles): + self.handles = handles + + def wait(self): + """Wait on the last handle and clear all handles.""" + if self.handles: + self.handles[-1].wait() + self.handles = None + class _ParamAndGradBucketGroup: """ @@ -138,11 +175,13 @@ def __init__( self.buckets = buckets self.ddp_config = ddp_config - if self.ddp_config.use_distributed_optimizer: + # overlap_param_gather covers the layer-wise optimizer case, which sets + # overlap_param_gather=True without use_distributed_optimizer. + if self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather: self.intra_distributed_optimizer_instance_group = collective_group self.intra_distributed_optimizer_instance_size = collective_group_size self.intra_distributed_optimizer_instance_rank = collective_group.rank() - else: + if not self.ddp_config.use_distributed_optimizer: self.data_parallel_group = collective_group # State for bookkeeping: params is the set of parameters this bucket group is @@ -262,7 +301,9 @@ def start_param_sync(self, force_sync: bool = False): force_sync (bool, optional): force synchronous collective regardless of other settings if true. """ - assert self.ddp_config.use_distributed_optimizer + # overlap_param_gather covers the layer-wise optimizer case, which sets + # overlap_param_gather=True without use_distributed_optimizer. + assert self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather if force_sync: if self.param_gather_handle is not None: @@ -273,33 +314,114 @@ def start_param_sync(self, force_sync: bool = False): assert self.param_gather_handle is None async_op = self.ddp_config.overlap_param_gather and not force_sync - # Coalesce communication kernels across buckets in the bucket group. - with _coalescing_manager( - self.intra_distributed_optimizer_instance_group, async_ops=async_op - ) as cm: - for idx, bucket in enumerate(self.buckets): - if self.cached_param_buffer_shard_list[idx] is None: - self.cached_param_buffer_shard_list[idx] = shard_buffer( - bucket.param_data, self.intra_distributed_optimizer_instance_size + + if not self.ddp_config.use_distributed_optimizer: + # Layer-wise optimizer path: use all_gather for variable-size + # param gather. + # + # Each rank may own a different number of params per bucket, so + # layerwise_param_flat_sizes can vary across ranks. PyTorch's NCCL + # backend handles uneven tensor sizes in torch.distributed.all_gather + # (falling back to grouped send/recv internally when sizes differ), + # so no manual padding is needed. + dp_size = self.intra_distributed_optimizer_instance_size + local_rank = self.intra_distributed_optimizer_instance_rank + group = self.intra_distributed_optimizer_instance_group + layerwise_work_handles = [] + for bucket in self.buckets: + # Use param dtype (e.g., bf16), NOT grad dtype (which may be + # fp32 when grad_reduce_in_fp32 is enabled). + param_dtype = bucket.params_list[0].dtype + + if max(bucket.layerwise_param_flat_sizes) == 0: + # All ranks have empty params for this bucket — skip. + bucket.layerwise_gather_list = None + continue + + # Flatten local params. Detach from the autograd graph because + # start_param_sync can be called during the forward pass (where + # autograd is active) and all_gather will write into gather_list + # entries in-place. + local_size = bucket.layerwise_param_flat_sizes[local_rank] + if local_size > 0: + flat_local_params = _flatten_dense_tensors( + bucket.layerwise_params_list[local_rank] + ).detach() + else: + flat_local_params = torch.empty( + 0, device=bucket.grad_data.device, dtype=param_dtype ) - local_data_view = self.cached_param_buffer_shard_list[idx][ - self.intra_distributed_optimizer_instance_rank - ] - dist_all_gather_func( - bucket.param_data, - local_data_view, - group=self.intra_distributed_optimizer_instance_group, - async_op=async_op, + # Keep flat_local_params alive until the async operation completes. + bucket._layerwise_src_buffer = flat_local_params + + # Allocate per-rank receive buffers with actual sizes (no padding). + # Reuse flat_local_params for local_rank's slot to avoid an extra allocation. + gather_list = [] + for i in range(dp_size): + if i == local_rank: + gather_list.append(flat_local_params) + else: + gather_list.append( + torch.empty( + bucket.layerwise_param_flat_sizes[i], + device=flat_local_params.device, + dtype=flat_local_params.dtype, + ) + ) + bucket.layerwise_gather_list = gather_list + + work = torch.distributed.all_gather( + gather_list, flat_local_params, group=group, async_op=async_op ) - if async_op: - self.param_gather_handle = cm + if async_op and work is not None: + layerwise_work_handles.append(work) + + if async_op: + self.param_gather_handle = _LayerwiseAllGatherHandle(layerwise_work_handles) + else: + # Synchronous: unflatten and copy gathered params immediately. + for bucket in self.buckets: + if bucket.layerwise_gather_list is None: + continue + for idx, params in enumerate(bucket.layerwise_params_list): + if len(params) == 0 or idx == local_rank: + continue + updated_params = _unflatten_dense_tensors( + bucket.layerwise_gather_list[idx], params + ) + for updated_p, model_p in zip(updated_params, params): + model_p.data.copy_(updated_p) + bucket.layerwise_gather_list = None + bucket._layerwise_src_buffer = None + self.param_gather_handle = None else: - # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used, - # `cm` is not None, which is different from when `_coalescing_manager` is not used in - # which case the torch.distributed._all_gather_base() will return None. In order to - # maintain consistency with prior code, we need to manually set communication handle to - # None. - self.param_gather_handle = None + # Standard distributed optimizer path: use _coalescing_manager. + # all_gather_into_tensor writes directly into a contiguous output buffer and + # does not need a copy-back step, so coalescing works correctly. + with _coalescing_manager( + self.intra_distributed_optimizer_instance_group, async_ops=async_op + ) as cm: + for idx, bucket in enumerate(self.buckets): + if self.cached_param_buffer_shard_list[idx] is None: + self.cached_param_buffer_shard_list[idx] = shard_buffer( + bucket.param_data, self.intra_distributed_optimizer_instance_size + ) + local_data_view = self.cached_param_buffer_shard_list[idx][ + self.intra_distributed_optimizer_instance_rank + ] + dist_all_gather_func( + bucket.param_data, + local_data_view, + group=self.intra_distributed_optimizer_instance_group, + async_op=async_op, + ) + if async_op: + self.param_gather_handle = cm + else: + # When using `_coalescing_manager`, even if a synchronous op + # (async_op=False) is used, `cm` is not None. Manually set to None for + # consistency with prior code. + self.param_gather_handle = None self.param_gather_dispatched = True def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): @@ -317,7 +439,6 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): skip_next_bucket_dispatch (bool, optional): if true, dispatch next bucket's communication if available. """ - assert self.ddp_config.use_distributed_optimizer assert self.ddp_config.overlap_param_gather # If current bucket's param AG has not been dispatched, dispatch it now (e.g., first @@ -355,6 +476,25 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): # correspond to multiple param buffers. If we zero out the entire grad buffer, # it would clear the data of those param buffers that have not yet completed AG. bucket.param_data.zero_() + elif not self.ddp_config.use_distributed_optimizer: + for bucket in self.buckets: + if bucket.layerwise_gather_list is None: + continue + # Unflatten and copy gathered params for each rank. + for idx, params in enumerate(bucket.layerwise_params_list): + # Skip local params and empty tensors. + if ( + len(params) == 0 + or idx == self.intra_distributed_optimizer_instance_rank + ): + continue + updated_params = _unflatten_dense_tensors( + bucket.layerwise_gather_list[idx], params + ) + for updated_p, model_p in zip(updated_params, params): + model_p.data.copy_(updated_p) + bucket.layerwise_gather_list = None + bucket._layerwise_src_buffer = None else: fp8_params = [] for bucket in self.buckets: @@ -539,6 +679,21 @@ def finish_grad_sync(self, force_all_reduce: Optional[bool] = False): self.grad_reduce_handle.wait() self.grad_reduce_handle = None + def free_overlap_buffers(self): + """Free GPU buffers used by overlap param gather. + + Waits on any pending param all-gather handle, then releases the + per-bucket temporary buffers so that the CUDA memory allocator can + reclaim them. Called before async checkpoint saves to avoid OOM in + the persistent checkpoint worker process. + """ + if self.param_gather_handle is not None: + self.param_gather_handle.wait() + self.param_gather_handle = None + for bucket in self.buckets: + bucket.layerwise_gather_list = None + bucket._layerwise_src_buffer = None + def register_grad_ready( self, param: torch.nn.Parameter, force_all_reduce: Optional[bool] = False ): diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index de4396a5b4f..a9fdc7ba72f 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -45,6 +45,8 @@ def __init__( config: OptimizerConfig, pg_collection: Optional[ProcessGroupCollection] = None, init_state_fn_list: Optional[List[Callable]] = None, + model_chunks: Optional[List] = None, + async_allgather: bool = False, ) -> None: """ Initialize LayerWiseDistributedOptimizer. @@ -54,10 +56,21 @@ def __init__( config: OptimizerConfig. pg_collection: ProcessGroupCollection. init_state_fn_list: List of init state functions. + model_chunks: DDP-wrapped model chunks (needed for async_allgather). + async_allgather: If True, defer param all-gather to forward pre-hooks. """ self.pg_collection = pg_collection self.shard_params(optimizers) + + # Set up async all-gather using DDP bucket infrastructure. + self.async_allgather = async_allgather + if self.async_allgather: + assert ( + model_chunks is not None + ), "model_chunks must be provided if async_allgather is True" + self.set_bucket_layerwise_params_list(model_chunks) + if init_state_fn_list: assert len(init_state_fn_list) == len( optimizers @@ -143,34 +156,79 @@ def shard_params(self, optimizers): if expt_dp_size == 1 or len(self.expt_dp_params_list[0]) == 0: self.expt_dp_params_list = None + def set_bucket_layerwise_params_list(self, model_chunks): + """Map sharded params to DDP buckets for async all-gather. + + For each bucket in each model chunk's bucket groups, build per-rank param lists + by cross-referencing the layer-wise sharded param lists with the bucket's params. + + Args: + model_chunks: DDP-wrapped model chunks with bucket_groups. + """ + for model_chunk in model_chunks: + for group in model_chunk.bucket_groups: + for bucket in group.buckets: + bucket_params_list = [[] for _ in range(get_pg_size(self.pg_collection.dp_cp))] + for bucket_list, full_params_list in zip( + bucket_params_list, self.dp_cp_params_list + ): + for param in full_params_list: + if param in bucket.params: + bucket_list.append(param) + bucket.set_layerwise_params_list(bucket_params_list) + # Do the same for expert parallel bucket groups. + if self.expt_dp_params_list is not None: + for group in model_chunk.expert_parallel_bucket_groups: + for bucket in group.buckets: + bucket_params_list = [ + [] for _ in range(get_pg_size(self.pg_collection.expt_dp)) + ] + for bucket_list, full_params_list in zip( + bucket_params_list, self.expt_dp_params_list + ): + for param in full_params_list: + if param in bucket.params: + bucket_list.append(param) + bucket.set_layerwise_params_list(bucket_params_list) + @torch.no_grad() def allgather_params(self) -> None: """All-gather updated params from all ranks.""" - # helper function to flatten local params, allgather, unflatten and copy to model params + # helper function to flatten local params, all-gather, + # unflatten and copy to model params def _allgather_helper(params_list, group): - # flatten this rank's params and create empty tensor output list device = params_list[0][0].device dtype = params_list[0][0].dtype rank = get_pg_rank(group) - # for rank without params create empty tensor and participate in allgather + dp_size = get_pg_size(group) + # Flatten this rank's params. src = ( _flatten_dense_tensors(params_list[rank]) if len(params_list[rank]) > 0 else torch.empty(0, device=device, dtype=dtype) ) - output_list = [ - torch.empty(sum([p.numel() for p in params]), device=device, dtype=dtype) - for params in params_list - ] - # single all_gather_v to collect all updated params - torch.distributed.all_gather(output_list, src, group=group) - # unflatten and copy gathered params for each rank i - for idx, (flat_params, params) in enumerate(zip(output_list, params_list)): - # skip local params and empty tensors + flat_sizes = [sum(p.numel() for p in params) for params in params_list] + if max(flat_sizes) == 0: + return + + # Allocate per-rank receive buffers with actual sizes (no padding). + # PyTorch's NCCL backend handles uneven sizes in all_gather via + # grouped send/recv internally. Reuse src for local rank's slot. + gather_list = [] + for i in range(dp_size): + if i == rank: + gather_list.append(src) + else: + gather_list.append(torch.empty(flat_sizes[i], device=device, dtype=dtype)) + + torch.distributed.all_gather(gather_list, src, group=group) + + # Unflatten and copy gathered params for each rank. + for idx, params in enumerate(params_list): if len(params) == 0 or idx == rank: continue - updated_params = _unflatten_dense_tensors(flat_params, params) + updated_params = _unflatten_dense_tensors(gather_list[idx], params) for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) @@ -223,8 +281,10 @@ def step(self): # type: ignore[no-untyped-def] """step function for layer-wise optimizer.""" update_successful, grad_norm, num_zeros_in_grad = super().step() - # All gather updated params. - self.allgather_params() + # All gather updated params. If async_allgather is True, the allgather + # is deferred to the forward pre-hooks via DDP bucket infrastructure. + if not self.async_allgather: + self.allgather_params() return update_successful, grad_norm, num_zeros_in_grad diff --git a/megatron/core/optimizer/muon.py b/megatron/core/optimizer/muon.py index 57eb1e94478..c637eaa5442 100644 --- a/megatron/core/optimizer/muon.py +++ b/megatron/core/optimizer/muon.py @@ -345,6 +345,11 @@ def adam_init_state_fn(opt, config=None): if reset_config_bf16: config.bf16 = True return LayerWiseDistributedOptimizer( - optimizers, config, pg_collection, init_state_fn_list=init_fns + optimizers, + config, + pg_collection, + init_state_fn_list=init_fns, + model_chunks=model_chunks, + async_allgather=config.overlap_param_gather, ) return ChainedOptimizer(optimizers) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 91e26af99c6..dd76d310275 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -813,8 +813,9 @@ def validate_args(args, defaults={}): ) if args.overlap_param_gather: - assert args.use_distributed_optimizer or args.use_megatron_fsdp, \ - '--overlap-param-gather only supported with distributed optimizer or megatron fsdp' + assert args.use_distributed_optimizer or args.use_megatron_fsdp \ + or args.optimizer == 'dist_muon', \ + '--overlap-param-gather only supported with distributed optimizer, megatron fsdp, or dist_muon' assert args.overlap_grad_reduce, \ 'Must use --overlap-param-gather with --overlap-grad-reduce' assert not args.use_legacy_models, \ @@ -1418,9 +1419,9 @@ def validate_args(args, defaults={}): # Muon optimizer check if 'muon' in args.optimizer: - # TODO: remove these checks once we support them - assert not args.overlap_grad_reduce, "Muon optimizer does not support overlap grad reduce for now." - assert not args.overlap_param_gather, "Muon optimizer does not support overlap param gather for now." + if args.optimizer == 'muon': + assert not args.overlap_grad_reduce, "Muon optimizer does not support overlap grad reduce. Use dist_muon instead." + assert not args.overlap_param_gather, "Muon optimizer does not support overlap param gather. Use dist_muon instead." assert not args.use_distributed_optimizer, "Muon optimizer does not support distributed optimizer for now." assert not args.use_torch_fsdp2, "Muon optimizer does not support Torch-FSDP2 for now." diff --git a/megatron/training/training.py b/megatron/training/training.py index b508da02ef1..91626e6a5dc 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -2263,6 +2263,13 @@ def save_checkpoint_and_time( one_logger_utils.track_e2e_metrics() if should_disable_forward_pre_hook(args): force_param_sync(model) + # Free overlap param-gather buffers and release cached GPU memory so + # that the async checkpoint worker process has enough GPU headroom for + # D2H tensor transfers. + for model_chunk in model: + if hasattr(model_chunk, 'free_overlap_buffers'): + model_chunk.free_overlap_buffers() + torch.cuda.empty_cache() global num_checkpoints_memory_reported, MAX_NUM_CHECKPOINTS_MEMORY_REPORTED should_report_memory = num_checkpoints_memory_reported < MAX_NUM_CHECKPOINTS_MEMORY_REPORTED @@ -3597,4 +3604,8 @@ def _get_iterator(dataloader_type, dataloader): def should_disable_forward_pre_hook(args): """Block forward pre-hook for certain configurations.""" - return not args.use_megatron_fsdp and args.use_distributed_optimizer and args.overlap_param_gather + return ( + not args.use_megatron_fsdp + and (args.use_distributed_optimizer or 'dist' in args.optimizer) + and args.overlap_param_gather + ) diff --git a/tests/unit_tests/distributed/test_param_and_grad_buffer.py b/tests/unit_tests/distributed/test_param_and_grad_buffer.py index b60dfb1791b..48f815531c5 100644 --- a/tests/unit_tests/distributed/test_param_and_grad_buffer.py +++ b/tests/unit_tests/distributed/test_param_and_grad_buffer.py @@ -373,3 +373,94 @@ def test_force_all_reduce_uses_correct_collective(force_all_reduce: bool): ), "Expected all_reduce NOT to be called when force_all_reduce=False" Utils.destroy_model_parallel() + + +class TestFreeOverlapBuffers: + """Tests for free_overlap_buffers() which releases GPU memory before async checkpoint saves.""" + + @staticmethod + def _make_model(): + """Create a DDP-wrapped model with overlap_param_gather enabled.""" + Utils.initialize_model_parallel() + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=True, + use_distributed_optimizer=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + bucket_size=None, + ) + module = TestModel( + input_dim=32, output_dim=32, num_layers=2, bias=False, shared_embedding=False + ).bfloat16() + model = DistributedDataParallel( + TransformerConfig(num_attention_heads=1, num_layers=1), + ddp_config=ddp_config, + module=module, + ) + return model + + def test_bucket_group_clears_buffers(self): + """free_overlap_buffers on a bucket group should None-out per-bucket layerwise buffers.""" + model = self._make_model() + + for bg in model.bucket_groups: + # Simulate buffers that would be allocated by start_param_sync. + for bucket in bg.buckets: + bucket.layerwise_gather_list = [torch.empty(8), torch.empty(8)] + bucket._layerwise_src_buffer = torch.empty(16) + + bg.free_overlap_buffers() + + for bucket in bg.buckets: + assert ( + bucket.layerwise_gather_list is None + ), "layerwise_gather_list should be None after free_overlap_buffers" + assert ( + bucket._layerwise_src_buffer is None + ), "_layerwise_src_buffer should be None after free_overlap_buffers" + + Utils.destroy_model_parallel() + + def test_bucket_group_waits_on_pending_handle(self): + """free_overlap_buffers should wait() on any pending param_gather_handle.""" + model = self._make_model() + + for bg in model.bucket_groups: + mock_handle = mock.MagicMock() + bg.param_gather_handle = mock_handle + + bg.free_overlap_buffers() + + mock_handle.wait.assert_called_once() + assert ( + bg.param_gather_handle is None + ), "param_gather_handle should be None after free_overlap_buffers" + + Utils.destroy_model_parallel() + + def test_bucket_group_noop_when_no_buffers(self): + """free_overlap_buffers should be safe to call when no buffers are allocated.""" + model = self._make_model() + + for bg in model.bucket_groups: + assert bg.param_gather_handle is None + for bucket in bg.buckets: + assert bucket.layerwise_gather_list is None + assert bucket._layerwise_src_buffer is None + + # Should not raise. + bg.free_overlap_buffers() + + Utils.destroy_model_parallel() + + def test_ddp_free_overlap_buffers_delegates(self): + """DDP.free_overlap_buffers should call free_overlap_buffers on all bucket groups.""" + model = self._make_model() + + with mock.patch.object(type(model.bucket_groups[0]), 'free_overlap_buffers') as mock_free: + model.free_overlap_buffers() + assert mock_free.call_count == len( + model.bucket_groups + model.expert_parallel_bucket_groups + ), "free_overlap_buffers should be called on every bucket group" + + Utils.destroy_model_parallel() diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index 05ce26bcfa0..c484ca104ee 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -10,12 +10,13 @@ from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig -from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer +from megatron.core.optimizer import OptimizerConfig from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer -from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer +from megatron.core.optimizer.muon import get_megatron_muon_optimizer +from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer import TransformerConfig -from megatron.core.utils import get_pg_size +from megatron.core.utils import get_pg_rank, get_pg_size from tests.unit_tests.test_utilities import Utils # Skip all tests in this file for LTS versions @@ -88,8 +89,8 @@ def create_model_and_optimizer( model_class: Model class to instantiate clip_grad: Optional gradient clipping value model_kwargs: Optional kwargs for model initialization - use_layer_wise: If True, wrap optimizer in LayerWiseDistributedOptimizer; - if False, use get_megatron_optimizer instead (for reference) + use_layer_wise: If True, use LayerWiseDistributedOptimizer via dist_muon; + if False, use standard muon ChainedOptimizer (for reference) Returns: tuple: (model, optimizer, pg_collection) @@ -110,24 +111,99 @@ def create_model_and_optimizer( model.broadcast_params() optimizer_config = OptimizerConfig( - optimizer='adam', + optimizer='muon', lr=0.01, weight_decay=0.01, - bf16=not use_layer_wise, + bf16=True, use_distributed_optimizer=False, clip_grad=clip_grad, + muon_tp_mode="duplicated", ) pg_collection = ProcessGroupCollection.use_mpu_process_groups() pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() - optimizer = get_megatron_optimizer(optimizer_config, [model]) - if use_layer_wise: - optimizer_config.bf16 = True - optimizer = LayerWiseDistributedOptimizer( - optimizer.chained_optimizers, optimizer_config, pg_collection - ) + optimizer = get_megatron_muon_optimizer( + config=optimizer_config, + model_chunks=[model], + use_gloo_process_groups=True, + layer_wise_distributed_optimizer=use_layer_wise, + pg_collection=pg_collection, + ) + return model, optimizer, pg_collection + + def create_model_and_optimizer_with_overlap_param_gather( + self, + model_class=SimpleModel, + clip_grad=1.0, + model_kwargs=None, + copy_from=None, + async_allgather=True, + grad_reduce_in_fp32=False, + bucket_size=None, + ): + """Create model, DDP wrapper, and optimizer with overlap-param-gather enabled. + + This variant sets overlap_param_gather=True in DDP config and uses + get_megatron_muon_optimizer with layer_wise_distributed_optimizer=True, + enabling the bucket-based async param gather path. + + Args: + model_class: Model class to instantiate + clip_grad: Optional gradient clipping value + model_kwargs: Optional kwargs for model initialization + copy_from: Optional DDP model to copy weights from + async_allgather: If True, defer param all-gather to bucket infrastructure + grad_reduce_in_fp32: If True, reduce grads in fp32 (regression test for dtype fix) + bucket_size: Maximum number of parameters per bucket (None = single bucket) + + Returns: + tuple: (model, optimizer, pg_collection) + """ + if model_kwargs is None: + model_kwargs = {} + + model = model_class(**model_kwargs).bfloat16().cuda() + model.requires_grad_(True) + + ddp_config = DistributedDataParallelConfig( + use_distributed_optimizer=False, + overlap_param_gather=True, + overlap_grad_reduce=True, + grad_reduce_in_fp32=grad_reduce_in_fp32, + bucket_size=bucket_size, + ) + model = DistributedDataParallel( + TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model + ) + if copy_from: + model.module.load_state_dict(copy_from.module.state_dict()) + else: + model.broadcast_params() + + optimizer_config = OptimizerConfig( + optimizer='muon', + lr=0.01, + weight_decay=0.01, + bf16=True, + use_distributed_optimizer=False, + clip_grad=clip_grad, + overlap_param_gather=async_allgather, + muon_tp_mode="duplicated", + ) + + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) + pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() + + optimizer = get_megatron_muon_optimizer( + config=optimizer_config, + model_chunks=[model], + use_gloo_process_groups=True, + layer_wise_distributed_optimizer=True, + pg_collection=pg_collection, + ) return model, optimizer, pg_collection def create_reference_model(self, model): @@ -239,11 +315,19 @@ def test_sharded_state_dict(self): # Test sharded_state_dict sharded_state_dict = optimizer.sharded_state_dict(model_sharded_state_dict) - # Verify the sharded_state_dict is not None and has expected structure + # Verify the sharded_state_dict is not None and has expected structure. + # With multiple chained optimizers (muon + adam), the top-level keys are + # integer indices; each sub-dict should contain an 'optimizer' key. assert sharded_state_dict is not None, "Sharded state dict should not be None" - assert ( - 'optimizer' in sharded_state_dict - ), "Sharded state dict should contain 'optimizer' key" + if isinstance(sharded_state_dict, dict) and all( + isinstance(k, int) for k in sharded_state_dict.keys() + ): + for idx, sub_dict in sharded_state_dict.items(): + assert 'optimizer' in sub_dict, f"Sub-dict {idx} should contain 'optimizer' key" + else: + assert ( + 'optimizer' in sharded_state_dict + ), "Sharded state dict should contain 'optimizer' key" # Verify that replica_id is set correctly (should be 0 for DP dimension) from megatron.core.dist_checkpointing import ShardedTensor @@ -261,42 +345,13 @@ def test_sharded_state_dict(self): def test_multiple_optimizers(self): """Test LayerWiseDistributedOptimizer with multiple chained optimizers. - This test properly tests allgather functionality with multiple ranks. + Uses get_megatron_muon_optimizer which produces multiple chained optimizers + (muon for 2D weights + adam for 1D biases). Tests allgather with multiple ranks. """ - model = SimpleModel().bfloat16().cuda() - model.requires_grad_(True) - - ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=False) - model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model - ) - - optimizer_config = OptimizerConfig( - optimizer='adam', lr=0.01, bf16=True, use_distributed_optimizer=False - ) - - # Split parameters into two groups for testing multiple optimizers - params = list(model.parameters()) - mid_point = len(params) // 2 - param_groups_1 = [{'params': params[:mid_point]}] - param_groups_2 = [{'params': params[mid_point:]}] - - # Create two separate base optimizers - base_optimizer_1 = torch.optim.Adam(param_groups_1, lr=optimizer_config.lr) - base_optimizer_2 = torch.optim.Adam(param_groups_2, lr=optimizer_config.lr) - - wrapped_optimizer_1 = FP32Optimizer(base_optimizer_1, optimizer_config, None) - wrapped_optimizer_2 = FP32Optimizer(base_optimizer_2, optimizer_config, None) - - pg_collection = ProcessGroupCollection.use_mpu_process_groups() - pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) - pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() - - optimizer = LayerWiseDistributedOptimizer( - [wrapped_optimizer_1, wrapped_optimizer_2], optimizer_config, pg_collection - ) + model, optimizer, pg_collection = self.create_model_and_optimizer() - assert len(optimizer.chained_optimizers) == 2, "Should have two chained optimizers" + # get_megatron_muon_optimizer produces muon + adam chained optimizers + assert len(optimizer.chained_optimizers) >= 2, "Should have multiple chained optimizers" # Set gradients and test optimizer step - this will trigger allgather for param in model.parameters(): @@ -332,26 +387,39 @@ def test_bf16_error(self): TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model ) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) + pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() + + # Create muon optimizer (non-layer-wise) — produces Float16-wrapped chained optimizers optimizer_config = OptimizerConfig( - optimizer='adam', lr=0.01, bf16=True, use_distributed_optimizer=False + optimizer='muon', + lr=0.01, + bf16=True, + use_distributed_optimizer=False, + muon_tp_mode="duplicated", ) - - # Create base optimizer and manually wrap in Float16 optimizer - param_groups = [{'params': list(model.parameters())}] - base_optimizer = torch.optim.Adam(param_groups, lr=optimizer_config.lr) - wrapped_optimizer = Float16OptimizerWithFloat16Params( - base_optimizer, optimizer_config, None, None + muon_optimizer = get_megatron_muon_optimizer( + config=optimizer_config, + model_chunks=[model], + use_gloo_process_groups=True, + layer_wise_distributed_optimizer=False, + pg_collection=pg_collection, ) - pg_collection = ProcessGroupCollection.use_mpu_process_groups() - pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) - pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() + # Extract a Float16-wrapped chained optimizer + wrapped_optimizer = muon_optimizer.chained_optimizers[0] + assert isinstance(wrapped_optimizer, Float16OptimizerWithFloat16Params) # Should raise TypeError when receiving already-wrapped Float16 optimizer + # Use a fresh config since get_megatron_muon_optimizer mutates config.optimizer + lw_config = OptimizerConfig( + optimizer='muon', lr=0.01, bf16=True, use_distributed_optimizer=False + ) with pytest.raises( TypeError, match='LayerWiseDistributedOptimizer received Float16 optimizer already' ): - LayerWiseDistributedOptimizer([wrapped_optimizer], optimizer_config, pg_collection) + LayerWiseDistributedOptimizer([wrapped_optimizer], lw_config, pg_collection) def _run_parameter_update_test(self, model_class=SimpleModel): """Helper method to test parameter updates with a given model class. @@ -438,3 +506,662 @@ def test_broadcast_vs_allgather(self): # Verify updated values match reference optimizer for param, ref_param in zip(model.parameters(), reference_model.parameters()): torch.testing.assert_close(param.data, ref_param.data, rtol=0, atol=0) + + # ---- Overlap-param-gather tests ---- + + def test_overlap_param_gather_basic(self): + """Test overlap-param-gather path: init, forward/backward/step, bucket-based param sync.""" + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather() + ) + + assert optimizer is not None, "Optimizer should not be None" + assert optimizer.async_allgather, "async_allgather should be True" + + reference_model = self.create_reference_model(model) + + input_tensor = torch.randn(16, 80, dtype=torch.bfloat16, device='cuda') + output = model(input_tensor) + loss = output.sum() + loss.backward() + + # step() updates local params but skips allgather (async_allgather=True) + update_successful, grad_norm, num_zeros = optimizer.step() + + assert update_successful, "Optimizer step should be successful" + + # Manually sync params through the bucket-based param sync path + # force_sync=True does synchronous allgather via bucket infrastructure + model.start_param_sync(force_sync=True) + + # Verify parameters were updated + params_updated = 0 + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + if not torch.equal(param.data, ref_param.data): + params_updated += 1 + + assert params_updated > 0, "At least some parameters should be updated" + + # Verify all ranks have the same updated parameters + dp_size = get_pg_size(pg_collection.dp_cp) + + if dp_size > 1: + for name, param in model.named_parameters(): + param_list = [torch.zeros_like(param.data) for _ in range(dp_size)] + torch.distributed.all_gather(param_list, param.data, group=pg_collection.dp_cp) + + for i in range(1, dp_size): + torch.testing.assert_close( + param_list[0], + param_list[i], + msg=f"Parameter {name} differs between rank 0 and rank {i}", + ) + + def test_overlap_param_gather_parameter_updates(self): + """Test overlap-param-gather produces same parameter updates as standard optimizer.""" + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather() + ) + + # Create reference model with standard (non-layer-wise) optimizer + reference_model, reference_optimizer, _ = self.create_model_and_optimizer( + use_layer_wise=False, copy_from=model + ) + + # Set same gradients on both models + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + assert torch.equal(param.data, ref_param.data) + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # step() with async_allgather=True: updates but no allgather + optimizer.step() + # Manually sync params via bucket infrastructure + model.start_param_sync(force_sync=True) + + reference_optimizer.step() + + # Verify updated values match reference optimizer + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + torch.testing.assert_close(param.data, ref_param.data, rtol=1e-5, atol=1e-5) + + def test_overlap_param_gather_vs_sync_allgather(self): + """Key correctness test: overlap path and sync allgather produce identical updates. + + Compares: + - Overlap path: async_allgather=True, bucket-based param sync + - Sync path: async_allgather=False, optimizer.allgather_params() in step() + """ + # Create overlap model + overlap_model, overlap_optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=True) + ) + + # Create sync model with same weights (overlap_param_gather=True but sync allgather) + sync_model, sync_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=overlap_model + ) + + # Verify initial parameters match + for op, sp in zip(overlap_model.parameters(), sync_model.parameters()): + assert torch.equal(op.data, sp.data) + + # Set identical gradients on both + for op, sp in zip(overlap_model.parameters(), sync_model.parameters()): + grad_value = torch.randn_like(op) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + op.main_grad = grad_value.clone().detach() + sp.main_grad = grad_value.clone().detach() + + # Overlap path: step + manual sync + overlap_optimizer.step() + overlap_model.start_param_sync(force_sync=True) + + # Sync path: step (includes allgather_params) + sync_optimizer.step() + + # Both paths should produce identical parameter values + for op, sp in zip(overlap_model.parameters(), sync_model.parameters()): + torch.testing.assert_close( + op.data, + sp.data, + rtol=0, + atol=0, + msg="Overlap and sync allgather paths produced different parameter updates", + ) + + def test_overlap_param_gather_bucket_lw_params(self): + """Verify bucket.layerwise_params_list is populated when async_allgather is enabled.""" + model, optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather() + ) + + dp_size = get_pg_size(pg_collection.dp_cp) + + for bucket_group in model.bucket_groups: + for bucket in bucket_group.buckets: + # layerwise_params_list should be populated by set_bucket_layerwise_params_list + assert ( + bucket.layerwise_params_list is not None + ), "bucket.layerwise_params_list should be populated" + assert ( + len(bucket.layerwise_params_list) == dp_size + ), f"Expected {dp_size} per-rank lists, got {len(bucket.layerwise_params_list)}" + + # The union of all per-rank param lists should cover all bucket params + all_lw_params = set() + for rank_params in bucket.layerwise_params_list: + for p in rank_params: + all_lw_params.add(p) + assert ( + all_lw_params == bucket.params + ), "Union of per-rank layerwise_params should equal bucket params" + + # layerwise_param_flat_sizes should be populated and have correct length + assert bucket.layerwise_param_flat_sizes is not None + assert len(bucket.layerwise_param_flat_sizes) == dp_size + + # Each flat size should equal the sum of param numels for that rank + for rank_idx in range(dp_size): + expected_size = sum(p.numel() for p in bucket.layerwise_params_list[rank_idx]) + assert bucket.layerwise_param_flat_sizes[rank_idx] == expected_size, ( + f"Rank {rank_idx}: expected flat_size {expected_size}, " + f"got {bucket.layerwise_param_flat_sizes[rank_idx]}" + ) + + def test_overlap_param_gather_vs_standard_ddp(self): + """Verify DDP with overlap_param_gather=True produces same results as standard DDP. + + Both use LayerWiseDistributedOptimizer but with different DDP configs: + - Overlap path: overlap_param_gather=True (padded buffers) + - Standard path: overlap_param_gather=False (unpadded buffers) + """ + # Create overlap-param-gather model (sync allgather for simpler comparison) + opg_model, opg_optimizer, pg_collection = ( + self.create_model_and_optimizer_with_overlap_param_gather(async_allgather=False) + ) + + # Create standard model with same weights + std_model, std_optimizer, _ = self.create_model_and_optimizer(copy_from=opg_model) + + # Set identical gradients + for op, sp in zip(opg_model.parameters(), std_model.parameters()): + assert torch.equal(op.data, sp.data) + grad_value = torch.randn_like(op) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + op.main_grad = grad_value.clone().detach() + sp.main_grad = grad_value.clone().detach() + + opg_optimizer.step() + std_optimizer.step() + + # Both should produce identical parameter values + for op, sp in zip(opg_model.parameters(), std_model.parameters()): + torch.testing.assert_close( + op.data, + sp.data, + rtol=1e-5, + atol=1e-5, + msg="Overlap-param-gather and standard paths produced different updates", + ) + + def test_overlap_param_gather_insufficient_parameters(self): + """Test overlap-param-gather with TinyModel (only 2 params). + + Many ranks will have no assigned params when world_size > 2. + """ + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + model_class=TinyModel + ) + + # Create reference model with standard (non-layer-wise) optimizer + reference_model, reference_optimizer, _ = self.create_model_and_optimizer( + model_class=TinyModel, use_layer_wise=False, copy_from=model + ) + + # Set same gradients on both models + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + assert torch.equal(param.data, ref_param.data) + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + optimizer.step() + model.start_param_sync(force_sync=True) + + reference_optimizer.step() + + # Verify updated values match reference optimizer + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + torch.testing.assert_close(param.data, ref_param.data, rtol=1e-5, atol=1e-5) + + def test_overlap_param_gather_broadcast_vs_allgather(self): + """Test overlap-param-gather: allgather vs broadcast produce same results.""" + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + model_class=SimpleModel, async_allgather=False + ) + + # Create reference model with overlap-param-gather path too + reference_model, reference_optimizer, _ = ( + self.create_model_and_optimizer_with_overlap_param_gather( + model_class=SimpleModel, async_allgather=False, copy_from=model + ) + ) + + # Set same gradients on both models + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + assert torch.equal(param.data, ref_param.data) + torch.testing.assert_close(param.data, ref_param.data, rtol=0, atol=0) + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + optimizer.step() + + # Verify at least some parameters were updated + params_updated = 0 + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + if not torch.equal(param.data, ref_param.data): + params_updated += 1 + + assert params_updated > 0, "At least some parameters should be updated" + + # step() internally calls allgather_params. Replace reference with broadcast. + reference_optimizer.allgather_params = reference_optimizer.broadcast_params + reference_optimizer.step() + + # Verify updated values match reference optimizer + for param, ref_param in zip(model.parameters(), reference_model.parameters()): + torch.testing.assert_close(param.data, ref_param.data, rtol=0, atol=0) + + def test_overlap_param_gather_multi_iteration(self): + """Test overlap-param-gather correctness over multiple training iterations. + + Runs multiple forward/backward/step iterations using the async allgather path. + After each iteration, manually syncs params and verifies they match a reference + model using the sync path. + """ + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True + ) + + # Create reference model with sync allgather for comparison + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model + ) + + for iteration in range(3): + # Set identical gradients on both models + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # Async path: step (no allgather) + manual sync + optimizer.step() + model.start_param_sync(force_sync=True) + + # Sync path: step (includes allgather) + ref_optimizer.step() + + # Verify parameters match after each iteration + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg=f"Parameters diverged at iteration {iteration}", + ) + + def test_overlap_param_gather_async_dispatch_and_finish(self): + """Test async dispatch + finish_param_sync cycle (the actual runtime path). + + start_param_sync() (no force_sync) dispatches async all-gathers, then + finish_param_sync() waits on the handle and unflattens gathered params. + """ + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True + ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model + ) + + # Set identical gradients on both models + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # Async path: step (no allgather) + async dispatch + explicit finish + optimizer.step() + model.start_param_sync() # async dispatch to all bucket groups + for bucket_group in model.bucket_groups: + bucket_group.finish_param_sync(skip_next_bucket_dispatch=True) + + # Sync path: step (includes allgather) + ref_optimizer.step() + + # Verify params match sync path + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg="Async dispatch + finish path produced different params than sync path", + ) + + # Verify all ranks have identical parameters + dp_size = get_pg_size(pg_collection.dp_cp) + if dp_size > 1: + for name, param in model.named_parameters(): + param_list = [torch.zeros_like(param.data) for _ in range(dp_size)] + torch.distributed.all_gather(param_list, param.data, group=pg_collection.dp_cp) + for i in range(1, dp_size): + torch.testing.assert_close( + param_list[0], + param_list[i], + msg=f"Parameter {name} differs between rank 0 and rank {i}", + ) + + def test_overlap_param_gather_finish_chains_next_bucket(self): + """Test that finish_param_sync() dispatches next_param_gather_bucket_group. + + Uses a small bucket_size to force multiple bucket groups, then dispatches + only the last bucket group and verifies that finishing it chains to the next. + """ + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True, bucket_size=2000 + ) + + bucket_groups = model.bucket_groups + if len(bucket_groups) <= 1: + pytest.skip("Need multiple bucket groups to test chaining") + + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model, bucket_size=2000 + ) + + # Set identical gradients on both models + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + optimizer.step() + + # Dispatch ONLY the last bucket group (which has next_param_gather_bucket_group set) + last_bg = bucket_groups[-1] + last_bg.start_param_sync() + + # Verify: next bucket group has NOT been dispatched yet + next_bg = last_bg.next_param_gather_bucket_group + assert next_bg is not None, "Last bucket group should have a next" + assert not next_bg.param_gather_dispatched, "Next bucket should not be dispatched yet" + + # Finish the last bucket group — should chain-dispatch the next one + last_bg.finish_param_sync() + + # Verify: next bucket group IS now dispatched via chaining + assert ( + next_bg.param_gather_dispatched + ), "finish_param_sync should have dispatched next bucket group" + + # Finish remaining bucket groups through the chain + for bg in reversed(bucket_groups[:-1]): + bg.finish_param_sync() + + # Reference: sync step + ref_optimizer.step() + + # Verify params match + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg="Chained bucket finish produced different params than sync path", + ) + + def test_overlap_param_gather_forward_pre_hook(self): + """Test forward pre-hooks trigger finish_param_sync during model(input). + + After async dispatch, running model(input) fires forward pre-hooks that + call finish_param_sync() on each bucket group, completing the param sync. + """ + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True + ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model + ) + + # Set identical gradients on both models + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # Async path: step (no allgather) + async dispatch + optimizer.step() + model.start_param_sync() # dispatch async broadcasts + + # Forward pass triggers hooks that call finish_param_sync() + input_tensor = torch.randn(16, 80, dtype=torch.bfloat16, device='cuda') + output = model(input_tensor) + + # Sync path: step (includes allgather) + ref_optimizer.step() + + # Verify params match + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg="Forward pre-hook path produced different params than sync path", + ) + + def test_overlap_param_gather_grad_reduce_in_fp32(self): + """Regression test: grad_reduce_in_fp32 must not cause dtype mismatch in broadcasts. + + When grad_reduce_in_fp32=True, the grad buffer dtype is fp32 but broadcast + buffers must use param dtype (bf16). Without the fix (commit cbed167fc), this + would cause a dtype mismatch error in the per-rank broadcast calls. + """ + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True, grad_reduce_in_fp32=True + ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model, grad_reduce_in_fp32=True + ) + + # Set identical gradients on both models + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # Async path: step + force_sync + optimizer.step() + model.start_param_sync(force_sync=True) + + # Sync path: step (includes allgather) + ref_optimizer.step() + + # Verify params match + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=1e-5, + atol=1e-5, + msg="grad_reduce_in_fp32 path produced different params than reference", + ) + + def test_overlap_param_gather_hook_enable_disable_cycle(self): + """Test the training loop's hook lifecycle: disable → manual sync → enable → forward. + + The training loop disables hooks before iteration 1 (for initialization), + then enables them for subsequent iterations. This test exercises that cycle. + """ + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True + ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model + ) + + input_tensor = torch.randn(16, 80, dtype=torch.bfloat16, device='cuda') + + # Iteration 1: hooks disabled, manual sync + model.disable_forward_pre_hook(param_sync=False) + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + optimizer.step() + model.start_param_sync(force_sync=True) # manual sync + + ref_optimizer.step() + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg="Params diverged after iteration 1 (hooks disabled)", + ) + + # Iteration 2: hooks re-enabled, forward pass triggers sync + model.enable_forward_pre_hook() + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + optimizer.step() + model.start_param_sync() # async dispatch + output = model(input_tensor) # hooks finish sync + + ref_optimizer.step() + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg="Params diverged after iteration 2 (hooks re-enabled)", + ) + + def test_overlap_param_gather_multi_iteration_with_hooks(self): + """Test multiple iterations using forward pre-hooks (not manual force_sync). + + Runs 3 iterations where each iteration uses: set grads → step → async dispatch → + forward pass (hooks wait+unflatten). Compares against reference model using sync + allgather after each iteration. + """ + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True + ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model + ) + + input_tensor = torch.randn(16, 80, dtype=torch.bfloat16, device='cuda') + + for iteration in range(3): + # Set identical gradients on both models + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # Async path: step + dispatch + forward (hooks wait+unflatten) + optimizer.step() + model.start_param_sync() # async dispatch + output = model(input_tensor) # hooks trigger finish_param_sync + + # Sync path: step (includes allgather) + ref_optimizer.step() + + # Verify parameters match after each iteration + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg=f"Parameters diverged at iteration {iteration}", + ) + + def test_overlap_param_gather_start_sync_with_autograd(self): + """Regression test: start_param_sync must work when autograd is active. + + _flatten_dense_tensors on params with requires_grad=True produces a tensor + that also requires grad. Since all_gather writes into gather_list entries + in-place and the local rank's slot reuses src, this triggers: + RuntimeError: a view of a leaf Variable that requires grad is being + used in an in-place operation. + The fix is to .detach() the flattened tensor before using it as src. + + This test calls start_param_sync (synchronous via force_sync) WITHOUT + torch.no_grad() to reproduce the exact scenario that occurs during the + forward pass when finish_param_sync chains to start_param_sync for the + next bucket group. + """ + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=True + ) + ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( + async_allgather=False, copy_from=model + ) + + # Confirm params require grad (the precondition for this bug). + for param in model.parameters(): + assert param.requires_grad, "Test requires params with requires_grad=True" + + # Set identical gradients on both models. + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + grad_value = torch.randn_like(param) + torch.distributed.broadcast(grad_value, src=0, group=pg_collection.dp_cp) + param.main_grad = grad_value.clone().detach() + ref_param.main_grad = grad_value.clone().detach() + + # Step both optimizers (async path skips allgather, ref path includes it). + optimizer.step() + ref_optimizer.step() + + # Call start_param_sync with autograd ENABLED (no torch.no_grad()). + # Before the .detach() fix, this would raise RuntimeError. + model.start_param_sync(force_sync=True) + + # Verify gathered params match the reference. + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + torch.testing.assert_close( + param.data, + ref_param.data, + rtol=0, + atol=0, + msg="Params incorrect after start_param_sync with autograd enabled", + )