Skip to content

Commit d722128

Browse files
committed
update short math grpo
1 parent 21b3f7e commit d722128

File tree

7 files changed

+158
-172
lines changed

7 files changed

+158
-172
lines changed

cookbook/client/server/megatron/server_config_4b.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,14 @@ applications:
7676
import_path: sampler
7777
args:
7878
model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier
79-
nproc_per_node: 2 # Number of GPU processes per node
79+
nproc_per_node: 1 # Number of GPU processes per node
8080
sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
8181
engine_args: # vLLM engine-specific settings
8282
max_model_len: 16000 # Maximum sequence length the engine supports
8383
gpu_memory_utilization: 0.7 # Fraction of GPU memory to use (0.0-1.0)
8484
enable_lora: true # Allow loading LoRA adapters during inference
8585
logprobs_mode: processed_logprobs # Logprobs mode for sampling results
86+
enable_tower_connector_lora: true
8687
device_group: # Logical device group for the sampler
8788
name: sampler
8889
ranks: 1 # Number of GPUs to use

cookbook/client/tinker/modelscope/short_math_grpo.py

Lines changed: 52 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
# Tinker-Compatible Client - Math GRPO Training Example
1+
# Tinker-Compatible Client - GSM8K GRPO Training Example
22
#
3-
# This script demonstrates Math problem training using the
3+
# This script demonstrates GSM8K math problem training using the
44
# Tinker-compatible client API with save_weights_for_sampler for weight sync.
55
# Instead of calling sync_weights directly, it periodically saves weights and
66
# creates a sampling client for generation.
77
#
88
# Flow:
9-
# 1. Prepare Math dataset (client-side)
9+
# 1. Prepare GSM8K dataset (client-side)
1010
# 2. Initialize Tinker-compatible training & sampling clients
1111
# 3. Training loop:
1212
# a. Every SYNC_INTERVAL steps: save_weights_for_sampler → sampling_client
@@ -22,191 +22,102 @@
2222
import os
2323
import re
2424
from tinker import types
25-
from typing import List, Tuple
25+
from typing import List, Tuple, Dict, Any
2626

2727
from twinkle import init_tinker_client
2828
from twinkle import get_logger
2929
from twinkle.advantage import GRPOAdvantage
30-
from twinkle.data_format import Message, Trajectory
3130
from twinkle.dataloader import DataLoader
3231
from twinkle.dataset import Dataset, DatasetMeta
33-
from twinkle.preprocessor import Preprocessor
32+
from twinkle.preprocessor.llm import GSM8KProcessor
33+
from twinkle.reward import GSM8KAccuracyReward
3434
from twinkle.reward.base import Reward
3535
from twinkle.metric import CompletionRewardMetric
36-
from twinkle.template import Template
36+
from twinkle.template import Qwen3_5Template
3737

3838
logger = get_logger()
3939

4040
# ========== Configuration ==========
4141
BASE_MODEL = 'Qwen/Qwen3.5-27B'
42-
NUM_GENERATIONS = 8
42+
NUM_GENERATIONS = 4
4343
MAX_NEW_TOKENS = 4096
44-
LEARNING_RATE = 1e-4
44+
LEARNING_RATE = 2e-5
4545
MAX_STEPS = 1000
4646
BATCH_SIZE = 2
4747
TEMPERATURE = 1.0
4848
SYNC_INTERVAL = 1 # Save weights for sampler every N steps
49-
LORA_RANK = 8
49+
LORA_RANK = 16
5050
DATA_NUM = 2000 # Number of Math samples to use
5151

52-
SYSTEM_PROMPT = ('You are a math assistant that values brevity. '
53-
'Solve problems with minimal but correct reasoning.\n\n'
54-
'Rules:\n'
55-
'1. Use <step> </step> tags for reasoning\n'
56-
'2. Final answer after ####\n\n'
57-
'Example:\n<step>Key step1 -> Ket step 2 -> conclusion</step>\n#### 42')
52+
SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
53+
'and put your final answer within \\boxed{}.')
5854

5955

