Skip to content

Conversation

@leloykun
Copy link
Contributor

@leloykun leloykun commented Feb 24, 2025

Adds code for:

  1. Optimizing Newton-Schulz coefficients
  2. Tighter estimate of spectral norm using Gram iteration taken from https://arxiv.org/pdf/2305.16173

Usage note:

def zeropower_via_newtonschulz5(
    G: Tensor, steps: int, enable_better_spec_norm_est: bool = False
) -> Tensor:
    assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
-     a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
        X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
-     for i in range(steps):
+     for i, (a, b, c) in enumerate([
+         ...[insert the coefficients here]...
+     ]):
        A = X @ X.mT
        if i == 0 and enable_better_spec_norm_est:
            # Tigher estimate of spectral norm using 1st Gram iteration.
            # Taken from https://arxiv.org/pdf/2305.16173
            S_norm_est_over_f_norm__squared = A.norm(dim=(-2, -1), keepdim=True)
            X = X / (S_norm_est_over_f_norm__squared**0.5 + 1e-7)
            A = A / (S_norm_est_over_f_norm__squared + 1e-7)
        B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
        X = a * X + B @ X
    
    if G.size(-2) > G.size(-1):
        X = X.mT
    return X

@KellerJordan
Copy link
Owner

have we confirmed that this option never causes any instability?

It's potentially risky, so important to confirm; so I will wait to accept it until I have evidence

@toothacher17
Copy link

Hey, @leloykun I tried your Jax scripts and get a group of new hyper coeffs:

(4.0246, -6.4224, 2.6026)
(3.9872, -6.2793, 2.5377)
(3.3260, -4.8258, 1.9451)
(2.8778, -3.6189, 1.6208)
(3.0133, -3.6424, 1.6122)

Do you have any recommendations for which one to use or I just pick a random one?

@leloykun
Copy link
Contributor Author

Hi @toothacher17,

In zeropower_via_newtonschulz5, you should replace

for i in range(steps):

with

for i, (a, b, c) in enumerate([
    ...[insert the coefficients here]...
])

@KellerJordan
Copy link
Owner

I'm still a little afraid of this causing instability. Will test more and think about it.

@leloykun
Copy link
Contributor Author

Same... I'll move this back to drafts until further analysis.

@leloykun leloykun marked this pull request as draft March 25, 2025 08:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants