Skip to content

Commit 4db3c40

Browse files
Refactor megatron to mcore_bridge (#134)
1 parent ccf0d79 commit 4db3c40

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+1855
-6887
lines changed

cookbook/megatron/tp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from twinkle.dataset import Dataset, DatasetMeta
99
from twinkle.model import MegatronModel
1010
from twinkle.preprocessor import SelfCognitionProcessor
11-
# Construct a device_mesh, tp=pp=cp=2, dp=1
11+
# Construct a device_mesh, tp=pp=dp=2
1212
device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2)
1313
# use torchrun mode
1414
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
@@ -19,7 +19,7 @@
1919
def eval(model):
2020
# 100 Samples
2121
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
22-
dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B')
22+
dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
2323
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
2424
dataset.encode()
2525
dataloader = DataLoader(dataset=dataset, batch_size=16)
@@ -33,7 +33,7 @@ def train():
3333
# 1000 samples
3434
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
3535
# Set template to prepare encoding
36-
dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B')
36+
dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
3737
# Preprocess the dataset to standard format
3838
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
3939
# Encode dataset

cookbook/rl/grpo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
def create_gsm8k_dataset():
4242
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
43-
dataset.set_template('Template', model_id=MODEL_ID, max_length=400)
43+
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=400)
4444
dataset.map(GSM8KProcessor())
4545
dataset.encode(add_generation_prompt=True)
4646
return dataset
@@ -94,7 +94,7 @@ def main():
9494
model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)
9595
model.set_loss('GRPOLoss', epsilon=0.2)
9696
model.set_processor(InputProcessor)
97-
model.set_template('Template', model_id=MODEL_ID)
97+
model.set_template('Qwen3_5Template', model_id=MODEL_ID)
9898

9999
sampler = vLLMSampler(
100100
model_id=MODEL_ID,
@@ -108,7 +108,7 @@ def main():
108108
device_mesh=sampler_mesh,
109109
remote_group='sampler',
110110
)
111-
sampler.set_template(Template, model_id=MODEL_ID)
111+
sampler.set_template('Qwen3_5Template', model_id=MODEL_ID)
112112

113113
ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)
114114

cookbook/rl/grpo_mm.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
"""GRPO training script for OlympiadBench multimodal math/physics dataset.
2+
3+
Supports three subsets:
4+
- OE_MM_maths_zh_CEE: Multimodal math problems (Chinese CEE)
5+
- OE_MM_physics_zh_CEE: Multimodal physics problems (Chinese CEE)
6+
- OE_TO_maths_zh_CEE: Text-only math problems (Chinese CEE)
7+
"""
8+
import os
9+
from typing import List, Tuple, Dict, Any
10+
11+
from peft import LoraConfig
12+
13+
import twinkle
14+
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
15+
from twinkle.advantage import GRPOAdvantage
16+
from twinkle.checkpoint_engine import CheckpointEngineManager
17+
from twinkle.data_format import SamplingParams
18+
from twinkle.dataloader import DataLoader
19+
from twinkle.dataset import DatasetMeta, LazyDataset
20+
from twinkle.metric import CompletionRewardMetric
21+
from twinkle.model import TransformersModel
22+
from twinkle.preprocessor.olympiad_bench import OlympiadBenchProcessor
23+
from twinkle.reward.olympiad_bench import (
24+
OlympiadBenchAccuracyReward,
25+
OlympiadBenchFormatReward,
26+
OlympiadBenchQualityReward,
27+
)
28+
from twinkle.sampler import vLLMSampler
29+
30+
import swanlab
31+
swanlab.init(
32+
project='twinkle',
33+
)
34+
logger = get_logger()
35+
36+
# Model configuration
37+
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
38+
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))
39+
40+
# GPU configuration
41+
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
42+
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
43+
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
44+
45+
# Training hyperparameters
46+
NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
47+
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
48+
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
49+
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
50+
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
51+
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4))
52+
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1))
53+
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
54+
ADAPTER_NAME = 'default'
55+
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50))
56+
57+
# Dataset configuration
58+
SUBSETS = [
59+
'OE_MM_maths_zh_CEE',
60+
'OE_MM_physics_zh_CEE',
61+
'OE_TO_maths_zh_CEE',
62+
]
63+
64+
65+
def create_olympiad_dataset():
66+
"""Create OlympiadBench dataset with all three subsets mixed."""
67+
# Create dataset with first subset
68+
ds = DatasetMeta(
69+
'ms://AI-ModelScope/OlympiadBench',
70+
subset_name=SUBSETS[0],
71+
split='train',
72+
)
73+
dataset = LazyDataset(ds)
74+
dataset.map(OlympiadBenchProcessor(language='zh'), dataset_meta=ds)
75+
76+
# Add remaining subsets
77+
for subset in SUBSETS[1:]:
78+
ds = DatasetMeta(
79+
'ms://AI-ModelScope/OlympiadBench',
80+
subset_name=subset,
81+
split='train',
82+
)
83+
dataset.add_dataset(ds)
84+
dataset.map(OlympiadBenchProcessor(language='zh'), dataset_meta=ds)
85+
86+
# Set template and preprocess
87+
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=2048, enable_thinking=False)
88+
# Mix all datasets (interleave)
89+
dataset.mix_dataset(interleave=True)
90+
return dataset
91+
92+
93+
def compute_rewards(
94+
trajectories: List[Dict[str, Any]],
95+
) -> Tuple[List[float], Dict[str, List[float]]]:
96+
"""Compute rewards for trajectories.
97+
98+
Three core rewards, all normalized to [0, 1]:
99+
- Accuracy: Answer correctness (weight: 2.0)
100+
- Format: Answer formatting and consistency (weight: 1.0)
101+
- Quality: Reasoning, length, repetition (weight: 1.0)
102+
103+
Returns:
104+
total_rewards: Weighted sum normalized to [0, 1]
105+
reward_dict: Individual reward components for logging
106+
"""
107+
accuracy_fn = OlympiadBenchAccuracyReward()
108+
format_fn = OlympiadBenchFormatReward()
109+
quality_fn = OlympiadBenchQualityReward()
110+
111+
accuracy = accuracy_fn(trajectories)
112+
format_r = format_fn(trajectories)
113+
quality = quality_fn(trajectories)
114+
115+
# Weights: accuracy most important, format and quality equal
116+
total_rewards = [
117+
(2.0 * a + 1.0 * f + 1.0 * q) / 4.0
118+
for a, f, q in zip(accuracy, format_r, quality)
119+
]
120+
121+
return total_rewards, {
122+
'accuracy': accuracy,
123+
'format': format_r,
124+
'quality': quality,
125+
}
126+
127+
128+
def main():
129+
# Device groups: model and sampler on separate GPUs
130+
device_groups = [
131+
DeviceGroup(name='model', ranks=MODEL_GPUS, device_type='GPU'),
132+
DeviceGroup(name='sampler', ranks=SAMPLER_GPUS, device_type='GPU'),
133+
]
134+
135+
model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
136+
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
137+
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)
138+
139+
# LoRA configuration
140+
lora_config = LoraConfig(
141+
target_modules=['all-linear'],
142+
r=16,
143+
lora_alpha=32,
144+
lora_dropout=0.05,
145+
)
146+
147+
# Model setup
148+
if USE_MEGATRON:
149+
from twinkle.model.megatron import MegatronModel
150+
model = MegatronModel(
151+
model_id=MODEL_ID,
152+
device_mesh=model_mesh,
153+
remote_group='model',
154+
)
155+
else:
156+
from transformers import Qwen3_5ForConditionalGeneration
157+
model = TransformersModel(
158+
model_id=MODEL_ID,
159+
model_cls=Qwen3_5ForConditionalGeneration,
160+
device_mesh=model_mesh,
161+
remote_group='model',
162+
)
163+
164+
model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)
165+
166+
if USE_MEGATRON:
167+
model.set_optimizer('default', lr=LEARNING_RATE, adapter_name=ADAPTER_NAME)
168+
model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE, adapter_name=ADAPTER_NAME)
169+
else:
170+
model.set_optimizer('AdamW', lr=LEARNING_RATE)
171+
model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)
172+
173+
model.set_loss('GRPOLoss', epsilon=0.2, adapter_name=ADAPTER_NAME)
174+
model.set_template('Qwen3_5Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME, enable_thinking=False)
175+
176+
# Sampler setup
177+
sampler = vLLMSampler(
178+
model_id=MODEL_ID,
179+
engine_args={
180+
'gpu_memory_utilization': 0.8,
181+
'max_model_len': 32000,
182+
'max_lora_rank': 32,
183+
'enable_lora': True,
184+
'limit_mm_per_prompt': {'image': 9}, # OlympiadBench has up to 9 images
185+
},
186+
device_mesh=sampler_mesh,
187+
remote_group='sampler',
188+
)
189+
sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False)
190+
191+
# Checkpoint manager
192+
ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)
193+
194+
# DataLoader
195+
GLOBAL_BATCH_SIZE = BATCH_SIZE
196+
dataloader = DataLoader(
197+
dataset=create_olympiad_dataset,
198+
batch_size=GLOBAL_BATCH_SIZE,
199+
min_batch_size=GLOBAL_BATCH_SIZE,
200+
device_mesh=model_mesh,
201+
)
202+
203+
# RL components
204+
advantage_fn = GRPOAdvantage()
205+
metrics = CompletionRewardMetric()
206+
207+
sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1)
208+
209+
optim_step = 0
210+
logger.info(f'Starting OlympiadBench GRPO training on subsets: {SUBSETS}')
211+
logger.info(get_device_placement())
212+
213+
for batch in dataloader:
214+
if optim_step >= MAX_STEPS:
215+
break
216+
217+
metrics.reset()
218+
219+
# Sync weights to sampler
220+
ckpt_manager.sync_weights(merge_and_sync=False)
221+
sampler.reset_prefix_cache()
222+
223+
# Sample multiple completions per prompt
224+
sample_responses = sampler.sample(
225+
batch * NUM_GENERATIONS,
226+
sampling_params,
227+
)
228+
229+
all_input_data: List[Dict[str, Any]] = []
230+
all_old_logps: List[List[float]] = []
231+
all_completion_lengths: List[int] = []
232+
233+
for sample_response in sample_responses:
234+
for sequence in sample_response.sequences:
235+
all_input_data.append(sequence.new_input_feature)
236+
all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])
237+
all_completion_lengths.append(len(sequence.tokens))
238+
239+
# Compute rewards
240+
total_rewards, reward_dict = compute_rewards(all_input_data)
241+
242+
metrics.accumulate(
243+
completion_lengths=all_completion_lengths,
244+
rewards={
245+
'total': total_rewards,
246+
**{k: v for k, v in reward_dict.items()},
247+
},
248+
)
249+
250+
# Compute advantages
251+
advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()
252+
253+
# Mini-batch training
254+
total_completions = len(all_input_data)
255+
for mb_start in range(0, total_completions, MINI_BATCH_SIZE):
256+
mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions)
257+
mb_inputs = all_input_data[mb_start:mb_end]
258+
mb_old_logps = all_old_logps[mb_start:mb_end]
259+
mb_advantages = advantages[mb_start:mb_end]
260+
261+
model.forward_backward(
262+
inputs=mb_inputs,
263+
old_logps=mb_old_logps,
264+
advantages=mb_advantages,
265+
micro_batch_size=MICRO_BATCH_SIZE,
266+
adapter_name=ADAPTER_NAME,
267+
)
268+
model.clip_grad_and_step(adapter_name=ADAPTER_NAME)
269+
optim_step += 1
270+
271+
if optim_step >= MAX_STEPS:
272+
break
273+
274+
if optim_step % SAVE_STEPS == 0:
275+
model.save(f'olympiad-grpo-mixed-checkpoint-{optim_step}', adapter_name=ADAPTER_NAME)
276+
277+
log_dict = metrics.calculate()
278+
log_dict.update(model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME))
279+
metrics.reset()
280+
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')
281+
swanlab.log(log_dict)
282+
283+
logger.info(f'Training completed. optim_steps={optim_step}')
284+
model.save('olympiad-grpo-mixed-final', adapter_name=ADAPTER_NAME)
285+
286+
287+
if __name__ == '__main__':
288+
main()

0 commit comments

Comments
 (0)