diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 98bc8912292..b4146d99576 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -327,7 +327,7 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): local_expert_indices_offset = ep_rank * self.num_local_experts prepend_axis_num = len(sharded_offsets) - replica_id = (0, 0, dp_rank) + replica_id = (0, tp_rank, dp_rank) local_ffn_dim_size = ( self.weight2.numel() // self.num_local_experts // self.config.hidden_size