Skip to content

Commit 196aa45

Browse files
authored
update Qwen3.5 grpo demo (#124)
1 parent 502354a commit 196aa45

File tree

6 files changed

+69
-87
lines changed

6 files changed

+69
-87
lines changed

cookbook/rl/grpo.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
logger = get_logger()
2222

2323
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
24-
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0')))
24+
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))
2525

2626
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
2727
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS',4))
@@ -31,15 +31,16 @@
3131
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
3232
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
3333
MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
34-
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
35-
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 16)) # global completion-level mini-batch-size
34+
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
35+
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) # global completion-level mini-batch-size
3636
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # per-device-micro-batch-size (completion-level), batch_size in forward_backward
3737
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
3838
ADAPTER_NAME = 'default'
39+
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50))
3940

4041
def create_gsm8k_dataset():
4142
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
42-
dataset.set_template('Template', model_id=MODEL_ID, max_length=2048)
43+
dataset.set_template('Template', model_id=MODEL_ID, max_length=400)
4344
dataset.map(GSM8KProcessor())
4445
dataset.encode(add_generation_prompt=True)
4546
return dataset
@@ -68,13 +69,21 @@ def main():
6869
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
6970
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)
7071

71-
lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05)
72-
72+
# lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05)
73+
lora_config = LoraConfig(
74+
target_modules=[
75+
'q_proj', 'k_proj', 'v_proj', 'o_proj',
76+
'gate_proj', 'up_proj', 'down_proj',
77+
'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b', 'out_proj',
78+
],
79+
r=32, lora_alpha=64, lora_dropout=0.05,
80+
)
7381
if USE_MEGATRON:
7482
from twinkle.model.megatron import MegatronModel
7583
model = MegatronModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model', mixed_precision='bf16')
7684
else:
77-
model = TransformersModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model')
85+
from transformers import Qwen3_5ForConditionalGeneration
86+
model = TransformersModel(model_id=MODEL_ID, model_cls=Qwen3_5ForConditionalGeneration, device_mesh=model_mesh, remote_group='model')
7887

7988
model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)
8089
if USE_MEGATRON:
@@ -91,8 +100,9 @@ def main():
91100
model_id=MODEL_ID,
92101
engine_args={
93102
'gpu_memory_utilization': 0.8,
94-
'max_model_len': 4096,
103+
'max_model_len': 4496,
95104
'max_lora_rank': 32, # save as lora_config
105+
# NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976
96106
'enable_lora': True,
97107
},
98108
device_mesh=sampler_mesh,
@@ -172,6 +182,8 @@ def main():
172182

173183
if optim_step >= MAX_STEPS:
174184
break
185+
if optim_step % SAVE_STEPS == 0:
186+
model.save(f'grpo-gsm8k-checkpoint-{optim_step}')
175187
log_dict = metrics.calculate()
176188
log_dict.update(model.calculate_metric(is_training=True))
177189
metrics.reset()

src/twinkle/model/megatron/megatron.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,7 +1587,7 @@ def _trim_vocab(name, tensor):
15871587

15881588
if base_sync_done and adapter_name:
15891589
if merge_and_sync:
1590-
1590+
# LoRA Training and sync full model(merge_adapter)
15911591
def weight_generator():
15921592
for _model in self.strategy.unwrap_model(self.model):
15931593
if isinstance(_model, PeftModel):
@@ -1616,7 +1616,7 @@ def weight_generator():
16161616
yield name, tensor
16171617

16181618
else:
1619-
1619+
# First full base-model sync.
16201620
def _raw_weights():
16211621
for name, tensor in self.get_hf_state_dict(adapter_name=''):
16221622
if name is None or tensor is None:
@@ -1627,7 +1627,7 @@ def _raw_weights():
16271627
yield _trim_vocab(name, tensor)
16281628

16291629
def weight_generator():
1630-
if is_peft_format:
1630+
if is_peft_format and not merge_and_sync:
16311631
yield from _add_base_layer_suffix(_raw_weights())
16321632
else:
16331633
yield from _raw_weights()

src/twinkle/model/transformers/transformers.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,21 +1159,28 @@ def send_weights(
11591159
# Get state dict from unwrapped model
11601160
model = self.strategy.unwrap_model(self.model)
11611161

1162+
def _normalize(name: str, keep_base_layer: bool) -> str:
1163+
name = name.replace('base_model.model.', '')
1164+
if not keep_base_layer:
1165+
name = name.replace('.base_layer', '')
1166+
return name
1167+
1168+
def _is_lora_key(name: str) -> bool:
1169+
return 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name
1170+
11621171
if base_sync_done and adapter_name:
11631172
if merge_and_sync:
1164-
1173+
# LoRA Training and sync full model(merge_adapter)
1174+
# merge and skip lora weigts(already merged)
1175+
# trim prefix(base_model.model.) and suffix(.base_layer)
11651176
def weight_generator():
11661177
if isinstance(model, PeftModel):
11671178
model.merge_adapter()
11681179
for name, tensor in model.state_dict().items():
1169-
# Skip LoRA-specific weights for base model sync
1170-
if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
1180+
if _is_lora_key(name):
11711181
continue
11721182
tensor = Torch.to_local_tensor(tensor)
1173-
# Keep original names (including .base_layer for PEFT models).
1174-
# The sampler side will strip .base_layer based on whether
1175-
# vLLM has enable_lora=True/False.
1176-
yield name, tensor
1183+
yield _normalize(name, keep_base_layer=False), tensor
11771184
if isinstance(model, PeftModel):
11781185
model.unmerge_adapter()
11791186
else:
@@ -1188,19 +1195,19 @@ def weight_generator():
11881195
yield name, tensor
11891196

11901197
else:
1191-
# Full model mode: send all weights (base model sync).
1198+
# First full base-model sync. Whether to keep ``.base_layer.``
1199+
# depends on whether the sampler uses ``enable_lora``:
1200+
# merge_and_sync=True → enable_lora=False → strip .base_layer
1201+
# merge_and_sync=False → enable_lora=True → keep .base_layer
1202+
keep_base_layer = not merge_and_sync
11921203
state_dict = model.state_dict()
11931204

11941205
def weight_generator():
11951206
for name, tensor in state_dict.items():
1196-
# Skip LoRA-specific weights for base model sync
1197-
if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
1207+
if _is_lora_key(name):
11981208
continue
11991209
tensor = Torch.to_local_tensor(tensor)
1200-
# Keep original names (including .base_layer for PEFT models).
1201-
# The sampler side will strip .base_layer based on whether
1202-
# vLLM has enable_lora=True/False.
1203-
yield name, tensor
1210+
yield _normalize(name, keep_base_layer=keep_base_layer), tensor
12041211

12051212
# Run async send_weights in a dedicated event loop thread.
12061213
# We cannot use the Ray worker's event loop because it may already

src/twinkle/preprocessor/llm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,8 @@ class GSM8KProcessor(Preprocessor):
122122
Extracts the ground truth number and stores it in user_data for reward.
123123
Only includes system + user messages; assistant response is generated on-policy.
124124
"""
125-
system_prompt = ('You are a helpful math assistant. Solve the problem step by step. '
126-
'Show your reasoning in <think> </think> tags, then give the final '
127-
'numerical answer after ####.\n'
128-
'For example:\n<think> ... reasoning ... </think>\n#### 42')
125+
system_prompt = ('You are a helpful math assistant. Solve the problem step by step '
126+
'and put your final answer within \\boxed{}.')
129127

130128
def __init__(self, system=None, add_assistant=False):
131129
self.system = system

src/twinkle/reward/gsm8k.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77
class GSM8KAccuracyReward(Reward):
88
"""Accuracy reward for GSM8K: checks if the model's answer matches ground truth.
99
10-
Extracts the last '#### <number>' from model output and compares with ground truth.
10+
Extracts the answer from \\boxed{} (preferred) or #### format.
1111
Returns 1.0 for correct, 0.0 for incorrect.
1212
"""
1313

1414
@staticmethod
1515
def extract_answer(completion: str) -> str:
16-
"""Extract the last #### answer from model completion."""
17-
# Only check last 500 chars for efficiency
16+
"""Extract the answer from model completion, preferring \\boxed{} over ####."""
1817
text = completion[-500:] if len(completion) > 500 else completion
18+
boxed = re.findall(r'\\boxed\{([^}]+)\}', text)
19+
if boxed:
20+
return boxed[-1].replace(',', '').replace(' ', '').strip()
1921
matches = re.findall(r'####\s*([\-\d,\.\s]+)', text)
2022
if matches:
2123
return matches[-1].replace(',', '').replace(' ', '').strip()
@@ -54,9 +56,9 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
5456

5557

5658
class GSM8KFormatReward(Reward):
57-
"""Format reward: checks if output contains <think>...</think> tag.
59+
"""Format reward: checks if output contains \\boxed{} or #### answer format.
5860
59-
Returns 1.0 if format is correct, 0.0 otherwise.
61+
Returns 1.0 if a valid answer format is present, 0.0 otherwise.
6062
"""
6163

6264
def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
@@ -68,7 +70,6 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
6870
if msg.get('role') == 'assistant':
6971
completion = msg.get('content', '')
7072
break
71-
has_think = bool(re.search(r'<think>.*?</think>', completion, re.DOTALL))
72-
has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion))
73-
rewards.append(1.0 if (has_think and has_answer) else 0.0)
73+
has_answer = bool(re.search(r'\\boxed\{[^}]+\}', completion) or re.search(r'####\s*[\-\d,\.]+', completion))
74+
rewards.append(1.0 if has_answer else 0.0)
7475
return rewards

src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py

Lines changed: 14 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,17 @@ def _load_weights(
390390
"""Load a batch of weights into vLLM.
391391
392392
Two modes:
393-
- LoRA mode (``peft_config`` and ``base_sync_done``): Loads weights as
394-
a tensor-based LoRA adapter via ``add_lora()``.
395-
- Base model mode: Strips PEFT prefixes, merges split weights
396-
(q/k/v_proj -> qkv_proj, gate/up_proj -> gate_up_proj) into vLLM's
397-
stacked format, normalizes prefixes, then loads via direct param copy.
393+
394+
* **LoRA mode** (``peft_config`` set and ``base_sync_done=True``):
395+
loads weights as a tensor-based LoRA adapter via ``add_lora()``.
396+
* **Base model mode** (all other cases): delegates to
397+
``model.load_weights()`` which handles stacked-parameter merging
398+
(q/k/v → qkv, gate/up → gate_up) and prefix mapping internally.
399+
400+
Weight names are expected to arrive **already normalised** by the
401+
sender (``TransformersModel.send_weights`` /
402+
``MegatronModel.send_weights``), so no name transformation is done
403+
here.
398404
"""
399405
if peft_config and base_sync_done:
400406
# Remove existing LoRA before replacing
@@ -412,51 +418,9 @@ def _load_weights(
412418
)
413419
self.add_lora(lora_request)
414420
else:
415-
# Base model mode — strip PEFT prefixes and delegate to
416-
# vLLM's model.load_weights() which handles stacked params,
417-
# prefix normalization, and weight_loader internally.
418-
vllm_has_lora = getattr(
419-
getattr(self, 'vllm_config', None),
420-
'lora_config',
421-
None,
422-
) is not None
423-
424-
# When vLLM LoRA is enabled, some LinearBase modules are
425-
# replaced by *WithLoRA wrappers. Their parameters shift
426-
# from e.g. ``gate.weight`` to ``gate.base_layer.weight``.
427-
# HF checkpoint names do NOT contain ``.base_layer.``, so
428-
# vLLM's own ``load_weights`` will KeyError on them.
429-
#
430-
# Build a set of base-layer prefixes that need rewriting.
431-
lora_base_prefixes: set = set()
432-
if vllm_has_lora:
433-
from vllm.lora.layers import BaseLayerWithLoRA
434-
for mod_name, mod in self.model_runner.model.named_modules():
435-
if isinstance(mod, BaseLayerWithLoRA):
436-
# mod_name is e.g. "model.layers.0.mlp.gate"
437-
lora_base_prefixes.add(mod_name + '.')
438-
439-
converted = []
440-
for name, tensor in weights:
441-
if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
442-
continue
443-
name = name.removeprefix('model.base_model.model.')
444-
name = name.removeprefix('base_model.model.')
445-
if not vllm_has_lora:
446-
name = name.replace('.base_layer.', '.')
447-
else:
448-
# Insert ``.base_layer.`` for weights whose module
449-
# has been wrapped by LoRA and whose name does NOT
450-
# already contain it.
451-
if '.base_layer.' not in name:
452-
for pfx in lora_base_prefixes:
453-
if name.startswith(pfx):
454-
# e.g. "model.layers.0.mlp.gate.weight"
455-
# → "model.layers.0.mlp.gate.base_layer.weight"
456-
suffix = name[len(pfx):]
457-
name = pfx + 'base_layer.' + suffix
458-
break
459-
converted.append((name, tensor))
421+
# Base model mode — weights arrive in canonical HF format
422+
converted = [(n, t) for n, t in weights
423+
if 'lora_A' not in n and 'lora_B' not in n and 'lora_embedding' not in n]
460424

461425
if not converted:
462426
return

0 commit comments

Comments
 (0)