|
1 | | -# Tinker-Compatible Client - Math GRPO Training Example |
| 1 | +# Tinker-Compatible Client - GSM8K GRPO Training Example |
2 | 2 | # |
3 | | -# This script demonstrates Math problem training using the |
| 3 | +# This script demonstrates GSM8K math problem training using the |
4 | 4 | # Tinker-compatible client API with save_weights_for_sampler for weight sync. |
5 | 5 | # Instead of calling sync_weights directly, it periodically saves weights and |
6 | 6 | # creates a sampling client for generation. |
7 | 7 | # |
8 | 8 | # Flow: |
9 | | -# 1. Prepare Math dataset (client-side) |
| 9 | +# 1. Prepare GSM8K dataset (client-side) |
10 | 10 | # 2. Initialize Tinker-compatible training & sampling clients |
11 | 11 | # 3. Training loop: |
12 | 12 | # a. Every SYNC_INTERVAL steps: save_weights_for_sampler → sampling_client |
|
22 | 22 | import os |
23 | 23 | import re |
24 | 24 | from tinker import types |
25 | | -from typing import List, Tuple |
| 25 | +from typing import List, Tuple, Dict, Any |
26 | 26 |
|
27 | 27 | from twinkle import init_tinker_client |
28 | 28 | from twinkle import get_logger |
29 | 29 | from twinkle.advantage import GRPOAdvantage |
30 | | -from twinkle.data_format import Message, Trajectory |
31 | 30 | from twinkle.dataloader import DataLoader |
32 | 31 | from twinkle.dataset import Dataset, DatasetMeta |
33 | | -from twinkle.preprocessor import Preprocessor |
| 32 | +from twinkle.preprocessor.llm import GSM8KProcessor |
| 33 | +from twinkle.reward import GSM8KAccuracyReward |
34 | 34 | from twinkle.reward.base import Reward |
35 | 35 | from twinkle.metric import CompletionRewardMetric |
36 | | -from twinkle.template import Template |
| 36 | +from twinkle.template import Qwen3_5Template |
37 | 37 |
|
38 | 38 | logger = get_logger() |
39 | 39 |
|
40 | 40 | # ========== Configuration ========== |
41 | 41 | BASE_MODEL = 'Qwen/Qwen3.5-27B' |
42 | | -NUM_GENERATIONS = 8 |
| 42 | +NUM_GENERATIONS = 4 |
43 | 43 | MAX_NEW_TOKENS = 4096 |
44 | | -LEARNING_RATE = 1e-4 |
| 44 | +LEARNING_RATE = 2e-5 |
45 | 45 | MAX_STEPS = 1000 |
46 | 46 | BATCH_SIZE = 2 |
47 | 47 | TEMPERATURE = 1.0 |
48 | 48 | SYNC_INTERVAL = 1 # Save weights for sampler every N steps |
49 | | -LORA_RANK = 8 |
| 49 | +LORA_RANK = 16 |
50 | 50 | DATA_NUM = 2000 # Number of Math samples to use |
51 | 51 |
|
52 | | -SYSTEM_PROMPT = ('You are a math assistant that values brevity. ' |
53 | | - 'Solve problems with minimal but correct reasoning.\n\n' |
54 | | - 'Rules:\n' |
55 | | - '1. Use <step> </step> tags for reasoning\n' |
56 | | - '2. Final answer after ####\n\n' |
57 | | - 'Example:\n<step>Key step1 -> Ket step 2 -> conclusion</step>\n#### 42') |
| 52 | +SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' |
| 53 | + 'and put your final answer within \\boxed{}.') |
58 | 54 |
|
59 | 55 |
|
| 56 | +# ========== Reward Functions ========== |
| 57 | +class GSM8KBrevityReward(Reward): |
| 58 | + """Brevity reward: rewards shorter completions that contain a valid answer. |
60 | 59 |
|
61 | | -class MathPreprocessor(Preprocessor): |
62 | | - |
63 | | - def __call__(self, rows): |
64 | | - rows = self.map_col_to_row(rows) |
65 | | - rows = [self.preprocess(row) for row in rows] |
66 | | - rows = self.map_row_to_col(rows) |
67 | | - return rows |
68 | | - |
69 | | - def preprocess(self, sample): |
70 | | - if sample['level'] not in ('Level 4', 'Level 5'): |
71 | | - return Trajectory(messages=[], user_data=[]) |
72 | | - |
73 | | - def get_boxed_answer(text): |
74 | | - match = re.search(r'\\boxed{([^}]*)}', text) |
75 | | - return match.group(1) if match else None |
76 | | - |
77 | | - ground_truth = get_boxed_answer(sample['solution']) |
78 | | - if ground_truth is None: |
79 | | - return Trajectory(messages=[], user_data=[]) |
80 | | - problem = sample['problem'] |
81 | | - return Trajectory( |
82 | | - messages=[ |
83 | | - Message(role='system', content=SYSTEM_PROMPT), |
84 | | - Message(role='user', content=problem), |
85 | | - ], |
86 | | - user_data=[('ground_truth', ground_truth)], |
87 | | - ) |
88 | | - |
89 | | - |
90 | | -# ========== Math Reward Functions ========== |
91 | | -class MathAccuracyReward(Reward): |
92 | | - """Accuracy reward for Math: checks if the model's answer matches ground truth. |
93 | | -
|
94 | | - Extracts the last '#### <number>' from model output and compares with ground truth. |
95 | | - Returns 1.0 for correct, 0.0 for incorrect. |
| 60 | + Returns 0.0 if no valid answer format (\\boxed{} or ####). |
| 61 | + Otherwise returns higher score for shorter completions (1.0 at <=200 chars). |
96 | 62 | """ |
97 | 63 |
|
98 | | - @staticmethod |
99 | | - def extract_answer(completion: str) -> str: |
100 | | - """Extract the last #### answer from model completion.""" |
101 | | - # Only check last 500 chars for efficiency |
102 | | - text = completion[-500:] if len(completion) > 500 else completion |
103 | | - matches = re.findall(r'####\s*([\-\d,\.\s]+)', text) |
104 | | - if matches: |
105 | | - return matches[-1].replace(',', '').replace(' ', '').strip() |
106 | | - return '' |
107 | | - |
108 | | - def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]: |
| 64 | + def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: |
109 | 65 | rewards = [] |
110 | | - for trajectory in trajectories: |
111 | | - messages = trajectory.get('messages', []) |
112 | | - # Get model completion (last assistant message) |
| 66 | + for traj in trajectories: |
| 67 | + messages = traj.get('messages', []) |
113 | 68 | completion = '' |
114 | 69 | for msg in reversed(messages): |
115 | 70 | if msg.get('role') == 'assistant': |
116 | 71 | completion = msg.get('content', '') |
117 | 72 | break |
118 | 73 |
|
119 | | - # Get ground truth from user_data |
120 | | - gt = '' |
121 | | - user_data = trajectory.get('user_data', []) |
122 | | - if isinstance(user_data, list): |
123 | | - for item in user_data: |
124 | | - if isinstance(item, (list, tuple)) and len(item) == 2: |
125 | | - if item[0] == 'ground_truth': |
126 | | - gt = str(item[1]) |
127 | | - break |
128 | | - |
129 | | - predicted = self.extract_answer(completion) |
130 | | - |
131 | | - # Numeric comparison |
132 | | - correct = False |
133 | | - if predicted and gt: |
134 | | - try: |
135 | | - correct = abs(float(predicted) - float(gt)) < 1e-5 |
136 | | - except (ValueError, OverflowError): |
137 | | - correct = predicted == gt |
138 | | - |
139 | | - rewards.append(1.0 if correct else 0.0) |
140 | | - return rewards |
141 | | - |
142 | | - |
143 | | -class MathFormatReward(Reward): |
144 | | - """Format reward: checks format and rewards shorter completions. |
145 | | -
|
146 | | - Returns higher score for shorter completions (1.0 at length 100 or less). |
147 | | - Returns 0.0 if format is incorrect. |
148 | | - """ |
149 | | - |
150 | | - def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]: |
151 | | - rewards = [] |
152 | | - for trajectory in trajectories: |
153 | | - messages = trajectory.get('messages', []) |
154 | | - completion = '' |
155 | | - for msg in reversed(messages): |
156 | | - if msg.get('role') == 'assistant': |
157 | | - completion = msg.get('content', '') |
158 | | - break |
159 | | - |
160 | | - has_think = bool(re.search(r'<step>.*?</step>', completion, re.DOTALL)) |
161 | | - has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion)) |
| 74 | + has_answer = bool( |
| 75 | + re.search(r'\\boxed\{[^}]+\}', completion) |
| 76 | + or re.search(r'####\s*[\-\d,\.]+', completion) |
| 77 | + ) |
162 | 78 |
|
163 | | - if not (has_think and has_answer): |
| 79 | + if not has_answer: |
164 | 80 | rewards.append(0.0) |
165 | 81 | else: |
166 | 82 | length = len(completion) |
167 | | - if length <= 100: |
| 83 | + if length <= 200: |
168 | 84 | rewards.append(1.0) |
169 | 85 | else: |
170 | | - reward = max(0.0, 1.0 - (length - 100) / 2000) |
171 | | - rewards.append(reward) |
172 | | - |
| 86 | + rewards.append(max(0.0, 1.0 - (length - 200) / 3000)) |
173 | 87 | return rewards |
174 | 88 |
|
175 | 89 |
|
176 | | -def create_math_dataset(): |
177 | | - """Create Math dataset.""" |
178 | | - meta = DatasetMeta( |
179 | | - 'ms://modelscope/competition_math', |
180 | | - subset_name='default', |
181 | | - split='train', |
182 | | - data_slice=range(DATA_NUM), |
183 | | - ) |
184 | | - dataset = Dataset(meta) |
185 | | - dataset.set_template('Qwen3_5Template', model_id=BASE_MODEL, max_length=4096, truncation_strategy='delete') |
186 | | - dataset.map(MathPreprocessor()) |
187 | | - dataset.filter(lambda row: bool(row['messages'])) |
| 90 | +# ========== Dataset ========== |
| 91 | +def create_gsm8k_dataset(): |
| 92 | + """Create GSM8K dataset.""" |
| 93 | + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train', data_slice=range(DATA_NUM))) |
| 94 | + dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=4096, |
| 95 | + truncation_strategy='delete', enable_thinking=True) |
| 96 | + dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) |
188 | 97 | dataset.encode(add_generation_prompt=True) |
189 | 98 | return dataset |
190 | 99 |
|
191 | 100 |
|
192 | | -def compute_rewards(trajectories: List[Trajectory], ) -> Tuple[List[float], List[float], List[float]]: |
193 | | - """Compute accuracy and format rewards for Math.""" |
194 | | - accuracy_reward_fn = MathAccuracyReward() |
195 | | - format_reward_fn = MathFormatReward() |
| 101 | +def compute_rewards( |
| 102 | + trajectories: List[Dict[str, Any]], |
| 103 | +) -> Tuple[List[float], List[float], List[float]]: |
| 104 | + """Compute accuracy and brevity rewards for GSM8K.""" |
| 105 | + accuracy_reward_fn = GSM8KAccuracyReward() |
| 106 | + brevity_reward_fn = GSM8KBrevityReward() |
196 | 107 |
|
197 | | - accuracy_rewards = accuracy_reward_fn(trajectories, []) |
198 | | - format_rewards = format_reward_fn(trajectories, []) |
199 | | - total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)] |
200 | | - return total_rewards, format_rewards, accuracy_rewards |
| 108 | + accuracy_rewards = accuracy_reward_fn(trajectories) |
| 109 | + brevity_rewards = brevity_reward_fn(trajectories) |
| 110 | + total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)] |
| 111 | + return total_rewards, brevity_rewards, accuracy_rewards |
201 | 112 |
|
202 | 113 |
|
203 | 114 | def main(): |
204 | | - logger.info('Starting Math GRPO training...') |
| 115 | + logger.info('Starting GSM8K GRPO training...') |
205 | 116 |
|
206 | 117 | # Step 1: Prepare dataset and dataloader (client-side) |
207 | | - dataset = create_math_dataset() |
208 | | - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) |
209 | | - template = Template(model_id=f'ms://{BASE_MODEL}') |
| 118 | + dataset = create_gsm8k_dataset() |
| 119 | + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0) |
| 120 | + template = Qwen3_5Template(model_id=f'ms://{BASE_MODEL}') |
210 | 121 |
|
211 | 122 | logger.info('Dataset and template initialized') |
212 | 123 |
|
@@ -254,7 +165,7 @@ def main(): |
254 | 165 | if step % SYNC_INTERVAL == 0: |
255 | 166 | logger.info(f'Step {step}: Saving weights for sampler...') |
256 | 167 |
|
257 | | - sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'Math-step-{step}')) |
| 168 | + sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'GSM8K-step-{step}')) |
258 | 169 | logger.info(f'Step {step}: Sampling client ready') |
259 | 170 |
|
260 | 171 | if sampling_client is None: |
@@ -317,14 +228,12 @@ def main(): |
317 | 228 | completion_lengths.append(len(seq.tokens)) |
318 | 229 |
|
319 | 230 | # ========== 4. Compute rewards ========== |
320 | | - total_rewards, format_rewards, accuracy_rewards = compute_rewards(trajectories) |
| 231 | + total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(trajectories) |
321 | 232 | metrics.accumulate( |
322 | | - None, |
323 | | - None, |
324 | 233 | completion_lengths=completion_lengths, |
325 | 234 | rewards={ |
326 | 235 | 'total': total_rewards, |
327 | | - 'format': format_rewards, |
| 236 | + 'brevity': brevity_rewards, |
328 | 237 | 'accuracy': accuracy_rewards, |
329 | 238 | }) |
330 | 239 |
|
@@ -407,7 +316,7 @@ def main(): |
407 | 316 | step += 1 |
408 | 317 |
|
409 | 318 | # Save final checkpoint |
410 | | - save_future = training_client.save_state('Math-grpo-final') |
| 319 | + save_future = training_client.save_state('gsm8k-grpo-final') |
411 | 320 | save_result = save_future.result() |
412 | 321 | logger.info(f'Saved final checkpoint to {save_result.path}') |
413 | 322 |
|
|
0 commit comments