Skip to content

Commit 295a2be

Browse files
authored
[bugfix] fix grpo loss (#93)
1 parent 6b48d61 commit 295a2be

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

src/twinkle/loss/grpo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,14 +290,13 @@ def __call__(
290290
labels = labels.unsqueeze(0)
291291

292292
logps = outputs.get('logps')
293+
loss_mask = (labels != self.ignore_index).bool()
293294
if logps is None:
294295
logits = outputs.get('logits')
295296
if logits.shape[1] != labels.shape[1]:
296297
# some mllm return logits with image tokens, exclude here
297298
logits = logits[:, -labels.shape[1]:]
298-
299299
# labels = torch.roll(labels, shifts=-1, dims=1)
300-
loss_mask = (labels != self.ignore_index).bool()
301300
masked_labels = labels.clone()
302301
masked_labels[~loss_mask] = 0
303302
logps = selective_log_softmax(logits, masked_labels)

0 commit comments

Comments
 (0)