5454SAMPLER_GPUS = int (os .environ .get ('SAMPLER_GPUS' , 4 ))
5555SAMPLER_TP = int (os .environ .get ('SAMPLER_TP' , 1 ))
5656NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
57- PP_SIZE = 2
58- NUM_GENERATIONS = int (os .environ .get ('NUM_GENERATIONS' , 8 ))
57+ PP_SIZE = 4
58+ NUM_GENERATIONS = int (os .environ .get ('NUM_GENERATIONS' , 4 ))
5959MAX_NEW_TOKENS = int (os .environ .get ('MAX_NEW_TOKENS' , 4096 ))
6060LEARNING_RATE = float (os .environ .get ('LR' , 1e-5 ))
6161GRPO_EPSILON = float (os .environ .get ('GRPO_EPSILON' , 0.2 ))
6262GRPO_BETA = float (os .environ .get ('GRPO_BETA' , 0.0 ))
6363MAX_STEPS = int (os .environ .get ('MAX_STEPS' , 200 ))
64- BATCH_SIZE = int (os .environ .get ('BATCH_SIZE' , 2 ))
64+ BATCH_SIZE = int (os .environ .get ('BATCH_SIZE' , 1 ))
6565GRADIENT_ACCUMULATION_STEPS = int (os .environ .get ('GRADIENT_ACCUMULATION_STEPS' , 1 ))
6666TEMPERATURE = float (os .environ .get ('TEMPERATURE' , 1.0 ))
6767WEIGHT_SYNC_INTERVAL = int (os .environ .get ('WEIGHT_SYNC_INTERVAL' , 1 ))
@@ -334,6 +334,9 @@ def compute_rewards(
334334
335335# ========== Main ==========
336336def main ():
337+ from twinkle .utils .import_utils import requires
338+ requires ("vllm>=0.15.0" )
339+
337340 device_groups = [
338341 DeviceGroup (
339342 name = 'model' ,
@@ -350,8 +353,10 @@ def main():
350353 ]
351354 if USE_MEGATRON :
352355 model_mesh = DeviceMesh .from_sizes (
353- dp_size = MODEL_GPUS // PP_SIZE , pp_size = PP_SIZE ,
354- ep_size = MODEL_GPUS // PP_SIZE ,
356+ dp_size = 1 ,
357+ tp_size = 2 ,
358+ pp_size = 2 ,
359+ ep_size = 2 ,
355360 )
356361 else :
357362 model_mesh = DeviceMesh .from_sizes (
@@ -370,7 +375,7 @@ def main():
370375 )
371376
372377 lora_config = LoraConfig (
373- target_modules = 'all-linear' ,
378+ target_modules = [ 'linear_qkv' , 'linear_proj' ] ,
374379 r = 8 ,
375380 lora_alpha = 32 ,
376381 lora_dropout = 0.05 ,
0 commit comments