Skip to content

[dev] refactor to support emerging optimizers beyond muon#3618

Queued
FDecaYed wants to merge 9 commits intoNVIDIA:devfrom
FDecaYed:deyuf/emerging_opt_refactor
Queued

[dev] refactor to support emerging optimizers beyond muon#3618
FDecaYed wants to merge 9 commits intoNVIDIA:devfrom
FDecaYed:deyuf/emerging_opt_refactor

Conversation

@FDecaYed
Copy link
Contributor

@FDecaYed FDecaYed commented Feb 26, 2026

What does this PR do ?

refactor optimizer config, distributed-optimizer dispatch and megatron optimizer get() in preparation for allowing more optimizers.

build on top of #3325

main PR: #3638

Summary

Unify optimizer creation so standard (Adam/SGD) and emerging (Muon) optimizers go through a single get_megatron_optimizer() entry point.

  • Single factory — All emerging optimizer logic consolidated into one new function _get_megatron_emerging_optimizer() in __init__.py. It reuses the same param-grouping and config-override mechanism as standard optimizers, removing the need to manually group parameters into separate optimizers with freeze/unfreeze hack.
  • Single config — Collapsed AdamOptimizerConfig/SGDOptimizerConfig subclasses back into one OptimizerConfig. Since every param group can override via config_overrides anyway, one default config + overrides is cleaner than passing multiple config objects.
  • Emerging optimizer registry — New emerging_optimizers.py with a pluggable registry. Each supported emerging optimizer maps to {optimizer_cls, config_to_kwargs, default_param_overrides, init_state_fn}. TensorParallelMuon and all Muon helpers moved here from muon.py (only backward-compat shim left in muon.py).
  • dist_muon deprecatedmuon + --use-distributed-optimizer is resolved at the argument level into a new use_layer_wise_distributed_optimizer flag. this replaces dist_muon. use_distributed_optimizer reset to False to avoid side effects in the standard distributed-optimizer code path.
  • LayerWiseDistributedOptimizer now expects plain torch optimizers (wrapping happens internally). cleaner than unwrapping megatron optimizers.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 26, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@FDecaYed FDecaYed changed the title [dev] emerging opt refactor [dev] refactor to support emerging optimizers beyond muon Feb 26, 2026
Copy link
Contributor

@skyw skyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall.

Terminology "plain" can be improved.

setattr(optimizer, 'tp_group', tp_group)
result = optimizer
else:
fallback_config = copy.copy(config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Does this need deepcopy? there could be very heavy structure in config.

"num_ns_steps": config.muon_num_ns_steps,
"scale_mode": config.muon_scale_mode,
"extra_scale_factor": config.muon_extra_scale_factor,
"mode": config.muon_tp_mode,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about change this to tp_mode?
so the contract is everything with muon_ prefix in config, translate to kwargs by stripping out the prefix. code can also be generalized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. we can do it in next PR when we bump emerging_optimizers and really add support for other optimizers

# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

"""Megatron muon optimizer wrapper to handle tensor-parallel."""
"""Backward-compatible shim — all code now lives in ``emerging_optimizers``."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Does everything exposed through muon.py still work? just with a deprecation warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll test it. haven't run the draft yet

# Muon optimizer check
if 'muon' in args.optimizer:
# Muon / emerging optimizer check
if args.optimizer in ('muon', 'dist_muon'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider creating a emerging optimizer group for everything with muon_, soap_ or other prefixs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, we'll change these part properly in next step when we add support

@FDecaYed FDecaYed force-pushed the deyuf/emerging_opt_refactor branch from 1ab146d to 6b21428 Compare February 27, 2026 06:20
@FDecaYed FDecaYed marked this pull request as ready for review February 27, 2026 09:21
@FDecaYed FDecaYed requested review from a team as code owners February 27, 2026 09:21
@FDecaYed
Copy link
Contributor Author

/ok to test 3ce76a9

@yaoyu-33
Copy link
Contributor

yaoyu-33 commented Mar 2, 2026

@FDecaYed : I think Megatron-Bridge needs some api support change to cover this PR as well?

@FDecaYed
Copy link
Contributor Author

FDecaYed commented Mar 4, 2026

@FDecaYed : I think Megatron-Bridge needs some api support change to cover this PR as well?

yes. but I'm hoping this would be last refactor

@FDecaYed FDecaYed added this pull request to the merge queue Mar 5, 2026
github-merge-queue bot pushed a commit that referenced this pull request Mar 5, 2026
Signed-off-by: Hao Wu <skyw@nvidia.com>
Co-authored-by: Hao Wu <skyw@nvidia.com>
@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22701547164

github-merge-queue bot pushed a commit that referenced this pull request Mar 5, 2026
Signed-off-by: Hao Wu <skyw@nvidia.com>
Co-authored-by: Hao Wu <skyw@nvidia.com>
@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22701733585

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants