Skip to content

Commit 1c42798

Browse files
committed
update twinkle dpo
1 parent 4fe4551 commit 1c42798

File tree

6 files changed

+29
-7
lines changed

6 files changed

+29
-7
lines changed

cookbook/client/twinkle/self_host/dpo.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
from peft import LoraConfig
1616

1717
from twinkle import get_logger
18-
from twinkle.dataset import DatasetMeta
18+
from twinkle.dataset import Dataset, DatasetMeta
1919
from twinkle_client import init_twinkle_client
2020
from twinkle.dataloader import DataLoader
21-
from twinkle.dataset import LazyDataset
2221
from twinkle_client.model import MultiLoraTransformersModel
2322
from twinkle.loss import DPOLoss
2423
from twinkle.metric import DPOMetric
@@ -65,7 +64,7 @@
6564

6665
def create_dpo_dataset():
6766
"""Create DPO dataset with positive/negative format."""
68-
dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id, data_slice=range(6000)))
67+
dataset = Dataset(DatasetMeta(dataset_id, data_slice=range(600)))
6968
dataset.set_template('Qwen3_5Template', model_id=f'ms://{base_model}', max_length=max_length)
7069
dataset.map(
7170
EmojiDPOProcessor,
@@ -75,7 +74,7 @@ def create_dpo_dataset():
7574
)
7675
# DPO preprocessor returns {'positive': [...], 'negative': [...]}
7776
# batch_encode handles this format automatically
78-
dataset.encode(batched=True)
77+
dataset.encode()
7978
return dataset
8079

8180

@@ -179,7 +178,7 @@ def train():
179178
# Get reference outputs using base model (without LoRA adapter)
180179
# disable_lora=True tells the model to skip LoRA and use base weights
181180
ref_outputs = model.forward_only(inputs=dpo_batch, disable_lora=True)
182-
model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs)
181+
model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs.result)
183182
model.clip_grad_and_step()
184183

185184
optim_step += 1

src/twinkle/loss/dpo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ def __call__(
310310
reference_rejected_logps = ref_rejected_logps.to(device=device, dtype=dtype)
311311
elif ref_logps is not None:
312312
# Per-token reference log probs provided, need to align and sum
313+
if not torch.is_tensor(ref_logps):
314+
ref_logps = torch.as_tensor(ref_logps)
313315
ref_logps_aligned = self._align_logps(ref_logps, labels.shape, device, dtype)
314316
ref_chosen, ref_rejected = self._split_chosen_rejected(ref_logps_aligned)
315317
reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels)

src/twinkle/metric/dpo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def _align_logps(self, logps, target_shape, device, dtype):
5050
Aligned tensor with shape matching target_shape
5151
"""
5252
import torch
53+
54+
if not torch.is_tensor(logps):
55+
logps = torch.as_tensor(logps)
5356
logps = logps.to(device=device, dtype=dtype)
5457
batch_size, src_len = logps.shape
5558
_, target_len = target_shape

src/twinkle/server/model/backends/megatron_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from twinkle import remote_class, remote_function
1010
from twinkle.data_format import InputFeature, Trajectory
11+
from twinkle.infra import collect_tensor_dict
1112
from twinkle.model.megatron import MultiLoraMegatronModel
1213
from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature
1314
from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics,
@@ -119,7 +120,13 @@ def tinker_load(self, checkpoint_dir: str, **kwargs):
119120
# Twinkle-native methods (InputFeature/Trajectory-based I/O)
120121
# ------------------------------------------------------------------
121122

122-
@remote_function(dispatch='slice_dp', collect='mean')
123+
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
124+
def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs):
125+
"""Forward-only for twinkle-native clients (InputFeature/Trajectory I/O)."""
126+
output = super().forward_only(inputs=inputs, **kwargs)
127+
return to_cpu_safe_output(output)
128+
129+
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
123130
def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
124131
**kwargs):
125132
"""Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O)."""

src/twinkle/server/model/backends/transformers_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from twinkle import remote_class, remote_function
1313
from twinkle.data_format import InputFeature, Trajectory
14+
from twinkle.infra import collect_tensor_dict
1415
from twinkle.model import MultiLoraTransformersModel
1516
from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature
1617
from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics,
@@ -106,7 +107,13 @@ def tinker_load(self, checkpoint_dir: str, **kwargs):
106107
# Twinkle-native methods (InputFeature/Trajectory-based I/O)
107108
# ------------------------------------------------------------------
108109

109-
@remote_function(dispatch='slice_dp', collect='mean')
110+
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
111+
def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs):
112+
"""Forward-only for twinkle-native clients (InputFeature/Trajectory I/O)."""
113+
output = super().forward_only(inputs=inputs, **kwargs)
114+
return to_cpu_safe_output(output)
115+
116+
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
110117
def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
111118
**kwargs):
112119
"""Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O)."""

src/twinkle_client/common/serialize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
from numbers import Number
44
from peft import LoraConfig
5+
from pydantic import BaseModel
56
from typing import Any, Mapping
67

78
from twinkle.dataset import DatasetMeta
@@ -56,6 +57,9 @@ def serialize_object(obj) -> str:
5657
}
5758
filtered_dict['_TWINKLE_TYPE_'] = 'LoraConfig'
5859
return json.dumps(filtered_dict, ensure_ascii=False)
60+
elif isinstance(obj, BaseModel):
61+
# Pydantic models: convert to dict for JSON serialization by requests
62+
return obj.model_dump(mode='json')
5963
elif isinstance(obj, Mapping):
6064
return json.dumps(obj, ensure_ascii=False)
6165
elif isinstance(obj, basic_types):

0 commit comments

Comments
 (0)