Skip to content

Commit 7743eeb

Browse files
authored
Fix tinker loss device mismatch (#115)
1 parent d69a864 commit 7743eeb

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

src/twinkle/server/tinker/common/compat_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps:
142142
token_log_probs = logps[idx, :seq_len]
143143

144144
# elementwise_loss: positive NLL loss (0.0 where masked)
145+
token_log_probs = token_log_probs.to(weights.device)
145146
elementwise_loss = -token_log_probs * weights
146147

147148
results.append({

0 commit comments

Comments
 (0)