Skip to content

Commit 76a20a3

Browse files
Support batched preprocessing (#101)
1 parent 14346a7 commit 76a20a3

File tree

9 files changed

+110
-22
lines changed

9 files changed

+110
-22
lines changed

cookbook/client/tinker/custom_service/short_math_grpo.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,13 @@
6060

6161
class MathPreprocessor(Preprocessor):
6262

63-
def __call__(self, sample):
63+
def __call__(self, rows):
64+
rows = self.map_col_to_row(rows)
65+
rows = [self.preprocess(row) for row in rows]
66+
rows = self.map_row_to_col(rows)
67+
return rows
68+
69+
def preprocess(self, sample):
6470
if sample['level'] not in ('Level 4', 'Level 5'):
6571
return Trajectory(messages=[], user_data=[])
6672

cookbook/client/tinker/modelscope_service/short_math_grpo.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,13 @@
6060

6161
class MathPreprocessor(Preprocessor):
6262

63-
def __call__(self, sample):
63+
def __call__(self, rows):
64+
rows = self.map_col_to_row(rows)
65+
rows = [self.preprocess(row) for row in rows]
66+
rows = self.map_row_to_col(rows)
67+
return rows
68+
69+
def preprocess(self, sample):
6470
if sample['level'] not in ('Level 4', 'Level 5'):
6571
return Trajectory(messages=[], user_data=[])
6672

cookbook/mm/fsdp2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919

2020
class LatexOCRProcessor(Preprocessor):
2121

22-
def __call__(self, row) -> Trajectory:
22+
def __call__(self, rows):
23+
rows = self.map_col_to_row(rows)
24+
rows = [self.preprocess(row) for row in rows]
25+
rows = self.map_row_to_col(rows)
26+
return rows
27+
28+
def preprocess(self, row) -> Trajectory:
2329
return Trajectory(
2430
messages=[
2531
Message(role='user', content='<image>Using LaTeX to perform OCR on the image.', images=[row['image']]),

docs/source_en/Components/Preprocessor and Filter/Preprocessor.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ The base class of Preprocessor:
77
```python
88
class Preprocessor:
99

10-
def __call__(self, row) -> Trajectory:
10+
def __call__(self, rows: List[Dict]) -> List[Trajectory]:
1111
...
1212
```
1313

14-
The format is to pass in a raw sample and output a `Trajectory`. If the sample cannot be used, you can directly return None.
14+
The format is to pass in a list of samples and output a list of `Trajectory`. If a sample cannot be used, you can directly ignore it.
1515

1616
We provide some basic Preprocessors, such as `SelfCognitionProcessor`:
1717

@@ -22,7 +22,7 @@ dataset.map('SelfCognitionProcessor', model_name='some-model', model_author='som
2222
Preprocessor contains the __call__ method, which means you can use a function to replace the class:
2323

2424
```python
25-
def self_cognition_preprocessor(row):
25+
def self_cognition_preprocessor(rows):
2626
...
27-
return Trajectory(...)
27+
return [Trajectory(...), ...]
2828
```

docs/source_zh/组件/预处理器和过滤器/Preprocessor.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ Preprocessor 的基类:
77
```python
88
class Preprocessor:
99

10-
def __call__(self, row) -> Trajectory:
10+
def __call__(self, rows: List[Dict]) -> List[Trajectory]:
1111
...
1212
```
1313

14-
格式为传入一个原始样本,输出一个`Trajectory`如果样本无法使用,可以直接返回None
14+
格式为传入一系列原始样本,输出对应的`Trajectory`如果某个样本无法使用,可以直接忽略它。输入条数和输出条数不必相同
1515

1616
我们提供了一些基本的 Preprocessor,例如 `SelfCognitionProcessor`
1717

@@ -22,7 +22,7 @@ dataset.map('SelfCognitionProcessor', model_name='some-model', model_author='som
2222
Preprocessor 包含 __call__ 方法,这意味着你可以使用 function 来代替类:
2323

2424
```python
25-
def self_cognition_preprocessor(row):
25+
def self_cognition_preprocessor(rows):
2626
...
27-
return Trajectory(...)
27+
return [Trajectory(...), ...]
2828
```

src/twinkle/dataset/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def map(self,
183183
key = next(iter(self.datasets.keys()))
184184
else:
185185
key = dataset_meta.get_id()
186-
kwargs['batched'] = False # TODO temporary change to False, because the interface does not support batched
186+
kwargs['batched'] = True
187187
with processing_lock(key):
188188
self.datasets[key] = self.datasets[key].map(preprocess_func, **kwargs)
189189
if len(self.datasets) == 1:
@@ -210,7 +210,7 @@ def filter(self,
210210
key = next(iter(self.datasets.keys()))
211211
else:
212212
key = dataset_meta.get_id()
213-
kwargs['batched'] = False # TODO temporary change to False, because the interface does not support batched
213+
kwargs['batched'] = False
214214
with processing_lock(key):
215215
self.datasets[key] = self.datasets[key].filter(filter_func, **kwargs)
216216
if len(self.datasets) == 1:

src/twinkle/preprocessor/base.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,38 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
from typing import Any, Dict, List
23

34
from twinkle.data_format import Trajectory
45

56

67
class Preprocessor:
78

8-
def __call__(self, row) -> Trajectory:
9+
@staticmethod
10+
def map_col_to_row(rows: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
11+
if not rows:
12+
return []
13+
_new_rows = []
14+
total_count = len(rows[next(iter(list(rows.keys())))])
15+
for i in range(total_count):
16+
row = {}
17+
for key in rows:
18+
row[key] = rows[key][i]
19+
_new_rows.append(row)
20+
return _new_rows
21+
22+
@staticmethod
23+
def map_row_to_col(rows: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
24+
if not rows:
25+
return {}
26+
27+
columns: Dict[str, List[Any]] = {}
28+
keys = rows[0].keys()
29+
30+
for key in keys:
31+
columns[key] = [row[key] for row in rows]
32+
33+
return columns
34+
35+
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
936
...
1037

1138

src/twinkle/preprocessor/llm.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
import re
3+
from typing import Any, Dict, List
34

45
from twinkle.data_format import Message, Trajectory
56
from .base import Preprocessor
67

78

89
class CompetitionMathProcessor(Preprocessor):
910

10-
def __call__(self, row) -> Trajectory:
11+
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
12+
rows = self.map_col_to_row(rows)
13+
rows = [self.preprocess(row) for row in rows]
14+
rows = self.map_row_to_col(rows)
15+
return rows
16+
17+
def preprocess(self, row) -> Dict[str, Any]:
1118
problem = row['problem']
1219
solution = row['solution']
1320
messages = [
@@ -19,7 +26,13 @@ def __call__(self, row) -> Trajectory:
1926

2027
class CompetitionMathGRPOProcessor(Preprocessor):
2128

22-
def __call__(self, row) -> Trajectory:
29+
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
30+
rows = self.map_col_to_row(rows)
31+
rows = [self.preprocess(row) for row in rows]
32+
rows = self.map_row_to_col(rows)
33+
return rows
34+
35+
def preprocess(self, row) -> Trajectory:
2336
problem = row['problem']
2437
solution = row['solution']
2538
messages = [
@@ -39,7 +52,13 @@ def __init__(self, model_name, model_author):
3952
self.model_name = model_name
4053
self.model_author = model_author
4154

42-
def __call__(self, row) -> Trajectory:
55+
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
56+
rows = self.map_col_to_row(rows)
57+
rows = [self.preprocess(row) for row in rows]
58+
rows = self.map_row_to_col(rows)
59+
return rows
60+
61+
def preprocess(self, row) -> Trajectory:
4362
problem = row['query'].replace('{{NAME}}', self.model_name).replace('{{AUTHOR}}', self.model_author)
4463
solution = row['response'].replace('{{NAME}}', self.model_name).replace('{{AUTHOR}}', self.model_author)
4564
messages = [
@@ -52,7 +71,13 @@ def __call__(self, row) -> Trajectory:
5271

5372
class AlpacaProcessor(Preprocessor):
5473

55-
def __call__(self, row) -> Trajectory:
74+
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
75+
rows = self.map_col_to_row(rows)
76+
rows = [self.preprocess(row) for row in rows]
77+
rows = self.map_row_to_col(rows)
78+
return rows
79+
80+
def preprocess(self, row) -> Trajectory:
5681
instruction = row.get('instruction') or ''
5782
input_text = row.get('input') or ''
5883
output_text = row.get('output') or ''
@@ -68,7 +93,13 @@ class CountdownProcessor(Preprocessor):
6893
system_prompt = ('You are a helpful assistant. You first thinks about the reasoning process '
6994
'in the mind and then provides the user with the answer.')
7095

71-
def __call__(self, row) -> Trajectory:
96+
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
97+
rows = self.map_col_to_row(rows)
98+
rows = [self.preprocess(row) for row in rows]
99+
rows = self.map_row_to_col(rows)
100+
return rows
101+
102+
def preprocess(self, row) -> Trajectory:
72103
nums = row.get('nums', [])
73104
target = row.get('response', row.get('target', 0))
74105

@@ -103,7 +134,13 @@ def extract_ground_truth(self, answer_str: str) -> str:
103134
return match.group(1).replace(',', '').strip()
104135
return ''
105136

106-
def __call__(self, row) -> Trajectory:
137+
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
138+
rows = self.map_col_to_row(rows)
139+
rows = [self.preprocess(row) for row in rows]
140+
rows = self.map_row_to_col(rows)
141+
return rows
142+
143+
def preprocess(self, row) -> Trajectory:
107144
question = row['question']
108145
answer = row.get('answer', '')
109146
ground_truth = self.extract_ground_truth(answer)

tests/preprocessor/test_preprocessor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,14 @@ def test_load_from_cache_file_false(self):
293293
# Modify processor, process again
294294
class ModifiedProcessor(CompetitionMathProcessor):
295295

296-
def __call__(self, row):
297-
traj = super().__call__(row)
296+
def __call__(self, rows):
297+
rows = self.map_col_to_row(rows)
298+
rows = [self.preprocess(row) for row in rows]
299+
rows = self.map_row_to_col(rows)
300+
return rows
301+
302+
def preprocess(self, row):
303+
traj = super().preprocess(row)
298304
traj['messages'][0]['content'] = 'Modified: ' + traj['messages'][0]['content']
299305
return traj
300306

0 commit comments

Comments
 (0)