Skip to content

Commit 3a25caa

Browse files
committed
wip
1 parent d1f223f commit 3a25caa

File tree

4 files changed

+124
-173
lines changed

4 files changed

+124
-173
lines changed

cookbook/rl/dpo.py

Lines changed: 31 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
import os
5252
from typing import Any, Dict, List, Optional
5353

54-
import torch
5554
from peft import LoraConfig
5655

5756
import twinkle
@@ -63,7 +62,6 @@
6362
from twinkle.model import TransformersModel
6463
from twinkle.preprocessor import EmojiDPOProcessor
6564
from twinkle.processor import InputProcessor
66-
from twinkle.template import Template
6765

6866
logger = get_logger()
6967

@@ -75,8 +73,8 @@
7573
REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4))
7674
NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS
7775

78-
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs
79-
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2))
76+
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs
77+
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4))
8078
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
8179
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
8280
LEARNING_RATE = float(os.environ.get('LR', 5e-6))
@@ -100,31 +98,38 @@ def create_dpo_dataset():
10098
)
10199
# DPO preprocessor returns {'positive': [...], 'negative': [...]}
102100
# batch_encode handles this format automatically
103-
dataset.encode()
101+
dataset.encode(load_from_cache_file=True)
104102
return dataset
105103

106104

107-
def prepare_dpo_batch(
108-
batch: Dict[str, List[Any]],
109-
template: Template,
110-
) -> List[Dict[str, Any]]:
111-
"""Prepare DPO batch: convert encoded batch to list format for training.
105+
def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
106+
"""Prepare DPO batch: reorganize batch for training with DP-safe interleaving.
112107
113108
Args:
114-
batch: Dict with 'positive' and 'negative' keys, each containing List[InputFeature]
109+
batch: List of rows, each with 'positive' and 'negative' InputFeatures
110+
and other fields (question, etc.)
115111
116112
Returns:
117-
List organized as [positive_1, ..., positive_n, negative_1, ..., negative_n]
113+
List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP
114+
worker gets complete positive/negative pairs after slicing.
115+
Each item contains all original fields plus the InputFeature fields.
118116
"""
119-
positive_features = batch.get('positive', [])
120-
negative_features = batch.get('negative', [])
117+
result = []
121118

122-
# Convert to list of dicts
123-
positive_samples = [dict(f) for f in positive_features]
124-
negative_samples = [dict(f) for f in negative_features]
119+
for row in batch:
120+
# Get base fields (excluding positive/negative)
121+
base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')}
125122

126-
# Return [positive..., negative...]
127-
return positive_samples + negative_samples
123+
# Positive sample: merge base fields with positive InputFeature
124+
pos_sample = {**base_fields, **row['positive']}
125+
# Negative sample: merge base fields with negative InputFeature
126+
neg_sample = {**base_fields, **row['negative']}
127+
128+
# Interleave: [pos, neg] per pair for DP-safe slicing
129+
result.append(pos_sample)
130+
result.append(neg_sample)
131+
132+
return result
128133

129134

130135
# ── Loss Factory ──────────────────────────────────────────────────────────────
@@ -196,9 +201,6 @@ def main():
196201
policy_model.set_processor(InputProcessor)
197202
policy_model.set_template('Template', model_id=MODEL_ID)
198203

199-
# Get template for encoding rejected messages
200-
template = Template(model_id=MODEL_ID, max_length=MAX_LENGTH)
201-
202204
# ── Reference Model Setup ─────────────────────────────────────────────────
203205
ref_model = None
204206
if not reference_free:
@@ -223,50 +225,19 @@ def main():
223225
if optim_step >= MAX_STEPS:
224226
break
225227

226-
# batch is Dict[str, List[Trajectory]] with 'positive' and 'negative' keys
227-
dpo_batch = prepare_dpo_batch(batch, template)
228+
# batch is List[Dict] with 'positive' and 'negative' keys
229+
dpo_batch = prepare_dpo_batch(batch)
228230

229-
# Compute reference log probabilities if using reference model
230-
# We compute sequence-level logps here to avoid alignment issues with micro-batching
231-
ref_chosen_logps = None
232-
ref_rejected_logps = None
231+
# Get reference outputs (lazy - not collected to driver)
232+
ref_outputs = None
233233
if ref_model is not None:
234-
with torch.no_grad():
235-
ref_outputs = ref_model.forward_only(inputs=dpo_batch)
236-
ref_logps = ref_outputs.get('logps') # [batch, seq_len]
237-
if ref_logps is not None:
238-
# Get labels and pad to same length for stacking
239-
label_tensors = [torch.as_tensor(s['labels']) for s in dpo_batch]
240-
max_len = max(t.shape[0] for t in label_tensors)
241-
# Pad labels with -100 (ignore_index) to max length
242-
padded_labels = []
243-
for t in label_tensors:
244-
if t.shape[0] < max_len:
245-
pad_size = max_len - t.shape[0]
246-
t = torch.cat([torch.full((pad_size,), -100, dtype=t.dtype), t])
247-
padded_labels.append(t)
248-
ref_labels = torch.stack(padded_labels)
249-
if ref_labels.device != ref_logps.device:
250-
ref_labels = ref_labels.to(ref_logps.device)
251-
# Align sequence lengths if needed
252-
if ref_logps.shape[1] != ref_labels.shape[1]:
253-
min_len = min(ref_logps.shape[1], ref_labels.shape[1])
254-
ref_logps = ref_logps[:, -min_len:]
255-
ref_labels = ref_labels[:, -min_len:]
256-
# Compute sequence-level logps (sum of valid token logps)
257-
loss_mask = (ref_labels != -100).float()
258-
seq_logps = (ref_logps * loss_mask).sum(dim=-1) # [batch]
259-
260-
# Split into chosen and rejected
261-
half = seq_logps.shape[0] // 2
262-
ref_chosen_logps = seq_logps[:half]
263-
ref_rejected_logps = seq_logps[half:]
234+
ref_outputs = ref_model.forward_only(inputs=dpo_batch)
264235

265236
# Forward-backward pass with DPO loss
237+
# ref_outputs is passed to loss which extracts logps internally
266238
policy_model.forward_backward(
267239
inputs=dpo_batch,
268-
ref_chosen_logps=ref_chosen_logps,
269-
ref_rejected_logps=ref_rejected_logps,
240+
ref_outputs=ref_outputs,
270241
)
271242

272243
# Gradient clipping and optimizer step

src/twinkle/loss/dpo.py

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,13 @@ def _split_chosen_rejected(
8888
self,
8989
tensor: 'torch.Tensor',
9090
) -> tuple:
91-
"""Split tensor into chosen (first half) and rejected (second half)."""
92-
half = tensor.shape[0] // 2
93-
return tensor[:half], tensor[half:]
91+
"""Split interleaved tensor into chosen and rejected.
92+
93+
Input format: [pos_1, neg_1, pos_2, neg_2, ...] (interleaved for DP-safe slicing)
94+
Output: (chosen [pos_1, pos_2, ...], rejected [neg_1, neg_2, ...])
95+
"""
96+
# Even indices = chosen (positive), odd indices = rejected (negative)
97+
return tensor[0::2], tensor[1::2]
9498

9599

96100
class DPOLoss(PreferenceLossBase):
@@ -131,20 +135,18 @@ def __init__(
131135
self.loss_type = loss_type
132136
self.reference_free = reference_free
133137

134-
def _pad_and_align_logps(
138+
def _align_logps(
135139
self,
136-
logps: Union['torch.Tensor', List[List[float]]],
140+
logps: 'torch.Tensor',
137141
target_shape: tuple,
138-
loss_mask: 'torch.Tensor',
139142
device: 'torch.device',
140143
dtype: 'torch.dtype',
141144
) -> 'torch.Tensor':
142-
"""Pad and align log probabilities to target shape.
145+
"""Align log probabilities to target shape.
143146
144147
Args:
145-
logps: Input log probabilities (tensor or ragged list)
148+
logps: Input log probabilities tensor
146149
target_shape: Target (batch, seq_len) shape
147-
loss_mask: Boolean mask for valid positions
148150
device: Target device
149151
dtype: Target dtype
150152
@@ -153,40 +155,32 @@ def _pad_and_align_logps(
153155
"""
154156
import torch
155157

