Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 179 additions & 13 deletions megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import logging
import warnings
from collections import defaultdict
from dataclasses import astuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -47,14 +48,22 @@
from ..transformer.module import MegatronModule
from ..utils import get_model_config, get_pg_rank, get_pg_size, is_te_min_version, log_single_rank
from .distrib_optimizer import DistributedOptimizer
from .emerging_optimizers import (
_EMERGING_OPTIMIZERS,
HAVE_EMERGING_OPTIMIZERS,
_create_emerging_optimizer,
)
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .layer_wise_optimizer import LayerWiseDistributedOptimizer
from .optimizer import (
ChainedOptimizer,
Float16OptimizerWithFloat16Params,
FP32Optimizer,
MegatronOptimizer,
param_group_identifier_keys,
)

# Subclass aliases kept for backward compatibility; all are OptimizerConfig.
from .optimizer_config import (
AdamOptimizerConfig,
OptimizerConfig,
Expand Down Expand Up @@ -134,14 +143,6 @@ def _get_param_groups(
# Map (pg_overrides, is_expert_parallel) to params.
params_map = {}

if config_overrides is None:
# TODO remove this default behavior eventually.
# This is only needed for backwards compatibility with the old config overrides API where
# the config_overrides argument by default lead to bias parameters and length 1 parameters.
# We assume that users of decoupled LR already provide config overrides so will adapt
# to the new API.
config_overrides = get_standard_config_overrides(config=config)

for model_chunk in model_chunks:
for name, param in model_chunk.named_parameters():
if not param.requires_grad:
Expand Down Expand Up @@ -276,7 +277,8 @@ def _get_megatron_optimizer_based_on_param_groups(
intra_dist_opt_group: Optional[torch.distributed.ProcessGroup] = None,
distributed_optimizer_instance_id: Optional[int] = 0,
pg_collection: Optional[ProcessGroupCollection] = None,
) -> MegatronOptimizer:
skip_megatron_wrapping: bool = False,
) -> Union[MegatronOptimizer, Tuple[Optional[torch.optim.Optimizer], Optional[Callable]]]:
"""Get Megatron optimizer based on parameter groups.

Args:
Expand All @@ -292,12 +294,24 @@ def _get_megatron_optimizer_based_on_param_groups(
optimizer. Defaults to None.
distributed_optimizer_instance_id (int, optional): Distributed optimizer instance. Defaults
0.
skip_megatron_wrapping (bool): if True, return a
``(optimizer, init_state_fn)`` tuple of the raw PyTorch optimizer
without any Megatron wrapping. Useful when the caller
(e.g. LayerWiseDistributedOptimizer) performs its own wrapping.

Returns:
Instance of MegatronOptimizer.
Instance of MegatronOptimizer, or ``(optimizer, init_state_fn)`` when
*skip_megatron_wrapping=True*.
"""
# TODO: Logic needs to be updated to handle different optimizer types (i.e., param_groups
# passed into this function need to correspond to the same optimizer).
# All param_groups passed here must belong to the same optimizer type (adam / sgd).
# Callers are responsible for splitting by optimizer type before calling this function.

if skip_megatron_wrapping and config.use_precision_aware_optimizer:
raise ValueError(
"skip_megatron_wrapping=True is incompatible with use_precision_aware_optimizer."
)
if skip_megatron_wrapping and config.optimizer_cpu_offload:
raise ValueError("skip_megatron_wrapping=True is incompatible with optimizer_cpu_offload.")

# When freezing sub-models we may have no trainable parameters on a rank and
# hence an empty param_groups. However, we still need to create an optimizer
Expand Down Expand Up @@ -412,6 +426,9 @@ def init_state_fn(opt, config=None):
optimizer = None
init_state_fn = None

if skip_megatron_wrapping:
return optimizer, init_state_fn

# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
Expand Down Expand Up @@ -502,6 +519,137 @@ def check_config_overrides_consistency(
return True


def _get_megatron_emerging_optimizer(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
config_overrides: Optional[Dict[ParamKey, Any]] = None,
pg_collection: Optional[ProcessGroupCollection] = None,
) -> MegatronOptimizer:
"""Build an emerging optimizer (e.g. Muon) for the given model chunks.

Parameter separation (e.g., linear weights -> Muon, rest -> Adam) is expressed as a
config_override, the same mechanism used for weight-decay and learning-rate overrides.
Adam/SGD groups are delegated to _get_megatron_optimizer_based_on_param_groups so they
go through the exact same code path as the standard optimizer factory.

When ``config.use_layer_wise_distributed_optimizer`` is True, the underlying optimizers
are wrapped with :class:`LayerWiseDistributedOptimizer`.
"""
eopt_name = config.optimizer
use_layer_wise = config.use_layer_wise_distributed_optimizer

# Handle legacy "dist_*" optimizer names (e.g. "dist_muon" → "muon" + layer-wise).
if eopt_name.startswith('dist_'):
bare_name = eopt_name[len('dist_') :]
warnings.warn(
f"optimizer='{eopt_name}' is deprecated. "
f"Use optimizer='{bare_name}' with use_layer_wise_distributed_optimizer=True.",
DeprecationWarning,
stacklevel=3,
)
eopt_name = bare_name
use_layer_wise = True

if not HAVE_EMERGING_OPTIMIZERS:
raise ImportError(
f"emerging-optimizers package is required for optimizer='{eopt_name}'. "
"Install it with: pip install emerging-optimizers"
)
if eopt_name not in _EMERGING_OPTIMIZERS:
raise ValueError(f"Unsupported emerging optimizer: {eopt_name}")
if config.fp16:
raise ValueError('emerging optimizer with fp16 is not supported.')

if pg_collection is None:
pg_collection = ProcessGroupCollection.use_mpu_process_groups()

log_single_rank(logger, logging.INFO, f'Setting up emerging optimizer with config {config}')

# Tag parameters with optimizer-specific attributes (expert_tp, is_qkv).
for model_chunk in model_chunks:
for name, param in model_chunk.named_parameters():
if not param.requires_grad:
continue
if 'experts' in name and 'shared' not in name:
param.expert_tp = True
# TODO(deyuf): support MLA
if 'linear_qkv.weight' in name and len(param.shape) == 2:
param.is_qkv = True

# Apply optimizer-specific default param overrides (e.g. muon: non-linear -> adam).
config_overrides.update(_EMERGING_OPTIMIZERS[eopt_name].default_param_overrides)

# Build param groups and bucket by (optimizer_name, is_expert_parallel).
# Layer-wise distributed optimizer handles expert params internally so we skip that split.
all_param_groups = _get_param_groups(model_chunks, config, config_overrides)
grouped_param_groups = defaultdict(list)
for group in all_param_groups:
opt_name = group.get('optimizer', eopt_name)
is_expert = group['is_expert_parallel'] and not use_layer_wise
grouped_param_groups[(opt_name, is_expert)].append(group)

# Build an optimizer for each (optimizer_name, is_expert) bucket and combine.
results = []
for (opt_name, is_expert), groups in grouped_param_groups.items():
if not groups:
continue

model_parallel_group = pg_collection.tp_ep_pp if is_expert else pg_collection.mp

if opt_name in _EMERGING_OPTIMIZERS:
optimizer, init_state_fn = _create_emerging_optimizer(
config, groups, eopt_name, model_chunks, pg_collection
)
if use_layer_wise:
result = (optimizer, init_state_fn)
else:
if config.bf16:
optimizer = Float16OptimizerWithFloat16Params(
optimizer, config, None, init_state_fn
)
else:
optimizer = FP32Optimizer(optimizer, config, init_state_fn)
setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group)
if pg_collection is None or not hasattr(pg_collection, 'tp'):
tp_group = parallel_state.get_tensor_model_parallel_group()
else:
tp_group = pg_collection.tp
setattr(optimizer, 'tp_group', tp_group)
result = optimizer
else:
fallback_config = copy.copy(config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Does this need deepcopy? there could be very heavy structure in config.

fallback_config.optimizer = opt_name
fallback_config.use_distributed_optimizer = False
result = _get_megatron_optimizer_based_on_param_groups(
config=fallback_config,
model_chunks=model_chunks,
param_groups=groups,
model_parallel_group=model_parallel_group,
pg_collection=pg_collection,
skip_megatron_wrapping=use_layer_wise,
)
# TODO(deyuf): ChainedOptimizer currently asserts all sub-optimizers
# share the same config. Revisit this design now that emerging
# optimizers mix different optimizer types (e.g. Muon + Adam).
# For now, reset to the top-level config so the assertion holds.
if not use_layer_wise and hasattr(result, 'config'):
result.config = config
results.append(result)

if use_layer_wise:
base_optimizers, init_fns = (), ()
if results:
base_optimizers, init_fns = zip(*results)
log_single_rank(
logger, logging.INFO, f'Using LayerWiseDistributedOptimizer for {eopt_name}'
)
return LayerWiseDistributedOptimizer(
list(base_optimizers), config, pg_collection, init_state_fn_list=list(init_fns)
)

return ChainedOptimizer(results)


def get_megatron_optimizer(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
Expand All @@ -512,7 +660,10 @@ def get_megatron_optimizer(
) -> MegatronOptimizer:
"""Retrieve the Megatron optimizer for model chunks.

Handles both standard optimizers (Adam, SGD) and emerging optimizers (e.g. Muon).
We use separate optimizers for expert parameters and non-expert parameters.
For emerging optimizers with ``config.use_layer_wise_distributed_optimizer=True``,
the optimizer is automatically wrapped with :class:`LayerWiseDistributedOptimizer`.

Args:
config (OptimizerConfig): optimizer configuration object.
Expand All @@ -529,10 +680,25 @@ def get_megatron_optimizer(
Instance of MegatronOptimizer.
"""

log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}')
# None → apply standard defaults. To extend defaults with custom overrides,
# start from get_standard_config_overrides(config) and merge yours in.
if config_overrides is None:
config_overrides = get_standard_config_overrides(config)

check_config_overrides_consistency(config, config_overrides)

# TODO: the standard and emerging optimizer paths handle pg_collection differently;
# unify them so both use a single pg_collection-based flow.
if config.optimizer not in ('adam', 'sgd'):
return _get_megatron_emerging_optimizer(
config=config,
model_chunks=model_chunks,
config_overrides=config_overrides,
pg_collection=pg_collection,
)

log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}')

# Separate out first model chunk if overlapping param AG with optimizer step.
if config.overlap_param_gather_with_optimizer_step:
all_dense_model_chunks = [[model_chunks[0]], model_chunks[1:]]
Expand Down
Loading
Loading