Skip to content

Muon momentum update appears to deviate from the paper; seems to be using a running average #52

@RobotSail

Description

@RobotSail

In the write-ups for Muon, as well as other third-party sources such as Moonshot AI's Muon is Scalable for LLM Training paper, the update rule for Momentum is described as $B_t \gets \mu B_{t-1} + G_t$, per the following definition:

Image

But the current muon_update function updates the momentum using muon.lerp_:

def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
    momentum.lerp_(grad, 1 - beta)
    update = grad.lerp_(momentum, beta) if nesterov else momentum
    if update.ndim == 4: # for the case of conv filters
        update = update.view(len(update), -1)
    update = zeropower_via_newtonschulz5(update, steps=ns_steps)
    update *= max(1, grad.size(-2) / grad.size(-1))**0.5
    return update

This calculation instead appears to have the effect of calculating Momentum as a running average instead of a decaying sum, i.e.:

$$ M_t \gets \mu B_{t-1} + (1 - \mu) G_t $$

Looking at the original implementation from commit b8dda, it looks like this has changed from the original implementation which computed it correctly:

                    if 'momentum_buffer' not in state:
                        state['momentum_buffer'] = torch.zeros_like(g)
                    buf = state['momentum_buffer']
                    buf.mul_(momentum).add_(g)
                    if group['nesterov']:
                        g = g.add(buf, alpha=momentum)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions