|
32 | 32 | from twinkle.preprocessor import Preprocessor |
33 | 33 | from twinkle.processor import InputProcessor |
34 | 34 | from twinkle.reward.base import Reward |
35 | | -from twinkle.sampler import VLLMSampler |
| 35 | +from twinkle.sampler import vLLMSampler |
36 | 36 | from twinkle.template import Template |
37 | 37 | from twinkle.metric import CompletionRewardMetric |
38 | 38 |
|
|
47 | 47 | NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS |
48 | 48 |
|
49 | 49 | NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) |
50 | | -MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 1024)) |
| 50 | +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048)) |
51 | 51 | LEARNING_RATE = float(os.environ.get('LR', 1e-5)) |
52 | 52 | GRPO_EPSILON = float(os.environ.get('GRPO_EPSILON', 0.2)) |
53 | 53 | GRPO_BETA = float(os.environ.get('GRPO_BETA', 0.0)) |
54 | 54 | MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) |
55 | 55 | BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) |
56 | | -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) |
| 56 | +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) |
57 | 57 | TEMPERATURE = float(os.environ.get('TEMPERATURE', 1.0)) |
58 | 58 | WEIGHT_SYNC_INTERVAL = int(os.environ.get('WEIGHT_SYNC_INTERVAL', 1)) |
59 | 59 | ADAPTER_NAME = 'default' |
|
80 | 80 | }) |
81 | 81 |
|
82 | 82 |
|
83 | | -# ========== GSM8K Data Processing ========== |
84 | 83 | SYSTEM_PROMPT = ( |
85 | 84 | "You are a helpful math assistant. Solve the problem step by step. " |
86 | 85 | "Show your reasoning in <think> </think> tags, then give the final " |
@@ -326,11 +325,12 @@ def main(): |
326 | 325 | model.set_template('Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME) |
327 | 326 |
|
328 | 327 | # ── Sampler (load real weights for meaningful generation) ───────── |
329 | | - sampler = VLLMSampler( |
| 328 | + sampler = vLLMSampler( |
330 | 329 | model_id=MODEL_ID, |
331 | 330 | engine_args={ |
332 | 331 | 'gpu_memory_utilization': 0.7, |
333 | | - 'max_model_len': 2048, |
| 332 | + 'max_model_len': 4096, |
| 333 | + 'max_lora_rank': 64, |
334 | 334 | 'enforce_eager': True, |
335 | 335 | 'enable_sleep_mode': False, |
336 | 336 | 'enable_lora': True, |
@@ -381,13 +381,12 @@ def main(): |
381 | 381 |
|
382 | 382 | global_prompts = batch if isinstance(batch, list) else [batch] |
383 | 383 |
|
384 | | - # ========== 1. Weight Sync ========== |
385 | 384 | t0 = time.perf_counter() |
386 | 385 | if optim_step % WEIGHT_SYNC_INTERVAL == 0: |
387 | 386 | ckpt_manager.sync_weights(adapter_name=ADAPTER_NAME) |
| 387 | + sampler.reset_prefix_cache() |
388 | 388 | timings['weight_sync'] = time.perf_counter() - t0 |
389 | 389 |
|
390 | | - # ========== 2. Generate ========== |
391 | 390 | t1 = time.perf_counter() |
392 | 391 | sample_response = sampler.sample( |
393 | 392 | global_prompts, |
@@ -445,32 +444,20 @@ def main(): |
445 | 444 | 1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0 |
446 | 445 | ) |
447 | 446 |
|
448 | | - # ========== 5. Training (micro-batches) ========== |
| 447 | + # ========== 5. Training ========== |
449 | 448 | t4 = time.perf_counter() |
450 | | - micro_batch_seqs = BATCH_SIZE * NUM_GENERATIONS |
451 | | - |
452 | | - for micro_idx in range(GRADIENT_ACCUMULATION_STEPS): |
453 | | - start = micro_idx * micro_batch_seqs |
454 | | - end = start + micro_batch_seqs |
455 | | - mb_inputs = all_input_data[start:end] |
456 | | - mb_old_logps = all_old_logps[start:end] |
457 | | - mb_advantages = advantages[start:end] |
458 | | - |
459 | | - if not mb_inputs: |
460 | | - break |
461 | | - |
462 | | - if all(abs(a) < 1e-8 for a in mb_advantages): |
463 | | - logger.info( |
464 | | - f"Optim step {optim_step}, micro {micro_idx}: " |
465 | | - f"All advantages zero, skipping" |
466 | | - ) |
467 | | - continue |
468 | 449 |
|
| 450 | + if all(abs(a) < 1e-8 for a in advantages): |
| 451 | + logger.info( |
| 452 | + f"Optim step {optim_step}: " |
| 453 | + f"All advantages zero, skipping training" |
| 454 | + ) |
| 455 | + else: |
469 | 456 | model.forward_backward( |
470 | | - inputs=mb_inputs, |
| 457 | + inputs=all_input_data, |
471 | 458 | adapter_name=ADAPTER_NAME, |
472 | | - advantages=mb_advantages, |
473 | | - old_logps=mb_old_logps, |
| 459 | + advantages=advantages, |
| 460 | + old_logps=all_old_logps, |
474 | 461 | ) |
475 | 462 |
|
476 | 463 | model.clip_grad_and_step(adapter_name=ADAPTER_NAME) |
|
0 commit comments