Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/apidocs/orthogonalized-optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ emerging_optimizers.orthogonalized_optimizers
.. autoclass:: MuonHyperball
:members:

:hidden:`Spectron`
~~~~~~~~~~~~~~~~~~~

.. autoclass:: Spectron
:members:


:hidden:`Newton-Schulz`
~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions emerging_optimizers/orthogonalized_optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
from emerging_optimizers.orthogonalized_optimizers.scion import *
from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import *
from emerging_optimizers.orthogonalized_optimizers.spectron import *
265 changes: 265 additions & 0 deletions emerging_optimizers/orthogonalized_optimizers/spectron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, overload, override

import torch
import torch.optim as optim
from absl import logging
from torch.optim.optimizer import ParamsT

from emerging_optimizers import mixin as opt_mixin
from emerging_optimizers import registry, utils
from emerging_optimizers.orthogonalized_optimizers import muon_utils
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
from emerging_optimizers.utils import FP32MatmulPrecT
from emerging_optimizers.utils.eig import power_iteration


__all__ = ["Spectron"]


@registry.register_optimizer("spectron")
class Spectron(opt_mixin.WeightDecayMixin, optim.Optimizer):
"""Spectron: Low-rank spectral optimizer with orthogonalized momentum.

Spectron maintains each 2D weight matrix W as a low-rank factorization W = A @ B^T,
where A ∈ R^(m×r) and B ∈ R^(n×r). It applies momentum, orthogonalizes the updates
using Newton-Schulz iteration, and scales the learning rate by the spectral radii
of both factors.

The algorithm:
1. Compute gradients with respect to A and B from parameter gradients
2. Apply momentum to both factors
3. Orthogonalize momentum buffers using Newton-Schulz iteration
4. Estimate spectral radius of A and B using power iteration
5. Update with scaled learning rate: η / (σ_A + σ_B + 1)
6. Reconstruct full weight matrix W = A @ B^T

References:
- Algorithm 1 (Spectron) and Algorithm 3 (PowerIter) from the Spectron paper (https://arxiv.org/abs/2602.12429).
Low-rank spectral optimization with orthogonalized momentum.

Warning:
- This optimizer requires that all parameters passed in are 2D.
- Low-rank factorization may not be suitable for all parameter types.

Args:
params: Iterable of parameters to optimize or dicts defining parameter groups
lr: The learning rate (η in the algorithm). Default: 3e-4
rank: The rank of the low-rank factorization. Default: 64
momentum_beta: The momentum decay coefficient (β). Default: 0.9
weight_decay: The weight decay coefficient. Default: 0.01
weight_decay_method: Method to apply weight decay. Default: "decoupled"
fp32_matmul_prec: Precision of matmul operations. Default: "medium"
num_ns_steps: Number of Newton-Schulz iteration steps. Default: 5
num_power_iter: Number of power iteration steps for spectral radius. Default: 1
coefficient_type: Type of coefficient set for Newton-Schulz. Default: "quintic"
"""

def __init__(
self,
params: ParamsT,
lr: float = 3e-4,
rank: int = 64,
momentum_beta: float = 0.9,
weight_decay: float = 0.01,
*,
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
fp32_matmul_prec: FP32MatmulPrecT = "medium",
num_ns_steps: int = 5,
num_power_iter: int = 1,
coefficient_type: NSCoeffT = "quintic",
) -> None:
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if rank < 1:
raise ValueError(f"Invalid rank: {rank}")
if not 0.0 <= momentum_beta < 1.0:
raise ValueError(f"Invalid momentum_beta: {momentum_beta}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay: {weight_decay}")
if num_ns_steps < 1:
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")
if num_power_iter < 1:
raise ValueError(f"num_power_iter must be at least 1, got {num_power_iter}")

self.fp32_matmul_prec = fp32_matmul_prec
self.weight_decay_method = weight_decay_method
self.rank = rank
self.num_power_iter = num_power_iter

