-
Notifications
You must be signed in to change notification settings - Fork 16
SPEL: Spectral steepest descent on the Stiefel manifold #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import override | ||
|
|
||
| import torch | ||
| from absl import logging | ||
| from torch.optim.optimizer import ParamsT | ||
|
|
||
| from emerging_optimizers import registry, triton_kernels, utils | ||
| from emerging_optimizers.mixin import WeightDecayT | ||
| from emerging_optimizers.orthogonalized_optimizers import muon_utils | ||
| from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT | ||
| from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc | ||
| from emerging_optimizers.utils import FP32MatmulPrecT | ||
|
|
||
|
|
||
| __all__ = ["Spel"] | ||
|
|
||
|
|
||
| @registry.register_optimizer("spel") | ||
| class Spel(OrthogonalizedOptimizer): | ||
| r"""SPEL: SPectral steepest descent on the stiefEL manifold | ||
|
|
||
| SPEL is the spectral-norm specialization of Manifold Constrained Steepest Descent (MCSD) | ||
| on the Stiefel manifold. It selects a norm-induced steepest-descent direction via the matrix | ||
| sign function applied to the momentum, then projects back onto the manifold via Newton-Schulz iteration: | ||
|
|
||
| .. math:: | ||
|
|
||
| x_{{t+1}} = \text{{msign}}\!\left(x_t - \alpha_t \, \text{{msign}}\!\left(\nabla_M f(x_t)\right)\right) | ||
|
|
||
| The inner :math:`\text{{msign}}` orthogonalizes the gradient via Newton-Schulz iteration | ||
| (identical to Muon, without scale factors). The outer :math:`\text{{msign}}` re-projects the | ||
| updated weights onto the Stiefel manifold, keeping parameters on (or near) the orthogonal | ||
| manifold. Both operations admit scalable implementations via fast matrix sign computations. | ||
|
|
||
| Note: | ||
| Weight decay still has an effect despite the re-orthogonalization. Before projection, | ||
| :math:`(1 - \eta_t*\lambda) \, x_t - \eta_t \, \text{{msign}}(\cdot)` acts as a convex-like | ||
| linear combination that rebalances the proportion of the previous weight and the new update. | ||
| The outer :math:`\text{{msign}}` then re-orthogonalizes this mixture, so weight decay controls | ||
| the relative influence of the old parameters versus the update direction. | ||
|
|
||
| References: | ||
| - *Manifold Constrained Steepest Descent.* arXiv:2601.21487 (2026). | ||
| [`arXiv:2601.21487 <https://arxiv.org/abs/2601.21487>`_] | ||
| - Jordan, K. *Muon Optimizer Implementation.* | ||
| [`GitHub <https://github.com/KellerJordan/Muon/blob/master/muon.py>`_] | ||
| - *Modular Duality in Deep Learning.* arXiv:2410.21265 (2024). | ||
| [`arXiv:2410.21265 <https://arxiv.org/abs/2410.21265>`_] | ||
|
|
||
| Warning: | ||
| - This optimizer requires that all parameters passed in are 2D. | ||
| - It should not be used for the embedding layer, the final fully connected layer, or any 1-D | ||
| parameters; those can all be optimized by a standard method (e.g., AdamW). | ||
|
|
||
| Args: | ||
| {{_args_doc}} | ||
| coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of | ||
| ["simple", "quintic", "polar_express"]. | ||
| num_ns_steps: The number of Newton-Schulz iteration steps for both gradient orthogonalization | ||
| and the post-update weight projection. | ||
| use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| params: ParamsT, | ||
| lr: float = 3e-4, | ||
| momentum_beta: float = 0.95, | ||
| weight_decay: float = 0.1, | ||
| *, | ||
| use_nesterov: bool = False, | ||
| weight_decay_method: WeightDecayT = "decoupled", | ||
| fp32_matmul_prec: FP32MatmulPrecT = "medium", | ||
| coefficient_type: NSCoeffT = "quintic", | ||
| num_ns_steps: int = 5, | ||
| use_syrk: bool = False, | ||
| ) -> None: | ||
| if num_ns_steps < 1: | ||
| raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") | ||
|
|
||
| if use_syrk: | ||
| if torch.cuda.is_available(): | ||
| sm_version = torch.cuda.get_device_capability() | ||
| else: | ||
| sm_version = (0, 0) | ||
| if not triton_kernels.HAS_TRITON_340: # type: ignore[attr-defined] | ||
| logging.error("Triton 3.4.0 or higher is required for use_syrk to be True.") | ||
| use_syrk = False | ||
| elif sm_version not in ((8, 0), (9, 0), (10, 0), (10, 3)): | ||
| logging.error( | ||
| f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False." | ||
| ) | ||
| use_syrk = False | ||
|
|
||
| def scaled_orthogonalize_fn(X: torch.Tensor) -> torch.Tensor: | ||
| logging.debug(f"Orthogonalizing with {num_ns_steps} steps, {coefficient_type} coefficient") | ||
| return muon_utils.newton_schulz( | ||
| X, | ||
| steps=num_ns_steps, | ||
| coefficient_type=coefficient_type, | ||
| use_syrk=use_syrk, | ||
| ) | ||
|
|
||
| super().__init__( | ||
| params, | ||
| lr, | ||
| momentum_beta, | ||
| weight_decay, | ||
| use_nesterov=use_nesterov, | ||
| weight_decay_method=weight_decay_method, | ||
| fp32_matmul_prec=fp32_matmul_prec, | ||
| scaled_orthogonalize_fn=scaled_orthogonalize_fn, | ||
| ) | ||
|
|
||
| @override | ||
| def post_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None: | ||
| """Re-orthogonalize the weight matrix after the update via Newton-Schulz. | ||
|
|
||
| This projects the updated weight matrix back onto (or near) the orthogonal manifold, | ||
| implementing the outer msign in: x_{t+1} = msign(x_t - lr * msign(momentum)). | ||
|
|
||
| Args: | ||
| p: The updated parameter tensor. | ||
| update: The orthogonalized gradient tensor that was applied. | ||
| """ | ||
| with utils.fp32_matmul_precision(self.fp32_matmul_prec): | ||
| orth_p = self.scaled_orthogonalize_fn(p) | ||
| p.copy_(orth_p) | ||
|
Comment on lines
+141
to
+143
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redundant The base class However, note that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes this was intentional |
||
|
|
||
|
|
||
| Spel.__doc__ = Spel.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr] | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,92 @@ | ||||||||||||||||||||||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import torch | ||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||
| from absl import flags | ||||||||||||||||||||||
| from absl.testing import absltest, parameterized | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from emerging_optimizers.orthogonalized_optimizers import spel | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| FLAGS = flags.FLAGS | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class SpelTest(parameterized.TestCase): | ||||||||||||||||||||||
| @parameterized.product( | ||||||||||||||||||||||
| shape=[(5, 7), (33, 65), (127, 257)], | ||||||||||||||||||||||
| weight_decay_method=["decoupled", "independent", "l2"], | ||||||||||||||||||||||
| use_nesterov=[True, False], | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
Comment on lines
+30
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Smoke test doesn't cover The base Since SPEL runs Newton-Schulz twice per step (once for gradient orthogonalization, once in
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unnecessary |
||||||||||||||||||||||
| def test_smoke(self, shape, weight_decay_method, use_nesterov) -> None: | ||||||||||||||||||||||
| """Smoke test Spel optimizer with various shapes, weight decay methods, and Nesterov.""" | ||||||||||||||||||||||
| test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=FLAGS.device)) | ||||||||||||||||||||||
| test_param.grad = torch.randint_like(test_param, -5, 5) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| spel_opt = spel.Spel( | ||||||||||||||||||||||
| [test_param], | ||||||||||||||||||||||
| weight_decay_method=weight_decay_method, | ||||||||||||||||||||||
| use_nesterov=use_nesterov, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| spel_opt.step() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @parameterized.product( | ||||||||||||||||||||||
| shape=[(5, 7), (33, 65), (127, 257)], | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| def test_post_update_produces_approximately_orthogonal_weights(self, shape) -> None: | ||||||||||||||||||||||
| """Test that post_weight_update_fn_inplace produces approximately orthogonal matrices. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| After each optimizer step, the weight matrix W should satisfy W @ W^T ≈ I (up to scale) | ||||||||||||||||||||||
| for the smaller dimension, which is the defining property of SPEL. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| spel_opt = spel.Spel( | ||||||||||||||||||||||
| [test_param], | ||||||||||||||||||||||
| lr=0.01, | ||||||||||||||||||||||
| momentum_beta=0.0, | ||||||||||||||||||||||
| weight_decay=0.0, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for _ in range(5): | ||||||||||||||||||||||
| test_param.grad = torch.randn_like(test_param) | ||||||||||||||||||||||
| spel_opt.step() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # After post_weight_update_fn_inplace, W should be approximately orthogonal. | ||||||||||||||||||||||
| # For an m x n matrix with m <= n: W @ W^T ≈ I_m | ||||||||||||||||||||||
| # For an m x n matrix with m > n: W^T @ W ≈ I_n | ||||||||||||||||||||||
| W = test_param.data | ||||||||||||||||||||||
| m, n = W.shape | ||||||||||||||||||||||
| if m <= n: | ||||||||||||||||||||||
| WWT = W @ W.mT | ||||||||||||||||||||||
| eye = torch.eye(m, device=FLAGS.device) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| WWT = W.mT @ W | ||||||||||||||||||||||
| eye = torch.eye(n, device=FLAGS.device) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Newton-Schulz normalizes the spectral norm to ~1, so WWT ≈ I | ||||||||||||||||||||||
| torch.testing.assert_close( | ||||||||||||||||||||||
| WWT, | ||||||||||||||||||||||
| eye, | ||||||||||||||||||||||
| atol=0.1, | ||||||||||||||||||||||
| rtol=0.1, | ||||||||||||||||||||||
| msg=f"Weight matrix of shape {shape} is not approximately orthogonal after step", | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||
| absltest.main() | ||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decoupled weight decay has no effect
With the default
weight_decay_method="decoupled"(or"independent"), weight decay shrinkspin-place before the gradient update. However,post_weight_update_fn_inplacethen callsnewton_schulz(p), which normalizespby its Frobenius norm (muon_utils.py:171) before iterating. This normalization erases the scale change from weight decay, making decoupled/independent weight decay a no-op.The default
weight_decay=0.1may mislead users into thinking weight decay is being applied, when it effectively isn't. Consider either:weight_decay=0.0to avoid confusion, orweight_decay_method="l2"which modifies the gradient direction (not the parameter scale) and would survive the re-projection.Note: L2 weight decay (
weight_decay_method="l2") would still have an effect since it modifies the gradient rather than the weight scale.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not true, ignore