Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions emerging_optimizers/orthogonalized_optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
from emerging_optimizers.orthogonalized_optimizers.scion import *
from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import *
from emerging_optimizers.orthogonalized_optimizers.spel import *
146 changes: 146 additions & 0 deletions emerging_optimizers/orthogonalized_optimizers/spel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import override

import torch
from absl import logging
from torch.optim.optimizer import ParamsT

from emerging_optimizers import registry, triton_kernels, utils
from emerging_optimizers.mixin import WeightDecayT
from emerging_optimizers.orthogonalized_optimizers import muon_utils
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
from emerging_optimizers.utils import FP32MatmulPrecT


__all__ = ["Spel"]


@registry.register_optimizer("spel")
class Spel(OrthogonalizedOptimizer):
r"""SPEL: SPectral steepest descent on the stiefEL manifold

SPEL is the spectral-norm specialization of Manifold Constrained Steepest Descent (MCSD)
on the Stiefel manifold. It selects a norm-induced steepest-descent direction via the matrix
sign function applied to the momentum, then projects back onto the manifold via Newton-Schulz iteration:

.. math::

x_{{t+1}} = \text{{msign}}\!\left(x_t - \alpha_t \, \text{{msign}}\!\left(\nabla_M f(x_t)\right)\right)

The inner :math:`\text{{msign}}` orthogonalizes the gradient via Newton-Schulz iteration
(identical to Muon, without scale factors). The outer :math:`\text{{msign}}` re-projects the
updated weights onto the Stiefel manifold, keeping parameters on (or near) the orthogonal
manifold. Both operations admit scalable implementations via fast matrix sign computations.

Note:
Weight decay still has an effect despite the re-orthogonalization. Before projection,
:math:`(1 - \eta_t*\lambda) \, x_t - \eta_t \, \text{{msign}}(\cdot)` acts as a convex-like
linear combination that rebalances the proportion of the previous weight and the new update.
The outer :math:`\text{{msign}}` then re-orthogonalizes this mixture, so weight decay controls
the relative influence of the old parameters versus the update direction.

References:
- *Manifold Constrained Steepest Descent.* arXiv:2601.21487 (2026).
[`arXiv:2601.21487 <https://arxiv.org/abs/2601.21487>`_]
- Jordan, K. *Muon Optimizer Implementation.*
[`GitHub <https://github.com/KellerJordan/Muon/blob/master/muon.py>`_]
- *Modular Duality in Deep Learning.* arXiv:2410.21265 (2024).
[`arXiv:2410.21265 <https://arxiv.org/abs/2410.21265>`_]

Warning:
- This optimizer requires that all parameters passed in are 2D.
- It should not be used for the embedding layer, the final fully connected layer, or any 1-D
parameters; those can all be optimized by a standard method (e.g., AdamW).

Args:
{{_args_doc}}
coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of
["simple", "quintic", "polar_express"].
num_ns_steps: The number of Newton-Schulz iteration steps for both gradient orthogonalization
and the post-update weight projection.
use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration.
"""

def __init__(
self,
params: ParamsT,
lr: float = 3e-4,
momentum_beta: float = 0.95,
weight_decay: float = 0.1,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decoupled weight decay has no effect

With the default weight_decay_method="decoupled" (or "independent"), weight decay shrinks p in-place before the gradient update. However, post_weight_update_fn_inplace then calls newton_schulz(p), which normalizes p by its Frobenius norm (muon_utils.py:171) before iterating. This normalization erases the scale change from weight decay, making decoupled/independent weight decay a no-op.

The default weight_decay=0.1 may mislead users into thinking weight decay is being applied, when it effectively isn't. Consider either:

  • Documenting that weight decay has no effect for this optimizer (since parameters are re-projected onto the Stiefel manifold each step), or
  • Defaulting weight_decay=0.0 to avoid confusion, or
  • Only supporting weight_decay_method="l2" which modifies the gradient direction (not the parameter scale) and would survive the re-projection.

Note: L2 weight decay (weight_decay_method="l2") would still have an effect since it modifies the gradient rather than the weight scale.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not true, ignore

*,
use_nesterov: bool = False,
weight_decay_method: WeightDecayT = "decoupled",
fp32_matmul_prec: FP32MatmulPrecT = "medium",
coefficient_type: NSCoeffT = "quintic",
num_ns_steps: int = 5,
use_syrk: bool = False,
) -> None:
if num_ns_steps < 1:
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")

if use_syrk:
if torch.cuda.is_available():
sm_version = torch.cuda.get_device_capability()
else:
sm_version = (0, 0)
if not triton_kernels.HAS_TRITON_340: # type: ignore[attr-defined]
logging.error("Triton 3.4.0 or higher is required for use_syrk to be True.")
use_syrk = False
elif sm_version not in ((8, 0), (9, 0), (10, 0), (10, 3)):
logging.error(
f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False."
)
use_syrk = False

