From 91b0d753ac1ce6e6805ab8e516af85d5524ee833 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Sat, 14 Mar 2026 15:30:00 +0800 Subject: [PATCH] Fix tinker loss device mismatch --- src/twinkle/server/tinker/common/compat_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/twinkle/server/tinker/common/compat_base.py b/src/twinkle/server/tinker/common/compat_base.py index 62e22ff6..0c4e5f2f 100644 --- a/src/twinkle/server/tinker/common/compat_base.py +++ b/src/twinkle/server/tinker/common/compat_base.py @@ -142,6 +142,7 @@ def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: token_log_probs = logps[idx, :seq_len] # elementwise_loss: positive NLL loss (0.0 where masked) + token_log_probs = token_log_probs.to(weights.device) elementwise_loss = -token_log_probs * weights results.append({