-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[dev] refactor to support emerging optimizers beyond muon #3618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+567
−527
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
a94046d
Move muon to its own config class
skyw 8b75626
Generalize interface of muon to support more optimizers
skyw da1a405
Update muon reference to emerging optimizer
skyw c93efd6
generalize init_state_fn
skyw 7b553b4
draft get opt refactor
FDecaYed 7303c08
full refactor to allow emerging optimizer properly
FDecaYed 6b21428
address comments
FDecaYed 71d3b36
fix minor issue and initial test passes
FDecaYed 3ce76a9
fix minor test issues
FDecaYed File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| import copy | ||
| import logging | ||
| import warnings | ||
| from collections import defaultdict | ||
| from dataclasses import astuple | ||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
|
|
||
|
|
@@ -47,14 +48,22 @@ | |
| from ..transformer.module import MegatronModule | ||
| from ..utils import get_model_config, get_pg_rank, get_pg_size, is_te_min_version, log_single_rank | ||
| from .distrib_optimizer import DistributedOptimizer | ||
| from .emerging_optimizers import ( | ||
| _EMERGING_OPTIMIZERS, | ||
| HAVE_EMERGING_OPTIMIZERS, | ||
| _create_emerging_optimizer, | ||
| ) | ||
| from .grad_scaler import ConstantGradScaler, DynamicGradScaler | ||
| from .layer_wise_optimizer import LayerWiseDistributedOptimizer | ||
| from .optimizer import ( | ||
| ChainedOptimizer, | ||
| Float16OptimizerWithFloat16Params, | ||
| FP32Optimizer, | ||
| MegatronOptimizer, | ||
| param_group_identifier_keys, | ||
| ) | ||
|
|
||
| # Subclass aliases kept for backward compatibility; all are OptimizerConfig. | ||
| from .optimizer_config import ( | ||
| AdamOptimizerConfig, | ||
| OptimizerConfig, | ||
|
|
@@ -134,14 +143,6 @@ def _get_param_groups( | |
| # Map (pg_overrides, is_expert_parallel) to params. | ||
| params_map = {} | ||
|
|
||
| if config_overrides is None: | ||
| # TODO remove this default behavior eventually. | ||
| # This is only needed for backwards compatibility with the old config overrides API where | ||
| # the config_overrides argument by default lead to bias parameters and length 1 parameters. | ||
| # We assume that users of decoupled LR already provide config overrides so will adapt | ||
| # to the new API. | ||
| config_overrides = get_standard_config_overrides(config=config) | ||
|
|
||
| for model_chunk in model_chunks: | ||
| for name, param in model_chunk.named_parameters(): | ||
| if not param.requires_grad: | ||
|
|
@@ -276,7 +277,8 @@ def _get_megatron_optimizer_based_on_param_groups( | |
| intra_dist_opt_group: Optional[torch.distributed.ProcessGroup] = None, | ||
| distributed_optimizer_instance_id: Optional[int] = 0, | ||
| pg_collection: Optional[ProcessGroupCollection] = None, | ||
| ) -> MegatronOptimizer: | ||
| skip_megatron_wrapping: bool = False, | ||
| ) -> Union[MegatronOptimizer, Tuple[Optional[torch.optim.Optimizer], Optional[Callable]]]: | ||
| """Get Megatron optimizer based on parameter groups. | ||
|
|
||
| Args: | ||
|
|
@@ -292,12 +294,24 @@ def _get_megatron_optimizer_based_on_param_groups( | |
| optimizer. Defaults to None. | ||
| distributed_optimizer_instance_id (int, optional): Distributed optimizer instance. Defaults | ||
| 0. | ||
| skip_megatron_wrapping (bool): if True, return a | ||
| ``(optimizer, init_state_fn)`` tuple of the raw PyTorch optimizer | ||
| without any Megatron wrapping. Useful when the caller | ||
| (e.g. LayerWiseDistributedOptimizer) performs its own wrapping. | ||
|
|
||
| Returns: | ||
| Instance of MegatronOptimizer. | ||
| Instance of MegatronOptimizer, or ``(optimizer, init_state_fn)`` when | ||
| *skip_megatron_wrapping=True*. | ||
| """ | ||
| # TODO: Logic needs to be updated to handle different optimizer types (i.e., param_groups | ||
| # passed into this function need to correspond to the same optimizer). | ||
| # All param_groups passed here must belong to the same optimizer type (adam / sgd). | ||
| # Callers are responsible for splitting by optimizer type before calling this function. | ||
|
|
||
| if skip_megatron_wrapping and config.use_precision_aware_optimizer: | ||
| raise ValueError( | ||
| "skip_megatron_wrapping=True is incompatible with use_precision_aware_optimizer." | ||
| ) | ||
| if skip_megatron_wrapping and config.optimizer_cpu_offload: | ||
| raise ValueError("skip_megatron_wrapping=True is incompatible with optimizer_cpu_offload.") | ||
|
|
||
| # When freezing sub-models we may have no trainable parameters on a rank and | ||
| # hence an empty param_groups. However, we still need to create an optimizer | ||
|
|
@@ -412,6 +426,9 @@ def init_state_fn(opt, config=None): | |
| optimizer = None | ||
| init_state_fn = None | ||
|
|
||
| if skip_megatron_wrapping: | ||
| return optimizer, init_state_fn | ||
|
|
||
| # Mixed precision optimizer. | ||
| # - Note: both the Float16Optimizer and the DistributedOptimizer inherit | ||
| # from the MixedPrecisionOptimizer, which manages any optimizer where | ||
|
|
@@ -502,6 +519,137 @@ def check_config_overrides_consistency( | |
| return True | ||
|
|
||
|
|
||
| def _get_megatron_emerging_optimizer( | ||
| config: OptimizerConfig, | ||
| model_chunks: List[MegatronModule], | ||
| config_overrides: Optional[Dict[ParamKey, Any]] = None, | ||
| pg_collection: Optional[ProcessGroupCollection] = None, | ||
| ) -> MegatronOptimizer: | ||
| """Build an emerging optimizer (e.g. Muon) for the given model chunks. | ||
|
|
||
| Parameter separation (e.g., linear weights -> Muon, rest -> Adam) is expressed as a | ||
| config_override, the same mechanism used for weight-decay and learning-rate overrides. | ||
| Adam/SGD groups are delegated to _get_megatron_optimizer_based_on_param_groups so they | ||
| go through the exact same code path as the standard optimizer factory. | ||
|
|
||
| When ``config.use_layer_wise_distributed_optimizer`` is True, the underlying optimizers | ||
| are wrapped with :class:`LayerWiseDistributedOptimizer`. | ||
| """ | ||
| eopt_name = config.optimizer | ||
| use_layer_wise = config.use_layer_wise_distributed_optimizer | ||
|
|
||
| # Handle legacy "dist_*" optimizer names (e.g. "dist_muon" → "muon" + layer-wise). | ||
| if eopt_name.startswith('dist_'): | ||
| bare_name = eopt_name[len('dist_') :] | ||
| warnings.warn( | ||
| f"optimizer='{eopt_name}' is deprecated. " | ||
| f"Use optimizer='{bare_name}' with use_layer_wise_distributed_optimizer=True.", | ||
| DeprecationWarning, | ||
| stacklevel=3, | ||
| ) | ||
| eopt_name = bare_name | ||
| use_layer_wise = True | ||
|
|
||
| if not HAVE_EMERGING_OPTIMIZERS: | ||
| raise ImportError( | ||
| f"emerging-optimizers package is required for optimizer='{eopt_name}'. " | ||
| "Install it with: pip install emerging-optimizers" | ||
| ) | ||
| if eopt_name not in _EMERGING_OPTIMIZERS: | ||
| raise ValueError(f"Unsupported emerging optimizer: {eopt_name}") | ||
| if config.fp16: | ||
| raise ValueError('emerging optimizer with fp16 is not supported.') | ||
|
|
||
| if pg_collection is None: | ||
| pg_collection = ProcessGroupCollection.use_mpu_process_groups() | ||
|
|
||
| log_single_rank(logger, logging.INFO, f'Setting up emerging optimizer with config {config}') | ||
|
|
||
| # Tag parameters with optimizer-specific attributes (expert_tp, is_qkv). | ||
| for model_chunk in model_chunks: | ||
| for name, param in model_chunk.named_parameters(): | ||
| if not param.requires_grad: | ||
| continue | ||
| if 'experts' in name and 'shared' not in name: | ||
| param.expert_tp = True | ||
| # TODO(deyuf): support MLA | ||
| if 'linear_qkv.weight' in name and len(param.shape) == 2: | ||
| param.is_qkv = True | ||
|
|
||
| # Apply optimizer-specific default param overrides (e.g. muon: non-linear -> adam). | ||
| config_overrides.update(_EMERGING_OPTIMIZERS[eopt_name].default_param_overrides) | ||
|
|
||
| # Build param groups and bucket by (optimizer_name, is_expert_parallel). | ||
| # Layer-wise distributed optimizer handles expert params internally so we skip that split. | ||
| all_param_groups = _get_param_groups(model_chunks, config, config_overrides) | ||
| grouped_param_groups = defaultdict(list) | ||
| for group in all_param_groups: | ||
| opt_name = group.get('optimizer', eopt_name) | ||
| is_expert = group['is_expert_parallel'] and not use_layer_wise | ||
| grouped_param_groups[(opt_name, is_expert)].append(group) | ||
|
|
||
| # Build an optimizer for each (optimizer_name, is_expert) bucket and combine. | ||
| results = [] | ||
| for (opt_name, is_expert), groups in grouped_param_groups.items(): | ||
| if not groups: | ||
| continue | ||
|
|
||
| model_parallel_group = pg_collection.tp_ep_pp if is_expert else pg_collection.mp | ||
|
|
||
| if opt_name in _EMERGING_OPTIMIZERS: | ||
| optimizer, init_state_fn = _create_emerging_optimizer( | ||
| config, groups, eopt_name, model_chunks, pg_collection | ||
| ) | ||
| if use_layer_wise: | ||
| result = (optimizer, init_state_fn) | ||
| else: | ||
| if config.bf16: | ||
| optimizer = Float16OptimizerWithFloat16Params( | ||
| optimizer, config, None, init_state_fn | ||
| ) | ||
| else: | ||
| optimizer = FP32Optimizer(optimizer, config, init_state_fn) | ||
| setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group) | ||
| if pg_collection is None or not hasattr(pg_collection, 'tp'): | ||
| tp_group = parallel_state.get_tensor_model_parallel_group() | ||
| else: | ||
| tp_group = pg_collection.tp | ||
| setattr(optimizer, 'tp_group', tp_group) | ||
| result = optimizer | ||
| else: | ||
| fallback_config = copy.copy(config) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: Does this need deepcopy? there could be very heavy structure in config. |
||
| fallback_config.optimizer = opt_name | ||
| fallback_config.use_distributed_optimizer = False | ||
| result = _get_megatron_optimizer_based_on_param_groups( | ||
| config=fallback_config, | ||
| model_chunks=model_chunks, | ||
| param_groups=groups, | ||
| model_parallel_group=model_parallel_group, | ||
| pg_collection=pg_collection, | ||
| skip_megatron_wrapping=use_layer_wise, | ||
| ) | ||
| # TODO(deyuf): ChainedOptimizer currently asserts all sub-optimizers | ||
| # share the same config. Revisit this design now that emerging | ||
| # optimizers mix different optimizer types (e.g. Muon + Adam). | ||
| # For now, reset to the top-level config so the assertion holds. | ||
| if not use_layer_wise and hasattr(result, 'config'): | ||
| result.config = config | ||
| results.append(result) | ||
|
|
||
| if use_layer_wise: | ||
| base_optimizers, init_fns = (), () | ||
| if results: | ||
| base_optimizers, init_fns = zip(*results) | ||
| log_single_rank( | ||
| logger, logging.INFO, f'Using LayerWiseDistributedOptimizer for {eopt_name}' | ||
| ) | ||
| return LayerWiseDistributedOptimizer( | ||
| list(base_optimizers), config, pg_collection, init_state_fn_list=list(init_fns) | ||
| ) | ||
|
|
||
| return ChainedOptimizer(results) | ||
|
|
||
|
|
||
| def get_megatron_optimizer( | ||
| config: OptimizerConfig, | ||
| model_chunks: List[MegatronModule], | ||
|
|
@@ -512,7 +660,10 @@ def get_megatron_optimizer( | |
| ) -> MegatronOptimizer: | ||
| """Retrieve the Megatron optimizer for model chunks. | ||
|
|
||
| Handles both standard optimizers (Adam, SGD) and emerging optimizers (e.g. Muon). | ||
| We use separate optimizers for expert parameters and non-expert parameters. | ||
| For emerging optimizers with ``config.use_layer_wise_distributed_optimizer=True``, | ||
| the optimizer is automatically wrapped with :class:`LayerWiseDistributedOptimizer`. | ||
|
|
||
| Args: | ||
| config (OptimizerConfig): optimizer configuration object. | ||
|
|
@@ -529,10 +680,25 @@ def get_megatron_optimizer( | |
| Instance of MegatronOptimizer. | ||
| """ | ||
|
|
||
| log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}') | ||
| # None → apply standard defaults. To extend defaults with custom overrides, | ||
| # start from get_standard_config_overrides(config) and merge yours in. | ||
| if config_overrides is None: | ||
| config_overrides = get_standard_config_overrides(config) | ||
|
|
||
| check_config_overrides_consistency(config, config_overrides) | ||
|
|
||
| # TODO: the standard and emerging optimizer paths handle pg_collection differently; | ||
| # unify them so both use a single pg_collection-based flow. | ||
| if config.optimizer not in ('adam', 'sgd'): | ||
| return _get_megatron_emerging_optimizer( | ||
| config=config, | ||
| model_chunks=model_chunks, | ||
| config_overrides=config_overrides, | ||
| pg_collection=pg_collection, | ||
| ) | ||
|
|
||
| log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}') | ||
|
|
||
| # Separate out first model chunk if overlapping param AG with optimizer step. | ||
| if config.overlap_param_gather_with_optimizer_step: | ||
| all_dense_model_chunks = [[model_chunks[0]], model_chunks[1:]] | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.