def scaled_orthogonalize_fn(X: torch.Tensor) -> torch.Tensor:
logging.debug(f"Orthogonalizing with {num_ns_steps} steps, {coefficient_type} coefficient")
return muon_utils.newton_schulz(
X,
steps=num_ns_steps,
coefficient_type=coefficient_type,
use_syrk=use_syrk,
)

super().__init__(
params,
lr,
momentum_beta,
weight_decay,
use_nesterov=use_nesterov,
weight_decay_method=weight_decay_method,
fp32_matmul_prec=fp32_matmul_prec,
scaled_orthogonalize_fn=scaled_orthogonalize_fn,
)

@override
def post_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None:
"""Re-orthogonalize the weight matrix after the update via Newton-Schulz.

This projects the updated weight matrix back onto (or near) the orthogonal manifold,
implementing the outer msign in: x_{t+1} = msign(x_t - lr * msign(momentum)).

Args:
p: The updated parameter tensor.
update: The orthogonalized gradient tensor that was applied.
"""
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
orth_p = self.scaled_orthogonalize_fn(p)
p.copy_(orth_p)
Comment on lines +141 to +143
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant fp32_matmul_precision context manager

The base class step() method in orthogonalized_optimizer.py:171 already wraps the orthogonalize call with utils.fp32_matmul_precision(self.fp32_matmul_prec), but post_weight_update_fn_inplace is called outside that context manager (line 178). So the explicit with utils.fp32_matmul_precision(self.fp32_matmul_prec): here is actually necessary and correct for the outer msign.

However, note that newton_schulz itself (in muon_utils.py:187) checks torch.get_float32_matmul_precision() to decide whether to use BF16 I/O kernels. So this context manager is indeed needed. Just wanted to confirm this is intentional — looks correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this was intentional



Spel.__doc__ = Spel.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]
1 change: 1 addition & 0 deletions tests/ci/L0_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ error=0
coverage run -p --source=emerging_optimizers tests/test_muon_utils.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_adaptive_muon.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_spel.py --device=cuda -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_soap_utils.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_soap.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py -v -2 || error=1
Expand Down
92 changes: 92 additions & 0 deletions tests/test_spel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from absl import flags
from absl.testing import absltest, parameterized

from emerging_optimizers.orthogonalized_optimizers import spel


flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'")

FLAGS = flags.FLAGS


class SpelTest(parameterized.TestCase):
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
weight_decay_method=["decoupled", "independent", "l2"],
use_nesterov=[True, False],
)
Comment on lines +30 to +34
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smoke test doesn't cover fp32_matmul_prec

The base OrthogonalizedOptimizerTest.test_smoke parameterizes over fp32_matmul_prec=["highest", "medium", "low"], which exercises different matmul precision paths in Newton-Schulz (notably, "medium" triggers the bf16 I/O path at muon_utils.py:187-193).

Since SPEL runs Newton-Schulz twice per step (once for gradient orthogonalization, once in post_weight_update_fn_inplace), it has more surface area for precision-related issues. Consider adding fp32_matmul_prec to the parameterized product to match the coverage of other optimizer tests:

Suggested change
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
weight_decay_method=["decoupled", "independent", "l2"],
use_nesterov=[True, False],
)
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
weight_decay_method=["decoupled", "independent", "l2"],
use_nesterov=[True, False],
fp32_matmul_prec=["highest", "medium"],

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary

def test_smoke(self, shape, weight_decay_method, use_nesterov) -> None:
"""Smoke test Spel optimizer with various shapes, weight decay methods, and Nesterov."""
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=FLAGS.device))
test_param.grad = torch.randint_like(test_param, -5, 5)

spel_opt = spel.Spel(
[test_param],
weight_decay_method=weight_decay_method,
use_nesterov=use_nesterov,
)
spel_opt.step()

@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
)
def test_post_update_produces_approximately_orthogonal_weights(self, shape) -> None:
"""Test that post_weight_update_fn_inplace produces approximately orthogonal matrices.

After each optimizer step, the weight matrix W should satisfy W @ W^T ≈ I (up to scale)
for the smaller dimension, which is the defining property of SPEL.
"""
test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device))

spel_opt = spel.Spel(
[test_param],
lr=0.01,
momentum_beta=0.0,
weight_decay=0.0,
)

for _ in range(5):
test_param.grad = torch.randn_like(test_param)
spel_opt.step()

# After post_weight_update_fn_inplace, W should be approximately orthogonal.
# For an m x n matrix with m <= n: W @ W^T ≈ I_m
# For an m x n matrix with m > n: W^T @ W ≈ I_n
W = test_param.data
m, n = W.shape
if m <= n:
WWT = W @ W.mT
eye = torch.eye(m, device=FLAGS.device)
else:
WWT = W.mT @ W
eye = torch.eye(n, device=FLAGS.device)

# Newton-Schulz normalizes the spectral norm to ~1, so WWT ≈ I
torch.testing.assert_close(
WWT,
eye,
atol=0.1,
rtol=0.1,
msg=f"Weight matrix of shape {shape} is not approximately orthogonal after step",
)


if __name__ == "__main__":
absltest.main()