Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ def save_to_aux_losses_tracker(
loss: torch.Tensor,
layer_number: int,
num_layers: int,
reduce_group: torch.distributed.ProcessGroup = None,
avg_group: torch.distributed.ProcessGroup = None,
):
"""Save the auxiliary loss for logging.
Expand All @@ -551,8 +550,7 @@ def save_to_aux_losses_tracker(
loss (torch.Tensor): The loss tensor.
layer_number (int): Layer index of the loss.
num_layers (int): The number of total layers.
reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss.
mean_group (torch.distributed.ProcessGroup): The group for averaging the loss.
avg_group (torch.distributed.ProcessGroup): The group for averaging the loss.
"""
# Skip aux loss logging if layer_number is None.
if layer_number is None:
Expand All @@ -563,7 +561,6 @@ def save_to_aux_losses_tracker(
tracker[name] = {}
tracker[name]["values"] = torch.zeros(num_layers, device=loss.device)
tracker[name]["values"][layer_number - 1] += loss.detach() # Aggregate the loss for the layer.
tracker[name]["reduce_group"] = reduce_group
tracker[name]["avg_group"] = avg_group


Expand All @@ -576,8 +573,11 @@ def clear_aux_losses_tracker():
tracker[name]["avg_group"] = None


def reduce_aux_losses_tracker_across_ranks():
"""Collect and reduce the auxiliary losses across ranks."""
def reduce_aux_losses_tracker_across_ranks(moe_token_dispatcher_type: str) -> None:
"""Collect and reduce the auxiliary losses across ranks.
Args:
moe_token_dispatcher_type: The type of token dispatcher to use
"""
tracker = parallel_state.get_moe_layer_wise_logging_tracker()
for name in tracker:
values = tracker[name]["values"]
Expand All @@ -586,20 +586,33 @@ def reduce_aux_losses_tracker_across_ranks():
values, group=parallel_state.get_pipeline_model_parallel_group()
)
# Reduce aux losses across ranks.
if tracker[name].get('reduce_group') is not None:
torch.distributed.all_reduce(values, group=tracker[name].get('reduce_group'))
if name == "load_balancing_loss":
reduce_group = None
if moe_token_dispatcher_type == "alltoall_seq":
reduce_group = parallel_state.get_context_parallel_group()
elif parallel_state.get_tensor_and_context_parallel_world_size() > 1:
reduce_group = parallel_state.get_tensor_and_context_parallel_group()

if reduce_group:
torch.distributed.all_reduce(values, group=reduce_group)
elif name == "z_loss":
pass
else:
raise ValueError(
f"Unknown reduction method for metric {name}. Please specify how the reduction needs to take place"
)
if tracker[name].get('avg_group') is not None:
torch.distributed.all_reduce(
values, group=tracker[name]['avg_group'], op=torch.distributed.ReduceOp.AVG
)


def track_moe_metrics(
loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None, per_layer_logging=False
loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None, per_layer_logging=False, moe_token_dispatcher_type="allgather"
):
"""Track the MoE metrics for logging."""
# Aux loss logging
reduce_aux_losses_tracker_across_ranks()
reduce_aux_losses_tracker_across_ranks(moe_token_dispatcher_type)
tracker = parallel_state.get_moe_layer_wise_logging_tracker()
if writer is not None:
aux_losses = {k: v['values'].float() * loss_scale for k, v in tracker.items()}
Expand Down
1 change: 0 additions & 1 deletion megatron/core/transformer/moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ def apply_load_balancing_loss(
aux_loss / moe_aux_loss_coeff,
self.layer_number,
self.config.num_layers,
reduce_group=sequence_partition_group,
)
activation = MoEAuxLossAutoScaler.apply(activation, aux_loss)
return activation
Expand Down
2 changes: 1 addition & 1 deletion megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
)
if args.num_experts is not None:
moe_loss_scale = 1 / get_num_microbatches()
track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging)
track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging, args.moe_token_dispatcher_type)

if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed(barrier=True)
Expand Down