From 1de40ab48ead6d36d290cc864ed9cf56ae15eb78 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 2 Mar 2026 13:56:24 +0800 Subject: [PATCH] fix grpo loss --- src/twinkle/loss/grpo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/twinkle/loss/grpo.py b/src/twinkle/loss/grpo.py index c0ff950a..f07a26c8 100644 --- a/src/twinkle/loss/grpo.py +++ b/src/twinkle/loss/grpo.py @@ -290,14 +290,13 @@ def __call__( labels = labels.unsqueeze(0) logps = outputs.get('logps') + loss_mask = (labels != self.ignore_index).bool() if logps is None: logits = outputs.get('logits') if logits.shape[1] != labels.shape[1]: # some mllm return logits with image tokens, exclude here logits = logits[:, -labels.shape[1]:] - # labels = torch.roll(labels, shifts=-1, dims=1) - loss_mask = (labels != self.ignore_index).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 logps = selective_log_softmax(logits, masked_labels)