diff --git a/cookbook/client/tinker/custom_service/short_math_grpo.py b/cookbook/client/tinker/custom_service/short_math_grpo.py index 6b1e6cea..6e34f899 100644 --- a/cookbook/client/tinker/custom_service/short_math_grpo.py +++ b/cookbook/client/tinker/custom_service/short_math_grpo.py @@ -60,7 +60,13 @@ class MathPreprocessor(Preprocessor): - def __call__(self, sample): + def __call__(self, rows): + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows + + def preprocess(self, sample): if sample['level'] not in ('Level 4', 'Level 5'): return Trajectory(messages=[], user_data=[]) diff --git a/cookbook/client/tinker/modelscope_service/short_math_grpo.py b/cookbook/client/tinker/modelscope_service/short_math_grpo.py index 3d3ad2c2..225f8219 100644 --- a/cookbook/client/tinker/modelscope_service/short_math_grpo.py +++ b/cookbook/client/tinker/modelscope_service/short_math_grpo.py @@ -60,7 +60,13 @@ class MathPreprocessor(Preprocessor): - def __call__(self, sample): + def __call__(self, rows): + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows + + def preprocess(self, sample): if sample['level'] not in ('Level 4', 'Level 5'): return Trajectory(messages=[], user_data=[]) diff --git a/cookbook/mm/fsdp2.py b/cookbook/mm/fsdp2.py index a93cd705..cbe6f50d 100644 --- a/cookbook/mm/fsdp2.py +++ b/cookbook/mm/fsdp2.py @@ -19,7 +19,13 @@ class LatexOCRProcessor(Preprocessor): - def __call__(self, row) -> Trajectory: + def __call__(self, rows): + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows + + def preprocess(self, row) -> Trajectory: return Trajectory( messages=[ Message(role='user', content='Using LaTeX to perform OCR on the image.', images=[row['image']]), diff --git a/docs/source_en/Components/Preprocessor and Filter/Preprocessor.md b/docs/source_en/Components/Preprocessor and Filter/Preprocessor.md index 51847bf2..828c8c5c 100644 --- a/docs/source_en/Components/Preprocessor and Filter/Preprocessor.md +++ b/docs/source_en/Components/Preprocessor and Filter/Preprocessor.md @@ -7,11 +7,11 @@ The base class of Preprocessor: ```python class Preprocessor: - def __call__(self, row) -> Trajectory: + def __call__(self, rows: List[Dict]) -> List[Trajectory]: ... ``` -The format is to pass in a raw sample and output a `Trajectory`. If the sample cannot be used, you can directly return None. +The format is to pass in a list of samples and output a list of `Trajectory`. If a sample cannot be used, you can directly ignore it. We provide some basic Preprocessors, such as `SelfCognitionProcessor`: @@ -22,7 +22,7 @@ dataset.map('SelfCognitionProcessor', model_name='some-model', model_author='som Preprocessor contains the __call__ method, which means you can use a function to replace the class: ```python -def self_cognition_preprocessor(row): +def self_cognition_preprocessor(rows): ... - return Trajectory(...) + return [Trajectory(...), ...] ``` diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/Preprocessor.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/Preprocessor.md" index ab1c58bd..bdc11d6c 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/Preprocessor.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/Preprocessor.md" @@ -7,11 +7,11 @@ Preprocessor 的基类: ```python class Preprocessor: - def __call__(self, row) -> Trajectory: + def __call__(self, rows: List[Dict]) -> List[Trajectory]: ... ``` -格式为传入一个原始样本,输出一个`Trajectory`。如果样本无法使用,可以直接返回None。 +格式为传入一系列原始样本,输出对应的`Trajectory`。如果某个样本无法使用,可以直接忽略它。输入条数和输出条数不必相同。 我们提供了一些基本的 Preprocessor,例如 `SelfCognitionProcessor`: @@ -22,7 +22,7 @@ dataset.map('SelfCognitionProcessor', model_name='some-model', model_author='som Preprocessor 包含 __call__ 方法,这意味着你可以使用 function 来代替类: ```python -def self_cognition_preprocessor(row): +def self_cognition_preprocessor(rows): ... - return Trajectory(...) + return [Trajectory(...), ...] ``` diff --git a/src/twinkle/dataset/base.py b/src/twinkle/dataset/base.py index 5fc31a8e..10b3bb92 100644 --- a/src/twinkle/dataset/base.py +++ b/src/twinkle/dataset/base.py @@ -183,7 +183,7 @@ def map(self, key = next(iter(self.datasets.keys())) else: key = dataset_meta.get_id() - kwargs['batched'] = False # TODO temporary change to False, because the interface does not support batched + kwargs['batched'] = True with processing_lock(key): self.datasets[key] = self.datasets[key].map(preprocess_func, **kwargs) if len(self.datasets) == 1: @@ -210,7 +210,7 @@ def filter(self, key = next(iter(self.datasets.keys())) else: key = dataset_meta.get_id() - kwargs['batched'] = False # TODO temporary change to False, because the interface does not support batched + kwargs['batched'] = False with processing_lock(key): self.datasets[key] = self.datasets[key].filter(filter_func, **kwargs) if len(self.datasets) == 1: diff --git a/src/twinkle/preprocessor/base.py b/src/twinkle/preprocessor/base.py index 035c178e..06ad06ba 100644 --- a/src/twinkle/preprocessor/base.py +++ b/src/twinkle/preprocessor/base.py @@ -1,11 +1,38 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from typing import Any, Dict, List from twinkle.data_format import Trajectory class Preprocessor: - def __call__(self, row) -> Trajectory: + @staticmethod + def map_col_to_row(rows: Dict[str, List[Any]]) -> List[Dict[str, Any]]: + if not rows: + return [] + _new_rows = [] + total_count = len(rows[next(iter(list(rows.keys())))]) + for i in range(total_count): + row = {} + for key in rows: + row[key] = rows[key][i] + _new_rows.append(row) + return _new_rows + + @staticmethod + def map_row_to_col(rows: List[Dict[str, Any]]) -> Dict[str, List[Any]]: + if not rows: + return {} + + columns: Dict[str, List[Any]] = {} + keys = rows[0].keys() + + for key in keys: + columns[key] = [row[key] for row in rows] + + return columns + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: ... diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py index 45ba3125..ddafb351 100644 --- a/src/twinkle/preprocessor/llm.py +++ b/src/twinkle/preprocessor/llm.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import re +from typing import Any, Dict, List from twinkle.data_format import Message, Trajectory from .base import Preprocessor @@ -7,7 +8,13 @@ class CompetitionMathProcessor(Preprocessor): - def __call__(self, row) -> Trajectory: + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows + + def preprocess(self, row) -> Dict[str, Any]: problem = row['problem'] solution = row['solution'] messages = [ @@ -19,7 +26,13 @@ def __call__(self, row) -> Trajectory: class CompetitionMathGRPOProcessor(Preprocessor): - def __call__(self, row) -> Trajectory: + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows + + def preprocess(self, row) -> Trajectory: problem = row['problem'] solution = row['solution'] messages = [ @@ -39,7 +52,13 @@ def __init__(self, model_name, model_author): self.model_name = model_name self.model_author = model_author - def __call__(self, row) -> Trajectory: + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows + + def preprocess(self, row) -> Trajectory: problem = row['query'].replace('{{NAME}}', self.model_name).replace('{{AUTHOR}}', self.model_author) solution = row['response'].replace('{{NAME}}', self.model_name).replace('{{AUTHOR}}', self.model_author) messages = [ @@ -52,7 +71,13 @@ def __call__(self, row) -> Trajectory: class AlpacaProcessor(Preprocessor): - def __call__(self, row) -> Trajectory: + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows + + def preprocess(self, row) -> Trajectory: instruction = row.get('instruction') or '' input_text = row.get('input') or '' output_text = row.get('output') or '' @@ -68,7 +93,13 @@ class CountdownProcessor(Preprocessor): system_prompt = ('You are a helpful assistant. You first thinks about the reasoning process ' 'in the mind and then provides the user with the answer.') - def __call__(self, row) -> Trajectory: + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows + + def preprocess(self, row) -> Trajectory: nums = row.get('nums', []) target = row.get('response', row.get('target', 0)) @@ -103,7 +134,13 @@ def extract_ground_truth(self, answer_str: str) -> str: return match.group(1).replace(',', '').strip() return '' - def __call__(self, row) -> Trajectory: + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows + + def preprocess(self, row) -> Trajectory: question = row['question'] answer = row.get('answer', '') ground_truth = self.extract_ground_truth(answer) diff --git a/tests/preprocessor/test_preprocessor.py b/tests/preprocessor/test_preprocessor.py index 8b7b90c3..a4a28b6c 100644 --- a/tests/preprocessor/test_preprocessor.py +++ b/tests/preprocessor/test_preprocessor.py @@ -293,8 +293,14 @@ def test_load_from_cache_file_false(self): # Modify processor, process again class ModifiedProcessor(CompetitionMathProcessor): - def __call__(self, row): - traj = super().__call__(row) + def __call__(self, rows): + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows + + def preprocess(self, row): + traj = super().preprocess(row) traj['messages'][0]['content'] = 'Modified: ' + traj['messages'][0]['content'] return traj