Skip to content

Commit 25e89ab

Browse files
committed
gsm8k demo
1 parent b1b0e71 commit 25e89ab

File tree

1 file changed

+17
-30
lines changed

1 file changed

+17
-30
lines changed

cookbook/legacy/grpo/gsm8k.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from twinkle.preprocessor import Preprocessor
3333
from twinkle.processor import InputProcessor
3434
from twinkle.reward.base import Reward
35-
from twinkle.sampler import VLLMSampler
35+
from twinkle.sampler import vLLMSampler
3636
from twinkle.template import Template
3737
from twinkle.metric import CompletionRewardMetric
3838

@@ -47,13 +47,13 @@
4747
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
4848

4949
NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
50-
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 1024))
50+
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048))
5151
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
5252
GRPO_EPSILON = float(os.environ.get('GRPO_EPSILON', 0.2))
5353
GRPO_BETA = float(os.environ.get('GRPO_BETA', 0.0))
5454
MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
5555
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
56-
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2))
56+
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
5757
TEMPERATURE = float(os.environ.get('TEMPERATURE', 1.0))
5858
WEIGHT_SYNC_INTERVAL = int(os.environ.get('WEIGHT_SYNC_INTERVAL', 1))
5959
ADAPTER_NAME = 'default'
@@ -80,7 +80,6 @@
8080
})
8181

8282

83-
# ========== GSM8K Data Processing ==========
8483
SYSTEM_PROMPT = (
8584
"You are a helpful math assistant. Solve the problem step by step. "
8685
"Show your reasoning in <think> </think> tags, then give the final "
@@ -326,11 +325,12 @@ def main():
326325
model.set_template('Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME)
327326

328327
# ── Sampler (load real weights for meaningful generation) ─────────
329-
sampler = VLLMSampler(
328+
sampler = vLLMSampler(
330329
model_id=MODEL_ID,
331330
engine_args={
332331
'gpu_memory_utilization': 0.7,
333-
'max_model_len': 2048,
332+
'max_model_len': 4096,
333+
'max_lora_rank': 64,
334334
'enforce_eager': True,
335335
'enable_sleep_mode': False,
336336
'enable_lora': True,
@@ -381,13 +381,12 @@ def main():
381381

382382
global_prompts = batch if isinstance(batch, list) else [batch]
383383

384-
# ========== 1. Weight Sync ==========
385384
t0 = time.perf_counter()
386385
if optim_step % WEIGHT_SYNC_INTERVAL == 0:
387386
ckpt_manager.sync_weights(adapter_name=ADAPTER_NAME)
387+
sampler.reset_prefix_cache()
388388
timings['weight_sync'] = time.perf_counter() - t0
389389

390-
# ========== 2. Generate ==========
391390
t1 = time.perf_counter()
392391
sample_response = sampler.sample(
393392
global_prompts,
@@ -445,32 +444,20 @@ def main():
445444
1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0
446445
)
447446

448-
# ========== 5. Training (micro-batches) ==========
447+
# ========== 5. Training ==========
449448
t4 = time.perf_counter()
450-
micro_batch_seqs = BATCH_SIZE * NUM_GENERATIONS
451-
452-
for micro_idx in range(GRADIENT_ACCUMULATION_STEPS):
453-
start = micro_idx * micro_batch_seqs
454-
end = start + micro_batch_seqs
455-
mb_inputs = all_input_data[start:end]
456-
mb_old_logps = all_old_logps[start:end]
457-
mb_advantages = advantages[start:end]
458-
459-
if not mb_inputs:
460-
break
461-
462-
if all(abs(a) < 1e-8 for a in mb_advantages):
463-
logger.info(
464-
f"Optim step {optim_step}, micro {micro_idx}: "
465-
f"All advantages zero, skipping"
466-
)
467-
continue
468449

450+
if all(abs(a) < 1e-8 for a in advantages):
451+
logger.info(
452+
f"Optim step {optim_step}: "
453+
f"All advantages zero, skipping training"
454+
)
455+
else:
469456
model.forward_backward(
470-
inputs=mb_inputs,
457+
inputs=all_input_data,
471458
adapter_name=ADAPTER_NAME,
472-
advantages=mb_advantages,
473-
old_logps=mb_old_logps,
459+
advantages=advantages,
460+
old_logps=all_old_logps,
474461
)
475462

476463
model.clip_grad_and_step(adapter_name=ADAPTER_NAME)

0 commit comments

Comments
 (0)