diff --git a/src/twinkle/server/tinker/common/compat_base.py b/src/twinkle/server/tinker/common/compat_base.py index 62e22ff6..0c4e5f2f 100644 --- a/src/twinkle/server/tinker/common/compat_base.py +++ b/src/twinkle/server/tinker/common/compat_base.py @@ -142,6 +142,7 @@ def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: token_log_probs = logps[idx, :seq_len] # elementwise_loss: positive NLL loss (0.0 where masked) + token_log_probs = token_log_probs.to(weights.device) elementwise_loss = -token_log_probs * weights results.append({