We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6b48d61 commit 295a2beCopy full SHA for 295a2be
src/twinkle/loss/grpo.py
@@ -290,14 +290,13 @@ def __call__(
290
labels = labels.unsqueeze(0)
291
292
logps = outputs.get('logps')
293
+ loss_mask = (labels != self.ignore_index).bool()
294
if logps is None:
295
logits = outputs.get('logits')
296
if logits.shape[1] != labels.shape[1]:
297
# some mllm return logits with image tokens, exclude here
298
logits = logits[:, -labels.shape[1]:]
-
299
# labels = torch.roll(labels, shifts=-1, dims=1)
300
- loss_mask = (labels != self.ignore_index).bool()
301
masked_labels = labels.clone()
302
masked_labels[~loss_mask] = 0
303
logps = selective_log_softmax(logits, masked_labels)
0 commit comments