From 670422e94c14fdaa29ee08ee6e4a30d06c7fe468 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 7 Apr 2026 23:01:47 +0800 Subject: [PATCH 1/2] compat transformers 5.4.0 --- src/mcore_bridge/bridge/gpt_bridge.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 95b55de..2716993 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -645,6 +645,7 @@ def _set_moe_state( hf_prefix: str, layer_idx: int, to_mcore: bool, + is_mtp: bool = False, ): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) @@ -689,6 +690,7 @@ def _set_moe_state( layer_idx, to_mcore, ep_rank=ep_rank, + is_mtp=is_mtp, )) if to_mcore: hf_state_dict = {} @@ -696,8 +698,10 @@ def _set_moe_state( hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _get_hf_experts_attr(self): + def _get_hf_experts_attr(self, is_mtp: bool = False): # return hf_grouped, is_gate_up + if self.model_type == 'qwen3_5_moe' and not is_mtp: + return True, True if self.model_type in {'glm4v_moe', 'kimi_vl', 'qwen3_omni_moe'} or self.llm_model_type in { 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'dots1', 'ernie4_5_moe', 'glm4_moe', 'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'qwen3_5_moe', 'glm_moe_dsa', 'deepseek_v32' @@ -723,6 +727,7 @@ def _set_mlp_state( layer_idx: int, to_mcore: bool, ep_rank: Optional[int] = None, + is_mtp: bool = False, ): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) @@ -738,7 +743,7 @@ def _set_mlp_state( is_gate_up = any('gate_up_proj' in k for k in hf_state_dict.keys()) # transformers 5.0 compatibility if not to_mcore and is_expert: - hf_grouped, is_gate_up = self._get_hf_experts_attr() + hf_grouped, is_gate_up = self._get_hf_experts_attr(is_mtp) need_transpose = False if hf_grouped: need_transpose = self._get_need_transpose() @@ -1401,7 +1406,7 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo 'input_layernorm.weight', to_mcore) return hf_state_dict - def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): + def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool, is_mtp: bool = False): mg_mlp = None if mg_layer is None else mg_layer.mlp is_moe = True if hasattr(mg_mlp, 'experts') else False if not to_mcore: @@ -1410,7 +1415,8 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool dist.all_reduce(is_moe, group=self.pp_group) if is_moe: hf_state_dict.update( - self._set_moe_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore)) + self._set_moe_state( + mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp=is_mtp)) self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', to_mcore) else: @@ -1608,7 +1614,7 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) hf_state_dict.update(self._set_layer_attn(transformer_layer, hf_state_dict, -1, to_mcore)) - hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore, is_mtp=True)) if to_mcore: hf_state_dict = {} else: From a331ddc1c293d7892563c64541e0b4ac9c27ca22 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 8 Apr 2026 11:57:21 +0800 Subject: [PATCH 2/2] fix --- src/mcore_bridge/model/gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 55443b7..4c108fb 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -402,7 +402,7 @@ def _postprocess( if self.config.is_multimodal and self.config.context_parallel_size > 1: input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1) - if self.mtp_process: + if self.mtp_process and labels is not None: if self.config.is_multimodal: embedding_ = (self.embedding, decoder_input) else: