1111from twinkle .model import TransformersModel
1212from twinkle .preprocessor import SelfCognitionProcessor
1313
14- if Platform .get_rank () == 0 :
14+ if Platform .get_rank () == 0 and os . environ . get ( 'SWANLAB_API_KEY' ) :
1515 # rank0 recording
1616 import swanlab
1717 swanlab .login (api_key = os .environ ['SWANLAB_API_KEY' ], save = True )
2121 )
2222
2323
24- # Construct a device_mesh, fsdp=2 , dp=2
25- device_mesh = DeviceMesh .from_sizes (fsdp_size = 4 )
24+ # Construct a device_mesh, fsdp=4 , dp=2
25+ device_mesh = DeviceMesh .from_sizes (fsdp_size = 4 , dp_size = 2 )
2626# use torchrun mode
2727twinkle .initialize (mode = 'local' , global_device_mesh = device_mesh )
2828
@@ -53,7 +53,7 @@ def train():
5353 # Encode dataset
5454 dataset .encode ()
5555 # Global batch size = 4, for GPUs, so 1 sample per GPU
56- dataloader = DataLoader (dataset = dataset , batch_size = 4 )
56+ dataloader = DataLoader (dataset = dataset , batch_size = 8 )
5757 # Use a TransformersModel
5858 model = TransformersModel (model_id = 'ms://Qwen/Qwen2.5-7B-Instruct' )
5959
@@ -64,7 +64,8 @@ def train():
6464 )
6565
6666 # Add a lora to model, with name `default`
67- model .add_adapter_to_model ('default' , lora_config , gradient_accumulation_steps = 4 )
67+ # Comment this to use full-parameter training
68+ model .add_adapter_to_model ('default' , lora_config , gradient_accumulation_steps = 2 )
6869 # Add Optimizer for lora `default`
6970 model .set_optimizer (optimizer_cls = 'AdamW' , lr = 1e-4 )
7071 # Add LRScheduler for lora `default`
@@ -84,7 +85,7 @@ def train():
8485 if step % 20 == 0 :
8586 # Print metric
8687 metric = model .calculate_metric (is_training = True )
87- if Platform .get_rank () == 0 :
88+ if Platform .get_rank () == 0 and os . environ . get ( 'SWANLAB_API_KEY' ) :
8889 swanlab .log (metric )
8990 logger .info (f'Current is step { step } of { len (dataloader )} , metric: { metric } ' )
9091 if step > 0 and step % 40 == 0 :
0 commit comments