Skip to content

Commit aaddedf

Browse files
committed
Revert "group sync"
This reverts commit 00c1e52.
1 parent 00c1e52 commit aaddedf

File tree

1 file changed

+12
-22
lines changed

1 file changed

+12
-22
lines changed

src/twinkle/sampler/vllm_sampler/vllm_sampler.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -510,37 +510,27 @@ def receive_weights(
510510
engine = self._get_or_create_checkpoint_engine()
511511

512512
async def _receive_and_load():
513+
# Collect weights with original names — name conversion is done
514+
# in the vLLM worker subprocess (TwinkleWorkerExtension).
513515
weights = {}
514-
total_count = 0
515-
batch_size = 200
516-
517516
async for name, tensor in engine.receive_weights():
518517
weights[name] = tensor.clone()
519518

520-
if len(weights) >= batch_size:
521-
await self.engine.update_weights(
522-
weights,
523-
peft_config=peft_config,
524-
base_sync_done=base_sync_done,
525-
)
526-
total_count += len(weights)
527-
weights = {}
528-
529-
if weights:
530-
await self.engine.update_weights(
531-
weights,
532-
peft_config=peft_config,
533-
base_sync_done=base_sync_done,
534-
)
535-
total_count += len(weights)
536-
537-
if total_count == 0:
519+
if not weights:
538520
return 0
539521

522+
await self.engine.update_weights(
523+
weights,
524+
peft_config=peft_config,
525+
base_sync_done=base_sync_done,
526+
)
527+
528+
# After LoRA sync, mark that the synced LoRA is loaded so
529+
# sampling automatically uses it.
540530
if base_sync_done and peft_config:
541531
self._ckpt_lora_loaded = True
542532

543-
return total_count
533+
return len(weights)
544534

545535
return self._run_in_loop(_receive_and_load())
546536

0 commit comments

Comments
 (0)