1515from peft import LoraConfig
1616
1717from twinkle import get_logger
18- from twinkle .dataset import DatasetMeta
18+ from twinkle .dataset import Dataset , DatasetMeta
1919from twinkle_client import init_twinkle_client
2020from twinkle .dataloader import DataLoader
21- from twinkle .dataset import LazyDataset
2221from twinkle_client .model import MultiLoraTransformersModel
2322from twinkle .loss import DPOLoss
2423from twinkle .metric import DPOMetric
6564
6665def create_dpo_dataset ():
6766 """Create DPO dataset with positive/negative format."""
68- dataset = LazyDataset ( dataset_meta = DatasetMeta (dataset_id , data_slice = range (6000 )))
67+ dataset = Dataset ( DatasetMeta (dataset_id , data_slice = range (600 )))
6968 dataset .set_template ('Qwen3_5Template' , model_id = f'ms://{ base_model } ' , max_length = max_length )
7069 dataset .map (
7170 EmojiDPOProcessor ,
@@ -75,7 +74,7 @@ def create_dpo_dataset():
7574 )
7675 # DPO preprocessor returns {'positive': [...], 'negative': [...]}
7776 # batch_encode handles this format automatically
78- dataset .encode (batched = True )
77+ dataset .encode ()
7978 return dataset
8079
8180
@@ -179,7 +178,7 @@ def train():
179178 # Get reference outputs using base model (without LoRA adapter)
180179 # disable_lora=True tells the model to skip LoRA and use base weights
181180 ref_outputs = model .forward_only (inputs = dpo_batch , disable_lora = True )
182- model .forward_backward (inputs = dpo_batch , ref_outputs = ref_outputs )
181+ model .forward_backward (inputs = dpo_batch , ref_outputs = ref_outputs . result )
183182 model .clip_grad_and_step ()
184183
185184 optim_step += 1
0 commit comments