From f2c3f3b1a6ae8e817e37a4da8127f25619b48910 Mon Sep 17 00:00:00 2001 From: Akshat Kumar Date: Tue, 3 Mar 2026 11:10:05 +0530 Subject: [PATCH 1/3] docs: overhaul mcore optimizer docstrings (final review fixes) Signed-off-by: Akshat Kumar --- megatron/core/optimizer/clip_grads.py | 73 +++++++----- megatron/core/optimizer/distrib_optimizer.py | 119 ++++++++----------- megatron/core/optimizer/grad_scaler.py | 35 +++++- megatron/core/optimizer/optimizer.py | 25 ++-- megatron/core/optimizer/optimizer_config.py | 37 +++++- megatron/core/optimizer/qk_clip.py | 8 +- 6 files changed, 176 insertions(+), 121 deletions(-) diff --git a/megatron/core/optimizer/clip_grads.py b/megatron/core/optimizer/clip_grads.py index cb2f23a685f..d1e35ad0670 100644 --- a/megatron/core/optimizer/clip_grads.py +++ b/megatron/core/optimizer/clip_grads.py @@ -53,22 +53,24 @@ def get_grad_norm_fp32( norm_type: Union[int, float] = 2, grad_stats_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> float: - """Calculate the norm of gradients in fp32. + """Calculate the p-norm of gradients in FP32 precision. - This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. + This function is adapted from `torch.nn.utils.clip_grad.clip_grad_norm_` + and extends it with functionality to handle model-parallel parameters. + It ensures that the norm is correctly computed and reduced across + the specified process group (typically the model-parallel group for + non-distributed optimizers or the entire world for distributed optimizers). - Arguments: - grads_for_norm (Iterable[Tensor] or Tensor): an iterable of Tensors or a single - Tensor that will be used for calculating the grad norm. - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - grad_stats_parallel_group (group): Process group for reducing the grad norms. This is - generally the model-parallel group for non-distributed optimizers, and the entire - world for the distributed optimizer. + Args: + grads_for_norm (Union[List[torch.Tensor], torch.Tensor]): An iterable + of Tensors or a single Tensor used to calculate the gradient norm. + norm_type (Union[int, float]): The type of the p-norm to use. Can be + 'inf' for infinity norm. Defaults to 2. + grad_stats_parallel_group (ProcessGroup, optional): The process group + used for reducing gradient statistics (e.g., norms and zero counts). Returns: - Total norm of the parameters (viewed as a single vector). + float: The total norm of the parameters, treated as a single vector. """ if isinstance(grads_for_norm, torch.Tensor): @@ -141,17 +143,19 @@ def clip_grad_by_total_norm_fp32( total_norm: float, use_decoupled_grad: bool = False, ): - """Clips gradient of an iterable of parameters in fp32 by total norm. + """Clips the gradients of an iterable of parameters in FP32 by total norm. - Note that the gradients are modified in place. + Note that the gradients are modified in-place. Args: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized. - max_norm (float or int): max norm of the gradients. - total_norm (float): total norm of the gradients. - use_decoupled_grad (bool, optional): whether to read grad from ".grad" or ".decoupled_grad", - default value is False. + parameters (Union[List[torch.Tensor], torch.Tensor]): An iterable of + Tensors or a single Tensor that will have gradients normalized. + max_norm (Union[int, float]): The maximum permissible total norm + of the gradients. + total_norm (float): The current total norm of the gradients. + use_decoupled_grad (bool, optional): Whether to read from the + '.decoupled_grad' attribute instead of the standard '.grad'. + Defaults to False. """ # Grads. params = [] @@ -183,18 +187,25 @@ def count_zeros_fp32( use_decoupled_grad: bool = False, tp_group: Optional[torch.distributed.ProcessGroup] = None, ) -> float: - """Counts the number of zeros in gradients associated with the passed-in list of - parameters. + """Counts the number of zero values in the gradients of the given parameters. + + The count is performed in FP32. This method filters parameters to ensure + gradients are not double-counted by checking if the gradient is not None, + the parameter is not shared, and the parameter is not a replica due + to tensor model parallelism. It also handles parameters managed by + Megatron FSDP specifically. Args: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have the number of zeros in its corresponding - gradient counted. - grad_stats_parallel_group (group): Process group for reducing the num_zeros count. This is - generally the model-parallel group for non-distributed optimizers, and the entire - world for the distributed optimizer. - use_decoupled_grad (bool, optional) whether to read grad from ".grad" or ".decoupled_grad", - default value is False. + parameters (Union[List[torch.Tensor], torch.Tensor]): An iterable of + Tensors or a single Tensor whose gradients will be checked for zeros. + grad_stats_parallel_group (ProcessGroup): The process group used for + reducing the zero count across distributed ranks. + use_decoupled_grad (bool, optional): If True, reads from the + '.decoupled_grad' attribute instead of the standard '.grad'. + Defaults to False. + + Returns: + float: The total number of zeros in the gradients across the process group. """ if isinstance(parameters, torch.Tensor): @@ -245,4 +256,4 @@ def count_zeros_fp32( total_num_zeros = total_num_zeros.item() - return total_num_zeros + return total_num_zeros \ No newline at end of file diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index eeda383a75d..e820caa727b 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -95,9 +95,15 @@ def __len__(self): class DistributedOptimizer(MixedPrecisionOptimizer): - """Distributed optimizer, for all data types (fp16, bf16, and fp32). + """Optimizer that shards state across data-parallel ranks. - See __init__() below for argument details. + This class reduces memory usage by distributing optimizer states (like + momentum and variance buffers) across GPUs in the data-parallel group. + + Attributes: + model_chunks (List[MegatronModule]): Model segments being optimized. + per_model_buffers (Dict): Buffers managing contiguous params/grads. + data_parallel_group (ProcessGroup): Group for sharding and all-gathers. """ # enumerates fully reshardable optimizer formats (as opposed to formats @@ -118,27 +124,13 @@ def _build_model_gbuf_param_range_map( """ Build mapping from param reference to grad buffer shard ranges. - This method builds a mapping from parameter references to grad - buffer shard ranges, specific to each data-parallel (DP) rank's - set of 'owned' parameters. Each grad buffer (padded to be an even - multiple of DP-world-size) is conceptually divided into DP-world-size - contiguous regions, where each DP rank 'owns' a contiguous region. - Ownership in this sense means DP rank is responsible for reducing - the relevant subset of grads, and updating the relevant subset of - params. - - This conceptual partitioning of the grad buffer does NOT respect - parameter boundaries, and as such it is assumed that each created - range references a shard (or subset) of the full parameter. It is - easiest to think of each DP rank as operating (i.e., reducing, - gathering) purely on views into the grad buffer, for all model-to- - main & main-to-model operations. - - This method creates four ranges: - - The param's range within the entire grad buffer (i.e., world index). - - The param's range within the relevant grad bucket's buffer. - - The param's range within the DP rank's local view of the grad buffer. - - The param's range within itself (i.e., its shard). + Each grad buffer is conceptually divided into DP-world-size contiguous + regions, where each DP rank 'owns' a region. This method creates + mappings for four specific ranges: + - The param's range within the entire grad buffer (world index). + - The param's range within the relevant grad bucket's buffer. + - The param's range within the DP rank's local view of the grad buffer. + - The param's range within itself (its shard). """ # Param range map. @@ -219,16 +211,17 @@ def _build_model_gbuf_range(cls, param_and_grad_buffer: _ParamAndGradBuffer, buc @classmethod def _build_gbuf_range_map(cls, param_and_grad_buffer: _ParamAndGradBuffer): - """ - Build mapping between params and their grad buffers. These mappings are - partitioned according to data type. + """Builds a map between parameters and their ranges in the grad buffer. Iterate through all buckets of grad buffer to construct param ranges that this rank "owns" (the dp_rank'th shard of each bucket, where each shard is 1/dp_world_size of the bucket). Args: - param_and_grad_buffer (_ParamAndGradBuffer): buffer to build mapping for. + param_and_grad_buffer (_ParamAndGradBuffer): The buffer to map. + + Returns: + Dict: Mapping of parameter dtypes to bucket ranges. """ return { (param_and_grad_buffer.param_dtype, param_and_grad_buffer.grad_dtype): [ @@ -469,36 +462,36 @@ def __init__( data_parallel_group_idx: int, distributed_optimizer_instance_id: int, ): - """ - Distributed optimizer, for all data types (fp16, bf16, and fp32). + """Initializes the distributed optimizer for FP16, BF16, and FP32. - The steps in this method create the core mapping between param and grad buffers, - parameters, and parameter shard ranges, that is needed for converting between model - param indexes and main parameter shard indexes. This method also updates the optimizer - parameter groups with the newly created shards. + The steps in this method create the core mapping between param and grad + buffers, parameters, and parameter shard ranges, that is needed for + converting between model param indexes and main parameter shard indexes. + This method also updates the optimizer parameter groups with the + newly created shards. Args: - optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD. - config (OptimizerConfig): configuration object for optimizer. - grad_scaler (MegatronGradScaler): used for scaling gradients. Note that - this can be None. This case happens when `bf16 = True` and we don't - use any loss scale. Note that for `bf16 = True`, we can have - a constant gradient scaler. Also for `bf16 = False`, we - always require a grad scaler. - init_state_fn (Callable, optional): function to initialize state in the optimizer. - model_chunks (List[MegatronModule]): list of model chunks. - per_model_buffers (Dict[int, List[_ParamAndGradBuffer]]): the implementation of the - distributed optimizer is centered on using a contiguous buffer for - communicating grads & params between the model state and the optimizer state. - You can find a more detailed description in - https://github.com/NVIDIA/Megatron-LM/blob/main/docs/source/distrib_optimizer.md. - data_parallel_group (torch.distributed.ProcessGroup): data-parallel group to use to + optimizer (torch.optim.Optimizer): Base optimizer such as Adam or SGD. + config (OptimizerConfig): Configuration object for the optimizer. + grad_scaler (MegatronGradScaler): Used for scaling gradients. Note that + this can be None for BF16 training if no loss scale is used. + For FP16, a grad scaler is always required. + init_state_fn (Callable, optional): Function to initialize state in + the optimizer. + model_chunks (List[MegatronModule]): List of model chunks to optimize. + per_model_buffers (Dict[int, List[_ParamAndGradBuffer]]): The + implementation of the distributed optimizer is centered on using + a contiguous buffer for communicating grads & params between + the model state and the optimizer state. For a detailed + description, see `docs/source/distrib_optimizer.md`. + data_parallel_group (ProcessGroup): Data-parallel group used to all-gather params after optimizer.step(). - data_parallel_group_gloo (torch.distributed.ProcessGroup): gloo data-parallel group - (used in checkpoint loading and saving). - data_parallel_group_idx (int): index in data-parallel group (used by - distributed checkpointing logic). - distributed_optimizer_instance_id (int): index of the Distributed Optimizer instance. + data_parallel_group_gloo (ProcessGroup, optional): Gloo data-parallel + group used specifically for checkpoint loading and saving. + data_parallel_group_idx (int): Index in the data-parallel group + used by distributed checkpointing logic. + distributed_optimizer_instance_id (int): Unique identifier for the + distributed optimizer instance. """ if has_config_logger_enabled(config): @@ -896,8 +889,6 @@ def _get_main_param_and_optimizer_states(self, model_param): sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order] tensors = {} for k in self.optimizer.state[sharded_model_param]: - if not isinstance(self.optimizer.state[sharded_model_param][k], torch.Tensor): - continue if isinstance(self.optimizer, HybridDeviceOptimizer): tensors[k] = self.optimizer.state[sharded_model_param][k] continue @@ -907,10 +898,7 @@ def _get_main_param_and_optimizer_states(self, model_param): else: main_param = self.optimizer.param_groups[group_index]["params"][group_order] optim_state = self.optimizer.state[main_param] - tensors = {"param": main_param} - for k, v in optim_state.items(): - if isinstance(v, torch.Tensor): - tensors[k] = v + tensors = {"param": main_param, **optim_state} return tensors def _set_main_param_and_optimizer_states(self, model_param, tensors): @@ -927,8 +915,6 @@ def _set_main_param_and_optimizer_states(self, model_param, tensors): if self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8: sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order] for k, v in tensors.items(): - if not isinstance(v, torch.Tensor): - continue if isinstance(self.optimizer, HybridDeviceOptimizer): if k == "param": k = "master_param" @@ -942,13 +928,8 @@ def _set_main_param_and_optimizer_states(self, model_param, tensors): else: main_param = self.optimizer.param_groups[group_index]["params"][group_order] optim_state = self.optimizer.state[main_param] - dst_tensors = {"param": main_param} - for k, v in optim_state.items(): - if isinstance(v, torch.Tensor): - dst_tensors[k] = v + dst_tensors = {"param": main_param, **optim_state} for key in dst_tensors: - if not isinstance(tensors[key], torch.Tensor): - continue dst_tensors[key].copy_(tensors[key]) def get_parameter_state_dp_reshardable(self): @@ -2539,7 +2520,7 @@ def _build_model_param_to_state_dict_param_map(self, state_dict): for name, model_param in model_chunk.named_parameters(): while name.startswith("module."): name = name[len("module.") :] - matched_keys = [k for k in names_in_state_dict if k.endswith(name)] + matched_keys = [k for k in names_in_state_dict if name in k] assert ( len(matched_keys) == 1 ), f"Parameter {name} has {len(matched_keys)} matches in state dict" @@ -2632,4 +2613,4 @@ def step_with_ready_grads(self) -> bool: if timers is not None: timers('params-all-gather').stop() - return update_successful + return update_successful \ No newline at end of file diff --git a/megatron/core/optimizer/grad_scaler.py b/megatron/core/optimizer/grad_scaler.py index abdd1e7b606..03656488982 100644 --- a/megatron/core/optimizer/grad_scaler.py +++ b/megatron/core/optimizer/grad_scaler.py @@ -9,6 +9,11 @@ class MegatronGradScaler(ABC): + """Abstract base class for gradient scalers. + + Args: + initial_scale (float): The initial value for the loss scale. + """ def __init__(self, initial_scale: float): """Initialize scale value with the input initial scale.""" assert initial_scale > 0.0 @@ -37,7 +42,10 @@ def load_state_dict(self, state_dict: Dict): class ConstantGradScaler(MegatronGradScaler): """ - Constant grad scaler (loss scale is never adjusted regardless of NaNs seen in gradients). + Grad scaler with a fixed scale factor. + + The loss scale is never adjusted, regardless of whether NaNs or Infs + are detected in the gradients. """ def update(self, found_inf: bool): @@ -51,11 +59,26 @@ def load_state_dict(self, state_dict): class DynamicGradScaler(MegatronGradScaler): - """ - Grad scaler with dynamic scale that gets adjusted during training. - - Reduces loss scale by `backoff_factor` if `hysteresis` number of NaNs are seen in a row. Increases - loss scale by `growth_factor` if NaNs are not seen for `growth_interval` iterations. + """Gradient scaler with a dynamic scale factor adjusted during training. + + This class implements a loss scaling strategy to prevent numerical underflow + during mixed-precision training. It reduces the loss scale by a + `backoff_factor` if a `hysteresis` number of NaNs/Infs are detected in + consecutive iterations. Conversely, it increases the loss scale by a + `growth_factor` if no non-finite gradients are seen for a specified + `growth_interval` of iterations. + + Args: + initial_scale (float): The starting value for the loss scale. + min_scale (float): The lower bound for the loss scale. + growth_factor (float): The multiplier used to increase the scale when + gradients are stable. Must be greater than 1.0. + backoff_factor (float): The multiplier used to decrease the scale when + non-finite gradients are detected. Must be between 0.0 and 1.0. + growth_interval (int): The number of consecutive stable iterations + required before increasing the scale. + hysteresis (int): The number of consecutive non-finite iterations + required before decreasing the scale. """ def __init__( diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index df8ec8ef613..9f3174ca77d 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -101,10 +101,13 @@ class MegatronOptimizer(ABC): """ Base class for all Megatron optimizers. + Provides a consistent interface for gradient management, parameter + access, and state-dict handling across different optimization types. + Args: - optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD. - config (OptimizerConfig): configuration object for optimizer. - init_state_fn (Callable, optional): function to initialize state in the optimizer. + optimizer (torch.optim.Optimizer): The base PyTorch optimizer. + config (OptimizerConfig): The optimizer configuration. + init_state_fn (Callable, optional): Function to initialize optimizer state. """ def __init__( @@ -135,13 +138,15 @@ def get_parameters(self) -> List[torch.nn.Parameter]: return params def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]: - """ - Get main_grads that should be taken into account to compute the grad norm. - Filter parameters based on: - - grad should not be None. - - parameter should not be shared (i.e., grads shouldn't be double counted while - computing norms). - - should not be a replica due to tensor model parallelism. + + """Collects gradients for norm calculation, filtering duplicates. + + This method filters parameters based on whether the gradient is not None, + the parameter is not shared (to avoid double-counting gradients), and + the parameter is not a replica due to tensor model parallelism. + + Returns: + List[torch.Tensor]: A list of gradient tensors filtered for norm calculation. """ params = self.get_parameters() grads_for_norm = [] diff --git a/megatron/core/optimizer/optimizer_config.py b/megatron/core/optimizer/optimizer_config.py index 2d3e3ca08e0..33175589adf 100644 --- a/megatron/core/optimizer/optimizer_config.py +++ b/megatron/core/optimizer/optimizer_config.py @@ -137,7 +137,42 @@ def matches(self, param: torch.nn.Parameter, param_name: str) -> bool: @dataclass class OptimizerConfig: - """Base optimizer configuration object.""" + """Configuration object for Megatron optimizers. + + Attributes: + lr (float, optional): Initial learning rate. Defaults to None. + min_lr (float, optional): Minimum learning rate for scheduler clipping. + weight_decay (float): L2 regularization coefficient. Defaults to 0.01. + + fp8_recipe (str, optional): Type of FP8 recipe affecting distributed + optimizer logic. + fp16 (bool): If True, use FP16 mixed precision. Defaults to False. + bf16 (bool): If True, use BF16 mixed precision. Defaults to False. + params_dtype (torch.dtype): Dtype used for weight initialization. + use_precision_aware_optimizer (bool): Allows lower precision for + master params and states. + + loss_scale (float, optional): Static loss scaling factor. + initial_loss_scale (float): Initial scale for dynamic scaling. + min_loss_scale (float): Minimum scale for dynamic scaling. + loss_scale_window (float): Window size for dynamic scale adjustments. + hysteresis (int): Delay iterations for dynamic scale reduction. + + adam_beta1 (float): Adam beta1 coefficient. Defaults to 0.9. + adam_beta2 (float): Adam beta2 coefficient. Defaults to 0.999. + adam_eps (float): Adam epsilon for numerical stability. + decoupled_weight_decay (bool): Whether to use AdamW-style decay. + + use_distributed_optimizer (bool): If True, shard state across DP ranks. + overlap_param_gather (bool): Overlap param all-gather with forward compute. + + optimizer_cpu_offload (bool): If True, offload state and compute to CPU. + optimizer_offload_fraction (float): Percentage of state to offload. + + clip_grad (float): Global L2 norm threshold for gradient clipping. + log_num_zeros_in_grad (bool): Whether to log count of zero gradients. + timers (Callable, optional): Timer utility function. + """ ############## # General diff --git a/megatron/core/optimizer/qk_clip.py b/megatron/core/optimizer/qk_clip.py index 26b5787cd50..284a39daa22 100644 --- a/megatron/core/optimizer/qk_clip.py +++ b/megatron/core/optimizer/qk_clip.py @@ -7,14 +7,14 @@ def clip_qk(model, log_max_only=False) -> float: """ - Clip the QK attention logits to the threshold, recommended for Muon optimizer. + Clips QK attention logits to prevent numerical instability. Args: - model: The model to clip the QK attention logits, a list of model chunks. - log_only: Whether to only log the max attention logit, without updating the weights. + model (List[MegatronModule]): Model chunks containing attention layers. + log_max_only (bool): If True, only computes max logit without clipping. Returns: - The maximum attention logit, a float. + float: The maximum QK logit value across all chunks. """ with torch.no_grad(): From 5493532137f6c39d6b9eb27b02cf65ed93514dc7 Mon Sep 17 00:00:00 2001 From: Akshat Kumar Date: Thu, 5 Mar 2026 09:42:48 +0530 Subject: [PATCH 2/3] removing logical changes causes by mergeconflict Signed-off-by: Akshat Kumar --- megatron/core/optimizer/distrib_optimizer.py | 21 ++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index e820caa727b..6326f6b6ec4 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -889,6 +889,8 @@ def _get_main_param_and_optimizer_states(self, model_param): sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order] tensors = {} for k in self.optimizer.state[sharded_model_param]: + if not isinstance(self.optimizer.state[sharded_model_param][k], torch.Tensor): + continue if isinstance(self.optimizer, HybridDeviceOptimizer): tensors[k] = self.optimizer.state[sharded_model_param][k] continue @@ -898,7 +900,10 @@ def _get_main_param_and_optimizer_states(self, model_param): else: main_param = self.optimizer.param_groups[group_index]["params"][group_order] optim_state = self.optimizer.state[main_param] - tensors = {"param": main_param, **optim_state} + tensors = {"param": main_param} + for k, v in optim_state.items(): + if isinstance(v, torch.Tensor): + tensors[k] = v return tensors def _set_main_param_and_optimizer_states(self, model_param, tensors): @@ -915,6 +920,8 @@ def _set_main_param_and_optimizer_states(self, model_param, tensors): if self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8: sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order] for k, v in tensors.items(): + if not isinstance(v, torch.Tensor): + continue if isinstance(self.optimizer, HybridDeviceOptimizer): if k == "param": k = "master_param" @@ -928,8 +935,13 @@ def _set_main_param_and_optimizer_states(self, model_param, tensors): else: main_param = self.optimizer.param_groups[group_index]["params"][group_order] optim_state = self.optimizer.state[main_param] - dst_tensors = {"param": main_param, **optim_state} + dst_tensors = {"param": main_param} + for k, v in optim_state.items(): + if isinstance(v, torch.Tensor): + dst_tensors[k] = v for key in dst_tensors: + if not isinstance(tensors[key], torch.Tensor): + continue dst_tensors[key].copy_(tensors[key]) def get_parameter_state_dp_reshardable(self): @@ -2520,7 +2532,7 @@ def _build_model_param_to_state_dict_param_map(self, state_dict): for name, model_param in model_chunk.named_parameters(): while name.startswith("module."): name = name[len("module.") :] - matched_keys = [k for k in names_in_state_dict if name in k] + matched_keys = [k for k in names_in_state_dict if k.endswith(name)] assert ( len(matched_keys) == 1 ), f"Parameter {name} has {len(matched_keys)} matches in state dict" @@ -2613,4 +2625,5 @@ def step_with_ready_grads(self) -> bool: if timers is not None: timers('params-all-gather').stop() - return update_successful \ No newline at end of file + return update_successful + \ No newline at end of file From 70e67d6f1e23ebd2dfa2b8268600f801b2431afb Mon Sep 17 00:00:00 2001 From: Akshat8510 Date: Thu, 5 Mar 2026 10:49:40 +0530 Subject: [PATCH 3/3] Refine docstrings for grad buffer mapping methods Updated docstrings to improve clarity and detail regarding grad buffer shard ranges and ownership. --- megatron/core/optimizer/distrib_optimizer.py | 47 ++++++++++++++------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index 6326f6b6ec4..a18206ed591 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -121,16 +121,34 @@ def _build_model_gbuf_param_range_map( gbuf_world_range: Range, bucket_offset: int, ): - """ - Build mapping from param reference to grad buffer shard ranges. - - Each grad buffer is conceptually divided into DP-world-size contiguous - regions, where each DP rank 'owns' a region. This method creates - mappings for four specific ranges: - - The param's range within the entire grad buffer (world index). - - The param's range within the relevant grad bucket's buffer. - - The param's range within the DP rank's local view of the grad buffer. - - The param's range within itself (its shard). + """Build mapping from param reference to grad buffer shard ranges. + + This method builds a mapping from parameter references to grad + buffer shard ranges, specific to each data-parallel (DP) rank's + set of 'owned' parameters. Each grad buffer (padded to be an even + multiple of DP-world-size) is conceptually divided into DP-world-size + contiguous regions, where each DP rank 'owns' a contiguous region. + Ownership in this sense means DP rank is responsible for reducing + the relevant subset of grads, and updating the relevant subset of + params. + + This conceptual partitioning of the grad buffer does NOT respect + parameter boundaries, and as such it is assumed that each created + range references a shard (or subset) of the full parameter. It is + easiest to think of each DP rank as operating (i.e., reducing, + gathering) purely on views into the grad buffer, for all model-to- + main & main-to-model operations. + + This method creates four ranges: + - gbuf_world: The param's range within the entire grad buffer (world index). + - gbuf_world_in_bucket: The param's range within the relevant grad bucket's buffer. + - gbuf_local: The param's range within the DP rank's local view of the grad buffer. + - param: The param's range within itself (i.e., its shard). + + Args: + param_world_index_map (Dict): Mapping from parameter to its world indexes. + gbuf_world_range (Range): The range of the grad buffer owned by this rank. + bucket_offset (int): The offset of the current bucket within the grad buffer. """ # Param range map. @@ -213,9 +231,10 @@ def _build_model_gbuf_range(cls, param_and_grad_buffer: _ParamAndGradBuffer, buc def _build_gbuf_range_map(cls, param_and_grad_buffer: _ParamAndGradBuffer): """Builds a map between parameters and their ranges in the grad buffer. - Iterate through all buckets of grad buffer to construct param ranges - that this rank "owns" (the dp_rank'th shard of each bucket, where each - shard is 1/dp_world_size of the bucket). + These mappings are partitioned according to data type. This method + iterates through all buckets of a grad buffer to construct param + ranges that this rank "owns" (the dp_rank'th shard of each bucket, + where each shard is 1/dp_world_size of the bucket). Args: param_and_grad_buffer (_ParamAndGradBuffer): The buffer to map. @@ -2626,4 +2645,4 @@ def step_with_ready_grads(self) -> bool: timers('params-all-gather').stop() return update_successful - \ No newline at end of file +