diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index a426973a25b..ceccd70ffde 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -542,7 +542,6 @@ def save_to_aux_losses_tracker( loss: torch.Tensor, layer_number: int, num_layers: int, - reduce_group: torch.distributed.ProcessGroup = None, avg_group: torch.distributed.ProcessGroup = None, ): """Save the auxiliary loss for logging. @@ -551,8 +550,7 @@ def save_to_aux_losses_tracker( loss (torch.Tensor): The loss tensor. layer_number (int): Layer index of the loss. num_layers (int): The number of total layers. - reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss. - mean_group (torch.distributed.ProcessGroup): The group for averaging the loss. + avg_group (torch.distributed.ProcessGroup): The group for averaging the loss. """ # Skip aux loss logging if layer_number is None. if layer_number is None: @@ -563,7 +561,6 @@ def save_to_aux_losses_tracker( tracker[name] = {} tracker[name]["values"] = torch.zeros(num_layers, device=loss.device) tracker[name]["values"][layer_number - 1] += loss.detach() # Aggregate the loss for the layer. - tracker[name]["reduce_group"] = reduce_group tracker[name]["avg_group"] = avg_group @@ -576,8 +573,11 @@ def clear_aux_losses_tracker(): tracker[name]["avg_group"] = None -def reduce_aux_losses_tracker_across_ranks(): - """Collect and reduce the auxiliary losses across ranks.""" +def reduce_aux_losses_tracker_across_ranks(moe_token_dispatcher_type: str) -> None: + """Collect and reduce the auxiliary losses across ranks. + Args: + moe_token_dispatcher_type: The type of token dispatcher to use + """ tracker = parallel_state.get_moe_layer_wise_logging_tracker() for name in tracker: values = tracker[name]["values"] @@ -586,8 +586,21 @@ def reduce_aux_losses_tracker_across_ranks(): values, group=parallel_state.get_pipeline_model_parallel_group() ) # Reduce aux losses across ranks. - if tracker[name].get('reduce_group') is not None: - torch.distributed.all_reduce(values, group=tracker[name].get('reduce_group')) + if name == "load_balancing_loss": + reduce_group = None + if moe_token_dispatcher_type == "alltoall_seq": + reduce_group = parallel_state.get_context_parallel_group() + elif parallel_state.get_tensor_and_context_parallel_world_size() > 1: + reduce_group = parallel_state.get_tensor_and_context_parallel_group() + + if reduce_group: + torch.distributed.all_reduce(values, group=reduce_group) + elif name == "z_loss": + pass + else: + raise ValueError( + f"Unknown reduction method for metric {name}. Please specify how the reduction needs to take place" + ) if tracker[name].get('avg_group') is not None: torch.distributed.all_reduce( values, group=tracker[name]['avg_group'], op=torch.distributed.ReduceOp.AVG @@ -595,11 +608,11 @@ def reduce_aux_losses_tracker_across_ranks(): def track_moe_metrics( - loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None, per_layer_logging=False + loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None, per_layer_logging=False, moe_token_dispatcher_type="allgather" ): """Track the MoE metrics for logging.""" # Aux loss logging - reduce_aux_losses_tracker_across_ranks() + reduce_aux_losses_tracker_across_ranks(moe_token_dispatcher_type) tracker = parallel_state.get_moe_layer_wise_logging_tracker() if writer is not None: aux_losses = {k: v['values'].float() * loss_scale for k, v in tracker.items()} diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 5965c16dc68..844dd553c5f 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -247,7 +247,6 @@ def apply_load_balancing_loss( aux_loss / moe_aux_loss_coeff, self.layer_number, self.config.num_layers, - reduce_group=sequence_partition_group, ) activation = MoEAuxLossAutoScaler.apply(activation, aux_loss) return activation diff --git a/megatron/training/training.py b/megatron/training/training.py index e84980f048b..b8385b6628d 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1037,7 +1037,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r ) if args.num_experts is not None: moe_loss_scale = 1 / get_num_microbatches() - track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging) + track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging, args.moe_token_dispatcher_type) if iteration % args.log_interval == 0: elapsed_time = timers('interval-time').elapsed(barrier=True)