Skip to content

Commit a222914

Browse files
authored
support transformers multi-modal grpo (#131)
1 parent a89ede5 commit a222914

File tree

10 files changed

+495
-96
lines changed

10 files changed

+495
-96
lines changed

cookbook/rl/mm_grpo.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
"""Multimodal GRPO training demo with Qwen3.5 VL model on CLEVR dataset.
2+
3+
This script demonstrates on-policy GRPO (Group Relative Policy Optimization)
4+
for visual question answering using:
5+
- Model: Qwen3.5-2B (vision-language model)
6+
- Dataset: AI-ModelScope/clevr_cogen_a_train (CLEVR visual reasoning)
7+
- Rewards: accuracy (answer correctness) + format (<think>/<answer> tags)
8+
- Template: Qwen3_5Template (handles vision token embedding merge)
9+
10+
Architecture:
11+
- Separate GPU groups for training model and vLLM sampler (Ray mode)
12+
- LoRA fine-tuning with NCCL weight sync between model and sampler
13+
- GRPO loss with PPO-style clipping (epsilon=0.2)
14+
15+
Usage:
16+
python mm_grpo.py
17+
18+
Environment variables:
19+
MODEL_ID : Model path (default: ms://Qwen/Qwen3.5-2B)
20+
MODEL_GPUS : GPUs for training model (default: 2)
21+
SAMPLER_GPUS : GPUs for vLLM sampler (default: 1)
22+
NUM_GENERATIONS: Completions per prompt for GRPO grouping (default: 4)
23+
MAX_NEW_TOKENS : Max generation length (default: 4096)
24+
LR : Learning rate (default: 5e-5)
25+
MAX_STEPS : Total optimization steps (default: 200)
26+
BATCH_SIZE : Global prompt-level batch size (default: 1)
27+
MINI_BATCH_SIZE: Global completion-level mini-batch size (default: 4)
28+
MICRO_BATCH_SIZE: Per-device micro-batch size (default: 1)
29+
DATA_SLICE : Number of dataset samples to use (default: 2000)
30+
SAVE_STEPS : Checkpoint save interval (default: 50)
31+
"""
32+
import os
33+
from typing import Any, Dict, List, Tuple
34+
35+
from peft import LoraConfig
36+
37+
import twinkle
38+
from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
39+
from twinkle.advantage import GRPOAdvantage
40+
from twinkle.checkpoint_engine import CheckpointEngineManager
41+
from twinkle.data_format import SamplingParams
42+
from twinkle.dataloader import DataLoader
43+
from twinkle.dataset import DatasetMeta, LazyDataset
44+
from twinkle.metric import CompletionRewardMetric
45+
from twinkle.model import TransformersModel
46+
from twinkle.preprocessor.mm import CLEVRProcessor
47+
from twinkle.processor import InputProcessor
48+
from twinkle.reward import FormatReward, MultiModalAccuracyReward
49+
from twinkle.sampler import vLLMSampler
50+
from twinkle.template import Qwen3_5Template
51+
52+
logger = get_logger()
53+
54+
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-2B')
55+
56+
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2))
57+
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 1))
58+
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
59+
60+
NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 4))
61+
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
62+
LEARNING_RATE = float(os.environ.get('LR', 5e-5))
63+
MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
64+
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 1))
65+
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4))
66+
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1))
67+
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
68+
DATA_SLICE = int(os.environ.get('DATA_SLICE', 2000))
69+
ADAPTER_NAME = 'default'
70+
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50))
71+
72+
73+
def create_clevr_dataset():
74+
dataset = LazyDataset(
75+
DatasetMeta('ms://AI-ModelScope/clevr_cogen_a_train', split='train',
76+
data_slice=range(DATA_SLICE)),
77+
)
78+
dataset.cast_column('image', decode=False)
79+
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096)
80+
dataset.map(CLEVRProcessor(), remove_columns=['image', 'problem', 'solution'])
81+
dataset.encode(add_generation_prompt=True)
82+
return dataset
83+
84+
85+
def compute_rewards(
86+
trajectories: List[Dict[str, Any]],
87+
) -> Tuple[List[float], List[float], List[float]]:
88+
accuracy_reward_fn = MultiModalAccuracyReward()
89+
format_reward_fn = FormatReward()
90+
accuracy_rewards = accuracy_reward_fn(trajectories)
91+
format_rewards = format_reward_fn(trajectories, trajectories)
92+
total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)]
93+
return total_rewards, format_rewards, accuracy_rewards
94+
95+
96+
def main():
97+
device_groups = [
98+
DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
99+
DeviceGroup(
100+
name='sampler',
101+
ranks=list(range(MODEL_GPUS, NUM_GPUS)),
102+
device_type='GPU',
103+
gpus_per_worker=SAMPLER_GPUS,
104+
),
105+
]
106+
model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
107+
sampler_mesh = DeviceMesh.from_sizes(world_size=1, dp_size=1)
108+
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)
109+
110+
lora_config = LoraConfig(
111+
target_modules=[
112+
'q_proj', 'k_proj', 'v_proj', 'o_proj',
113+
'gate_proj', 'up_proj', 'down_proj',
114+
'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b', 'out_proj',
115+
],
116+
)
117+
118+
from modelscope import Qwen3_5ForConditionalGeneration
119+
model = TransformersModel(
120+
model_id=MODEL_ID,
121+
model_cls=Qwen3_5ForConditionalGeneration,
122+
device_mesh=model_mesh,
123+
remote_group='model',
124+
)
125+
126+
model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
127+
model.set_optimizer('AdamW', lr=LEARNING_RATE)
128+
model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)
129+
model.set_loss('GRPOLoss', epsilon=0.2)
130+
model.set_processor(InputProcessor)
131+
model.set_template('Qwen3_5Template', model_id=MODEL_ID)
132+
133+
sampler = vLLMSampler(
134+
model_id=MODEL_ID,
135+
engine_args={
136+
'gpu_memory_utilization': 0.8,
137+
'max_model_len': 8192,
138+
'max_lora_rank': 8,
139+
'enable_lora': True,
140+
'limit_mm_per_prompt': {'image': 1, 'video': 0},
141+
'mm_processor_cache_gb': 0,
142+
},
143+
device_mesh=sampler_mesh,
144+
remote_group='sampler',
145+
)
146+
sampler.set_template(Qwen3_5Template, model_id=MODEL_ID)
147+
148+
ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)
149+
150+
GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
151+
dataloader = DataLoader(
152+
dataset=create_clevr_dataset,
153+
batch_size=GLOBAL_BATCH_SIZE,
154+
min_batch_size=GLOBAL_BATCH_SIZE,
155+
device_mesh=model_mesh,
156+
remote_group='model',
157+
)
158+
advantage_fn = GRPOAdvantage()
159+
metrics = CompletionRewardMetric()
160+
161+
sampling_params = SamplingParams(
162+
max_tokens=MAX_NEW_TOKENS,
163+
num_samples=1,
164+
logprobs=1,
165+
temperature=1.0,
166+
)
167+
168+
optim_step = 0
169+
logger.info(get_device_placement())
170+
171+
for batch in dataloader:
172+
if optim_step >= MAX_STEPS:
173+
break
174+
metrics.reset()
175+
global_prompts = batch if isinstance(batch, list) else [batch]
176+
177+
ckpt_manager.sync_weights(merge_and_sync=False)
178+
sampler.reset_prefix_cache()
179+
sample_responses = sampler.sample(
180+
global_prompts * NUM_GENERATIONS,
181+
sampling_params,
182+
)
183+
184+
all_input_data: List[Dict[str, Any]] = []
185+
all_old_logps: List[List[float]] = []
186+
all_completion_lengths: List[int] = []
187+
for sample_response in sample_responses:
188+
for sequence in sample_response.sequences:
189+
all_input_data.append(sequence.new_input_feature)
190+
all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])
191+
all_completion_lengths.append(len(sequence.tokens))
192+
193+
total_rewards, format_rewards, accuracy_rewards = compute_rewards(all_input_data)
194+
metrics.accumulate(
195+
completion_lengths=all_completion_lengths,
196+
rewards={
197+
'total': total_rewards,
198+
'format': format_rewards,
199+
'accuracy': accuracy_rewards,
200+
},
201+
)
202+
203+
advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()
204+
205+
total_completions = len(all_input_data)
206+
for mb_start in range(0, total_completions, MINI_BATCH_SIZE):
207+
mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions)
208+
mb_inputs = all_input_data[mb_start:mb_end]
209+
mb_old_logps = all_old_logps[mb_start:mb_end]
210+
mb_advantages = advantages[mb_start:mb_end]
211+
212+
model.forward_backward(
213+
inputs=mb_inputs,
214+
old_logps=mb_old_logps,
215+
advantages=mb_advantages,
216+
micro_batch_size=MICRO_BATCH_SIZE,
217+
)
218+
model.clip_grad_and_step()
219+
optim_step += 1
220+
221+
if optim_step >= MAX_STEPS:
222+
break
223+
if optim_step % SAVE_STEPS == 0:
224+
model.save(f'mm-grpo-clevr-checkpoint-{optim_step}')
225+
log_dict = metrics.calculate()
226+
log_dict.update(model.calculate_metric(is_training=True))
227+
metrics.reset()
228+
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')
229+
230+
logger.info(f'Training completed. optim_steps={optim_step}')
231+
model.save('mm-grpo-clevr-checkpoint')
232+
233+
234+
if __name__ == '__main__':
235+
main()

src/twinkle/dataset/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,19 @@ def _load_dataset(dataset_meta: DatasetMeta, **kwargs):
168168
dataset = dataset.select(iter_list)
169169
return dataset
170170

171+
@remote_function()
172+
def cast_column(self, column: str, decode: bool = True) -> None:
173+
"""Cast an image/audio column's decode mode.
174+
175+
Useful for setting ``decode=False`` before ``.map()`` to keep media
176+
as raw bytes and avoid expensive PIL encode/decode round-trips.
177+
"""
178+
from datasets import Image as ImageFeature
179+
for key in list(self.datasets.keys()):
180+
self.datasets[key] = self.datasets[key].cast_column(column, ImageFeature(decode=decode))
181+
if len(self.datasets) == 1:
182+
self.dataset = self.datasets[next(iter(self.datasets.keys()))]
183+
171184
@remote_function()
172185
def map(self,
173186
preprocess_func: Union[Preprocessor, Callable, str, Type[Preprocessor]],

src/twinkle/dataset/lazy_dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def encode(self, **kwargs):
2222
assert self.template.truncation_strategy != 'split', ('Lazy tokenize does not support '
2323
'truncation_strategy==`split`')
2424
self.do_encode = True
25+
self.encode_kwargs = kwargs
2526

2627
@remote_function()
2728
def check(self, **kwargs):
@@ -33,7 +34,11 @@ def __getitem__(self, idx):
3334
item = self.dataset[idx]
3435
# may raise errors
3536
if self.do_encode:
36-
item = self.template.batch_encode([item])[0]
37+
encoded = self.template.batch_encode([item], **self.encode_kwargs)[0]
38+
for key in item:
39+
if key not in encoded:
40+
encoded[key] = item[key]
41+
item = encoded
3742
elif self.do_check:
3843
item = self.template.check(item)
3944
return item

src/twinkle/preprocessor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .dpo import EmojiDPOProcessor
44
from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor,
55
GSM8KProcessor, SelfCognitionProcessor)
6+
from .mm import CLEVRProcessor, VisionQAProcessor

