Skip to content

Commit 0cf1ac3

Browse files
committed
wip
1 parent 3a25caa commit 0cf1ac3

File tree

6 files changed

+217
-16
lines changed

6 files changed

+217
-16
lines changed

cookbook/rl/dpo.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,15 @@
5959
from twinkle.dataloader import DataLoader
6060
from twinkle.dataset import Dataset, DatasetMeta
6161
from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss
62+
from twinkle.metric import DPOMetric
6263
from twinkle.model import TransformersModel
6364
from twinkle.preprocessor import EmojiDPOProcessor
6465
from twinkle.processor import InputProcessor
6566

6667
logger = get_logger()
6768

6869
# ── Configuration ─────────────────────────────────────────────────────────────
69-
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
70+
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct')
7071
DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')
7172

7273
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
@@ -75,20 +76,21 @@
7576

7677
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs
7778
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4))
78-
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
79+
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 8))
7980
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
80-
LEARNING_RATE = float(os.environ.get('LR', 5e-6))
81+
LEARNING_RATE = float(os.environ.get('LR', 5e-5))
8182
DPO_BETA = float(os.environ.get('DPO_BETA', 0.1))
83+
SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 0.1)) # SFT loss weight for regularization
8284
LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo
83-
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100))
85+
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 200))
8486
MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048))
8587
ADAPTER_NAME = 'default'
8688
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.')
8789

8890

8991
def create_dpo_dataset():
9092
"""Create DPO dataset with positive/negative format."""
91-
dataset = Dataset(DatasetMeta(DATASET_ID))
93+
dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(15000)))
9294
dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH)
9395
dataset.map(
9496
EmojiDPOProcessor,
@@ -134,7 +136,7 @@ def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
134136

135137
# ── Loss Factory ──────────────────────────────────────────────────────────────
136138

137-
def create_loss(loss_type: str, beta: float, reference_free: bool = False):
139+
def create_loss(loss_type: str, beta: float, sft_weight: float = 0.0, reference_free: bool = False):
138140
"""Create the appropriate loss function based on configuration."""
139141
if loss_type == 'simpo':
140142
return SimPOLoss(beta=beta, gamma=0.5)
@@ -148,6 +150,7 @@ def create_loss(loss_type: str, beta: float, reference_free: bool = False):
148150
beta=beta,
149151
loss_type=loss_type,
150152
reference_free=reference_free,
153+
sft_weight=sft_weight,
151154
)
152155

153156

@@ -174,10 +177,7 @@ def main():
174177

175178
# ── Policy Model Setup ────────────────────────────────────────────────────
176179
lora_config = LoraConfig(
177-
target_modules=[
178-
'q_proj', 'k_proj', 'v_proj', 'o_proj',
179-
'gate_proj', 'up_proj', 'down_proj',
180-
],
180+
target_modules='all-linear',
181181
r=16,
182182
lora_alpha=32,
183183
lora_dropout=0.05,
@@ -195,9 +195,10 @@ def main():
195195
# Determine if we need reference model based on loss type
196196
reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo']
197197

198-
# Set up loss function
199-
loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=False)
198+
# Set up loss function and metrics
199+
loss_fn = create_loss(LOSS_TYPE, DPO_BETA, sft_weight=SFT_WEIGHT, reference_free=False)
200200
policy_model.set_loss(loss_fn)
201+
policy_model.add_metric(DPOMetric, beta=DPO_BETA)
201202
policy_model.set_processor(InputProcessor)
202203
policy_model.set_template('Template', model_id=MODEL_ID)
203204

