Skip to content

Commit af9ee69

Browse files
committed
Merge commit '9871f3cde444687ba53d8bcda41d82398fba0740' into dev
2 parents 905cc57 + 9871f3c commit af9ee69

File tree

5 files changed

+41
-32
lines changed

5 files changed

+41
-32
lines changed

cookbook/legacy/grpo/dapo_math.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@
5555
SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 1))
5656
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
5757
PP_SIZE = 4
58-
NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 4))
59-
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048))
58+
NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
59+
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
6060
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
6161
GRPO_EPSILON = float(os.environ.get('GRPO_EPSILON', 0.2))
6262
GRPO_BETA = float(os.environ.get('GRPO_BETA', 0.0))
63-
MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
63+
MAX_STEPS = int(os.environ.get('MAX_STEPS', 2000))
6464
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 1))
6565
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
6666
TEMPERATURE = float(os.environ.get('TEMPERATURE', 1.0))
@@ -387,7 +387,7 @@ def main():
387387
model_id=MODEL_ID,
388388
engine_args={
389389
'gpu_memory_utilization': 0.8,
390-
'max_model_len': 8192,
390+
'max_model_len': 6000,
391391
'max_loras': 1,
392392
'max_lora_rank': 32,
393393
'enable_sleep_mode': False,
@@ -408,7 +408,7 @@ def main():
408408
remote_group='model',
409409
mixed_precision='bf16',
410410
recompute_granularity='full',
411-
recompute_num_layers=None,
411+
recompute_num_layers=1,
412412
)
413413
else:
414414
model = TransformersModel(

cookbook/legacy/grpo/gsm8k_dense.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,8 @@ def main():
261261

262262
lora_config = LoraConfig(
263263
target_modules="all-linear",
264-
r=8,
265-
lora_alpha=32,
264+
r=32,
265+
lora_alpha=64,
266266
lora_dropout=0.05,
267267
)
268268

@@ -274,7 +274,7 @@ def main():
274274
remote_group='model',
275275
mixed_precision='bf16',
276276
recompute_granularity='full',
277-
recompute_num_layers=None,
277+
recompute_num_layers=1,
278278
)
279279
else:
280280
model = TransformersModel(

src/twinkle/model/megatron/megatron.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,28 +1240,41 @@ def weight_generator():
12401240
else:
12411241
yield from _raw_weights()
12421242

1243-
async def _send():
1244-
await engine.send_weights(weight_generator())
1243+
is_sender = (engine.rank is not None and engine.rank == 0)
12451244

1246-
result_container = {'error': None}
1245+
if not is_sender:
1246+
for _name, _tensor in weight_generator():
1247+
pass
1248+
return
1249+
1250+
import queue
1251+
buf: queue.Queue = queue.Queue(maxsize=4)
1252+
error: list = []
12471253

1248-
def _run():
1254+
def _send():
1255+
def _iter():
1256+
while (item := buf.get()) is not None:
1257+
yield item
1258+
loop = asyncio.new_event_loop()
12491259
try:
1250-
loop = asyncio.new_event_loop()
1251-
asyncio.set_event_loop(loop)
1252-
try:
1253-
loop.run_until_complete(_send())
1254-
finally:
1255-
loop.close()
1256-
except Exception as e:
1257-
result_container['error'] = e
1258-
1259-
thread = threading.Thread(target=_run)
1260-
thread.start()
1261-
thread.join()
1262-
1263-
if result_container['error'] is not None:
1264-
raise result_container['error']
1260+
loop.run_until_complete(engine.send_weights(_iter()))
1261+
except Exception as exc:
1262+
error.append(exc)
1263+
finally:
1264+
loop.close()
1265+
1266+
sender = threading.Thread(target=_send, name="ce-broadcast", daemon=True)
1267+
sender.start()
1268+
try:
1269+
for name, tensor in weight_generator():
1270+
buf.put((name, tensor.clone()))
1271+
if error:
1272+
break
1273+
finally:
1274+
buf.put(None) # sentinel
1275+
sender.join()
1276+
if error:
1277+
raise error[0]
12651278

12661279
@remote_function(collect='first')
12671280
def get_peft_config_dict(self, adapter_name: str = None) -> dict:

src/twinkle/utils/torch_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def selective_log_softmax(logits, index) -> 'torch.Tensor':
131131
print(traceback.format_exc())
132132
except Exception:
133133
pass
134-
135134
if logits.dtype in [torch.float32, torch.float64]:
136135
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
137136
# loop to reduce peak mem consumption

tests/sampler/test_weight_sync.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,11 @@ def test_standalone_weight_sync(model_gpus: int = 1, sampler_gpus: int = 1):
140140
)
141141
from peft import LoraConfig
142142
model.add_adapter_to_model('default', LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05, target_modules="all-linear"), gradient_accumulation_steps=1)
143-
lora_path = '/mnt/nas2/hujinghan.hjh/swift/output/v1168-20260209-194533/checkpoint-32/default/output/v0-20260209-212154/checkpoint-32'
144-
145-
model.load('default', output_dir=lora_path, adapter_name='default')
146143
# ── Create Sampler (dummy weights) ────────────────────────────────
147144
sampler = vLLMSampler(
148145
model_id=model_path,
149146
engine_args={
150-
# 'load_format': 'dummy', # start with random weights
147+
'load_format': 'dummy', # start with random weights
151148
'gpu_memory_utilization': 0.3,
152149
'max_model_len': 256,
153150
'enforce_eager': True,

0 commit comments

Comments
 (0)