99from typing import TYPE_CHECKING , Dict , List , Optional , Union
1010
1111from twinkle .data_format import LossOutput
12- from twinkle .utils .torch_utils import selective_log_softmax
1312from twinkle .loss .base import Loss
13+ from twinkle .utils .torch_utils import selective_log_softmax
1414
1515if TYPE_CHECKING :
1616 import torch
@@ -176,14 +176,10 @@ def _align_logps(
176176 # Truncate right (keep left part) - may happen in Ray result merging
177177 return logps [:, :target_seq_len ]
178178 else :
179- raise ValueError (
180- f'ref_logps seq_len ({ src_seq_len } ) < target seq_len ({ target_seq_len } ). '
181- f'This should not happen when both models process the same batch.'
182- )
179+ raise ValueError (f'ref_logps seq_len ({ src_seq_len } ) < target seq_len ({ target_seq_len } ). '
180+ f'This should not happen when both models process the same batch.' )
183181
184- raise ValueError (
185- f'Cannot align ref_logps shape { logps .shape } to target shape { target_shape } '
186- )
182+ raise ValueError (f'Cannot align ref_logps shape { logps .shape } to target shape { target_shape } ' )
187183
188184 def _compute_dpo_loss (
189185 self ,
@@ -227,7 +223,7 @@ def _compute_dpo_loss(
227223 elif self .loss_type == 'ipo' :
228224 # IPO (Identity Preference Optimization) loss
229225 # Reference: "A General Theoretical Paradigm to Understand Learning from Human Feedback"
230- losses = (logits - 1 / (2 * self .beta )) ** 2
226+ losses = (logits - 1 / (2 * self .beta ))** 2
231227 elif self .loss_type == 'kto_pair' :
232228 # KTO pair loss (simplified version)
233229 chosen_logratios_scaled = self .beta * chosen_logratios
@@ -236,7 +232,7 @@ def _compute_dpo_loss(
236232 rejected_losses = F .sigmoid (rejected_logratios_scaled )
237233 losses = chosen_losses + rejected_losses
238234 else :
239- raise ValueError (f" Unknown loss_type: { self .loss_type } " )
235+ raise ValueError (f' Unknown loss_type: { self .loss_type } ' )
240236
241237 # Apply label smoothing if specified
242238 if self .label_smoothing > 0 :
@@ -292,7 +288,7 @@ def __call__(
292288 labels = labels .unsqueeze (0 )
293289
294290 batch_size = labels .shape [0 ]
295- assert batch_size % 2 == 0 , " Batch size must be even (chosen + rejected pairs)"
291+ assert batch_size % 2 == 0 , ' Batch size must be even (chosen + rejected pairs)'
296292
297293 # Get log probabilities from outputs
298294 logps = self ._get_logps_from_outputs (outputs , labels )
@@ -314,9 +310,7 @@ def __call__(
314310 reference_rejected_logps = ref_rejected_logps .to (device = device , dtype = dtype )
315311 elif ref_logps is not None :
316312 # Per-token reference log probs provided, need to align and sum
317- ref_logps_aligned = self ._align_logps (
318- ref_logps , labels .shape , device , dtype
319- )
313+ ref_logps_aligned = self ._align_logps (ref_logps , labels .shape , device , dtype )
320314 ref_chosen , ref_rejected = self ._split_chosen_rejected (ref_logps_aligned )
321315 reference_chosen_logps = self ._compute_sequence_logps (ref_chosen , chosen_labels )
322316 reference_rejected_logps = self ._compute_sequence_logps (ref_rejected , rejected_labels )
@@ -392,7 +386,7 @@ def __call__(
392386 if labels .dim () == 1 :
393387 labels = labels .unsqueeze (0 )
394388
395- assert labels .shape [0 ] % 2 == 0 , " Batch size must be even (chosen + rejected pairs)"
389+ assert labels .shape [0 ] % 2 == 0 , ' Batch size must be even (chosen + rejected pairs)'
396390
397391 # Get log probabilities
398392 logps = self ._get_logps_from_outputs (outputs , labels )
@@ -455,7 +449,7 @@ def __call__(
455449 if labels .dim () == 1 :
456450 labels = labels .unsqueeze (0 )
457451
458- assert labels .shape [0 ] % 2 == 0 , " Batch size must be even"
452+ assert labels .shape [0 ] % 2 == 0 , ' Batch size must be even'
459453
460454 # Get log probabilities
461455 logps = self ._get_logps_from_outputs (outputs , labels )
@@ -521,7 +515,7 @@ def __call__(
521515 if labels .dim () == 1 :
522516 labels = labels .unsqueeze (0 )
523517
524- assert labels .shape [0 ] % 2 == 0 , " Batch size must be even"
518+ assert labels .shape [0 ] % 2 == 0 , ' Batch size must be even'
525519
526520 # Get log probabilities
527521 logps = self ._get_logps_from_outputs (outputs , labels )
@@ -540,8 +534,8 @@ def __call__(
540534 # Odds ratio: log(odds_chosen / odds_rejected)
541535 # log_odds = log(p/(1-p)) = log(p) - log(1-p)
542536 # Use numerically stable computation
543- prob_chosen = torch .exp (chosen_avg_logps ).clamp (min = 1e-7 , max = 1 - 1e-7 )
544- prob_rejected = torch .exp (rejected_avg_logps ).clamp (min = 1e-7 , max = 1 - 1e-7 )
537+ prob_chosen = torch .exp (chosen_avg_logps ).clamp (min = 1e-7 , max = 1 - 1e-7 )
538+ prob_rejected = torch .exp (rejected_avg_logps ).clamp (min = 1e-7 , max = 1 - 1e-7 )
545539 log_odds_chosen = torch .log (prob_chosen ) - torch .log (1 - prob_chosen )
546540 log_odds_rejected = torch .log (prob_rejected ) - torch .log (1 - prob_rejected )
547541
0 commit comments