Warning
Only verified with DDPStrategy, FSDP and deepspeed are still untested. Known issue exists in deepspeed_stage_2 due to sharded grad.
This repo tests out Muon optimizer and its variant applications. Including convenient mixture with AdamW, Scion, or under constraints like Stiefel Manifold, Spectral sphere,.etc.
- Roadmap: TODO.md
This project can be easily installed via pip install -e .. For developers, you may install via pip install --no-build-isolation -e .[dev] to obtain pytest and related toolkits like tilelang for developing custom kernels.
- Official implementation: https://github.com/MoonshotAI/Moonlight
- Core contribution:
- An extra "update-RMS equalization" step so the per-parameter update RMS lines up across matrix vs. non-matrix params, allowing a more unified LR strategy across groups (
MuonvsAdamW).
- An extra "update-RMS equalization" step so the per-parameter update RMS lines up across matrix vs. non-matrix params, allowing a more unified LR strategy across groups (
- [TODO] Unofficial implementation here
- Core contribution:
- QK-Clip for controlling max-logit explosion (reference)
- source update function
- Reference: 《流形上的最速下降:4. Muon + 谱球面》
Note
This implementation has not been exposed due to unsatisfying speed and accuracy. However, you may import it via from manifold_muon.stiefel.stiefel_moonlight import StiefelMoonlight to try it out.
Dual Ascent based method (source code) referred from modula's blog.
Fixed Point based method (source code) referred from 《流形上的最速下降:3. Muon + Stiefel》.
[!] Weight decay is still under development
Our ManifoldMoonlight optimizer uses a new parameter grouping method, which is different from classical Moonlight or Muon implementation.
Current valid grouping choices are ["use_muon", "use_adamw", "use_spectral_muon"].
from manifold_muon import ManifoldMoonlight, deduplicate_and_check_missing_params
params = {
"use_muon": [p
for name, p in model.named_parameters()
if ((p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) and not ("q_proj" in name and "k_proj" in name))
],
"use_adamw": [p
for name, p in model.named_parameters()
if not ((p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) and ("q_proj" not in name and "k_proj" not in name))
],
"use_spectral_muon": [p
for name, p in model.named_parameters()
if (("q_proj" in name and "k_proj" in name) and not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name))
],
}
# We highly suggest to add this line to check if any missing or duplicate params exists among groups
deduplicate_and_check_missing_params(model, params)
optimizer = ManifoldMoonlight(
grouped_params = params,
...
)