diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py index a4b59c9c..d4f076bf 100644 --- a/src/twinkle/model/megatron/model/gpt_bridge.py +++ b/src/twinkle/model/megatron/model/gpt_bridge.py @@ -3,6 +3,8 @@ import math import os +import re +import shutil import torch import torch.distributed as dist import torch.nn.functional as F @@ -733,6 +735,15 @@ def _set_moe_state( hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict + def _get_hf_grouped(self): + if self.args.hf_model_type in { + 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'dots1', 'ernie4_5_moe', 'glm4_moe', + 'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe', + 'qwen3_5_moe' + }: + return False, False + return None, None + def _set_mlp_state(self, mg_mlp, hf_state_dict, @@ -741,6 +752,8 @@ def _set_mlp_state(self, to_mcore: bool, ep_rank: Optional[int] = None, hf_mlp=None): + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) if hf_mlp is None: hf_mlp = self._get_hf_mlp(layer_idx) is_expert = ep_rank is not None @@ -748,18 +761,34 @@ def _set_mlp_state(self, hf_grouped = False args = self.args if is_expert: - hf_grouped = not hasattr(hf_mlp.experts, '__len__') - hf_mlp = hf_mlp.experts if hf_grouped else hf_mlp.experts[0] + hf_mlp = hf_mlp.experts + # When converting to_mcore, hf_grouped is determined by default from the hf_state_dict condition. + # When converting to_hf, it is determined by default from the hf_mlp condition. + if to_mcore: + pattern = r'\d+\.down_proj' + hf_grouped = not any(re.match(pattern, k) is not None for k in hf_state_dict.keys()) + else: + hf_grouped = not hasattr(hf_mlp, '__len__') + if hasattr(hf_mlp, '__len__'): + hf_mlp = hf_mlp[0] num_local_experts = args.num_experts // self.ep_size - # TODO: Temporary modification for transformers 5.0 compatibility with GLM4.6v, to be fixed later - is_gate_up = hasattr(hf_mlp, 'gate_up_proj') - if self.is_transformers_5 and self.args.hf_model_type in {'glm4v_moe', 'glm4_moe_lite'}: - hf_grouped = False - is_gate_up = False - if to_mcore or hf_grouped: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + if to_mcore: + is_gate_up = any('gate_up_proj' in k for k in hf_state_dict.keys()) else: + is_gate_up = hasattr(hf_mlp, 'gate_up_proj') + # transformers 5.0 compatibility + if self.is_transformers_5 and not to_mcore and is_expert: + _hf_grouped, _is_gate_up = self._get_hf_grouped() + if _hf_grouped is not None: + hf_grouped = _hf_grouped + if _is_gate_up is not None: + is_gate_up = _is_gate_up + + if hf_grouped and not to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + elif not to_mcore: hf_state_dict = {} + # linear_fc1 if to_mcore: has_scale_inv = any('_scale_inv' in k for k in hf_state_dict.keys())