Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cookbook/client/tinker/custom_service/short_math_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@

class MathPreprocessor(Preprocessor):

def __call__(self, sample):
def __call__(self, rows):
rows = self.map_col_to_row(rows)
rows = [self.preprocess(row) for row in rows]
rows = self.map_row_to_col(rows)
return rows

def preprocess(self, sample):
if sample['level'] not in ('Level 4', 'Level 5'):
return Trajectory(messages=[], user_data=[])

Expand Down
8 changes: 7 additions & 1 deletion cookbook/client/tinker/modelscope_service/short_math_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@

class MathPreprocessor(Preprocessor):

def __call__(self, sample):
def __call__(self, rows):
rows = self.map_col_to_row(rows)
rows = [self.preprocess(row) for row in rows]
rows = self.map_row_to_col(rows)
return rows

def preprocess(self, sample):
if sample['level'] not in ('Level 4', 'Level 5'):
return Trajectory(messages=[], user_data=[])

Expand Down
8 changes: 7 additions & 1 deletion cookbook/mm/fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@

class LatexOCRProcessor(Preprocessor):

def __call__(self, row) -> Trajectory:
def __call__(self, rows):
rows = self.map_col_to_row(rows)
rows = [self.preprocess(row) for row in rows]
rows = self.map_row_to_col(rows)
return rows

def preprocess(self, row) -> Trajectory:
return Trajectory(
messages=[
Message(role='user', content='<image>Using LaTeX to perform OCR on the image.', images=[row['image']]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ The base class of Preprocessor:
```python
class Preprocessor:

def __call__(self, row) -> Trajectory:
def __call__(self, rows: List[Dict]) -> List[Trajectory]:
...
```

The format is to pass in a raw sample and output a `Trajectory`. If the sample cannot be used, you can directly return None.
The format is to pass in a list of samples and output a list of `Trajectory`. If a sample cannot be used, you can directly ignore it.

We provide some basic Preprocessors, such as `SelfCognitionProcessor`:

Expand All @@ -22,7 +22,7 @@ dataset.map('SelfCognitionProcessor', model_name='some-model', model_author='som
Preprocessor contains the __call__ method, which means you can use a function to replace the class:

```python
def self_cognition_preprocessor(row):
def self_cognition_preprocessor(rows):
...
return Trajectory(...)
return [Trajectory(...), ...]
```
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ Preprocessor 的基类:
```python
class Preprocessor:

def __call__(self, row) -> Trajectory:
def __call__(self, rows: List[Dict]) -> List[Trajectory]:
...
```

格式为传入一个原始样本,输出一个`Trajectory`。如果样本无法使用,可以直接返回None
格式为传入一系列原始样本,输出对应的`Trajectory`。如果某个样本无法使用,可以直接忽略它。输入条数和输出条数不必相同

我们提供了一些基本的 Preprocessor,例如 `SelfCognitionProcessor`:

Expand All @@ -22,7 +22,7 @@ dataset.map('SelfCognitionProcessor', model_name='some-model', model_author='som
Preprocessor 包含 __call__ 方法,这意味着你可以使用 function 来代替类:

```python
def self_cognition_preprocessor(row):
def self_cognition_preprocessor(rows):
...
return Trajectory(...)
return [Trajectory(...), ...]
```
4 changes: 2 additions & 2 deletions src/twinkle/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def map(self,
key = next(iter(self.datasets.keys()))
else:
key = dataset_meta.get_id()
kwargs['batched'] = False # TODO temporary change to False, because the interface does not support batched
kwargs['batched'] = True
with processing_lock(key):
self.datasets[key] = self.datasets[key].map(preprocess_func, **kwargs)
if len(self.datasets) == 1:
Expand All @@ -210,7 +210,7 @@ def filter(self,
key = next(iter(self.datasets.keys()))
else:
key = dataset_meta.get_id()
kwargs['batched'] = False # TODO temporary change to False, because the interface does not support batched
kwargs['batched'] = False
with processing_lock(key):
self.datasets[key] = self.datasets[key].filter(filter_func, **kwargs)
if len(self.datasets) == 1:
Expand Down
29 changes: 28 additions & 1 deletion src/twinkle/preprocessor/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from typing import Any, Dict, List

from twinkle.data_format import Trajectory


class Preprocessor:

def __call__(self, row) -> Trajectory:
@staticmethod
def map_col_to_row(rows: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
if not rows:
return []
_new_rows = []
total_count = len(rows[next(iter(list(rows.keys())))])
for i in range(total_count):
row = {}
for key in rows:
row[key] = rows[key][i]
_new_rows.append(row)
return _new_rows

@staticmethod
def map_row_to_col(rows: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
if not rows:
return {}

columns: Dict[str, List[Any]] = {}
keys = rows[0].keys()

for key in keys:
columns[key] = [row[key] for row in rows]

return columns

def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
...


Expand Down
49 changes: 43 additions & 6 deletions src/twinkle/preprocessor/llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import re
from typing import Any, Dict, List

from twinkle.data_format import Message, Trajectory
from .base import Preprocessor


class CompetitionMathProcessor(Preprocessor):

def __call__(self, row) -> Trajectory:
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
rows = self.map_col_to_row(rows)
rows = [self.preprocess(row) for row in rows]
rows = self.map_row_to_col(rows)
return rows

def preprocess(self, row) -> Dict[str, Any]:
problem = row['problem']
solution = row['solution']
messages = [
Expand All @@ -19,7 +26,13 @@ def __call__(self, row) -> Trajectory:

class CompetitionMathGRPOProcessor(Preprocessor):

def __call__(self, row) -> Trajectory:
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
rows = self.map_col_to_row(rows)
rows = [self.preprocess(row) for row in rows]
rows = self.map_row_to_col(rows)
return rows

def preprocess(self, row) -> Trajectory:
problem = row['problem']
solution = row['solution']
messages = [
Expand All @@ -39,7 +52,13 @@ def __init__(self, model_name, model_author):
self.model_name = model_name
self.model_author = model_author

def __call__(self, row) -> Trajectory:
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
rows = self.map_col_to_row(rows)
rows = [self.preprocess(row) for row in rows]
rows = self.map_row_to_col(rows)
return rows

def preprocess(self, row) -> Trajectory:
problem = row['query'].replace('{{NAME}}', self.model_name).replace('{{AUTHOR}}', self.model_author)
solution = row['response'].replace('{{NAME}}', self.model_name).replace('{{AUTHOR}}', self.model_author)
messages = [
Expand All @@ -52,7 +71,13 @@ def __call__(self, row) -> Trajectory:

class AlpacaProcessor(Preprocessor):

def __call__(self, row) -> Trajectory:
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
rows = self.map_col_to_row(rows)
rows = [self.preprocess(row) for row in rows]
rows = self.map_row_to_col(rows)
return rows

def preprocess(self, row) -> Trajectory:
instruction = row.get('instruction') or ''
input_text = row.get('input') or ''
output_text = row.get('output') or ''
Expand All @@ -68,7 +93,13 @@ class CountdownProcessor(Preprocessor):
system_prompt = ('You are a helpful assistant. You first thinks about the reasoning process '
'in the mind and then provides the user with the answer.')

def __call__(self, row) -> Trajectory:
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
rows = self.map_col_to_row(rows)
rows = [self.preprocess(row) for row in rows]
rows = self.map_row_to_col(rows)
return rows

def preprocess(self, row) -> Trajectory:
nums = row.get('nums', [])
target = row.get('response', row.get('target', 0))

Expand Down Expand Up @@ -103,7 +134,13 @@ def extract_ground_truth(self, answer_str: str) -> str:
return match.group(1).replace(',', '').strip()
return ''

def __call__(self, row) -> Trajectory:
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
rows = self.map_col_to_row(rows)
rows = [self.preprocess(row) for row in rows]
rows = self.map_row_to_col(rows)
return rows

def preprocess(self, row) -> Trajectory:
question = row['question']
answer = row.get('answer', '')
ground_truth = self.extract_ground_truth(answer)
Expand Down
10 changes: 8 additions & 2 deletions tests/preprocessor/test_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,14 @@ def test_load_from_cache_file_false(self):
# Modify processor, process again
class ModifiedProcessor(CompetitionMathProcessor):

def __call__(self, row):
traj = super().__call__(row)
def __call__(self, rows):
rows = self.map_col_to_row(rows)
rows = [self.preprocess(row) for row in rows]
rows = self.map_row_to_col(rows)
return rows

def preprocess(self, row):
traj = super().preprocess(row)
traj['messages'][0]['content'] = 'Modified: ' + traj['messages'][0]['content']
return traj

Expand Down
Loading