From dbcb7e650fb83956193ab661db1c6668b696623b Mon Sep 17 00:00:00 2001 From: HuangWei-95 Date: Sat, 28 Feb 2026 15:01:51 +0800 Subject: [PATCH 1/2] feat(megatron): add muon optimizer support via get_megatron_optimizer patch Patch megatron.core.optimizer.get_megatron_optimizer to automatically dispatch to get_megatron_muon_optimizer when config.optimizer contains "muon" (e.g., "muon", "distributed_muon"), enabling muon optimizer in the backends workflow without a separate branch in MegatronTrainer. --- .../patches/muon_optimizer_patches.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 primus/backends/megatron/patches/muon_optimizer_patches.py diff --git a/primus/backends/megatron/patches/muon_optimizer_patches.py b/primus/backends/megatron/patches/muon_optimizer_patches.py new file mode 100644 index 000000000..35611a25a --- /dev/null +++ b/primus/backends/megatron/patches/muon_optimizer_patches.py @@ -0,0 +1,93 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Megatron Muon Optimizer patches. + +This module patches megatron.core.optimizer.get_megatron_optimizer to +automatically dispatch to get_megatron_muon_optimizer when args.optimizer +contains "muon", enabling muon optimizer support in the backends workflow +without maintaining a separate branch in MegatronTrainer. +""" + +import dataclasses + +from primus.core.patches import PatchContext, register_patch +from primus.modules.module_utils import log_rank_0 + + +@register_patch( + "megatron.optimizer.muon", + backend="megatron", + phase="before_train", + description="Patch get_megatron_optimizer to dispatch to muon optimizer when optimizer name contains 'muon'.", +) +def patch_get_megatron_optimizer_muon(ctx: PatchContext) -> None: + """ + Patch megatron.core.optimizer.get_megatron_optimizer to delegate to + get_megatron_muon_optimizer when config.optimizer contains "muon". + """ + import megatron.core.optimizer as optimizer_module + + original_get_megatron_optimizer = optimizer_module.get_megatron_optimizer + + if getattr(original_get_megatron_optimizer, "_primus_muon_wrapper", False): + return + + def _patched_get_megatron_optimizer( + config, + model_chunks, + no_weight_decay_cond=None, + scale_lr_cond=None, + lr_mult=1.0, + use_gloo_process_groups=True, + default_skip_embedding_weight_decay=False, + pg_collection=None, + ): + if not config.optimizer or "muon" not in config.optimizer: + return original_get_megatron_optimizer( + config, + model_chunks, + no_weight_decay_cond=no_weight_decay_cond, + scale_lr_cond=scale_lr_cond, + lr_mult=lr_mult, + use_gloo_process_groups=use_gloo_process_groups, + default_skip_embedding_weight_decay=default_skip_embedding_weight_decay, + pg_collection=pg_collection, + ) + + from primus.backends.megatron.core.optimizer.moun import ( + get_megatron_muon_optimizer, + ) + from primus.backends.megatron.core.optimizer.moun_optimizer_config import ( + MounOptimizerConfig, + ) + + args = ctx.extra.get("backend_args", {}) + kwargs = {} + for f in dataclasses.fields(MounOptimizerConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + + moun_config = MounOptimizerConfig(**kwargs) + moun_config.timers = config.timers + + return get_megatron_muon_optimizer( + moun_config, + model_chunks, + no_weight_decay_cond=no_weight_decay_cond, + scale_lr_cond=scale_lr_cond, + lr_mult=lr_mult, + use_gloo_process_groups=use_gloo_process_groups, + layer_wise_distributed_optimizer="dist" in config.optimizer, + pg_collection=pg_collection, + ) + + setattr(_patched_get_megatron_optimizer, "_primus_muon_wrapper", True) + optimizer_module.get_megatron_optimizer = _patched_get_megatron_optimizer + log_rank_0( + "[Patch:megatron.optimizer.muon] Patched get_megatron_optimizer to dispatch to muon when optimizer contains 'muon'" + ) From c722ed81327f2421d65c0c37141d031cf7589518 Mon Sep 17 00:00:00 2001 From: HuangWei-95 Date: Sat, 28 Feb 2026 16:07:35 +0800 Subject: [PATCH 2/2] fix patch point --- .../patches/muon_optimizer_patches.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/primus/backends/megatron/patches/muon_optimizer_patches.py b/primus/backends/megatron/patches/muon_optimizer_patches.py index 35611a25a..8737d7380 100644 --- a/primus/backends/megatron/patches/muon_optimizer_patches.py +++ b/primus/backends/megatron/patches/muon_optimizer_patches.py @@ -7,10 +7,11 @@ """ Megatron Muon Optimizer patches. -This module patches megatron.core.optimizer.get_megatron_optimizer to +This module patches megatron.training.training.get_megatron_optimizer to automatically dispatch to get_megatron_muon_optimizer when args.optimizer -contains "muon", enabling muon optimizer support in the backends workflow -without maintaining a separate branch in MegatronTrainer. +contains "muon". Since training.py uses `from megatron.core.optimizer import +get_megatron_optimizer`, we must patch the training module's namespace where +the function is actually used, not megatron.core.optimizer. """ import dataclasses @@ -27,12 +28,21 @@ ) def patch_get_megatron_optimizer_muon(ctx: PatchContext) -> None: """ - Patch megatron.core.optimizer.get_megatron_optimizer to delegate to + Patch megatron.training.training.get_megatron_optimizer to delegate to get_megatron_muon_optimizer when config.optimizer contains "muon". + + We patch the training module (not megatron.core.optimizer) because + training.py imports get_megatron_optimizer into its namespace at import + time; patching the optimizer module would not affect the training module's + local reference. """ - import megatron.core.optimizer as optimizer_module + try: + import megatron.training.training as training_module + except ImportError as e: + log_rank_0(f"[Patch:megatron.optimizer.muon] Skip patch (Megatron not available): {e}") + return - original_get_megatron_optimizer = optimizer_module.get_megatron_optimizer + original_get_megatron_optimizer = training_module.get_megatron_optimizer if getattr(original_get_megatron_optimizer, "_primus_muon_wrapper", False): return @@ -87,7 +97,8 @@ def _patched_get_megatron_optimizer( ) setattr(_patched_get_megatron_optimizer, "_primus_muon_wrapper", True) - optimizer_module.get_megatron_optimizer = _patched_get_megatron_optimizer + training_module.get_megatron_optimizer = _patched_get_megatron_optimizer log_rank_0( - "[Patch:megatron.optimizer.muon] Patched get_megatron_optimizer to dispatch to muon when optimizer contains 'muon'" + "[Patch:megatron.optimizer.muon] Patched get_megatron_optimizer in megatron.training.training " + "to dispatch to muon when optimizer contains 'muon'." )