diff --git a/muon.py b/muon.py index 8f11732..a0ab7f3 100644 --- a/muon.py +++ b/muon.py @@ -61,11 +61,14 @@ class Muon(torch.optim.Optimizer): lr: The learning rate, in units of spectral norm per update. weight_decay: The AdamW-style weight decay. momentum: The momentum. A value of 0.95 here is usually fine. + distributed: Whether to use distributed training. Defaults to `torch.distributed.is_initialized()`. """ - def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95, distributed=None): defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) - params = sorted(params, key=lambda x: x.size(), reverse=True) + self.distributed = distributed if distributed is not None else dist.is_initialized() + if self.distributed: + params = sorted(params, key=lambda x: x.size(), reverse=True) super().__init__(params, defaults) @torch.no_grad() @@ -78,51 +81,33 @@ def step(self, closure=None): for group in self.param_groups: params = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) - for base_i in range(len(params))[::dist.get_world_size()]: - if base_i + dist.get_rank() < len(params): - p = params[base_i + dist.get_rank()] + + if self.distributed: + world_size = dist.get_world_size() + rank = dist.get_rank() + params_pad = params + [torch.empty_like(params[-1])] * (world_size - len(params) % world_size) + for base_i in range(len(params))[::world_size]: + if base_i + rank < len(params): + p = params[base_i + rank] + if p.grad is None: + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank]) + else: + for p in params: if p.grad is None: - # continue - p.grad = torch.zeros_like(p) # Force synchronization + p.grad = torch.zeros_like(p) state = self.state[p] if len(state) == 0: state["momentum_buffer"] = torch.zeros_like(p) update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) - dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) - - return loss - - -class SingleDeviceMuon(torch.optim.Optimizer): - """ - Muon variant for usage in non-distributed settings. - """ - def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - super().__init__(params, defaults) - - @torch.no_grad() - def step(self, closure=None): - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - # continue - p.grad = torch.zeros_like(p) # Force synchronization - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update.reshape(p.shape), alpha=-group["lr"]) return loss @@ -162,11 +147,13 @@ class MuonWithAuxAdam(torch.optim.Optimizer): optimizer = MuonWithAuxAdam(param_groups) ``` """ - def __init__(self, param_groups): + def __init__(self, param_groups, distributed=None): + self.distributed = distributed if distributed is not None else dist.is_initialized() for group in param_groups: assert "use_muon" in group if group["use_muon"]: - group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True) + if self.distributed: + group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True) # defaults group["lr"] = group.get("lr", 0.02) group["momentum"] = group.get("momentum", 0.95) @@ -192,25 +179,36 @@ def step(self, closure=None): for group in self.param_groups: if group["use_muon"]: params = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) - for base_i in range(len(params))[::dist.get_world_size()]: - if base_i + dist.get_rank() < len(params): - p = params[base_i + dist.get_rank()] + if self.distributed: + world_size = dist.get_world_size() + rank = dist.get_rank() + params_pad = params + [torch.empty_like(params[-1])] * (world_size - len(params) % world_size) + for base_i in range(len(params))[::world_size]: + if base_i + rank < len(params): + p = params[base_i + rank] + if p.grad is None: + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank]) + else: + for p in params: if p.grad is None: - # continue - p.grad = torch.zeros_like(p) # Force synchronization + p.grad = torch.zeros_like(p) state = self.state[p] if len(state) == 0: state["momentum_buffer"] = torch.zeros_like(p) update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) - dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) else: for p in group["params"]: if p.grad is None: - # continue - p.grad = torch.zeros_like(p) # Force synchronization + p.grad = torch.zeros_like(p) state = self.state[p] if len(state) == 0: state["exp_avg"] = torch.zeros_like(p) @@ -224,63 +222,6 @@ def step(self, closure=None): return loss - -class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer): - """ - Non-distributed variant of MuonWithAuxAdam. - """ - def __init__(self, param_groups): - for group in param_groups: - assert "use_muon" in group - if group["use_muon"]: - # defaults - group["lr"] = group.get("lr", 0.02) - group["momentum"] = group.get("momentum", 0.95) - group["weight_decay"] = group.get("weight_decay", 0) - assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) - else: - # defaults - group["lr"] = group.get("lr", 3e-4) - group["betas"] = group.get("betas", (0.9, 0.95)) - group["eps"] = group.get("eps", 1e-10) - group["weight_decay"] = group.get("weight_decay", 0) - assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) - super().__init__(param_groups, dict()) - - @torch.no_grad() - def step(self, closure=None): - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - for p in group["params"]: - if p.grad is None: - # continue - p.grad = torch.zeros_like(p) # Force synchronization - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update.reshape(p.shape), alpha=-group["lr"]) - else: - for p in group["params"]: - if p.grad is None: - # continue - p.grad = torch.zeros_like(p) # Force synchronization - state = self.state[p] - if len(state) == 0: - state["exp_avg"] = torch.zeros_like(p) - state["exp_avg_sq"] = torch.zeros_like(p) - state["step"] = 0 - state["step"] += 1 - update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], - state["step"], group["betas"], group["eps"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update, alpha=-group["lr"]) - - return loss +# Alias for backward compatibility +SingleDeviceMuon = Muon +SingleDeviceMuonWithAuxAdam = MuonWithAuxAdam