Skip to content

SPEL: Spectral steepest descent on the Stiefel manifold#106

Draft
mkhona-nvidia wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
mkhona-nvidia:mkhona/spel
Draft

SPEL: Spectral steepest descent on the Stiefel manifold#106
mkhona-nvidia wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
mkhona-nvidia:mkhona/spel

Conversation

@mkhona-nvidia
Copy link
Contributor

@mkhona-nvidia mkhona-nvidia commented Feb 18, 2026

Adds the Spel optimizer, the spectral-norm specialization of Manifold Constrained Steepest Descent (MCSD) on the Stiefel manifold based on https://arxiv.org/abs/2601.21487

Algorithm:
$$x_{t+1} = \text{msign}\left(x_t - \alpha_t , \text{msign}\left(\nabla_M f(x_t)\right)\right)$$

Inner msign — orthogonalizes the momentum via Newton-Schulz iteration (same as Muon, without scale factors).
Outer msign — re-projects the updated weights onto the Stiefel manifold via Newton-Schulz iteration.

Signed-off-by: mikail <mkhona@nvidia.com>
@mkhona-nvidia mkhona-nvidia requested a review from skyw February 18, 2026 22:06
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 18, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link

greptile-apps bot commented Feb 18, 2026

Greptile Summary

Adds the SPEL (SPectral steepest descent on the stiefEL manifold) optimizer, a spectral-norm specialization of Manifold Constrained Steepest Descent on the Stiefel manifold based on arXiv:2601.21487. The optimizer performs double orthogonalization: inner msign orthogonalizes the gradient momentum via Newton-Schulz iteration (like Muon), and outer msign re-projects updated weights onto the orthogonal manifold.

Key Changes:

  • New Spel class in spel.py extending OrthogonalizedOptimizer with post_weight_update_fn_inplace for weight re-orthogonalization
  • Comprehensive test suite with smoke tests covering various shapes, weight decay methods, and Nesterov modes, plus orthogonality property tests
  • CI integration with GPU tests
  • Well-documented implementation with clear mathematical notation and references

Previous Review Comments:
The previous review threads identified documentation around weight decay behavior and test coverage for fp32_matmul_prec as areas that could be enhanced, but these are minor suggestions that don't affect correctness.

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk - solid implementation with comprehensive tests
  • The SPEL optimizer is well-implemented following established patterns in the codebase (similar to Muon). The code includes proper error handling, input validation, comprehensive tests (both smoke and property-based), and clear documentation. Previous review threads have identified minor documentation suggestions around weight decay behavior, but these don't affect correctness. The implementation correctly uses Newton-Schulz iteration for both gradient orthogonalization and weight projection, properly manages precision contexts, and includes appropriate CI integration.
  • No files require special attention - all changes follow codebase conventions

Important Files Changed

