@@ -111,6 +111,10 @@ def forward(self, logits) -> RouterResults:
111111 _ , topk_idx = torch .topk (scores_for_choice , k = self .top_k , dim = - 1 )
112112 topk_weight = scores .gather (1 , topk_idx )
113113
114+ # The returned `router_weights` is only used for computing balance loss
115+ # It should be normalized
116+ scores_for_choice = scores_for_choice / torch .sum (scores_for_choice , dim = - 1 , keepdim = True )
117+
114118 if self .top_k > 1 and self .norm_topk_prob :
115119 denominator = topk_weight .sum (dim = - 1 , keepdim = True ) + 1e-20
116120 topk_weight = topk_weight / denominator
@@ -169,6 +173,10 @@ def forward(self, logits) -> RouterResults:
169173 topk_weight = scores_for_choice .gather (1 , topk_idx ) # [seq, n_groups]
170174 scores_for_choice = scores_for_choice .view (seq , self .n_routed_experts )
171175
176+ # The returned `router_weights` is only used for computing balance loss
177+ # It should be normalized
178+ scores_for_choice = scores_for_choice / torch .sum (scores_for_choice , dim = - 1 , keepdim = True )
179+
172180 if self .top_k > 1 and self .norm_topk_prob :
173181 denominator = topk_weight .sum (dim = - 1 , keepdim = True ) + 1e-20
174182 topk_weight = topk_weight / denominator
0 commit comments