99from twinkle .model import TransformersModel
1010from twinkle .reward import MathReward
1111from twinkle .sampler import VLLMSampler , TorchSampler
12- from twinkle .sampler .types import SamplingParams
12+ from twinkle .sampler .types import SamplingParams , SampleResponse
1313from twinkle .weight_loader import NativeLoader
1414from twinkle .rl import compute_advantages
1515
1616# Environment variable setup
1717os .environ .setdefault ('TRUST_REMOTE_CODE' , '1' )
1818os .environ .setdefault ('TWINKLE_SEED' , '42' )
1919os .environ .setdefault ('TWINKLE_FULL_DETERMINISM' , '1' )
20+ os .environ .setdefault ('RAY_TMPDIR' , os .path .expanduser ('~/tmp/ray' ))
2021
2122# Training configuration
2223use_ref_model = os .environ .get ('TWINKLE_USE_REF_MODEL' , '1' ) != '0'
@@ -129,6 +130,29 @@ def get_sampling_params(eos_token_ids) -> SamplingParams:
129130 )
130131
131132
133+ def build_trajectories_from_sample_response (sample_response : SampleResponse , batch_list , tokenizer ):
134+ """Convert sampler output into GRPO trajectories."""
135+ if not sample_response or not getattr (sample_response , 'sequences' , None ):
136+ return []
137+ if not batch_list :
138+ return []
139+
140+ trajectories = []
141+ for i , seq in enumerate (sample_response .sequences ):
142+ src_batch = batch_list [i % len (batch_list )]
143+ src_messages = [dict (msg ) for msg in src_batch .get ('messages' , [])]
144+ if src_messages and src_messages [- 1 ].get ('role' ) == 'assistant' :
145+ # Remove reference answer and append sampled assistant reply.
146+ src_messages = src_messages [:- 1 ]
147+
148+ response_text = tokenizer .decode (seq .tokens , skip_special_tokens = True ) if tokenizer is not None else ''
149+ trajectories .append ({
150+ 'messages' : src_messages + [{'role' : 'assistant' , 'content' : response_text }],
151+ 'user_data' : list (src_batch .get ('user_data' , [])),
152+ })
153+ return trajectories
154+
155+
132156def debug_print_rollout (step , trajectories , ground_truths , rewards = None ):
133157 """Debug helper that prints rollout intermediates (sampling, rewards, etc.).
134158
@@ -182,6 +206,19 @@ def debug_print_rollout(step, trajectories, ground_truths, rewards=None):
182206 )
183207
184208
209+ def _collect_sample_responses (results ):
210+ """Custom collect function to merge multiple SampleResponse objects."""
211+ if not results :
212+ return SampleResponse (sequences = [])
213+ if len (results ) == 1 :
214+ return results [0 ]
215+ all_sequences = []
216+ for resp in results :
217+ if resp is not None and hasattr (resp , 'sequences' ):
218+ all_sequences .extend (resp .sequences )
219+ return SampleResponse (sequences = all_sequences )
220+
221+
185222@remote_class ()
186223class ActorGroup :
187224
@@ -226,7 +263,7 @@ def __init__(self, engine_args=None, lora_config=None, adapter_name=None, **kwar
226263 self .adapter_name = adapter_name
227264 self .lora_config = lora_config
228265
229- @remote_function (collect = 'flatten' )
266+ @remote_function (collect = _collect_sample_responses )
230267 def sample (self , batch , sampling_params : SamplingParams = None ):
231268 return self .sampler .sample (batch , sampling_params = sampling_params , adapter_name = self .adapter_name )
232269
@@ -293,6 +330,11 @@ def train():
293330 )
294331
295332 eos_token_ids = get_eos_token_ids ()
333+ try :
334+ from transformers import AutoTokenizer
335+ tokenizer = AutoTokenizer .from_pretrained (model_path , trust_remote_code = True )
336+ except Exception :
337+ tokenizer = None
296338
297339 engine_args = {
298340 'model' : model_path ,
@@ -339,13 +381,18 @@ def train():
339381 batch_list = [batch ]
340382 else :
341383 batch_list = list (batch )
342- ground_truths = batch_list .copy ()
343-
344384 sampling_params = get_sampling_params (eos_token_ids )
345385
346- trajectories = actor_group .sample (batch_list , sampling_params )
347- if callable (trajectories ):
348- trajectories = trajectories ()
386+ sample_response = actor_group .sample (batch_list , sampling_params )
387+ if callable (sample_response ):
388+ sample_response = sample_response ()
389+ trajectories = build_trajectories_from_sample_response (sample_response , batch_list , tokenizer )
390+ if not trajectories :
391+ print (f'[step { step } ] empty sampled trajectories, skip.' , flush = True )
392+ continue
393+
394+ # Expand ground truths to align with sampled trajectory count.
395+ ground_truths = [batch_list [i % len (batch_list )] for i in range (len (trajectories ))]
349396
350397 ref_logits = None
351398 if use_ref_model :
@@ -357,14 +404,19 @@ def train():
357404 else :
358405 ref_logits = ref_outputs ['logits' ] if isinstance (ref_outputs , dict ) else ref_outputs .logits
359406
360- rewards = reward . calculate (trajectories , ground_truths )
407+ rewards = reward (trajectories , ground_truths )
361408 if callable (rewards ):
362409 rewards = rewards ()
363410
364- # Updated: compute advantages from rewards and store in trajectory
365- advantages = compute_advantages (rewards , num_generations = num_generations )
411+ effective_num_generations = num_generations if len (rewards ) % num_generations == 0 else 1
412+ scale = 'group' if effective_num_generations > 1 else 'batch'
413+ advantages = compute_advantages (
414+ rewards ,
415+ num_generations = effective_num_generations ,
416+ scale = scale ,
417+ )
366418 for trajectory , advantage in zip (trajectories , advantages .tolist ()):
367- trajectory ['advantages' ] = advantage
419+ trajectory ['advantages' ] = float ( advantage )
368420
369421 # Debug: print reward statistics (enable via TWINKLE_DEBUG=1)
370422 debug_print_rollout (step , trajectories , ground_truths , rewards = rewards )
@@ -383,5 +435,6 @@ def train():
383435 break
384436
385437
438+
386439if __name__ == '__main__' :
387440 train ()
0 commit comments