Skip to content

Commit 7640fb2

Browse files
committed
fix
1 parent b941858 commit 7640fb2

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

cookbook/legacy/grpo/gsm8k.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from twinkle.model import TransformersModel
3232
from twinkle.preprocessor import Preprocessor
3333
from twinkle.processor import InputProcessor
34+
from twinkle.reward import MathReward
3435
from twinkle.reward.base import Reward
3536
from twinkle.sampler import vLLMSampler
3637
from twinkle.template import Template
@@ -44,7 +45,7 @@
4445

4546
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
4647
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
47-
SAMPLER_TP = int(os.environ.get('SAMPLER_TP', SAMPLER_GPUS // 2))
48+
SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 1))
4849
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
4950

5051
PP_SIZE = 2
@@ -83,12 +84,37 @@
8384

8485
SYSTEM_PROMPT = (
8586
"You are a helpful math assistant. Solve the problem step by step. "
86-
"Show your reasoning in <think> </think> tags, then give the final "
87-
"numerical answer after ####.\n"
88-
"For example:\n<think> ... reasoning ... </think>\n#### 42"
87+
"YOU MUST Show your reasoning in <think> </think> tags, then give the final "
88+
"numerical answer with boxed \\boxed{}.\n"
89+
"For example:\n<think> ... reasoning ... </think>\n\\boxed{42}"
8990
)
9091

9192

93+
class MathPreprocessor(Preprocessor):
94+
95+
def __call__(self, sample):
96+
if sample['level'] not in ('Level 4', 'Level 5'):
97+
return None
98+
99+
def get_boxed_answer(text):
100+
match = re.search(r'\\boxed\{([^}]+)\}', text)
101+
return match.group(1) if match else None
102+
103+
ground_truth = get_boxed_answer(sample['solution'])
104+
if ground_truth is None:
105+
return None
106+
problem = sample['problem']
107+
solution = sample['solution']
108+
return Trajectory(
109+
messages=[
110+
Message(role='user', content=problem),
111+
Message(role='assistant', content=solution)
112+
],
113+
user_data=[('ground_truth', ground_truth)],
114+
)
115+
116+
117+
92118
class GSM8KProcessor(Preprocessor):
93119
"""Preprocessor for GSM8K dataset.
94120
@@ -202,13 +228,13 @@ def __call__(
202228
def create_gsm8k_dataset():
203229
"""Create GSM8K dataset."""
204230
meta = DatasetMeta(
205-
"ms://modelscope/gsm8k",
206-
subset_name='main', split='train',
207-
data_slice=range(DATA_NUM),
231+
"ms://modelscope/competition_math",
232+
subset_name='default', split='train',
233+
data_slice=range(2000),
208234
)
209235
dataset = Dataset(meta)
210236
dataset.set_template("Template", model_id=MODEL_ID, max_length=2048)
211-
dataset.map(GSM8KProcessor())
237+
dataset.map(MathPreprocessor())
212238
dataset.encode(add_generation_prompt=True)
213239
return dataset
214240

@@ -217,7 +243,7 @@ def compute_rewards(
217243
trajectories: List[Trajectory],
218244
) -> Tuple[List[float], List[float], List[float]]:
219245
"""Compute accuracy and format rewards for GSM8K."""
220-
accuracy_reward_fn = GSM8KAccuracyReward()
246+
accuracy_reward_fn = MathReward()
221247
format_reward_fn = GSM8KFormatReward()
222248

223249
accuracy_rewards = accuracy_reward_fn(trajectories, [])

0 commit comments

Comments
 (0)