Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
e317e8d
add --overlap-param-gather support for layer-wise optimizer (muon)
mchrzanowski Feb 13, 2026
d8189e4
Fix NaN in overlap-param-gather for layer-wise optimizer (Muon)
mchrzanowski Feb 15, 2026
2df43c9
Add unit tests for overlap-param-gather in layer-wise optimizer
mchrzanowski Feb 16, 2026
fa0ceb9
Remove use_layer_wise_optimizer from DDP config
mchrzanowski Feb 18, 2026
bbed683
Add comments explaining overlap_param_gather replacing use_layer_wise…
mchrzanowski Feb 18, 2026
97fba8b
Merge branch 'main' into overlap-param-gather-muon-rebased
mchrzanowski Feb 20, 2026
46bc317
Run autoformat (black, isort) on changed files
mchrzanowski Feb 20, 2026
7027091
Merge branch 'main' into overlap-param-gather-muon-rebased
mchrzanowski Feb 21, 2026
0c6a632
Merge branch 'main' into overlap-param-gather-muon-rebased
mchrzanowski Feb 22, 2026
3e88f38
Free overlap param-gather buffers before async checkpoint save to fix…
mchrzanowski Feb 22, 2026
3bf616c
Add unit tests for free_overlap_buffers
mchrzanowski Feb 22, 2026
f111804
Replace all_gather with per-rank broadcasts for layer-wise param gather
mchrzanowski Feb 22, 2026
211a0ca
Autoformat: black formatting fixes
mchrzanowski Feb 23, 2026
c7ee958
Remove assertions blocking overlap_grad_reduce and overlap_param_gath…
mchrzanowski Feb 23, 2026
cbed167
Fix dtype mismatch in layer-wise param gather broadcasts causing NCCL…
mchrzanowski Feb 23, 2026
7726a35
Fix timing-dependent NCCL deadlock in layer-wise param gather by wait…
mchrzanowski Feb 23, 2026
1b0f0b6
Merge branch 'main' into overlap-param-gather-muon-rebased
mchrzanowski Feb 23, 2026
5253153
Add missing regression tests for layer-wise optimizer overlap param g…
mchrzanowski Feb 23, 2026
189428b
Autoformat: black formatting fixes
mchrzanowski Feb 23, 2026
d07bb57
Switch layer-wise optimizer tests from adam to dist_muon
mchrzanowski Feb 23, 2026
3a0afd1
Add back overlap assertions for muon (but not dist_muon)
mchrzanowski Feb 23, 2026
97b9916
Remove redundant assert in finish_param_sync
mchrzanowski Feb 23, 2026
e5a53a2
Replace per-rank broadcasts with all_gather for layer-wise param gather
mchrzanowski Feb 24, 2026
be0d438
Merge branch 'main' into overlap-param-gather-muon-rebased
mchrzanowski Feb 25, 2026
6655a10
Merge branch 'main' into overlap-param-gather-muon-rebased
mchrzanowski Feb 25, 2026
87df1dd
Merge branch 'main' into overlap-param-gather-muon-rebased
mchrzanowski Feb 26, 2026
a566a7b
Skip --overlap-grad-reduce requirement for layer-wise optimizers
mchrzanowski Feb 28, 2026
f30914e
Rename lw -> layerwise, improve docstring, clear handles in wait()
mchrzanowski Mar 2, 2026
23c4516
Fix tests to use renamed layerwise_ attribute prefix
mchrzanowski Mar 3, 2026
d77abd5
Re-enable --overlap-grad-reduce assertion for --overlap-param-gather
mchrzanowski Mar 4, 2026
079e2a1
Merge branch 'main' into overlap-param-gather-muon-rebased
deepakn94 Mar 4, 2026
b22ecd3
Run autoformat.sh
mchrzanowski Mar 4, 2026
64dd7ea
Merge branch 'main' into overlap-param-gather-muon-rebased
mchrzanowski Mar 4, 2026
62b5ae2
Add missing docstring to fix lint error
mchrzanowski Mar 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions megatron/core/distributed/distributed_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,10 @@ def _allocate_buffers_for_parameters(

# Set `next_param_gather_bucket_group` for different bucket groups by iterating through
# buckets in reverse order (since all-gathers happen in reverse order of buckets).
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
# Note: overlap_param_gather covers both the distributed optimizer and the
# layer-wise optimizer cases; the latter sets overlap_param_gather=True
# without use_distributed_optimizer.
if self.ddp_config.overlap_param_gather:
num_bucket_groups = len(bucket_groups)
for i in range(1, num_bucket_groups):
bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = (
Expand Down Expand Up @@ -344,9 +347,10 @@ def unmap_weight_tensor(m):
grad_acc.register_hook(self._make_backward_post_hook(param))
self.grad_accs.append(grad_acc)

self.use_forward_hook = (
self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather
)
# Note: overlap_param_gather covers both the distributed optimizer and the
# layer-wise optimizer cases; the latter sets overlap_param_gather=True
# without use_distributed_optimizer.
self.use_forward_hook = self.ddp_config.overlap_param_gather
self.remove_forward_pre_hook_handles = {}
if self.use_forward_hook:
self.enable_forward_pre_hook()
Expand Down Expand Up @@ -534,6 +538,11 @@ def finish_grad_sync(self, force_all_reduce: Optional[bool] = False):
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.finish_grad_sync(force_all_reduce=force_all_reduce)

def free_overlap_buffers(self):
"""Free overlap param-gather GPU buffers across all bucket groups."""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.free_overlap_buffers()

def scale_gradients(self, scaling_factor: float):
"""Scale all gradients inside the buffers by `scaling_factor`."""
for buffer in self.buffers + self.expert_parallel_buffers:
Expand Down
211 changes: 183 additions & 28 deletions megatron/core/distributed/param_and_grad_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Dict, List, Optional

import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import _coalescing_manager

import megatron.core.nccl_allocator as nccl_allocator
Expand Down Expand Up @@ -112,6 +113,42 @@ def __init__(
global_start, global_end, _ = param_index_map[param]
self.param_to_index[param] = (global_start - offset, global_end - offset)

# Layer-wise optimizer attributes for async param gather.
self.layerwise_params_list = None
self.layerwise_param_flat_sizes = None
self.layerwise_gather_list = None
self._layerwise_src_buffer = None

def set_layerwise_params_list(self, layerwise_params_list: List[List[torch.nn.Parameter]]):
"""Set per-rank parameter lists for layer-wise async all-gather.

Args:
layerwise_params_list: List of param lists, one per rank in the DP group.
Each inner list contains the parameters owned by that rank's
layer-wise optimizer that also belong to this bucket.
"""
self.layerwise_params_list = layerwise_params_list
self.layerwise_param_flat_sizes = [
sum([p.numel() for p in param_list]) for param_list in layerwise_params_list
]


class _LayerwiseAllGatherHandle:
"""Handle wrapping multiple async all-gather work objects.

NCCL guarantees in-order completion on the same communicator, so waiting
on only the last handle is sufficient.
"""

def __init__(self, handles):
self.handles = handles

def wait(self):
"""Wait on the last handle and clear all handles."""
if self.handles:
self.handles[-1].wait()
self.handles = None


class _ParamAndGradBucketGroup:
"""
Expand All @@ -138,11 +175,13 @@ def __init__(
self.buckets = buckets
self.ddp_config = ddp_config

if self.ddp_config.use_distributed_optimizer:
# overlap_param_gather covers the layer-wise optimizer case, which sets
# overlap_param_gather=True without use_distributed_optimizer.
if self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather:
self.intra_distributed_optimizer_instance_group = collective_group
self.intra_distributed_optimizer_instance_size = collective_group_size
self.intra_distributed_optimizer_instance_rank = collective_group.rank()
else:
if not self.ddp_config.use_distributed_optimizer:
self.data_parallel_group = collective_group

# State for bookkeeping: params is the set of parameters this bucket group is
Expand Down Expand Up @@ -262,7 +301,9 @@ def start_param_sync(self, force_sync: bool = False):
force_sync (bool, optional): force synchronous collective regardless of
other settings if true.
"""
assert self.ddp_config.use_distributed_optimizer
# overlap_param_gather covers the layer-wise optimizer case, which sets
# overlap_param_gather=True without use_distributed_optimizer.
assert self.ddp_config.use_distributed_optimizer or self.ddp_config.overlap_param_gather

if force_sync:
if self.param_gather_handle is not None:
Expand All @@ -273,33 +314,114 @@ def start_param_sync(self, force_sync: bool = False):
assert self.param_gather_handle is None

async_op = self.ddp_config.overlap_param_gather and not force_sync
# Coalesce communication kernels across buckets in the bucket group.
with _coalescing_manager(
self.intra_distributed_optimizer_instance_group, async_ops=async_op
) as cm:
for idx, bucket in enumerate(self.buckets):
if self.cached_param_buffer_shard_list[idx] is None:
self.cached_param_buffer_shard_list[idx] = shard_buffer(
bucket.param_data, self.intra_distributed_optimizer_instance_size

if not self.ddp_config.use_distributed_optimizer:
# Layer-wise optimizer path: use all_gather for variable-size
# param gather.
#
# Each rank may own a different number of params per bucket, so
# layerwise_param_flat_sizes can vary across ranks. PyTorch's NCCL
# backend handles uneven tensor sizes in torch.distributed.all_gather
# (falling back to grouped send/recv internally when sizes differ),
# so no manual padding is needed.
dp_size = self.intra_distributed_optimizer_instance_size
local_rank = self.intra_distributed_optimizer_instance_rank
group = self.intra_distributed_optimizer_instance_group
layerwise_work_handles = []
for bucket in self.buckets:
# Use param dtype (e.g., bf16), NOT grad dtype (which may be
# fp32 when grad_reduce_in_fp32 is enabled).
param_dtype = bucket.params_list[0].dtype

if max(bucket.layerwise_param_flat_sizes) == 0:
# All ranks have empty params for this bucket — skip.
bucket.layerwise_gather_list = None
continue

# Flatten local params. Detach from the autograd graph because
# start_param_sync can be called during the forward pass (where
# autograd is active) and all_gather will write into gather_list
# entries in-place.
local_size = bucket.layerwise_param_flat_sizes[local_rank]
if local_size > 0:
flat_local_params = _flatten_dense_tensors(
bucket.layerwise_params_list[local_rank]
).detach()
else:
flat_local_params = torch.empty(
0, device=bucket.grad_data.device, dtype=param_dtype
)
local_data_view = self.cached_param_buffer_shard_list[idx][
self.intra_distributed_optimizer_instance_rank
]
dist_all_gather_func(
bucket.param_data,
local_data_view,
group=self.intra_distributed_optimizer_instance_group,
async_op=async_op,
# Keep flat_local_params alive until the async operation completes.
bucket._layerwise_src_buffer = flat_local_params

# Allocate per-rank receive buffers with actual sizes (no padding).
# Reuse flat_local_params for local_rank's slot to avoid an extra allocation.
gather_list = []
for i in range(dp_size):
if i == local_rank:
gather_list.append(flat_local_params)
else:
gather_list.append(
torch.empty(
bucket.layerwise_param_flat_sizes[i],
device=flat_local_params.device,
dtype=flat_local_params.dtype,
)
)
bucket.layerwise_gather_list = gather_list

work = torch.distributed.all_gather(
gather_list, flat_local_params, group=group, async_op=async_op
)
if async_op:
self.param_gather_handle = cm
if async_op and work is not None:
layerwise_work_handles.append(work)

if async_op:
self.param_gather_handle = _LayerwiseAllGatherHandle(layerwise_work_handles)
else:
# Synchronous: unflatten and copy gathered params immediately.
for bucket in self.buckets:
if bucket.layerwise_gather_list is None:
continue
for idx, params in enumerate(bucket.layerwise_params_list):
if len(params) == 0 or idx == local_rank:
continue
updated_params = _unflatten_dense_tensors(
bucket.layerwise_gather_list[idx], params
)
for updated_p, model_p in zip(updated_params, params):
model_p.data.copy_(updated_p)
bucket.layerwise_gather_list = None
bucket._layerwise_src_buffer = None
self.param_gather_handle = None
else:
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
# which case the torch.distributed._all_gather_base() will return None. In order to
# maintain consistency with prior code, we need to manually set communication handle to
# None.
self.param_gather_handle = None
# Standard distributed optimizer path: use _coalescing_manager.
# all_gather_into_tensor writes directly into a contiguous output buffer and
# does not need a copy-back step, so coalescing works correctly.
with _coalescing_manager(
self.intra_distributed_optimizer_instance_group, async_ops=async_op
) as cm:
for idx, bucket in enumerate(self.buckets):
if self.cached_param_buffer_shard_list[idx] is None:
self.cached_param_buffer_shard_list[idx] = shard_buffer(
bucket.param_data, self.intra_distributed_optimizer_instance_size
)
local_data_view = self.cached_param_buffer_shard_list[idx][
self.intra_distributed_optimizer_instance_rank
]
dist_all_gather_func(
bucket.param_data,
local_data_view,
group=self.intra_distributed_optimizer_instance_group,
async_op=async_op,
)
if async_op:
self.param_gather_handle = cm
else:
# When using `_coalescing_manager`, even if a synchronous op
# (async_op=False) is used, `cm` is not None. Manually set to None for
# consistency with prior code.
self.param_gather_handle = None
self.param_gather_dispatched = True

def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
Expand All @@ -317,7 +439,6 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
skip_next_bucket_dispatch (bool, optional): if true, dispatch next
bucket's communication if available.
"""
assert self.ddp_config.use_distributed_optimizer
assert self.ddp_config.overlap_param_gather

# If current bucket's param AG has not been dispatched, dispatch it now (e.g., first
Expand Down Expand Up @@ -355,6 +476,25 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
# correspond to multiple param buffers. If we zero out the entire grad buffer,
# it would clear the data of those param buffers that have not yet completed AG.
bucket.param_data.zero_()
elif not self.ddp_config.use_distributed_optimizer:
for bucket in self.buckets:
if bucket.layerwise_gather_list is None:
continue
# Unflatten and copy gathered params for each rank.
for idx, params in enumerate(bucket.layerwise_params_list):
# Skip local params and empty tensors.
if (
len(params) == 0
or idx == self.intra_distributed_optimizer_instance_rank
):
continue
updated_params = _unflatten_dense_tensors(
bucket.layerwise_gather_list[idx], params
)
for updated_p, model_p in zip(updated_params, params):
model_p.data.copy_(updated_p)
bucket.layerwise_gather_list = None
bucket._layerwise_src_buffer = None
else:
fp8_params = []
for bucket in self.buckets:
Expand Down Expand Up @@ -539,6 +679,21 @@ def finish_grad_sync(self, force_all_reduce: Optional[bool] = False):
self.grad_reduce_handle.wait()
self.grad_reduce_handle = None

def free_overlap_buffers(self):
"""Free GPU buffers used by overlap param gather.

Waits on any pending param all-gather handle, then releases the
per-bucket temporary buffers so that the CUDA memory allocator can
reclaim them. Called before async checkpoint saves to avoid OOM in
the persistent checkpoint worker process.
"""
if self.param_gather_handle is not None:
self.param_gather_handle.wait()
self.param_gather_handle = None
for bucket in self.buckets:
bucket.layerwise_gather_list = None
bucket._layerwise_src_buffer = None

def register_grad_ready(
self, param: torch.nn.Parameter, force_all_reduce: Optional[bool] = False
):
Expand Down
Loading
Loading