diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index bad557093..73d1d473a 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -54,6 +54,7 @@ "GptOssForCausalLM": "atom.models.gpt_oss.GptOssForCausalLM", "Glm4MoeForCausalLM": "atom.models.glm4_moe.Glm4MoeForCausalLM", "Qwen3NextForCausalLM": "atom.models.qwen3_next.Qwen3NextForCausalLM", + "MiniMaxM2ForCausalLM": "atom.models.minimax_m2.MiniMaxM2ForCausalLM", } # seed = 34567 # np.random.seed(seed) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index d9dc1e34c..5aa414d0a 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1904,10 +1904,6 @@ def __init__( self.use_chunked = get_dp_group().world_size > 1 - if self.scoring_func != "softmax" and not self.use_grouped_topk: - raise ValueError( - "Only softmax scoring function is supported for " "non-grouped topk." - ) moe = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=self.top_k, @@ -2078,21 +2074,34 @@ def _load_w13( # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 - loaded_weight = loaded_weight.narrow( - shard_dim, shard_size * tp_rank, shard_size + + # Calculate original shard size from loaded_weight + # Assuming loaded_weight is the full tensor (or one full partition if partially loaded) + # Here we assume loaded_weight is full tensor + original_shard_size = loaded_weight.shape[shard_dim] // self.tp_size + valid_shard_size = min(shard_size, original_shard_size) + + # Load valid part from loaded_weight + loaded_shard = loaded_weight.narrow( + shard_dim, original_shard_size * tp_rank, valid_shard_size ) + # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": - expert_data = expert_data.narrow(shard_dim, 0, shard_size) + expert_data_slice = expert_data.narrow(shard_dim, 0, shard_size) # w3, up_proj: Load into second logical weight of w13. else: assert shard_id == "w3" - expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data_slice = expert_data.narrow(shard_dim, shard_size, shard_size) + + # Determine slice of expert_data to copy into (handle padding) + expert_data_valid = expert_data_slice.narrow(shard_dim, 0, valid_shard_size) + if expert_data.dtype != dtypes.fp4x2: - expert_data.copy_(loaded_weight) + expert_data_valid.copy_(loaded_shard) else: - expert_data.view(torch.uint8).copy_(loaded_weight.view(torch.uint8)) + expert_data_valid.view(torch.uint8).copy_(loaded_shard.view(torch.uint8)) def _load_w2( self, @@ -2107,15 +2116,28 @@ def _load_w2( # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] + if not load_full: - loaded_weight = loaded_weight.narrow( - shard_dim, shard_size * tp_rank, shard_size + original_shard_size = loaded_weight.shape[shard_dim] // self.tp_size + valid_shard_size = min(shard_size, original_shard_size) + + loaded_shard = loaded_weight.narrow( + shard_dim, original_shard_size * tp_rank, valid_shard_size ) - # w2, down_proj: Load into only logical weight of w2. - if expert_data.dtype != dtypes.fp4x2: - expert_data.copy_(loaded_weight) + expert_data_valid = expert_data.narrow(shard_dim, 0, valid_shard_size) + + if expert_data.dtype != dtypes.fp4x2: + expert_data_valid.copy_(loaded_shard) + else: + expert_data_valid.view(torch.uint8).copy_( + loaded_shard.view(torch.uint8) + ) else: - expert_data.view(torch.uint8).copy_(loaded_weight.view(torch.uint8)) + # Full load + if expert_data.dtype != dtypes.fp4x2: + expert_data.copy_(loaded_weight) + else: + expert_data.view(torch.uint8).copy_(loaded_weight.view(torch.uint8)) def _load_single_value( self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int @@ -2388,11 +2410,33 @@ def select_experts( num_fused_shared_experts=num_fused_shared_experts, ) else: - topk_weights, topk_ids = fused_topk( - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) + if scoring_func == "softmax": + topk_weights, topk_ids = fused_topk( + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + elif scoring_func == "sigmoid": + routing_weights = torch.sigmoid(router_logits.float()) + scores_for_choice = routing_weights + if e_score_correction_bias is not None: + scores_for_choice = scores_for_choice + e_score_correction_bias + + topk_ids = torch.topk( + scores_for_choice, top_k, dim=-1, sorted=False + ).indices + topk_weights = routing_weights.gather(dim=-1, index=topk_ids) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum( + dim=-1, keepdim=True + ).clamp_min(1e-20) + + topk_ids = topk_ids.to(torch.int32) + else: + raise ValueError( + f"Unsupported scoring function for non-grouped topk: {scoring_func}" + ) return topk_weights, topk_ids diff --git a/atom/models/minimax_m2.py b/atom/models/minimax_m2.py new file mode 100644 index 000000000..2c956a5f5 --- /dev/null +++ b/atom/models/minimax_m2.py @@ -0,0 +1,431 @@ +from typing import Optional, Union + +import torch +from aiter.dist.communication_op import tensor_model_parallel_all_reduce +from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size +from aiter.rotary_embedding import get_rope +from atom.config import Config, QuantizationConfig +from atom.model_ops.base_attention import Attention +from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding +from atom.model_ops.layernorm import RMSNorm +from atom.model_ops.linear import QKVParallelLinear, ReplicatedLinear, RowParallelLinear +from atom.model_ops.moe import FusedMoE +from atom.models.utils import ( + IntermediateTensors, + PPMissingLayer, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from atom.utils import envs +from atom.utils.decorators import support_torch_compile +from torch import nn +from transformers import PretrainedConfig + +ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION + + +class MiniMaxM2SparseMoeBlock(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + self.num_experts = config.num_local_experts + if self.tp_size > self.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.num_experts}." + ) + + if getattr(config, "use_routing_bias", False): + self.e_score_correction_bias = nn.Parameter( + torch.zeros(self.num_experts, dtype=torch.float32) + ) + else: + self.register_parameter("e_score_correction_bias", None) + + self.experts = FusedMoE( + num_experts=self.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + reduce_results=False, + renormalize=True, + quant_config=quant_config, + use_grouped_topk=False, + scoring_func=getattr(config, "scoring_func", "softmax"), + e_score_correction_bias=self.e_score_correction_bias, + prefix=f"{prefix}.experts", + has_bias=False, + config=config, + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + self.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert ( + hidden_states.dim() <= 2 + ), "MiniMaxM2SparseMoeBlock only supports 1D or 2D inputs" + is_input_1d = hidden_states.dim() == 1 + + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + router_logits = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self.tp_size > 1 and not ENABLE_ALLREDUCE_RMSNORM_FUSION: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states + + +class MiniMaxM2Attention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int, + head_dim: int, + rotary_dim: int, + rms_norm_eps: float, + rope_theta: float, + rope_scaling: tuple | None, + qkv_bias: bool, + kv_cache_dtype: str, + layer_num: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_qk_norm: bool = True, + ) -> None: + super().__init__() + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=prefix, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position, + base=rope_theta, + rope_scaling=rope_scaling, + ) + + self.use_qk_norm = use_qk_norm + if self.use_qk_norm: + self.q_norm = RMSNorm(self.q_size, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.kv_size, eps=rms_norm_eps) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + self.num_kv_heads, + kv_cache_dtype=kv_cache_dtype, + layer_num=layer_num, + use_mla=False, + rotary_emb=self.rotary_emb, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv = self.qkv_proj(hidden_states) + q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) + + if self.use_qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + + attn_output = self.attn(q, k, v, positions) + output = self.o_proj(attn_output) + return output + + +class MiniMaxM2DecoderLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + prefix: str, + cache_config: str = "bf16", + quant_config: Optional[QuantizationConfig] = None, + layer_num: int = 0, + ) -> None: + super().__init__() + + self.layer_idx = layer_num + self.hidden_size = config.hidden_size + + self.self_attn = MiniMaxM2Attention( + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + max_position=getattr(config, "max_position_embeddings", 8192), + head_dim=getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ), + rotary_dim=getattr( + config, + "rotary_dim", + getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ), + ), + rms_norm_eps=config.rms_norm_eps, + rope_theta=getattr(config, "rope_theta", 10000), + rope_scaling=getattr(config, "rope_scaling", None), + qkv_bias=bool(getattr(config, "attention_bias", False) or False), + kv_cache_dtype=cache_config, + layer_num=layer_num, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_qk_norm=getattr(config, "use_qk_norm", True), + ) + + self.block_sparse_moe = MiniMaxM2SparseMoeBlock( + config, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe", + ) + + self.input_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION and self.layer_idx > 0, + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class MiniMaxM2Model(nn.Module): + def __init__( + self, + atom_config: Config, + prefix: str = "", + layer_type: type[nn.Module] = MiniMaxM2DecoderLayer, + ): + super().__init__() + + config = atom_config.hf_config + cache_config = atom_config.kv_cache_dtype + quant_config = atom_config.quant_config + self.config = config + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix, layer_num=None: layer_type( + config, + prefix, + cache_config=cache_config, + quant_config=quant_config, + layer_num=layer_num, + ), + prefix=f"{prefix}.layers", + layer_num_offset=0, + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, + ) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer : self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) + + +class MiniMaxM2ForCausalLM(nn.Module): + packed_modules_mapping = { + "q_proj": ("qkv_proj", "q"), + "k_proj": ("qkv_proj", "k"), + "v_proj": ("qkv_proj", "v"), + } + + def __init__( + self, + atom_config: Config, + prefix: str = "", + layer_type: type[nn.Module] = MiniMaxM2DecoderLayer, + ): + super().__init__() + config = atom_config.hf_config + self.config = config + + self.model = MiniMaxM2Model( + atom_config=atom_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type, + ) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + return self.model(input_ids, positions, intermediate_tensors, inputs_embeds) + + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + return self.lm_head(hidden_states) + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 8da45b41e..1a05b0707 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -236,9 +236,13 @@ def forward( if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: attn_output = self.attn(q, k, v, positions, None, qkv) else: - # Add qk-norm - q = self.q_norm(q) - k = self.k_norm(k) + # Add qk-norm (per-head) + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( + -1, self.num_heads * self.head_dim + ) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( + -1, self.num_kv_heads * self.head_dim + ) attn_output = self.attn(q, k, v, positions) output = self.o_proj(attn_output) diff --git a/tests/test_mxfp4_moe_has_bias.py b/tests/test_mxfp4_moe_has_bias.py new file mode 100644 index 000000000..daa960a28 --- /dev/null +++ b/tests/test_mxfp4_moe_has_bias.py @@ -0,0 +1,262 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +""" +Regression test for the MXFP4 MoE uninitialized bias bug. + +Root cause: + FusedMoE defaulted has_bias=True, but Qwen3MoE experts have no bias + in the checkpoint. Mxfp4MoEMethod.create_weights allocated bias + parameters with torch.empty() that never got loaded, causing the + kernel to add garbage bias to every expert output. + +Fix: + - FusedMoE default changed to has_bias=False + - Qwen3MoeSparseMoeBlock and Qwen3NextSparseMoeBlock explicitly + pass has_bias=False +""" + +import sys +import unittest + +# Clear cached atom modules (conftest.py stubs) +for mod_name in list(sys.modules): + if mod_name.startswith("atom"): + del sys.modules[mod_name] + + +class TestFusedMoEDefaultHasBias(unittest.TestCase): + """FusedMoE must default to has_bias=False.""" + + def test_default_is_false(self): + import inspect + from atom.model_ops.moe import FusedMoE + + sig = inspect.signature(FusedMoE.__init__) + default = sig.parameters["has_bias"].default + self.assertFalse( + default, + "FusedMoE default has_bias must be False to prevent " + "uninitialized bias when checkpoint has no expert bias", + ) + + +class TestQwen3MoeExplicitHasBias(unittest.TestCase): + """Qwen3 MoE models must explicitly pass has_bias=False.""" + + def _check_source_has_bias_false(self, module_path: str, class_name: str): + import importlib + import inspect + + mod = importlib.import_module(module_path) + cls = getattr(mod, class_name) + source = inspect.getsource(cls.__init__) + self.assertIn( + "has_bias=False", + source, + f"{class_name} must pass has_bias=False to FusedMoE", + ) + + def test_qwen3_moe_sparse_block(self): + self._check_source_has_bias_false( + "atom.models.qwen3_moe", "Qwen3MoeSparseMoeBlock" + ) + + def test_qwen3_next_sparse_block(self): + self._check_source_has_bias_false( + "atom.models.qwen3_next", "Qwen3NextSparseMoeBlock" + ) + + +class TestGptOssKeepsBias(unittest.TestCase): + """gpt_oss explicitly uses has_bias=True and must not be affected.""" + + def test_gpt_oss_has_bias_true(self): + import inspect + from atom.models.gpt_oss import MLPBlock as SparseMoeBlock + + source = inspect.getsource(SparseMoeBlock.__init__) + self.assertIn( + "has_bias=True", + source, + "gpt_oss SparseMoeBlock must keep has_bias=True", + ) + + +class TestMxfp4NoBiasCreated(unittest.TestCase): + """When has_bias=False, Mxfp4MoEMethod must not create bias parameters.""" + + def test_no_bias_when_has_bias_false(self): + import torch + from unittest.mock import MagicMock + + from atom.model_ops.moe import Mxfp4MoEMethod + from atom.config import QuantizationConfig + from aiter import QuantType + + qc = QuantizationConfig( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + quant_method="quark", + ) + moe_config = MagicMock() + method = Mxfp4MoEMethod(qc, moe_config) + + # Create a mock layer with has_bias=False + layer = MagicMock() + layer.has_bias = False + layer.hidden_size = 6144 + layer.intermediate_size_per_partition = 2560 + layer.activation = "silu" + + # Track what register_parameter is called with + registered = {} + + def mock_register(name, param): + registered[name] = param + + layer.register_parameter = mock_register + + method.create_weights( + layer=layer, + num_experts=8, + hidden_size=6144, + intermediate_size_per_partition=2560, + params_dtype=torch.float4_e2m1fn_x2, + weight_loader=lambda *a: None, + ) + + # Bias should be None when has_bias=False + self.assertIsNone( + registered.get("w13_bias"), + "w13_bias must be None when has_bias=False", + ) + self.assertIsNone( + registered.get("w2_bias"), + "w2_bias must be None when has_bias=False", + ) + + def test_bias_created_when_has_bias_true(self): + import torch + from unittest.mock import MagicMock + + from atom.model_ops.moe import Mxfp4MoEMethod + from atom.config import QuantizationConfig + from aiter import QuantType + + qc = QuantizationConfig( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + quant_method="quark", + ) + moe_config = MagicMock() + method = Mxfp4MoEMethod(qc, moe_config) + + # Create a mock layer with has_bias=True + layer = MagicMock() + layer.has_bias = True + layer.hidden_size = 6144 + layer.intermediate_size_per_partition = 2560 + layer.activation = "silu" + + registered = {} + + def mock_register(name, param): + registered[name] = param + + layer.register_parameter = mock_register + + method.create_weights( + layer=layer, + num_experts=8, + hidden_size=6144, + intermediate_size_per_partition=2560, + params_dtype=torch.float4_e2m1fn_x2, + weight_loader=lambda *a: None, + ) + + # Bias should be a Parameter when has_bias=True + self.assertIsNotNone(registered.get("w13_bias")) + self.assertIsInstance(registered["w13_bias"], torch.nn.Parameter) + self.assertIsNotNone(registered.get("w2_bias")) + self.assertIsInstance(registered["w2_bias"], torch.nn.Parameter) + + +class TestSwiGLUInterleavingWithoutBias(unittest.TestCase): + """SwiGLU weight interleaving must happen regardless of has_bias. + + Root cause: + process_weights_after_loading guarded the SwiGLU interleaving branch + on ``layer.w13_bias is not None``. When has_bias=False (no bias), + it fell through to the generic 'else' branch that uses different + shuffling functions (shuffle_weights + e8m0_shuffle) which do NOT + interleave gate/up weights. The a16w4 kernel still expects + interleaved weights -> garbage output. + + Fix: + Change condition from + ``layer.activation == ActivationType.Swiglu and layer.w13_bias is not None`` + to + ``layer.activation == ActivationType.Swiglu`` + and guard only the bias interleaving on ``layer.w13_bias is not None``. + """ + + def test_swiglu_branch_condition_no_bias_check(self): + """The SwiGLU branch must NOT require bias to be present.""" + import inspect + from atom.model_ops.moe import Mxfp4MoEMethod + + source = inspect.getsource(Mxfp4MoEMethod.process_weights_after_loading) + + # The condition should be just ActivationType.Swiglu, without "and ... bias" + self.assertIn( + "layer.activation == ActivationType.Swiglu:", + source.replace("\n", ""), + "SwiGLU branch must trigger on activation type alone, " + "not conditionally on bias presence", + ) + + # Bias interleaving should be guarded separately + self.assertIn( + "if layer.w13_bias is not None:", + source, + "Bias interleaving should be a separate conditional inside " + "the SwiGLU branch", + ) + + def test_swiglu_branch_does_not_couple_bias_and_shuffle(self): + """Ensure the old coupled condition is gone.""" + import inspect + from atom.model_ops.moe import Mxfp4MoEMethod + + source = inspect.getsource(Mxfp4MoEMethod.process_weights_after_loading) + + self.assertNotIn( + "Swiglu and layer.w13_bias is not None", + source, + "Old coupled condition (Swiglu AND bias) must be removed", + ) + + +class TestQwen3MoeQKNormShape(unittest.TestCase): + """Qwen3MoeAttention must apply q/k norm per-head, not on flattened vectors.""" + + def test_qk_norm_is_per_head(self): + import inspect + from atom.models.qwen3_moe import Qwen3MoeAttention + + source = inspect.getsource(Qwen3MoeAttention.forward) + self.assertIn( + "self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view", + source, + "q_norm must reshape q to [tokens, num_heads, head_dim] before RMSNorm", + ) + self.assertIn( + "self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view", + source, + "k_norm must reshape k to [tokens, num_kv_heads, head_dim] before RMSNorm", + ) + + +if __name__ == "__main__": + unittest.main()