diff --git a/src/twinkle/loss/grpo.py b/src/twinkle/loss/grpo.py index c0ff950a..f07a26c8 100644 --- a/src/twinkle/loss/grpo.py +++ b/src/twinkle/loss/grpo.py @@ -290,14 +290,13 @@ def __call__( labels = labels.unsqueeze(0) logps = outputs.get('logps') + loss_mask = (labels != self.ignore_index).bool() if logps is None: logits = outputs.get('logits') if logits.shape[1] != labels.shape[1]: # some mllm return logits with image tokens, exclude here logits = logits[:, -labels.shape[1]:] - # labels = torch.roll(labels, shifts=-1, dims=1) - loss_mask = (labels != self.ignore_index).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 logps = selective_log_softmax(logits, masked_labels)