diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 11aa6c49585..8babff5d4f5 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -2,6 +2,7 @@ import copy import logging import warnings +from collections import defaultdict from dataclasses import astuple from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -47,7 +48,13 @@ from ..transformer.module import MegatronModule from ..utils import get_model_config, get_pg_rank, get_pg_size, is_te_min_version, log_single_rank from .distrib_optimizer import DistributedOptimizer +from .emerging_optimizers import ( + _EMERGING_OPTIMIZERS, + HAVE_EMERGING_OPTIMIZERS, + _create_emerging_optimizer, +) from .grad_scaler import ConstantGradScaler, DynamicGradScaler +from .layer_wise_optimizer import LayerWiseDistributedOptimizer from .optimizer import ( ChainedOptimizer, Float16OptimizerWithFloat16Params, @@ -55,6 +62,8 @@ MegatronOptimizer, param_group_identifier_keys, ) + +# Subclass aliases kept for backward compatibility; all are OptimizerConfig. from .optimizer_config import ( AdamOptimizerConfig, OptimizerConfig, @@ -134,14 +143,6 @@ def _get_param_groups( # Map (pg_overrides, is_expert_parallel) to params. params_map = {} - if config_overrides is None: - # TODO remove this default behavior eventually. - # This is only needed for backwards compatibility with the old config overrides API where - # the config_overrides argument by default lead to bias parameters and length 1 parameters. - # We assume that users of decoupled LR already provide config overrides so will adapt - # to the new API. - config_overrides = get_standard_config_overrides(config=config) - for model_chunk in model_chunks: for name, param in model_chunk.named_parameters(): if not param.requires_grad: @@ -276,7 +277,8 @@ def _get_megatron_optimizer_based_on_param_groups( intra_dist_opt_group: Optional[torch.distributed.ProcessGroup] = None, distributed_optimizer_instance_id: Optional[int] = 0, pg_collection: Optional[ProcessGroupCollection] = None, -) -> MegatronOptimizer: + skip_megatron_wrapping: bool = False, +) -> Union[MegatronOptimizer, Tuple[Optional[torch.optim.Optimizer], Optional[Callable]]]: """Get Megatron optimizer based on parameter groups. Args: @@ -292,12 +294,24 @@ def _get_megatron_optimizer_based_on_param_groups( optimizer. Defaults to None. distributed_optimizer_instance_id (int, optional): Distributed optimizer instance. Defaults 0. + skip_megatron_wrapping (bool): if True, return a + ``(optimizer, init_state_fn)`` tuple of the raw PyTorch optimizer + without any Megatron wrapping. Useful when the caller + (e.g. LayerWiseDistributedOptimizer) performs its own wrapping. Returns: - Instance of MegatronOptimizer. + Instance of MegatronOptimizer, or ``(optimizer, init_state_fn)`` when + *skip_megatron_wrapping=True*. """ - # TODO: Logic needs to be updated to handle different optimizer types (i.e., param_groups - # passed into this function need to correspond to the same optimizer). + # All param_groups passed here must belong to the same optimizer type (adam / sgd). + # Callers are responsible for splitting by optimizer type before calling this function. + + if skip_megatron_wrapping and config.use_precision_aware_optimizer: + raise ValueError( + "skip_megatron_wrapping=True is incompatible with use_precision_aware_optimizer." + ) + if skip_megatron_wrapping and config.optimizer_cpu_offload: + raise ValueError("skip_megatron_wrapping=True is incompatible with optimizer_cpu_offload.") # When freezing sub-models we may have no trainable parameters on a rank and # hence an empty param_groups. However, we still need to create an optimizer @@ -412,6 +426,9 @@ def init_state_fn(opt, config=None): optimizer = None init_state_fn = None + if skip_megatron_wrapping: + return optimizer, init_state_fn + # Mixed precision optimizer. # - Note: both the Float16Optimizer and the DistributedOptimizer inherit # from the MixedPrecisionOptimizer, which manages any optimizer where @@ -502,6 +519,137 @@ def check_config_overrides_consistency( return True +def _get_megatron_emerging_optimizer( + config: OptimizerConfig, + model_chunks: List[MegatronModule], + config_overrides: Optional[Dict[ParamKey, Any]] = None, + pg_collection: Optional[ProcessGroupCollection] = None, +) -> MegatronOptimizer: + """Build an emerging optimizer (e.g. Muon) for the given model chunks. + + Parameter separation (e.g., linear weights -> Muon, rest -> Adam) is expressed as a + config_override, the same mechanism used for weight-decay and learning-rate overrides. + Adam/SGD groups are delegated to _get_megatron_optimizer_based_on_param_groups so they + go through the exact same code path as the standard optimizer factory. + + When ``config.use_layer_wise_distributed_optimizer`` is True, the underlying optimizers + are wrapped with :class:`LayerWiseDistributedOptimizer`. + """ + eopt_name = config.optimizer + use_layer_wise = config.use_layer_wise_distributed_optimizer + + # Handle legacy "dist_*" optimizer names (e.g. "dist_muon" → "muon" + layer-wise). + if eopt_name.startswith('dist_'): + bare_name = eopt_name[len('dist_') :] + warnings.warn( + f"optimizer='{eopt_name}' is deprecated. " + f"Use optimizer='{bare_name}' with use_layer_wise_distributed_optimizer=True.", + DeprecationWarning, + stacklevel=3, + ) + eopt_name = bare_name + use_layer_wise = True + + if not HAVE_EMERGING_OPTIMIZERS: + raise ImportError( + f"emerging-optimizers package is required for optimizer='{eopt_name}'. " + "Install it with: pip install emerging-optimizers" + ) + if eopt_name not in _EMERGING_OPTIMIZERS: + raise ValueError(f"Unsupported emerging optimizer: {eopt_name}") + if config.fp16: + raise ValueError('emerging optimizer with fp16 is not supported.') + + if pg_collection is None: + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + + log_single_rank(logger, logging.INFO, f'Setting up emerging optimizer with config {config}') + + # Tag parameters with optimizer-specific attributes (expert_tp, is_qkv). + for model_chunk in model_chunks: + for name, param in model_chunk.named_parameters(): + if not param.requires_grad: + continue + if 'experts' in name and 'shared' not in name: + param.expert_tp = True + # TODO(deyuf): support MLA + if 'linear_qkv.weight' in name and len(param.shape) == 2: + param.is_qkv = True + + # Apply optimizer-specific default param overrides (e.g. muon: non-linear -> adam). + config_overrides.update(_EMERGING_OPTIMIZERS[eopt_name].default_param_overrides) + + # Build param groups and bucket by (optimizer_name, is_expert_parallel). + # Layer-wise distributed optimizer handles expert params internally so we skip that split. + all_param_groups = _get_param_groups(model_chunks, config, config_overrides) + grouped_param_groups = defaultdict(list) + for group in all_param_groups: + opt_name = group.get('optimizer', eopt_name) + is_expert = group['is_expert_parallel'] and not use_layer_wise + grouped_param_groups[(opt_name, is_expert)].append(group) + + # Build an optimizer for each (optimizer_name, is_expert) bucket and combine. + results = [] + for (opt_name, is_expert), groups in grouped_param_groups.items(): + if not groups: + continue + + model_parallel_group = pg_collection.tp_ep_pp if is_expert else pg_collection.mp + + if opt_name in _EMERGING_OPTIMIZERS: + optimizer, init_state_fn = _create_emerging_optimizer( + config, groups, eopt_name, model_chunks, pg_collection + ) + if use_layer_wise: + result = (optimizer, init_state_fn) + else: + if config.bf16: + optimizer = Float16OptimizerWithFloat16Params( + optimizer, config, None, init_state_fn + ) + else: + optimizer = FP32Optimizer(optimizer, config, init_state_fn) + setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group) + if pg_collection is None or not hasattr(pg_collection, 'tp'): + tp_group = parallel_state.get_tensor_model_parallel_group() + else: + tp_group = pg_collection.tp + setattr(optimizer, 'tp_group', tp_group) + result = optimizer + else: + fallback_config = copy.copy(config) + fallback_config.optimizer = opt_name + fallback_config.use_distributed_optimizer = False + result = _get_megatron_optimizer_based_on_param_groups( + config=fallback_config, + model_chunks=model_chunks, + param_groups=groups, + model_parallel_group=model_parallel_group, + pg_collection=pg_collection, + skip_megatron_wrapping=use_layer_wise, + ) + # TODO(deyuf): ChainedOptimizer currently asserts all sub-optimizers + # share the same config. Revisit this design now that emerging + # optimizers mix different optimizer types (e.g. Muon + Adam). + # For now, reset to the top-level config so the assertion holds. + if not use_layer_wise and hasattr(result, 'config'): + result.config = config + results.append(result) + + if use_layer_wise: + base_optimizers, init_fns = (), () + if results: + base_optimizers, init_fns = zip(*results) + log_single_rank( + logger, logging.INFO, f'Using LayerWiseDistributedOptimizer for {eopt_name}' + ) + return LayerWiseDistributedOptimizer( + list(base_optimizers), config, pg_collection, init_state_fn_list=list(init_fns) + ) + + return ChainedOptimizer(results) + + def get_megatron_optimizer( config: OptimizerConfig, model_chunks: List[MegatronModule], @@ -512,7 +660,10 @@ def get_megatron_optimizer( ) -> MegatronOptimizer: """Retrieve the Megatron optimizer for model chunks. + Handles both standard optimizers (Adam, SGD) and emerging optimizers (e.g. Muon). We use separate optimizers for expert parameters and non-expert parameters. + For emerging optimizers with ``config.use_layer_wise_distributed_optimizer=True``, + the optimizer is automatically wrapped with :class:`LayerWiseDistributedOptimizer`. Args: config (OptimizerConfig): optimizer configuration object. @@ -529,10 +680,25 @@ def get_megatron_optimizer( Instance of MegatronOptimizer. """ - log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}') + # None → apply standard defaults. To extend defaults with custom overrides, + # start from get_standard_config_overrides(config) and merge yours in. + if config_overrides is None: + config_overrides = get_standard_config_overrides(config) check_config_overrides_consistency(config, config_overrides) + # TODO: the standard and emerging optimizer paths handle pg_collection differently; + # unify them so both use a single pg_collection-based flow. + if config.optimizer not in ('adam', 'sgd'): + return _get_megatron_emerging_optimizer( + config=config, + model_chunks=model_chunks, + config_overrides=config_overrides, + pg_collection=pg_collection, + ) + + log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}') + # Separate out first model chunk if overlapping param AG with optimizer step. if config.overlap_param_gather_with_optimizer_step: all_dense_model_chunks = [[model_chunks[0]], model_chunks[1:]] diff --git a/megatron/core/optimizer/emerging_optimizers.py b/megatron/core/optimizer/emerging_optimizers.py new file mode 100644 index 00000000000..3cf36670fd3 --- /dev/null +++ b/megatron/core/optimizer/emerging_optimizers.py @@ -0,0 +1,260 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Emerging optimizer registry. + +To add a new emerging optimizer: + 1. Define its optimizer class (or import it). + 2. Write its ``__init_state_fn`` and ``__config_to_kwargs``. + 3. Add an ``EmergingOptimizerEntry`` to ``_EMERGING_OPTIMIZERS`` at the bottom. +""" + +import logging +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Literal, Optional + +import torch +from torch.optim.optimizer import ParamsT + +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.utils import get_pg_size, log_single_rank + +from .optimizer_config import ParamKey, ParamPredicate + +try: + from emerging_optimizers.orthogonalized_optimizers import ( + OrthogonalizedOptimizer, + get_muon_scale_factor, + ) + from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz_tp + + HAVE_EMERGING_OPTIMIZERS = True +except ImportError: + HAVE_EMERGING_OPTIMIZERS = False + + +logger = logging.getLogger(__name__) + + +# =========================================================================== +# Registry dataclass and public API +# =========================================================================== + + +@dataclass +class EmergingOptimizerEntry: + """Everything needed to create and configure an emerging optimizer. + + Attributes: + optimizer_cls: The torch optimizer class. + init_state_fn: Lazily initialises optimizer state (needed for checkpoint formats). + config_to_kwargs: ``(config, model_chunks, pg_collection) -> dict`` of constructor kwargs. + default_param_overrides: Per-parameter config overrides applied automatically + (e.g. route non-linear params to Adam). + """ + + optimizer_cls: type + init_state_fn: Callable + config_to_kwargs: Callable + default_param_overrides: Dict[ParamKey, Dict[str, Any]] = field(default_factory=dict) + + +def _create_emerging_optimizer(config, param_groups, eopt_name, model_chunks, pg_collection): + """Instantiate an emerging optimizer and return it with its init_state_fn.""" + entry = _EMERGING_OPTIMIZERS[eopt_name] + eopt_kwargs = entry.config_to_kwargs(config, model_chunks, pg_collection) + optimizer = entry.optimizer_cls(param_groups, **eopt_kwargs) + return optimizer, entry.init_state_fn + + +# =========================================================================== +# Shared helpers +# =========================================================================== + + +def _is_nonlinear_or_embedding(param): + """True for parameters that should NOT use the emerging optimizer.""" + return getattr(param, 'is_embedding_or_output_parameter', False) or len(param.shape) != 2 + + +def _get_qkv_split_shapes(model_cfg) -> List[int]: + """Compute QKV split shapes from model config.""" + return [ + model_cfg.num_attention_heads // model_cfg.num_query_groups * model_cfg.kv_channels, + model_cfg.kv_channels, + model_cfg.kv_channels, + ] + + +# =========================================================================== +# Registry – populated below only when emerging_optimizers is installed. +# =========================================================================== + +_EMERGING_OPTIMIZERS: Dict[str, EmergingOptimizerEntry] = {} + + +# =========================================================================== +# Muon +# =========================================================================== + +if HAVE_EMERGING_OPTIMIZERS: + + class TensorParallelMuon(OrthogonalizedOptimizer): + """Tensor Parallel Muon optimizer.""" + + def __init__( + self, + params: ParamsT, + lr: float = 3e-4, + momentum_beta: float = 0.95, + use_nesterov: bool = True, + weight_decay: float = 0.01, + use_decoupled_weight_decay: bool = True, + split_qkv: bool = False, + is_qkv_fn: Callable[[torch.Tensor], bool] | None = None, + qkv_split_shapes: tuple[int, int, int] | None = None, + fp32_matmul_prec: str = "medium", + coefficient_type: str = "quintic", + num_ns_steps: int = 5, + scale_mode: str = "spectral", + extra_scale_factor: float = 1.0, + pg_collection: Optional[ProcessGroupCollection] = None, + mode: Literal["blockwise", "duplicated", "distributed"] = "duplicated", + ) -> None: + if num_ns_steps < 1: + raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") + + def scaled_orthogonalize_fn( + grad: torch.Tensor, + tp_group: torch.distributed.ProcessGroup, + partition_dim: int | None = None, + ) -> torch.Tensor: + log_single_rank( + logger, + logging.DEBUG, + f'Orthogonalizing grad with {num_ns_steps} steps, ' + f'{coefficient_type} coefficient, ' + f'{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}', + ) + size = [grad.size(-2), grad.size(-1)] + if partition_dim is not None: + size[partition_dim] *= get_pg_size(tp_group) + orth_grad = newton_schulz_tp( + grad, + steps=num_ns_steps, + coefficient_type=coefficient_type, + tp_group=tp_group, + partition_dim=partition_dim, + mode="duplicated" if mode == "blockwise" else mode, + ) + scale_factor = get_muon_scale_factor(size[0], size[1], mode=scale_mode) + return orth_grad * scale_factor * extra_scale_factor + + self.pg_collection = pg_collection + self.mode = mode + self.split_qkv = split_qkv + self.is_qkv_fn = is_qkv_fn + self.qkv_split_shapes = qkv_split_shapes + + weight_decay_method = "decoupled" if use_decoupled_weight_decay else "l2" + super().__init__( + params, + lr, + momentum_beta, + use_nesterov=use_nesterov, + weight_decay=weight_decay, + weight_decay_method=weight_decay_method, + fp32_matmul_prec=fp32_matmul_prec, + scaled_orthogonalize_fn=scaled_orthogonalize_fn, + ) + + def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """Orthogonalize the momentum. + + Args: + p: The parameter tensor. i is necessary to pass param tensor in addition to + momentum because a lot of information is only available in the param tensor, + attributes for example. + grad: The momentum tensor. + + Returns: + The orthogonalized gradient tensor. + """ + # TODO(deyuf): switch to group + if self.pg_collection: + tp_group = ( + self.pg_collection.expt_tp + if getattr(p, 'expert_tp', False) + else self.pg_collection.tp + ) + else: + tp_group = None + partition_dim = None if self.mode == "blockwise" else getattr(p, "partition_dim", None) + if partition_dim == -1: + partition_dim = None + + if self.split_qkv and self.is_qkv_fn(p): # type: ignore[misc] + grad_shape = grad.shape + log_single_rank( + logger, + logging.DEBUG, + f'qkv split grad shape {grad_shape}, ' f'split shapes {self.qkv_split_shapes}', + ) + num_query_groups = grad_shape[0] // sum(self.qkv_split_shapes) + qkv_grads = torch.split( + grad.view(num_query_groups, sum(self.qkv_split_shapes), -1), + self.qkv_split_shapes, + dim=1, + ) + qkv_grads = [g.reshape(-1, grad_shape[-1]) for g in qkv_grads] + + qkv_grads = [ + self.scaled_orthogonalize_fn(g, tp_group, partition_dim).view( + num_query_groups, -1, grad_shape[-1] + ) + for g in qkv_grads + ] + grad = torch.cat(qkv_grads, dim=1).view(grad_shape) + else: + grad = self.scaled_orthogonalize_fn(grad, tp_group, partition_dim) + return grad + + def _muon_init_state_fn(opt, config=None): + """Initialize Muon optimizer state for torch_dist checkpoint format.""" + for group in opt.param_groups: + for p in group['params']: + if len(opt.state[p]) == 0: + opt.state[p]['momentum_buffer'] = torch.zeros_like(p.data) + + def _muon_config_to_kwargs(config, model_chunks, pg_collection) -> Dict[str, Any]: + """Convert OptimizerConfig to TensorParallelMuon constructor kwargs.""" + return { + "lr": config.lr, + "weight_decay": config.weight_decay, + "momentum_beta": config.muon_momentum, + "use_nesterov": config.muon_use_nesterov, + "fp32_matmul_prec": config.muon_fp32_matmul_prec, + "num_ns_steps": config.muon_num_ns_steps, + "scale_mode": config.muon_scale_mode, + "extra_scale_factor": config.muon_extra_scale_factor, + "mode": config.muon_tp_mode, + "split_qkv": config.muon_split_qkv, + "is_qkv_fn": lambda p: getattr(p, "is_qkv", False), + "qkv_split_shapes": _get_qkv_split_shapes(model_chunks[0].config), + "pg_collection": pg_collection, + } + + # ----------------------------------------------------------------------- + # Register Muon + # ----------------------------------------------------------------------- + _EMERGING_OPTIMIZERS['muon'] = EmergingOptimizerEntry( + optimizer_cls=TensorParallelMuon, + init_state_fn=_muon_init_state_fn, + config_to_kwargs=_muon_config_to_kwargs, + default_param_overrides={ + ParamKey( + predicate=ParamPredicate( + name="nonlinear_or_embedding", fn=_is_nonlinear_or_embedding + ) + ): {'optimizer': 'adam'} + }, + ) diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index de4396a5b4f..d5dcef209a9 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -63,19 +63,17 @@ def __init__( optimizers ), "init_state_fn_list must be the same length as optimizers if provided" - # wrap optimizer after sharding to avoid unnecessary master weight creation - # for higher precision, optimizers are wrapped with megatron already + # Wrap base torch optimizers with Float16 for bf16 training. + # Callers pass base optimizers; wrapping happens here *after* + # shard_params so master weights are only created for the local shard. if config.bf16: - # unwrap FP32 optimizer, possibly from reusing get_megatron_optimizer for adam for i in range(len(optimizers)): opt = optimizers[i] - if isinstance(opt, Float16OptimizerWithFloat16Params): + if isinstance(opt, (Float16OptimizerWithFloat16Params, FP32Optimizer)): raise TypeError( - 'LayerWiseDistributedOptimizer received Float16 optimizer already.' + 'LayerWiseDistributedOptimizer expects base torch optimizers, ' + f'got {type(opt).__name__}. Do not pre-wrap with Megatron optimizers.' ) - # unwrap FP32 optimizer from reusing get_megatron_optimizer for adam - if isinstance(opt, FP32Optimizer): - opt = opt.optimizer optimizers[i] = Float16OptimizerWithFloat16Params( opt, config, None, init_state_fn_list[i] if init_state_fn_list else None ) diff --git a/megatron/core/optimizer/muon.py b/megatron/core/optimizer/muon.py index 57eb1e94478..a3f7506f941 100644 --- a/megatron/core/optimizer/muon.py +++ b/megatron/core/optimizer/muon.py @@ -1,350 +1,16 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -"""Megatron muon optimizer wrapper to handle tensor-parallel.""" +"""Backward-compatible shim — all code now lives in ``emerging_optimizers``.""" -import logging -from typing import Any, Callable, Dict, List, Literal, Optional +from typing import Any -import torch -from torch.optim.optimizer import ParamsT -from megatron.core.optimizer_param_scheduler import ParamGroupOverride -from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.module import MegatronModule -from megatron.core.utils import get_pg_size, log_single_rank +def get_megatron_muon_optimizer(*args: Any, **kwargs: Any) -> Any: + """Backward compatible muon optimizer getter. -from . import _get_param_groups, get_megatron_optimizer -from .layer_wise_optimizer import LayerWiseDistributedOptimizer -from .optimizer import ( - ChainedOptimizer, - Float16OptimizerWithFloat16Params, - FP32Optimizer, - MegatronOptimizer, -) -from .optimizer_config import OptimizerConfig, ParamKey - -try: - from emerging_optimizers.orthogonalized_optimizers import ( - OrthogonalizedOptimizer, - get_muon_scale_factor, - ) - from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz_tp - - HAVE_EMERGING_OPTIMIZERS = True -except ImportError: - HAVE_EMERGING_OPTIMIZERS = False - OrthogonalizedOptimizer = object - - -logger = logging.getLogger(__name__) - - -class TensorParallelMuon(OrthogonalizedOptimizer): - """Tensor Parallel Muon optimizer.""" - - def __init__( - self, - params: ParamsT, - lr: float = 3e-4, - momentum_beta: float = 0.95, - use_nesterov: bool = True, - weight_decay: float = 0.01, - use_decoupled_weight_decay: bool = True, - split_qkv: bool = False, - is_qkv_fn: Callable[[torch.Tensor], bool] | None = None, - qkv_split_shapes: tuple[int, int, int] | None = None, - fp32_matmul_prec: str = "medium", - coefficient_type: str = "quintic", - num_ns_steps: int = 5, - scale_mode: str = "spectral", - extra_scale_factor: float = 1.0, - pg_collection: Optional[ProcessGroupCollection] = None, - mode: Literal["blockwise", "duplicated", "distributed"] = "duplicated", - ) -> None: - if num_ns_steps < 1: - raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") - - def scaled_orthogonalize_fn( - grad: torch.Tensor, - tp_group: torch.distributed.ProcessGroup, - partition_dim: int | None = None, - ) -> torch.Tensor: - log_single_rank( - logger, - logging.DEBUG, - f'Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, ' - f'{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}', - ) - size = [grad.size(-2), grad.size(-1)] - if partition_dim is not None: - size[partition_dim] *= get_pg_size(tp_group) - orth_grad = newton_schulz_tp( - grad, - steps=num_ns_steps, - coefficient_type=coefficient_type, - tp_group=tp_group, - partition_dim=partition_dim, - mode="duplicated" if mode == "blockwise" else mode, - ) - scale_factor = get_muon_scale_factor(size[0], size[1], mode=scale_mode) - return orth_grad * scale_factor * extra_scale_factor - - self.pg_collection = pg_collection - self.mode = mode - self.split_qkv = split_qkv - self.is_qkv_fn = is_qkv_fn - self.qkv_split_shapes = qkv_split_shapes - - weight_decay_method = "decoupled" if use_decoupled_weight_decay else "l2" - super().__init__( - params, - lr, - momentum_beta, - use_nesterov=use_nesterov, - weight_decay=weight_decay, - weight_decay_method=weight_decay_method, - fp32_matmul_prec=fp32_matmul_prec, - scaled_orthogonalize_fn=scaled_orthogonalize_fn, - ) - - def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor: - """Orthogonalize the momentum. - - Args: - p: The parameter tensor. i is necessary to pass param tensor in addition to momentum - because a lot of information is only available in the param tensor, - attributes for example. - grad: The momentum tensor. - - Returns: - The orthogonalized gradient tensor. - """ - # TODO(deyuf): switch to group - if self.pg_collection: - tp_group = ( - self.pg_collection.expt_tp - if getattr(p, 'expert_tp', False) - else self.pg_collection.tp - ) - else: - tp_group = None - partition_dim = None if self.mode == "blockwise" else getattr(p, "partition_dim", None) - if partition_dim == -1: - # emerging-optimizers use None instead of -1 to indicate no tensor parallel - partition_dim = None - - if self.split_qkv and self.is_qkv_fn(p): # type: ignore[misc] - # split grouped attention parameters (e.g., QKV, GQA, etc.) - grad_shape = grad.shape - log_single_rank( - logger, - logging.DEBUG, - f'qkv split grad shape {grad_shape}, split shapes {self.qkv_split_shapes}', - ) - num_query_groups = grad_shape[0] // sum(self.qkv_split_shapes) - qkv_grads = torch.split( - grad.view(num_query_groups, sum(self.qkv_split_shapes), -1), - self.qkv_split_shapes, - dim=1, - ) - qkv_grads = [g.reshape(-1, grad_shape[-1]) for g in qkv_grads] - - # Apply Newton-Schulz and scales to each component, concat back - qkv_grads = [ - self.scaled_orthogonalize_fn(g, tp_group, partition_dim).view( - num_query_groups, -1, grad_shape[-1] - ) - for g in qkv_grads - ] - grad = torch.cat(qkv_grads, dim=1).view(grad_shape) - else: - grad = self.scaled_orthogonalize_fn(grad, tp_group, partition_dim) - return grad - - -def get_megatron_muon_optimizer( - config: OptimizerConfig, - model_chunks: List[MegatronModule], - config_overrides: Optional[Dict[ParamKey, ParamGroupOverride]] = None, - use_gloo_process_groups: bool = True, - layer_wise_distributed_optimizer: bool = False, - pg_collection: Optional[ProcessGroupCollection] = None, -) -> MegatronOptimizer: - """This function is used to get the muon optimizer for the model chunks. - It is used to get the muon optimizer for the model chunks. - - Args: - config (OptimizerConfig): optimizer configuration object. - model_chunks (List[MegatronModule]): model chunks to get optimizer for. - use_gloo_process_groups (bool): if false, disable use of Gloo process groups - in underlying Megatron optimizers. - layer_wise_distributed_optimizer (bool): if true, use layer-wise distributed optimizer. - Defaults to False. + .. deprecated:: + Use :func:`megatron.core.optimizer.get_megatron_optimizer` instead. """ - # Muon currently use adam config. setting str here to call regular get for adam creation - # side effect is muon optimizer will have wrong name, i.e. config.optimizer == 'adam' - config.optimizer = 'adam' - - assert HAVE_EMERGING_OPTIMIZERS, "Emerging Optimizers is not installed." - - # Dist-opt is not supported due to strong coupling with how DDP init grad buffer - # In theory we can change DDP to enable use muon and dist-opt-adam together - if config.use_distributed_optimizer: - raise Exception('muon with dist optimizer is not supported.') - # only support bf16 w/o loss scale now - if config.fp16: - raise Exception('muon with fp16 is not supported.') - - # before this function receive properly created collection - if pg_collection is None: - pg_collection = ProcessGroupCollection.use_mpu_process_groups() - - log_single_rank(logger, logging.INFO, f'Setting up emerging optimizer with config {config}') - - # Needed for torch_dist ckpt_format, unlike torch ckpt_format - # For other emerging optimizers, need to implement init_state_fn as well - # TODO(boxiangw): Improve usability after optimizer refactor - # TODO(boxiangw): support precision aware optimizer - def muon_init_state_fn(opt, config=None): - for group in opt.param_groups: - for p in group['params']: - if len(opt.state[p]) == 0: - opt.state[p]['momentum_buffer'] = torch.zeros_like(p.data) - - def adam_init_state_fn(opt, config=None): - for group in opt.param_groups: - for p in group['params']: - if len(opt.state[p]) == 0: - if config is None or not config.use_precision_aware_optimizer: - opt.state[p]['exp_avg'] = torch.zeros_like(p.data) - opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data) - else: - opt.initialize_state(p) - - optimizers = [] - # record list of non/linear params - linear_params = [] - nonlinear_params = [] - for model_chunk in model_chunks: - # use config to determine qkv split shapes. - # no need to check tp since tp splits by head and this is per head(group) dimension - num_attention_heads = model_chunk.config.num_attention_heads - num_query_groups = model_chunk.config.num_query_groups - kv_channels = model_chunk.config.kv_channels - qkv_split_shapes = [ - num_attention_heads // num_query_groups * kv_channels, - kv_channels, - kv_channels, - ] - for name, param in model_chunk.named_parameters(): - if not param.requires_grad: - continue - # add flag for expert weight so optimizer can figure which tp group it uses - # alternatively, create new param group and save tp_group. this require more - # change in optimizer - if 'experts' in name and 'shared' not in name: - param.expert_tp = True - # add flag for qkv parameter - # TODO(deyuf): support MLA - if 'linear_qkv.weight' in name and len(param.shape) == 2: - param.is_qkv = True - # TODO(deyuf): currently only allow 2D non-embedding weight to avoid breaking - if ( - not getattr(param, 'is_embedding_or_output_parameter', False) - and len(param.shape) == 2 - ): - linear_params.append(param) - else: - nonlinear_params.append(param) - - muon_kwargs = { - "lr": config.lr, - "momentum_beta": config.muon_momentum, - "use_nesterov": config.muon_use_nesterov, - "weight_decay": config.weight_decay, - "fp32_matmul_prec": config.muon_fp32_matmul_prec, - "num_ns_steps": config.muon_num_ns_steps, - "scale_mode": config.muon_scale_mode, - "split_qkv": config.muon_split_qkv, - "is_qkv_fn": lambda p: getattr(p, "is_qkv", False), - "qkv_split_shapes": qkv_split_shapes, - "extra_scale_factor": config.muon_extra_scale_factor, - "pg_collection": pg_collection, - "mode": config.muon_tp_mode, - } - - # freezing nonlinear params and get param groups for muon - for param in nonlinear_params: - param.requires_grad = False - - linear_param_groups = _get_param_groups(model_chunks, config, config_overrides) - # if layerwise distributed optimizer is not used, need to handle ep params separately - expert_param_groups = [] - if not layer_wise_distributed_optimizer: - for group in linear_param_groups: - if group['is_expert_parallel']: - expert_param_groups.append(group) - linear_param_groups.remove(group) - - optimizer = TensorParallelMuon(linear_param_groups, **muon_kwargs) - - reset_config_bf16 = False - if config.bf16: - if layer_wise_distributed_optimizer: - # creating master weight before layerwise sharding will lead to unnecessary master - # weight so here we delay master weight creation into layer_wise unset config.bf16 - # will also result in all optimizers below(adam) to also not be wrapped - config.bf16 = False - reset_config_bf16 = True - else: - # if not using layer_wise wrapper, just create master weight here is fine - optimizer = Float16OptimizerWithFloat16Params( - optimizer, config, None, muon_init_state_fn - ) - else: - optimizer = FP32Optimizer(optimizer, config, muon_init_state_fn) - - optimizers.append(optimizer) - - # expert optimizer exists meaning layerwise distributed optimizer is not used - if len(expert_param_groups) > 0: - expert_optimizer = TensorParallelMuon(expert_param_groups, **muon_kwargs) - if config.bf16: - expert_optimizer = Float16OptimizerWithFloat16Params( - expert_optimizer, config, None, muon_init_state_fn - ) - else: - expert_optimizer = FP32Optimizer(expert_optimizer, config, muon_init_state_fn) - setattr(expert_optimizer, 'grad_stats_parallel_group', pg_collection.tp_ep_pp) - optimizers.append(expert_optimizer) - - # done with muon, unfreeze nonlinear and freeze linear - for param in nonlinear_params: - param.requires_grad = True - for param in linear_params: - param.requires_grad = False - - # call original get. linear params will be skipped since they're freezed - chained_adam = get_megatron_optimizer( - config, - model_chunks, - config_overrides=config_overrides, - use_gloo_process_groups=use_gloo_process_groups, - ) - - # unfreeze everything - for param in linear_params: - param.requires_grad = True - - # chain everything together - init_fns = [muon_init_state_fn] + len(chained_adam.chained_optimizers) * [adam_init_state_fn] - optimizers += chained_adam.chained_optimizers + from . import get_megatron_optimizer - if layer_wise_distributed_optimizer: - log_single_rank(logger, logging.INFO, 'Using LayerWiseDistributedOptimizer for Muon') - if reset_config_bf16: - config.bf16 = True - return LayerWiseDistributedOptimizer( - optimizers, config, pg_collection, init_state_fn_list=init_fns - ) - return ChainedOptimizer(optimizers) + return get_megatron_optimizer(*args, **kwargs) diff --git a/megatron/core/optimizer/optimizer_config.py b/megatron/core/optimizer/optimizer_config.py index 94163102eb3..4b43e7b5c08 100644 --- a/megatron/core/optimizer/optimizer_config.py +++ b/megatron/core/optimizer/optimizer_config.py @@ -206,7 +206,8 @@ class OptimizerConfig: """dtype of exp_avg_sq when enabling precision-aware-optimizer""" optimizer: str = 'adam' - """Optimizer name. NOTE: Deprecated, use individual optimizer classes instead.""" + """Optimizer name (e.g., 'adam', 'sgd', 'muon'). Can be overridden per-parameter group + via config_overrides to use different optimizers for different parameters.""" ############### # Loss scaling @@ -229,7 +230,7 @@ class OptimizerConfig: """Hysteresis for dynamic loss scaling.""" ################################################################################### - # Optimizer (NOTE: Deprecated, use individual optimizer classes instead.). + # Optimizer-specific parameters. ################################################################################### # Adam. adam_beta1: float = 0.9 @@ -254,10 +255,9 @@ class OptimizerConfig: sgd_momentum: float = 0.9 """Momentum factor for SGD optimizer.""" - # Muon. - # TODO: move muon configs to it's own `MuonConfig`. + # Muon / emerging optimizers. muon_momentum: float = 0.95 - """The momentum used by the internal SGD.""" + """The momentum used by the internal SGD in Muon optimizer.""" muon_split_qkv: bool = True """Whether to split QKV parameters for Muon optimizer.""" @@ -286,6 +286,12 @@ class OptimizerConfig: use_distributed_optimizer: bool = False """Distribute optimizer state over data-parallel replicas.""" + use_layer_wise_distributed_optimizer: bool = False + """Use :class:`LayerWiseDistributedOptimizer` for emerging optimizers (e.g. Muon). + When set via ``--use-distributed-optimizer`` with an emerging optimizer, the training + arguments layer sets this flag and resets ``use_distributed_optimizer`` to False so + that the standard distributed-optimizer path is not triggered.""" + overlap_param_gather: bool = False """If true, overlap param all-gather with forward compute. This argument is intended to have the same value as the "overlap_param_gather" argument @@ -431,33 +437,6 @@ def __post_init__(self): ), "exp_avg_sq_dtype can only be fp32 when not using precision-aware optimizer" -@dataclass -class AdamOptimizerConfig(OptimizerConfig): - """Adam optimizer configuration object.""" - - optimizer: str = 'adam' - """Optimizer name.""" - - adam_beta1: float = 0.9 - """First coefficient for computing running averages of gradient and its square in Adam - optimizer. - """ - - adam_beta2: float = 0.999 - """Second coefficient for computing running averages of gradient and its square in Adam - optimizer. - """ - - adam_eps: float = 1e-08 - """Term added to the denominator to improve numerical stability in Adam optimizer.""" - - -@dataclass -class SGDOptimizerConfig(OptimizerConfig): - """SGD optimizer configuration object.""" - - optimizer: str = 'sgd' - """Optimizer name.""" - - sgd_momentum: float = 0.9 - """Momentum factor for SGD optimizer.""" +# Backward-compatible aliases (deprecated; use OptimizerConfig directly). +AdamOptimizerConfig = OptimizerConfig +SGDOptimizerConfig = OptimizerConfig diff --git a/megatron/core/optimizer_param_scheduler.py b/megatron/core/optimizer_param_scheduler.py index e01a708ce79..91ed362b1b2 100644 --- a/megatron/core/optimizer_param_scheduler.py +++ b/megatron/core/optimizer_param_scheduler.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -class ParamGroupOverride(TypedDict): +class ParamGroupOverride(TypedDict, total=False): """Override values for a parameter group. These values may be optimizer-state/scheduler related. These are the values you see later in param_group.get(...) calls in the @@ -23,7 +23,7 @@ class ParamGroupOverride(TypedDict): Example: >>> param_group_override = ParamGroupOverride(min_lr=1e-4, wd_mult=0.1) - >>> param_group_override == ParamGroupOverride(newvar=3) # this is ok too + >>> param_group_override == ParamGroupOverride(optimizer='muon') # per-param optimizer """ @@ -32,6 +32,7 @@ class ParamGroupOverride(TypedDict): start_wd: float end_wd: float wd_mult: float + optimizer: str def get_canonical_lr_for_logging(param_groups: list[dict]) -> float | None: diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5d5fa34b6c5..cab8a04c59f 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1325,12 +1325,23 @@ def validate_args(args, defaults={}): args.no_load_optim = True warn_rank_0('enabling --no-load-optim when skipping training.') - # Muon optimizer check - if 'muon' in args.optimizer: + # Muon / emerging optimizer check + if args.optimizer in ('muon', 'dist_muon'): + if args.optimizer == 'dist_muon': + warn_rank_0( + "optimizer='dist_muon' is deprecated. " + "Use --optimizer muon --use-distributed-optimizer instead." + ) + args.optimizer = 'muon' + args.use_layer_wise_distributed_optimizer = True + + if args.use_distributed_optimizer: + args.use_layer_wise_distributed_optimizer = True + args.use_distributed_optimizer = False + # TODO: remove these checks once we support them assert not args.overlap_grad_reduce, "Muon optimizer does not support overlap grad reduce for now." assert not args.overlap_param_gather, "Muon optimizer does not support overlap param gather for now." - assert not args.use_distributed_optimizer, "Muon optimizer does not support distributed optimizer for now." assert not args.use_torch_fsdp2, "Muon optimizer does not support Torch-FSDP2 for now." assert not args.use_megatron_fsdp, "Muon optimizer does not support Megatron-FSDP for now." assert args.ckpt_format in ["torch", "torch_dist"], "Muon optimizer supports torch and torch_dist checkpoint format." @@ -2248,7 +2259,9 @@ def _add_training_args(parser): 'https://arxiv.org/abs/2205.14135') group.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd', 'muon', 'dist_muon'], - help='Optimizer function') + help='Optimizer function. ' + 'Note: dist_muon is deprecated; use --optimizer muon ' + 'with --use-distributed-optimizer instead.') group.add_argument('--optimizer-cpu-offload', action='store_true', help='Offload optimizer state to CPU') group.add_argument('--optimizer-offload-fraction', type=float, default=1.0, diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index a64d0cd318c..d9204f9007d 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -563,7 +563,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati optimizer.save_parameter_state(optim_checkpoint_name) # LayerWiseDistributedOptimizer save optimizer state to file on different ranks - if getattr(args, "optimizer", "adam").startswith("dist_") and args.ckpt_format == 'torch': + if getattr(args, "use_layer_wise_distributed_optimizer", False) and args.ckpt_format == 'torch': dp_rank = mpu.get_data_parallel_rank() optim_checkpoint_name = os.path.join(os.path.dirname(checkpoint_name), f"layer_wise_optimizer_{dp_rank}.pt") ensure_directory_exists(optim_checkpoint_name) @@ -1809,7 +1809,7 @@ def load_model_state_dict(module, state_dict, strict: bool): if not release and not args.finetune and not args.no_load_optim: try: # Load state dict. - if getattr(args, "optimizer", "adam").startswith("dist_") and args.ckpt_format == 'torch': + if getattr(args, "use_layer_wise_distributed_optimizer", False) and args.ckpt_format == 'torch': # LayerWiseDistributedOptimizer load optimizer state from file on different ranks dp_rank = mpu.get_data_parallel_rank() optim_checkpoint_name = os.path.join(os.path.dirname(checkpoint_name), f"layer_wise_optimizer_{dp_rank}.pt") diff --git a/megatron/training/training.py b/megatron/training/training.py index 0c33206ba8b..6d7a548b947 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -126,8 +126,11 @@ def set_startup_timestamps(program_start=None, main_entry=None): from megatron.core.distributed import finalize_model_grads from megatron.core.enums import ModelType -from megatron.core.optimizer import get_megatron_optimizer, AdamOptimizerConfig, SGDOptimizerConfig, OptimizerConfig, ParamKey -from megatron.core.optimizer.muon import get_megatron_muon_optimizer +from megatron.core.optimizer import ( + get_megatron_optimizer, + OptimizerConfig, + ParamKey, +) from megatron.core.rerun_state_machine import ( get_rerun_state_machine, destroy_rerun_state_machine, @@ -1478,23 +1481,11 @@ def get_optimizer_param_scheduler(optimizer): def get_megatron_optimizer_config(args: Any) -> OptimizerConfig: """Return a Megatron optimizer config object from Megatron's arguments.""" - config = None - if args.optimizer == 'adam' or 'muon' in args.optimizer: - # TODO(deyuf): Muon needs both adam + muon but get() only receive one config - # So for now we keep using adam config that's back compat with old way - kwargs = {} - for f in dataclasses.fields(AdamOptimizerConfig): - if hasattr(args, f.name): - kwargs[f.name] = getattr(args, f.name) - config = AdamOptimizerConfig(**kwargs) - elif args.optimizer == 'sgd': - kwargs = {} - for f in dataclasses.fields(SGDOptimizerConfig): - if hasattr(args, f.name): - kwargs[f.name] = getattr(args, f.name) - config = SGDOptimizerConfig(**kwargs) - else: - raise ValueError("Invalid optimizer type!") + kwargs = {} + for f in dataclasses.fields(OptimizerConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + config = OptimizerConfig(**kwargs) # Construct the appropriate config_overrides object. This default handles many cases, but # can be added to as needed by the user, or replaced entirely with a custom override. @@ -1524,25 +1515,13 @@ def setup_model_and_optimizer( config, config_overrides = get_megatron_optimizer_config(args) config.timers = timers - if 'muon' not in config.optimizer: - # If the user is asking for a non-zero embedding init std, skip weight decay for embeddings - # to avoid embeddings from shrinking to zero as recommended in https://arxiv.org/abs/2312.16903 - # default_skip_embedding_weight_decay=args.embedding_init_method_std is not None, - optimizer = get_megatron_optimizer( - config, - model, - config_overrides=config_overrides, - use_gloo_process_groups=args.enable_gloo_process_groups, - dump_param_to_param_group_map=args.dump_param_to_param_group_map, - ) - else: - optimizer = get_megatron_muon_optimizer( - config, - model, - config_overrides=config_overrides, - use_gloo_process_groups=args.enable_gloo_process_groups, - layer_wise_distributed_optimizer='dist' in config.optimizer, - ) + optimizer = get_megatron_optimizer( + config, + model, + config_overrides=config_overrides, + use_gloo_process_groups=args.enable_gloo_process_groups, + dump_param_to_param_group_map=args.dump_param_to_param_group_map, + ) opt_param_scheduler = get_optimizer_param_scheduler(optimizer) one_logger and one_logger.log_metrics({"app_build_optimzer_finish_time": one_logger_utils.get_timestamp_in_ms()}) diff --git a/tests/unit_tests/dist_checkpointing/utils.py b/tests/unit_tests/dist_checkpointing/utils.py index dd12ecd7684..cf6662c72bf 100644 --- a/tests/unit_tests/dist_checkpointing/utils.py +++ b/tests/unit_tests/dist_checkpointing/utils.py @@ -12,7 +12,7 @@ get_gpt_layer_with_transformer_engine_spec, ) from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer -from megatron.core.optimizer.muon import get_megatron_muon_optimizer +from megatron.core.optimizer.optimizer import ChainedOptimizer from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed from megatron.core.transformer import TransformerConfig from megatron.training.arguments import parse_args @@ -172,11 +172,6 @@ def init_checkpointing_mock_args(args, ckpt_dir, fully_parallel=False): def setup_model_and_optimizer( seed, tp, pp, initialize_fn=initialize_gpt_model, bf16=True, dist_opt=True, optimizer='adam' ): - if 'muon' in optimizer and dist_opt: - raise ValueError( - "Layer-wise distributed optimizer with Muon is not supported with distributed optimizer." - ) - mock_args = parse_args(ignore_unknown_args=True) with mock.patch('megatron.training.training.get_args', new=lambda: mock_args): init_basic_mock_args(mock_args, tp, pp, bf16=bf16) @@ -191,37 +186,39 @@ def setup_model_and_optimizer( ) ) + optimizer_type = optimizer + use_layer_wise = False + if optimizer_type == 'dist_muon': + optimizer = 'muon' + use_layer_wise = True + if optimizer_type in ('muon', 'dist_muon') and dist_opt: + use_layer_wise = True + dist_opt = False + config = OptimizerConfig( bf16=bf16, params_dtype=torch.bfloat16 if bf16 else torch.float, use_distributed_optimizer=dist_opt, + use_layer_wise_distributed_optimizer=use_layer_wise, optimizer=optimizer, ) - if 'muon' in optimizer: - # Use layer-wise distributed optimizer with Muon - optimizer_type = optimizer - # default lr None feels wrong. only change muon lr to avoid breaking old tests + if optimizer_type in ('muon', 'dist_muon'): config.lr = 0.0 - optimizer = get_megatron_muon_optimizer( - config, model, layer_wise_distributed_optimizer='dist' in optimizer_type - ) - else: - optimizer_type = optimizer - optimizer = get_megatron_optimizer(config, model) + optimizer = get_megatron_optimizer(config, model) torch.manual_seed(seed + 1) model_parallel_cuda_manual_seed(seed + 1) - if not 'muon' in optimizer_type: + if isinstance(optimizer, ChainedOptimizer): + for opt in optimizer.chained_optimizers: + opt.init_state_fn(opt) + else: for group in optimizer.optimizer.param_groups: for p in group['params']: if len(optimizer.optimizer.state[p]) == 0: optimizer.optimizer.state[p]['exp_avg'] = torch.rand_like(p.data) optimizer.optimizer.state[p]['exp_avg_sq'] = torch.rand_like(p.data) - else: - for opt in optimizer.chained_optimizers: - opt.init_state_fn(opt) optimizer.reload_model_params() @@ -266,10 +263,6 @@ def setup_moe_model_and_optimizer( use_glu=False, optimizer='adam', ): - if 'muon' in optimizer and dist_opt: - raise ValueError( - "Layer-wise distributed optimizer with Muon is not supported with distributed optimizer." - ) mock_args = parse_args(ignore_unknown_args=True) with mock.patch('megatron.training.training.get_args', new=lambda: mock_args): init_basic_mock_args(mock_args, tp, pp, bf16=bf16) @@ -289,37 +282,40 @@ def setup_moe_model_and_optimizer( ) ) + optimizer_type = optimizer + use_layer_wise = False + if optimizer_type == 'dist_muon': + optimizer = 'muon' + use_layer_wise = True + if optimizer_type in ('muon', 'dist_muon') and dist_opt: + use_layer_wise = True + dist_opt = False + config = OptimizerConfig( bf16=bf16, params_dtype=torch.bfloat16 if bf16 else torch.float, use_distributed_optimizer=dist_opt, + use_layer_wise_distributed_optimizer=use_layer_wise, optimizer=optimizer, ) - if 'muon' in optimizer: - optimizer_type = optimizer - # default lr None feels wrong. only change muon lr to avoid breaking old tests + if optimizer_type in ('muon', 'dist_muon'): config.lr = 0.0 - optimizer = get_megatron_muon_optimizer( - config, model, layer_wise_distributed_optimizer='dist' in optimizer_type - ) - else: - optimizer_type = optimizer - optimizer = get_megatron_optimizer(config, model) + optimizer = get_megatron_optimizer(config, model) torch.manual_seed(seed + 1) model_parallel_cuda_manual_seed(seed + 1) - if not 'muon' in optimizer_type: + if optimizer_type in ('muon', 'dist_muon'): + for opt in optimizer.chained_optimizers: + opt.init_state_fn(opt) + else: for opt in optimizer.chained_optimizers: for group in opt.param_groups: for p in group['params']: if len(opt.state[p]) == 0: opt.state[p]['exp_avg'] = torch.rand_like(p.data) opt.state[p]['exp_avg_sq'] = torch.rand_like(p.data) - else: - for opt in optimizer.chained_optimizers: - opt.init_state_fn(opt) optimizer.reload_model_params() diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index 05ce26bcfa0..9b404b388b4 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -124,9 +124,11 @@ def create_model_and_optimizer( optimizer = get_megatron_optimizer(optimizer_config, [model]) if use_layer_wise: + # Extract base torch optimizers from the FP32Optimizer wrappers. + base_optimizers = [opt.optimizer for opt in optimizer.chained_optimizers] optimizer_config.bf16 = True optimizer = LayerWiseDistributedOptimizer( - optimizer.chained_optimizers, optimizer_config, pg_collection + base_optimizers, optimizer_config, pg_collection ) return model, optimizer, pg_collection @@ -281,19 +283,16 @@ def test_multiple_optimizers(self): param_groups_1 = [{'params': params[:mid_point]}] param_groups_2 = [{'params': params[mid_point:]}] - # Create two separate base optimizers + # Create two separate plain base optimizers (LayerWise wraps them itself) base_optimizer_1 = torch.optim.Adam(param_groups_1, lr=optimizer_config.lr) base_optimizer_2 = torch.optim.Adam(param_groups_2, lr=optimizer_config.lr) - wrapped_optimizer_1 = FP32Optimizer(base_optimizer_1, optimizer_config, None) - wrapped_optimizer_2 = FP32Optimizer(base_optimizer_2, optimizer_config, None) - pg_collection = ProcessGroupCollection.use_mpu_process_groups() pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() optimizer = LayerWiseDistributedOptimizer( - [wrapped_optimizer_1, wrapped_optimizer_2], optimizer_config, pg_collection + [base_optimizer_1, base_optimizer_2], optimizer_config, pg_collection ) assert len(optimizer.chained_optimizers) == 2, "Should have two chained optimizers" @@ -347,9 +346,9 @@ def test_bf16_error(self): pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() - # Should raise TypeError when receiving already-wrapped Float16 optimizer + # Should raise TypeError when receiving already-wrapped optimizer with pytest.raises( - TypeError, match='LayerWiseDistributedOptimizer received Float16 optimizer already' + TypeError, match='LayerWiseDistributedOptimizer expects base torch optimizers' ): LayerWiseDistributedOptimizer([wrapped_optimizer], optimizer_config, pg_collection) diff --git a/tests/unit_tests/test_muon_optimizer.py b/tests/unit_tests/test_muon_optimizer.py index cc99f7a16e6..86d75ee7a49 100644 --- a/tests/unit_tests/test_muon_optimizer.py +++ b/tests/unit_tests/test_muon_optimizer.py @@ -10,8 +10,8 @@ from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig -from megatron.core.optimizer import OptimizerConfig -from megatron.core.optimizer.muon import TensorParallelMuon, get_megatron_muon_optimizer +from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer +from megatron.core.optimizer.emerging_optimizers import TensorParallelMuon from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer import TransformerConfig from tests.unit_tests.test_utilities import Utils @@ -129,8 +129,8 @@ def create_ddp_model(self, model): TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model ) - def test_get_megatron_muon_optimizer_smoke(self): - """Smoke test for get_megatron_muon_optimizer function.""" + def test_get_megatron_optimizer_smoke(self): + """Smoke test for get_megatron_optimizer function.""" model = Net().bfloat16().cuda() model.requires_grad_(True) model = self.create_ddp_model(model) @@ -155,11 +155,8 @@ def test_get_megatron_muon_optimizer_smoke(self): ) # Test creating the optimizer - optimizer = get_megatron_muon_optimizer( - config=optimizer_config, - model_chunks=[model], - use_gloo_process_groups=True, - layer_wise_distributed_optimizer=False, + optimizer = get_megatron_optimizer( + config=optimizer_config, model_chunks=[model], use_gloo_process_groups=True ) # Test basic properties @@ -204,24 +201,13 @@ def test_get_megatron_muon_optimizer_smoke(self): # Load state dict should not raise error optimizer.load_state_dict(state_dict) - def test_get_megatron_muon_optimizer_validation(self): - """Test validation logic for get_megatron_muon_optimizer.""" + def test_get_megatron_optimizer_validation(self): + """Test validation logic for get_megatron_optimizer.""" model = torch.nn.Linear(100, 50, bias=False, dtype=torch.bfloat16, device='cuda') model.requires_grad_(True) model = self.create_ddp_model(model) - # Test 1: Distributed optimizer should raise exception - optimizer_config_dist = OptimizerConfig( - optimizer='muon', - lr=0.01, - bf16=True, - use_distributed_optimizer=True, # This should cause an exception - ) - - with pytest.raises(Exception, match='muon with dist optimizer is not supported'): - get_megatron_muon_optimizer(config=optimizer_config_dist, model_chunks=[model]) - - # Test 2: FP16 should raise exception + # Test 1: FP16 should raise exception optimizer_config_fp16 = OptimizerConfig( optimizer='muon', lr=0.01, @@ -229,8 +215,8 @@ def test_get_megatron_muon_optimizer_validation(self): use_distributed_optimizer=False, ) - with pytest.raises(Exception, match='muon with fp16 is not supported'): - get_megatron_muon_optimizer(config=optimizer_config_fp16, model_chunks=[model]) + with pytest.raises(Exception, match='emerging optimizer with fp16 is not supported'): + get_megatron_optimizer(config=optimizer_config_fp16, model_chunks=[model]) # Test 3: Invalid num_ns_steps should raise exception optimizer_config_invalid_ns = OptimizerConfig( @@ -242,10 +228,10 @@ def test_get_megatron_muon_optimizer_validation(self): ) with pytest.raises(ValueError, match='num_ns_steps must be at least 1'): - get_megatron_muon_optimizer(config=optimizer_config_invalid_ns, model_chunks=[model]) + get_megatron_optimizer(config=optimizer_config_invalid_ns, model_chunks=[model]) - def test_get_megatron_muon_optimizer_layer_wise(self): - """Test get_megatron_muon_optimizer with layer-wise distributed optimizer.""" + def test_get_megatron_optimizer_layer_wise(self): + """Test get_megatron_optimizer with layer-wise distributed optimizer.""" model = Net().bfloat16().cuda() model.requires_grad_(True) model = self.create_ddp_model(model) @@ -255,7 +241,7 @@ def test_get_megatron_muon_optimizer_layer_wise(self): lr=0.01, weight_decay=0.01, bf16=True, - use_distributed_optimizer=False, + use_layer_wise_distributed_optimizer=True, muon_momentum=0.95, muon_use_nesterov=True, muon_fp32_matmul_prec="medium", @@ -264,12 +250,9 @@ def test_get_megatron_muon_optimizer_layer_wise(self): muon_tp_mode="duplicated", ) - # Test with layer_wise_distributed_optimizer=True - optimizer = get_megatron_muon_optimizer( - config=optimizer_config, - model_chunks=[model], - use_gloo_process_groups=True, - layer_wise_distributed_optimizer=True, + # use_layer_wise_distributed_optimizer=True triggers LayerWiseDistributedOptimizer + optimizer = get_megatron_optimizer( + config=optimizer_config, model_chunks=[model], use_gloo_process_groups=True ) # Verify it's a LayerWiseDistributedOptimizer diff --git a/tests/unit_tests/test_optimizer.py b/tests/unit_tests/test_optimizer.py index 2488900ba72..56af8545042 100644 --- a/tests/unit_tests/test_optimizer.py +++ b/tests/unit_tests/test_optimizer.py @@ -106,10 +106,10 @@ def test_get_param_groups_no_overrides(mock_get_world_size): def test_get_param_groups_default_overrides(mock_get_world_size): """Test that the default overrides are applied to the parameter groups.""" net = Net() - # NOTE: to get legacy default overrides, supply None. opt_config = OptimizerConfig(optimizer='adam', lr=0.01) - check_config_overrides_consistency(opt_config, None) - param_groups = _get_param_groups([net], opt_config, None) + config_overrides = get_standard_config_overrides(opt_config) + check_config_overrides_consistency(opt_config, config_overrides) + param_groups = _get_param_groups([net], opt_config, config_overrides) assert len(param_groups) == 2 pg0, pg1 = param_groups wd_mults = {pg0['wd_mult'], pg1['wd_mult']}