src/twinkle/loss/dpo.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class DPOLoss(PreferenceLossBase):
118118
ignore_index: Index to ignore in labels (default: -100).
119119
loss_type: Type of DPO loss variant ('sigmoid', 'hinge', 'ipo', 'kto_pair') (default: 'sigmoid').
120120
reference_free: Whether to use reference-free DPO (default: False).
121+
sft_weight: Weight for SFT loss on chosen responses to prevent likelihood displacement (default: 0.0).
121122
"""
122123

123124
def __init__(
@@ -127,13 +128,15 @@ def __init__(
127128
ignore_index: int = -100,
128129
loss_type: str = 'sigmoid',
129130
reference_free: bool = False,
131+
sft_weight: float = 0.0,
130132
**kwargs,
131133
):
132134
super().__init__(ignore_index=ignore_index)
133135
self.beta = beta
134136
self.label_smoothing = label_smoothing
135137
self.loss_type = loss_type
136138
self.reference_free = reference_free
139+
self.sft_weight = sft_weight
137140

138141
def _align_logps(
139142
self,
@@ -329,14 +332,26 @@ def __call__(
329332
)
330333

331334
# Compute DPO loss
332-
loss = self._compute_dpo_loss(
335+
dpo_loss = self._compute_dpo_loss(
333336
policy_chosen_logps,
334337
policy_rejected_logps,
335338
reference_chosen_logps,
336339
reference_rejected_logps,
337340
)
338341

339-
return LossOutput(loss=loss, num_tokens=0)
342+
# Add SFT loss on chosen responses to prevent likelihood displacement
343+
if self.sft_weight > 0:
344+
sft_loss = self._compute_nll_loss(chosen_logps, chosen_labels)
345+
loss = dpo_loss + self.sft_weight * sft_loss
346+
else:
347+
loss = dpo_loss
348+
349+
# Return sample count for gradient normalization (not token count)
350+
# DPO loss is already per-sample mean, so we just count samples for accumulation
351+
import torch
352+
num_samples = torch.tensor(chosen_labels.shape[0], device=loss.device)
353+
354+
return LossOutput(loss=loss, num_tokens=num_samples)
340355

341356

342357
class SimPOLoss(PreferenceLossBase):

src/twinkle/metric/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .accuracy import Accuracy
33
from .base import Metric
44
from .completion_and_reward import CompletionRewardMetric
5+
from .dpo import DPOMetric
56
from .loss import LossMetric
67
from .train_metric import TrainMetric

src/twinkle/metric/dpo.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
"""DPO-specific metrics for preference optimization training."""
3+
from typing import List, Union
4+
5+
from twinkle.data_format import InputFeature, ModelOutput
6+
from .base import Metric
7+
8+
9+
class DPOMetric(Metric):
10+
"""Metrics for DPO (Direct Preference Optimization) training.
11+
12+
Computes TRL-style metrics:
13+
- logps/chosen: Average sequence-level log prob of chosen responses
14+
- logps/rejected: Average sequence-level log prob of rejected responses
15+
- rewards/chosen: β * (policy_chosen - ref_chosen)
16+
- rewards/rejected: β * (policy_rejected - ref_rejected)
17+
- rewards/margins: chosen_reward - rejected_reward
18+
- rewards/accuracies: Percentage where chosen_reward > rejected_reward
19+
20+
Args:
21+
device_mesh: The device mesh
22+
process_group: The process group to collect data from
23+
ignore_index: Label index to ignore (default: -100)
24+
beta: DPO beta parameter for reward scaling (default: 0.1)
25+
"""
26+
27+
def __init__(self, device_mesh, process_group, ignore_index: int = -100, beta: float = 0.1, **kwargs):
28+
super().__init__(device_mesh, process_group, **kwargs)
29+
self.ignore_index = ignore_index
30+
self.beta = beta
31+
self.reset()
32+
33+
def _compute_sequence_logps(self, per_token_logps, labels):
34+
"""Compute sequence-level log probs by summing valid token logps."""
35+
import torch
36+
loss_mask = (labels != self.ignore_index).float()
37+
return (per_token_logps * loss_mask).sum(dim=-1)
38+
39+
def _split_chosen_rejected(self, tensor):
40+
"""Split interleaved tensor into chosen and rejected.
41+
42+
Input format: [pos_1, neg_1, pos_2, neg_2, ...] (interleaved for DP-safe slicing)
43+
Output: (chosen [pos_1, pos_2, ...], rejected [neg_1, neg_2, ...])
44+
"""
45+
return tensor[0::2], tensor[1::2]
46+
47+
def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs):
48+
"""Accumulate DPO metrics from model outputs.
49+
50+
Expects:
51+
- outputs['logps']: [batch, seq_len] per-token log probabilities
52+
- inputs['labels']: [batch, seq_len] labels with ignore_index for non-target tokens
53+
- kwargs['ref_outputs']: Optional reference model outputs with 'logps'
54+
"""
55+
import torch
56+
57+
logps = outputs.get('logps')
58+
if logps is None:
59+
return
60+
61+
# Get labels from inputs
62+
if isinstance(inputs, list):
63+
# Stack labels from list of inputs
64+
labels_list = [torch.as_tensor(inp['labels']) for inp in inputs]
65+
max_len = max(l.shape[0] for l in labels_list)
66+
padded = []
67+
for l in labels_list:
68+
if l.shape[0] < max_len:
69+
pad = torch.full((max_len - l.shape[0],), self.ignore_index, dtype=l.dtype)
70+
l = torch.cat([pad, l])
71+
padded.append(l)
72+
labels = torch.stack(padded)
73+
else:
74+
labels = torch.as_tensor(inputs['labels'])
75+
if labels.dim() == 1:
76+
labels = labels.unsqueeze(0)
77+
78+
# Ensure logps and labels have same device
79+
if logps.device != labels.device:
80+
labels = labels.to(logps.device)
81+
82+
# Align sequence lengths if needed (truncate right)
83+
if logps.shape[1] != labels.shape[1]:
84+
min_len = min(logps.shape[1], labels.shape[1])
85+
logps = logps[:, :min_len]
86+
labels = labels[:, :min_len]
87+
88+
# Compute sequence-level logps
89+
seq_logps = self._compute_sequence_logps(logps, labels)
90+
91+
# Split into chosen and rejected (interleaved format)
92+
chosen_logps, rejected_logps = self._split_chosen_rejected(seq_logps)
93+
chosen_labels, rejected_labels = self._split_chosen_rejected(labels)
94+
95+
# Accumulate policy logps
96+
self.total_chosen_logps += chosen_logps.sum().item()
97+
self.total_rejected_logps += rejected_logps.sum().item()
98+
99+
# Compute rewards if ref_outputs available
100+
ref_outputs = kwargs.get('ref_outputs')
101+
if ref_outputs is not None:
102+
ref_logps = ref_outputs.get('logps')
103+
if ref_logps is not None:
104+
# Align ref_logps
105+
if ref_logps.device != labels.device:
106+
ref_logps = ref_logps.to(labels.device)
107+
if ref_logps.shape[1] != labels.shape[1]:
108+
min_len = min(ref_logps.shape[1], labels.shape[1])
109+
ref_logps = ref_logps[:, :min_len]
110+
111+
ref_seq_logps = self._compute_sequence_logps(ref_logps, labels)
112+
ref_chosen_logps, ref_rejected_logps = self._split_chosen_rejected(ref_seq_logps)
113+
114+
# Compute rewards: β * (policy - ref)
115+
chosen_rewards = self.beta * (chosen_logps - ref_chosen_logps)
116+
rejected_rewards = self.beta * (rejected_logps - ref_rejected_logps)
117+
118+
self.total_chosen_rewards += chosen_rewards.sum().item()
119+
self.total_rejected_rewards += rejected_rewards.sum().item()
120+
margins = chosen_rewards - rejected_rewards
121+
self.total_reward_margin += margins.sum().item()
122+
self.total_reward_correct += (margins > 0).sum().item()
123+
self.has_rewards = True
124+
125+
self.total_count += chosen_logps.shape[0]
126+
127+
def reset(self):
128+
"""Reset all accumulated values."""
129+
self.total_chosen_logps = 0.0
130+
self.total_rejected_logps = 0.0
131+
self.total_chosen_rewards = 0.0
132+
self.total_rejected_rewards = 0.0
133+
self.total_reward_margin = 0.0
134+
self.total_reward_correct = 0
135+
self.total_count = 0
136+
self.has_rewards = False
137+
138+
def calculate(self):
139+
"""Calculate and return aggregated metrics."""
140+
local_results = [{
141+
'chosen_logps': self.total_chosen_logps,
142+
'rejected_logps': self.total_rejected_logps,
143+
'chosen_rewards': self.total_chosen_rewards,
144+
'rejected_rewards': self.total_rejected_rewards,
145+
'reward_margin': self.total_reward_margin,
146+
'reward_correct': self.total_reward_correct,
147+
'count': self.total_count,
148+
'has_rewards': self.has_rewards,
149+
}]
150+
all_results = self.gather_results(local_results)
151+
152+
total_chosen_logps = sum(r['chosen_logps'] for r in all_results)
153+
total_rejected_logps = sum(r['rejected_logps'] for r in all_results)
154+
total_chosen_rewards = sum(r['chosen_rewards'] for r in all_results)
155+
total_rejected_rewards = sum(r['rejected_rewards'] for r in all_results)
156+
total_reward_margin = sum(r['reward_margin'] for r in all_results)
157+
total_reward_correct = sum(r['reward_correct'] for r in all_results)
158+
total_count = sum(r['count'] for r in all_results)
159+
has_rewards = any(r['has_rewards'] for r in all_results)
160+
161+
self.reset()
162+
163+
if total_count == 0:
164+
return {}
165+
166+
results = {
167+
'logps/chosen': f'{total_chosen_logps / total_count:.2f}',
168+
'logps/rejected': f'{total_rejected_logps / total_count:.2f}',
169+
}
170+
171+
if has_rewards:
172+
results['rewards/chosen'] = f'{total_chosen_rewards / total_count:.4f}'
173+
results['rewards/rejected'] = f'{total_rejected_rewards / total_count:.4f}'
174+
results['rewards/margins'] = f'{total_reward_margin / total_count:.4f}'
175+
results['rewards/accuracies'] = f'{total_reward_correct / total_count * 100:.1f}%'
176+
177+
return results

src/twinkle/metric/loss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ def calculate(self):
6060
num_tokens = sum(r['num_tokens'] for r in all_results)
6161
if num_tokens > 0:
6262
avg_loss = total_loss / num_tokens
63-
else:
63+
elif total_count > 0:
6464
avg_loss = total_loss / total_count
65+
else:
66+
avg_loss = 0.0
6567
self.reset()
6668
results = {}
6769
if avg_loss is not None:

src/twinkle/model/transformers/transformers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def accumulate_metrics(self, is_training):
121121
metrics = self.train_metrics
122122
else:
123123
metrics = self.eval_metrics
124+
# Get stored forward_kwargs from previous forward
125+
forward_kwargs = getattr(self, 'forward_kwargs', None) or {}
124126
if len(metrics) > 0 and self.inputs is not None and self.outputs is not None:
125127
for metric in metrics:
126128
metric.accumulate(
@@ -130,7 +132,8 @@ def accumulate_metrics(self, is_training):
130132
step=self.cur_step - 1,
131133
gradient_accumulation_steps=self.gradient_accumulation_steps,
132134
grad_norm=self._last_grad_norm,
133-
loss_reduction=getattr(self.loss_instance, 'reduction', 'mean'))
135+
loss_reduction=getattr(self.loss_instance, 'reduction', 'mean'),
136+
**forward_kwargs)
134137

135138
def calculate_metrics(self, is_training):
136139
self.accumulate_metrics(is_training)
@@ -405,6 +408,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec
405408
inputs['labels'] = labels
406409
optimizer_config.inputs = inputs
407410
optimizer_config.outputs = outputs
411+
optimizer_config.forward_kwargs = kwargs # Store for next metric accumulation
408412
optimizer_config.loss_value = outputs.get('aux_loss', 0)
409413
if labels is not None:
410414
loss_mask = (labels != -100).bool()
@@ -1086,6 +1090,7 @@ def set_grad_scaler(self, **kwargs):
10861090
grad_scaler_config.update(kwargs)
10871091
optimizer_config.scaler = GradScaler(**grad_scaler_config)
10881092

1093+
@remote_function()
10891094
def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs):
10901095
"""Add an eval metric
10911096

0 commit comments

Comments
 (0)