2121logger = get_logger ()
2222
2323MODEL_ID = os .environ .get ('MODEL_ID' , 'ms://Qwen/Qwen3.5-4B' )
24- USE_MEGATRON = bool (int (os .environ .get ('USE_MEGATRON' , '0 ' )))
24+ USE_MEGATRON = bool (int (os .environ .get ('USE_MEGATRON' , '1 ' )))
2525
2626MODEL_GPUS = int (os .environ .get ('MODEL_GPUS' , 4 ))
2727SAMPLER_GPUS = int (os .environ .get ('SAMPLER_GPUS' ,4 ))
3131MAX_NEW_TOKENS = int (os .environ .get ('MAX_NEW_TOKENS' , 4096 ))
3232LEARNING_RATE = float (os .environ .get ('LR' , 1e-5 ))
3333MAX_STEPS = int (os .environ .get ('MAX_STEPS' , 200 ))
34- BATCH_SIZE = int (os .environ .get ('BATCH_SIZE' , 16 )) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
35- MINI_BATCH_SIZE = int (os .environ .get ('MINI_BATCH_SIZE' , 16 )) # global completion-level mini-batch-size
34+ BATCH_SIZE = int (os .environ .get ('BATCH_SIZE' , 8 )) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
35+ MINI_BATCH_SIZE = int (os .environ .get ('MINI_BATCH_SIZE' , 8 )) # global completion-level mini-batch-size
3636MICRO_BATCH_SIZE = int (os .environ .get ('MICRO_BATCH_SIZE' , 2 )) # per-device-micro-batch-size (completion-level), batch_size in forward_backward
3737GRADIENT_ACCUMULATION_STEPS = int (os .environ .get ('GRADIENT_ACCUMULATION_STEPS' , 1 ))
3838ADAPTER_NAME = 'default'
39+ SAVE_STEPS = int (os .environ .get ('SAVE_STEPS' , 50 ))
3940
4041def create_gsm8k_dataset ():
4142 dataset = Dataset (DatasetMeta ('ms://modelscope/gsm8k' , subset_name = 'main' , split = 'train' ))
42- dataset .set_template ('Template' , model_id = MODEL_ID , max_length = 2048 )
43+ dataset .set_template ('Template' , model_id = MODEL_ID , max_length = 400 )
4344 dataset .map (GSM8KProcessor ())
4445 dataset .encode (add_generation_prompt = True )
4546 return dataset
@@ -68,13 +69,21 @@ def main():
6869 sampler_mesh = DeviceMesh .from_sizes (world_size = SAMPLER_GPUS , dp_size = SAMPLER_GPUS )
6970 twinkle .initialize (mode = 'ray' , nproc_per_node = NUM_GPUS , groups = device_groups , lazy_collect = False )
7071
71- lora_config = LoraConfig (target_modules = 'all-linear' , r = 32 , lora_alpha = 64 , lora_dropout = 0.05 )
72-
72+ # lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05)
73+ lora_config = LoraConfig (
74+ target_modules = [
75+ 'q_proj' , 'k_proj' , 'v_proj' , 'o_proj' ,
76+ 'gate_proj' , 'up_proj' , 'down_proj' ,
77+ 'in_proj_qkv' , 'in_proj_z' , 'in_proj_a' , 'in_proj_b' , 'out_proj' ,
78+ ],
79+ r = 32 , lora_alpha = 64 , lora_dropout = 0.05 ,
80+ )
7381 if USE_MEGATRON :
7482 from twinkle .model .megatron import MegatronModel
7583 model = MegatronModel (model_id = MODEL_ID , device_mesh = model_mesh , remote_group = 'model' , mixed_precision = 'bf16' )
7684 else :
77- model = TransformersModel (model_id = MODEL_ID , device_mesh = model_mesh , remote_group = 'model' )
85+ from transformers import Qwen3_5ForConditionalGeneration
86+ model = TransformersModel (model_id = MODEL_ID , model_cls = Qwen3_5ForConditionalGeneration , device_mesh = model_mesh , remote_group = 'model' )
7887
7988 model .add_adapter_to_model (ADAPTER_NAME , lora_config , gradient_accumulation_steps = 1 )
8089 if USE_MEGATRON :
@@ -91,8 +100,9 @@ def main():
91100 model_id = MODEL_ID ,
92101 engine_args = {
93102 'gpu_memory_utilization' : 0.8 ,
94- 'max_model_len' : 4096 ,
103+ 'max_model_len' : 4496 ,
95104 'max_lora_rank' : 32 , # save as lora_config
105+ # NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976
96106 'enable_lora' : True ,
97107 },
98108 device_mesh = sampler_mesh ,
@@ -172,6 +182,8 @@ def main():
172182
173183 if optim_step >= MAX_STEPS :
174184 break
185+ if optim_step % SAVE_STEPS == 0 :
186+ model .save (f'grpo-gsm8k-checkpoint-{ optim_step } ' )
175187 log_dict = metrics .calculate ()
176188 log_dict .update (model .calculate_metric (is_training = True ))
177189 metrics .reset ()
0 commit comments