56+
# ========== Reward Functions ==========
57+
class GSM8KBrevityReward(Reward):
58+
"""Brevity reward: rewards shorter completions that contain a valid answer.
6059
61-
class MathPreprocessor(Preprocessor):
62-
63-
def __call__(self, rows):
64-
rows = self.map_col_to_row(rows)
65-
rows = [self.preprocess(row) for row in rows]
66-
rows = self.map_row_to_col(rows)
67-
return rows
68-
69-
def preprocess(self, sample):
70-
if sample['level'] not in ('Level 4', 'Level 5'):
71-
return Trajectory(messages=[], user_data=[])
72-
73-
def get_boxed_answer(text):
74-
match = re.search(r'\\boxed{([^}]*)}', text)
75-
return match.group(1) if match else None
76-
77-
ground_truth = get_boxed_answer(sample['solution'])
78-
if ground_truth is None:
79-
return Trajectory(messages=[], user_data=[])
80-
problem = sample['problem']
81-
return Trajectory(
82-
messages=[
83-
Message(role='system', content=SYSTEM_PROMPT),
84-
Message(role='user', content=problem),
85-
],
86-
user_data=[('ground_truth', ground_truth)],
87-
)
88-
89-
90-
# ========== Math Reward Functions ==========
91-
class MathAccuracyReward(Reward):
92-
"""Accuracy reward for Math: checks if the model's answer matches ground truth.
93-
94-
Extracts the last '#### <number>' from model output and compares with ground truth.
95-
Returns 1.0 for correct, 0.0 for incorrect.
60+
Returns 0.0 if no valid answer format (\\boxed{} or ####).
61+
Otherwise returns higher score for shorter completions (1.0 at <=200 chars).
9662
"""
9763

98-
@staticmethod
99-
def extract_answer(completion: str) -> str:
100-
"""Extract the last #### answer from model completion."""
101-
# Only check last 500 chars for efficiency
102-
text = completion[-500:] if len(completion) > 500 else completion
103-
matches = re.findall(r'####\s*([\-\d,\.\s]+)', text)
104-
if matches:
105-
return matches[-1].replace(',', '').replace(' ', '').strip()
106-
return ''
107-
108-
def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]:
64+
def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
10965
rewards = []
110-
for trajectory in trajectories:
111-
messages = trajectory.get('messages', [])
112-
# Get model completion (last assistant message)
66+
for traj in trajectories:
67+
messages = traj.get('messages', [])
11368
completion = ''
11469
for msg in reversed(messages):
11570
if msg.get('role') == 'assistant':
11671
completion = msg.get('content', '')
11772
break
11873

119-
# Get ground truth from user_data
120-
gt = ''
121-
user_data = trajectory.get('user_data', [])
122-
if isinstance(user_data, list):
123-
for item in user_data:
124-
if isinstance(item, (list, tuple)) and len(item) == 2:
125-
if item[0] == 'ground_truth':
126-
gt = str(item[1])
127-
break
128-
129-
predicted = self.extract_answer(completion)
130-
131-
# Numeric comparison
132-
correct = False
133-
if predicted and gt:
134-
try:
135-
correct = abs(float(predicted) - float(gt)) < 1e-5
136-
except (ValueError, OverflowError):
137-
correct = predicted == gt
138-
139-
rewards.append(1.0 if correct else 0.0)
140-
return rewards
141-
142-
143-
class MathFormatReward(Reward):
144-
"""Format reward: checks format and rewards shorter completions.
145-
146-
Returns higher score for shorter completions (1.0 at length 100 or less).
147-
Returns 0.0 if format is incorrect.
148-
"""
149-
150-
def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]:
151-
rewards = []
152-
for trajectory in trajectories:
153-
messages = trajectory.get('messages', [])
154-
completion = ''
155-
for msg in reversed(messages):
156-
if msg.get('role') == 'assistant':
157-
completion = msg.get('content', '')
158-
break
159-
160-
has_think = bool(re.search(r'<step>.*?</step>', completion, re.DOTALL))
161-
has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion))
74+
has_answer = bool(
75+
re.search(r'\\boxed\{[^}]+\}', completion)
76+
or re.search(r'####\s*[\-\d,\.]+', completion)
77+
)
16278

163-
if not (has_think and has_answer):
79+
if not has_answer:
16480
rewards.append(0.0)
16581
else:
16682
length = len(completion)
167-
if length <= 100:
83+
if length <= 200:
16884
rewards.append(1.0)
16985
else:
170-
reward = max(0.0, 1.0 - (length - 100) / 2000)
171-
rewards.append(reward)
172-
86+
rewards.append(max(0.0, 1.0 - (length - 200) / 3000))
17387
return rewards
17488

17589

176-
def create_math_dataset():
177-
"""Create Math dataset."""
178-
meta = DatasetMeta(
179-
'ms://modelscope/competition_math',
180-
subset_name='default',
181-
split='train',
182-
data_slice=range(DATA_NUM),
183-
)
184-
dataset = Dataset(meta)
185-
dataset.set_template('Qwen3_5Template', model_id=BASE_MODEL, max_length=4096, truncation_strategy='delete')
186-
dataset.map(MathPreprocessor())
187-
dataset.filter(lambda row: bool(row['messages']))
90+
# ========== Dataset ==========
91+
def create_gsm8k_dataset():
92+
"""Create GSM8K dataset."""
93+
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train', data_slice=range(DATA_NUM)))
94+
dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=4096,
95+
truncation_strategy='delete', enable_thinking=True)
96+
dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
18897
dataset.encode(add_generation_prompt=True)
18998
return dataset
19099

191100

192-
def compute_rewards(trajectories: List[Trajectory], ) -> Tuple[List[float], List[float], List[float]]:
193-
"""Compute accuracy and format rewards for Math."""
194-
accuracy_reward_fn = MathAccuracyReward()
195-
format_reward_fn = MathFormatReward()
101+
def compute_rewards(
102+
trajectories: List[Dict[str, Any]],
103+
) -> Tuple[List[float], List[float], List[float]]:
104+
"""Compute accuracy and brevity rewards for GSM8K."""
105+
accuracy_reward_fn = GSM8KAccuracyReward()
106+
brevity_reward_fn = GSM8KBrevityReward()
196107

197-
accuracy_rewards = accuracy_reward_fn(trajectories, [])
198-
format_rewards = format_reward_fn(trajectories, [])
199-
total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)]
200-
return total_rewards, format_rewards, accuracy_rewards
108+
accuracy_rewards = accuracy_reward_fn(trajectories)
109+
brevity_rewards = brevity_reward_fn(trajectories)
110+
total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)]
111+
return total_rewards, brevity_rewards, accuracy_rewards
201112

202113

203114
def main():
204-
logger.info('Starting Math GRPO training...')
115+
logger.info('Starting GSM8K GRPO training...')
205116

206117
# Step 1: Prepare dataset and dataloader (client-side)
207-
dataset = create_math_dataset()
208-
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
209-
template = Template(model_id=f'ms://{BASE_MODEL}')
118+
dataset = create_gsm8k_dataset()
119+
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0)
120+
template = Qwen3_5Template(model_id=f'ms://{BASE_MODEL}')
210121

211122
logger.info('Dataset and template initialized')
212123

@@ -254,7 +165,7 @@ def main():
254165
if step % SYNC_INTERVAL == 0:
255166
logger.info(f'Step {step}: Saving weights for sampler...')
256167

257-
sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'Math-step-{step}'))
168+
sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'GSM8K-step-{step}'))
258169
logger.info(f'Step {step}: Sampling client ready')
259170

260171
if sampling_client is None:
@@ -317,14 +228,12 @@ def main():
317228
completion_lengths.append(len(seq.tokens))
318229

319230
# ========== 4. Compute rewards ==========
320-
total_rewards, format_rewards, accuracy_rewards = compute_rewards(trajectories)
231+
total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(trajectories)
321232
metrics.accumulate(
322-
None,
323-
None,
324233
completion_lengths=completion_lengths,
325234
rewards={
326235
'total': total_rewards,
327-
'format': format_rewards,
236+
'brevity': brevity_rewards,
328237
'accuracy': accuracy_rewards,
329238
})
330239

@@ -407,7 +316,7 @@ def main():
407316
step += 1
408317

409318
# Save final checkpoint
410-
save_future = training_client.save_state('Math-grpo-final')
319+
save_future = training_client.save_state('gsm8k-grpo-final')
411320
save_result = save_future.result()
412321
logger.info(f'Saved final checkpoint to {save_result.path}')
413322

cookbook/client/tinker/self_host/short_math_grpo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,17 @@
3333
from twinkle.reward import GSM8KAccuracyReward
3434
from twinkle.reward.base import Reward
3535
from twinkle.metric import CompletionRewardMetric
36-
from twinkle.template import Template
36+
from twinkle.template import Qwen3_5Template
3737

3838
logger = get_logger()
3939

4040
# ========== Configuration ==========
4141
BASE_MODEL = 'Qwen/Qwen3.5-4B'
4242
NUM_GENERATIONS = 4
4343
MAX_NEW_TOKENS = 4096
44-
LEARNING_RATE = 1e-5
44+
LEARNING_RATE = 2e-5
4545
MAX_STEPS = 1000
46-
BATCH_SIZE = 4
46+
BATCH_SIZE = 2
4747
TEMPERATURE = 1.0
4848
SYNC_INTERVAL = 1 # Save weights for sampler every N steps
4949
LORA_RANK = 16
@@ -92,7 +92,7 @@ def create_gsm8k_dataset():
9292
"""Create GSM8K dataset."""
9393
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train', data_slice=range(DATA_NUM)))
9494
dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=4096,
95-
truncation_strategy='delete', enable_thinking=False)
95+
truncation_strategy='delete', enable_thinking=True)
9696
dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
9797
dataset.encode(add_generation_prompt=True)
9898
return dataset
@@ -117,7 +117,7 @@ def main():
117117
# Step 1: Prepare dataset and dataloader (client-side)
118118
dataset = create_gsm8k_dataset()
119119
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0)
120-
template = Template(model_id=f'ms://{BASE_MODEL}')
120+
template = Qwen3_5Template(model_id=f'ms://{BASE_MODEL}')
121121

122122
logger.info('Dataset and template initialized')
123123

cookbook/client/twinkle/self_host/self_congnition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from twinkle import get_logger
1616
from twinkle.dataset import DatasetMeta
17-
from twinkle_client import init_twinkle_client
17+
from twinkle import init_twinkle_client
1818
from twinkle_client.dataloader import DataLoader
1919
from twinkle_client.dataset import Dataset
2020
from twinkle_client.model import MultiLoraTransformersModel

0 commit comments

Comments
 (0)