diff --git a/cookbook/megatron/qwen3_5.py b/cookbook/megatron/qwen3_5.py new file mode 100644 index 00000000..8807c066 --- /dev/null +++ b/cookbook/megatron/qwen3_5.py @@ -0,0 +1,46 @@ +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import MegatronModel +from twinkle.preprocessor import SelfCognitionProcessor + +device_mesh = DeviceMesh.from_sizes(dp_size=4, tp_size=1, pp_size=1, ep_size=4) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + +logger = get_logger() + +MODEL_ID = 'Qwen/Qwen3.5-35B-A3B' + +def train(): + dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) + dataset.set_template('Template', model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.encode() + dataloader = DataLoader(dataset=dataset, batch_size=4) + + model = MegatronModel(model_id=MODEL_ID) + lora_config = LoraConfig(r=8, lora_alpha=16, target_modules='all-linear') + model.add_adapter_to_model('default', lora_config) + model.set_optimizer(optimizer_cls='default', lr=1e-4) + model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=2, lr_decay_steps=len(dataloader)) + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}') + + for step, batch in enumerate(dataloader): + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + if step % 5 == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Step {step}/{len(dataloader)}, metric: {metric}') + + # NOTE: you should merge lora for Qwen3.5 model when using Megatron + model.save('last-checkpoint', merge_lora=True) + logger.info('Training completed.') + + +if __name__ == '__main__': + train() diff --git a/src/twinkle/model/megatron/args.py b/src/twinkle/model/megatron/args.py index 858c2f0d..80759d82 100644 --- a/src/twinkle/model/megatron/args.py +++ b/src/twinkle/model/megatron/args.py @@ -107,6 +107,7 @@ class TwinkleMegatronArgs: num_experts: int = 0 num_experts_per_tok: int = 2 shared_expert_intermediate_size: int = 0 + moe_router_enable_expert_bias: bool = False # ========================================================================= # Training/inference settings @@ -137,9 +138,6 @@ class TwinkleMegatronArgs: # ========================================================================= merge_lora: bool = False target_modules: List[str] = field(default_factory=list) - freeze_llm: bool = False - freeze_vit: bool = False - freeze_aligner: bool = False # ========================================================================= # FP8 quantization settings @@ -160,7 +158,6 @@ class TwinkleMegatronArgs: # ========================================================================= untie_embeddings_and_output_weights: bool = True max_shard_size: str = '5GB' - llm_model_type: str = 'gpt' # For transformers 5.0 compatibility use_cpu_initialization: bool = False def __post_init__(self): @@ -260,6 +257,10 @@ def head_dim(self) -> int: def intermediate_size(self) -> int: return self.ffn_hidden_size + @property + def moe_shared_expert_intermediate_size(self) -> int: + return self.shared_expert_intermediate_size + @property def num_query_groups(self) -> int: """Alias for num_key_value_heads (Megatron naming).""" @@ -330,9 +331,12 @@ def from_hf_config( # Get rope_scaling rope_scaling = getattr(text_config, 'rope_scaling', None) - # Detect multimodal model model_type = getattr(hf_config, 'model_type', 'qwen2') - is_multimodal = 'vl' in model_type.lower() or 'vision' in model_type.lower() or 'omni' in model_type.lower() + + # Detect multimodal model from the registered MegatronModelMeta + from .model.register import get_megatron_model_meta + model_meta = get_megatron_model_meta(model_type) + is_multimodal = model_meta.is_multimodal if model_meta is not None else False # Determine QKV bias if hasattr(text_config, 'attention_bias'): @@ -435,7 +439,6 @@ def create_model(self, ) -> List[nn.Module]: if self._model is not None: return self._model from megatron.core import parallel_state as mpu - from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import AttnBackend @@ -611,24 +614,37 @@ def _get_base_model(m): if exists('megatron_core>=0.13'): config.expert_tensor_parallel_size = self.etp_size - # Save transformer config for later use (e.g., DDP wrapping) + if mg_config_dict.get('use_shared_expert_gate'): + config.moe_use_shared_expert_gate = True + if mg_config_dict.get('rotary_interleaved'): + config.rotary_interleaved = True + partial_rotary_factor = mg_config_dict.get('partial_rotary_factor') + if partial_rotary_factor is not None: + config.rotary_percent = partial_rotary_factor + config.apply_rope_fusion = False + mrope_section = mg_config_dict.get('mrope_section') + if mrope_section is not None: + config.mrope_section = mrope_section + if mg_config_dict.get('mrope_interleaved'): + config.mrope_interleaved = True + self.config = config - # Get layer spec - enable moe_grouped_gemm for MoE models - moe_grouped_gemm = num_experts > 0 - try: - layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=mg_config_dict.get('num_experts'), - moe_grouped_gemm=moe_grouped_gemm, - qk_layernorm=mg_config_dict.get('qk_layernorm', False), - ) - except (ImportError, AttributeError): - raise RuntimeError( - 'TransformerEngine is not installed or not compatible with this version of Megatron-Core.') + # Delegate model-specific config & layer spec construction to the loader + loader = model_meta.loader() if model_meta else None + if loader is not None: + loader.post_config(config, self, mg_config_dict) + layer_spec = loader.get_layer_spec(config, self, mg_config_dict) + else: + from .model.register import MegatronModelLoader + default_loader = MegatronModelLoader() + default_loader.post_config(config, self, mg_config_dict) + layer_spec = default_loader.get_layer_spec(config, self, mg_config_dict) # Create model max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096) rotary_base = mg_config_dict.get('rotary_base', 10000) + position_embedding_type = mg_config_dict.get('position_embedding_type', 'rope') extra_init_args = {} if hasattr(hf_config, 'rope_scaling') and hf_config.rope_scaling is not None and 'factor' in hf_config.rope_scaling: @@ -651,7 +667,7 @@ def _get_base_model(m): post_process=mpu.is_pipeline_last_stage(**extra_kwargs), parallel_output=True, share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False), - position_embedding_type='rope', + position_embedding_type=position_embedding_type, rotary_base=rotary_base, **extra_init_args) model.append(_model) @@ -666,7 +682,7 @@ def _get_base_model(m): post_process=mpu.is_pipeline_last_stage(), parallel_output=True, share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False), - position_embedding_type='rope', + position_embedding_type=position_embedding_type, rotary_base=rotary_base, **extra_init_args, ) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index aa74e72e..38b34a7e 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -819,6 +819,7 @@ def save(self, output_dir: Optional[str] = None, interval: int = 1, save_optimizer: bool = False, + merge_lora: bool = False, **kwargs): """Save model checkpoint. @@ -832,6 +833,9 @@ def save(self, interval: Save each *interval* steps. save_optimizer: If True, save optimizer + lr_scheduler + RNG state alongside the HF weights for checkpoint resumption. + merge_lora: If True, merge LoRA adapters into base weights and save + the full merged model instead of PEFT adapter format. The merge + is reversed after saving so training can continue. **kwargs: Additional arguments forwarded to the underlying save methods (e.g. ``adapter_name``). """ @@ -846,8 +850,16 @@ def save(self, output_dir = 'output' checkpoint_dir = os.path.join(output_dir, name) - # Always save HF-format weights (for inference / deployment). - self._save_hf_format(checkpoint_dir, optimizer_config.adapter_name) + is_lora = (optimizer_config.adapter_name != _default_adapter_name) + + if merge_lora and is_lora: + self._merge_lora_adapters(optimizer_config.adapter_name) + self._save_hf_format(checkpoint_dir, _default_adapter_name) + self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name) + self._unmerge_lora_adapters() + else: + self._save_hf_format(checkpoint_dir, optimizer_config.adapter_name) + self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name) # Optionally save mcore optimizer state (for training resumption). if save_optimizer: @@ -857,8 +869,6 @@ def save(self, **kwargs, ) - self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name) - # Final synchronization to ensure all ranks complete save. if dist.is_initialized(): dist.barrier() @@ -1160,6 +1170,24 @@ def _read_iteration(tracker_path: str) -> int: iteration = iters_cuda[0].item() return iteration + def _merge_lora_adapters(self, adapter_name: str = 'default'): + """Merge LoRA adapters into base model weights.""" + from .tuners.lora import LoraParallelLinear + with torch.no_grad(): + for model in self.strategy.unwrap_model(self.model): + for module in model.modules(): + if isinstance(module, (LoraParallelLinear, LoraLinear)): + module.merge(adapter_names=[adapter_name]) + + def _unmerge_lora_adapters(self): + """Unmerge LoRA adapters to restore training state.""" + from .tuners.lora import LoraParallelLinear + with torch.no_grad(): + for model in self.strategy.unwrap_model(self.model): + for module in model.modules(): + if isinstance(module, (LoraParallelLinear, LoraLinear)): + module.unmerge() + def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=None): """Save in HuggingFace format using bridge adapter. diff --git a/src/twinkle/model/megatron/model/__init__.py b/src/twinkle/model/megatron/model/__init__.py index 28bae1ad..c61acef9 100644 --- a/src/twinkle/model/megatron/model/__init__.py +++ b/src/twinkle/model/megatron/model/__init__.py @@ -1,4 +1,4 @@ from . import gpts, mm_gpts from .constant import MegatronModelType from .gpt_bridge import GPTBridge -from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model +from .register import MegatronModelLoader, MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/src/twinkle/model/megatron/model/constant.py b/src/twinkle/model/megatron/model/constant.py index b3ea8807..33ac637c 100644 --- a/src/twinkle/model/megatron/model/constant.py +++ b/src/twinkle/model/megatron/model/constant.py @@ -14,6 +14,8 @@ class MLLMModelType: qwen2_5_vl = 'qwen2_5_vl' qwen3_vl = 'qwen3_vl' qwen3_vl_moe = 'qwen3_vl_moe' + qwen3_5 = 'qwen3_5' + qwen3_5_moe = 'qwen3_5_moe' class ModelType(LLMModelType, MLLMModelType): @@ -29,6 +31,7 @@ class MLLMMegatronModelType: qwen2_vl = 'qwen2_vl' qwen2_5_vl = 'qwen2_5_vl' qwen3_vl = 'qwen3_vl' + qwen3_5 = 'qwen3_5' class MegatronModelType(LLMMegatronModelType, MLLMMegatronModelType): diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py index d4f076bf..a5190ae0 100644 --- a/src/twinkle/model/megatron/model/gpt_bridge.py +++ b/src/twinkle/model/megatron/model/gpt_bridge.py @@ -20,7 +20,7 @@ from twinkle.hub import HubOperation from twinkle.model.megatron.args import get_args # Use twinkle's get_args from twinkle.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr, get_logger, - get_modules_to_not_convert, get_multimodal_target_regex, is_last_rank, requires) + get_modules_to_not_convert, is_last_rank, requires) logger = get_logger() @@ -111,6 +111,12 @@ def get_hf_mlp_prefix(self, layer_idx): def _get_hf_mlp(self, layer_idx): return getattr(self.hf_layers[layer_idx], self.get_hf_mlp_prefix(layer_idx)) + def _get_transpose(self): + if self.args.hf_model_type in {'qwen3_vl_moe', 'gpt_oss', 'llama4'}: + return True + else: + return False + def _init_meta_hf_model(self): import copy @@ -300,6 +306,9 @@ def _set_module(self, mg_module, hf_state_dict, hf_prefix: str, to_mcore: bool): if self._is_peft_format: if '.lora_A.' in k or '.lora_B.' in k or '.modules_to_save.' in k: k = k.replace(f'{self._adapter_name}.', '') + if '.lora_A.' in k: + module_name = k.split('.lora_A.')[0].rsplit('.', 1)[-1] + self._peft_target_modules.add(module_name) new_state_dict[k] = v else: if '.lora_A.' in k or '.lora_B.' in k or 'original_module.' in k: @@ -784,6 +793,10 @@ def _set_mlp_state(self, if _is_gate_up is not None: is_gate_up = _is_gate_up + need_transpose = True + if self.is_transformers_5 and hf_grouped: + need_transpose = self._get_transpose() + if hf_grouped and not to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) elif not to_mcore: @@ -858,18 +871,22 @@ def _set_mlp_state(self, gate_up_proj_weight = self.mxfp4_quantizer.convert(blocks, scales) else: gate_up_proj_weight = hf_state_dict['gate_up_proj'].load() - gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) + if need_transpose: + gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) + gate_up_proj_weight = gate_up_proj_weight[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] if has_scale_inv: - gate_up_scale_inv = hf_state_dict['gate_up_proj_scale_inv'].load().transpose(1, 2) + gate_up_scale_inv = hf_state_dict['gate_up_proj_scale_inv'].load() + if need_transpose: + gate_up_scale_inv = gate_up_scale_inv.transpose(1, 2) gate_up_scale_inv = gate_up_scale_inv[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] if fc1_bias is not None: gate_up_proj_bias = hf_state_dict['gate_up_proj_bias'].load() gate_up_proj_bias = gate_up_proj_bias[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] - if args.llm_model_type == 'gpt_oss': + if args.hf_model_type == 'gpt_oss': gate_proj_weight = gate_up_proj_weight[:, ::2] up_proj_weight = gate_up_proj_weight[:, 1::2] gate_proj_bias, up_proj_bias = gate_up_proj_bias[:, ::2], gate_up_proj_bias[:, 1::2] @@ -1018,12 +1035,13 @@ def _set_mlp_state(self, if is_gate_up: if is_expert: if hf_grouped: - gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) + if need_transpose: + gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) if 'gate_up_proj' in hf_state_dict: gate_up_proj_weight = torch.concat( [hf_state_dict['gate_up_proj'], gate_up_proj_weight], dim=0) is_last_ckpt = gate_up_proj_weight.shape[0] == args.num_experts - if args.llm_model_type == 'gpt_oss' and is_last_ckpt: + if args.hf_model_type == 'gpt_oss' and is_last_ckpt: gate_proj_weight, up_proj_weight = gate_up_proj_weight.chunk(2, dim=2) new_gate_up_proj_weight = torch.empty_like(gate_up_proj_weight) new_gate_up_proj_weight[..., ::2] = gate_proj_weight @@ -1032,7 +1050,8 @@ def _set_mlp_state(self, del new_gate_up_proj_weight, gate_proj_weight, up_proj_weight hf_state_dict['gate_up_proj'] = gate_up_proj_weight.clone() if scale_inv is not None: - scale_inv = scale_inv.transpose(1, 2) + if need_transpose: + scale_inv = scale_inv.transpose(1, 2) if 'gate_up_proj_scale_inv' in hf_state_dict: scale_inv = torch.concat([hf_state_dict['gate_up_proj_scale_inv'], scale_inv], dim=0) @@ -1042,7 +1061,7 @@ def _set_mlp_state(self, if 'gate_up_proj_bias' in hf_state_dict: gate_up_proj_bias = torch.concat( [hf_state_dict['gate_up_proj_bias'], gate_up_proj_bias], dim=0) - if args.llm_model_type == 'gpt_oss' and is_last_ckpt: + if args.hf_model_type == 'gpt_oss' and is_last_ckpt: gate_proj_bias, up_proj_bias = gate_up_proj_bias.chunk(2, dim=1) new_gate_up_proj_bias = torch.empty_like(gate_up_proj_bias) new_gate_up_proj_bias[:, ::2] = gate_proj_bias @@ -1124,12 +1143,15 @@ def _set_mlp_state(self, down_proj_weight = self.mxfp4_quantizer.convert(blocks, scales) else: down_proj_weight = hf_state_dict['down_proj'].load() - down_proj_weight = down_proj_weight.transpose(1, 2) + if need_transpose: + down_proj_weight = down_proj_weight.transpose(1, 2) down_proj_weight = down_proj_weight[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts].reshape( -1, down_proj_weight.shape[-1]) if has_scale_inv: - down_scale_inv = hf_state_dict['down_proj_scale_inv'].load().transpose(1, 2) + down_scale_inv = hf_state_dict['down_proj_scale_inv'].load() + if need_transpose: + down_scale_inv = down_scale_inv.transpose(1, 2) down_scale_inv = down_scale_inv[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts].reshape(-1, down_scale_inv.shape[-1]) if fc2_bias is not None: @@ -1209,12 +1231,14 @@ def _set_mlp_state(self, del fc2_weight, fc2_bias if down_proj_weight is not None: if hf_grouped: - down_proj_weight = down_proj_weight.transpose(1, 2) + if need_transpose: + down_proj_weight = down_proj_weight.transpose(1, 2) if 'down_proj' in hf_state_dict: down_proj_weight = torch.concat([hf_state_dict['down_proj'], down_proj_weight], dim=0) hf_state_dict['down_proj'] = down_proj_weight.clone() if scale_inv is not None: - scale_inv = scale_inv.transpose(1, 2) + if need_transpose: + scale_inv = scale_inv.transpose(1, 2) if 'down_proj_scale_inv' in hf_state_dict: scale_inv = torch.concat([hf_state_dict['down_proj_scale_inv'], scale_inv], dim=0) hf_state_dict['down_proj_scale_inv'] = scale_inv.clone() @@ -1467,7 +1491,7 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: hf_state_dict = {} self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict) transformer_layer = None if mtp_layer is None else mtp_layer.transformer_layer - if not to_mcore and not self.args.hf_model_type.startswith('qwen3_next'): + if not to_mcore and not self.args.hf_model_type.startswith(('qwen3_next', 'qwen3_5')): self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', to_mcore) if self.args.untie_embeddings_and_output_weights: @@ -1547,16 +1571,7 @@ def save_weights(self, peft_config = copy(mg_models[0].peft_config[self._adapter_name]) if args.task_type == 'seq_cls': peft_config.task_type = 'SEQ_CLS' - if args.is_multimodal and 'all-linear' in args.target_modules: - peft_config.target_modules = get_multimodal_target_regex( - self.hf_model, - freeze_llm=args.freeze_llm, - freeze_vit=args.freeze_vit, - freeze_aligner=args.freeze_aligner, - include_embedding='all-embedding' in args.target_modules, - exclude_router='all-router' not in args.target_modules) - else: - peft_config.target_modules = self._peft_target_modules + peft_config.target_modules = self._peft_target_modules peft_config.modules_to_save = self._peft_modules_to_save peft_config.save_pretrained(output_dir) else: diff --git a/src/twinkle/model/megatron/model/gpt_model.py b/src/twinkle/model/megatron/model/gpt_model.py index 477ccaf5..85e3f251 100644 --- a/src/twinkle/model/megatron/model/gpt_model.py +++ b/src/twinkle/model/megatron/model/gpt_model.py @@ -211,10 +211,12 @@ def _preprocess( f'current_attention_scaling: {attention_scaling}.') packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if self.position_embedding_type == 'mrope': + mrope_position_ids = position_ids + if mrope_position_ids.dim() == 2: + mrope_position_ids = mrope_position_ids.unsqueeze(0).expand(3, -1, -1) rotary_pos_emb = self.rotary_pos_emb( - position_ids, + mrope_position_ids, mrope_section=self.mrope_section, - packed_seq=packed_seq, ) else: rotary_pos_emb = self.rotary_pos_emb( diff --git a/src/twinkle/model/megatron/model/gpts/qwen3_next.py b/src/twinkle/model/megatron/model/gpts/qwen3_next.py new file mode 100644 index 00000000..b589a0f2 --- /dev/null +++ b/src/twinkle/model/megatron/model/gpts/qwen3_next.py @@ -0,0 +1,508 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Reference: swift/swift/megatron/model/gpts/qwen3_next.py +# Qwen3-Next / Qwen3.5 series model support for Megatron + +import megatron.core +import torch +from copy import deepcopy +from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, _get_extra_te_kwargs +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec, get_gpt_mtp_block_spec +from megatron.core.models.huggingface import HuggingFaceModule as _HuggingFaceModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel import (gather_from_sequence_parallel_region, + reduce_scatter_to_sequence_parallel_region) +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.spec_utils import build_module +from megatron.core.transformer.transformer_block import TransformerBlockSubmodules +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import deprecate_inference_params, is_fa_min_version +from packaging import version +from typing import List, Optional, Tuple, Union + +from twinkle import get_logger +from twinkle.model.megatron.args import get_args +from twinkle.model.megatron.model.register import MegatronModelLoader + +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +mcore_015 = version.parse(megatron.core.__version__) >= version.parse('0.15.0rc0') +try: + from flashattn_hopper.flash_attn_interface import _flash_attn_forward + from flashattn_hopper.flash_attn_interface import flash_attn_with_kvcache as flash_attn3_with_kvcache + HAVE_FA3 = True +except Exception: + HAVE_FA3 = False + +try: + from einops import rearrange +except ImportError: + rearrange = None + +try: + import transformer_engine # pylint: disable=unused-import + HAVE_TE = True + from megatron.core.extensions.transformer_engine import SplitAlongDim +except ImportError: + HAVE_TE = False + SplitAlongDim = None + +logger = get_logger() + + +class Qwen3NextRMSNorm(torch.nn.Module): + """ + Zero-Centered RMSNorm for Qwen3-Next/Qwen3.5. + Uses (1 + weight) scaling to match HuggingFace implementation exactly. + This eliminates the need for +1/-1 offset during weight conversion. + """ + + def __init__(self, config, hidden_size: int, eps: float = 1e-5): + super().__init__() + self.config = config + self.eps = eps + self.weight = torch.nn.Parameter(torch.zeros(hidden_size)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, hidden_states): + output = self._norm(hidden_states.float()) + output = output * (1.0 + self.weight.float()) + return output.type_as(hidden_states) + + +class Qwen3NextSelfAttention(SelfAttention): + """Full attention with output gate for Qwen3-Next/Qwen3.5 models. + + QKV projection produces [Q_heads, gate_heads, K_heads, V_heads] where + Q and gate are interleaved: Q0, gate0, Q1, gate1, ... + """ + + def __init__(self, config, submodules: SelfAttentionSubmodules, *args, **kwargs): + super(SelfAttention, self).__init__(config, submodules, *args, attention_type='self', **kwargs) + kwargs_pg = {} + if mcore_015: + kwargs_pg['tp_group'] = self.pg_collection.tp + elif mcore_013: + kwargs_pg['tp_group'] = self.model_comm_pgs.tp + self.linear_qkv = build_module( + submodules.linear_qkv, + self.config.hidden_size, + 2 * self.query_projection_size + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear or self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + **kwargs_pg, + ) + + if submodules.q_layernorm is not None: + self.q_layernorm = build_module( + submodules.q_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.q_layernorm = None + + if submodules.k_layernorm is not None: + self.k_layernorm = build_module( + submodules.k_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.k_layernorm = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + rotary_pos_cos: Optional[torch.Tensor] = None, + rotary_pos_sin: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + try: + from megatron.core.utils import nvtx_range_pop, nvtx_range_push + except ImportError: + + def nvtx_range_pop(*args, **kwargs): + return + + def nvtx_range_push(*args, **kwargs): + return + + if hasattr(self.config, 'no_rope_freq'): + no_rope = (self.config.no_rope_freq[self.layer_number - 1] if self.config.no_rope_freq else False) + if no_rope: + rotary_pos_emb = None + + inference_context = deprecate_inference_params(inference_context, inference_params) + + if inference_context and inference_context.is_dynamic_batching(): + assert HAVE_FA3 or is_fa_min_version( + '2.7.3'), 'flash attn verion v2.7.3 and above is required for dynamic batching.' + + if self.config.flash_decode and not self.training and inference_context is not None: + rotary_pos_emb = None + else: + assert rotary_pos_cos is None and rotary_pos_sin is None + + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb, ) * 2 + + nvtx_range_push(suffix='qkv') + query, key, value, gate = self.get_query_key_value_tensors(hidden_states, key_value_states) + nvtx_range_pop(suffix='qkv') + + in_decode_mode = (inference_context is not None and inference_context.is_decode_only() and not self.training) + + nvtx_range_push(suffix='adjust_key_value') + if in_decode_mode and self.config.flash_decode: + assert self.layer_number in inference_context.key_value_memory_dict + assert inference_context.sequence_len_offset is not None + inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number] + output = self.flash_decode( + sequence_len_offset=sequence_len_offset, + query_layer=query, + key_layer=key, + value_layer=value, + inference_key_memory=inference_key_memory, + inference_value_memory=inference_value_memory, + rotary_cos=rotary_pos_cos, + rotary_sin=rotary_pos_sin, + rotary_interleaved=self.config.rotary_interleaved, + ) + out = output.transpose(0, 1).contiguous() + context_layer = out.view(out.size(0), out.size(1), -1) + output, bias = self.linear_proj(context_layer) + return output, bias + + if (in_decode_mode and self.config.enable_cuda_graph and inference_context.is_static_batching()): + raise ValueError('CUDA graphs must use flash decode with static batching!') + + result = self._adjust_key_value_for_inference( + inference_context, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + ) + if mcore_013: + query, key, value, rotary_pos_emb, attn_mask_type, block_table = result + else: + query, key, value, rotary_pos_emb, attn_mask_type = result + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + nvtx_range_pop(suffix='adjust_key_value') + + kwargs_cp = {} + if mcore_015: + kwargs_cp['cp_group'] = self.pg_collection.cp + elif mcore_013: + kwargs_cp['cp_group'] = self.model_comm_pgs.cp + nvtx_range_push(suffix='rotary_pos_emb') + if rotary_pos_emb is not None and not self.config.flash_decode: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + cu_seqlens_q = ( + packed_seq_params.cu_seqlens_q_padded + if packed_seq_params.cu_seqlens_q_padded is not None else packed_seq_params.cu_seqlens_q) + cu_seqlens_kv = ( + packed_seq_params.cu_seqlens_kv_padded + if packed_seq_params.cu_seqlens_kv_padded is not None else packed_seq_params.cu_seqlens_kv) + else: + cu_seqlens_q = cu_seqlens_kv = None + + if q_pos_emb is not None: + if inference_context is None or inference_context.is_static_batching(): + query = apply_rotary_pos_emb( + query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q, **kwargs_cp) + else: + query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q, + **kwargs_cp) + if k_pos_emb is not None: + key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv, **kwargs_cp) + nvtx_range_pop(suffix='rotary_pos_emb') + + nvtx_range_push(suffix='core_attention') + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + if inference_context is None or inference_context.is_static_batching(): + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + q, k, v = (query, key, value) + cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() + cu_kv_lengths, kv_lengths, kv_lengths_decode_only, max_seqlen_k = (inference_context.cu_kv_lengths()) + core_attn_out = self.flash_decode_and_prefill( + q, + k, + v, + max_seqlen_q, + max_seqlen_k, + cu_query_lengths, + cu_kv_lengths, + kv_lengths, + kv_lengths_decode_only, + block_table, + ) + core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') + + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + nvtx_range_pop(suffix='core_attention') + + core_attn_out = core_attn_out * torch.sigmoid(gate.reshape_as(core_attn_out)) + nvtx_range_push(suffix='linear_proj') + output, bias = self.linear_proj(core_attn_out) + nvtx_range_pop(suffix='linear_proj') + + return output, bias + + def get_query_key_value_tensors(self, hidden_states, key_value_states=None): + mixed_qkv, _ = self.linear_qkv(hidden_states) + + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + ((self.num_attention_heads_per_partition // self.num_query_groups_per_partition * 2 + 2) + * self.hidden_size_per_attention_head), + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + split_arg_list = [ + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head * 2), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + + if SplitAlongDim is not None: + (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) + else: + (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) + + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + query, gate = query[:, :, ::2], query[:, :, 1::2] + if self.q_layernorm is not None: + query = self.q_layernorm(query) + if self.k_layernorm is not None: + key = self.k_layernorm(key) + + if self.config.test_mode: + self.run_realtime_tests() + + return query, key, value, gate + + +def _gated_delta_net_forward(self, hidden_states: torch.Tensor, **kwargs): + """Shared forward logic for all GatedDeltaNet variants.""" + args = get_args() + if args.sequence_parallel and args.tensor_model_parallel_size > 1: + hidden_states = gather_from_sequence_parallel_region(hidden_states) + seq_len = hidden_states.shape[0] + packed_seq_params = kwargs.get('packed_seq_params') + thd_format = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if thd_format and not getattr(args, 'packing', False): + new_hidden_states = hidden_states.new_zeros( + (packed_seq_params.num_samples, packed_seq_params.max_seqlen_q.item(), hidden_states.shape[-1])) + attention_mask = hidden_states.new_zeros((packed_seq_params.num_samples, packed_seq_params.max_seqlen_q.item()), + dtype=torch.bool) + cu_seqlens_q = packed_seq_params.cu_seqlens_q + for i in range(packed_seq_params.num_samples): + start, end = cu_seqlens_q[i], cu_seqlens_q[i + 1] + attention_mask[i, :end - start] = True + new_hidden_states[i, :end - start] = hidden_states[start:end, 0] + hidden_states = new_hidden_states + else: + hidden_states = hidden_states.transpose(0, 1) + attention_mask = kwargs.get('attention_mask') + if attention_mask is not None: + attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0 + res = super(type(self), self).forward(hidden_states=hidden_states, attention_mask=attention_mask) + if thd_format and not getattr(args, 'packing', False): + res = res[attention_mask][:, None] + res = torch.concat([res, res.new_zeros(seq_len - res.shape[0], 1, res.shape[2])]) + else: + res = res.transpose(0, 1) + if args.sequence_parallel and args.tensor_model_parallel_size > 1: + res = reduce_scatter_to_sequence_parallel_region(res) / args.tensor_model_parallel_size + return res, None + + +def _gated_delta_net_init(self, hf_cls, config, submodules, layer_number, **kwargs): + """Shared __init__ logic for all GatedDeltaNet variants.""" + assert config.context_parallel_size == 1, 'Qwen3-Next/Qwen3.5 currently does not support context parallel.' + hf_cls.__init__(self, config, layer_number) + self.config = config + extra_kwargs = _get_extra_te_kwargs(config) + self.to(dtype=extra_kwargs['params_dtype'], device=extra_kwargs['device']) + + +try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeGatedDeltaNet as _Qwen3_5MoeGatedDeltaNet +except ImportError: + _Qwen3_5MoeGatedDeltaNet = object + +try: + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextGatedDeltaNet as _Qwen3NextGatedDeltaNet +except ImportError: + _Qwen3NextGatedDeltaNet = object + + +class Qwen3NextGatedDeltaNet(_HuggingFaceModule, _Qwen3NextGatedDeltaNet): + """GatedDeltaNet for linear attention layers in Qwen3-Next models.""" + + def __init__(self, config, submodules: SelfAttentionSubmodules, layer_number: int, **kwargs): + assert _Qwen3NextGatedDeltaNet is not object, 'please update the `transformers` version.' + _gated_delta_net_init(self, _Qwen3NextGatedDeltaNet, config, submodules, layer_number, **kwargs) + + def forward(self, hidden_states: torch.Tensor, **kwargs): + return _gated_delta_net_forward(self, hidden_states, **kwargs) + + +class Qwen3_5MoeGatedDeltaNet(_HuggingFaceModule, _Qwen3_5MoeGatedDeltaNet): + """GatedDeltaNet for Qwen3.5-MoE linear attention layers.""" + + def __init__(self, config, submodules: SelfAttentionSubmodules, layer_number: int, **kwargs): + assert _Qwen3_5MoeGatedDeltaNet is not object, 'please update the `transformers` version.' + _gated_delta_net_init(self, _Qwen3_5MoeGatedDeltaNet, config, submodules, layer_number, **kwargs) + + def forward(self, hidden_states: torch.Tensor, **kwargs): + return _gated_delta_net_forward(self, hidden_states, **kwargs) + + +def get_local_layer_specs(config, layer_specs, vp_stage=None): + """Get the layer specs for layers assigned to this pipeline stage. + + Mirrors swift.megatron.utils.get_local_layer_specs for distributing + heterogeneous layer specs across pipeline stages. + """ + from megatron.core import mpu + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + if pp_size <= 1: + return layer_specs + num_layers = len(layer_specs) + layers_per_stage = num_layers // pp_size + remainder = num_layers % pp_size + start = pp_rank * layers_per_stage + min(pp_rank, remainder) + if pp_rank < remainder: + layers_per_stage += 1 + return layer_specs[start:start + layers_per_stage] + + +def get_qwen3_next_layer_spec(config, args, gated_delta_net_cls): + """Build the heterogeneous transformer layer specs for Qwen3-Next/Qwen3.5. + + Returns a TransformerBlockSubmodules with per-layer specs matching + the model's layer_types (linear_attention / full_attention). + """ + config.hetereogenous_dist_checkpoint = True + config.hidden_act = 'silu' + config.rms_norm_eps = config.layernorm_epsilon + config.dtype = args.params_dtype + + layer_norm_impl = Qwen3NextRMSNorm + kwargs = {'use_kitchen': config.use_kitchen} if hasattr(config, 'use_kitchen') and mcore_013 else {} + moe_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=getattr(config, 'moe_grouped_gemm', True), + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=getattr(config, 'moe_use_legacy_grouped_gemm', False), + **kwargs, + ) + layer_specs = [] + for layer_type in config.layer_types: + layer_spec = deepcopy(moe_layer_spec) + if layer_type == 'linear_attention': + layer_spec.submodules.self_attention.module = gated_delta_net_cls + elif layer_type == 'full_attention': + layer_spec.submodules.self_attention.submodules.linear_qkv = TEColumnParallelLinear + layer_spec.submodules.self_attention.module = Qwen3NextSelfAttention + layer_spec.submodules.input_layernorm = layer_norm_impl + if hasattr(layer_spec.submodules, + 'pre_mlp_layernorm') and layer_spec.submodules.pre_mlp_layernorm is not IdentityOp: + layer_spec.submodules.pre_mlp_layernorm = layer_norm_impl + if hasattr(layer_spec.submodules.self_attention.submodules, 'q_layernorm'): + layer_spec.submodules.self_attention.submodules.q_layernorm = layer_norm_impl + if hasattr(layer_spec.submodules.self_attention.submodules, 'k_layernorm'): + layer_spec.submodules.self_attention.submodules.k_layernorm = layer_norm_impl + if (getattr(config, 'moe_use_shared_expert_gate', False) and hasattr(layer_spec.submodules, 'mlp') + and hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts')): + layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} + layer_specs.append(layer_spec) + + local_layer_specs = get_local_layer_specs(config, layer_specs) + block_spec = TransformerBlockSubmodules(layer_specs=local_layer_specs, layer_norm=layer_norm_impl) + + return block_spec + + +def get_qwen3_next_mtp_block_spec(config, transformer_layer_spec, **kwargs): + """Build MTP block spec with Qwen3NextRMSNorm.""" + mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=True, **kwargs) + for layer_spec in mtp_block_spec.layer_specs: + layer_spec.submodules.enorm = Qwen3NextRMSNorm + layer_spec.submodules.hnorm = Qwen3NextRMSNorm + layer_spec.submodules.layer_norm = Qwen3NextRMSNorm + return mtp_block_spec + + +class Qwen3NextLoader(MegatronModelLoader): + """Loader for Qwen3-Next models with heterogeneous linear/full attention layers.""" + gated_delta_net = Qwen3NextGatedDeltaNet + + def post_config(self, config, args, mg_config_dict): + layer_types = mg_config_dict.get('layer_types') + if layer_types is not None: + config.layer_types = layer_types + for attr in ('linear_num_value_heads', 'linear_num_key_heads', 'linear_key_head_dim', + 'linear_value_head_dim', 'linear_conv_kernel_dim'): + val = mg_config_dict.get(attr) + if val is not None: + setattr(config, attr, val) + + def get_layer_spec(self, config, args, mg_config_dict): + return get_qwen3_next_layer_spec(config, args, self.gated_delta_net) + + def get_mtp_block_spec(self, config, layer_spec, **kwargs): + return get_qwen3_next_mtp_block_spec(config, layer_spec, **kwargs) diff --git a/src/twinkle/model/megatron/model/mm_gpt_model.py b/src/twinkle/model/megatron/model/mm_gpt_model.py index 4e2aa4d1..83a86ef5 100644 --- a/src/twinkle/model/megatron/model/mm_gpt_model.py +++ b/src/twinkle/model/megatron/model/mm_gpt_model.py @@ -82,7 +82,8 @@ def forward(_self, input_): if reduce_scatter_embeddings: res = res.transpose(0, 1).contiguous() group_kwargs = {'group': _self.tp_group} if mcore_013 else {} - res = reduce_scatter_to_sequence_parallel_region(res, **group_kwargs) / args.tensor_model_parallel_size + tp_size = mpu.get_tensor_model_parallel_world_size() + res = reduce_scatter_to_sequence_parallel_region(res, **group_kwargs) / tp_size return res VocabParallelEmbedding.forward = forward diff --git a/src/twinkle/model/megatron/model/mm_gpts/__init__.py b/src/twinkle/model/megatron/model/mm_gpts/__init__.py index 2cee28f6..30f10d89 100644 --- a/src/twinkle/model/megatron/model/mm_gpts/__init__.py +++ b/src/twinkle/model/megatron/model/mm_gpts/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from . import qwen, qwen3_vl, utils +from . import qwen, qwen3_5, qwen3_vl, utils diff --git a/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py b/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py new file mode 100644 index 00000000..f0dec64f --- /dev/null +++ b/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py @@ -0,0 +1,166 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Reference: swift/swift/megatron/model/mm_gpts/qwen3_5.py +# Qwen3.5 / Qwen3.5-MoE multimodal model support for Megatron + +import torch +from PIL import Image + +from twinkle.model.megatron.args import get_args +from twinkle.utils.torch_utils import to_device +from ..constant import MegatronModelType, ModelType +from ..gpt_bridge import GPTBridge, MultimodalGPTBridge +from ..gpts.qwen3_next import Qwen3_5MoeGatedDeltaNet, Qwen3NextLoader +from ..register import MegatronModelMeta, register_megatron_model +from .utils import HuggingFaceModule + + +class Qwen3_5Vit(HuggingFaceModule): + """Vision module for Qwen3.5 / Qwen3.5-MoE models. + + Maps 'model.visual' from HF model to 'visual' in Megatron, + with merger as aligner. + """ + module_mapping = {'model.visual': 'visual'} + _vision_tower = ['visual'] + _aligner = ['visual.merger'] + + def __init__(self, config): + try: + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5TextModel + except ImportError: + Qwen3_5TextModel = None + try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextModel + except ImportError: + Qwen3_5MoeTextModel = None + ignore_cls = [c for c in [Qwen3_5TextModel, Qwen3_5MoeTextModel] if c is not None] + super().__init__(config, ignore_cls) + + def get_inputs_embeds(self, inputs_embeds, **kwargs): + return self._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.model_config) + + def _get_inputs_embeds_hf(self, inputs_embeds, inputs, visual, processor, config): + input_ids = inputs['input_ids'] + pixel_values = inputs.get('pixel_values') + pixel_values_videos = inputs.get('pixel_values_videos') + image_grid_thw = inputs.get('image_grid_thw') + video_grid_thw = inputs.get('video_grid_thw') + dtype = visual.dtype + if pixel_values is None and pixel_values_videos is None: + images = [Image.new('RGB', (32, 32), (0, 0, 0))] + media_inputs = processor.image_processor(images=images, return_tensors='pt') + media_inputs = to_device(media_inputs, input_ids.device) + pixel_values = media_inputs['pixel_values'].type(dtype) + image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) + if hasattr(image_embeds, 'pooler_output'): + image_embeds = image_embeds.pooler_output + inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0. + else: + if pixel_values is None: + pixel_values_mixed = pixel_values_videos + grid_thw = video_grid_thw + elif pixel_values_videos is None: + pixel_values_mixed = pixel_values + grid_thw = image_grid_thw + else: + pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0) + grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0) + pixel_values_mixed = pixel_values_mixed.type(dtype) + mixed_embeds = visual(pixel_values_mixed, grid_thw=grid_thw) + if hasattr(mixed_embeds, 'pooler_output'): + mixed_embeds = mixed_embeds.pooler_output + if pixel_values is None: + image_embeds = None + video_embeds = mixed_embeds + elif pixel_values_videos is None: + image_embeds = mixed_embeds + video_embeds = None + else: + merge_length = processor.image_processor.merge_size**2 + image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum() + image_embeds = mixed_embeds[:image_tokens] + video_embeds = mixed_embeds[image_tokens:] + + if image_embeds is not None: + image_mask = (input_ids == config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask = image_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if video_embeds is not None: + video_mask = (input_ids == config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + video_mask = video_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + return inputs_embeds + + +class Qwen3_5Bridge(MultimodalGPTBridge): + """Bridge for Qwen3.5 multimodal models. + + Uses language_model prefix for the LLM backbone since Qwen3.5 has a + multimodal architecture with model.language_model.layers structure. + + Overrides _set_layer_attn to handle the mixed linear/full attention + architecture specific to Qwen3-Next/Qwen3.5. + """ + hf_layers_prefix = 'model.language_model.layers' + hf_embed_key = 'model.language_model.embed_tokens.weight' + hf_final_layernorm_key = 'model.language_model.norm.weight' + hf_mtp_prefix = 'mtp.layers' + + def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): + args = self.args + layer_types = getattr(args, 'layer_types', None) + if layer_types is None: + return super()._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore) + + layer_type = layer_types[layer_idx] if 0 <= layer_idx < len(layer_types) else 'full_attention' + mg_attn = None if mg_layer is None else mg_layer.self_attention + if layer_type == 'linear_attention': + hf_state_dict.update(self._set_module(mg_attn, hf_state_dict, 'linear_attn.', to_mcore)) + elif layer_type == 'full_attention': + hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) + self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore) + return hf_state_dict + + def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict): + hf_state_dict = self._remove_prefix(origin_hf_state_dict, 'mtp.') + for mg_key, key in zip(['enorm.weight', 'hnorm.weight', 'eh_proj.weight'], + ['pre_fc_norm_embedding.weight', 'pre_fc_norm_hidden.weight', 'fc.weight']): + self._set_state_dict(mtp_layer, mg_key, hf_state_dict, key, to_mcore) + self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'norm.weight', to_mcore) + if not to_mcore: + origin_hf_state_dict.update(self._add_prefix(hf_state_dict, 'mtp.')) + + +try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForConditionalGeneration +except ImportError: + Qwen3_5MoeForConditionalGeneration = None + +_auto_model_cls = Qwen3_5MoeForConditionalGeneration +if _auto_model_cls is None: + try: + from transformers import AutoModel + _auto_model_cls = AutoModel + except ImportError: + _auto_model_cls = None + + +class Qwen3_5MoeLoader(Qwen3NextLoader): + gated_delta_net = Qwen3_5MoeGatedDeltaNet + + +register_megatron_model( + MegatronModelMeta( + MegatronModelType.qwen3_5, + [ + ModelType.qwen3_5, + ModelType.qwen3_5_moe, + ], + bridge_cls=Qwen3_5Bridge, + visual_cls=Qwen3_5Vit, + auto_model_cls=_auto_model_cls, + loader=Qwen3_5MoeLoader, + )) diff --git a/src/twinkle/model/megatron/model/register.py b/src/twinkle/model/megatron/model/register.py index f7ef917d..ab59569d 100644 --- a/src/twinkle/model/megatron/model/register.py +++ b/src/twinkle/model/megatron/model/register.py @@ -1,8 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import torch.nn as nn -from argparse import ArgumentParser from dataclasses import dataclass -from typing import Callable, List, Optional, Type +from typing import List, Optional, Type from .constant import MLLMMegatronModelType @@ -17,14 +16,9 @@ class MegatronModelMeta: is_multimodal: bool = False bridge_cls: Optional[Type] = None model_cls: Optional[Type[nn.Module]] = None - get_transformer_layer_spec: Optional[Callable] = None - model_provider: Optional[Callable[[], nn.Module]] = None visual_cls: Optional[Type[nn.Module]] = None - get_mtp_block_spec: Optional[Callable] = None - # AutoModel class for loading HF model (AutoModelForCausalLM for text, AutoModel for multimodal) auto_model_cls: Optional[Type] = None - - extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None + loader: Optional[Type['MegatronModelLoader']] = None def __post_init__(self): if self.megatron_model_type in MLLMMegatronModelType.__dict__: @@ -39,11 +33,51 @@ def __post_init__(self): if self.auto_model_cls is None: from transformers import AutoModel, AutoModelForCausalLM self.auto_model_cls = AutoModel if self.is_multimodal else AutoModelForCausalLM + if self.loader is None: + self.loader = MegatronModelLoader + + +class MegatronModelLoader: + """Default loader that builds TransformerConfig + layer specs for a model. + + Subclass this to customize layer spec construction (e.g. heterogeneous + attention types, custom layer norms). Register the subclass via + ``MegatronModelMeta(loader=MyLoader)``. + """ + + def get_layer_spec(self, config, args, mg_config_dict): + """Build a transformer layer spec from *config* (``TransformerConfig``). + + The default implementation delegates to Megatron-Core's + ``get_gpt_layer_with_transformer_engine_spec``. + + Returns: + A ``ModuleSpec`` or ``TransformerBlockSubmodules`` instance. + """ + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec + num_experts = mg_config_dict.get('num_experts', 0) or 0 + return get_gpt_layer_with_transformer_engine_spec( + num_experts=num_experts, + moe_grouped_gemm=num_experts > 0, + qk_layernorm=mg_config_dict.get('qk_layernorm', False), + ) + + def get_mtp_block_spec(self, config, layer_spec, **kwargs): + """Build MTP block spec. Override for custom layer norms etc.""" + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec + return get_gpt_mtp_block_spec(config, layer_spec, use_transformer_engine=True, **kwargs) + + def post_config(self, config, args, mg_config_dict): + """Hook called after TransformerConfig is created but before layer specs. + + Use this to set model-specific config attributes (e.g. ``layer_types``, + ``moe_use_shared_expert_gate``). + """ + pass def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False): megatron_model_type = megatron_model_meta.megatron_model_type - # diff here if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING: raise ValueError(f'The `{megatron_model_type}` has already been registered in the MEGATRON_MODEL_MAPPING.') MEGATRON_MODEL_MAPPING[megatron_model_type] = megatron_model_meta diff --git a/src/twinkle/model/megatron/utils/config.py b/src/twinkle/model/megatron/utils/config.py index ef44b4b1..eca0edbb 100644 --- a/src/twinkle/model/megatron/utils/config.py +++ b/src/twinkle/model/megatron/utils/config.py @@ -37,6 +37,13 @@ 'qk_pos_emb_head_dim': ['qk_rope_head_dim'], 'moe_router_topk_scaling_factor': ['routed_scaling_factor'], 'qk_layernorm': ['use_qk_norm'], + # qwen3_next / qwen3_5 + 'linear_num_value_heads': ['linear_num_value_heads'], + 'linear_num_key_heads': ['linear_num_key_heads'], + 'linear_key_head_dim': ['linear_key_head_dim'], + 'linear_value_head_dim': ['linear_value_head_dim'], + 'linear_conv_kernel_dim': ['linear_conv_kernel_dim'], + 'full_attention_interval': ['full_attention_interval'], # other 'original_max_position_embeddings': ['original_max_position_embeddings'], 'partial_rotary_factor': ['partial_rotary_factor'], @@ -95,13 +102,14 @@ def convert_hf_config(config) -> Dict[str, Any]: interleave_moe_layer_step = res.pop('interleave_moe_layer_step', None) window_size = res.pop('window_size', None) rope_scaling = res.get('rope_scaling') or {} - if llm_model_type in {'qwen3', 'qwen3_moe', 'qwen3_next' - } or hf_model_type in {'qwen3_omni_moe', 'qwen3_omni', 'qwen3_vl', 'qwen3_vl_moe'}: + if llm_model_type in {'qwen3', 'qwen3_moe', 'qwen3_next'} or hf_model_type in { + 'qwen3_omni_moe', 'qwen3_omni', 'qwen3_vl', 'qwen3_vl_moe', 'qwen3_5', 'qwen3_5_moe' + }: res['qk_layernorm'] = True if llm_model_type in {'qwen2_moe', 'qwen3_moe', 'qwen3_next' - } or hf_model_type in {'qwen3_omni_moe', 'qwen3_vl_moe'}: + } or hf_model_type in {'qwen3_omni_moe', 'qwen3_vl_moe', 'qwen3_5_moe'}: res.pop('ffn_hidden_size', None) - if llm_model_type in {'qwen2_moe', 'qwen3_next'}: + if llm_model_type in {'qwen2_moe', 'qwen3_next'} or hf_model_type == 'qwen3_5_moe': res['use_shared_expert_gate'] = True if llm_model_type in { 'deepseek', @@ -145,8 +153,8 @@ def convert_hf_config(config) -> Dict[str, Any]: if llm_model_type == 'glm4_moe_lite': res['qk_layernorm'] = True res.pop('num_query_groups', None) - elif llm_model_type == 'qwen3_next': - full_attention_interval = res.pop('full_attention_interval') + elif llm_model_type == 'qwen3_next' or hf_model_type in {'qwen3_5', 'qwen3_5_moe'}: + full_attention_interval = res.pop('full_attention_interval', 4) num_layers = res['num_layers'] res['layer_types'] = [ 'full_attention' if (i + 1) % full_attention_interval == 0 else 'linear_attention' diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index b75603bb..fe9733f1 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -363,7 +363,7 @@ def collate_fn(self, micro_batch_size: Optional[int] = None, variable_seq_lengths=False, **kwargs) -> List[InputFeature]: - if len(inputs) == 1: + if len(inputs) == 1 and self.framework != 'megatron': return inputs if micro_batch_size is None: # normal collate diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index edcefc34..1d7f9028 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -11,6 +11,6 @@ from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device -from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert, get_multimodal_target_regex +from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr diff --git a/src/twinkle/utils/transformers_utils.py b/src/twinkle/utils/transformers_utils.py index ee751c90..036f7538 100644 --- a/src/twinkle/utils/transformers_utils.py +++ b/src/twinkle/utils/transformers_utils.py @@ -71,54 +71,6 @@ def _cond(name, module): return find_layers(model, _cond, sub_module=sub_module) -def get_multimodal_target_regex( - model, - *, - freeze_llm: bool = False, - freeze_vit: bool = True, - freeze_aligner: bool = True, - include_embedding: bool = False, - exclude_router: bool = False, -) -> str: - import torch.nn as nn - model_arch = model.model_meta.model_arch - modules = [] - if not freeze_llm: - modules += model_arch.language_model - if not freeze_vit: - modules += model_arch.vision_tower - if not freeze_aligner: - modules += model_arch.aligner - assert len(modules) > 0, f'modules: {modules}' - - extra_layers = [] - if include_embedding: - extra_layers.append(nn.Embedding) - res = [] - for module in modules: - rejected_modules = [] - if not freeze_vit or not freeze_llm: - for aligner in model_arch.aligner: - if aligner.startswith(f'{module}.'): - rejected_modules.append(aligner) - - sub_module = deep_getattr(model, module) - if isinstance(sub_module, nn.Linear) and module.endswith('lm_head'): - target_modules = [] - else: - target_modules = find_all_linears(sub_module, model_arch, extra_layers) - if exclude_router and model.model_info.is_moe_model: - target_modules = [tm for tm in target_modules if tm not in {'gate'}] - if not target_modules: - continue - target_modules = [tm for tm in target_modules if tm] - target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else '' - rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else '' - res.append(rf'{rejected_pattern}{module}{target_pattern}') - - return rf'^({"|".join(res)})$' - - def get_modules_to_not_convert(model): if not hasattr(model, 'model_meta') or not hasattr(model, 'model_info'): return