From 2d25b60cae465aa66124485d77addc7b5943bf46 Mon Sep 17 00:00:00 2001 From: mikail Date: Wed, 18 Feb 2026 14:04:12 -0800 Subject: [PATCH 1/4] Added initial implementation of SPEL Signed-off-by: mikail --- .../orthogonalized_optimizers/__init__.py | 1 + .../orthogonalized_optimizers/spel.py | 141 ++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 emerging_optimizers/orthogonalized_optimizers/spel.py diff --git a/emerging_optimizers/orthogonalized_optimizers/__init__.py b/emerging_optimizers/orthogonalized_optimizers/__init__.py index 7e8ddc4..13f83f3 100644 --- a/emerging_optimizers/orthogonalized_optimizers/__init__.py +++ b/emerging_optimizers/orthogonalized_optimizers/__init__.py @@ -18,4 +18,5 @@ from emerging_optimizers.orthogonalized_optimizers.muon_hyperball import * from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import * from emerging_optimizers.orthogonalized_optimizers.scion import * +from emerging_optimizers.orthogonalized_optimizers.spel import * from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import * diff --git a/emerging_optimizers/orthogonalized_optimizers/spel.py b/emerging_optimizers/orthogonalized_optimizers/spel.py new file mode 100644 index 0000000..707c6a6 --- /dev/null +++ b/emerging_optimizers/orthogonalized_optimizers/spel.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import override + +import torch +from absl import logging +from torch.optim.optimizer import ParamsT + +from emerging_optimizers import registry, triton_kernels, utils +from emerging_optimizers.mixin import WeightDecayT +from emerging_optimizers.orthogonalized_optimizers import muon_utils +from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT +from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc +from emerging_optimizers.utils import FP32MatmulPrecT + + +__all__ = ["Spel"] + + +@registry.register_optimizer("spel") +class Spel(OrthogonalizedOptimizer): + r"""SPEL: SPectral steepest descent on the stiefEL manifold + + SPEL is the spectral-norm specialization of Manifold Constrained Steepest Descent (MCSD) + on the Stiefel manifold. It selects a norm-induced steepest-descent direction via the matrix + sign function applied to the momentum, then projects back onto the manifold via Newton-Schulz iteration: + + .. math:: + + x_{{t+1}} = \text{{msign}}\!\left(x_t - \alpha_t \, \text{{msign}}\!\left(\nabla_M f(x_t)\right)\right) + + The inner :math:`\text{{msign}}` orthogonalizes the gradient via Newton-Schulz iteration + (identical to Muon, without scale factors). The outer :math:`\text{{msign}}` re-projects the + updated weights onto the Stiefel manifold, keeping parameters on (or near) the orthogonal + manifold. Both operations admit scalable implementations via fast matrix sign computations. + + References: + - *Manifold Constrained Steepest Descent.* arXiv:2601.21487 (2026). + [`arXiv:2601.21487 `_] + - Jordan, K. *Muon Optimizer Implementation.* + [`GitHub `_] + - *Modular Duality in Deep Learning.* arXiv:2410.21265 (2024). + [`arXiv:2410.21265 `_] + + Warning: + - This optimizer requires that all parameters passed in are 2D. + - It should not be used for the embedding layer, the final fully connected layer, or any 1-D + parameters; those can all be optimized by a standard method (e.g., AdamW). + + Args: + {{_args_doc}} + coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of + ["simple", "quintic", "polar_express"]. + num_ns_steps: The number of Newton-Schulz iteration steps for both gradient orthogonalization + and the post-update weight projection. + use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration. + """ + + def __init__( + self, + params: ParamsT, + lr: float = 3e-4, + momentum_beta: float = 0.95, + weight_decay: float = 0.1, + *, + use_nesterov: bool = False, + weight_decay_method: WeightDecayT = "decoupled", + fp32_matmul_prec: FP32MatmulPrecT = "medium", + coefficient_type: NSCoeffT = "quintic", + num_ns_steps: int = 5, + use_syrk: bool = False, + ) -> None: + if num_ns_steps < 1: + raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") + + if use_syrk: + if torch.cuda.is_available(): + sm_version = torch.cuda.get_device_capability() + else: + sm_version = (0, 0) + if not triton_kernels.HAS_TRITON_340: # type: ignore[attr-defined] + logging.error("Triton 3.4.0 or higher is required for use_syrk to be True.") + use_syrk = False + elif sm_version not in ((8, 0), (9, 0), (10, 0), (10, 3)): + logging.error( + f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False." + ) + use_syrk = False + + def scaled_orthogonalize_fn(X: torch.Tensor) -> torch.Tensor: + logging.debug( + f"Orthogonalizing with {num_ns_steps} steps, {coefficient_type} coefficient" + ) + return muon_utils.newton_schulz( + X, + steps=num_ns_steps, + coefficient_type=coefficient_type, + use_syrk=use_syrk, + ) + + super().__init__( + params, + lr, + momentum_beta, + weight_decay, + use_nesterov=use_nesterov, + weight_decay_method=weight_decay_method, + fp32_matmul_prec=fp32_matmul_prec, + scaled_orthogonalize_fn=scaled_orthogonalize_fn, + ) + + @override + def post_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None: + """Re-orthogonalize the weight matrix after the update via Newton-Schulz. + + This projects the updated weight matrix back onto (or near) the orthogonal manifold, + implementing the outer msign in: x_{t+1} = msign(x_t - lr * msign(momentum)). + + Args: + p: The updated parameter tensor. + update: The orthogonalized gradient tensor that was applied. + """ + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + orth_p = self.scaled_orthogonalize_fn(p) + p.copy_(orth_p) + + +Spel.__doc__ = Spel.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr] From c07cadd103f751b8cd3a380a4cea22f62b9e18bf Mon Sep 17 00:00:00 2001 From: mikail Date: Wed, 18 Feb 2026 14:16:02 -0800 Subject: [PATCH 2/4] added weight decay clarificaiton Signed-off-by: mikail --- emerging_optimizers/orthogonalized_optimizers/spel.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/emerging_optimizers/orthogonalized_optimizers/spel.py b/emerging_optimizers/orthogonalized_optimizers/spel.py index 707c6a6..8780729 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spel.py +++ b/emerging_optimizers/orthogonalized_optimizers/spel.py @@ -47,6 +47,13 @@ class Spel(OrthogonalizedOptimizer): updated weights onto the Stiefel manifold, keeping parameters on (or near) the orthogonal manifold. Both operations admit scalable implementations via fast matrix sign computations. + Note: + Weight decay still has an effect despite the re-orthogonalization. Before projection, + :math:`(1 - \eta_t*\lambda) \, x_t - \eta_t \, \text{{msign}}(\cdot)` acts as a convex-like + linear combination that rebalances the proportion of the previous weight and the new update. + The outer :math:`\text{{msign}}` then re-orthogonalizes this mixture, so weight decay controls + the relative influence of the old parameters versus the update direction. + References: - *Manifold Constrained Steepest Descent.* arXiv:2601.21487 (2026). [`arXiv:2601.21487 `_] From f58488a220a7e20f9192ceded47d969bd93ebf12 Mon Sep 17 00:00:00 2001 From: mikail Date: Wed, 18 Feb 2026 14:23:28 -0800 Subject: [PATCH 3/4] added unit test to ci Signed-off-by: mikail --- tests/ci/L0_Tests_GPU.sh | 1 + tests/test_spel.py | 91 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 tests/test_spel.py diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index 25866ce..0a644dc 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -19,6 +19,7 @@ error=0 coverage run -p --source=emerging_optimizers tests/test_muon_utils.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_adaptive_muon.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_spel.py --device=cuda -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_soap_utils.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_soap.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py -v -2 || error=1 diff --git a/tests/test_spel.py b/tests/test_spel.py new file mode 100644 index 0000000..f6f8dcb --- /dev/null +++ b/tests/test_spel.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from absl import flags +from absl.testing import absltest, parameterized + +from emerging_optimizers.orthogonalized_optimizers import spel + +flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") + +FLAGS = flags.FLAGS + + +class SpelTest(parameterized.TestCase): + @parameterized.product( + shape=[(5, 7), (33, 65), (127, 257)], + weight_decay_method=["decoupled", "independent", "l2"], + use_nesterov=[True, False], + ) + def test_smoke(self, shape, weight_decay_method, use_nesterov) -> None: + """Smoke test Spel optimizer with various shapes, weight decay methods, and Nesterov.""" + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=FLAGS.device)) + test_param.grad = torch.randint_like(test_param, -5, 5) + + spel_opt = spel.Spel( + [test_param], + weight_decay_method=weight_decay_method, + use_nesterov=use_nesterov, + ) + spel_opt.step() + + @parameterized.product( + shape=[(5, 7), (33, 65), (127, 257)], + ) + def test_post_update_produces_approximately_orthogonal_weights(self, shape) -> None: + """Test that post_weight_update_fn_inplace produces approximately orthogonal matrices. + + After each optimizer step, the weight matrix W should satisfy W @ W^T ≈ I (up to scale) + for the smaller dimension, which is the defining property of SPEL. + """ + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + + spel_opt = spel.Spel( + [test_param], + lr=0.01, + momentum_beta=0.0, + weight_decay=0.0, + ) + + for _ in range(5): + test_param.grad = torch.randn_like(test_param) + spel_opt.step() + + # After post_weight_update_fn_inplace, W should be approximately orthogonal. + # For an m x n matrix with m <= n: W @ W^T ≈ I_m + # For an m x n matrix with m > n: W^T @ W ≈ I_n + W = test_param.data + m, n = W.shape + if m <= n: + WWT = W @ W.mT + eye = torch.eye(m, device=FLAGS.device) + else: + WWT = W.mT @ W + eye = torch.eye(n, device=FLAGS.device) + + # Newton-Schulz normalizes the spectral norm to ~1, so WWT ≈ I + torch.testing.assert_close( + WWT, + eye, + atol=0.1, + rtol=0.1, + msg=f"Weight matrix of shape {shape} is not approximately orthogonal after step", + ) + + +if __name__ == "__main__": + absltest.main() From 1a8beee77a3b91b0f7b9c2bae7f09fca76cc7b2e Mon Sep 17 00:00:00 2001 From: mikail Date: Wed, 18 Feb 2026 18:20:09 -0800 Subject: [PATCH 4/4] linting Signed-off-by: mikail --- emerging_optimizers/orthogonalized_optimizers/__init__.py | 2 +- emerging_optimizers/orthogonalized_optimizers/spel.py | 4 +--- tests/test_spel.py | 1 + 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/__init__.py b/emerging_optimizers/orthogonalized_optimizers/__init__.py index 13f83f3..c09acb3 100644 --- a/emerging_optimizers/orthogonalized_optimizers/__init__.py +++ b/emerging_optimizers/orthogonalized_optimizers/__init__.py @@ -18,5 +18,5 @@ from emerging_optimizers.orthogonalized_optimizers.muon_hyperball import * from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import * from emerging_optimizers.orthogonalized_optimizers.scion import * -from emerging_optimizers.orthogonalized_optimizers.spel import * from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import * +from emerging_optimizers.orthogonalized_optimizers.spel import * diff --git a/emerging_optimizers/orthogonalized_optimizers/spel.py b/emerging_optimizers/orthogonalized_optimizers/spel.py index 8780729..f2c72e3 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spel.py +++ b/emerging_optimizers/orthogonalized_optimizers/spel.py @@ -108,9 +108,7 @@ def __init__( use_syrk = False def scaled_orthogonalize_fn(X: torch.Tensor) -> torch.Tensor: - logging.debug( - f"Orthogonalizing with {num_ns_steps} steps, {coefficient_type} coefficient" - ) + logging.debug(f"Orthogonalizing with {num_ns_steps} steps, {coefficient_type} coefficient") return muon_utils.newton_schulz( X, steps=num_ns_steps, diff --git a/tests/test_spel.py b/tests/test_spel.py index f6f8dcb..c16b1dd 100644 --- a/tests/test_spel.py +++ b/tests/test_spel.py @@ -20,6 +20,7 @@ from emerging_optimizers.orthogonalized_optimizers import spel + flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") FLAGS = flags.FLAGS