Skip to content

Commit 21a7c17

Browse files
committed
update twinkle dpo
1 parent f45fc13 commit 21a7c17

File tree

3 files changed

+217
-1
lines changed

3 files changed

+217
-1
lines changed
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Tinker-Compatible Client - DPO (Direct Preference Optimization) Training with LoRA
2+
#
3+
# This script demonstrates how to fine-tune a language model using DPO
4+
# through the Tinker-compatible client API.
5+
#
6+
# Training flow per step:
7+
# 1. forward_backward with 'cross_entropy' + disable_lora=True
8+
# → base-model forward pass; LoRA weights are NOT in the computation graph
9+
# so backward accumulates zero LoRA gradients (safe to discard).
10+
# 2. Attach returned per-token ref logps to each datum's loss_fn_inputs.
11+
# 3. forward_backward with 'importance_sampling'
12+
# → server detects ref_logps and switches to DPOLoss + DPOMetric.
13+
# 4. optim_step → update LoRA, DPO metrics returned automatically.
14+
#
15+
# The server must be running first (see server.py and server_config.yaml).
16+
17+
import numpy as np
18+
import torch
19+
from tqdm import tqdm
20+
from typing import Any, Dict, List
21+
22+
from tinker import types
23+
from twinkle import init_tinker_client, get_logger
24+
from twinkle.dataset import Dataset, DatasetMeta
25+
from twinkle.dataloader import DataLoader
26+
from twinkle.preprocessor import EmojiDPOProcessor
27+
from twinkle.server.common import input_feature_to_datum
28+
29+
logger = get_logger()
30+
31+
# Initialize the Tinker client before importing ServiceClient
32+
init_tinker_client()
33+
34+
from tinker import ServiceClient # noqa: E402 (must follow init_tinker_client)
35+
36+
# ---------------------------------------------------------------------------
37+
# Configuration
38+
# ---------------------------------------------------------------------------
39+
base_model = 'Qwen/Qwen3.5-4B'
40+
base_url = 'http://localhost:8000'
41+
api_key = 'EMPTY_API_KEY'
42+
dataset_id = 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji'
43+
44+
batch_size = 4
45+
learning_rate = 1e-4
46+
dpo_beta = 0.1
47+
sft_weight = 1.0
48+
loss_type = 'sigmoid'
49+
max_length = 2048
50+
lora_rank = 8
51+
system_prompt = 'You are a helpful assistant.'
52+
53+
54+
# ---------------------------------------------------------------------------
55+
# Dataset helpers (reused from twinkle/self_host/dpo.py)
56+
# ---------------------------------------------------------------------------
57+
58+
def create_dpo_dataset():
59+
"""Create DPO dataset with positive/negative format."""
60+
dataset = Dataset(DatasetMeta(dataset_id, data_slice=range(600)))
61+
dataset.set_template('Qwen3_5Template', model_id=f'ms://{base_model}', max_length=max_length)
62+
dataset.map(
63+
EmojiDPOProcessor,
64+
init_args={'system': system_prompt},
65+
)
66+
# EmojiDPOProcessor returns {'positive': InputFeature, 'negative': InputFeature, ...}
67+
# encode handles this format automatically
68+
dataset.encode()
69+
return dataset
70+
71+
72+
def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
73+
"""Reorganise batch into DP-safe interleaved format [pos_1, neg_1, pos_2, neg_2, ...].
74+
75+
Args:
76+
batch: List of rows, each with 'positive' and 'negative' InputFeatures.
77+
78+
Returns:
79+
Interleaved list so each DP worker slice contains complete pairs.
80+
"""
81+
result = []
82+
for row in batch:
83+
base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')}
84+
pos_sample = {**base_fields, **row['positive']}
85+
neg_sample = {**base_fields, **row['negative']}
86+
result.append(pos_sample)
87+
result.append(neg_sample)
88+
return result
89+
90+
91+
# ---------------------------------------------------------------------------
92+
# Training
93+
# ---------------------------------------------------------------------------
94+
95+
def train():
96+
# Step 1: Prepare dataset & dataloader
97+
logger.info('Loading DPO dataset...')
98+
dataset = create_dpo_dataset()
99+
dataloader = DataLoader(dataset=dataset, batch_size=batch_size)
100+
logger.info(f'Dataset ready: {len(dataloader)} steps per epoch')
101+
102+
# Step 2: Connect to server and create LoRA training client
103+
service_client = ServiceClient(base_url=base_url, api_key=api_key)
104+
training_client = service_client.create_lora_training_client(
105+
base_model=base_model,
106+
rank=lora_rank,
107+
)
108+
logger.info(f'LoRA training client created (rank={lora_rank})')
109+
logger.info(f'Starting DPO training: loss_type={loss_type}, beta={dpo_beta}, lr={learning_rate}')
110+
111+
# Step 3: Training loop
112+
for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
113+
# Normalise numpy / torch tensors to plain Python lists for serialisation
114+
for row in batch:
115+
for key in list(row.keys()):
116+
if isinstance(row[key], np.ndarray):
117+
row[key] = row[key].tolist()
118+
elif isinstance(row[key], torch.Tensor):
119+
row[key] = row[key].cpu().numpy().tolist()
120+
121+
# Build interleaved [pos, neg, pos, neg, ...] batch
122+
dpo_batch = prepare_dpo_batch(batch)
123+
124+
# Convert each InputFeature dict to a Tinker Datum
125+
input_datums = [input_feature_to_datum(row) for row in dpo_batch]
126+
127+
# -----------------------------------------------------------------
128+
# A. Reference forward pass (base model, disable_lora=True)
129+
# LoRA weights are outside the computation graph → backward
130+
# produces zero LoRA gradients, so this call is safe.
131+
# -----------------------------------------------------------------
132+
ref_result = training_client.forward_backward(
133+
input_datums,
134+
'cross_entropy',
135+
loss_fn_config={'disable_lora': True},
136+
).result()
137+
138+
# -----------------------------------------------------------------
139+
# B. Attach per-token ref logps to each datum's loss_fn_inputs
140+
# -----------------------------------------------------------------
141+
for datum, ref_out in zip(input_datums, ref_result.loss_fn_outputs):
142+
ref_logprobs_np = np.array(ref_out['logprobs'].tolist(), dtype=np.float32)
143+
datum.loss_fn_inputs['ref_logps'] = types.TensorData.from_numpy(ref_logprobs_np)
144+
145+
# -----------------------------------------------------------------
146+
# C. DPO forward_backward
147+
# Server detects ref_logps → sets DPOLoss + DPOMetric automatically.
148+
# Optional DPO hyper-params can be forwarded via loss_fn_config.
149+
# -----------------------------------------------------------------
150+
fwdbwd_result = training_client.forward_backward(
151+
input_datums,
152+
'importance_sampling',
153+
loss_fn_config={
154+
'dpo_beta': dpo_beta,
155+
'dpo_loss_type': loss_type,
156+
'dpo_sft_weight': sft_weight,
157+
},
158+
).result()
159+
160+
# -----------------------------------------------------------------
161+
# D. Optimizer step — DPOMetric is calculated automatically on the
162+
# server and returned inside optim_result.metrics.
163+
# -----------------------------------------------------------------
164+
optim_result = training_client.optim_step(
165+
types.AdamParams(learning_rate=learning_rate)
166+
).result()
167+
168+
dpo_loss = fwdbwd_result.metrics.get('loss:avg', 'N/A')
169+
logger.info(f'[Step {step}] dpo_loss={dpo_loss} | metrics={optim_result.metrics}')
170+
171+
# Step 4: Save checkpoint
172+
save_result = training_client.save_state('dpo-lora-final').result()
173+
logger.info(f'Saved checkpoint: {save_result.path}')
174+
175+
# Step 5: (Optional) Upload to ModelScope Hub
176+
# YOUR_USER_NAME = 'your_username'
177+
# hub_model_id = f'{YOUR_USER_NAME}/twinkle-tinker-dpo-lora'
178+
# training_client.publish_checkpoint_from_tinker_path(save_result.path).result()
179+
# logger.info(f'Uploaded checkpoint to hub: {hub_model_id}')
180+
181+
182+
if __name__ == '__main__':
183+
train()

