-
Notifications
You must be signed in to change notification settings - Fork 106
Open
Description
In the write-ups for Muon, as well as other third-party sources such as Moonshot AI's Muon is Scalable for LLM Training paper, the update rule for Momentum is described as
But the current muon_update function updates the momentum using muon.lerp_:
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum
if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
return updateThis calculation instead appears to have the effect of calculating Momentum as a running average instead of a decaying sum, i.e.:
Looking at the original implementation from commit b8dda, it looks like this has changed from the original implementation which computed it correctly:
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(g)
buf = state['momentum_buffer']
buf.mul_(momentum).add_(g)
if group['nesterov']:
g = g.add(buf, alpha=momentum)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels