Skip to content

Commit 1d8094d

Browse files
committed
Merge remote-tracking branch 'origin/dev' into kernels_unittest_fix_ljl
2 parents 2d4a19f + 9b875d7 commit 1d8094d

File tree

15 files changed

+64
-149
lines changed

15 files changed

+64
-149
lines changed

cookbook/client/twinkle/grpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajector
107107

108108
def compute_rewards(trajectories: List[dict], ) -> Tuple[List[float], List[float], List[float]]:
109109
"""Compute format and accuracy rewards for Countdown game."""
110-
from twinkle.reward import CountDownAccuracy, FormatReward
110+
from twinkle.reward import FormatReward
111111
format_rewards = FormatReward()(trajectories, [])
112112
accuracy_rewards = CountDownAccuracy()(trajectories, [])
113113
total_rewards = [a + b for a, b in zip(accuracy_rewards, format_rewards)]

cookbook/rl/grpo.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,6 @@ def main():
142142
all_input_data
143143
)
144144
metrics.accumulate(
145-
None,
146-
None,
147145
completion_lengths=all_completion_lengths,
148146
rewards={
149147
'total': total_rewards,
@@ -159,7 +157,7 @@ def main():
159157
)
160158
advantages = advantages.tolist()
161159

162-
model.forward_backward(inputs=all_input_data, old_logps=all_old_logps, advantages=advantages)
160+
model.forward_backward(inputs=all_input_data, old_logps=all_old_logps, advantages=advantages, micro_batch_size=2)
163161
model.clip_grad_and_step()
164162
optim_step += 1
165163
log_dict = metrics.calculate()

docs/source_en/Components/Reward/Reward.md

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,6 @@ reward_fn = FormatReward()
4444
rewards = reward_fn(trajectories, ground_truths)
4545
```
4646

47-
## CountDownAccuracyReward
48-
49-
The countdown accuracy reward function provides partial rewards when answers are close to correct.
50-
51-
```python
52-
from twinkle.reward import CountDownAccuracyReward
53-
54-
reward_fn = CountDownAccuracyReward()
55-
rewards = reward_fn(trajectories, ground_truths)
56-
```
57-
5847
## Custom Reward Functions
5948

6049
You can create custom rewards by inheriting from the Reward base class or using functions:

docs/source_en/Usage Guide/Train-as-a-Service.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ API endpoint: `base_url="https://www.modelscope.cn/twinkle"`
1616

1717
## Step 2. Review the Cookbook and Customize Development
1818

19-
We strongly recommend that developers review our [cookbook](https://github.com/modelscope/twinkle/tree/main/cookbook/client/) and build upon the training code provided there.
19+
We strongly recommend that developers review our [cookbook](https://github.com/modelscope/twinkle/tree/main/cookbook/client/tinker) and build upon the training code provided there.
20+
21+
> The ModelScope server is tinker-compatible, so use the tinker cookbooks. In the future version, we will support a server works both for twinkle/tinker clients.
2022
2123
Developers can customize datasets, advantage functions, rewards, templates, and more. However, the Loss component is not currently customizable since it needs to be executed on the server side (for security reasons). If you need support for additional Loss functions, you can upload your Loss implementation to ModelHub and contact us via the Q&A group or through an issue to have the corresponding component added to the whitelist.
2224

docs/source_zh/使用指引/训练服务.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
## Step 2. 查看 Cookbook 并二次定制开发
2121

22-
我们强烈推荐开发者查看我们的 [cookbook](https://github.com/modelscope/twinkle/tree/main/cookbook/client/),并根据其中的训练代码进行二次开发。
22+
我们强烈推荐开发者查看我们的 [cookbook](https://github.com/modelscope/twinkle/tree/main/cookbook/client/tinker),并根据其中的训练代码进行二次开发。
23+
24+
> 目前的服务兼容tinker client,因此请使用tinker的cookbook进行训练。后续我们会支持单服务器支持twinkle/tinker双client。
2325
2426
开发者可以定制数据集/优势函数/奖励/模板等,其中 Loss 部分由于需要在服务端执行,因此当前暂不支持(安全性原因)。
2527
如果需要支持您的额外 Loss,可以将该 Loss 实现上传到 ModelHub 中,并在答疑群中或者 issue 中联系我们,将对应组件开放白名单即可使用。

docs/source_zh/组件/奖励/Reward.md

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,6 @@ reward_fn = FormatReward()
4444
rewards = reward_fn(trajectories, ground_truths)
4545
```
4646

47-
## CountDownAccuracyReward
48-
49-
倒计时准确率奖励函数,在答案接近正确时给予部分奖励。
50-
51-
```python
52-
from twinkle.reward import CountDownAccuracyReward
53-
54-
reward_fn = CountDownAccuracyReward()
55-
rewards = reward_fn(trajectories, ground_truths)
56-
```
57-
5847
## 自定义奖励函数
5948

6049
你可以通过继承 Reward 基类或使用函数来创建自定义奖励:

src/twinkle/loss/grpo.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,6 @@ def __init__(
3434
self.beta = beta
3535
self.ignore_index = ignore_index
3636

37-
def _extract_advantages_from_trajectories(self, trajectories: List[Trajectory],
38-
device: 'torch.device') -> 'torch.Tensor':
39-
"""Extract advantages from trajectory objects."""
40-
import torch
41-
advantages_list = []
42-
for traj in trajectories:
43-
if isinstance(traj, dict):
44-
adv = traj.get('advantages', None)
45-
else:
46-
adv = getattr(traj, 'advantages', None)
47-
assert adv is not None, "trajectories must contain 'advantages'"
48-
advantages_list.append(float(adv))
49-
return torch.tensor(advantages_list, dtype=torch.float32, device=device)
50-
5137
def _compute_loss_mask(self, labels: 'torch.Tensor') -> 'torch.Tensor':
5238
"""
5339
Compute loss mask from labels.
@@ -275,7 +261,6 @@ def __call__(
275261
*,
276262
old_logps: Optional[Union['torch.Tensor', List[List[float]]]] = None,
277263
ref_logps: Optional['torch.Tensor'] = None,
278-
trajectories: Optional[List[Trajectory]] = None, # TODO: remove this argument
279264
advantages: Optional[Union['torch.Tensor', List[float], np.ndarray]] = None,
280265
**kwargs,
281266
) -> 'torch.Tensor':
@@ -326,13 +311,8 @@ def __call__(
326311
# In padding_free / packing mode the processor concatenates all
327312
# sequences into a single row [1, total_tokens]. We detect this
328313
# by checking: batch_size == 1 but the actual number of sequences
329-
# (known from trajectories or advantages) is greater than 1.
330-
if trajectories is not None:
331-
num_sequences = len(trajectories)
332-
elif advantages is not None:
333-
num_sequences = len(advantages) if isinstance(advantages, (list, tuple)) else advantages.shape[0]
334-
else:
335-
num_sequences = logps.shape[0]
314+
# is greater than 1.
315+
num_sequences = len(advantages) if isinstance(advantages, (list, tuple)) else advantages.shape[0]
336316
is_packed = (logps.shape[0] == 1 and num_sequences > 1)
337317
if is_packed:
338318
position_ids = inputs.get('position_ids')

src/twinkle/metric/completion_and_reward.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def reset(self):
2121
self.completion_lengths = []
2222

2323
def accumulate(self,
24-
_,
25-
__,
24+
inputs=None, # ignore
25+
outputs=None,# ignore
2626
*,
2727
rewards=None,
2828
completion_lengths=None,
@@ -55,11 +55,11 @@ def _std(statistic_list: List[float]) -> float:
5555
return 0.0
5656

5757
def calculate(self) -> Dict[str, Any]:
58-
metric_dict = {
59-
'profiling/Time taken: move_model_to_sampler': self._mean(self.weight_sync_time),
60-
'profiling/Time taken: generate': self._mean(self.generate_time),
61-
}
62-
58+
metric_dict = {}
59+
if self.weight_sync_time is not None:
60+
metric_dict['profiling/Time taken: move_model_to_sampler'] = self._mean(self.weight_sync_time)
61+
if self.generate_time is not None:
62+
metric_dict['profiling/Time taken: generate'] = self._mean(self.generate_time)
6363
for key, values in self.rewards.items():
6464
metric_dict[f'train/{key}_reward'] = self._mean(values)
6565
metric_dict[f'train/{key}_reward_std'] = self._std(values)

src/twinkle/model/megatron/megatron.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,35 @@ def forward_backward(self,
393393
else:
394394
seq_length = original_seq_length
395395

396-
loss_extra_kwargs = kwargs
396+
num_microbatches = len(inputs)
397+
loss_extra_kwargs_per_mb = []
398+
if num_microbatches <= 1:
399+
loss_extra_kwargs_per_mb = [kwargs]
400+
else:
401+
for mb_idx in range(num_microbatches):
402+
mb_start = mb_idx * micro_batch_size
403+
mb_end = mb_start + micro_batch_size
404+
mb_kwargs = {}
405+
for key, value in kwargs.items():
406+
if isinstance(value, torch.Tensor) and value.dim() >= 1 and value.shape[0] > micro_batch_size:
407+
mb_kwargs[key] = value[mb_start:mb_end]
408+
elif isinstance(value, np.ndarray) and value.ndim >= 1 and value.shape[0] > micro_batch_size:
409+
mb_kwargs[key] = value[mb_start:mb_end]
410+
elif isinstance(value, (list, tuple)) and len(value) > micro_batch_size:
411+
mb_kwargs[key] = value[mb_start:mb_end]
412+
else:
413+
# Scalars, small tensors, or non-sliceable values pass through as-is
414+
mb_kwargs[key] = value
415+
loss_extra_kwargs_per_mb.append(mb_kwargs)
416+
417+
_mb_counter = [0] # mutable counter for closure
397418

398419
def post_loss_function(output_tensor, inputs):
420+
mb_idx = _mb_counter[0]
421+
_mb_counter[0] += 1
422+
current_kwargs = loss_extra_kwargs_per_mb[mb_idx % len(loss_extra_kwargs_per_mb)]
399423
outputs = ModelOutput(logits=output_tensor)
400-
result = loss_instance(inputs, outputs, **loss_extra_kwargs)
424+
result = loss_instance(inputs, outputs, **current_kwargs)
401425
if isinstance(result, tuple):
402426
losses, counts = result
403427
else:
@@ -789,7 +813,7 @@ def clip_grad_and_step(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
789813
self.zero_grad(**kwargs)
790814
self.lr_step(**kwargs)
791815

792-
@remote_function(dispatch='all', sync=True)
816+
@remote_function(dispatch='all', collect='first', sync=True)
793817
def save(self,
794818
name: Optional[str] = None,
795819
output_dir: Optional[str] = None,

src/twinkle/preprocessor/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class GSM8KProcessor(Preprocessor):
9696
'numerical answer after ####.\n'
9797
'For example:\n<think> ... reasoning ... </think>\n#### 42')
9898

99-
def extract_ground_truth(answer_str: str) -> str:
99+
def extract_ground_truth(self, answer_str: str) -> str:
100100
"""Extract the number after '####' from GSM8K answer."""
101101
match = re.search(r'####\s*([\-\d,\.]+)', answer_str)
102102
if match:

0 commit comments

Comments
 (0)