src/twinkle/server/common/datum.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def extract_rl_feature(datum: types.Datum | list[types.Datum]) -> dict:
7171
if 'advantages' in d.loss_fn_inputs:
7272
advantages = d.loss_fn_inputs['advantages'].to_numpy().tolist()
7373
result['advantages'].append(advantages)
74+
75+
# 'ref_logps' -> 'ref_logps' (for DPO loss)
76+
if 'ref_logps' in d.loss_fn_inputs:
77+
ref_logps = d.loss_fn_inputs['ref_logps'].to_numpy().tolist()
78+
result['ref_logps'].append(ref_logps)
7479
return result
7580

7681

src/twinkle/server/model/backends/transformers_model.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,42 @@ def tinker_forward_backward(self, *, inputs: List[types.Datum], adapter_name: st
5050
if loss_fn == 'cross_entropy':
5151
super().set_loss('CrossEntropyLoss', adapter_name=adapter_name)
5252
elif loss_fn == 'importance_sampling':
53-
super().set_loss('GRPOLoss', adapter_name=adapter_name, epsilon=0.2, beta=0.0)
53+
# Detect DPO format: datums contain ref_logps in loss_fn_inputs
54+
has_ref_logps = any('ref_logps' in d.loss_fn_inputs for d in inputs)
55+
if has_ref_logps:
56+
# DPO mode: read optional DPO params from loss_fn_config kwargs
57+
beta = kwargs.pop('dpo_beta', 0.1)
58+
loss_type = kwargs.pop('dpo_loss_type', 'sigmoid')
59+
sft_weight = kwargs.pop('dpo_sft_weight', 0.0)
60+
super().set_loss(
61+
'DPOLoss', adapter_name=adapter_name, beta=beta, loss_type=loss_type, sft_weight=sft_weight)
62+
super().add_metric('DPOMetric', adapter_name=adapter_name, beta=beta)
63+
else:
64+
# GRPO mode: read optional GRPO params from loss_fn_config kwargs
65+
# Also pop DPO-specific kwargs to prevent leaking into forward/backward
66+
epsilon = kwargs.pop('epsilon', 0.2)
67+
grpo_beta = kwargs.pop('beta', 0.0)
68+
super().set_loss('GRPOLoss', adapter_name=adapter_name, epsilon=epsilon, beta=grpo_beta)
5469
else:
5570
super().set_loss('CrossEntropyLoss', adapter_name=adapter_name)
5671
template = self.get_template(adapter_name)
5772
input_features = datum_to_input_feature(inputs, template)
5873
outputs = super().forward(inputs=input_features, adapter_name=adapter_name, **kwargs)
5974
loss_values = extract_rl_feature(inputs)
6075
loss_kwargs = kwargs.copy()
76+
# Convert ref_logps list-of-lists into a padded tensor wrapped in ref_outputs
77+
# so that DPOLoss and DPOMetric can consume it via ref_outputs.get('logps').
78+
# if 'ref_logps' in loss_values:
79+
# import torch
80+
# import torch.nn.functional as F
81+
# ref_logps_lists = loss_values.pop('ref_logps')
82+
# max_len = max(len(r) for r in ref_logps_lists)
83+
# padded = [
84+
# F.pad(torch.tensor(r, dtype=torch.float32), (0, max_len - len(r)))
85+
# for r in ref_logps_lists
86+
# ]
87+
# ref_logps_tensor = torch.stack(padded) # [batch, max_seq_len]
88+
# loss_kwargs['ref_outputs'] = {'logps': ref_logps_tensor}
6189
loss_kwargs.update(loss_values)
6290
loss = super().calculate_loss(adapter_name=adapter_name, **loss_kwargs)
6391
super().backward(adapter_name=adapter_name, **kwargs)

0 commit comments

Comments
 (0)