Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions src/twinkle/model/megatron/model/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import math
import os
import re
import shutil
import torch
import torch.distributed as dist
import torch.nn.functional as F
Expand Down Expand Up @@ -733,6 +735,15 @@ def _set_moe_state(
hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
return hf_state_dict

def _get_hf_grouped(self):
if self.args.hf_model_type in {
'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'dots1', 'ernie4_5_moe', 'glm4_moe',
'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe',
'qwen3_5_moe'
}:
return False, False
return None, None

def _set_mlp_state(self,
mg_mlp,
hf_state_dict,
Expand All @@ -741,25 +752,43 @@ def _set_mlp_state(self,
to_mcore: bool,
ep_rank: Optional[int] = None,
hf_mlp=None):
if to_mcore:
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
if hf_mlp is None:
hf_mlp = self._get_hf_mlp(layer_idx)
is_expert = ep_rank is not None
num_local_experts = 1
hf_grouped = False
args = self.args
if is_expert:
hf_grouped = not hasattr(hf_mlp.experts, '__len__')
hf_mlp = hf_mlp.experts if hf_grouped else hf_mlp.experts[0]
hf_mlp = hf_mlp.experts
# When converting to_mcore, hf_grouped is determined by default from the hf_state_dict condition.
# When converting to_hf, it is determined by default from the hf_mlp condition.
if to_mcore:
pattern = r'\d+\.down_proj'
hf_grouped = not any(re.match(pattern, k) is not None for k in hf_state_dict.keys())
else:
hf_grouped = not hasattr(hf_mlp, '__len__')
if hasattr(hf_mlp, '__len__'):
hf_mlp = hf_mlp[0]
num_local_experts = args.num_experts // self.ep_size
# TODO: Temporary modification for transformers 5.0 compatibility with GLM4.6v, to be fixed later
is_gate_up = hasattr(hf_mlp, 'gate_up_proj')
if self.is_transformers_5 and self.args.hf_model_type in {'glm4v_moe', 'glm4_moe_lite'}:
hf_grouped = False
is_gate_up = False
if to_mcore or hf_grouped:
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
if to_mcore:
is_gate_up = any('gate_up_proj' in k for k in hf_state_dict.keys())
else:
is_gate_up = hasattr(hf_mlp, 'gate_up_proj')
# transformers 5.0 compatibility
if self.is_transformers_5 and not to_mcore and is_expert:
_hf_grouped, _is_gate_up = self._get_hf_grouped()
if _hf_grouped is not None:
hf_grouped = _hf_grouped
if _is_gate_up is not None:
is_gate_up = _is_gate_up

if hf_grouped and not to_mcore:
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
elif not to_mcore:
hf_state_dict = {}

# linear_fc1
if to_mcore:
has_scale_inv = any('_scale_inv' in k for k in hf_state_dict.keys())
Expand Down
Loading