Skip to content

Global auxiliary loss gradient incorrectly scaled due to averaging over global token count #3672

@zyeric

Description

@zyeric

Describe the bug

In the MoE (Mixture of Experts) auxiliary loss functions—_apply_aux_loss, _apply_seq_aux_loss, and _apply_global_aux_loss—the function switch_load_balancing_loss_func is called. This loss function computes a mean loss by dividing by an input argument total_num_tokens.

When calculate_per_token_loss is disabled (the common setting), the mean loss for the first two auxiliary losses (_apply_aux_loss and _apply_seq_aux_loss) is correctly computed based on the number of tokens in the current training step (micro-batch). However, _apply_global_aux_loss averages its loss over the global total number of tokens (i.e., across the entire data parallel group).

Because the global total number of tokens is larger by a factor of self.tp_dp_cp_group.size() (the world size), the gradient back-propagated from the global aux loss is proportionally smaller. Concretely, the gradient is scaled down by this factor compared to what the aux_loss_coeff would imply, making the coefficient's effect on the global loss inconsistent with its effect on the other two aux losses.


Steps/Code to reproduce bug

  1. Enable MoE with auxiliary losses in a training setup that uses tensor/model parallelism (e.g., --num-experts 8 --moe-aux-loss-coeff 0.01).
  2. Ensure --calculate-per-token-loss is not set (the default behavior).
  3. Run a few iterations and compare the gradient magnitudes contributed by the global aux loss versus the per-token or sequence-level aux losses.
  4. Observe that the global aux loss gradients are significantly smaller (by a factor approximately equal to the DP group size) for the same coefficient value.

A minimal code inspection can also confirm the issue:

  • In megatron/core/transformer/moe/moe_utils.py, examine switch_load_balancing_loss_func. Note it returns loss / total_num_tokens.
  • Trace the calls to this function in the three auxiliary loss methods mentioned above. You will see that _apply_global_aux_loss passes a total_num_tokens that is aggregated across the entire parallel group, while the other two use the local micro-batch token count.

Expected behavior

The global auxiliary loss should be averaged over the same token count as the other auxiliary losses when calculate_per_token_loss is disabled. That is, it should be averaged over the number of tokens in the current training step (micro-batch), not the global aggregated token count. This would ensure that the aux_loss_coeff has a consistent scaling effect across all auxiliary loss components.

Alternatively, if the global averaging is intentional, its behavior should be clearly documented and the coefficient scaling should be adjusted accordingly. However, the current implementation appears to be an unintended inconsistency.


Additional context

  • The root cause is that total_num_tokens in _apply_global_aux_loss is the sum of tokens across the entire parallel group, while in the other loss functions it is the local token count.
  • This leads to a practical effect where the global aux loss contributes much less to the overall gradient than expected for a given coefficient. For example, with a world size of 8, the effective coefficient for the global loss becomes aux_loss_coeff / 8.
  • Fixing this would involve either:
    1. Using the local token count for the global loss as well, or
    2. Explicitly scaling the loss or coefficient to maintain the intended influence.
  • This issue is particularly relevant when tuning MoE auxiliary loss coefficients, as the same coefficient value will have different impacts depending on which aux loss function it applies to.

Tagging @mcore-oncall for visibility.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions