11# Copyright (c) ModelScope Contributors. All rights reserved.
22import re
3+ from typing import Any , Dict , List
34
45from twinkle .data_format import Message , Trajectory
56from .base import Preprocessor
67
78
89class CompetitionMathProcessor (Preprocessor ):
910
10- def __call__ (self , row ) -> Trajectory :
11+ def __call__ (self , rows : Dict [str , List [Any ]]) -> Dict [str , List [Any ]]:
12+ rows = self .map_col_to_row (rows )
13+ rows = [self .preprocess (row ) for row in rows ]
14+ rows = self .map_row_to_col (rows )
15+ return rows
16+
17+ def preprocess (self , row ) -> Dict [str , Any ]:
1118 problem = row ['problem' ]
1219 solution = row ['solution' ]
1320 messages = [
@@ -19,7 +26,13 @@ def __call__(self, row) -> Trajectory:
1926
2027class CompetitionMathGRPOProcessor (Preprocessor ):
2128
22- def __call__ (self , row ) -> Trajectory :
29+ def __call__ (self , rows : Dict [str , List [Any ]]) -> Dict [str , List [Any ]]:
30+ rows = self .map_col_to_row (rows )
31+ rows = [self .preprocess (row ) for row in rows ]
32+ rows = self .map_row_to_col (rows )
33+ return rows
34+
35+ def preprocess (self , row ) -> Trajectory :
2336 problem = row ['problem' ]
2437 solution = row ['solution' ]
2538 messages = [
@@ -39,7 +52,13 @@ def __init__(self, model_name, model_author):
3952 self .model_name = model_name
4053 self .model_author = model_author
4154
42- def __call__ (self , row ) -> Trajectory :
55+ def __call__ (self , rows : Dict [str , List [Any ]]) -> Dict [str , List [Any ]]:
56+ rows = self .map_col_to_row (rows )
57+ rows = [self .preprocess (row ) for row in rows ]
58+ rows = self .map_row_to_col (rows )
59+ return rows
60+
61+ def preprocess (self , row ) -> Trajectory :
4362 problem = row ['query' ].replace ('{{NAME}}' , self .model_name ).replace ('{{AUTHOR}}' , self .model_author )
4463 solution = row ['response' ].replace ('{{NAME}}' , self .model_name ).replace ('{{AUTHOR}}' , self .model_author )
4564 messages = [
@@ -52,7 +71,13 @@ def __call__(self, row) -> Trajectory:
5271
5372class AlpacaProcessor (Preprocessor ):
5473
55- def __call__ (self , row ) -> Trajectory :
74+ def __call__ (self , rows : Dict [str , List [Any ]]) -> Dict [str , List [Any ]]:
75+ rows = self .map_col_to_row (rows )
76+ rows = [self .preprocess (row ) for row in rows ]
77+ rows = self .map_row_to_col (rows )
78+ return rows
79+
80+ def preprocess (self , row ) -> Trajectory :
5681 instruction = row .get ('instruction' ) or ''
5782 input_text = row .get ('input' ) or ''
5883 output_text = row .get ('output' ) or ''
@@ -68,7 +93,13 @@ class CountdownProcessor(Preprocessor):
6893 system_prompt = ('You are a helpful assistant. You first thinks about the reasoning process '
6994 'in the mind and then provides the user with the answer.' )
7095
71- def __call__ (self , row ) -> Trajectory :
96+ def __call__ (self , rows : Dict [str , List [Any ]]) -> Dict [str , List [Any ]]:
97+ rows = self .map_col_to_row (rows )
98+ rows = [self .preprocess (row ) for row in rows ]
99+ rows = self .map_row_to_col (rows )
100+ return rows
101+
102+ def preprocess (self , row ) -> Trajectory :
72103 nums = row .get ('nums' , [])
73104 target = row .get ('response' , row .get ('target' , 0 ))
74105
@@ -103,7 +134,13 @@ def extract_ground_truth(self, answer_str: str) -> str:
103134 return match .group (1 ).replace (',' , '' ).strip ()
104135 return ''
105136
106- def __call__ (self , row ) -> Trajectory :
137+ def __call__ (self , rows : Dict [str , List [Any ]]) -> Dict [str , List [Any ]]:
138+ rows = self .map_col_to_row (rows )
139+ rows = [self .preprocess (row ) for row in rows ]
140+ rows = self .map_row_to_col (rows )
141+ return rows
142+
143+ def preprocess (self , row ) -> Trajectory :
107144 question = row ['question' ]
108145 answer = row .get ('answer' , '' )
109146 ground_truth = self .extract_ground_truth (answer )
0 commit comments