diff --git a/src/mcore_bridge/model/gpts/qwen3_next.py b/src/mcore_bridge/model/gpts/qwen3_next.py index ec87ddb..bc8aee8 100644 --- a/src/mcore_bridge/model/gpts/qwen3_next.py +++ b/src/mcore_bridge/model/gpts/qwen3_next.py @@ -516,12 +516,13 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore) return hf_state_dict - def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): + def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool, is_mtp: bool = False): if self.model_type != 'qwen3_5': - return super()._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore) + return super()._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore, is_mtp=is_mtp) # dense mg_mlp = None if mg_layer is None else mg_layer.mlp - hf_state_dict.update(self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore)) + hf_state_dict.update( + self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp=is_mtp)) self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', to_mcore) return hf_state_dict