|
39 | 39 | logger = get_logger() |
40 | 40 |
|
41 | 41 | # ========== Configuration ========== |
42 | | -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-3B-Instruct') |
| 42 | +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507') |
43 | 43 | USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) |
44 | 44 |
|
45 | 45 | MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) |
46 | | -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) |
| 46 | +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) |
47 | 47 | NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS |
48 | 48 |
|
49 | 49 | NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) |
@@ -242,22 +242,21 @@ def main(): |
242 | 242 | ] |
243 | 243 | if USE_MEGATRON: |
244 | 244 | model_mesh = DeviceMesh.from_sizes( |
245 | | - dp_size=MODEL_GPUS, tp_size=1, pp_size=1, |
| 245 | + dp_size=1, tp_size=SAMPLER_GPUS, ep_size=SAMPLER_GPUS, |
246 | 246 | ) |
247 | 247 | else: |
248 | 248 | model_mesh = DeviceMesh.from_sizes( |
249 | 249 | world_size=MODEL_GPUS, dp_size=MODEL_GPUS, |
250 | 250 | ) |
251 | 251 | sampler_mesh = DeviceMesh.from_sizes( |
252 | | - world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS, |
| 252 | + world_size=SAMPLER_GPUS, dp_size=4 |
253 | 253 | ) |
254 | 254 | twinkle.initialize( |
255 | 255 | mode='ray', |
256 | 256 | nproc_per_node=NUM_GPUS, |
257 | 257 | groups=device_groups, |
258 | 258 | lazy_collect=False, |
259 | 259 | ) |
260 | | - logger.info(get_device_placement()) |
261 | 260 |
|
262 | 261 | lora_config = LoraConfig( |
263 | 262 | target_modules="all-linear", |
@@ -316,8 +315,8 @@ def main(): |
316 | 315 | model_id=MODEL_ID, |
317 | 316 | engine_args={ |
318 | 317 | 'gpu_memory_utilization': 0.7, |
319 | | - 'max_model_len': 4096, |
320 | | - 'max_lora_rank': 64, |
| 318 | + 'max_model_len': 2048, |
| 319 | + 'max_lora_rank': 32, |
321 | 320 | 'enforce_eager': True, |
322 | 321 | 'enable_sleep_mode': False, |
323 | 322 | 'enable_lora': False, |
@@ -351,6 +350,7 @@ def main(): |
351 | 350 |
|
352 | 351 | # ── Training loop ──────────────────────────────────────────────── |
353 | 352 | optim_step = 0 |
| 353 | + logger.info(get_device_placement()) |
354 | 354 |
|
355 | 355 | for batch in dataloader: |
356 | 356 | if optim_step >= MAX_STEPS: |
|
0 commit comments