77from twinkle .dataloader import DataLoader
88from twinkle .dataset import Dataset , DatasetMeta
99from twinkle .model import TransformersModel
10- from twinkle .data_format import Message , Trajectory
11- from twinkle .preprocessor import SelfCognitionProcessor , Preprocessor
10+ from twinkle .preprocessor import SelfCognitionProcessor
1211
1312# Construct a device_mesh, dp=2
14- device_mesh = DeviceMesh .from_sizes (dp_size = 8 )
13+ device_mesh = DeviceMesh .from_sizes (dp_size = 2 )
1514# use torchrun mode
1615twinkle .initialize (mode = 'local' , global_device_mesh = device_mesh )
1716
2120def eval (model ):
2221 # 100 Samples
2322 dataset = Dataset (dataset_meta = DatasetMeta ('ms://swift/self-cognition' , data_slice = range (100 )))
24- dataset .set_template ('Template' , model_id = 'ms://Qwen/Qwen2 .5-7B-Instruct ' )
23+ dataset .set_template ('Template' , model_id = 'ms://Qwen/Qwen3 .5-4B ' )
2524 dataset .map (SelfCognitionProcessor ('twinkle大模型' , 'ModelScope社区' ))
2625 dataset .encode ()
2726 dataloader = DataLoader (dataset = dataset , batch_size = 8 )
@@ -32,55 +31,19 @@ def eval(model):
3231 return metrics
3332
3433
35- class EmojiDPOProcessor (Preprocessor ):
36- def __init__ (
37- self ,
38- system = 'You are a helpful assistant.' ,
39- chosen_key : str = 'answer_zh' ,
40- rejected_key : str = 'answer_en' ,
41- prompt_key : str = 'prompt' ,
42- ):
43- self .system = system
44- self .chosen_key = chosen_key
45- self .rejected_key = rejected_key
46- self .prompt_key = prompt_key
47-
48- def __call__ (self , rows ):
49- rows = self .map_col_to_row (rows )
50- rows = [self .preprocess (row ) for row in rows ]
51- rows = self .map_row_to_col (rows )
52- return rows
53-
54- def preprocess (self , row ):
55- """Process a single row."""
56- prompt = row .get (self .prompt_key , '' )
57- chosen = row .get (self .chosen_key , '' )
58- rejected = row .get (self .rejected_key , '' )
59-
60- prompt_messages = []
61- if self .system :
62- prompt_messages .append (Message (role = 'system' , content = self .system ))
63- prompt_messages .append (Message (role = 'user' , content = prompt ))
64-
65- chosen_messages = prompt_messages + [Message (role = 'assistant' , content = chosen )]
66- rejected_messages = prompt_messages + [Message (role = 'assistant' , content = rejected )]
67-
68- return Trajectory (messages = chosen_messages )
69-
70-
7134def train ():
7235 # 1000 samples
73- dataset = Dataset (dataset_meta = DatasetMeta ('ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji' ))
36+ dataset = Dataset (dataset_meta = DatasetMeta ('ms://swift/self-cognition' , data_slice = range ( 1000 ) ))
7437 # Set template to prepare encoding
75- dataset .set_template ('Template' , model_id = 'ms://Qwen/Qwen2 .5-7B-Instruct ' )
38+ dataset .set_template ('Template' , model_id = 'ms://Qwen/Qwen3 .5-4B ' )
7639 # Preprocess the dataset to standard format
77- dataset .map (EmojiDPOProcessor )
40+ dataset .map (SelfCognitionProcessor ( 'twinkle大模型' , 'ModelScope社区' ) )
7841 # Encode dataset
7942 dataset .encode ()
8043 # Global batch size = 8, for GPUs, so 1 sample per GPU
8144 dataloader = DataLoader (dataset = dataset , batch_size = 8 )
8245 # Use a TransformersModel
83- model = TransformersModel (model_id = 'ms://Qwen/Qwen2 .5-7B-Instruct ' )
46+ model = TransformersModel (model_id = 'ms://Qwen/Qwen3 .5-4B ' )
8447 model .model ._no_split_modules = {'Qwen3_5DecoderLayer' }
8548
8649 lora_config = LoraConfig (r = 8 , lora_alpha = 32 , target_modules = 'all-linear' )
@@ -109,6 +72,13 @@ def train():
10972 # Print metric
11073 metric = model .calculate_metric (is_training = True )
11174 logger .info (f'Current is step { step } of { len (dataloader )} , metric: { metric } ' )
75+ if step > 0 and step % 40 == 0 :
76+ metrics = eval (model )
77+ logger .info (f'Eval metric: { metrics } ' )
78+ metrics ['step' ] = step
79+ if loss_metric > float (metrics ['loss' ]):
80+ model .save (f'checkpoint-{ step } ' )
81+ loss_metric = float (metrics ['loss' ])
11282 model .save (f'last-checkpoint' )
11383
11484
0 commit comments