Skip to content

Commit 0095fc0

Browse files
committed
wip
1 parent 1451afc commit 0095fc0

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

cookbook/legacy/grpo/gsm8k.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@
3939
logger = get_logger()
4040

4141
# ========== Configuration ==========
42-
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-3B-Instruct')
42+
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507')
4343
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))
4444

4545
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
46-
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2))
46+
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
4747
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
4848

4949
NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
@@ -242,22 +242,21 @@ def main():
242242
]
243243
if USE_MEGATRON:
244244
model_mesh = DeviceMesh.from_sizes(
245-
dp_size=MODEL_GPUS, tp_size=1, pp_size=1,
245+
dp_size=1, tp_size=SAMPLER_GPUS, ep_size=SAMPLER_GPUS,
246246
)
247247
else:
248248
model_mesh = DeviceMesh.from_sizes(
249249
world_size=MODEL_GPUS, dp_size=MODEL_GPUS,
250250
)
251251
sampler_mesh = DeviceMesh.from_sizes(
252-
world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS,
252+
world_size=SAMPLER_GPUS, dp_size=4
253253
)
254254
twinkle.initialize(
255255
mode='ray',
256256
nproc_per_node=NUM_GPUS,
257257
groups=device_groups,
258258
lazy_collect=False,
259259
)
260-
logger.info(get_device_placement())
261260

262261
lora_config = LoraConfig(
263262
target_modules="all-linear",
@@ -316,8 +315,8 @@ def main():
316315
model_id=MODEL_ID,
317316
engine_args={
318317
'gpu_memory_utilization': 0.7,
319-
'max_model_len': 4096,
320-
'max_lora_rank': 64,
318+
'max_model_len': 2048,
319+
'max_lora_rank': 32,
321320
'enforce_eager': True,
322321
'enable_sleep_mode': False,
323322
'enable_lora': False,
@@ -351,6 +350,7 @@ def main():
351350

352351
# ── Training loop ────────────────────────────────────────────────
353352
optim_step = 0
353+
logger.info(get_device_placement())
354354

355355
for batch in dataloader:
356356
if optim_step >= MAX_STEPS:

0 commit comments

Comments
 (0)