diff --git a/megatron/core/optimizer/clip_grads.py b/megatron/core/optimizer/clip_grads.py index cb2f23a685f..d1e35ad0670 100644 --- a/megatron/core/optimizer/clip_grads.py +++ b/megatron/core/optimizer/clip_grads.py @@ -53,22 +53,24 @@ def get_grad_norm_fp32( norm_type: Union[int, float] = 2, grad_stats_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> float: - """Calculate the norm of gradients in fp32. + """Calculate the p-norm of gradients in FP32 precision. - This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. + This function is adapted from `torch.nn.utils.clip_grad.clip_grad_norm_` + and extends it with functionality to handle model-parallel parameters. + It ensures that the norm is correctly computed and reduced across + the specified process group (typically the model-parallel group for + non-distributed optimizers or the entire world for distributed optimizers). - Arguments: - grads_for_norm (Iterable[Tensor] or Tensor): an iterable of Tensors or a single - Tensor that will be used for calculating the grad norm. - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - grad_stats_parallel_group (group): Process group for reducing the grad norms. This is - generally the model-parallel group for non-distributed optimizers, and the entire - world for the distributed optimizer. + Args: + grads_for_norm (Union[List[torch.Tensor], torch.Tensor]): An iterable + of Tensors or a single Tensor used to calculate the gradient norm. + norm_type (Union[int, float]): The type of the p-norm to use. Can be + 'inf' for infinity norm. Defaults to 2. + grad_stats_parallel_group (ProcessGroup, optional): The process group + used for reducing gradient statistics (e.g., norms and zero counts). Returns: - Total norm of the parameters (viewed as a single vector). + float: The total norm of the parameters, treated as a single vector. """ if isinstance(grads_for_norm, torch.Tensor): @@ -141,17 +143,19 @@ def clip_grad_by_total_norm_fp32( total_norm: float, use_decoupled_grad: bool = False, ): - """Clips gradient of an iterable of parameters in fp32 by total norm. + """Clips the gradients of an iterable of parameters in FP32 by total norm. - Note that the gradients are modified in place. + Note that the gradients are modified in-place. Args: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized. - max_norm (float or int): max norm of the gradients. - total_norm (float): total norm of the gradients. - use_decoupled_grad (bool, optional): whether to read grad from ".grad" or ".decoupled_grad", - default value is False. + parameters (Union[List[torch.Tensor], torch.Tensor]): An iterable of + Tensors or a single Tensor that will have gradients normalized. + max_norm (Union[int, float]): The maximum permissible total norm + of the gradients. + total_norm (float): The current total norm of the gradients. + use_decoupled_grad (bool, optional): Whether to read from the + '.decoupled_grad' attribute instead of the standard '.grad'. + Defaults to False. """ # Grads. params = [] @@ -183,18 +187,25 @@ def count_zeros_fp32( use_decoupled_grad: bool = False, tp_group: Optional[torch.distributed.ProcessGroup] = None, ) -> float: - """Counts the number of zeros in gradients associated with the passed-in list of - parameters. + """Counts the number of zero values in the gradients of the given parameters. + + The count is performed in FP32. This method filters parameters to ensure + gradients are not double-counted by checking if the gradient is not None, + the parameter is not shared, and the parameter is not a replica due + to tensor model parallelism. It also handles parameters managed by + Megatron FSDP specifically. Args: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have the number of zeros in its corresponding - gradient counted. - grad_stats_parallel_group (group): Process group for reducing the num_zeros count. This is - generally the model-parallel group for non-distributed optimizers, and the entire - world for the distributed optimizer. - use_decoupled_grad (bool, optional) whether to read grad from ".grad" or ".decoupled_grad", - default value is False. + parameters (Union[List[torch.Tensor], torch.Tensor]): An iterable of + Tensors or a single Tensor whose gradients will be checked for zeros. + grad_stats_parallel_group (ProcessGroup): The process group used for + reducing the zero count across distributed ranks. + use_decoupled_grad (bool, optional): If True, reads from the + '.decoupled_grad' attribute instead of the standard '.grad'. + Defaults to False. + + Returns: + float: The total number of zeros in the gradients across the process group. """ if isinstance(parameters, torch.Tensor): @@ -245,4 +256,4 @@ def count_zeros_fp32( total_num_zeros = total_num_zeros.item() - return total_num_zeros + return total_num_zeros \ No newline at end of file diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index eeda383a75d..a18206ed591 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -95,9 +95,15 @@ def __len__(self): class DistributedOptimizer(MixedPrecisionOptimizer): - """Distributed optimizer, for all data types (fp16, bf16, and fp32). + """Optimizer that shards state across data-parallel ranks. - See __init__() below for argument details. + This class reduces memory usage by distributing optimizer states (like + momentum and variance buffers) across GPUs in the data-parallel group. + + Attributes: + model_chunks (List[MegatronModule]): Model segments being optimized. + per_model_buffers (Dict): Buffers managing contiguous params/grads. + data_parallel_group (ProcessGroup): Group for sharding and all-gathers. """ # enumerates fully reshardable optimizer formats (as opposed to formats @@ -115,8 +121,7 @@ def _build_model_gbuf_param_range_map( gbuf_world_range: Range, bucket_offset: int, ): - """ - Build mapping from param reference to grad buffer shard ranges. + """Build mapping from param reference to grad buffer shard ranges. This method builds a mapping from parameter references to grad buffer shard ranges, specific to each data-parallel (DP) rank's @@ -135,10 +140,15 @@ def _build_model_gbuf_param_range_map( main & main-to-model operations. This method creates four ranges: - - The param's range within the entire grad buffer (i.e., world index). - - The param's range within the relevant grad bucket's buffer. - - The param's range within the DP rank's local view of the grad buffer. - - The param's range within itself (i.e., its shard). + - gbuf_world: The param's range within the entire grad buffer (world index). + - gbuf_world_in_bucket: The param's range within the relevant grad bucket's buffer. + - gbuf_local: The param's range within the DP rank's local view of the grad buffer. + - param: The param's range within itself (i.e., its shard). + + Args: + param_world_index_map (Dict): Mapping from parameter to its world indexes. + gbuf_world_range (Range): The range of the grad buffer owned by this rank. + bucket_offset (int): The offset of the current bucket within the grad buffer. """ # Param range map. @@ -219,16 +229,18 @@ def _build_model_gbuf_range(cls, param_and_grad_buffer: _ParamAndGradBuffer, buc @classmethod def _build_gbuf_range_map(cls, param_and_grad_buffer: _ParamAndGradBuffer): - """ - Build mapping between params and their grad buffers. These mappings are - partitioned according to data type. + """Builds a map between parameters and their ranges in the grad buffer. - Iterate through all buckets of grad buffer to construct param ranges - that this rank "owns" (the dp_rank'th shard of each bucket, where each - shard is 1/dp_world_size of the bucket). + These mappings are partitioned according to data type. This method + iterates through all buckets of a grad buffer to construct param + ranges that this rank "owns" (the dp_rank'th shard of each bucket, + where each shard is 1/dp_world_size of the bucket). Args: - param_and_grad_buffer (_ParamAndGradBuffer): buffer to build mapping for. + param_and_grad_buffer (_ParamAndGradBuffer): The buffer to map. + + Returns: + Dict: Mapping of parameter dtypes to bucket ranges. """ return { (param_and_grad_buffer.param_dtype, param_and_grad_buffer.grad_dtype): [ @@ -469,36 +481,36 @@ def __init__( data_parallel_group_idx: int, distributed_optimizer_instance_id: int, ): - """ - Distributed optimizer, for all data types (fp16, bf16, and fp32). + """Initializes the distributed optimizer for FP16, BF16, and FP32. - The steps in this method create the core mapping between param and grad buffers, - parameters, and parameter shard ranges, that is needed for converting between model - param indexes and main parameter shard indexes. This method also updates the optimizer - parameter groups with the newly created shards. + The steps in this method create the core mapping between param and grad + buffers, parameters, and parameter shard ranges, that is needed for + converting between model param indexes and main parameter shard indexes. + This method also updates the optimizer parameter groups with the + newly created shards. Args: - optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD. - config (OptimizerConfig): configuration object for optimizer. - grad_scaler (MegatronGradScaler): used for scaling gradients. Note that - this can be None. This case happens when `bf16 = True` and we don't - use any loss scale. Note that for `bf16 = True`, we can have - a constant gradient scaler. Also for `bf16 = False`, we - always require a grad scaler. - init_state_fn (Callable, optional): function to initialize state in the optimizer. - model_chunks (List[MegatronModule]): list of model chunks. - per_model_buffers (Dict[int, List[_ParamAndGradBuffer]]): the implementation of the - distributed optimizer is centered on using a contiguous buffer for - communicating grads & params between the model state and the optimizer state. - You can find a more detailed description in - https://github.com/NVIDIA/Megatron-LM/blob/main/docs/source/distrib_optimizer.md. - data_parallel_group (torch.distributed.ProcessGroup): data-parallel group to use to + optimizer (torch.optim.Optimizer): Base optimizer such as Adam or SGD. + config (OptimizerConfig): Configuration object for the optimizer. + grad_scaler (MegatronGradScaler): Used for scaling gradients. Note that + this can be None for BF16 training if no loss scale is used. + For FP16, a grad scaler is always required. + init_state_fn (Callable, optional): Function to initialize state in + the optimizer. + model_chunks (List[MegatronModule]): List of model chunks to optimize. + per_model_buffers (Dict[int, List[_ParamAndGradBuffer]]): The + implementation of the distributed optimizer is centered on using + a contiguous buffer for communicating grads & params between + the model state and the optimizer state. For a detailed + description, see `docs/source/distrib_optimizer.md`. + data_parallel_group (ProcessGroup): Data-parallel group used to all-gather params after optimizer.step(). - data_parallel_group_gloo (torch.distributed.ProcessGroup): gloo data-parallel group - (used in checkpoint loading and saving). - data_parallel_group_idx (int): index in data-parallel group (used by - distributed checkpointing logic). - distributed_optimizer_instance_id (int): index of the Distributed Optimizer instance. + data_parallel_group_gloo (ProcessGroup, optional): Gloo data-parallel + group used specifically for checkpoint loading and saving. + data_parallel_group_idx (int): Index in the data-parallel group + used by distributed checkpointing logic. + distributed_optimizer_instance_id (int): Unique identifier for the + distributed optimizer instance. """ if has_config_logger_enabled(config): @@ -2633,3 +2645,4 @@ def step_with_ready_grads(self) -> bool: timers('params-all-gather').stop() return update_successful + diff --git a/megatron/core/optimizer/grad_scaler.py b/megatron/core/optimizer/grad_scaler.py index abdd1e7b606..03656488982 100644 --- a/megatron/core/optimizer/grad_scaler.py +++ b/megatron/core/optimizer/grad_scaler.py @@ -9,6 +9,11 @@ class MegatronGradScaler(ABC): + """Abstract base class for gradient scalers. + + Args: + initial_scale (float): The initial value for the loss scale. + """ def __init__(self, initial_scale: float): """Initialize scale value with the input initial scale.""" assert initial_scale > 0.0 @@ -37,7 +42,10 @@ def load_state_dict(self, state_dict: Dict): class ConstantGradScaler(MegatronGradScaler): """ - Constant grad scaler (loss scale is never adjusted regardless of NaNs seen in gradients). + Grad scaler with a fixed scale factor. + + The loss scale is never adjusted, regardless of whether NaNs or Infs + are detected in the gradients. """ def update(self, found_inf: bool): @@ -51,11 +59,26 @@ def load_state_dict(self, state_dict): class DynamicGradScaler(MegatronGradScaler): - """ - Grad scaler with dynamic scale that gets adjusted during training. - - Reduces loss scale by `backoff_factor` if `hysteresis` number of NaNs are seen in a row. Increases - loss scale by `growth_factor` if NaNs are not seen for `growth_interval` iterations. + """Gradient scaler with a dynamic scale factor adjusted during training. + + This class implements a loss scaling strategy to prevent numerical underflow + during mixed-precision training. It reduces the loss scale by a + `backoff_factor` if a `hysteresis` number of NaNs/Infs are detected in + consecutive iterations. Conversely, it increases the loss scale by a + `growth_factor` if no non-finite gradients are seen for a specified + `growth_interval` of iterations. + + Args: + initial_scale (float): The starting value for the loss scale. + min_scale (float): The lower bound for the loss scale. + growth_factor (float): The multiplier used to increase the scale when + gradients are stable. Must be greater than 1.0. + backoff_factor (float): The multiplier used to decrease the scale when + non-finite gradients are detected. Must be between 0.0 and 1.0. + growth_interval (int): The number of consecutive stable iterations + required before increasing the scale. + hysteresis (int): The number of consecutive non-finite iterations + required before decreasing the scale. """ def __init__( diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index df8ec8ef613..9f3174ca77d 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -101,10 +101,13 @@ class MegatronOptimizer(ABC): """ Base class for all Megatron optimizers. + Provides a consistent interface for gradient management, parameter + access, and state-dict handling across different optimization types. + Args: - optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD. - config (OptimizerConfig): configuration object for optimizer. - init_state_fn (Callable, optional): function to initialize state in the optimizer. + optimizer (torch.optim.Optimizer): The base PyTorch optimizer. + config (OptimizerConfig): The optimizer configuration. + init_state_fn (Callable, optional): Function to initialize optimizer state. """ def __init__( @@ -135,13 +138,15 @@ def get_parameters(self) -> List[torch.nn.Parameter]: return params def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]: - """ - Get main_grads that should be taken into account to compute the grad norm. - Filter parameters based on: - - grad should not be None. - - parameter should not be shared (i.e., grads shouldn't be double counted while - computing norms). - - should not be a replica due to tensor model parallelism. + + """Collects gradients for norm calculation, filtering duplicates. + + This method filters parameters based on whether the gradient is not None, + the parameter is not shared (to avoid double-counting gradients), and + the parameter is not a replica due to tensor model parallelism. + + Returns: + List[torch.Tensor]: A list of gradient tensors filtered for norm calculation. """ params = self.get_parameters() grads_for_norm = [] diff --git a/megatron/core/optimizer/optimizer_config.py b/megatron/core/optimizer/optimizer_config.py index 2d3e3ca08e0..33175589adf 100644 --- a/megatron/core/optimizer/optimizer_config.py +++ b/megatron/core/optimizer/optimizer_config.py @@ -137,7 +137,42 @@ def matches(self, param: torch.nn.Parameter, param_name: str) -> bool: @dataclass class OptimizerConfig: - """Base optimizer configuration object.""" + """Configuration object for Megatron optimizers. + + Attributes: + lr (float, optional): Initial learning rate. Defaults to None. + min_lr (float, optional): Minimum learning rate for scheduler clipping. + weight_decay (float): L2 regularization coefficient. Defaults to 0.01. + + fp8_recipe (str, optional): Type of FP8 recipe affecting distributed + optimizer logic. + fp16 (bool): If True, use FP16 mixed precision. Defaults to False. + bf16 (bool): If True, use BF16 mixed precision. Defaults to False. + params_dtype (torch.dtype): Dtype used for weight initialization. + use_precision_aware_optimizer (bool): Allows lower precision for + master params and states. + + loss_scale (float, optional): Static loss scaling factor. + initial_loss_scale (float): Initial scale for dynamic scaling. + min_loss_scale (float): Minimum scale for dynamic scaling. + loss_scale_window (float): Window size for dynamic scale adjustments. + hysteresis (int): Delay iterations for dynamic scale reduction. + + adam_beta1 (float): Adam beta1 coefficient. Defaults to 0.9. + adam_beta2 (float): Adam beta2 coefficient. Defaults to 0.999. + adam_eps (float): Adam epsilon for numerical stability. + decoupled_weight_decay (bool): Whether to use AdamW-style decay. + + use_distributed_optimizer (bool): If True, shard state across DP ranks. + overlap_param_gather (bool): Overlap param all-gather with forward compute. + + optimizer_cpu_offload (bool): If True, offload state and compute to CPU. + optimizer_offload_fraction (float): Percentage of state to offload. + + clip_grad (float): Global L2 norm threshold for gradient clipping. + log_num_zeros_in_grad (bool): Whether to log count of zero gradients. + timers (Callable, optional): Timer utility function. + """ ############## # General diff --git a/megatron/core/optimizer/qk_clip.py b/megatron/core/optimizer/qk_clip.py index 26b5787cd50..284a39daa22 100644 --- a/megatron/core/optimizer/qk_clip.py +++ b/megatron/core/optimizer/qk_clip.py @@ -7,14 +7,14 @@ def clip_qk(model, log_max_only=False) -> float: """ - Clip the QK attention logits to the threshold, recommended for Muon optimizer. + Clips QK attention logits to prevent numerical instability. Args: - model: The model to clip the QK attention logits, a list of model chunks. - log_only: Whether to only log the max attention logit, without updating the weights. + model (List[MegatronModule]): Model chunks containing attention layers. + log_max_only (bool): If True, only computes max logit without clipping. Returns: - The maximum attention logit, a float. + float: The maximum QK logit value across all chunks. """ with torch.no_grad():