Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions primus/backends/megatron/patches/muon_optimizer_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
###############################################################################
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################

"""
Megatron Muon Optimizer patches.

This module patches megatron.training.training.get_megatron_optimizer to
automatically dispatch to get_megatron_muon_optimizer when args.optimizer
contains "muon". Since training.py uses `from megatron.core.optimizer import
get_megatron_optimizer`, we must patch the training module's namespace where
the function is actually used, not megatron.core.optimizer.
"""

import dataclasses

from primus.core.patches import PatchContext, register_patch
from primus.modules.module_utils import log_rank_0


@register_patch(
"megatron.optimizer.muon",
backend="megatron",
phase="before_train",
description="Patch get_megatron_optimizer to dispatch to muon optimizer when optimizer name contains 'muon'.",
)
def patch_get_megatron_optimizer_muon(ctx: PatchContext) -> None:
"""
Patch megatron.training.training.get_megatron_optimizer to delegate to
get_megatron_muon_optimizer when config.optimizer contains "muon".

We patch the training module (not megatron.core.optimizer) because
training.py imports get_megatron_optimizer into its namespace at import
time; patching the optimizer module would not affect the training module's
local reference.
"""
try:
import megatron.training.training as training_module
except ImportError as e:
log_rank_0(f"[Patch:megatron.optimizer.muon] Skip patch (Megatron not available): {e}")
return

original_get_megatron_optimizer = training_module.get_megatron_optimizer

if getattr(original_get_megatron_optimizer, "_primus_muon_wrapper", False):
return

def _patched_get_megatron_optimizer(
config,
model_chunks,
no_weight_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0,
use_gloo_process_groups=True,
default_skip_embedding_weight_decay=False,
pg_collection=None,
):
if not config.optimizer or "muon" not in config.optimizer:
return original_get_megatron_optimizer(
config,
model_chunks,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
use_gloo_process_groups=use_gloo_process_groups,
default_skip_embedding_weight_decay=default_skip_embedding_weight_decay,
pg_collection=pg_collection,
)

from primus.backends.megatron.core.optimizer.moun import (
get_megatron_muon_optimizer,
)
from primus.backends.megatron.core.optimizer.moun_optimizer_config import (
MounOptimizerConfig,
)

args = ctx.extra.get("backend_args", {})
kwargs = {}
for f in dataclasses.fields(MounOptimizerConfig):
if hasattr(args, f.name):
kwargs[f.name] = getattr(args, f.name)

moun_config = MounOptimizerConfig(**kwargs)
moun_config.timers = config.timers

return get_megatron_muon_optimizer(
moun_config,
model_chunks,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
use_gloo_process_groups=use_gloo_process_groups,
layer_wise_distributed_optimizer="dist" in config.optimizer,
pg_collection=pg_collection,
)

setattr(_patched_get_megatron_optimizer, "_primus_muon_wrapper", True)
training_module.get_megatron_optimizer = _patched_get_megatron_optimizer
log_rank_0(
"[Patch:megatron.optimizer.muon] Patched get_megatron_optimizer in megatron.training.training "
"to dispatch to muon when optimizer contains 'muon'."
)