156-
if torch.is_tensor(logps):
157-
if logps.dim() == 1:
158-
logps = logps.unsqueeze(0)
159-
if logps.shape == target_shape:
160-
return logps.to(device=device, dtype=dtype)
161-
# Handle tensor with different sequence length - align to target shape
162-
if logps.dim() == 2 and logps.shape[0] == target_shape[0]:
163-
batch_size, target_seq_len = target_shape
164-
src_seq_len = logps.shape[1]
165-
logps = logps.to(device=device, dtype=dtype)
166-
if src_seq_len > target_seq_len:
167-
# Truncate: take the last target_seq_len tokens (response part)
168-
return logps[:, -target_seq_len:]
169-
else:
170-
# Pad: add zeros at the beginning
171-
padded = torch.zeros(target_shape, device=device, dtype=dtype)
172-
padded[:, -src_seq_len:] = logps
173-
return padded
174-
175-
# Handle ragged list input
176-
if isinstance(logps, (list, tuple)):
177-
batch_size, seq_len = target_shape
178-
padded = torch.zeros(target_shape, device=device, dtype=dtype)
179-
for i, row in enumerate(logps):
180-
if row is None:
181-
continue
182-
row_t = torch.as_tensor(row, device=device, dtype=dtype)
183-
valid_positions = loss_mask[i].nonzero(as_tuple=True)[0]
184-
length = min(len(row_t), len(valid_positions))
185-
if length > 0:
186-
padded[i, valid_positions[:length]] = row_t[:length]
187-
return padded
188-
189-
return logps.to(device=device, dtype=dtype)
158+
if not torch.is_tensor(logps):
159+
raise TypeError(f'Expected torch.Tensor, got {type(logps)}')
160+
161+
if logps.dim() == 1:
162+
logps = logps.unsqueeze(0)
163+
164+
if logps.shape == target_shape:
165+
return logps.to(device=device, dtype=dtype)
166+
167+
# Handle tensor with different sequence length
168+
if logps.dim() == 2 and logps.shape[0] == target_shape[0]:
169+
batch_size, target_seq_len = target_shape
170+
src_seq_len = logps.shape[1]
171+
logps = logps.to(device=device, dtype=dtype)
172+
if src_seq_len > target_seq_len:
173+
# Truncate right (keep left part) - may happen in Ray result merging
174+
return logps[:, :target_seq_len]
175+
else:
176+
raise ValueError(
177+
f'ref_logps seq_len ({src_seq_len}) < target seq_len ({target_seq_len}). '
178+
f'This should not happen when both models process the same batch.'
179+
)
180+
181+
raise ValueError(
182+
f'Cannot align ref_logps shape {logps.shape} to target shape {target_shape}'
183+
)
190184

