-
Notifications
You must be signed in to change notification settings - Fork 403
Open
Description
I think it should be softmax instead. Otherwise p_t and p_s are not comparable.
Could you please explain why?
RepDistiller/distiller_zoo/KD.py
Lines 13 to 17 in dcc0432
| def forward(self, y_s, y_t): | |
| p_s = F.log_softmax(y_s/self.T, dim=1) | |
| p_t = F.softmax(y_t/self.T, dim=1) | |
| loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0] | |
| return loss |
Metadata
Metadata
Assignees
Labels
No labels