1+ import os
2+
3+ from peft import LoraConfig
4+ from tqdm import tqdm
5+
6+ import twinkle
7+ from twinkle import DeviceMesh , Platform
8+ from twinkle import get_device_placement , get_logger
9+ from twinkle .dataloader import DataLoader
10+ from twinkle .dataset import Dataset , DatasetMeta
11+ from twinkle .model import TransformersModel
12+ from twinkle .preprocessor import SelfCognitionProcessor
13+
14+ if Platform .get_rank () == 0 :
15+ # rank0 recording
16+ import swanlab
17+ swanlab .login (api_key = os .environ ['SWANLAB_API_KEY' ], save = True )
18+
19+ run = swanlab .init (
20+ project = "megatron-swift" ,
21+ )
22+
23+
24+ # Construct a device_mesh, fsdp=2, dp=2
25+ device_mesh = DeviceMesh .from_sizes (dp_size = 2 , fsdp_size = 2 )
26+ # use torchrun mode
27+ twinkle .initialize (mode = 'local' , global_device_mesh = device_mesh )
28+
29+ logger = get_logger ()
30+
31+
32+ def eval (model ):
33+ # 100 Samples
34+ dataset = Dataset (dataset_meta = DatasetMeta ('ms://swift/self-cognition' , data_slice = range (100 )))
35+ dataset .set_template ('Template' , model_id = 'ms://Qwen/Qwen2.5-7B-Instruct' )
36+ dataset .map (SelfCognitionProcessor ('twinkle大模型' , 'ModelScope社区' ))
37+ dataset .encode ()
38+ dataloader = DataLoader (dataset = dataset , batch_size = 4 )
39+ for step , batch in tqdm (enumerate (dataloader )):
40+ model .forward_only (inputs = batch )
41+ model .calculate_loss ()
42+ metrics = model .calculate_metric (is_training = False )
43+ return metrics
44+
45+
46+ def train ():
47+ # 1000 samples
48+ dataset = Dataset (dataset_meta = DatasetMeta ('ms://swift/self-cognition' , data_slice = range (1000 )))
49+ # Set template to prepare encoding
50+ dataset .set_template ('Template' , model_id = 'ms://Qwen/Qwen2.5-7B-Instruct' )
51+ # Preprocess the dataset to standard format
52+ dataset .map (SelfCognitionProcessor ('twinkle大模型' , 'ModelScope社区' ))
53+ # Encode dataset
54+ dataset .encode ()
55+ # Global batch size = 4, for GPUs, so 1 sample per GPU
56+ dataloader = DataLoader (dataset = dataset , batch_size = 4 )
57+ # Use a TransformersModel
58+ model = TransformersModel (model_id = 'ms://Qwen/Qwen2.5-7B-Instruct' )
59+
60+ lora_config = LoraConfig (
61+ r = 8 ,
62+ lora_alpha = 32 ,
63+ target_modules = 'all-linear'
64+ )
65+
66+ # Add a lora to model, with name `default`
67+ model .add_adapter_to_model ('default' , lora_config , gradient_accumulation_steps = 4 )
68+ # Add Optimizer for lora `default`
69+ model .set_optimizer (optimizer_cls = 'AdamW' , lr = 1e-4 )
70+ # Add LRScheduler for lora `default`
71+ model .set_lr_scheduler (scheduler_cls = 'CosineWarmupScheduler' , num_warmup_steps = 5 , num_training_steps = len (dataloader ))
72+ logger .info (get_device_placement ())
73+ # Print the training config
74+ logger .info (model .get_train_configs ())
75+ logger .info (f'Total steps: { len (dataloader )} ' )
76+ loss_metric = 99.0
77+ for step , batch in enumerate (dataloader ):
78+ # Do forward and backward
79+ model .forward_backward (inputs = batch )
80+ # Step
81+ model .clip_grad_and_step ()
82+ if step % 20 == 0 :
83+ # Print metric
84+ metric = model .calculate_metric (is_training = True )
85+ if Platform .get_rank () == 0 :
86+ swanlab .log (metric )
87+ logger .info (f'Current is step { step } of { len (dataloader )} , metric: { metric } ' )
88+ if step > 0 and step % 40 == 0 :
89+ metrics = eval (model )
90+ logger .info (f'Eval metric: { metrics } ' )
91+ metrics ['step' ] = step
92+ if loss_metric > float (metrics ['loss' ]):
93+ model .save (f'checkpoint-{ step } ' )
94+ loss_metric = float (metrics ['loss' ])
95+ model .save (f'last-checkpoint' , adapter_name = 'default' )
96+
97+
98+ if __name__ == '__main__' :
99+ train ()
0 commit comments