From 37c861029f061b01aa8610e3884705f768934ecf Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 2 Mar 2026 07:15:40 +0000 Subject: [PATCH] add a new score function Signed-off-by: Xin Yao --- megatron/core/transformer/moe/moe_utils.py | 44 ++++++++++++++----- .../core/transformer/transformer_config.py | 14 +++--- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 51c8b51134f..265e563f71f 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -675,8 +675,8 @@ def topk_routing_with_score_function( group_topk (int, optional): Number of selected groups for each token. Defaults to None. scaling_factor (float, optional): Scaling factor of routing score in top-k selection. Defaults to None. - score_function (str, optional): The score function to use. Can be either "softmax" or - "sigmoid". Defaults to "softmax". + score_function (str, optional): The score function to use. Can be "softmax", "sigmoid" + or "sqrtsoftplus". Defaults to "softmax". expert_bias (torch.Tensor, optional): The bias added to logits for expert routing. Defaults to None. fused (bool, optional): Whether to use the fused version. Defaults to False. @@ -710,6 +710,11 @@ def topk_routing_with_score_function( raise ValueError( "fused_topk_with_score_function is not available. Please install TE >= 2.6.0." ) + if score_function == "sqrtsoftplus" and not is_te_min_version("2.13.0"): + raise ValueError( + "Fused sqrtsoftplus score function requires TE >= 2.13.0. " + "Please upgrade Transformer Engine or disable moe_router_fusion." + ) return fused_topk_with_score_function( logits=logits, topk=topk, @@ -762,19 +767,26 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): scores, topk, num_groups, group_topk, _compute_topk ) + # Precision notes: + # - Logits are converted to fp32 for score functions. + # - All the intermediate calculations are in fp32. + # - The final probs are casted to the same dtype as the logits. if score_function == "softmax": if use_pre_softmax: - scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + scores = torch.softmax(logits, dim=-1, dtype=torch.float32) probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) else: scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) - probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) - elif score_function == "sigmoid": - scores = torch.sigmoid(logits.float()).type_as(logits) + probs = torch.softmax(scores, dim=-1, dtype=torch.float32) + elif score_function in ("sigmoid", "sqrtsoftplus"): + if score_function == "sigmoid": + scores = torch.sigmoid(logits.float()) + else: + scores = torch.nn.functional.softplus(logits.float()).sqrt() if expert_bias is not None: - scores_for_routing = scores + expert_bias + scores_for_routing = scores + expert_bias.float() _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) - scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) + scores = torch.gather(scores, dim=1, index=top_indices) else: scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores @@ -784,6 +796,8 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): if scaling_factor: probs = probs * scaling_factor + probs = probs.type_as(logits) + if dense_output: return probs, top_indices @@ -818,7 +832,8 @@ def compute_routing_scores_for_aux_loss( Args: logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts]. topk (int): The number of top-k indices to compute. - score_function (str): The score function to use. Can be either "softmax" or "sigmoid". + score_function (str): The score function to use. Can be "softmax", "sigmoid" + or "sqrtsoftplus". fused (bool, optional): Whether to use the fused version. Defaults to False. padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens. Shape in [num_tokens]. True for valid tokens, @@ -832,6 +847,11 @@ def compute_routing_scores_for_aux_loss( raise ValueError( "fused_compute_score_for_moe_aux_loss is not available. Please install TE >= 2.6.0." ) + if score_function == "sqrtsoftplus" and not is_te_min_version("2.13.0"): + raise ValueError( + "Fused sqrtsoftplus score function requires TE >= 2.13.0. " + "Please upgrade Transformer Engine or disable moe_router_fusion." + ) routing_map, scores = fused_compute_score_for_moe_aux_loss( logits=logits, topk=topk, score_function=score_function ) @@ -839,8 +859,10 @@ def compute_routing_scores_for_aux_loss( if score_function == "softmax": scores = torch.softmax(logits, dim=-1, dtype=torch.float32) elif score_function == "sigmoid": - # Cast logits to float32 before sigmoid for stability - scores = torch.sigmoid(logits.to(torch.float32)) + scores = torch.sigmoid(logits.float()) + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) + elif score_function == "sqrtsoftplus": + scores = torch.nn.functional.softplus(logits.float()).sqrt() scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) else: raise ValueError(f"Invalid score_function: {score_function}") diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 559f4226af2..7ebd25585ba 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -692,8 +692,8 @@ class TransformerConfig(ModelParallelConfig): """Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax enabled. Defaults to None, which means no scaling.""" - moe_router_score_function: Literal['softmax', 'sigmoid'] = "softmax" - """Score function for MoE routing. Can be "softmax" or "sigmoid".""" + moe_router_score_function: Literal['softmax', 'sigmoid', 'sqrtsoftplus'] = "softmax" + """Score function for MoE routing. Can be "softmax", "sigmoid" or "sqrtsoftplus".""" moe_router_dtype: Optional[Literal['fp32', 'fp64']] = None """Data type for routing and expert output weighted averaging. Using fp32 or fp64 can @@ -1770,10 +1770,14 @@ def __post_init__(self): self.expert_tensor_parallel_size == 1 ), "Bias in Moe is only supported when ETP==1" - if self.moe_router_enable_expert_bias and self.moe_router_score_function != "sigmoid": + if self.moe_router_enable_expert_bias and self.moe_router_score_function not in ( + "sigmoid", + "sqrtsoftplus", + ): raise ValueError( - "Expert bias for aux-loss-free routing only supports sigmoid score function." - "Please set --moe-router-score-function sigmoid for sigmoid score function." + "Expert bias for aux-loss-free routing only supports 'sigmoid' and 'sqrtsoftplus' " + "score functions. Please set --moe-router-score-function to 'sigmoid' or " + "'sqrtsoftplus', or unset --moe-router-enable-expert-bias." ) if self.num_moe_experts and self.fp8: