3131from twinkle .model import TransformersModel
3232from twinkle .preprocessor import Preprocessor
3333from twinkle .processor import InputProcessor
34+ from twinkle .reward import MathReward
3435from twinkle .reward .base import Reward
3536from twinkle .sampler import vLLMSampler
3637from twinkle .template import Template
4445
4546MODEL_GPUS = int (os .environ .get ('MODEL_GPUS' , 4 ))
4647SAMPLER_GPUS = int (os .environ .get ('SAMPLER_GPUS' , 4 ))
47- SAMPLER_TP = int (os .environ .get ('SAMPLER_TP' , SAMPLER_GPUS // 2 ))
48+ SAMPLER_TP = int (os .environ .get ('SAMPLER_TP' , 1 ))
4849NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
4950
5051PP_SIZE = 2
8384
8485SYSTEM_PROMPT = (
8586 "You are a helpful math assistant. Solve the problem step by step. "
86- "Show your reasoning in <think> </think> tags, then give the final "
87- "numerical answer after #### .\n "
88- "For example:\n <think> ... reasoning ... </think>\n #### 42 "
87+ "YOU MUST Show your reasoning in <think> </think> tags, then give the final "
88+ "numerical answer with boxed \\ boxed{} .\n "
89+ "For example:\n <think> ... reasoning ... </think>\n \\ boxed{42} "
8990)
9091
9192
93+ class MathPreprocessor (Preprocessor ):
94+
95+ def __call__ (self , sample ):
96+ if sample ['level' ] not in ('Level 4' , 'Level 5' ):
97+ return None
98+
99+ def get_boxed_answer (text ):
100+ match = re .search (r'\\boxed\{([^}]+)\}' , text )
101+ return match .group (1 ) if match else None
102+
103+ ground_truth = get_boxed_answer (sample ['solution' ])
104+ if ground_truth is None :
105+ return None
106+ problem = sample ['problem' ]
107+ solution = sample ['solution' ]
108+ return Trajectory (
109+ messages = [
110+ Message (role = 'user' , content = problem ),
111+ Message (role = 'assistant' , content = solution )
112+ ],
113+ user_data = [('ground_truth' , ground_truth )],
114+ )
115+
116+
117+
92118class GSM8KProcessor (Preprocessor ):
93119 """Preprocessor for GSM8K dataset.
94120
@@ -202,13 +228,13 @@ def __call__(
202228def create_gsm8k_dataset ():
203229 """Create GSM8K dataset."""
204230 meta = DatasetMeta (
205- "ms://modelscope/gsm8k " ,
206- subset_name = 'main ' , split = 'train' ,
207- data_slice = range (DATA_NUM ),
231+ "ms://modelscope/competition_math " ,
232+ subset_name = 'default ' , split = 'train' ,
233+ data_slice = range (2000 ),
208234 )
209235 dataset = Dataset (meta )
210236 dataset .set_template ("Template" , model_id = MODEL_ID , max_length = 2048 )
211- dataset .map (GSM8KProcessor ())
237+ dataset .map (MathPreprocessor ())
212238 dataset .encode (add_generation_prompt = True )
213239 return dataset
214240
@@ -217,7 +243,7 @@ def compute_rewards(
217243 trajectories : List [Trajectory ],
218244) -> Tuple [List [float ], List [float ], List [float ]]:
219245 """Compute accuracy and format rewards for GSM8K."""
220- accuracy_reward_fn = GSM8KAccuracyReward ()
246+ accuracy_reward_fn = MathReward ()
221247 format_reward_fn = GSM8KFormatReward ()
222248
223249 accuracy_rewards = accuracy_reward_fn (trajectories , [])
0 commit comments