@@ -88,9 +88,13 @@ def _split_chosen_rejected(
8888 self ,
8989 tensor : 'torch.Tensor' ,
9090 ) -> tuple :
91- """Split tensor into chosen (first half) and rejected (second half)."""
92- half = tensor .shape [0 ] // 2
93- return tensor [:half ], tensor [half :]
91+ """Split interleaved tensor into chosen and rejected.
92+
93+ Input format: [pos_1, neg_1, pos_2, neg_2, ...] (interleaved for DP-safe slicing)
94+ Output: (chosen [pos_1, pos_2, ...], rejected [neg_1, neg_2, ...])
95+ """
96+ # Even indices = chosen (positive), odd indices = rejected (negative)
97+ return tensor [0 ::2 ], tensor [1 ::2 ]
9498
9599
96100class DPOLoss (PreferenceLossBase ):
@@ -131,20 +135,18 @@ def __init__(
131135 self .loss_type = loss_type
132136 self .reference_free = reference_free
133137
134- def _pad_and_align_logps (
138+ def _align_logps (
135139 self ,
136- logps : Union [ 'torch.Tensor' , List [ List [ float ]]] ,
140+ logps : 'torch.Tensor' ,
137141 target_shape : tuple ,
138- loss_mask : 'torch.Tensor' ,
139142 device : 'torch.device' ,
140143 dtype : 'torch.dtype' ,
141144 ) -> 'torch.Tensor' :
142- """Pad and align log probabilities to target shape.
145+ """Align log probabilities to target shape.
143146
144147 Args:
145- logps: Input log probabilities ( tensor or ragged list)
148+ logps: Input log probabilities tensor
146149 target_shape: Target (batch, seq_len) shape
147- loss_mask: Boolean mask for valid positions
148150 device: Target device
149151 dtype: Target dtype
150152
@@ -153,40 +155,32 @@ def _pad_and_align_logps(
153155 """
154156 import torch
155157
156- if torch .is_tensor (logps ):
157- if logps .dim () == 1 :
158- logps = logps .unsqueeze (0 )
159- if logps .shape == target_shape :
160- return logps .to (device = device , dtype = dtype )
161- # Handle tensor with different sequence length - align to target shape
162- if logps .dim () == 2 and logps .shape [0 ] == target_shape [0 ]:
163- batch_size , target_seq_len = target_shape
164- src_seq_len = logps .shape [1 ]
165- logps = logps .to (device = device , dtype = dtype )
166- if src_seq_len > target_seq_len :
167- # Truncate: take the last target_seq_len tokens (response part)
168- return logps [:, - target_seq_len :]
169- else :
170- # Pad: add zeros at the beginning
171- padded = torch .zeros (target_shape , device = device , dtype = dtype )
172- padded [:, - src_seq_len :] = logps
173- return padded
174-
175- # Handle ragged list input
176- if isinstance (logps , (list , tuple )):
177- batch_size , seq_len = target_shape
178- padded = torch .zeros (target_shape , device = device , dtype = dtype )
179- for i , row in enumerate (logps ):
180- if row is None :
181- continue
182- row_t = torch .as_tensor (row , device = device , dtype = dtype )
183- valid_positions = loss_mask [i ].nonzero (as_tuple = True )[0 ]
184- length = min (len (row_t ), len (valid_positions ))
185- if length > 0 :
186- padded [i , valid_positions [:length ]] = row_t [:length ]
187- return padded
188-
189- return logps .to (device = device , dtype = dtype )
158+ if not torch .is_tensor (logps ):
159+ raise TypeError (f'Expected torch.Tensor, got { type (logps )} ' )
160+
161+ if logps .dim () == 1 :
162+ logps = logps .unsqueeze (0 )
163+
164+ if logps .shape == target_shape :
165+ return logps .to (device = device , dtype = dtype )
166+
167+ # Handle tensor with different sequence length
168+ if logps .dim () == 2 and logps .shape [0 ] == target_shape [0 ]:
169+ batch_size , target_seq_len = target_shape
170+ src_seq_len = logps .shape [1 ]
171+ logps = logps .to (device = device , dtype = dtype )
172+ if src_seq_len > target_seq_len :
173+ # Truncate right (keep left part) - may happen in Ray result merging
174+ return logps [:, :target_seq_len ]
175+ else :
176+ raise ValueError (
177+ f'ref_logps seq_len ({ src_seq_len } ) < target seq_len ({ target_seq_len } ). '
178+ f'This should not happen when both models process the same batch.'
179+ )
180+
181+ raise ValueError (
182+ f'Cannot align ref_logps shape { logps .shape } to target shape { target_shape } '
183+ )
190184
191185 def _compute_dpo_loss (
192186 self ,
@@ -254,6 +248,7 @@ def __call__(
254248 inputs : Dict ,
255249 outputs : Dict ,
256250 * ,
251+ ref_outputs : Optional [Dict ] = None ,
257252 ref_logps : Optional [Union ['torch.Tensor' , List [List [float ]]]] = None ,
258253 ref_chosen_logps : Optional ['torch.Tensor' ] = None ,
259254 ref_rejected_logps : Optional ['torch.Tensor' ] = None ,
@@ -271,6 +266,7 @@ def __call__(
271266 outputs: Dict containing either:
272267 - 'logps': [batch, seq_len] pre-computed log probs, OR
273268 - 'logits': [batch, seq_len, vocab] from which logps will be computed
269+ ref_outputs: Dict from reference model forward, containing 'logps'.
274270 ref_logps: [batch, seq_len] or List[List[float]] reference model log probs.
275271 Can also be provided as separate ref_chosen_logps and ref_rejected_logps.
276272 ref_chosen_logps: [batch/2] pre-computed reference log probs for chosen.
@@ -282,6 +278,10 @@ def __call__(
282278 """
283279 import torch
284280
281+ # Extract ref_logps from ref_outputs if provided
282+ if ref_outputs is not None and ref_logps is None :
283+ ref_logps = ref_outputs .get ('logps' )
284+
285285 labels = inputs .get ('labels' )
286286 assert labels is not None , "inputs must contain 'labels'"
287287 if not torch .is_tensor (labels ):
@@ -312,9 +312,8 @@ def __call__(
312312 reference_rejected_logps = ref_rejected_logps .to (device = device , dtype = dtype )
313313 elif ref_logps is not None :
314314 # Per-token reference log probs provided, need to align and sum
315- loss_mask = (labels != self .ignore_index ).bool ()
316- ref_logps_aligned = self ._pad_and_align_logps (
317- ref_logps , labels .shape , loss_mask , device , dtype
315+ ref_logps_aligned = self ._align_logps (
316+ ref_logps , labels .shape , device , dtype
318317 )
319318 ref_chosen , ref_rejected = self ._split_chosen_rejected (ref_logps_aligned )
320319 reference_chosen_logps = self ._compute_sequence_logps (ref_chosen , chosen_labels )
0 commit comments