@@ -516,12 +516,13 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo
516516 self ._set_state_dict (mg_layer , 'input_layernorm.weight' , hf_state_dict , 'input_layernorm.weight' , to_mcore )
517517 return hf_state_dict
518518
519- def _set_layer_mlp (self , mg_layer , hf_state_dict , layer_idx : int , to_mcore : bool ):
519+ def _set_layer_mlp (self , mg_layer , hf_state_dict , layer_idx : int , to_mcore : bool , is_mtp : bool = False ):
520520 if self .model_type != 'qwen3_5' :
521- return super ()._set_layer_mlp (mg_layer , hf_state_dict , layer_idx , to_mcore )
521+ return super ()._set_layer_mlp (mg_layer , hf_state_dict , layer_idx , to_mcore , is_mtp = is_mtp )
522522 # dense
523523 mg_mlp = None if mg_layer is None else mg_layer .mlp
524- hf_state_dict .update (self ._set_mlp_state (mg_mlp , hf_state_dict , f'{ self .hf_mlp_prefix } .' , layer_idx , to_mcore ))
524+ hf_state_dict .update (
525+ self ._set_mlp_state (mg_mlp , hf_state_dict , f'{ self .hf_mlp_prefix } .' , layer_idx , to_mcore , is_mtp = is_mtp ))
525526 self ._set_state_dict (mg_layer , 'pre_mlp_layernorm.weight' , hf_state_dict , 'post_attention_layernorm.weight' ,
526527 to_mcore )
527528 return hf_state_dict
0 commit comments