Skip to content

Commit d3317bb

Browse files
authored
[Models] Lfm2Moe: minor name changes for resolving lora conflicts (vllm-project#29063)
Signed-off-by: Paul Pak <paulpak58@gmail.com>
1 parent 8e61425 commit d3317bb

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

vllm/model_executor/models/lfm2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def __init__(
248248
) -> None:
249249
super().__init__()
250250
self.layer_idx = layer_idx
251-
self.conv = ShortConv(
251+
self.short_conv = ShortConv(
252252
config=config,
253253
dim=config.conv_dim,
254254
layer_idx=layer_idx,
@@ -281,7 +281,7 @@ def forward(
281281
else:
282282
hidden_states, residual = self.operator_norm(hidden_states, residual)
283283
output = torch.empty_like(hidden_states)
284-
self.conv(
284+
self.short_conv(
285285
hidden_states,
286286
output,
287287
)
@@ -380,6 +380,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
380380
params_dict = dict(self.named_parameters())
381381
loaded_params: set[str] = set()
382382
for name, loaded_weight in weights:
383+
if ".conv." in name:
384+
name = name.replace(".conv.", ".short_conv.", 1)
385+
383386
for param_name, weight_name, shard_id in stacked_params_mapping:
384387
if weight_name not in name:
385388
continue
@@ -414,6 +417,7 @@ class Lfm2ForCausalLM(
414417
"w1",
415418
"w3",
416419
],
420+
"in_proj": ["in_proj"],
417421
}
418422

419423
# LoRA specific attributes

vllm/model_executor/models/lfm2_moe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def __init__(
349349
) -> None:
350350
super().__init__()
351351
self.layer_idx = layer_idx
352-
self.conv = ShortConv(
352+
self.short_conv = ShortConv(
353353
config=config,
354354
dim=config.hidden_size,
355355
layer_idx=layer_idx,
@@ -388,7 +388,7 @@ def forward(
388388
else:
389389
hidden_states, residual = self.operator_norm(hidden_states, residual)
390390
output = torch.empty_like(hidden_states)
391-
self.conv(
391+
self.short_conv(
392392
hidden_states,
393393
output,
394394
)
@@ -509,6 +509,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
509509
if "expert_bias" in name:
510510
name = name.replace("expert_bias", "gate.e_score_correction_bias")
511511

512+
if ".conv." in name:
513+
name = name.replace(".conv.", ".short_conv.", 1)
514+
512515
for param_name, weight_name, shard_id in stacked_params_mapping:
513516
# Skip non-stacked layers and experts (experts handled below).
514517
if weight_name not in name:
@@ -595,6 +598,7 @@ class Lfm2MoeForCausalLM(
595598
"w1",
596599
"w3",
597600
],
601+
"in_proj": ["in_proj"],
598602
}
599603

600604
# LoRA specific attributes

0 commit comments

Comments
 (0)