Skip to content

Commit 95a8a6b

Browse files
authored
[bugfix] Fix the multi-LoRA issue in Twinkle (#24)
1 parent 4636e98 commit 95a8a6b

1 file changed

Lines changed: 29 additions & 22 deletions

File tree

src/mcore_bridge/bridge/gpt_bridge.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,16 @@ def _set_module(self, mg_module, hf_state_dict, hf_prefix: str, to_mcore: bool):
267267
new_state_dict = {}
268268
for k, v in hf_state_dict.items():
269269
if self._peft_format:
270-
if '.lora_A.' in k or '.lora_B.' in k or '.modules_to_save.' in k:
271-
k = k.replace(f'{self._adapter_name}.', '')
270+
# Without adding a leading '.' here (e.g., '.lora_A.'),
271+
# we avoid the case where mg_module itself is a linear layer (such as proj1).
272+
if ('lora_A.' in k or 'lora_B.' in k
273+
or 'modules_to_save.' in k) and f'.{self._adapter_name}.' in k:
274+
k = k.replace(f'.{self._adapter_name}.', '.')
272275
new_state_dict[k] = v
273276
else:
274-
if '.lora_A.' in k or '.lora_B.' in k or 'original_module.' in k:
277+
if 'lora_A.' in k or 'lora_B.' in k or 'original_module.' in k:
278+
continue
279+
if 'modules_to_save.' in k and f'modules_to_save.{self._adapter_name}.' not in k:
275280
continue
276281
k = k.replace('base_layer.', '')
277282
k = k.replace(f'modules_to_save.{self._adapter_name}.', '')
@@ -1324,24 +1329,26 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i
13241329
hf_state_dict['in_proj_b.weight_scale_inv'] = scale_inv[qkv_block + z_block:-a_block].clone()
13251330
hf_state_dict['in_proj_a.weight_scale_inv'] = scale_inv[-a_block:].clone()
13261331
del in_proj_weight
1327-
if to_mcore:
1328-
conv1d = hf_state_dict['conv1d.weight'].load()
1329-
q_c, k_c, v_c = torch.split(conv1d, [key_dim, key_dim, value_dim], dim=0)
1330-
conv1d = torch.cat([
1331-
*(x.reshape(num_key_heads, -1, *conv1d.shape[-2:]) for x in [q_c, k_c, v_c]),
1332-
], dim=1).reshape((-1, *conv1d.shape[-2:]))
1333-
self._set_weight(mg_attn.conv1d.weight, conv1d, 'conv1d.weight')
1334-
else:
1335-
conv1d, _ = self._get_weight(None if mg_attn is None else mg_attn.conv1d.weight, 'conv1d.weight')
1336-
if conv1d is not None:
1337-
conv1d = conv1d.reshape(num_key_heads, -1, *conv1d.shape[-2:])
1338-
q_c, k_c, v_c = torch.split(
1339-
conv1d, [key_dim // num_key_heads, key_dim // num_key_heads, value_dim // num_key_heads], dim=1)
1340-
q_c = q_c.reshape(-1, *q_c.shape[-2:])
1341-
k_c = k_c.reshape(-1, *k_c.shape[-2:])
1342-
v_c = v_c.reshape(-1, *v_c.shape[-2:])
1343-
conv1d = torch.concat([q_c, k_c, v_c], dim=0)
1344-
hf_state_dict['conv1d.weight'] = conv1d
1332+
if not self._peft_format:
1333+
if to_mcore:
1334+
conv1d = hf_state_dict['conv1d.weight'].load()
1335+
q_c, k_c, v_c = torch.split(conv1d, [key_dim, key_dim, value_dim], dim=0)
1336+
conv1d = torch.cat([
1337+
*(x.reshape(num_key_heads, -1, *conv1d.shape[-2:]) for x in [q_c, k_c, v_c]),
1338+
],
1339+
dim=1).reshape((-1, *conv1d.shape[-2:]))
1340+
self._set_weight(mg_attn.conv1d.weight, conv1d, 'conv1d.weight')
1341+
else:
1342+
conv1d, _ = self._get_weight(None if mg_attn is None else mg_attn.conv1d.weight, 'conv1d.weight')
1343+
if conv1d is not None:
1344+
conv1d = conv1d.reshape(num_key_heads, -1, *conv1d.shape[-2:])
1345+
q_c, k_c, v_c = torch.split(
1346+
conv1d, [key_dim // num_key_heads, key_dim // num_key_heads, value_dim // num_key_heads], dim=1)
1347+
q_c = q_c.reshape(-1, *q_c.shape[-2:])
1348+
k_c = k_c.reshape(-1, *k_c.shape[-2:])
1349+
v_c = v_c.reshape(-1, *v_c.shape[-2:])
1350+
conv1d = torch.concat([q_c, k_c, v_c], dim=0)
1351+
hf_state_dict['conv1d.weight'] = conv1d
13451352
self._set_state_dict(mg_attn, 'dt_bias', hf_state_dict, 'dt_bias', to_mcore)
13461353
self._set_state_dict(mg_attn, 'A_log', hf_state_dict, 'A_log', to_mcore)
13471354
self._set_state_dict(mg_attn, 'out_norm.weight', hf_state_dict, 'norm.weight', to_mcore)
@@ -1703,7 +1710,7 @@ def export_weights(
17031710
self.config = mg_models[0].config
17041711
with torch.no_grad():
17051712
for k, v in self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc):
1706-
if converter:
1713+
if converter and v is not None:
17071714
kv = converter(k, v)
17081715
if kv is None:
17091716
continue

0 commit comments

Comments
 (0)