src/twinkle/preprocessor/mm.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import re
3+
from typing import Any, Dict, List, Optional
4+
5+
from twinkle.data_format import Message, Trajectory
6+
from .base import Preprocessor
7+
8+
9+
class CLEVRProcessor(Preprocessor):
10+
"""Preprocessor for CLEVR-CoGenT visual reasoning dataset (prompt-only, for GRPO).
11+
12+
Dataset fields: image (PIL.Image or dict), problem (str), solution (str with <answer> tags)
13+
Produces prompt-only trajectories with image in the user message and
14+
ground truth stored in user_data for reward computation.
15+
16+
For fast ``.map()`` performance, call ``dataset.cast_column('image', decode=False)``
17+
before mapping so that images stay as Arrow-native bytes dicts.
18+
"""
19+
20+
DEFAULT_SYSTEM = ('A conversation between User and Assistant. The user asks a question, '
21+
'and the Assistant solves it. The assistant first thinks about the reasoning '
22+
'process in the mind and then provides the user with the answer. The reasoning '
23+
'process and answer are enclosed within <think> </think> and <answer> </answer> '
24+
'tags, respectively, i.e., <think> reasoning process here </think>'
25+
'<answer> answer here </answer>')
26+
27+
def __init__(self, system: Optional[str] = None):
28+
self.system = system if system is not None else self.DEFAULT_SYSTEM
29+
30+
@staticmethod
31+
def extract_ground_truth(solution: str) -> str:
32+
"""Extract answer text from <answer>...</answer> tags."""
33+
match = re.search(r'<answer>\s*(.*?)\s*</answer>', solution, re.DOTALL)
34+
return match.group(1).strip() if match else solution.strip()
35+
36+
def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
37+
rows = self.map_col_to_row(rows)
38+
rows = [self.preprocess(row) for row in rows]
39+
rows = self.map_row_to_col(rows)
40+
return rows
41+
42+
def preprocess(self, row) -> Trajectory:
43+
image = row['image']
44+
problem = row['problem']
45+
solution = row.get('solution', '')
46+
ground_truth = self.extract_ground_truth(solution)
47+
48+
messages = [
49+
Message(role='system', content=[{
50+
'type': 'text',
51+
'text': self.system
52+
}]),
53+
Message(role='user', content=[
54+
{
55+
'type': 'image',
56+
'image': image
57+
},
58+
{
59+
'type': 'text',
60+
'text': problem
61+
},
62+
]),
63+
]
64+
return Trajectory(
65+
messages=messages,
66+
user_data=[('ground_truth', ground_truth), ('solution', solution)],
67+
)

src/twinkle/reward/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .format_reward import FormatReward
44
from .gsm8k import GSM8KAccuracyReward, GSM8KFormatReward
55
from .math_reward import MathReward
6+
from .mm_reward import MultiModalAccuracyReward

0 commit comments

Comments
 (0)