Skip to content

Commit 9b4d0f0

Browse files
authored
support rl vit lora with vLLM (#147)
1 parent 37dc52b commit 9b4d0f0

File tree

4 files changed

+6
-7
lines changed

4 files changed

+6
-7
lines changed

cookbook/rl/grpo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def main():
7070
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)
7171

7272
# lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05)
73+
# Since we are training on text-only data, we avoid using 'all-linear' which would include the ViT layers.
7374
lora_config = LoraConfig(
7475
target_modules=[
7576
'q_proj', 'k_proj', 'v_proj', 'o_proj',

cookbook/rl/grpo_mm.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,7 @@ def main():
134134

135135
# LoRA configuration
136136
lora_config = LoraConfig(
137-
target_modules=[
138-
'q_proj', 'k_proj', 'v_proj', 'o_proj',
139-
'gate_proj', 'up_proj', 'down_proj',
140-
'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b', 'out_proj',
141-
],
137+
target_modules=['all-linear'], # including ViT and Merger/Connector
142138
r=16,
143139
lora_alpha=32,
144140
lora_dropout=0.05,
@@ -185,6 +181,7 @@ def main():
185181
# lora will be merged into the base model and sync all weights to vLLM
186182
'enable_lora': True,
187183
'limit_mm_per_prompt': {'image': 9}, # OlympiadBench has up to 9 images
184+
'enable_tower_connector_lora': True, # enable ViT(tower) and Merger(connector) LoRA on vLLM side
188185
},
189186
device_mesh=sampler_mesh,
190187
remote_group='sampler',

cookbook/rl/short_math_grpo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def main():
116116
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
117117
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)
118118

119+
# Since we are training on text-only data, we avoid using 'all-linear' which would include the ViT layers.
119120
lora_config = LoraConfig(
120121
target_modules=[
121122
'q_proj', 'k_proj', 'v_proj', 'o_proj',

src/twinkle/model/megatron/megatron.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,11 +1417,11 @@ def _print_weight_example(names):
14171417
def _add_base_layer_suffix(name):
14181418
if name.endswith('.weight'):
14191419
base_layer_name = f'{name[:-7]}.base_layer.weight'
1420-
if base_layer_name in model_keys or not model_keys:
1420+
if not model_keys or base_layer_name in model_keys:
14211421
name = base_layer_name
14221422
elif name.endswith('.bias'):
14231423
base_layer_name = f'{name[:-5]}.base_layer.bias'
1424-
if base_layer_name in model_keys or not model_keys:
1424+
if not model_keys or base_layer_name in model_keys:
14251425
name = base_layer_name
14261426
return name
14271427

0 commit comments

Comments
 (0)