Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,8 @@ def topk_routing_with_score_function(
group_topk (int, optional): Number of selected groups for each token. Defaults to None.
scaling_factor (float, optional): Scaling factor of routing score in top-k selection.
Defaults to None.
score_function (str, optional): The score function to use. Can be either "softmax" or
"sigmoid". Defaults to "softmax".
score_function (str, optional): The score function to use. Can be "softmax", "sigmoid"
or "sqrtsoftplus". Defaults to "softmax".
expert_bias (torch.Tensor, optional): The bias added to logits for expert routing.
Defaults to None.
fused (bool, optional): Whether to use the fused version. Defaults to False.
Expand Down Expand Up @@ -710,6 +710,11 @@ def topk_routing_with_score_function(
raise ValueError(
"fused_topk_with_score_function is not available. Please install TE >= 2.6.0."
)
if score_function == "sqrtsoftplus" and not is_te_min_version("2.13.0"):
raise ValueError(
"Fused sqrtsoftplus score function requires TE >= 2.13.0. "
"Please upgrade Transformer Engine or disable moe_router_fusion."
)
return fused_topk_with_score_function(
logits=logits,
topk=topk,
Expand Down Expand Up @@ -762,19 +767,26 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):
scores, topk, num_groups, group_topk, _compute_topk
)

# Precision notes:
# - Logits are converted to fp32 for score functions.
# - All the intermediate calculations are in fp32.
# - The final probs are casted to the same dtype as the logits.
if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
else:
scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits.float()).type_as(logits)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32)
elif score_function in ("sigmoid", "sqrtsoftplus"):
if score_function == "sigmoid":
scores = torch.sigmoid(logits.float())
else:
scores = torch.nn.functional.softplus(logits.float()).sqrt()
if expert_bias is not None:
scores_for_routing = scores + expert_bias
scores_for_routing = scores + expert_bias.float()
_, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
scores = torch.gather(scores, dim=1, index=top_indices)
else:
scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
Expand All @@ -784,6 +796,8 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):
if scaling_factor:
probs = probs * scaling_factor

probs = probs.type_as(logits)

if dense_output:
return probs, top_indices

Expand Down Expand Up @@ -818,7 +832,8 @@ def compute_routing_scores_for_aux_loss(
Args:
logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts].
topk (int): The number of top-k indices to compute.
score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
score_function (str): The score function to use. Can be "softmax", "sigmoid"
or "sqrtsoftplus".
fused (bool, optional): Whether to use the fused version. Defaults to False.
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
Shape in [num_tokens]. True for valid tokens,
Expand All @@ -832,15 +847,22 @@ def compute_routing_scores_for_aux_loss(
raise ValueError(
"fused_compute_score_for_moe_aux_loss is not available. Please install TE >= 2.6.0."
)
if score_function == "sqrtsoftplus" and not is_te_min_version("2.13.0"):
raise ValueError(
"Fused sqrtsoftplus score function requires TE >= 2.13.0. "
"Please upgrade Transformer Engine or disable moe_router_fusion."
)
routing_map, scores = fused_compute_score_for_moe_aux_loss(
logits=logits, topk=topk, score_function=score_function
)
else:
if score_function == "softmax":
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
elif score_function == "sigmoid":
# Cast logits to float32 before sigmoid for stability
scores = torch.sigmoid(logits.to(torch.float32))
scores = torch.sigmoid(logits.float())
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
elif score_function == "sqrtsoftplus":
scores = torch.nn.functional.softplus(logits.float()).sqrt()
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
else:
raise ValueError(f"Invalid score_function: {score_function}")
Expand Down
14 changes: 9 additions & 5 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,8 @@ class TransformerConfig(ModelParallelConfig):
"""Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax
enabled. Defaults to None, which means no scaling."""

moe_router_score_function: Literal['softmax', 'sigmoid'] = "softmax"
"""Score function for MoE routing. Can be "softmax" or "sigmoid"."""
moe_router_score_function: Literal['softmax', 'sigmoid', 'sqrtsoftplus'] = "softmax"
"""Score function for MoE routing. Can be "softmax", "sigmoid" or "sqrtsoftplus"."""

moe_router_dtype: Optional[Literal['fp32', 'fp64']] = None
"""Data type for routing and expert output weighted averaging. Using fp32 or fp64 can
Expand Down Expand Up @@ -1770,10 +1770,14 @@ def __post_init__(self):
self.expert_tensor_parallel_size == 1
), "Bias in Moe is only supported when ETP==1"

if self.moe_router_enable_expert_bias and self.moe_router_score_function != "sigmoid":
if self.moe_router_enable_expert_bias and self.moe_router_score_function not in (
"sigmoid",
"sqrtsoftplus",
):
raise ValueError(
"Expert bias for aux-loss-free routing only supports sigmoid score function."
"Please set --moe-router-score-function sigmoid for sigmoid score function."
"Expert bias for aux-loss-free routing only supports 'sigmoid' and 'sqrtsoftplus' "
"score functions. Please set --moe-router-score-function to 'sigmoid' or "
"'sqrtsoftplus', or unset --moe-router-enable-expert-bias."
)

if self.num_moe_experts and self.fp8:
Expand Down
Loading