Skip to content

Commit 00c1e52

Browse files
committed
group sync
1 parent 1a49246 commit 00c1e52

1 file changed

Lines changed: 22 additions & 12 deletions

File tree

src/twinkle/sampler/vllm_sampler/vllm_sampler.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -510,27 +510,37 @@ def receive_weights(
510510
engine = self._get_or_create_checkpoint_engine()
511511

512512
async def _receive_and_load():
513-
# Collect weights with original names — name conversion is done
514-
# in the vLLM worker subprocess (TwinkleWorkerExtension).
515513
weights = {}
514+
total_count = 0
515+
batch_size = 200
516+
516517
async for name, tensor in engine.receive_weights():
517518
weights[name] = tensor.clone()
518519

519-
if not weights:
520+
if len(weights) >= batch_size:
521+
await self.engine.update_weights(
522+
weights,
523+
peft_config=peft_config,
524+
base_sync_done=base_sync_done,
525+
)
526+
total_count += len(weights)
527+
weights = {}
528+
529+
if weights:
530+
await self.engine.update_weights(
531+
weights,
532+
peft_config=peft_config,
533+
base_sync_done=base_sync_done,
534+
)
535+
total_count += len(weights)
536+
537+
if total_count == 0:
520538
return 0
521539

522-
await self.engine.update_weights(
523-
weights,
524-
peft_config=peft_config,
525-
base_sync_done=base_sync_done,
526-
)
527-
528-
# After LoRA sync, mark that the synced LoRA is loaded so
529-
# sampling automatically uses it.
530540
if base_sync_done and peft_config:
531541
self._ckpt_lora_loaded = True
532542

533-
return len(weights)
543+
return total_count
534544

535545
return self._run_in_loop(_receive_and_load())
536546

0 commit comments

Comments
 (0)