From 1430c3495b60613a6bb6f035ef38456529443a14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niccol=C3=B2=20Ajroldi?= Date: Mon, 3 Nov 2025 17:12:11 +0100 Subject: [PATCH 1/5] async AllGather --- muon.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/muon.py b/muon.py index 8f11732..5287e99 100644 --- a/muon.py +++ b/muon.py @@ -76,6 +76,7 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() + handles = [] for group in self.param_groups: params = group["params"] params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) @@ -91,7 +92,11 @@ def step(self, closure=None): update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) - dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + handle = dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + handles.append(handle) + + for handle in handles: + handle.wait() return loss From 07c2db30ca9f45f2f15a45582eb4248e1b604539 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niccol=C3=B2=20Ajroldi?= Date: Tue, 4 Nov 2025 09:10:47 +0100 Subject: [PATCH 2/5] async AllGather --- muon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/muon.py b/muon.py index 5287e99..7f5e592 100644 --- a/muon.py +++ b/muon.py @@ -92,7 +92,7 @@ def step(self, closure=None): update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) - handle = dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + handle = dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()], async_op=True) handles.append(handle) for handle in handles: From 1dbeb159f74f058644860bc24f0dd0c10478e7e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niccol=C3=B2=20Ajroldi?= <61059403+Niccolo-Ajroldi@users.noreply.github.com> Date: Sun, 30 Nov 2025 15:43:37 +0100 Subject: [PATCH 3/5] Refactor all_gather futures --- muon.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/muon.py b/muon.py index 7f5e592..e6042ec 100644 --- a/muon.py +++ b/muon.py @@ -76,7 +76,7 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() - handles = [] + all_gather_futures = [] for group in self.param_groups: params = group["params"] params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) @@ -92,11 +92,12 @@ def step(self, closure=None): update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) - handle = dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()], async_op=True) - handles.append(handle) + future = dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], + params_pad[base_i + dist.get_rank()], + async_op=True).get_future() + all_gather_futures.append(future) - for handle in handles: - handle.wait() + torch.futures.wait_all(all_gather_futures) return loss From 2e7a023888a148f631ad97a2048f10b23314e6c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niccol=C3=B2=20Ajroldi?= <61059403+Niccolo-Ajroldi@users.noreply.github.com> Date: Mon, 1 Dec 2025 18:30:53 +0100 Subject: [PATCH 4/5] Refactor future waiting --- muon.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/muon.py b/muon.py index e6042ec..3150a0f 100644 --- a/muon.py +++ b/muon.py @@ -94,10 +94,11 @@ def step(self, closure=None): p.add_(update.reshape(p.shape), alpha=-group["lr"]) future = dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()], - async_op=True).get_future() + async_op=True) all_gather_futures.append(future) - torch.futures.wait_all(all_gather_futures) + for fut in gather_futures: + fut.wait() return loss From 7b3d6602400ddcab1f85bfe659615d2eacce2bdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niccol=C3=B2=20Ajroldi?= <61059403+Niccolo-Ajroldi@users.noreply.github.com> Date: Mon, 1 Dec 2025 18:33:48 +0100 Subject: [PATCH 5/5] Update muon.py --- muon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/muon.py b/muon.py index 3150a0f..05b44fd 100644 --- a/muon.py +++ b/muon.py @@ -97,7 +97,7 @@ def step(self, closure=None): async_op=True) all_gather_futures.append(future) - for fut in gather_futures: + for fut in all_gather_futures: fut.wait() return loss