From f9b938de3e6fdd1b3374b8c3144a8deb559f8a5a Mon Sep 17 00:00:00 2001 From: xiaotaoliu Date: Fri, 17 Oct 2025 11:08:25 +0800 Subject: [PATCH] fea: qwen3next support feat: qwen3 next fix: qwen3next gated conv weight split tp feat: qwen3next mcore to hf remove useless log --- .../qwen3_next/cmp_origin_and_reconverted.py | 90 ++++++ example/qwen3_next/load_model_and_forward.py | 274 ++++++++++++++++++ example/qwen3_next/offload_hf_fwd.py | 54 ++++ example/qwen3_next/run.sh | 12 + mbridge/core/bridge.py | 242 +++++++++++++++- mbridge/core/safetensor_io.py | 10 +- mbridge/models/__init__.py | 1 + mbridge/models/qwen3_next.py | 83 ++++++ mbridge/utils/post_creation_callbacks.py | 2 +- 9 files changed, 750 insertions(+), 18 deletions(-) create mode 100644 example/qwen3_next/cmp_origin_and_reconverted.py create mode 100644 example/qwen3_next/load_model_and_forward.py create mode 100644 example/qwen3_next/offload_hf_fwd.py create mode 100644 example/qwen3_next/run.sh create mode 100644 mbridge/models/qwen3_next.py diff --git a/example/qwen3_next/cmp_origin_and_reconverted.py b/example/qwen3_next/cmp_origin_and_reconverted.py new file mode 100644 index 0000000..bc5c448 --- /dev/null +++ b/example/qwen3_next/cmp_origin_and_reconverted.py @@ -0,0 +1,90 @@ +import torch +import os +from transformers import AutoModelForCausalLM, AutoConfig +import argparse + +def compare_checkpoints(path1, path2, tolerance=1e-6): + print(f"loading: {path1}") + model1 = AutoModelForCausalLM.from_pretrained(path1, + dtype="auto", + device_map="auto") + print(f"loading: {path2}") + model2 = AutoModelForCausalLM.from_pretrained(path2, + dtype="auto", + device_map="auto") + + state_dict1 = model1.state_dict() + state_dict2 = model2.state_dict() + + print(f"number of params: {len(state_dict1)} {len(state_dict2)}") + + keys1 = set(state_dict1.keys()) + keys2 = set(state_dict2.keys()) + + if keys1 != keys2: + print("❌ params name mismatch") + only_in_1 = keys1 - keys2 + only_in_2 = keys2 - keys1 + if only_in_1: + print(f"params only in model1: {only_in_1}") + if only_in_2: + print(f"params only in model2: {only_in_2}") + return False + + print("✅ params name match") + + all_match = True + mismatch_count = 0 + + error_key = [] + + for key in state_dict1.keys(): + param1 = state_dict1[key] + param2 = state_dict2[key] + + if param1.shape != param2.shape: + print(f"❌ {key}: mismatch {param1.shape} vs {param2.shape}") + all_match = False + mismatch_count += 1 + error_key.append(key) + continue + + if torch.allclose(param1, param2, atol=tolerance): + print(f"✅ {key} shape {param1.shape} and weight is equal") + else: + # 计算差异统计 + diff = torch.abs(param1 - param2) + max_diff = torch.max(diff).item() + mean_diff = torch.mean(diff).item() + print(f"❌ {key}: shape {param1.shape}, max diff: {max_diff:.6f}, avg diff: {mean_diff:.6f} {param1.sum()} {param2.sum()}") + all_match = False + mismatch_count += 1 + error_key.append(key) + + if all_match: + print("match successfully") + else: + print(f"failed {mismatch_count} params {error_key=}") + + return all_match + +def main(): + parser = argparse.ArgumentParser(description='compare') + parser.add_argument('--path1', type=str, required=True, help='fisrt checkpoint') + parser.add_argument('--path2', type=str, required=True, help='second checkpoint') + parser.add_argument('--tolerance', type=float, default=1e-6, help='') + + args = parser.parse_args() + + if not os.path.exists(args.path1): + print(f"error: {args.path1}") + return + + if not os.path.exists(args.path2): + print(f"error: {args.path2}") + return + + compare_checkpoints(args.path1, args.path2, args.tolerance) + +if __name__ == "__main__": + main() diff --git a/example/qwen3_next/load_model_and_forward.py b/example/qwen3_next/load_model_and_forward.py new file mode 100644 index 0000000..00631f6 --- /dev/null +++ b/example/qwen3_next/load_model_and_forward.py @@ -0,0 +1,274 @@ +# Example to use tp/pp/cp/vpp to test dense model +# torchrun --nproc_per_node=8 load_model_and_export.py --model_path /path/to/model + + +import argparse +import json +import os +from typing import List +import requests + +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, AutoTokenizer + +from megatron.core import parallel_state +from megatron.core import parallel_state as mpu +from megatron.core.models.gpt.gpt_model import ModelType +from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.tensor_parallel.mappings import ( + gather_from_tensor_model_parallel_region, +) + +from mbridge import AutoBridge +from mbridge.utils.post_creation_callbacks import freeze_moe_router + + +# hf logits vs megatron logits +def cos_similarity(a, b): + print(f"a {a.shape} b {b.shape}") + a = a.to(b.device) + a = a.float() + # a = a / a.norm(dim=-1, keepdim=True) + a = torch.exp(a) + a = a / a.norm(dim=-1, keepdim=True) + """ + a = (a - a.mean(dim=-1, keepdim=True)) + a = a / a.norm(dim=-1, keepdim=True) + """ + b = b.float() + # b = b / b.norm(dim=-1, keepdim=True) + b = torch.exp(b) + b = b / b.norm(dim=-1, keepdim=True) + """ + b = (b - b.mean(dim=-1, keepdim=True)) + b = b / b.norm(dim=-1, keepdim=True) + """ + sim = (a * b).sum(dim=-1) + print( + f"hf vs megatron cos_similarity min: {sim.min()}; max: {sim.max()}; mean: {sim.mean()}" + ) + + +def get_ltor_masks_and_position_ids(data, + eod_token, + pad_token, + reset_position_ids, + reset_attention_mask, + eod_mask_loss, + pad_mask_loss): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + micro_batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if reset_attention_mask: + att_mask_batch = micro_batch_size + else: + att_mask_batch = 1 + attention_mask = torch.tril( + torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) + ).view(att_mask_batch, 1, seq_length, seq_length) + + # Loss mask. + loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) + if eod_mask_loss: + loss_mask[data == eod_token] = 0.0 + if pad_mask_loss: + loss_mask[data == pad_token] = 0.0 + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data) + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(micro_batch_size): + + # Find indecies where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] & position_ids[b, data[b] == pad_token] + # Detach indecies from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indecies: + prev_index = 0 + for j in range(eod_index.size()[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i + 1) :] -= i + 1 - prev_index + prev_index = i + 1 + + # Convert attention mask to binary: + attention_mask = attention_mask < 0.5 + + return attention_mask, loss_mask, position_ids + + +def is_first_rank(): + """First tensor and pipeline parallel rank.""" + return ( + parallel_state.is_pipeline_first_stage(ignore_virtual=True) + and parallel_state.get_tensor_model_parallel_rank() == 0 + ) + + +def init_distributed(tp=2, pp=1, cp=1, vpp=1, ep=1, etp=None): + """Initialize distributed environment""" + torch.distributed.init_process_group("nccl") + torch.cuda.set_device(torch.distributed.get_rank()) + if pp <= 1: + vpp = None + mpu.initialize_model_parallel( + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, + virtual_pipeline_model_parallel_size=vpp, + context_parallel_size=cp, + expert_model_parallel_size=ep, + expert_tensor_parallel_size=etp, + ) + model_parallel_cuda_manual_seed(0) + + +def get_args(): + parser = argparse.ArgumentParser(description="Load model and generate text") + parser.add_argument( + "--model_path", type=str, required=True, help="HuggingFace model path" + ) + parser.add_argument("--tp", type=int, default=2, help="Tensor model parallel size") + parser.add_argument( + "--pp", type=int, default=1, help="Pipeline model parallel size" + ) + parser.add_argument("--cp", type=int, default=1, help="Context parallel size") + parser.add_argument( + "--vpp", type=int, default=None, help="Virtual pipeline model parallel size" + ) + parser.add_argument("--ep", type=int, default=1, help="Expert model parallel size") + parser.add_argument( + "--etp", type=int, default=None, help="Expert tensor parallel size" + ) + parser.add_argument( + "--save_path", type=str, default=None, help="Path to save weights" + ) + args = parser.parse_args() + return args + + +def main(args): + # Parse command line arguments + # Initialize distributed environment + init_distributed( + tp=args.tp, + pp=args.pp, + cp=args.cp, + vpp=args.vpp, + ep=args.ep, + etp=args.etp, + ) + + # Load megatron model + hf_model_path = args.model_path + print(f"rank{torch.distributed.get_rank()}: {args=} start loading model ...") + bridge = AutoBridge.from_pretrained(hf_model_path) + bridge.config.sequence_parallel = True if args.tp > 1 else False + model = bridge.get_model() + # if torch.distributed.get_rank() == 0: + # print(f"Model arch {model} len {len(model)}") + + # torch.distributed.barrier() + bridge.load_weights(model, hf_model_path, memory_efficient=True) + print(f"rank{torch.distributed.get_rank()}: end load weight, start forward ...") + torch.distributed.barrier() + for pname, params in model[0].named_parameters(): + if torch.distributed.get_rank() == torch.distributed.get_world_size() - 1: + print(f"Trace export_weights {pname=} shape {params.shape=} dtype {params.dtype=} {params.sum()}") + + tokenizer = AutoTokenizer.from_pretrained(hf_model_path) + prompt = "李白,字太白,号" + messages = [ + {"role": "user", "content": prompt}, + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + input_ids = tokenizer([text], return_tensors="pt")["input_ids"] + + attn_mask, _, pids = get_ltor_masks_and_position_ids( + input_ids, None, tokenizer.pad_token_id, False, False, False, True + ) + sample_list = [{"input_ids": input_ids, "attention_mask": attn_mask, "position_ids": pids}] + print(f"model input {input_ids.shape=} {input_ids=} {attn_mask.shape=} {pids.shape=}" + f" {attn_mask=} {pids=}") + + with torch.no_grad(): + fwd_bwd_function = get_forward_backward_func() + real_seq_length = input_ids.shape[-1] + seq_length = real_seq_length + if real_seq_length % args.tp != 0: + seq_length = (real_seq_length + args.tp - 1) // args.tp * args.tp + sample_list[0]["input_ids"] = F.pad( + sample_list[0]["input_ids"], + (0, seq_length - real_seq_length, 0, 0), + value=0, + ) + + def mcore_fwd_fn(data_iter, model): + sample = next(data_iter) + + output_tensor = model( + input_ids=sample['input_ids'].cuda(), + position_ids=sample['position_ids'].cuda(), + attention_mask=sample['attention_mask'].cuda(), + ) + if isinstance(output_tensor, tuple): + output_tensor = output_tensor[0] + assert isinstance(output_tensor, torch.Tensor) + def loss_func(output_tensor, non_loss_data=True): + loss = output_tensor.mean() + return loss, { + "loss": loss.detach(), + "logits": output_tensor.detach(), + } + return output_tensor, loss_func + + mcore_output = fwd_bwd_function( + forward_step_func=mcore_fwd_fn, + data_iterator=iter(sample_list), + model=model, + num_microbatches=1, + forward_only=True, + seq_length=seq_length, + decoder_seq_length=seq_length, + micro_batch_size=1, + ) + + if mpu.is_pipeline_last_stage(): + megatron_output = mcore_output[0]["logits"] + if mpu.get_tensor_model_parallel_world_size() > 1: + megatron_output = gather_from_tensor_model_parallel_region( + megatron_output + ) + megatron_output = megatron_output[:, :real_seq_length, :] + torch.save(megatron_output, f"./megatron_qwen3next_tp{args.tp}.pt") + + # hf_output = torch.load("./hf_qwen3next.pt") + # cos_similarity(hf_output, megatron_output) + + torch.distributed.barrier() + torch.distributed.destroy_process_group() + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/example/qwen3_next/offload_hf_fwd.py b/example/qwen3_next/offload_hf_fwd.py new file mode 100644 index 0000000..cec642d --- /dev/null +++ b/example/qwen3_next/offload_hf_fwd.py @@ -0,0 +1,54 @@ +# Example to use tp/pp/cp/vpp to test dense model +# python3 offload_hf_fwd.py --model_path /path/to/model + + +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def get_args(): + parser = argparse.ArgumentParser(description="Load model and generate text") + parser.add_argument( + "--model_path", type=str, required=True, help="HuggingFace model path" + ) + parser.add_argument( + "--save_path", type=str, default=None, help="Path to save weights" + ) + args = parser.parse_args() + return args + + +def load_hf_model_and_forward(args): + hf_model_path = args.model_path + hf_model = AutoModelForCausalLM.from_pretrained( + hf_model_path, + dtype="auto", + device_map="auto",) + + tokenizer = AutoTokenizer.from_pretrained(hf_model_path) + prompt = "李白,字太白,号" + messages = [ + {"role": "user", "content": prompt}, + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + for pname, params in hf_model.named_parameters(): + print(f"Trace export_weights name={pname}: {params.shape=} {params.dtype=} {params.sum()}") + + model_inputs = tokenizer([text], return_tensors="pt").to(hf_model.device) + + with torch.no_grad(): + hf_output = hf_model( + **model_inputs) + print(f"rank hf_output {hf_output.logits}") + torch.save(hf_output.logits, "./hf_qwen3next.pt") + + +if __name__ == "__main__": + args = get_args() + load_hf_model_and_forward(args) diff --git a/example/qwen3_next/run.sh b/example/qwen3_next/run.sh new file mode 100644 index 0000000..a5adbfb --- /dev/null +++ b/example/qwen3_next/run.sh @@ -0,0 +1,12 @@ + +MCORE="../3rdparty/Megatron-LM/" +export PYTHONPATH="$PWD:$MCORE:$PYTHONPATH" + +# model_path="/home/hf-hub/Qwen/Qwen3-Next-80B-A3B-Instruct/" +model_path="/home/hf_reconverted_release" +torchrun --nproc_per_node=8 example/qwen3_next/load_model_and_forward.py \ + --model_path $model_path \ + --tp 2 \ + --pp 2 \ + --ep 4 \ + --etp 1 \ No newline at end of file diff --git a/mbridge/core/bridge.py b/mbridge/core/bridge.py index 5a034c7..62372d1 100644 --- a/mbridge/core/bridge.py +++ b/mbridge/core/bridge.py @@ -143,8 +143,8 @@ def get_model( self.load_weights(model, self._get_actual_hf_path(weight_path)) return model - def _get_safetensor_io(self, weights_path: str): - return SafeTensorIO(self._get_actual_hf_path(weights_path)) + def _get_safetensor_io(self, weights_path: str, ignore_mtp: bool = False): + return SafeTensorIO(self._get_actual_hf_path(weights_path), ignore_mtp=ignore_mtp) def _get_mcore_config_by_name(self, mcore_weights_name: str): return self.config @@ -163,6 +163,7 @@ def load_weights( weights_path: Path to the weights file or Hugging Face model identifier """ self.safetensor_io = self._get_safetensor_io(weights_path) + self_attn_output_gate = getattr(self.config, "attention_output_gate", False) for i, model in enumerate(models): # map local weight names to global weight names @@ -173,6 +174,7 @@ def load_weights( for k, v in local_to_global_map.items() if "_extra_state" not in k } + # only tp_rank0/etp_rank0 load from disk, others load from tp_rank0/etp_rank0 to_load_from_disk = [] for local_name, hf_names in local_to_hf_map.items(): @@ -205,7 +207,8 @@ def load_weights( hf_weights = [ self.safetensor_io.load_one_hf_weight(x) for x in hf_names ] - mcore_weight = self._weight_to_mcore_format(local_name, hf_weights) + mcore_weight = self._weight_to_mcore_format(local_name, hf_weights, + self_attn_output_gate=self_attn_output_gate) else: mcore_weight = None if hf_names[0] in {"lm_head.weight", "model.embed_tokens.weight"}: @@ -214,6 +217,10 @@ def load_weights( ): # skip lm_head.weight when the model is a value model continue + + mcore_weight_shape = None + if mcore_weight is not None: + mcore_weight_shape = mcore_weight.shape param_to_load = torch.empty_like(param) if ".mlp.experts.linear_fc" in local_name: @@ -351,6 +358,8 @@ def export_weights( ), f"should be empty {self.export_weights_buff=}" models = [unwrap_model(model) for model in models] + self_attn_output_gate = getattr(self.config, "attention_output_gate", False) + def get_model_chunk_generator(): for model in models: existing_keys = set() @@ -457,7 +466,7 @@ def get_model_chunk_generator(): name, params, broad_pp_param ) converted_names, converted_params = self._weight_to_hf_format( - name, merge_params + name, merge_params, self_attn_output_gate=self_attn_output_gate ) # Some moe models require multiple weights to be merge into one, such as qwen3vl if len(converted_names) == 0: @@ -489,7 +498,7 @@ def get_model_chunk_generator(): infer_params = broad_pp_param converted_names, converted_params = self._weight_to_hf_format( - name, infer_params + name, infer_params, self_attn_output_gate=self_attn_output_gate ) # Some moe models require multiple weights to be merge into one, such as qwen3vl if len(converted_names) == 0: @@ -736,7 +745,7 @@ def _weight_name_mapping_mcore_to_hf(self, mcore_weights_name: str) -> list[str] return self._weight_name_mapping_other(mcore_weights_name) def _weight_to_hf_format( - self, mcore_weights_name: str, mcore_weights: torch.Tensor + self, mcore_weights_name: str, mcore_weights: torch.Tensor, self_attn_output_gate: bool = False ) -> tuple[list[str], list[torch.Tensor]]: """ Export MCore weights to Hugging Face format. @@ -793,9 +802,25 @@ def _weight_to_hf_format( single_out_shape = ( [-1, hidden_dim] if ".bias" not in mcore_weights_name else [-1] ) + q = qkv[:, :q_len].reshape(*single_out_shape) + g = None + if self_attn_output_gate: + g = qkv[:, q_len : q_len + q_len].reshape(*single_out_shape) + q_len += q_len + k = qkv[:, q_len : q_len + k_len].reshape(*single_out_shape) v = qkv[:, q_len + k_len :].reshape(*single_out_shape) + + if self_attn_output_gate: + _out_shape = ( + [num_attention_heads, -1, hidden_dim] + if ".bias" not in mcore_weights_name + else [num_attention_heads, -1] + ) + q = q.view(_out_shape) + g = g.view(_out_shape) + q = torch.cat([q, g], dim=1).view(*single_out_shape).contiguous() return hf_names, [q, k, v] elif ( @@ -806,10 +831,46 @@ def _weight_to_hf_format( assert len(hf_names) == 2 gate, up = mcore_weights.chunk(2) return hf_names, [gate, up] + elif "self_attention.in_proj.weight" in mcore_weights_name: + assert len(hf_names) == 2 + hidden_size = self.hf_config.hidden_size + linear_num_key_heads = self.hf_config.linear_num_key_heads + linear_key_head_dim = self.hf_config.linear_key_head_dim + linear_num_value_heads = self.hf_config.linear_num_value_heads + linear_value_head_dim = self.hf_config.linear_value_head_dim + + k_dim = linear_num_key_heads * linear_key_head_dim + v_dim = linear_num_value_heads * linear_value_head_dim + split_shape = [ + k_dim, + k_dim, + v_dim, + v_dim, + linear_num_value_heads, + linear_num_value_heads, + ] + weight_lst = mcore_weights.split(split_shape, dim=0) + # weight_lst: [wq, wk, wv, wz, wb, wa] + assert len(weight_lst) == 6 + wq, wk, wv, wz, wb, wa = weight_lst + + qk_shape = [linear_num_key_heads, linear_key_head_dim, -1] + vz_shape = [linear_num_key_heads, v_dim // linear_num_key_heads, -1] + ba_shape = [linear_num_key_heads, linear_num_value_heads // linear_num_key_heads, -1] + wq = wq.view(qk_shape) + wk = wk.view(qk_shape) + wv = wv.view(vz_shape) + wz = wz.view(vz_shape) + wb = wb.view(ba_shape) + wa = wa.view(ba_shape) + + qkvz_weight = torch.cat([wq, wk, wv, wz], dim=1).view([-1, hidden_size]).contiguous() + ba_weight = torch.cat([wb, wa], dim=1).view([-1, hidden_size]).contiguous() + return hf_names, [qkvz_weight, ba_weight] raise NotImplementedError(f"Unsupported parameter name: {mcore_weights_name}") def _weight_to_mcore_format( - self, mcore_weights_name: str, hf_weights: list[torch.Tensor] + self, mcore_weights_name: str, hf_weights: list[torch.Tensor], self_attn_output_gate: bool=False ) -> torch.Tensor: """ Import Hugging Face weights to MCore format. @@ -872,19 +933,44 @@ def _weight_to_mcore_format( group_dim = head_dim * num_attention_heads // num_key_value_heads q, k, v = hf_weights # q k v might be tp split - real_num_key_value_heads = q.shape[0] // group_dim - q = q.view( - [ - real_num_key_value_heads, - group_dim, - -1, - ] - ) + if self_attn_output_gate: + real_num_key_value_heads = q.shape[0] // 2 // group_dim + + combined_w = q.reshape((num_attention_heads, 2 * head_dim, -1)) + q_w = combined_w.narrow(1, 0, head_dim).reshape((num_attention_heads * head_dim, -1)) + g_w = combined_w.narrow(1, head_dim, head_dim).reshape((num_attention_heads * head_dim, -1)) + + q = q_w.view( + [ + real_num_key_value_heads, + group_dim, + -1, + ] + ) + g = g_w.view( + [ + real_num_key_value_heads, + group_dim, + -1, + ] + ) + else: + real_num_key_value_heads = q.shape[0] // group_dim + q = q.view( + [ + real_num_key_value_heads, + group_dim, + -1, + ] + ) k = k.view([real_num_key_value_heads, head_dim, -1]) v = v.view([real_num_key_value_heads, head_dim, -1]) out_shape = [-1, hidden_dim] if ".bias" not in mcore_weights_name else [-1] - qkv = torch.cat([q, k, v], dim=1).view(*out_shape).contiguous() + if self_attn_output_gate: + qkv = torch.cat([q, g, k, v], dim=1).view(*out_shape).contiguous() + else: + qkv = torch.cat([q, k, v], dim=1).view(*out_shape).contiguous() return qkv elif ( "linear_fc1.weight" in mcore_weights_name @@ -894,6 +980,33 @@ def _weight_to_mcore_format( assert len(hf_weights) == 2 gate, up = hf_weights return torch.cat([gate, up], dim=0) + elif "self_attention.in_proj.weight" in mcore_weights_name: + assert len(hf_weights) == 2 + qkvz_weight, ba_weight = hf_weights + linear_num_key_heads = self.hf_config.linear_num_key_heads + linear_key_head_dim = self.hf_config.linear_key_head_dim + linear_num_value_heads = self.hf_config.linear_num_value_heads + linear_value_head_dim = self.hf_config.linear_value_head_dim + key_dim = linear_key_head_dim * linear_num_key_heads + value_dim = linear_value_head_dim * linear_num_value_heads + + qkvz_dim_per_partition = 2 * linear_key_head_dim + 2 * value_dim // linear_num_key_heads + qkvz_weight_ = qkvz_weight.reshape((linear_num_key_heads, qkvz_dim_per_partition, -1)) + wq = qkvz_weight_.narrow(1, 0, linear_key_head_dim).reshape(key_dim, -1) + wk = qkvz_weight_.narrow(1, linear_key_head_dim, linear_key_head_dim).reshape(key_dim, -1) + wv = qkvz_weight_.narrow(1, 2 * linear_key_head_dim, + value_dim // linear_num_key_heads).reshape(value_dim, -1) + wz = qkvz_weight_.narrow(1, 2 * linear_key_head_dim + value_dim // linear_num_key_heads, + value_dim // linear_num_key_heads).reshape(value_dim, -1) + + ba_weight_ = ba_weight.reshape((linear_num_key_heads, 2 * (linear_num_value_heads // linear_num_key_heads), -1)) + wb = ba_weight_.narrow(1, 0, linear_num_value_heads // linear_num_key_heads).reshape((linear_num_value_heads, -1)) + wa = ba_weight_.narrow(1, linear_num_value_heads // linear_num_key_heads, + linear_num_value_heads // linear_num_key_heads).reshape((linear_num_value_heads, -1)) + return torch.cat([wq, wk, wv, wz, wb, wa], dim=0) + else: + print(f"others weights {len(hf_weights)} {[hf_w.shape for hf_w in hf_weights]}") + raise NotImplementedError(f"Unsupported parameter name: {mcore_weights_name}") def _weight_merge_across_tp( @@ -947,6 +1060,36 @@ def _weight_merge_across_tp( elif "mlp.experts.linear_fc2.weight" in mcore_weights_name: # moe ret = torch.cat(mcore_weights, dim=1) + elif "self_attention.in_proj.weight" in mcore_weights_name: + mcore_config = self._get_mcore_config_by_name(mcore_weights_name) + tp_size = len(mcore_weights) + k_dim = mcore_config.linear_num_key_heads * mcore_config.linear_key_head_dim + v_dim = mcore_config.linear_num_value_heads * mcore_config.linear_value_head_dim + split_shape = [ + k_dim // tp_size, + k_dim // tp_size, + v_dim // tp_size, + v_dim // tp_size, + mcore_config.linear_num_value_heads // tp_size, + mcore_config.linear_num_value_heads // tp_size, + ] + # split_shape for [wq, wk, wv, wz, wb, wa] + ret = self._split_weight_by_size_and_merge_across_tp(mcore_weights, split_shape) + elif "self_attention.conv1d" in mcore_weights_name: + if "weight" in mcore_weights_name: + mcore_config = self._get_mcore_config_by_name(mcore_weights_name) + tp_size = len(mcore_weights) + k_dim = mcore_config.linear_num_key_heads * mcore_config.linear_key_head_dim + v_dim = mcore_config.linear_num_value_heads * mcore_config.linear_value_head_dim + split_shape = [ + k_dim // tp_size, + k_dim // tp_size, + v_dim // tp_size, + ] + # split_shape for [X, B, C] + ret = self._split_weight_by_size_and_merge_across_tp(mcore_weights, split_shape) + else: + raise NotImplementedError(f"{mcore_weights_name} not supported yet") else: assert ( hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel @@ -995,6 +1138,50 @@ def _weight_split_across_tp( ret = [torch.cat([g, u], dim=0) for g, u in zip(gates, ups)] elif "mlp.experts.linear_fc2.weight" in mcore_weights_name: # moe ret = mcore_weights.chunk(tp_split_size, dim=1) + elif "self_attention.in_proj.weight" in mcore_weights_name: + mcore_config = self._get_mcore_config_by_name(mcore_weights_name) + k_dim = mcore_config.linear_num_key_heads * mcore_config.linear_key_head_dim + v_dim = mcore_config.linear_num_value_heads * mcore_config.linear_value_head_dim + split_shape = [ + k_dim, + k_dim, + v_dim, + v_dim, + mcore_config.linear_num_value_heads, + mcore_config.linear_num_value_heads, + ] + split_w_lst = mcore_weights.split(split_shape, dim=0) + # split_w_lst: [wq, wk, wv, wz, wb, wa] + assert len(split_w_lst) == 6, f"split_shape {split_shape} not supported" + weight_list = [] + for weight in split_w_lst: + weight_list.append(weight.chunk(tp_split_size)) + ret = [ + torch.cat([wq_slice, wk_slice, wv_slice, wz_slice, wb_slice, wa_slice], dim=0) + for wq_slice, wk_slice, wv_slice, wz_slice, wb_slice, wa_slice in zip(*weight_list) + ] + elif "self_attention.conv1d" in mcore_weights_name: + if "weight" in mcore_weights_name: + mcore_config = self._get_mcore_config_by_name(mcore_weights_name) + k_dim = mcore_config.linear_num_key_heads * mcore_config.linear_key_head_dim + v_dim = mcore_config.linear_num_value_heads * mcore_config.linear_value_head_dim + split_shape = [ + k_dim, + k_dim, + v_dim, + ] + split_w_lst = mcore_weights.split(split_shape, dim=0) + # split_w_lst: [X, B, C] + assert len(split_w_lst) == 3, f"split_shape {split_shape} not supported" + weight_list = [] + for weight in split_w_lst: + weight_list.append(weight.chunk(tp_split_size)) + ret = [ + torch.cat([x_slice, b_slice, c_slice], dim=0) + for x_slice, b_slice, c_slice in zip(*weight_list) + ] + else: + raise NotImplementedError(f"{mcore_weights_name} not supported yet") else: if param.shape == mcore_weights.shape: return [mcore_weights for _ in range(tp_split_size)] @@ -1021,6 +1208,29 @@ def _get_actual_hf_path(self, weight_path: str) -> str: return os.path.dirname(cached_file(weight_path, "config.json")) + def _split_weight_by_size_and_merge_across_tp( + self, + mcore_weights: list[torch.Tensor], + split_shape: list[int], + ) -> torch.Tensor: + ''' + First split weight by splist_shape and then merge across tensor parallel ranks + + use for linear attn in_proj and linear attn conv1d layer weight + ''' + tp_size = len(mcore_weights) + + weight_lst = [[] for _ in range(len(split_shape))] + for mcore_weight in mcore_weights: + split_w_lst = mcore_weight.split(split_shape, dim=0) + assert len(split_w_lst) == len(weight_lst) + for wi, split_w in enumerate(split_w_lst): + weight_lst[wi].append(split_w) + for weight in weight_lst: + assert len(weight) == tp_size + ret = torch.cat([torch.cat(w_split, dim=0) for w_split in weight_lst], dim=0) + return ret + # Model registry _MODEL_REGISTRY = {} diff --git a/mbridge/core/safetensor_io.py b/mbridge/core/safetensor_io.py index 3e6157d..0177537 100644 --- a/mbridge/core/safetensor_io.py +++ b/mbridge/core/safetensor_io.py @@ -11,7 +11,7 @@ class SafeTensorIO: - def __init__(self, hf_dir: str): + def __init__(self, hf_dir: str, ignore_mtp: bool = False): index_file = os.path.join(hf_dir, "model.safetensors.index.json") self.index = {} @@ -19,6 +19,14 @@ def __init__(self, hf_dir: str): if os.path.exists(index_file): with open(index_file, "r") as f: origin_index = json.load(f) + + filtered_index = {} + for key, value in origin_index["weight_map"].items(): + if ignore_mtp and"mtp" in key: + continue + filtered_index[key] = value + origin_index["weight_map"] = filtered_index + self.index = origin_index["weight_map"] self.origin_index = origin_index diff --git a/mbridge/models/__init__.py b/mbridge/models/__init__.py index e195cef..012826a 100644 --- a/mbridge/models/__init__.py +++ b/mbridge/models/__init__.py @@ -30,3 +30,4 @@ from .gemma3 import Gemma3Bridge from .internvl3 import InternVL3Bridge from .qwen3_vl import Qwen3VLBridge, Qwen3VLBridge +from .qwen3_next import Qwen3NextBridge diff --git a/mbridge/models/qwen3_next.py b/mbridge/models/qwen3_next.py new file mode 100644 index 0000000..31b266a --- /dev/null +++ b/mbridge/models/qwen3_next.py @@ -0,0 +1,83 @@ +from ..core import register_model +from .qwen3moe import Qwen3MoEBridge + + +@register_model("qwen3_next") +class Qwen3NextBridge(Qwen3MoEBridge): + + _ATTENTION_MAPPING = { + **(Qwen3MoEBridge._ATTENTION_MAPPING), + "self_attention.dt_bias": [ + "model.layers.{layer_number}.linear_attn.dt_bias" + ], + "self_attention.A_log": [ + "model.layers.{layer_number}.linear_attn.A_log" + ], + "self_attention.in_proj.weight": [ + "model.layers.{layer_number}.linear_attn.in_proj_qkvz.weight", + "model.layers.{layer_number}.linear_attn.in_proj_ba.weight" + ], + "self_attention.conv1d.weight": [ + "model.layers.{layer_number}.linear_attn.conv1d.weight" + ], + "self_attention.out_norm.weight": [ + "model.layers.{layer_number}.linear_attn.norm.weight" + ], + "self_attention.out_proj.weight": [ + "model.layers.{layer_number}.linear_attn.out_proj.weight" + ], + "self_attention.in_proj.layer_norm_weight": [ + "model.layers.{layer_number}.input_layernorm.weight" + ], + } + + def _get_gptmodel_args(self) -> dict: + """ + Gets the arguments for GPTModel initialization. + + Constructs a dictionary of arguments required to initialize a GPTModel + based on the configuration. + + Returns: + dict: A dictionary of arguments for GPTModel initialization + """ + return dict( + vocab_size=self.hf_config.vocab_size, + max_sequence_length=self.hf_config.max_position_embeddings, + position_embedding_type="rope", + rotary_base=self.hf_config.rope_theta, + rotary_percent=self.hf_config.partial_rotary_factor, + ) + + def _build_config(self): + return self._build_base_config( + use_cpu_initialization=False, + # MoE specific + moe_ffn_hidden_size=self.hf_config.moe_intermediate_size, + moe_router_bias_update_rate=0.001, + moe_router_topk=self.hf_config.num_experts_per_tok, + num_moe_experts=self.hf_config.num_experts, + moe_aux_loss_coeff=self.hf_config.router_aux_loss_coef, + moe_router_load_balancing_type="none", # default None for RL + moe_shared_expert_overlap=True, + moe_grouped_gemm=True, + moe_router_score_function="softmax", + moe_shared_expert_intermediate_size=self.hf_config.shared_expert_intermediate_size, + moe_shared_expert_gate=self.hf_config.shared_expert_intermediate_size > 0, + # Qwen specific + moe_router_pre_softmax=False, + qk_layernorm=True, + layernorm_zero_centered_gamma=True, + attention_output_gate=True, + # Qwen3-next and linear attention + kv_channels=self.hf_config.head_dim, + linear_attention_type="gated_delta_net", + linear_attention_freq=self.hf_config.full_attention_interval, + linear_conv_kernel_dim=self.hf_config.linear_conv_kernel_dim, + linear_key_head_dim=self.hf_config.linear_key_head_dim, + linear_value_head_dim=self.hf_config.linear_value_head_dim, + linear_num_key_heads=self.hf_config.linear_num_key_heads, + linear_num_value_heads=self.hf_config.linear_num_value_heads, + zero_centered_gated_delta_norm=False, + #TODO: mtp 相关的参数还没加 + ) diff --git a/mbridge/utils/post_creation_callbacks.py b/mbridge/utils/post_creation_callbacks.py index 3c3f9a4..239cf8b 100644 --- a/mbridge/utils/post_creation_callbacks.py +++ b/mbridge/utils/post_creation_callbacks.py @@ -15,7 +15,7 @@ def freeze_moe_router(model, pre_process, post_process, config, hf_config): if hasattr(layer.mlp, "router"): if hasattr(layer.mlp.router, "weight"): layer.mlp.router.weight.requires_grad = False - if hasattr(layer.mlp.router, "bias"): + if hasattr(layer.mlp.router, "bias") and layer.mlp.router.bias is not None: layer.mlp.router.bias.requires_grad = False if hasattr(layer.mlp, "shared_experts"): if hasattr(layer.mlp.shared_experts, "gate_weight"):