Bring in improvements from modded-nanogpt repo
#14
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Adds code for:
Usage note:
def zeropower_via_newtonschulz5( G: Tensor, steps: int, enable_better_spec_norm_est: bool = False ) -> Tensor: assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng - a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT # Ensure spectral norm is at most 1 X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) # Perform the NS iterations - for i in range(steps): + for i, (a, b, c) in enumerate([ + ...[insert the coefficients here]... + ]): A = X @ X.mT if i == 0 and enable_better_spec_norm_est: # Tigher estimate of spectral norm using 1st Gram iteration. # Taken from https://arxiv.org/pdf/2305.16173 S_norm_est_over_f_norm__squared = A.norm(dim=(-2, -1), keepdim=True) X = X / (S_norm_est_over_f_norm__squared**0.5 + 1e-7) A = A / (S_norm_est_over_f_norm__squared + 1e-7) B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng X = a * X + B @ X if G.size(-2) > G.size(-1): X = X.mT return X