Skip to content

Commit 821cdf5

Browse files
committed
fix bugs
1 parent 87999d7 commit 821cdf5

File tree

7 files changed

+23
-21
lines changed

7 files changed

+23
-21
lines changed

cookbook/megatron/tp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def eval(model):
3030
dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
3131
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
3232
dataset.encode()
33-
dataloader = DataLoader(dataset=dataset, batch_size=1)
33+
dataloader = DataLoader(dataset=dataset, batch_size=16)
3434
for step, batch in tqdm(enumerate(dataloader)):
3535
model.forward_only(inputs=batch)
3636
metrics = model.calculate_metric(is_training=False)

cookbook/megatron/tp_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def eval(model):
3030
dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507')
3131
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
3232
dataset.encode()
33-
dataloader = DataLoader(dataset=dataset, batch_size=1)
33+
dataloader = DataLoader(dataset=dataset, batch_size=16)
3434
for step, batch in tqdm(enumerate(dataloader)):
3535
model.forward_only(inputs=batch)
3636
metrics = model.calculate_metric(is_training=False)

cookbook/rl/grpo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
ADAPTER_NAME = 'default'
3838

3939
def create_gsm8k_dataset():
40-
dataset = Dataset(DatasetMeta("ms://modelscope/gsm8k", subset_name='main', split='train'))
41-
dataset.set_template("Template", model_id=MODEL_ID, max_length=2048)
40+
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
41+
dataset.set_template('Template', model_id=MODEL_ID, max_length=2048)
4242
dataset.map(GSM8KProcessor())
4343
dataset.encode(add_generation_prompt=True)
4444
return dataset
@@ -67,7 +67,7 @@ def main():
6767
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
6868
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)
6969

70-
lora_config = LoraConfig(target_modules="all-linear", r=32, lora_alpha=64, lora_dropout=0.05)
70+
lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05)
7171

7272
if USE_MEGATRON:
7373
from twinkle.model.megatron import MegatronModel
@@ -164,9 +164,9 @@ def main():
164164
optim_step += 1
165165
log_dict = metrics.calculate()
166166
log_dict.update(model.calculate_metric(is_training=True))
167-
logger.info(f"[Step {optim_step}/{MAX_STEPS}] {log_dict}")
167+
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')
168168

169-
logger.info(f"Training completed. optim_steps={optim_step}")
169+
logger.info(f'Training completed. optim_steps={optim_step}')
170170
model.save('grpo-gsm8k-checkpoint')
171171

172172
if __name__ == '__main__':

src/twinkle/infra/_ray/resource_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, groups: List[De
8686
for i in range(self.nnodes):
8787
# TODO not accurate, because placement_group cannot distribute to node same ordered with self.nodes
8888
node_idx = self.min_node_idx + i if device_type != 'CPU' else i
89-
node = self.nodes[node_idx]
89+
try:
90+
node = self.nodes[node_idx]
91+
except IndexError:
92+
# node_idx may not be continuous
93+
node = self.nodes[0]
9094
node_cpu = int(node['Resources']['CPU'])
9195
if device_type != 'CPU':
9296
bundles.append({device_type: nproc_per_node, 'CPU': max(node_cpu // 2, 1)}) # create bundles

src/twinkle/preprocessor/llm.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import re
3+
24
from twinkle.data_format import Message, Trajectory
35
from .base import Preprocessor
4-
import re
56

6-
from twinkle.data_format import Trajectory, Message
77

88
class CompetitionMathProcessor(Preprocessor):
99

@@ -83,19 +83,18 @@ def __call__(self, row) -> Trajectory:
8383
]
8484
return Trajectory(messages=messages, user_data=[{'target': target, 'nums': nums}])
8585

86+
8687
class GSM8KProcessor(Preprocessor):
8788
"""Preprocessor for GSM8K dataset.
8889
8990
GSM8K fields: question (str), answer (str ending with '#### <number>')
9091
Extracts the ground truth number and stores it in user_data for reward.
9192
"""
9293

93-
system_prompt = (
94-
"You are a helpful math assistant. Solve the problem step by step. "
95-
"Show your reasoning in <think> </think> tags, then give the final "
96-
"numerical answer after ####.\n"
97-
"For example:\n<think> ... reasoning ... </think>\n#### 42"
98-
)
94+
system_prompt = ('You are a helpful math assistant. Solve the problem step by step. '
95+
'Show your reasoning in <think> </think> tags, then give the final '
96+
'numerical answer after ####.\n'
97+
'For example:\n<think> ... reasoning ... </think>\n#### 42')
9998

10099
def extract_ground_truth(answer_str: str) -> str:
101100
"""Extract the number after '####' from GSM8K answer."""

src/twinkle/reward/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
from .base import Reward
33
from .count_down_accuracy import CountDownAccuracy
44
from .format_reward import FormatReward
5-
from .math_reward import MathReward
65
from .gsm8k import GSM8KAccuracyReward, GSM8KFormatReward
6+
from .math_reward import MathReward

src/twinkle/reward/gsm8k.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import List, Dict, Any
21
import re
2+
from typing import Any, Dict, List
3+
34
from twinkle.reward.base import Reward
45

56

@@ -64,9 +65,7 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
6465
if msg.get('role') == 'assistant':
6566
completion = msg.get('content', '')
6667
break
67-
has_think = bool(
68-
re.search(r'<think>.*?</think>', completion, re.DOTALL)
69-
)
68+
has_think = bool(re.search(r'<think>.*?</think>', completion, re.DOTALL))
7069
has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion))
7170
rewards.append(1.0 if (has_think and has_answer) else 0.0)
7271
return rewards

0 commit comments

Comments
 (0)