From 9ef958cc0f69a7649077ad1ca4b2625516313ba9 Mon Sep 17 00:00:00 2001 From: Anshita Date: Thu, 30 Mar 2023 02:30:31 -0400 Subject: [PATCH] Fix the same device issue --- climatenet/utils/losses.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/climatenet/utils/losses.py b/climatenet/utils/losses.py index 42e0389..bd96652 100644 --- a/climatenet/utils/losses.py +++ b/climatenet/utils/losses.py @@ -18,7 +18,8 @@ def jaccard_loss(logits, true, eps=1e-7): jacc_loss: the Jaccard loss. """ num_classes = logits.shape[1] - true_1_hot = torch.eye(num_classes)[true.squeeze(1)] + # Keep on same device + true_1_hot = torch.eye(num_classes).to(true.device)[true.squeeze(1)] true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() probas = F.softmax(logits, dim=1) true_1_hot = true_1_hot.type(logits.type())