Filename Overview
emerging_optimizers/orthogonalized_optimizers/spel.py New SPEL optimizer implementation with dual Newton-Schulz orthogonalization (gradient + weight projection). Well-documented with clear mathematical foundations. Previous threads identified weight decay behavior as documented but potentially confusing.
tests/test_spel.py Comprehensive test suite with smoke tests and orthogonality property tests. Previous thread suggested adding fp32_matmul_prec parameterization to match coverage of other optimizer tests.
emerging_optimizers/orthogonalized_optimizers/init.py Standard module export addition for new SPEL optimizer. No issues.
tests/ci/L0_Tests_GPU.sh Added SPEL tests to CI pipeline with correct flags. No issues.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Start: Weights x_t, Gradient ∇f] --> B[Apply Weight Decay]
    B --> C[Update Momentum Buffer<br/>EMA with gradient]
    C --> D{Nesterov?}
    D -->|Yes| E[Interpolate gradient<br/>with momentum]
    D -->|No| F[Use momentum buffer]
    E --> G[Inner msign:<br/>Orthogonalize momentum<br/>via Newton-Schulz]
    F --> G
    G --> H[Apply Update:<br/>x' = x_t - lr * msign_grad]
    H --> I[Outer msign:<br/>Re-project x' onto Stiefel manifold<br/>via Newton-Schulz]
    I --> J[End: Orthogonal weights x_t+1]
    
    style G fill:#e1f5ff
    style I fill:#e1f5ff
    style J fill:#d4edda
Loading

Last reviewed commit: 1a8beee

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +136 to +138
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
orth_p = self.scaled_orthogonalize_fn(p)
p.copy_(orth_p)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant fp32_matmul_precision context manager

The base class step() method in orthogonalized_optimizer.py:171 already wraps the orthogonalize call with utils.fp32_matmul_precision(self.fp32_matmul_prec), but post_weight_update_fn_inplace is called outside that context manager (line 178). So the explicit with utils.fp32_matmul_precision(self.fp32_matmul_prec): here is actually necessary and correct for the outer msign.

However, note that newton_schulz itself (in muon_utils.py:187) checks torch.get_float32_matmul_precision() to decide whether to use BF16 I/O kernels. So this context manager is indeed needed. Just wanted to confirm this is intentional — looks correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this was intentional

params: ParamsT,
lr: float = 3e-4,
momentum_beta: float = 0.95,
weight_decay: float = 0.1,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decoupled weight decay has no effect

With the default weight_decay_method="decoupled" (or "independent"), weight decay shrinks p in-place before the gradient update. However, post_weight_update_fn_inplace then calls newton_schulz(p), which normalizes p by its Frobenius norm (muon_utils.py:171) before iterating. This normalization erases the scale change from weight decay, making decoupled/independent weight decay a no-op.

The default weight_decay=0.1 may mislead users into thinking weight decay is being applied, when it effectively isn't. Consider either:

  • Documenting that weight decay has no effect for this optimizer (since parameters are re-projected onto the Stiefel manifold each step), or
  • Defaulting weight_decay=0.0 to avoid confusion, or
  • Only supporting weight_decay_method="l2" which modifies the gradient direction (not the parameter scale) and would survive the re-projection.

Note: L2 weight decay (weight_decay_method="l2") would still have an effect since it modifies the gradient rather than the weight scale.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not true, ignore

Signed-off-by: mikail <mkhona@nvidia.com>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +1 to +148
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import override

import torch
from absl import logging
from torch.optim.optimizer import ParamsT

from emerging_optimizers import registry, triton_kernels, utils
from emerging_optimizers.mixin import WeightDecayT
from emerging_optimizers.orthogonalized_optimizers import muon_utils
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
from emerging_optimizers.utils import FP32MatmulPrecT


__all__ = ["Spel"]


@registry.register_optimizer("spel")
class Spel(OrthogonalizedOptimizer):
r"""SPEL: SPectral steepest descent on the stiefEL manifold

SPEL is the spectral-norm specialization of Manifold Constrained Steepest Descent (MCSD)
on the Stiefel manifold. It selects a norm-induced steepest-descent direction via the matrix
sign function applied to the momentum, then projects back onto the manifold via Newton-Schulz iteration:

.. math::

x_{{t+1}} = \text{{msign}}\!\left(x_t - \alpha_t \, \text{{msign}}\!\left(\nabla_M f(x_t)\right)\right)

The inner :math:`\text{{msign}}` orthogonalizes the gradient via Newton-Schulz iteration
(identical to Muon, without scale factors). The outer :math:`\text{{msign}}` re-projects the
updated weights onto the Stiefel manifold, keeping parameters on (or near) the orthogonal
manifold. Both operations admit scalable implementations via fast matrix sign computations.

Note:
Weight decay still has an effect despite the re-orthogonalization. Before projection,
:math:`(1 - \eta_t*\lambda) \, x_t - \eta_t \, \text{{msign}}(\cdot)` acts as a convex-like
linear combination that rebalances the proportion of the previous weight and the new update.
The outer :math:`\text{{msign}}` then re-orthogonalizes this mixture, so weight decay controls
the relative influence of the old parameters versus the update direction.

References:
- *Manifold Constrained Steepest Descent.* arXiv:2601.21487 (2026).
[`arXiv:2601.21487 <https://arxiv.org/abs/2601.21487>`_]
- Jordan, K. *Muon Optimizer Implementation.*
[`GitHub <https://github.com/KellerJordan/Muon/blob/master/muon.py>`_]
- *Modular Duality in Deep Learning.* arXiv:2410.21265 (2024).
[`arXiv:2410.21265 <https://arxiv.org/abs/2410.21265>`_]

Warning:
- This optimizer requires that all parameters passed in are 2D.
- It should not be used for the embedding layer, the final fully connected layer, or any 1-D
parameters; those can all be optimized by a standard method (e.g., AdamW).

Args:
{{_args_doc}}
coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of
["simple", "quintic", "polar_express"].
num_ns_steps: The number of Newton-Schulz iteration steps for both gradient orthogonalization
and the post-update weight projection.
use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration.
"""

def __init__(
self,
params: ParamsT,
lr: float = 3e-4,
momentum_beta: float = 0.95,
weight_decay: float = 0.1,
*,
use_nesterov: bool = False,
weight_decay_method: WeightDecayT = "decoupled",
fp32_matmul_prec: FP32MatmulPrecT = "medium",
coefficient_type: NSCoeffT = "quintic",
num_ns_steps: int = 5,
use_syrk: bool = False,
) -> None:
if num_ns_steps < 1:
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")

if use_syrk:
if torch.cuda.is_available():
sm_version = torch.cuda.get_device_capability()
else:
sm_version = (0, 0)
if not triton_kernels.HAS_TRITON_340: # type: ignore[attr-defined]
logging.error("Triton 3.4.0 or higher is required for use_syrk to be True.")
use_syrk = False
elif sm_version not in ((8, 0), (9, 0), (10, 0), (10, 3)):
logging.error(
f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False."
)
use_syrk = False

def scaled_orthogonalize_fn(X: torch.Tensor) -> torch.Tensor:
logging.debug(
f"Orthogonalizing with {num_ns_steps} steps, {coefficient_type} coefficient"
)
return muon_utils.newton_schulz(
X,
steps=num_ns_steps,
coefficient_type=coefficient_type,
use_syrk=use_syrk,
)

super().__init__(
params,
lr,
momentum_beta,
weight_decay,
use_nesterov=use_nesterov,
weight_decay_method=weight_decay_method,
fp32_matmul_prec=fp32_matmul_prec,
scaled_orthogonalize_fn=scaled_orthogonalize_fn,
)

@override
def post_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None:
"""Re-orthogonalize the weight matrix after the update via Newton-Schulz.

This projects the updated weight matrix back onto (or near) the orthogonal manifold,
implementing the outer msign in: x_{t+1} = msign(x_t - lr * msign(momentum)).

Args:
p: The updated parameter tensor.
update: The orthogonalized gradient tensor that was applied.
"""
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
orth_p = self.scaled_orthogonalize_fn(p)
p.copy_(orth_p)


Spel.__doc__ = Spel.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing unit tests for Spel optimizer

Every other optimizer in this package (Muon, Scion, MOP, MuonHyperball) has corresponding tests in tests/test_orthogonalized_optimizer.py — at minimum a smoke test, and in some cases property-based tests (e.g., MuonHyperball tests norm preservation). Spel has no tests at all.

At minimum, consider adding:

  • A smoke test (similar to MuonTest.test_smoke) exercising different shapes, weight_decay_method, and use_nesterov combinations.
  • A property test verifying that the output of post_weight_update_fn_inplace produces (approximately) orthogonal matrices — this is the distinguishing property of Spel and should be validated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added test

Signed-off-by: mikail <mkhona@nvidia.com>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +29 to +33
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
weight_decay_method=["decoupled", "independent", "l2"],
use_nesterov=[True, False],
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smoke test doesn't cover fp32_matmul_prec

The base OrthogonalizedOptimizerTest.test_smoke parameterizes over fp32_matmul_prec=["highest", "medium", "low"], which exercises different matmul precision paths in Newton-Schulz (notably, "medium" triggers the bf16 I/O path at muon_utils.py:187-193).

Since SPEL runs Newton-Schulz twice per step (once for gradient orthogonalization, once in post_weight_update_fn_inplace), it has more surface area for precision-related issues. Consider adding fp32_matmul_prec to the parameterized product to match the coverage of other optimizer tests:

Suggested change
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
weight_decay_method=["decoupled", "independent", "l2"],
use_nesterov=[True, False],
)
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
weight_decay_method=["decoupled", "independent", "l2"],
use_nesterov=[True, False],
fp32_matmul_prec=["highest", "medium"],

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary

Signed-off-by: mikail <mkhona@nvidia.com>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@skyw skyw marked this pull request as draft February 23, 2026 16:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant