Skip to content

Commit 5e96763

Browse files
committed
update
1 parent 57e4482 commit 5e96763

File tree

6 files changed

+232
-52
lines changed

6 files changed

+232
-52
lines changed

cookbook/client/tinker/transformer/grpo.py

Lines changed: 75 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -182,68 +182,98 @@ def main():
182182
step += 1
183183
continue
184184

185-
# ========== 6. Training step ==========
186-
# Select samples with positive advantages for training
187-
# Weight them by their advantage value for GRPO-style optimization
188-
training_data = []
189-
for i, seq in enumerate(all_sequences):
190-
if advantages[i] <= 0:
191-
continue
192-
# Build a Datum from the completion tokens
193-
# Prompt tokens: weight=0 (don't compute loss on prompt)
194-
# Completion tokens: weight=advantage (advantage-weighted SFT)
195-
prompt_feature = prompts[i // NUM_GENERATIONS]
196-
prompt_ids = prompt_feature['input_ids']
197-
if hasattr(prompt_ids, 'tolist'):
198-
prompt_ids = prompt_ids.tolist()
199-
200-
full_tokens = prompt_ids + list(seq.tokens)
201-
prompt_weights = [0.0] * len(prompt_ids)
202-
# Scale completion weights by normalized advantage
203-
completion_weights = [float(advantages[i])] * len(seq.tokens)
204-
205-
# Shift by one for next-token prediction
206-
input_tokens = full_tokens[:-1]
207-
target_tokens = full_tokens[1:]
208-
weights = (prompt_weights + completion_weights)[1:]
209-
210-
datum = types.Datum(
211-
model_input=types.ModelInput.from_ints(input_tokens),
212-
loss_fn_inputs={
213-
'target_tokens': target_tokens,
214-
'weights': weights,
215-
},
216-
)
217-
training_data.append(datum)
185+
# Train the policies with the Advantage-Regularized policy
186+
# gradient (GRPO) loss function.
187+
#
188+
# The GRPO loss function requires:
189+
# 1. logprobs: The log probabilities of the tokens under the current policy
190+
# 2. advantages: The advantage values for each completion
191+
#
192+
# The training data is constructed with:
193+
# - model_input: The full prompt + completion tokens
194+
# - target_tokens: The shifted tokens for next-token prediction
195+
# - logprobs: The log probabilities from the sampling step
196+
# - advantages: The computed advantage values
197+
training_data = []
198+
for i, seq in enumerate(all_sequences):
199+
# Build a Datum from the completion tokens with logprobs and advantages
200+
prompt_feature = prompts[i // NUM_GENERATIONS]
201+
prompt_ids = prompt_feature['input_ids']
202+
if hasattr(prompt_ids, 'tolist'):
203+
prompt_ids = prompt_ids.tolist()
204+
205+
full_tokens = prompt_ids + list(seq.tokens)
206+
207+
# Shift by one for next-token prediction
208+
input_tokens = full_tokens[:-1]
209+
target_tokens = full_tokens[1:]
210+
211+
# Get logprobs from the sampling result
212+
logprobs = seq.logprobs if seq.logprobs else [0.0] * len(seq.tokens)
213+
# Pad logprobs to match full sequence length (prompt + completion)
214+
# Prompt positions get 0.0 logprobs (no loss computed on prompt)
215+
padded_logprobs = [0.0] * len(prompt_ids) + logprobs
216+
217+
# Get advantage for this sequence
218+
advantage = float(advantages[i])
219+
220+
# Pad advantages to match full sequence length
221+
# Only completion tokens get the advantage value, prompt gets 0.0
222+
padded_advantages = [0.0] * len(prompt_ids) + [advantage] * len(seq.tokens)
223+
224+
# Verify lengths match
225+
assert len(input_tokens) == len(target_tokens) == len(padded_logprobs) == len(padded_advantages), \
226+
f"Length mismatch: input={len(input_tokens)}, target={len(target_tokens)}, " \
227+
f"logprobs={len(padded_logprobs)}, advantages={len(padded_advantages)}"
228+
229+
datum = types.Datum(
230+
model_input=types.ModelInput.from_ints(input_tokens),
231+
loss_fn_inputs={
232+
'target_tokens': target_tokens,
233+
'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)),
234+
'advantages': types.TensorData.from_numpy(np.array(padded_advantages, dtype=np.float32)),
235+
},
236+
)
237+
training_data.append(datum)
218238

219239
if not training_data:
220240
logger.info(
221-
f"Step {step}: No positive-advantage samples, skipping")
241+
f"Step {step}: No training data constructed, skipping")
222242
step += 1
223243
continue
224244

225-
# Forward-backward pass with cross-entropy on advantage-weighted data
245+
# Forward-backward pass with importance_sampling (GRPO) loss
246+
# The training data already contains logprobs and advantages for the GRPO loss
226247
fwdbwd_future = training_client.forward_backward(
227-
training_data, "cross_entropy")
248+
training_data, "importance_sampling")
228249
optim_future = training_client.optim_step(
229250
types.AdamParams(learning_rate=LEARNING_RATE))
230-
251+
231252
fwdbwd_result = fwdbwd_future.result()
232253
optim_result = optim_future.result()
233254

234-
# Compute weighted average loss for monitoring
235-
logprobs = np.concatenate(
236-
[output['logprobs'].tolist()
237-
for output in fwdbwd_result.loss_fn_outputs])
238-
weights = np.concatenate(
239-
[d.loss_fn_inputs['weights'].tolist() for d in training_data])
240-
loss_per_token = -np.dot(logprobs, weights) / max(weights.sum(), 1e-8)
255+
# Compute metrics from the forward-backward result
256+
# For importance_sampling, we get logprobs and elementwise_loss
257+
logprobs_list = []
258+
elementwise_losses = []
259+
for output in fwdbwd_result.loss_fn_outputs:
260+
if output.get('logprobs') is not None:
261+
logprobs_list.append(output['logprobs'].to_numpy())
262+
if output.get('elementwise_loss') is not None:
263+
elementwise_losses.append(output['elementwise_loss'].to_numpy())
264+
265+
# Compute average loss per token (weighted by advantages)
266+
if elementwise_losses:
267+
all_losses = np.concatenate(elementwise_losses)
268+
avg_loss = np.mean(all_losses) if len(all_losses) > 0 else 0.0
269+
else:
270+
avg_loss = 0.0
241271

242272
gc.collect()
243273

244274
# ========== 7. Log ==========
245275
log_dict = metrics.calculate()
246-
log_dict['train/loss_per_token'] = loss_per_token
276+
log_dict['train/loss_per_token'] = float(avg_loss)
247277
log_dict['train/frac_reward_zero_std'] = frac_zero_std
248278
log_dict['train/num_training_samples'] = len(training_data)
249279
logger.info(f"Step {step}: {log_dict}")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
from .datum import datum_to_input_feature, input_feature_to_datum
3+
from .transformers_model import _extract_rl_fields_from_inputs as extract_rl_fields_from_inputs
34
from twinkle.utils import exists, requires

src/twinkle/server/tinker/common/datum.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ def datum_to_input_feature(datum: types.Datum) -> InputFeature:
2222
labels = datum.loss_fn_inputs['target_tokens'].to_numpy()
2323

2424
input_feature['labels'] = np.where(weights > 0, labels, -100).tolist()
25+
26+
# 3. Handle importance_sampling specific fields
27+
# 'logprobs' -> 'old_logps' (for GRPO loss)
28+
if 'logprobs' in datum.loss_fn_inputs:
29+
old_logps = datum.loss_fn_inputs['logprobs'].to_numpy().tolist()
30+
input_feature['old_logps'] = old_logps
31+
32+
# 'advantages' -> 'advantages' (for GRPO loss)
33+
if 'advantages' in datum.loss_fn_inputs:
34+
advantages = datum.loss_fn_inputs['advantages'].to_numpy().tolist()
35+
input_feature['advantages'] = advantages
2536

2637
return input_feature
2738

src/twinkle/server/tinker/common/megatron_model.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,53 @@
22
import numpy as np
33
import torch
44
from tinker import types
5-
from typing import List, TYPE_CHECKING
5+
from typing import List, TYPE_CHECKING, Tuple, Optional, Any
66
from twinkle import remote_class, remote_function
77
from twinkle.utils import exists, requires
88
from .datum import datum_to_input_feature
99
from .io_utils import create_checkpoint_manager
1010

11+
12+
def _extract_rl_fields_from_inputs(
13+
input_features: List[dict],
14+
kwargs: dict
15+
) -> Tuple[Optional[List], Optional[List], dict]:
16+
"""Extract old_logps and advantages from input features and kwargs.
17+
18+
This function handles the common logic for extracting reinforcement learning
19+
fields (old_logps and advantages) from both input features and kwargs.
20+
21+
Args:
22+
input_features: List of input feature dictionaries
23+
kwargs: Keyword arguments dictionary
24+
25+
Returns:
26+
Tuple of (old_logps, advantages, updated_kwargs)
27+
"""
28+
# Extract from kwargs first (higher priority)
29+
old_logps = kwargs.pop('old_logps', None)
30+
advantages = kwargs.pop('advantages', None)
31+
32+
# If not in kwargs, check input features
33+
if old_logps is None:
34+
old_logps_list = [inp.get('old_logps') for inp in input_features if inp.get('old_logps') is not None]
35+
if old_logps_list:
36+
old_logps = old_logps_list
37+
38+
if advantages is None:
39+
advantages_list = [inp.get('advantages') for inp in input_features if inp.get('advantages') is not None]
40+
if advantages_list:
41+
advantages = advantages_list
42+
43+
# Prepare kwargs for loss function
44+
loss_kwargs = kwargs.copy()
45+
if old_logps is not None:
46+
loss_kwargs['old_logps'] = old_logps
47+
if advantages is not None:
48+
loss_kwargs['advantages'] = advantages
49+
50+
return old_logps, advantages, loss_kwargs
51+
1152
if TYPE_CHECKING:
1253
from twinkle.model.megatron import MultiLoraMegatronModel as _MegatronBase
1354
elif exists('megatron_core'):
@@ -83,9 +124,12 @@ def forward_backward(self, *, inputs: List[types.Datum], **kwargs):
83124
# Convert Datum to InputFeature
84125
input_features = [datum_to_input_feature(datum) for datum in inputs]
85126

127+
# Extract old_logps and advantages using common utility
128+
old_logps, advantages, loss_kwargs = _extract_rl_fields_from_inputs(input_features, kwargs)
129+
86130
adapter_name = kwargs.get('adapter_name')
87131
# Megatron forward_backward returns loss directly
88-
loss = super().forward_backward(inputs=input_features, **kwargs)
132+
loss = super().forward_backward(inputs=input_features, **loss_kwargs)
89133

90134
# Get logits from outputs
91135
optimizer_config = self.optimizer_group.get(adapter_name)

src/twinkle/server/tinker/common/transformers_model.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,52 @@
11
import torch
22
from tinker import types
3-
from typing import List
3+
from typing import List, Tuple, Optional, Any
44
from twinkle.model import MultiLoraTransformersModel
55
from twinkle import remote_class, remote_function
66
from .datum import datum_to_input_feature
77
from .io_utils import create_checkpoint_manager
88

9+
10+
def _extract_rl_fields_from_inputs(
11+
input_features: List[dict],
12+
kwargs: dict
13+
) -> Tuple[Optional[List], Optional[List], dict]:
14+
"""Extract old_logps and advantages from input features and kwargs.
15+
16+
This function handles the common logic for extracting reinforcement learning
17+
fields (old_logps and advantages) from both input features and kwargs.
18+
19+
Args:
20+
input_features: List of input feature dictionaries
21+
kwargs: Keyword arguments dictionary
22+
23+
Returns:
24+
Tuple of (old_logps, advantages, updated_kwargs)
25+
"""
26+
# Extract from kwargs first (higher priority)
27+
old_logps = kwargs.pop('old_logps', None)
28+
advantages = kwargs.pop('advantages', None)
29+
30+
# If not in kwargs, check input features
31+
if old_logps is None:
32+
old_logps_list = [inp.get('old_logps') for inp in input_features if inp.get('old_logps') is not None]
33+
if old_logps_list:
34+
old_logps = old_logps_list
35+
36+
if advantages is None:
37+
advantages_list = [inp.get('advantages') for inp in input_features if inp.get('advantages') is not None]
38+
if advantages_list:
39+
advantages = advantages_list
40+
41+
# Prepare kwargs for loss function
42+
loss_kwargs = kwargs.copy()
43+
if old_logps is not None:
44+
loss_kwargs['old_logps'] = old_logps
45+
if advantages is not None:
46+
loss_kwargs['advantages'] = advantages
47+
48+
return old_logps, advantages, loss_kwargs
49+
950
@remote_class()
1051
class TwinkleCompatTransformersModel(MultiLoraTransformersModel):
1152
"""
@@ -45,6 +86,11 @@ def forward(self, *, inputs: List[types.Datum], **kwargs):
4586
# Convert Datum to InputFeature
4687
input_features = [datum_to_input_feature(datum) for datum in inputs]
4788

89+
# Extract old_logps and advantages using common utility
90+
old_logps, advantages, loss_kwargs = _extract_rl_fields_from_inputs(input_features, kwargs)
91+
# Update kwargs for forward pass (exclude loss-specific fields)
92+
kwargs.update({k: v for k, v in loss_kwargs.items() if k not in ['old_logps', 'advantages']})
93+
4894
outputs = super().forward(inputs=input_features, **kwargs)
4995
logits = outputs['logits'].detach().cpu() # shape (batch_size, seq_len, vocab_size)
5096
results = self._get_forward_output(inputs, logits)
@@ -54,6 +100,11 @@ def forward(self, *, inputs: List[types.Datum], **kwargs):
54100
def forward_only(self, *, inputs: List[types.Datum], **kwargs):
55101
# Convert Datum to InputFeature
56102
input_features = [datum_to_input_feature(datum) for datum in inputs]
103+
104+
# Extract old_logps and advantages using common utility
105+
old_logps, advantages, loss_kwargs = _extract_rl_fields_from_inputs(input_features, kwargs)
106+
# Update kwargs for forward pass (exclude loss-specific fields)
107+
kwargs.update({k: v for k, v in loss_kwargs.items() if k not in ['old_logps', 'advantages']})
57108

58109
outputs = super().forward_only(inputs=input_features, **kwargs)
59110
logits = outputs['logits'].detach().cpu() # shape (batch_size, seq_len, vocab_size)
@@ -62,9 +113,34 @@ def forward_only(self, *, inputs: List[types.Datum], **kwargs):
62113

63114
@remote_function(collect='mean')
64115
def calculate_loss(self, **kwargs):
65-
loss = super().calculate_loss(**kwargs)
116+
# Extract old_logps and advantages using common utility (for importance_sampling loss)
117+
# Note: We don't need the input_features here since this is called separately
118+
old_logps, advantages, loss_kwargs = _extract_rl_fields_from_inputs([], kwargs)
119+
120+
loss = super().calculate_loss(**loss_kwargs)
66121
return loss
67122

123+
@remote_function(dispatch='slice_dp', collect='flatten')
124+
def forward_backward(self, *, inputs: List[types.Datum], **kwargs):
125+
# Convert Datum to InputFeature
126+
input_features = [datum_to_input_feature(datum) for datum in inputs]
127+
128+
# Extract old_logps and advantages using common utility
129+
old_logps, advantages, loss_kwargs = _extract_rl_fields_from_inputs(input_features, kwargs)
130+
131+
# Forward pass
132+
outputs = super().forward(inputs=input_features, **kwargs)
133+
134+
# Calculate loss with extra parameters
135+
loss = super().calculate_loss(**loss_kwargs)
136+
137+
# Backward pass
138+
super().backward(**kwargs)
139+
140+
logits = outputs['logits'].detach().cpu() # shape (batch_size, seq_len, vocab_size)
141+
results = self._get_forward_output(inputs, logits)
142+
return results, loss
143+
68144
@remote_function()
69145
def step(self, *, adam_params: types.AdamParams, **kwargs):
70146
# Gradient clipping

src/twinkle/server/tinker/model.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,17 +362,34 @@ async def _do_forward_backward():
362362

363363
if self.use_megatron:
364364
# Megatron uses combined forward_backward, no separate backward/calculate_loss
365+
# Set loss first based on loss_fn
366+
if loss_fn == 'cross_entropy':
367+
self.model.set_loss('CrossEntropyLoss',
368+
adapter_name=adapter_name)
369+
elif loss_fn == 'importance_sampling':
370+
self.model.set_loss('GRPOLoss',
371+
adapter_name=adapter_name,
372+
epsilon=0.2, # Default GRPO epsilon
373+
beta=0.0) # No KL penalty by default
374+
else:
375+
raise ValueError(
376+
f'Unsupported loss function {loss_fn}')
377+
365378
output, loss = self.model.forward_backward(
366379
inputs=datum_list,
367380
adapter_name=adapter_name,
368381
**loss_fn_config)
369382
else:
370383
# Transformers uses separate forward, calculate_loss, backward
371-
# When use_megatron is True, we don't need to set the loss
372-
# Set loss first
384+
# Set loss first based on loss_fn
373385
if loss_fn == 'cross_entropy':
374386
self.model.set_loss('CrossEntropyLoss',
375387
adapter_name=adapter_name)
388+
elif loss_fn == 'importance_sampling':
389+
self.model.set_loss('GRPOLoss',
390+
adapter_name=adapter_name,
391+
epsilon=0.2, # Default GRPO epsilon
392+
beta=0.0) # No KL penalty by default
376393
else:
377394
raise ValueError(
378395
f'Unsupported loss function {loss_fn}')
@@ -382,8 +399,9 @@ async def _do_forward_backward():
382399
loss = self.model.calculate_loss(adapter_name=adapter_name,
383400
**loss_fn_config)
384401
self.model.backward(adapter_name=adapter_name)
402+
output_type = 'ImportanceSamplingLossReturn' if loss_fn == 'importance_sampling' else 'CrossEntropyLossReturn'
385403
return types.ForwardBackwardOutput(
386-
loss_fn_output_type='CrossEntropyLossReturn',
404+
loss_fn_output_type=output_type,
387405
loss_fn_outputs=output,
388406
metrics={'loss:avg': loss},
389407
)

0 commit comments

Comments
 (0)