11# Copyright (c) ModelScope Contributors. All rights reserved.
2- import numpy as np
32import os
43from transformers import AutoConfig
54
1716TEMPLATE_ID = os .environ .get ('TEMPLATE_ID' , 'Template' )
1817_num_layers_env = os .environ .get ('NUM_LAYERS' )
1918NUM_LAYERS = int (_num_layers_env ) if _num_layers_env is not None else None
19+ BATCH_SIZE = int (os .environ .get ('BATCH_SIZE' , '4' ))
20+ GRAD_ACCUM_STEPS = int (os .environ .get ('GRAD_ACCUM_STEPS' , '4' ))
21+ LR = float (os .environ .get ('LR' , '1e-5' ))
22+ MAX_GRAD_NORM = float (os .environ .get ('MAX_GRAD_NORM' , '1.0' ))
23+ KEEP_ROUTER_LOGITS = os .environ .get ('KEEP_ROUTER_LOGITS' , '0' ) == '1'
2024
21- # 4 gpus, dp=2, ep=2
22- dp_size = 2
23- ep_size = 2
24-
25- device_mesh = DeviceMesh (
25+ # 8 gpus, dp=1, fsdp=8 (data parallel), ep_size=8 (expert parallel)
26+ device_mesh = DeviceMesh . from_sizes (
27+ fsdp_size = 8 ,
28+ dp_size = 1 ,
29+ ep_size = 8 ,
2630 device_type = Platform .get_platform ().device_prefix (),
27- mesh = np .arange (dp_size * ep_size ).reshape (dp_size , ep_size ),
28- mesh_dim_names = ('dp' , 'ep' ),
2931)
3032
3133twinkle .initialize (
@@ -41,7 +43,7 @@ def train():
4143 if hasattr (config , 'use_cache' ):
4244 config .use_cache = False
4345
44- dataset = Dataset (dataset_meta = DatasetMeta ('ms://swift/self-cognition' , data_slice = range (1000 )))
46+ dataset = Dataset (dataset_meta = DatasetMeta (DATASET_ID , data_slice = range (1000 )))
4547 try :
4648 dataset .set_template (TEMPLATE_ID , model_id = MODEL_ID )
4749 except ValueError :
@@ -51,11 +53,10 @@ def train():
5153 dataset .encode (batched = True )
5254 dataloader = DataLoader (
5355 dataset = dataset ,
54- batch_size = 4 ,
56+ batch_size = BATCH_SIZE ,
5557 device_mesh = device_mesh ,
5658 )
5759
58- grad_accum_steps = 4
5960 model = TransformersModel (
6061 model_id = MODEL_ID ,
6162 config = config ,
@@ -64,29 +65,43 @@ def train():
6465 'expert_parallel' : {
6566 'enabled' : True ,
6667 'router_dtype' : 'fp32' ,
67- 'all_to_all' : 'torch' ,
68- 'keep_router_logits' : False ,
68+ 'keep_router_logits' : KEEP_ROUTER_LOGITS ,
6969 }
7070 },
7171 )
7272 # Disable foreach to avoid DTensor mixed-type errors in EP runs.
73- model .set_optimizer ('AdamW' , foreach = False )
73+ model .set_optimizer ('AdamW' , lr = LR , foreach = False )
74+ model .set_lr_scheduler (
75+ scheduler_cls = 'CosineWarmupScheduler' ,
76+ num_warmup_steps = 5 ,
77+ num_training_steps = len (dataloader ),
78+ )
7479
7580 logger .info (get_device_placement ())
7681 logger .info (model .get_train_configs ())
82+ logger .info (
83+ f'Total steps: { len (dataloader )} , batch_size={ BATCH_SIZE } , grad_accum={ GRAD_ACCUM_STEPS } , '
84+ f'lr={ LR :.2e} , max_grad_norm={ MAX_GRAD_NORM } , '
85+ f'keep_router_logits={ KEEP_ROUTER_LOGITS } ' )
7786
7887 for step , batch in enumerate (dataloader ):
7988 if callable (batch ):
8089 batch = batch ()
81- model .forward_backward (inputs = batch , gradient_accumulation_steps = grad_accum_steps )
82- model .clip_grad_and_step (gradient_accumulation_steps = grad_accum_steps )
83- if step % grad_accum_steps == 0 :
90+ model .forward_backward (inputs = batch , gradient_accumulation_steps = GRAD_ACCUM_STEPS )
91+ model .clip_grad_and_step (
92+ max_grad_norm = MAX_GRAD_NORM ,
93+ gradient_accumulation_steps = GRAD_ACCUM_STEPS ,
94+ )
95+
96+ is_sync_step = ((step + 1 ) % GRAD_ACCUM_STEPS == 0 )
97+ if is_sync_step :
98+ optimizer_step = (step + 1 ) // GRAD_ACCUM_STEPS
8499 metric = model .calculate_metric (is_training = True )
85100 if callable (metric ):
86101 metric = metric ()
87- logger .info (f'Current is step { step // grad_accum_steps } , metric: { metric } ' )
88- if step > 0 and step % 50 == 0 :
89- model .save ('./output' )
102+ logger .info (f'Current optimizer_step { optimizer_step } , metric: { metric } ' )
103+ if optimizer_step > 0 and optimizer_step % 50 == 0 :
104+ model .save (name = f'checkpoint-step- { optimizer_step } ' , output_dir = './output' )
90105
91106
92107if __name__ == '__main__' :
0 commit comments