A performance-optimized Muon optimizer for PyTorch.
Features:
- Foreach-native: uses
torch._foreach_*ops for momentum, weight decay, and parameter updates. - Batched Newton-Schulz: groups matrices by shape for parallel orthogonalization.
- Auto-parameter routing: automatically partitions model parameters into Muon-eligible (≥2D hidden weights) and auxiliary (embeddings, heads, norms, biases).
- Composite optimizer:
CompositeMuoncombines Muon with any arbitrary auxiliary optimizer (not just AdamW). - Three LR modes: Keller Jordan's
"original"(with aspect-ratio scaling), Moonshot AI's"match_rms_adamw", and"none"(no scaling). - Momentum conventions:
"ema"(m = beta*m + (1-beta)*g, default) and"classical"(m = beta*m + g). - Corrections: MARS, cautious updates, cautious weight decay, NorMuon, gradient/update clipping (all toggleable).
- Weight normalization: optional Frobenius-norm clamping to
sqrt(fan_out)(from KJ's original Muon). - Half-precision momentum: optional lower-precision momentum buffers for memory savings.
- Polar Express: optimal per-step Newton-Schulz coefficients (default).
- Distributed:
torch.distributedgradient sharding viaall_gather.
uv pip install git+https://github.com/emaballarin/optimuonfrom optimuon import Muon
# Muon for ≥2D hidden weight matrices only
muon = Muon(muon_params, lr=0.02, momentum=0.95, weight_decay=0.01)
# Separate AdamW for everything else
import torch
adamw = torch.optim.AdamW(other_params, lr=3e-4)
# Training loop
for batch in dataloader:
loss = model(batch).loss
loss.backward()
muon.step()
adamw.step()
muon.zero_grad()
adamw.zero_grad()from optimuon import CompositeMuon
optimizer = CompositeMuon(
model,
muon_lr=0.02,
muon_kwargs={"weight_decay": 0.01, "foreach": True},
aux_optimizer_class=torch.optim.AdamW,
aux_optimizer_kwargs={"lr": 3e-4, "betas": (0.9, 0.95), "weight_decay": 0.01},
verbose=True,
)
for batch in dataloader:
loss = model(batch).loss
loss.backward()
optimizer.step()
optimizer.zero_grad()from optimuon import CompositeMuon
optimizer = CompositeMuon(
model,
muon_lr=0.02,
muon_kwargs={
"weight_decay": 0.01,
"mars": True, # MARS gradient correction
"cautious": True, # cautious update masking
"grad_clip": 1.0, # gradient norm clipping
"weight_norm": True, # Frobenius-norm clamping
},
aux_optimizer_class=torch.optim.AdamW,
aux_optimizer_kwargs={"lr": 3e-4},
)from optimuon import CompositeMuon
optimizer = CompositeMuon(
model,
muon_lr=0.02,
aux_optimizer_factory=lambda param_groups: SomeExoticOptimizer(param_groups, lr=1e-3),
)from optimuon import partition_params
result = partition_params(model)
print(f"Muon: {result.muon_names}")
print(f"Aux: {result.aux_names}")- Keller Jordan et al., Muon: An optimizer for hidden layers in neural networks (2024)
- Huizhuo Yuan et al., MARS: Unleashing the Power of Variance Reduction for Training Large Models (2024)
- Kaizhao Liang et al., Cautious Optimizers: Improving Training with One Line of Code (2024)
- Moonshot AI, Muon is Scalable for LLM Training (2025)
- Essential AI, Practical Efficiency of Muon for Pretraining (2025)
- Noah Amsel et al., The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm (2025)
- Zichong Li et al., NorMuon: Making Muon more efficient and scalable (2025)
- Lizhang Chen et al., Cautious Weight Decay (2025)
MIT