Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 52 additions & 111 deletions muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,14 @@ class Muon(torch.optim.Optimizer):
lr: The learning rate, in units of spectral norm per update.
weight_decay: The AdamW-style weight decay.
momentum: The momentum. A value of 0.95 here is usually fine.
distributed: Whether to use distributed training. Defaults to `torch.distributed.is_initialized()`.
"""
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95, distributed=None):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
params = sorted(params, key=lambda x: x.size(), reverse=True)
self.distributed = distributed if distributed is not None else dist.is_initialized()
if self.distributed:
params = sorted(params, key=lambda x: x.size(), reverse=True)
super().__init__(params, defaults)

@torch.no_grad()
Expand All @@ -78,51 +81,33 @@ def step(self, closure=None):

for group in self.param_groups:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]

if self.distributed:
world_size = dist.get_world_size()
rank = dist.get_rank()
params_pad = params + [torch.empty_like(params[-1])] * (world_size - len(params) % world_size)
for base_i in range(len(params))[::world_size]:
if base_i + rank < len(params):
p = params[base_i + rank]
if p.grad is None:
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank])
else:
for p in params:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
p.grad = torch.zeros_like(p)
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])

return loss


class SingleDeviceMuon(torch.optim.Optimizer):
"""
Muon variant for usage in non-distributed settings.
"""
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):

loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])

return loss

Expand Down Expand Up @@ -162,11 +147,13 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
optimizer = MuonWithAuxAdam(param_groups)
```
"""
def __init__(self, param_groups):
def __init__(self, param_groups, distributed=None):
self.distributed = distributed if distributed is not None else dist.is_initialized()
for group in param_groups:
assert "use_muon" in group
if group["use_muon"]:
group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
if self.distributed:
group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
# defaults
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
Expand All @@ -192,25 +179,36 @@ def step(self, closure=None):
for group in self.param_groups:
if group["use_muon"]:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]
if self.distributed:
world_size = dist.get_world_size()
rank = dist.get_rank()
params_pad = params + [torch.empty_like(params[-1])] * (world_size - len(params) % world_size)
for base_i in range(len(params))[::world_size]:
if base_i + rank < len(params):
p = params[base_i + rank]
if p.grad is None:
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank])
else:
for p in params:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
p.grad = torch.zeros_like(p)
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
else:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
p.grad = torch.zeros_like(p)
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p)
Expand All @@ -224,63 +222,6 @@ def step(self, closure=None):

return loss


class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
"""
Non-distributed variant of MuonWithAuxAdam.
"""
def __init__(self, param_groups):
for group in param_groups:
assert "use_muon" in group
if group["use_muon"]:
# defaults
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
group["betas"] = group.get("betas", (0.9, 0.95))
group["eps"] = group.get("eps", 1e-10)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
super().__init__(param_groups, dict())

@torch.no_grad()
def step(self, closure=None):

loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if group["use_muon"]:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
else:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])

return loss
# Alias for backward compatibility
SingleDeviceMuon = Muon
SingleDeviceMuonWithAuxAdam = MuonWithAuxAdam