191185
def _compute_dpo_loss(
192186
self,
@@ -254,6 +248,7 @@ def __call__(
254248
inputs: Dict,
255249
outputs: Dict,
256250
*,
251+
ref_outputs: Optional[Dict] = None,
257252
ref_logps: Optional[Union['torch.Tensor', List[List[float]]]] = None,
258253
ref_chosen_logps: Optional['torch.Tensor'] = None,
259254
ref_rejected_logps: Optional['torch.Tensor'] = None,
@@ -271,6 +266,7 @@ def __call__(
271266
outputs: Dict containing either:
272267
- 'logps': [batch, seq_len] pre-computed log probs, OR
273268
- 'logits': [batch, seq_len, vocab] from which logps will be computed
269+
ref_outputs: Dict from reference model forward, containing 'logps'.
274270
ref_logps: [batch, seq_len] or List[List[float]] reference model log probs.
275271
Can also be provided as separate ref_chosen_logps and ref_rejected_logps.
276272
ref_chosen_logps: [batch/2] pre-computed reference log probs for chosen.
@@ -282,6 +278,10 @@ def __call__(
282278
"""
283279
import torch
284280

281+
# Extract ref_logps from ref_outputs if provided
282+
if ref_outputs is not None and ref_logps is None:
283+
ref_logps = ref_outputs.get('logps')
284+
285285
labels = inputs.get('labels')
286286
assert labels is not None, "inputs must contain 'labels'"
287287
if not torch.is_tensor(labels):
@@ -312,9 +312,8 @@ def __call__(
312312
reference_rejected_logps = ref_rejected_logps.to(device=device, dtype=dtype)
313313
elif ref_logps is not None:
314314
# Per-token reference log probs provided, need to align and sum
315-
loss_mask = (labels != self.ignore_index).bool()
316-
ref_logps_aligned = self._pad_and_align_logps(
317-
ref_logps, labels.shape, loss_mask, device, dtype
315+
ref_logps_aligned = self._align_logps(
316+
ref_logps, labels.shape, device, dtype
318317
)
319318
ref_chosen, ref_rejected = self._split_chosen_rejected(ref_logps_aligned)
320319
reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels)

src/twinkle/metric/loss.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def calculate(self):
5252
'grad_norm': self.grad_norm,
5353
'num_tokens': self.num_tokens
5454
}]
55-
5655
all_results = self.gather_results(local_results)
5756

5857
total_loss = sum(r['loss'] for r in all_results)

0 commit comments

Comments
 (0)