diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo.py index 30d5d898..7fc3f2fd 100644 --- a/cookbook/rl/grpo.py +++ b/cookbook/rl/grpo.py @@ -70,6 +70,7 @@ def main(): twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) # lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05) + # Since we are training on text-only data, we avoid using 'all-linear' which would include the ViT layers. lora_config = LoraConfig( target_modules=[ 'q_proj', 'k_proj', 'v_proj', 'o_proj', diff --git a/cookbook/rl/grpo_mm.py b/cookbook/rl/grpo_mm.py index 0705febb..9398ce8f 100644 --- a/cookbook/rl/grpo_mm.py +++ b/cookbook/rl/grpo_mm.py @@ -134,11 +134,7 @@ def main(): # LoRA configuration lora_config = LoraConfig( - target_modules=[ - 'q_proj', 'k_proj', 'v_proj', 'o_proj', - 'gate_proj', 'up_proj', 'down_proj', - 'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b', 'out_proj', - ], + target_modules=['all-linear'], # including ViT and Merger/Connector r=16, lora_alpha=32, lora_dropout=0.05, @@ -185,6 +181,7 @@ def main(): # lora will be merged into the base model and sync all weights to vLLM 'enable_lora': True, 'limit_mm_per_prompt': {'image': 9}, # OlympiadBench has up to 9 images + 'enable_tower_connector_lora': True, # enable ViT(tower) and Merger(connector) LoRA on vLLM side }, device_mesh=sampler_mesh, remote_group='sampler', diff --git a/cookbook/rl/short_math_grpo.py b/cookbook/rl/short_math_grpo.py index 8f498923..bbfda68b 100644 --- a/cookbook/rl/short_math_grpo.py +++ b/cookbook/rl/short_math_grpo.py @@ -116,6 +116,7 @@ def main(): sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) + # Since we are training on text-only data, we avoid using 'all-linear' which would include the ViT layers. lora_config = LoraConfig( target_modules=[ 'q_proj', 'k_proj', 'v_proj', 'o_proj', diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 4160f591..c134e41c 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1417,11 +1417,11 @@ def _print_weight_example(names): def _add_base_layer_suffix(name): if name.endswith('.weight'): base_layer_name = f'{name[:-7]}.base_layer.weight' - if base_layer_name in model_keys or not model_keys: + if not model_keys or base_layer_name in model_keys: name = base_layer_name elif name.endswith('.bias'): base_layer_name = f'{name[:-5]}.base_layer.bias' - if base_layer_name in model_keys or not model_keys: + if not model_keys or base_layer_name in model_keys: name = base_layer_name return name