@@ -510,27 +510,37 @@ def receive_weights(
510510 engine = self ._get_or_create_checkpoint_engine ()
511511
512512 async def _receive_and_load ():
513- # Collect weights with original names — name conversion is done
514- # in the vLLM worker subprocess (TwinkleWorkerExtension).
515513 weights = {}
514+ total_count = 0
515+ batch_size = 200
516+
516517 async for name , tensor in engine .receive_weights ():
517518 weights [name ] = tensor .clone ()
518519
519- if not weights :
520+ if len (weights ) >= batch_size :
521+ await self .engine .update_weights (
522+ weights ,
523+ peft_config = peft_config ,
524+ base_sync_done = base_sync_done ,
525+ )
526+ total_count += len (weights )
527+ weights = {}
528+
529+ if weights :
530+ await self .engine .update_weights (
531+ weights ,
532+ peft_config = peft_config ,
533+ base_sync_done = base_sync_done ,
534+ )
535+ total_count += len (weights )
536+
537+ if total_count == 0 :
520538 return 0
521539
522- await self .engine .update_weights (
523- weights ,
524- peft_config = peft_config ,
525- base_sync_done = base_sync_done ,
526- )
527-
528- # After LoRA sync, mark that the synced LoRA is loaded so
529- # sampling automatically uses it.
530540 if base_sync_done and peft_config :
531541 self ._ckpt_lora_loaded = True
532542
533- return len ( weights )
543+ return total_count
534544
535545 return self ._run_in_loop (_receive_and_load ())
536546
0 commit comments