diff --git a/muon.py b/muon.py index 8f11732..05b44fd 100644 --- a/muon.py +++ b/muon.py @@ -76,6 +76,7 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() + all_gather_futures = [] for group in self.param_groups: params = group["params"] params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) @@ -91,7 +92,13 @@ def step(self, closure=None): update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) - dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + future = dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], + params_pad[base_i + dist.get_rank()], + async_op=True) + all_gather_futures.append(future) + + for fut in all_gather_futures: + fut.wait() return loss