# Create orthogonalization function following OrthogonalizedOptimizer pattern
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
logging.debug(f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient")
return muon_utils.newton_schulz(
grad,
steps=num_ns_steps,
coefficient_type=coefficient_type,
)

self.scaled_orthogonalize_fn = scaled_orthogonalize_fn

defaults = dict(
lr=lr,
momentum_beta=momentum_beta,
weight_decay=weight_decay,
)

super().__init__(params, defaults)

@overload
def step(self, closure: None = ...) -> None: ...

@overload
def step(self, closure: Callable[[], float]) -> float: ...

@torch.no_grad() # type: ignore[misc]
@override
def step(self, closure: Callable[[], float] | None = None) -> float | None:
"""Performs a single optimization step.

Args:
closure: A closure that reevaluates the model and returns the loss.
"""
if closure is None:
loss = None
else:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue

if p.ndim != 2:
raise ValueError(f"Spectron only supports 2D parameters, got shape {p.shape}")

grad = p.grad
state = self.state[p]

# State initialization
if len(state) == 0:
state["step"] = 0

if state["step"] == 0:
assert all(
key not in state for key in ["factor_A", "factor_B", "momentum_A", "momentum_B", "u_A", "u_B"]
), (
"factor_A, factor_B, momentum_A, momentum_B, u_A, u_B should not be initialized at step 0. "
"Some mismatch has been created likely in checkpointing"
)
self._initialize_state(p, state)

state["step"] += 1

# Get state variables
factor_A = state["factor_A"]
factor_B = state["factor_B"]
momentum_A = state["momentum_A"]
momentum_B = state["momentum_B"]
u_A = state["u_A"]
u_B = state["u_B"]

# Compute gradients for A and B from parameter gradient
# Using chain rule: ∂L/∂A = ∂L/∂W @ B, ∂L/∂B = ∂L/∂W^T @ A
with utils.fp32_matmul_precision("highest"):
grad_A = grad @ factor_B # shape: (m, r)
grad_B = grad.mT @ factor_A # shape: (n, r)

Comment on lines +178 to +181
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gradient dtype mismatch with non-fp32 parameters

grad = p.grad inherits p's dtype, but factor_B is always float32 (initialized from torch.linalg.svd(p.float(), ...)). When the parameter is bfloat16 — the standard dtype for LLM pretraining, which is the stated use case — the line grad @ factor_B will raise a RuntimeError at runtime:

RuntimeError: expected scalar type Float but found BFloat16

Even if PyTorch silently promotes the dtype in some contexts, momentum_A.lerp_(grad_A, ...) on line 187 will then fail because momentum_A is float32 but grad_A would be bfloat16.

The gradient should be explicitly cast to float32 before the matmul:

Suggested change
with utils.fp32_matmul_precision("highest"):
grad_A = grad @ factor_B # shape: (m, r)
grad_B = grad.mT @ factor_A # shape: (n, r)
with utils.fp32_matmul_precision("highest"):
grad_A = grad.float() @ factor_B # shape: (m, r)
grad_B = grad.float().mT @ factor_A # shape: (n, r)

# Apply weight decay
self._apply_weight_decay_inplace(factor_A, grad_A, group["lr"], group["weight_decay"])
self._apply_weight_decay_inplace(factor_B, grad_B, group["lr"], group["weight_decay"])

# Update momentum buffers (EMA of gradients)
momentum_A.lerp_(grad_A, 1 - group["momentum_beta"])
momentum_B.lerp_(grad_B, 1 - group["momentum_beta"])

# Orthogonalize momentum using Newton-Schulz
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
orth_momentum_A = self.scaled_orthogonalize_fn(momentum_A)
orth_momentum_B = self.scaled_orthogonalize_fn(momentum_B)

with utils.fp32_matmul_precision("highest"):
# Estimate spectral radius using power iteration
sigma_A, u_A = self._power_iteration(factor_A, u_A, self.num_power_iter)
sigma_B, u_B = self._power_iteration(factor_B, u_B, self.num_power_iter)

# Update power iteration vectors
state["u_A"] = u_A
state["u_B"] = u_B

# Compute scaled learning rate
scaled_lr = group["lr"] / (sigma_A + sigma_B + 1.0)

# Update low-rank factors
factor_A.add_(orth_momentum_A, alpha=-scaled_lr)
factor_B.add_(orth_momentum_B, alpha=-scaled_lr)

# Reconstruct full weight matrix: W = A @ B^T
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
p.copy_(factor_A @ factor_B.mT)

return loss

def _initialize_state(self, p: torch.Tensor, state: dict[str, torch.Tensor]) -> None:
"""Initialize low-rank factors and state for a parameter.

Args:
p: The parameter tensor (shape: m × n)
state: The state dictionary for this parameter
"""
m, n = p.shape
r = min(self.rank, m, n) # Ensure rank doesn't exceed dimensions

# Initialize A and B using SVD of the parameter
# This provides a good initialization close to the original weights
# Low-rank factors are stored in fp32 for numerical stability
with torch.no_grad():
U, S, Vh = torch.linalg.svd(p.float(), full_matrices=False)
# Keep only top r singular values/vectors
sqrt_S = torch.sqrt(S[:r])
factor_A = U[:, :r] * sqrt_S
factor_B = Vh[:r, :].mT * sqrt_S

state["factor_A"] = factor_A
state["factor_B"] = factor_B
# Momentum buffers are always stored in fp32 for numerical stability
state["momentum_A"] = torch.zeros_like(factor_A, dtype=torch.float32)
state["momentum_B"] = torch.zeros_like(factor_B, dtype=torch.float32)

# Initialize power iteration vectors (normalized random vectors in fp32)
u_A = torch.randn(m, dtype=torch.float32, device=p.device)
u_A = u_A / u_A.norm()
u_B = torch.randn(n, dtype=torch.float32, device=p.device)
u_B = u_B / u_B.norm()

state["u_A"] = u_A
state["u_B"] = u_B

def _power_iteration(self, X: torch.Tensor, u: torch.Tensor, num_iters: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Estimate the largest singular value using power iteration.

Args:
X: The matrix to estimate largest singular value for
u: The current approximation of the dominant left singular vector
num_iters: Number of power iteration steps

Returns:
Tuple of (largest singular value, updated_u)
"""
# power_iteration returns (sigma, u, v) but Spectron only needs sigma and u (left singular vector)
sigma, u, _v = power_iteration(X, u, k=num_iters)
return sigma, u
51 changes: 51 additions & 0 deletions emerging_optimizers/utils/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,60 @@
"met_approx_eigvals_criteria",
"conjugate",
"orthogonal_iteration",
"power_iteration",
]


def power_iteration(
W: torch.Tensor,
u: torch.Tensor,
k: int = 1,
eps: float = 1e-8,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Approximate largest singular value and left/right singular vectors using power iteration.

Implements Algorithm 3 from the Spectron paper (https://arxiv.org/abs/2602.12429). This method iteratively refines
estimates of the dominant singular value and corresponding left and right singular vectors
of a matrix W.

Args:
W: Matrix of shape (p, q) to analyze
u: Initial left singular vector of shape (p,), should be normalized
k: Number of power iteration steps. Default: 1
eps: Small constant for numerical stability. Default: 1e-8

Returns:
Tuple of (sigma, u, v) where:
- sigma: Approximation of the largest singular value (scalar tensor)
- u: Updated left singular vector of shape (p,)
- v: Updated right singular vector of shape (q,)
"""
# Ensure initial normalization
u = u / u.norm(p=2).clamp_min(eps)

# Power iteration loop
for _ in range(k):
# v ← W^T u (right vector)
v = W.mT @ u

# v ← v / ||v||_2 (normalize right vector)
v = v / v.norm(p=2).clamp_min(eps)

# u ← W v (left vector)
u = W @ v

# u ← u / ||u||_2 (normalize left vector)
u = u / u.norm(p=2).clamp_min(eps)

# σ ← u^T W v (Rayleigh quotient approximation)
v = W.mT @ u
v = v / v.norm(p=2).clamp_min(eps)
sigma = u @ (W @ v)

# Return σ, u, and v
return sigma, u, v


def eigh_with_fallback(
x: Tensor,
force_double: bool = False,
Expand Down
1 change: 1 addition & 0 deletions tests/ci/L0_Tests_CPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ error=0
torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1
torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_spectron.py --device=cpu -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu -v -2 || error=1

exit "${error}"
1 change: 1 addition & 0 deletions tests/ci/L0_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ error=0
coverage run -p --source=emerging_optimizers tests/test_muon_utils.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_adaptive_muon.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_spectron.py --device=cuda -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_soap_utils.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_soap.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py -v -2 || error=1
Expand Down
Loading