@@ -267,11 +267,16 @@ def _set_module(self, mg_module, hf_state_dict, hf_prefix: str, to_mcore: bool):
267267 new_state_dict = {}
268268 for k , v in hf_state_dict .items ():
269269 if self ._peft_format :
270- if '.lora_A.' in k or '.lora_B.' in k or '.modules_to_save.' in k :
271- k = k .replace (f'{ self ._adapter_name } .' , '' )
270+ # Without adding a leading '.' here (e.g., '.lora_A.'),
271+ # we avoid the case where mg_module itself is a linear layer (such as proj1).
272+ if ('lora_A.' in k or 'lora_B.' in k
273+ or 'modules_to_save.' in k ) and f'.{ self ._adapter_name } .' in k :
274+ k = k .replace (f'.{ self ._adapter_name } .' , '.' )
272275 new_state_dict [k ] = v
273276 else :
274- if '.lora_A.' in k or '.lora_B.' in k or 'original_module.' in k :
277+ if 'lora_A.' in k or 'lora_B.' in k or 'original_module.' in k :
278+ continue
279+ if 'modules_to_save.' in k and f'modules_to_save.{ self ._adapter_name } .' not in k :
275280 continue
276281 k = k .replace ('base_layer.' , '' )
277282 k = k .replace (f'modules_to_save.{ self ._adapter_name } .' , '' )
@@ -1324,24 +1329,26 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i
13241329 hf_state_dict ['in_proj_b.weight_scale_inv' ] = scale_inv [qkv_block + z_block :- a_block ].clone ()
13251330 hf_state_dict ['in_proj_a.weight_scale_inv' ] = scale_inv [- a_block :].clone ()
13261331 del in_proj_weight
1327- if to_mcore :
1328- conv1d = hf_state_dict ['conv1d.weight' ].load ()
1329- q_c , k_c , v_c = torch .split (conv1d , [key_dim , key_dim , value_dim ], dim = 0 )
1330- conv1d = torch .cat ([
1331- * (x .reshape (num_key_heads , - 1 , * conv1d .shape [- 2 :]) for x in [q_c , k_c , v_c ]),
1332- ], dim = 1 ).reshape ((- 1 , * conv1d .shape [- 2 :]))
1333- self ._set_weight (mg_attn .conv1d .weight , conv1d , 'conv1d.weight' )
1334- else :
1335- conv1d , _ = self ._get_weight (None if mg_attn is None else mg_attn .conv1d .weight , 'conv1d.weight' )
1336- if conv1d is not None :
1337- conv1d = conv1d .reshape (num_key_heads , - 1 , * conv1d .shape [- 2 :])
1338- q_c , k_c , v_c = torch .split (
1339- conv1d , [key_dim // num_key_heads , key_dim // num_key_heads , value_dim // num_key_heads ], dim = 1 )
1340- q_c = q_c .reshape (- 1 , * q_c .shape [- 2 :])
1341- k_c = k_c .reshape (- 1 , * k_c .shape [- 2 :])
1342- v_c = v_c .reshape (- 1 , * v_c .shape [- 2 :])
1343- conv1d = torch .concat ([q_c , k_c , v_c ], dim = 0 )
1344- hf_state_dict ['conv1d.weight' ] = conv1d
1332+ if not self ._peft_format :
1333+ if to_mcore :
1334+ conv1d = hf_state_dict ['conv1d.weight' ].load ()
1335+ q_c , k_c , v_c = torch .split (conv1d , [key_dim , key_dim , value_dim ], dim = 0 )
1336+ conv1d = torch .cat ([
1337+ * (x .reshape (num_key_heads , - 1 , * conv1d .shape [- 2 :]) for x in [q_c , k_c , v_c ]),
1338+ ],
1339+ dim = 1 ).reshape ((- 1 , * conv1d .shape [- 2 :]))
1340+ self ._set_weight (mg_attn .conv1d .weight , conv1d , 'conv1d.weight' )
1341+ else :
1342+ conv1d , _ = self ._get_weight (None if mg_attn is None else mg_attn .conv1d .weight , 'conv1d.weight' )
1343+ if conv1d is not None :
1344+ conv1d = conv1d .reshape (num_key_heads , - 1 , * conv1d .shape [- 2 :])
1345+ q_c , k_c , v_c = torch .split (
1346+ conv1d , [key_dim // num_key_heads , key_dim // num_key_heads , value_dim // num_key_heads ], dim = 1 )
1347+ q_c = q_c .reshape (- 1 , * q_c .shape [- 2 :])
1348+ k_c = k_c .reshape (- 1 , * k_c .shape [- 2 :])
1349+ v_c = v_c .reshape (- 1 , * v_c .shape [- 2 :])
1350+ conv1d = torch .concat ([q_c , k_c , v_c ], dim = 0 )
1351+ hf_state_dict ['conv1d.weight' ] = conv1d
13451352 self ._set_state_dict (mg_attn , 'dt_bias' , hf_state_dict , 'dt_bias' , to_mcore )
13461353 self ._set_state_dict (mg_attn , 'A_log' , hf_state_dict , 'A_log' , to_mcore )
13471354 self ._set_state_dict (mg_attn , 'out_norm.weight' , hf_state_dict , 'norm.weight' , to_mcore )
@@ -1703,7 +1710,7 @@ def export_weights(
17031710 self .config = mg_models [0 ].config
17041711 with torch .no_grad ():
17051712 for k , v in self ._convert (mg_models , {}, hf_prefix , False , tqdm_desc = tqdm_desc ):
1706- if converter :
1713+ if converter and v is not None :
17071714 kv = converter (k , v )
17081715 if kv is None :
17091716 continue
0 commit comments