Skip to content

Commit f9bd261

Browse files
HIT-cwhCyCle1024
authored andcommitted
fix noaux_router balanceloss bug
1 parent 8bbbef7 commit f9bd261

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

xtuner/v1/module/router/noaux_router.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def forward(self, logits) -> RouterResults:
111111
_, topk_idx = torch.topk(scores_for_choice, k=self.top_k, dim=-1)
112112
topk_weight = scores.gather(1, topk_idx)
113113

114+
# The returned `router_weights` is only used for computing balance loss
115+
# It should be normalized
116+
scores_for_choice = scores_for_choice / torch.sum(scores_for_choice, dim=-1, keepdim=True)
117+
114118
if self.top_k > 1 and self.norm_topk_prob:
115119
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
116120
topk_weight = topk_weight / denominator
@@ -169,6 +173,10 @@ def forward(self, logits) -> RouterResults:
169173
topk_weight = scores_for_choice.gather(1, topk_idx) # [seq, n_groups]
170174
scores_for_choice = scores_for_choice.view(seq, self.n_routed_experts)
171175

176+
# The returned `router_weights` is only used for computing balance loss
177+
# It should be normalized
178+
scores_for_choice = scores_for_choice / torch.sum(scores_for_choice, dim=-1, keepdim=True)
179+
172180
if self.top_k > 1 and self.norm_topk_prob:
173181
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
174182
topk_weight = topk_weight / denominator

0 commit comments

Comments
 (0)