Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 42 additions & 31 deletions megatron/core/optimizer/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,24 @@ def get_grad_norm_fp32(
norm_type: Union[int, float] = 2,
grad_stats_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> float:
"""Calculate the norm of gradients in fp32.
"""Calculate the p-norm of gradients in FP32 precision.

This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters.
This function is adapted from `torch.nn.utils.clip_grad.clip_grad_norm_`
and extends it with functionality to handle model-parallel parameters.
It ensures that the norm is correctly computed and reduced across
the specified process group (typically the model-parallel group for
non-distributed optimizers or the entire world for distributed optimizers).

Arguments:
grads_for_norm (Iterable[Tensor] or Tensor): an iterable of Tensors or a single
Tensor that will be used for calculating the grad norm.
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
grad_stats_parallel_group (group): Process group for reducing the grad norms. This is
generally the model-parallel group for non-distributed optimizers, and the entire
world for the distributed optimizer.
Args:
grads_for_norm (Union[List[torch.Tensor], torch.Tensor]): An iterable
of Tensors or a single Tensor used to calculate the gradient norm.
norm_type (Union[int, float]): The type of the p-norm to use. Can be
'inf' for infinity norm. Defaults to 2.
grad_stats_parallel_group (ProcessGroup, optional): The process group
used for reducing gradient statistics (e.g., norms and zero counts).

Returns:
Total norm of the parameters (viewed as a single vector).
float: The total norm of the parameters, treated as a single vector.
"""

if isinstance(grads_for_norm, torch.Tensor):
Expand Down Expand Up @@ -141,17 +143,19 @@ def clip_grad_by_total_norm_fp32(
total_norm: float,
use_decoupled_grad: bool = False,
):
"""Clips gradient of an iterable of parameters in fp32 by total norm.
"""Clips the gradients of an iterable of parameters in FP32 by total norm.

Note that the gradients are modified in place.
Note that the gradients are modified in-place.

Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized.
max_norm (float or int): max norm of the gradients.
total_norm (float): total norm of the gradients.
use_decoupled_grad (bool, optional): whether to read grad from ".grad" or ".decoupled_grad",
default value is False.
parameters (Union[List[torch.Tensor], torch.Tensor]): An iterable of
Tensors or a single Tensor that will have gradients normalized.
max_norm (Union[int, float]): The maximum permissible total norm
of the gradients.
total_norm (float): The current total norm of the gradients.
use_decoupled_grad (bool, optional): Whether to read from the
'.decoupled_grad' attribute instead of the standard '.grad'.
Defaults to False.
"""
# Grads.
params = []
Expand Down Expand Up @@ -183,18 +187,25 @@ def count_zeros_fp32(
use_decoupled_grad: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> float:
"""Counts the number of zeros in gradients associated with the passed-in list of
parameters.
"""Counts the number of zero values in the gradients of the given parameters.

The count is performed in FP32. This method filters parameters to ensure
gradients are not double-counted by checking if the gradient is not None,
the parameter is not shared, and the parameter is not a replica due
to tensor model parallelism. It also handles parameters managed by
Megatron FSDP specifically.

Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have the number of zeros in its corresponding
gradient counted.
grad_stats_parallel_group (group): Process group for reducing the num_zeros count. This is
generally the model-parallel group for non-distributed optimizers, and the entire
world for the distributed optimizer.
use_decoupled_grad (bool, optional) whether to read grad from ".grad" or ".decoupled_grad",
default value is False.
parameters (Union[List[torch.Tensor], torch.Tensor]): An iterable of
Tensors or a single Tensor whose gradients will be checked for zeros.
grad_stats_parallel_group (ProcessGroup): The process group used for
reducing the zero count across distributed ranks.
use_decoupled_grad (bool, optional): If True, reads from the
'.decoupled_grad' attribute instead of the standard '.grad'.
Defaults to False.

Returns:
float: The total number of zeros in the gradients across the process group.
"""

if isinstance(parameters, torch.Tensor):
Expand Down Expand Up @@ -245,4 +256,4 @@ def count_zeros_fp32(

total_num_zeros = total_num_zeros.item()

return total_num_zeros
return total_num_zeros
95 changes: 54 additions & 41 deletions megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,15 @@ def __len__(self):


class DistributedOptimizer(MixedPrecisionOptimizer):
"""Distributed optimizer, for all data types (fp16, bf16, and fp32).
"""Optimizer that shards state across data-parallel ranks.

See __init__() below for argument details.
This class reduces memory usage by distributing optimizer states (like
momentum and variance buffers) across GPUs in the data-parallel group.

Attributes:
model_chunks (List[MegatronModule]): Model segments being optimized.
per_model_buffers (Dict): Buffers managing contiguous params/grads.
data_parallel_group (ProcessGroup): Group for sharding and all-gathers.
"""

# enumerates fully reshardable optimizer formats (as opposed to formats
Expand All @@ -115,8 +121,7 @@ def _build_model_gbuf_param_range_map(
gbuf_world_range: Range,
bucket_offset: int,
):
"""
Build mapping from param reference to grad buffer shard ranges.
"""Build mapping from param reference to grad buffer shard ranges.

This method builds a mapping from parameter references to grad
buffer shard ranges, specific to each data-parallel (DP) rank's
Expand All @@ -135,10 +140,15 @@ def _build_model_gbuf_param_range_map(
main & main-to-model operations.

This method creates four ranges:
- The param's range within the entire grad buffer (i.e., world index).
- The param's range within the relevant grad bucket's buffer.
- The param's range within the DP rank's local view of the grad buffer.
- The param's range within itself (i.e., its shard).
- gbuf_world: The param's range within the entire grad buffer (world index).
- gbuf_world_in_bucket: The param's range within the relevant grad bucket's buffer.
- gbuf_local: The param's range within the DP rank's local view of the grad buffer.
- param: The param's range within itself (i.e., its shard).

Args:
param_world_index_map (Dict): Mapping from parameter to its world indexes.
gbuf_world_range (Range): The range of the grad buffer owned by this rank.
bucket_offset (int): The offset of the current bucket within the grad buffer.
"""

# Param range map.
Expand Down Expand Up @@ -219,16 +229,18 @@ def _build_model_gbuf_range(cls, param_and_grad_buffer: _ParamAndGradBuffer, buc

@classmethod
def _build_gbuf_range_map(cls, param_and_grad_buffer: _ParamAndGradBuffer):
"""
Build mapping between params and their grad buffers. These mappings are
partitioned according to data type.
"""Builds a map between parameters and their ranges in the grad buffer.

Iterate through all buckets of grad buffer to construct param ranges
that this rank "owns" (the dp_rank'th shard of each bucket, where each
shard is 1/dp_world_size of the bucket).
These mappings are partitioned according to data type. This method
iterates through all buckets of a grad buffer to construct param
ranges that this rank "owns" (the dp_rank'th shard of each bucket,
where each shard is 1/dp_world_size of the bucket).

Args:
param_and_grad_buffer (_ParamAndGradBuffer): buffer to build mapping for.
param_and_grad_buffer (_ParamAndGradBuffer): The buffer to map.

Returns:
Dict: Mapping of parameter dtypes to bucket ranges.
"""
return {
(param_and_grad_buffer.param_dtype, param_and_grad_buffer.grad_dtype): [
Expand Down Expand Up @@ -469,36 +481,36 @@ def __init__(
data_parallel_group_idx: int,
distributed_optimizer_instance_id: int,
):
"""
Distributed optimizer, for all data types (fp16, bf16, and fp32).
"""Initializes the distributed optimizer for FP16, BF16, and FP32.

The steps in this method create the core mapping between param and grad buffers,
parameters, and parameter shard ranges, that is needed for converting between model
param indexes and main parameter shard indexes. This method also updates the optimizer
parameter groups with the newly created shards.
The steps in this method create the core mapping between param and grad
buffers, parameters, and parameter shard ranges, that is needed for
converting between model param indexes and main parameter shard indexes.
This method also updates the optimizer parameter groups with the
newly created shards.

Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
grad_scaler (MegatronGradScaler): used for scaling gradients. Note that
this can be None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constant gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
model_chunks (List[MegatronModule]): list of model chunks.
per_model_buffers (Dict[int, List[_ParamAndGradBuffer]]): the implementation of the
distributed optimizer is centered on using a contiguous buffer for
communicating grads & params between the model state and the optimizer state.
You can find a more detailed description in
https://github.com/NVIDIA/Megatron-LM/blob/main/docs/source/distrib_optimizer.md.
data_parallel_group (torch.distributed.ProcessGroup): data-parallel group to use to
optimizer (torch.optim.Optimizer): Base optimizer such as Adam or SGD.
config (OptimizerConfig): Configuration object for the optimizer.
grad_scaler (MegatronGradScaler): Used for scaling gradients. Note that
this can be None for BF16 training if no loss scale is used.
For FP16, a grad scaler is always required.
init_state_fn (Callable, optional): Function to initialize state in
the optimizer.
model_chunks (List[MegatronModule]): List of model chunks to optimize.
per_model_buffers (Dict[int, List[_ParamAndGradBuffer]]): The
implementation of the distributed optimizer is centered on using
a contiguous buffer for communicating grads & params between
the model state and the optimizer state. For a detailed
description, see `docs/source/distrib_optimizer.md`.
data_parallel_group (ProcessGroup): Data-parallel group used to
all-gather params after optimizer.step().
data_parallel_group_gloo (torch.distributed.ProcessGroup): gloo data-parallel group
(used in checkpoint loading and saving).
data_parallel_group_idx (int): index in data-parallel group (used by
distributed checkpointing logic).
distributed_optimizer_instance_id (int): index of the Distributed Optimizer instance.
data_parallel_group_gloo (ProcessGroup, optional): Gloo data-parallel
group used specifically for checkpoint loading and saving.
data_parallel_group_idx (int): Index in the data-parallel group
used by distributed checkpointing logic.
distributed_optimizer_instance_id (int): Unique identifier for the
distributed optimizer instance.
"""

if has_config_logger_enabled(config):
Expand Down Expand Up @@ -2633,3 +2645,4 @@ def step_with_ready_grads(self) -> bool:
timers('params-all-gather').stop()

return update_successful

35 changes: 29 additions & 6 deletions megatron/core/optimizer/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@


class MegatronGradScaler(ABC):
"""Abstract base class for gradient scalers.

Args:
initial_scale (float): The initial value for the loss scale.
"""
def __init__(self, initial_scale: float):
"""Initialize scale value with the input initial scale."""
assert initial_scale > 0.0
Expand Down Expand Up @@ -37,7 +42,10 @@ def load_state_dict(self, state_dict: Dict):

class ConstantGradScaler(MegatronGradScaler):
"""
Constant grad scaler (loss scale is never adjusted regardless of NaNs seen in gradients).
Grad scaler with a fixed scale factor.

The loss scale is never adjusted, regardless of whether NaNs or Infs
are detected in the gradients.
"""

def update(self, found_inf: bool):
Expand All @@ -51,11 +59,26 @@ def load_state_dict(self, state_dict):


class DynamicGradScaler(MegatronGradScaler):
"""
Grad scaler with dynamic scale that gets adjusted during training.

Reduces loss scale by `backoff_factor` if `hysteresis` number of NaNs are seen in a row. Increases
loss scale by `growth_factor` if NaNs are not seen for `growth_interval` iterations.
"""Gradient scaler with a dynamic scale factor adjusted during training.

This class implements a loss scaling strategy to prevent numerical underflow
during mixed-precision training. It reduces the loss scale by a
`backoff_factor` if a `hysteresis` number of NaNs/Infs are detected in
consecutive iterations. Conversely, it increases the loss scale by a
`growth_factor` if no non-finite gradients are seen for a specified
`growth_interval` of iterations.

Args:
initial_scale (float): The starting value for the loss scale.
min_scale (float): The lower bound for the loss scale.
growth_factor (float): The multiplier used to increase the scale when
gradients are stable. Must be greater than 1.0.
backoff_factor (float): The multiplier used to decrease the scale when
non-finite gradients are detected. Must be between 0.0 and 1.0.
growth_interval (int): The number of consecutive stable iterations
required before increasing the scale.
hysteresis (int): The number of consecutive non-finite iterations
required before decreasing the scale.
"""

def __init__(
Expand Down
25 changes: 15 additions & 10 deletions megatron/core/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,13 @@ class MegatronOptimizer(ABC):
"""
Base class for all Megatron optimizers.

Provides a consistent interface for gradient management, parameter
access, and state-dict handling across different optimization types.

Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
optimizer (torch.optim.Optimizer): The base PyTorch optimizer.
config (OptimizerConfig): The optimizer configuration.
init_state_fn (Callable, optional): Function to initialize optimizer state.
"""

def __init__(
Expand Down Expand Up @@ -135,13 +138,15 @@ def get_parameters(self) -> List[torch.nn.Parameter]:
return params

def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]:
"""
Get main_grads that should be taken into account to compute the grad norm.
Filter parameters based on:
- grad should not be None.
- parameter should not be shared (i.e., grads shouldn't be double counted while
computing norms).
- should not be a replica due to tensor model parallelism.

"""Collects gradients for norm calculation, filtering duplicates.

This method filters parameters based on whether the gradient is not None,
the parameter is not shared (to avoid double-counting gradients), and
the parameter is not a replica due to tensor model parallelism.

Returns:
List[torch.Tensor]: A list of gradient tensors filtered for norm calculation.
"""
params = self.get_parameters()
grads_for_norm = []
Expand Down
37 changes: 36 additions & 1 deletion megatron/core/optimizer/optimizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,42 @@ def matches(self, param: torch.nn.Parameter, param_name: str) -> bool:

@dataclass
class OptimizerConfig:
"""Base optimizer configuration object."""
"""Configuration object for Megatron optimizers.

Attributes:
lr (float, optional): Initial learning rate. Defaults to None.
min_lr (float, optional): Minimum learning rate for scheduler clipping.
weight_decay (float): L2 regularization coefficient. Defaults to 0.01.

fp8_recipe (str, optional): Type of FP8 recipe affecting distributed
optimizer logic.
fp16 (bool): If True, use FP16 mixed precision. Defaults to False.
bf16 (bool): If True, use BF16 mixed precision. Defaults to False.
params_dtype (torch.dtype): Dtype used for weight initialization.
use_precision_aware_optimizer (bool): Allows lower precision for
master params and states.

loss_scale (float, optional): Static loss scaling factor.
initial_loss_scale (float): Initial scale for dynamic scaling.
min_loss_scale (float): Minimum scale for dynamic scaling.
loss_scale_window (float): Window size for dynamic scale adjustments.
hysteresis (int): Delay iterations for dynamic scale reduction.

adam_beta1 (float): Adam beta1 coefficient. Defaults to 0.9.
adam_beta2 (float): Adam beta2 coefficient. Defaults to 0.999.
adam_eps (float): Adam epsilon for numerical stability.
decoupled_weight_decay (bool): Whether to use AdamW-style decay.

use_distributed_optimizer (bool): If True, shard state across DP ranks.
overlap_param_gather (bool): Overlap param all-gather with forward compute.

optimizer_cpu_offload (bool): If True, offload state and compute to CPU.
optimizer_offload_fraction (float): Percentage of state to offload.

clip_grad (float): Global L2 norm threshold for gradient clipping.
log_num_zeros_in_grad (bool): Whether to log count of zero gradients.
timers (Callable, optional): Timer utility function.
"""

##############
# General
Expand Down
Loading