Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/apidocs/orthogonalized-optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ emerging_optimizers.orthogonalized_optimizers
:members:


:hidden:`AdaptiveMuon`
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: AdaptiveMuon
:members:


:hidden:`Newton-Schulz`
~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: emerging_optimizers.orthogonalized_optimizers.muon_utils
Expand Down
75 changes: 62 additions & 13 deletions emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ class AdaptiveMuon(muon.Muon):
scale_mode: The type of scale factor to use for the update.
extra_scale_factor: The additional scale factor to use for the update.
use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration.
moment2_method: Method for second moment accumulation ("adamuon" or "normuon").
moment2_method: Method for second moment accumulation ("adamuon", "normuon", or "namo").
- "adamuon": Full elementwise second moment (like AdamW).
- "normuon": Row or column-wise second moment.
- "namo": Scalar adaptive scaling via Frobenius-norm ratio.
beta2: The exponential decay rate for second moment.
eps: Small constant for numerical stability.
"""
Expand All @@ -70,7 +73,7 @@ def __init__(
scale_mode: muon.MuonScaleT = "spectral",
extra_scale_factor: float = 1.0,
use_syrk: bool = False,
moment2_method: Literal["adamuon", "normuon"] = "adamuon",
moment2_method: Literal["adamuon", "normuon", "namo"] = "adamuon",
beta2: float = 0.95,
eps: float = 1e-8,
):
Expand Down Expand Up @@ -104,6 +107,7 @@ def _initialize_moment2(
The shape of the buffer depends on the moment2_method:
- "adamuon": Full elementwise buffer with same shape as grad
- "normuon": Reduced shape buffer (averaged along -1 if shape[-2] >= shape[-1], else -2)
- "namo": Scalar buffer (EMA of squared Frobenius norm of gradient)

Args:
state: The optimizer state dict for a parameter.
Expand All @@ -121,6 +125,9 @@ def _initialize_moment2(
moment2_shape = list(grad.shape)
moment2_shape[avg_dim] = 1
moment2 = torch.zeros(moment2_shape, dtype=grad.dtype, device=grad.device)
elif self.moment2_method == "namo":
# Scalar second moment: EMA of ||G_t||_F^2
moment2 = torch.zeros((), dtype=grad.dtype, device=grad.device)
else:
raise TypeError(f"Invalid second moment method: {self.moment2_method}")

Expand All @@ -132,22 +139,37 @@ def _apply_moment2_normalization(
moment2: torch.Tensor,
beta2: float,
eps: float,
*,
grad_fro_sq: torch.Tensor | None = None,
pre_orth_norm: torch.Tensor | None = None,
) -> torch.Tensor:
"""Apply AdamW-style second moment accumulation and normalization.
"""Apply second moment accumulation and normalization.

This method supports two variants:
This method supports three variants:
- "adamuon": Full elementwise second moment (like AdamW, https://arxiv.org/abs/2507.11005)
- "normuon": Row or column-wise second moment (https://arxiv.org/abs/2510.05491)
- "namo": Scalar adaptive scaling using Frobenius-norm ratio (https://arxiv.org/abs/2602.17080).
Scales the orthogonalized momentum by

.. math::

\\alpha_t = \\frac{\\|g_t^{\\text{pre-orth}}\\|_F}{\\sqrt{v_t} + \\varepsilon}

For both methods:
1. Updates the second moment as an EMA of squared gradients
where :math:`v_t` is the EMA of :math:`\\|G_t\\|_F^2`.

For all methods:
1. Updates the second moment as an EMA of (some function of) squared gradients
2. Returns the adaptively scaled gradient

Args:
orth_grad: The orthogonalized gradient tensor.
moment2: The second moment buffer from state.
beta2: The exponential decay rate for second moment.
eps: Small constant for numerical stability.
grad_fro_sq: (NAMO only) Squared Frobenius norm of the raw gradient
(before momentum), ``||G_t||_F^2``.
pre_orth_norm: (NAMO only) Frobenius norm of the gradient that enters
orthogonalization (after momentum + Nesterov), ``||g_t^{pre-orth}||_F``.

Returns:
The adaptively scaled weight update tensor.
Expand Down Expand Up @@ -175,6 +197,17 @@ def _apply_moment2_normalization(
step_size = moment2.clamp_min(eps).rsqrt_()
return orth_grad * step_size

elif self.moment2_method == "namo":
assert grad_fro_sq is not None and pre_orth_norm is not None, "NAMO requires grad_fro_sq and pre_orth_norm"
# NAMO: Scalar adaptive scaling via Frobenius-norm ratio
# v_t = β2 * v_{t-1} + (1 - β2) * ||G_t||_F^2
moment2.lerp_(grad_fro_sq, 1 - beta2)

# α_t = ||pre_orth_grad||_F / (sqrt(v_t) + ε)
alpha_t = pre_orth_norm / (moment2.sqrt() + eps)

return orth_grad * alpha_t

else:
raise TypeError(f"Invalid second moment method: {self.moment2_method}")

Expand Down Expand Up @@ -219,6 +252,11 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
group["weight_decay"],
)

# NAMO: capture ||G_t||_F^2 before momentum update
grad_fro_sq: torch.Tensor | None = None
if self.moment2_method == "namo":
grad_fro_sq = torch.linalg.vector_norm(grad).square()

# update momentum buffer with EMA of gradient
exp_avg.lerp_(grad, 1 - group["momentum_beta"])

Expand All @@ -231,15 +269,26 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
group_kwargs = {k: v for k, v in group.items() if k != "params"}
orth_grad = self.orthogonalize(p, grad, **group_kwargs)

update = self._apply_moment2_normalization(
orth_grad=orth_grad,
moment2=state["moment2_buffer"],
beta2=group["beta2"],
eps=group["eps"],
)
if self.moment2_method == "namo":
# Capture ||pre_orth_grad||_F (after momentum + Nesterov)
update = self._apply_moment2_normalization(
orth_grad=orth_grad,
moment2=state["moment2_buffer"],
beta2=group["beta2"],
eps=group["eps"],
grad_fro_sq=grad_fro_sq,
pre_orth_norm=torch.linalg.vector_norm(grad),
)
else:
update = self._apply_moment2_normalization(
orth_grad=orth_grad,
moment2=state["moment2_buffer"],
beta2=group["beta2"],
eps=group["eps"],
)

# perform weight update
# scale is applied to have update RMS == 1
# scale is applied to have update RMS roughly 1
p.add_(update, alpha=-group["lr"])

return loss
6 changes: 5 additions & 1 deletion tests/test_adaptive_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
class AdaptiveMuonTest(parameterized.TestCase):
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
second_moment_method=["adamuon", "normuon"],
second_moment_method=["adamuon", "normuon", "namo"],
use_nesterov=[True, False],
)
def test_smoke(self, shape, second_moment_method, use_nesterov) -> None:
Expand All @@ -55,6 +55,7 @@ def test_smoke(self, shape, second_moment_method, use_nesterov) -> None:
@parameterized.parameters(
{"shape": (8, 16), "second_moment_method": "adamuon"},
{"shape": (16, 8), "second_moment_method": "normuon"},
{"shape": (8, 16), "second_moment_method": "namo"},
)
def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None:
"""Test that second moment buffers are properly initialized."""
Expand Down Expand Up @@ -93,6 +94,9 @@ def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None
expected_shape = list(shape)
expected_shape[avg_dim] = 1
self.assertEqual(list(second_moment.shape), expected_shape)
elif second_moment_method == "namo":
# Scalar buffer
self.assertEqual(second_moment.shape, torch.Size([]))

def test_unknown_moment2_method_raise_type_error(self) -> None:
"""Test that AdaptiveMuon raises TypeError for unknown moment2_method."""
Expand Down