@@ -182,68 +182,98 @@ def main():
182182 step += 1
183183 continue
184184
185- # ========== 6. Training step ==========
186- # Select samples with positive advantages for training
187- # Weight them by their advantage value for GRPO-style optimization
188- training_data = []
189- for i , seq in enumerate (all_sequences ):
190- if advantages [i ] <= 0 :
191- continue
192- # Build a Datum from the completion tokens
193- # Prompt tokens: weight=0 (don't compute loss on prompt)
194- # Completion tokens: weight=advantage (advantage-weighted SFT)
195- prompt_feature = prompts [i // NUM_GENERATIONS ]
196- prompt_ids = prompt_feature ['input_ids' ]
197- if hasattr (prompt_ids , 'tolist' ):
198- prompt_ids = prompt_ids .tolist ()
199-
200- full_tokens = prompt_ids + list (seq .tokens )
201- prompt_weights = [0.0 ] * len (prompt_ids )
202- # Scale completion weights by normalized advantage
203- completion_weights = [float (advantages [i ])] * len (seq .tokens )
204-
205- # Shift by one for next-token prediction
206- input_tokens = full_tokens [:- 1 ]
207- target_tokens = full_tokens [1 :]
208- weights = (prompt_weights + completion_weights )[1 :]
209-
210- datum = types .Datum (
211- model_input = types .ModelInput .from_ints (input_tokens ),
212- loss_fn_inputs = {
213- 'target_tokens' : target_tokens ,
214- 'weights' : weights ,
215- },
216- )
217- training_data .append (datum )
185+ # Train the policies with the Advantage-Regularized policy
186+ # gradient (GRPO) loss function.
187+ #
188+ # The GRPO loss function requires:
189+ # 1. logprobs: The log probabilities of the tokens under the current policy
190+ # 2. advantages: The advantage values for each completion
191+ #
192+ # The training data is constructed with:
193+ # - model_input: The full prompt + completion tokens
194+ # - target_tokens: The shifted tokens for next-token prediction
195+ # - logprobs: The log probabilities from the sampling step
196+ # - advantages: The computed advantage values
197+ training_data = []
198+ for i , seq in enumerate (all_sequences ):
199+ # Build a Datum from the completion tokens with logprobs and advantages
200+ prompt_feature = prompts [i // NUM_GENERATIONS ]
201+ prompt_ids = prompt_feature ['input_ids' ]
202+ if hasattr (prompt_ids , 'tolist' ):
203+ prompt_ids = prompt_ids .tolist ()
204+
205+ full_tokens = prompt_ids + list (seq .tokens )
206+
207+ # Shift by one for next-token prediction
208+ input_tokens = full_tokens [:- 1 ]
209+ target_tokens = full_tokens [1 :]
210+
211+ # Get logprobs from the sampling result
212+ logprobs = seq .logprobs if seq .logprobs else [0.0 ] * len (seq .tokens )
213+ # Pad logprobs to match full sequence length (prompt + completion)
214+ # Prompt positions get 0.0 logprobs (no loss computed on prompt)
215+ padded_logprobs = [0.0 ] * len (prompt_ids ) + logprobs
216+
217+ # Get advantage for this sequence
218+ advantage = float (advantages [i ])
219+
220+ # Pad advantages to match full sequence length
221+ # Only completion tokens get the advantage value, prompt gets 0.0
222+ padded_advantages = [0.0 ] * len (prompt_ids ) + [advantage ] * len (seq .tokens )
223+
224+ # Verify lengths match
225+ assert len (input_tokens ) == len (target_tokens ) == len (padded_logprobs ) == len (padded_advantages ), \
226+ f"Length mismatch: input={ len (input_tokens )} , target={ len (target_tokens )} , " \
227+ f"logprobs={ len (padded_logprobs )} , advantages={ len (padded_advantages )} "
228+
229+ datum = types .Datum (
230+ model_input = types .ModelInput .from_ints (input_tokens ),
231+ loss_fn_inputs = {
232+ 'target_tokens' : target_tokens ,
233+ 'logprobs' : types .TensorData .from_numpy (np .array (padded_logprobs , dtype = np .float32 )),
234+ 'advantages' : types .TensorData .from_numpy (np .array (padded_advantages , dtype = np .float32 )),
235+ },
236+ )
237+ training_data .append (datum )
218238
219239 if not training_data :
220240 logger .info (
221- f"Step { step } : No positive-advantage samples , skipping" )
241+ f"Step { step } : No training data constructed , skipping" )
222242 step += 1
223243 continue
224244
225- # Forward-backward pass with cross-entropy on advantage-weighted data
245+ # Forward-backward pass with importance_sampling (GRPO) loss
246+ # The training data already contains logprobs and advantages for the GRPO loss
226247 fwdbwd_future = training_client .forward_backward (
227- training_data , "cross_entropy " )
248+ training_data , "importance_sampling " )
228249 optim_future = training_client .optim_step (
229250 types .AdamParams (learning_rate = LEARNING_RATE ))
230-
251+
231252 fwdbwd_result = fwdbwd_future .result ()
232253 optim_result = optim_future .result ()
233254
234- # Compute weighted average loss for monitoring
235- logprobs = np .concatenate (
236- [output ['logprobs' ].tolist ()
237- for output in fwdbwd_result .loss_fn_outputs ])
238- weights = np .concatenate (
239- [d .loss_fn_inputs ['weights' ].tolist () for d in training_data ])
240- loss_per_token = - np .dot (logprobs , weights ) / max (weights .sum (), 1e-8 )
255+ # Compute metrics from the forward-backward result
256+ # For importance_sampling, we get logprobs and elementwise_loss
257+ logprobs_list = []
258+ elementwise_losses = []
259+ for output in fwdbwd_result .loss_fn_outputs :
260+ if output .get ('logprobs' ) is not None :
261+ logprobs_list .append (output ['logprobs' ].to_numpy ())
262+ if output .get ('elementwise_loss' ) is not None :
263+ elementwise_losses .append (output ['elementwise_loss' ].to_numpy ())
264+
265+ # Compute average loss per token (weighted by advantages)
266+ if elementwise_losses :
267+ all_losses = np .concatenate (elementwise_losses )
268+ avg_loss = np .mean (all_losses ) if len (all_losses ) > 0 else 0.0
269+ else :
270+ avg_loss = 0.0
241271
242272 gc .collect ()
243273
244274 # ========== 7. Log ==========
245275 log_dict = metrics .calculate ()
246- log_dict ['train/loss_per_token' ] = loss_per_token
276+ log_dict ['train/loss_per_token' ] = float ( avg_loss )
247277 log_dict ['train/frac_reward_zero_std' ] = frac_zero_std
248278 log_dict ['train/num_training_samples' ] = len (training_data )
249279 logger .info (f"Step { step } : { log_dict } " )
0 commit comments