Skip to content

Commit 4316583

Browse files
committed
fix moe sp
1 parent d7503d8 commit 4316583

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

cookbook/legacy/grpo/dapo_math.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@
5454
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
5555
SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 1))
5656
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
57-
PP_SIZE = 2
58-
NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
57+
PP_SIZE = 4
58+
NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 4))
5959
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
6060
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
6161
GRPO_EPSILON = float(os.environ.get('GRPO_EPSILON', 0.2))
6262
GRPO_BETA = float(os.environ.get('GRPO_BETA', 0.0))
6363
MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
64-
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 2))
64+
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 1))
6565
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
6666
TEMPERATURE = float(os.environ.get('TEMPERATURE', 1.0))
6767
WEIGHT_SYNC_INTERVAL = int(os.environ.get('WEIGHT_SYNC_INTERVAL', 1))
@@ -334,6 +334,9 @@ def compute_rewards(
334334

335335
# ========== Main ==========
336336
def main():
337+
from twinkle.utils.import_utils import requires
338+
requires("vllm>=0.15.0")
339+
337340
device_groups = [
338341
DeviceGroup(
339342
name='model',
@@ -350,8 +353,10 @@ def main():
350353
]
351354
if USE_MEGATRON:
352355
model_mesh = DeviceMesh.from_sizes(
353-
dp_size=MODEL_GPUS // PP_SIZE, pp_size=PP_SIZE,
354-
ep_size=MODEL_GPUS // PP_SIZE,
356+
dp_size=1,
357+
tp_size=2,
358+
pp_size=2,
359+
ep_size=2,
355360
)
356361
else:
357362
model_mesh = DeviceMesh.from_sizes(
@@ -370,7 +375,7 @@ def main():
370375
)
371376

372377
lora_config = LoraConfig(
373-
target_modules='all-linear',
378+
target_modules=['linear_qkv', 'linear_proj'],
374379
r=8,
375380
lora_alpha=32,
376381
lora_dropout=0.05,

src/twinkle/model/megatron/args.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,9 +557,11 @@ def _get_base_model(m):
557557
use_sequence_parallel = self.sequence_parallel and self.tp_size > 1
558558
if num_experts > 0 and self.tp_size > 1 and not use_sequence_parallel:
559559
use_sequence_parallel = True
560-
print(
561-
f'Auto-enabling sequence_parallel for MoE with TP={self.tp_size}'
562-
)
560+
# Sync the flag back so that callers (e.g. padding logic in
561+
# megatron.py) see the auto-enabled value.
562+
self.sequence_parallel = True
563+
if self.device_mesh is not None:
564+
self.device_mesh.sequence_parallel = True
563565

564566
# For MoE models, ffn_hidden_size should be moe_ffn_hidden_size if not specified
565567
ffn_hidden_size = mg_config_dict.get('ffn_hidden_size')

src/twinkle/model/megatron/strategy/megatron.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@ def __init__(
1818
**kwargs,
1919
):
2020
self.device_mesh = device_mesh
21-
self.sequence_parallel = self.device_mesh.sequence_parallel
2221
self.use_distributed_optimizer = use_distributed_optimizer
2322
self.mixed_precision = mixed_precision
2423
self._params_dtype = params_dtype
2524

25+
@property
26+
def sequence_parallel(self) -> bool:
27+
"""Read from device_mesh so auto-enable in args.py is visible."""
28+
return getattr(self.device_mesh, 'sequence_parallel', False)
29+
2630
def _check_device_mesh(self):
2731
from megatron.core import parallel_state as mpu
2832

0 commit comments

Comments
 (0)