diff --git a/atom/config.py b/atom/config.py index 8ef9be07..a813765f 100644 --- a/atom/config.py +++ b/atom/config.py @@ -255,6 +255,7 @@ def __init__( quant_name="", quant_method=None, exclude_layers: Optional[list[str]] = None, + packed_components: Optional[dict[str, list[str]]] = None, ): super().__init__() self["quant_type"] = quant_type if quant_type is not None else QuantType.No @@ -263,6 +264,9 @@ def __init__( self["is_dynamic"] = is_dynamic self["quant_method"] = quant_method self["exclude_layers"] = exclude_layers if exclude_layers is not None else [] + self["packed_components"] = ( + packed_components if packed_components is not None else {} + ) def get_name(self): return self["quant_name"] diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index bad55709..9b051f40 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -23,6 +23,7 @@ from atom.model_engine.scheduler import ScheduledBatch, ScheduledBatchOutput from atom.model_engine.sequence import Sequence, SequenceStatus, SequenceType from atom.model_loader.loader import load_model +from atom.models.utils import build_packed_components_mapping from atom.model_ops.rejection_sampler import RejectionSampler from atom.model_ops.sampler import Sampler from atom.spec_decode.eagle import EagleProposer @@ -563,6 +564,7 @@ def __init__(self, rank: int, config: Config): ) model_class = resolve_obj_by_qualname(support_model_arch_dict[hf_config.architectures[0]]) # type: ignore + self.build_inverse_mapping(model_class) self.model = model_class(config) torch.set_default_device(None) load_model(self.model, config.model, config.hf_config, config.load_dummy) @@ -585,6 +587,16 @@ def __init__(self, rank: int, config: Config): if self.config.compilation_config.level == 1: self.model = torch.compile(self.model, fullgraph=True, backend="eager") + def build_inverse_mapping(self, model_class: Any): + # Build inverse mapping from the model class's packed_modules_mapping + # BEFORE instantiation, so that get_quant_config_for_layer can resolve + # packed names (e.g. "gate_up_proj") during layer construction. + packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {}) + if packed_modules_mapping and self.config.quant_config.get("exclude_layers"): + self.config.quant_config["packed_components"] = ( + build_packed_components_mapping(packed_modules_mapping) + ) + def is_deepseek_mla(self) -> bool: if not hasattr(self.hf_text_config, "model_type"): return False diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index 01d83e6e..0f4ece52 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -68,6 +68,56 @@ def _swizzle_mxfp4(quant_tensor, scale): return quant_tensor, InFlexData(), scale +def routing_from_topk(topk_weights, topk_ids, n_expts_tot): + """Convert FusedMoE.select_experts output to triton routing data structures. + + This bridges the gap between ATOM's grouped topk / sigmoid routing + (which triton_kernels routing() does not support) and the triton + matmul_ogs compute kernels. + + Args: + topk_weights: (n_tokens, n_expts_act) routing weights from select_experts + topk_ids: (n_tokens, n_expts_act) expert indices from select_experts + n_expts_tot: total number of experts (global, before EP) + + Returns: + (RoutingData, GatherIndx, ScatterIndx) compatible with triton_kernel_fused_experts + """ + from triton_kernels.routing import ( + RoutingData, + GatherIndx, + ScatterIndx, + compute_expt_data, + ) + + n_tokens, n_expts_act = topk_weights.shape + n_gates_pad = n_tokens * n_expts_act + + # Sort each token's selected experts by expert_id (required by triton kernels) + expt_indx_sorted, sort_indices = torch.sort(topk_ids.int(), dim=1) + expt_scal_sorted = torch.gather(topk_weights, 1, sort_indices.long()) + + # Flatten to 1D + expt_scal = expt_scal_sorted.reshape(-1).to(topk_weights.dtype) + expt_indx = expt_indx_sorted.reshape(-1).to(torch.int32) + + # Sort by expert_id globally so experts are contiguous for the matmul + topk_indx = torch.argsort(expt_indx, stable=True).int() + gate_indx = torch.argsort(topk_indx, stable=True).int() + gate_scal = expt_scal[topk_indx.long()] + + # Histogram of tokens over experts + hist = torch.histc(expt_indx.float(), bins=n_expts_tot, max=n_expts_tot - 1).int() + + # Build routing data structures using triton-accelerated compute_expt_data + gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx) + scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx) + expt_data = compute_expt_data(hist, n_expts_tot, n_gates_pad) + + routing_data = RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data) + return routing_data, gather_indx, scatter_indx + + def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: """ Shrink the given tensor and apply the given view to it. This is @@ -161,26 +211,37 @@ def triton_kernel_fused_experts( if global_num_experts == -1: global_num_experts = E + half_N = N // 2 + if intermediate_cache is None: intermediate_cache = torch.empty( - (batch_dim, M * topk, N // 2), + (batch_dim, M * topk, half_N), device=hidden_states.device, dtype=hidden_states.dtype, ) # Add batch_dim to output buffer because matmul_ogs expects 3D output intermediate_cache = _resize_cache( - intermediate_cache, (batch_dim, M * topk, N // 2) + intermediate_cache, (batch_dim, M * topk, half_N) ) output_tensor = _resize_cache(output_tensor, (batch_dim, M, K)) - act = FusedActivation( - FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), - (swiglu_alpha, swiglu_limit), - 2, - ) gammas = routing_data.gate_scal if routing_data else None + # NOTE: We intentionally do NOT use the triton fused SwiGLU activation + # because it expects interleaved [gate0, up0, gate1, up1, ...] layout + # while our w13 weights produce concatenated [gate | up] output. + # It also uses a non-standard formula: s*sigmoid(alpha*s)*(linear+1) + # with alpha=1.702, which differs from the standard SiLU activation + # (x*sigmoid(x)*up) used by most MoE models. + # Instead, we compute the matmul without fused activation and apply + # standard silu(gate) * up manually. + raw_intermediate = torch.empty( + (batch_dim, M * topk, N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + matmul_ogs( hidden_states, w1, @@ -189,12 +250,17 @@ def triton_kernel_fused_experts( gather_indx=gather_indx, precision_config=w13_precision_config, gammas=gammas if apply_router_weight_on_input else None, - fused_activation=act, - y=intermediate_cache, + y=raw_intermediate, ) + # Standard SiLU/SwiGLU activation: silu(gate) * up + raw_2d = raw_intermediate.view(M * topk, N) + gate = raw_2d[:, :half_N] + up = raw_2d[:, half_N:] + intermediate_cache[0] = torch.nn.functional.silu(gate) * up + matmul_ogs( - intermediate_cache.view(M * topk, N // 2), + intermediate_cache.view(M * topk, half_N), w2, w2_bias, routing_data, @@ -203,5 +269,6 @@ def triton_kernel_fused_experts( gammas=None if apply_router_weight_on_input else gammas, y=output_tensor, ) + output_tensor = output_tensor.view(M, K) return output_tensor diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 1bd3538a..90e7df97 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -476,9 +476,12 @@ def __init__( bias: bool = False, quant_config: Optional[QuantizationConfig] = None, source_quant_dtype: torch.dtype = None, + prefix: str = "", **kwargs, ): self.tp_dim = 0 + if quant_config is not None and prefix: + quant_config = get_quant_config_for_layer(quant_config, prefix) super().__init__( input_size, output_size, @@ -551,6 +554,7 @@ def __init__( bias: bool = False, quant_config: Optional[QuantizationConfig] = None, source_quant_dtype: torch.dtype = None, + prefix: str = "", **kwargs, ): self.head_size = head_size @@ -582,6 +586,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=prefix, ) def weight_loader( diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index d9dc1e34..b2520c85 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -871,7 +871,68 @@ def apply( activation: ActivationType = ActivationType.Silu, ) -> torch.Tensor: if self.use_triton: - from atom.model_ops.fused_moe_triton import triton_kernel_moe_forward + from atom.model_ops.fused_moe_triton import ( + triton_kernel_moe_forward, + triton_kernel_fused_experts, + routing_from_topk, + ) + + # Check if the model needs custom routing that triton routing() + # does not support (grouped topk, sigmoid scoring, bias correction). + needs_custom_routing = ( + use_grouped_topk + or scoring_func != "softmax" + or e_score_correction_bias is not None + or custom_routing_function is not None + ) + + if needs_custom_routing: + # Use ATOM's full-featured select_experts for routing, + # then triton matmul_ogs for the actual MoE computation. + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + num_fused_shared_experts=layer.num_fused_shared_experts, + routed_scaling_factor=layer.routed_scaling_factor, + ) + + # Convert to triton routing data structures + n_expts_tot = router_logits.shape[-1] + if global_num_experts > 0: + n_expts_tot = global_num_experts + + routing_data, gather_idx, scatter_idx = routing_from_topk( + topk_weights, topk_ids, n_expts_tot + ) + + output = torch.empty_like(x) + _moe_result = triton_kernel_fused_experts( + output, + x, + layer.w13_weight, + layer.w2_weight, + routing_data, + gather_idx, + scatter_idx, + topk=top_k, + activation=activation, + w13_precision_config=self.w13_precision_config, + w2_precision_config=self.w2_precision_config, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) + return _moe_result return triton_kernel_moe_forward( x, @@ -2077,18 +2138,27 @@ def _load_w13( # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - shard_size = expert_data.shape[shard_dim] // 2 + expert_shard_size = expert_data.shape[shard_dim] // 2 + # Derive shard size from loaded_weight (unpadded checkpoint) to avoid + # out-of-bounds when expert_data is padded (e.g. MXFP4 alignment). + load_shard_size = loaded_weight.shape[shard_dim] // self.tp_size loaded_weight = loaded_weight.narrow( - shard_dim, shard_size * tp_rank, shard_size + shard_dim, load_shard_size * tp_rank, load_shard_size ) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": - expert_data = expert_data.narrow(shard_dim, 0, shard_size) + expert_data = expert_data.narrow(shard_dim, 0, expert_shard_size) # w3, up_proj: Load into second logical weight of w13. else: assert shard_id == "w3" - expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data = expert_data.narrow( + shard_dim, expert_shard_size, expert_shard_size + ) + # When expert_data is padded beyond the actual weight size, narrow to + # the loaded weight size so the copy shape matches. + if load_shard_size != expert_shard_size: + expert_data = expert_data.narrow(shard_dim, 0, load_shard_size) if expert_data.dtype != dtypes.fp4x2: expert_data.copy_(loaded_weight) else: @@ -2108,9 +2178,14 @@ def _load_w2( # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] if not load_full: + # Derive shard size from loaded_weight (unpadded checkpoint) to + # avoid out-of-bounds when expert_data is padded (e.g. MXFP4). + load_shard_size = loaded_weight.shape[shard_dim] // self.tp_size loaded_weight = loaded_weight.narrow( - shard_dim, shard_size * tp_rank, shard_size + shard_dim, load_shard_size * tp_rank, load_shard_size ) + if load_shard_size != shard_size: + expert_data = expert_data.narrow(shard_dim, 0, load_shard_size) # w2, down_proj: Load into only logical weight of w2. if expert_data.dtype != dtypes.fp4x2: expert_data.copy_(loaded_weight) diff --git a/atom/models/glm4_moe.py b/atom/models/glm4_moe.py index 31fe1d24..73128138 100644 --- a/atom/models/glm4_moe.py +++ b/atom/models/glm4_moe.py @@ -152,6 +152,7 @@ def __init__( prefix=f"{prefix}.experts", scoring_func="sigmoid", e_score_correction_bias=self.gate.e_score_correction_bias, + has_bias=getattr(config, "moe_ffn_bias", False), config=config, ) diff --git a/atom/models/utils.py b/atom/models/utils.py index 60334d78..d55ffe06 100644 --- a/atom/models/utils.py +++ b/atom/models/utils.py @@ -237,6 +237,38 @@ def fast_topk(values, topk, dim): return torch.topk(values, topk, dim=dim) +def build_packed_components_mapping( + packed_modules_mapping: dict[str, tuple[str, object]], +) -> dict[str, list[str]]: + """Build an inverse mapping from packed parameter names to their original + checkpoint weight names. + + Args: + packed_modules_mapping: Model's mapping from checkpoint weight name to + (packed_param_name, shard_id), e.g.:: + + { + "q_proj": ("qkv_proj", "q"), + "k_proj": ("qkv_proj", "k"), + "v_proj": ("qkv_proj", "v"), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + Returns: + Inverse mapping from packed name to list of checkpoint names, e.g.:: + + { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + """ + inverse: dict[str, list[str]] = {} + for ckpt_name, (packed_name, _shard_id) in packed_modules_mapping.items(): + inverse.setdefault(packed_name, []).append(ckpt_name) + return inverse + + def should_ignore_layer( quantization_config: Optional[QuantizationConfig], prefix: str ) -> bool: @@ -259,6 +291,26 @@ def should_ignore_layer( # case "lm_head". Common practice won't quant lm_head, however. if prefix.split(".")[-1] == exclude_layer: return True + # Handle packed/merged module names (e.g. "gate_up_proj" -> "gate_proj"/"up_proj", + # "qkv_proj" -> "q_proj"/"k_proj"/"v_proj"). The exclude list uses checkpoint + # weight names, but the prefix may use the packed parameter name. + # The mapping is built from the model's own packed_modules_mapping and stored + # on the QuantizationConfig at model init time. + packed_components = quantization_config.get("packed_components", {}) + leaf = prefix.rsplit(".", 1)[-1] if "." in prefix else prefix + if leaf in packed_components: + parent = prefix.rsplit(".", 1)[0] if "." in prefix else "" + for component in packed_components[leaf]: + component_path = f"{parent}.{component}" if parent else component + for exclude_layer in exclude_layers: + if exclude_layer.startswith("re"): + regex_pattern = exclude_layer[3:] + if re.search(regex_pattern, component_path): + return True + elif component_path in exclude_layer: + return True + elif component_path.split(".")[-1] == exclude_layer: + return True return False