From ad4cef076eb83fb0e67dc2d61d8adb120f328fb9 Mon Sep 17 00:00:00 2001 From: Thiago Rocha Date: Fri, 6 Feb 2026 22:07:46 +0000 Subject: [PATCH 1/2] Add support for Kimi-K2 and Kimi-K25 models in configuration and model runner - Updated configuration to include mappings for Kimi-K2 and Kimi-K25. - Enhanced ModelRunner to handle Kimi-K2 in model type checks. - Modified model loading to skip unnecessary weight prefixes for multimodal models. --- atom/config.py | 77 ++++++++ atom/model_engine/model_runner.py | 2 + atom/model_loader/loader.py | 8 + atom/model_ops/moe.py | 306 ++++++++++++++++++++++++++++++ atom/models/deepseek_v2.py | 7 +- atom/models/kimi_k25.py | 92 +++++++++ recipes/Kimi-K2.5.md | 67 +++++++ 7 files changed, 558 insertions(+), 1 deletion(-) create mode 100644 atom/models/kimi_k25.py create mode 100644 recipes/Kimi-K2.5.md diff --git a/atom/config.py b/atom/config.py index 677ae41c9..31367ab72 100644 --- a/atom/config.py +++ b/atom/config.py @@ -255,6 +255,7 @@ def __init__( quant_name="", quant_method=None, exclude_layers: Optional[list[str]] = None, + quark_w4a8: bool = False, ): super().__init__() self["quant_type"] = quant_type if quant_type is not None else QuantType.No @@ -263,6 +264,9 @@ def __init__( self["is_dynamic"] = is_dynamic self["quant_method"] = quant_method self["exclude_layers"] = exclude_layers if exclude_layers is not None else [] + # When True the checkpoint stores INT4-packed expert weights that must + # be dequantised and re-quantised to FP8 during model loading. + self["quark_w4a8"] = quark_w4a8 def get_name(self): return self["quant_name"] @@ -290,6 +294,27 @@ def compute_hash(self) -> str: return hashlib.sha256(str(factors).encode()).hexdigest() +def _is_quark_w4a8(orig_quant_config: dict) -> bool: + """Detect Quark W4A8 checkpoint format. + + The signature is ``quant_method == "quark"`` with a + ``global_quant_config.weight`` list that contains an INT4 entry and + FP8 activations. + """ + if orig_quant_config.get("quant_method") != "quark": + return False + global_cfg = orig_quant_config.get("global_quant_config", {}) + weight_specs = global_cfg.get("weight", []) + if not isinstance(weight_specs, list): + return False + has_int4_weight = any( + w.get("dtype", "").startswith("int4") for w in weight_specs if isinstance(w, dict) + ) + input_cfg = global_cfg.get("input_tensors") or {} + has_fp8_act = isinstance(input_cfg, dict) and "fp8" in input_cfg.get("dtype", "") + return has_int4_weight and has_fp8_act + + def get_quant_config(config: PretrainedConfig) -> QuantizationConfig: torch_dtype = getattr(config, "torch_dtype", "bf16") orig_quant_config = getattr(config, "quantization_config", None) @@ -300,6 +325,25 @@ def get_quant_config(config: PretrainedConfig) -> QuantizationConfig: ) quant_method = orig_quant_config.get("quant_method", None) + + # ------------------------------------------------------------------ + # Quark W4A8: INT4 weights + FP8 activations. At runtime the INT4 + # expert weights are dequantised and re-quantised to FP8, so we + # advertise FP8 per-tensor as the runtime format and set a flag so + # that the MoE layer knows to do the conversion. + # ------------------------------------------------------------------ + if _is_quark_w4a8(orig_quant_config): + logger.info("Detected Quark W4A8 checkpoint – will convert INT4 experts to FP8") + exclude_layers = orig_quant_config.get("exclude", None) + return QuantizationConfig( + quant_type=QuantType.per_Tensor, + quant_dtype=d_dtypes["fp8"], + is_dynamic=True, + quant_method=quant_method, + exclude_layers=exclude_layers, + quark_w4a8=True, + ) + RE_QUANT_BLOCKSIZE = r"\'(?:group_size|weight_block_size)\'\:\s*(?:\[\n*)\s*(\d+)," orig_quant_config_str = str(orig_quant_config) if quant_method == "compressed-tensors" or "channel'," in orig_quant_config_str: @@ -371,6 +415,13 @@ def get_quant_config(config: PretrainedConfig) -> QuantizationConfig: _CONFIG_REGISTRY: dict[str, str] = { "deepseek_v32": "deepseek_v3", + "kimi_k2": "deepseek_v3", +} + + +_MULTIMODAL_MODEL_TYPES: dict[str, str] = { + # Maps multimodal model_type -> key in config_dict for the text sub-config + "kimi_k25": "text_config", } @@ -386,6 +437,32 @@ def _get_hf_token() -> str | None: return token return None + # For multimodal models, extract the text sub-config so the rest of ATOM + # (which is text-only today) works transparently. + if model_type in _MULTIMODAL_MODEL_TYPES: + text_config_key = _MULTIMODAL_MODEL_TYPES[model_type] + text_config_dict = config_dict.get(text_config_key, {}).copy() + # Remove auto_map to avoid trust_remote_code issues + text_config_dict.pop("auto_map", None) + # Propagate quantization_config from root level into text config + # (quantization_config lives alongside text_config, not inside it). + if "quantization_config" not in text_config_dict and "quantization_config" in config_dict: + text_config_dict["quantization_config"] = config_dict["quantization_config"] + text_model_type = text_config_dict.get("model_type", "deepseek_v3") + mapped_type = _CONFIG_REGISTRY.get(text_model_type, text_model_type) + config_class = AutoConfig.for_model(mapped_type) + hf_config = config_class.from_dict(text_config_dict) + # Override architectures so that ATOM selects the correct model class + # which can handle the multimodal weight prefix during loading. + original_arch = config_dict.get("architectures", []) + if original_arch: + hf_config.architectures = original_arch + # Propagate top-level token IDs if missing in text config + for field in ("bos_token_id", "eos_token_id", "pad_token_id"): + if getattr(hf_config, field, None) is None and field in config_dict: + setattr(hf_config, field, config_dict[field]) + return hf_config + if model_type in _CONFIG_REGISTRY: config_class = AutoConfig.for_model(_CONFIG_REGISTRY[model_type]) return config_class.from_pretrained( diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 6861af392..877cde8a6 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -54,6 +54,7 @@ "DeepseekV32ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM", "GptOssForCausalLM": "atom.models.gpt_oss.GptOssForCausalLM", "Glm4MoeForCausalLM": "atom.models.glm4_moe.Glm4MoeForCausalLM", + "KimiK25ForConditionalGeneration": "atom.models.kimi_k25.KimiK25ForCausalLM", } # seed = 34567 # np.random.seed(seed) @@ -628,6 +629,7 @@ def is_deepseek_mla(self) -> bool: "deepseek_v3", "deepseek_v32", "deepseek_mtp", + "kimi_k2", ): return self.hf_text_config.kv_lora_rank is not None elif self.hf_text_config.model_type == "eagle": diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index 032b51385..d65e40742 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -89,6 +89,7 @@ def load_model( ): packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) weights_mapping = getattr(model, "weights_mapping", {}) + skip_weight_prefixes = getattr(model, "skip_weight_prefixes", []) params_dict = dict(model.named_parameters()) with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] @@ -100,6 +101,13 @@ def load_model( continue if name.endswith("kv_scale"): continue + # Skip weights matching model-defined prefixes (e.g. vision encoder + # weights in multimodal checkpoints that are not needed for text-only + # inference). + if skip_weight_prefixes and any( + name.startswith(p) for p in skip_weight_prefixes + ): + continue if spec_decode: spec_layer = get_spec_layer_idx_from_weight_name(hf_config, name) if spec_layer is None: diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index b22fe317f..3c0fa5be9 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1684,6 +1684,297 @@ def apply( ) +class QuarkW4A8MoEMethod(FusedMoEMethodBase): + """MoE method for Quark W4A8 checkpoints (INT4 weights + FP8 activations). + + At model-load time the INT4-packed expert weights are dequantised via + their dual scales (per-tensor ``weight_scale`` + per-channel + ``weight_scale_2``), then re-quantised to FP8 per-tensor. After that + the standard FP8 MoE inference kernels are used. + """ + + def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): + super().__init__(moe) + self.quant_config = quant_config + self.quant_type = QuantType.per_Tensor # runtime FP8 per-tensor + self.quant_dtype = quant_config["quant_dtype"] + + # ------------------------------------------------------------------ # + # Weight allocation # + # ------------------------------------------------------------------ # + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + self.num_experts = num_experts + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size_per_partition + + # --- staging buffers for INT4-packed weights (stored as INT32) --- + # gate/up fused: [E, 2*N, K//8] int32 (K = hidden_size) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 8, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down: [E, K_out, N//8] int32 (K_out = hidden_size) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 8, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # --- per-tensor scales (weight_scale from checkpoint) --- + # w13: [E, 2] (one scalar each for gate and up per expert) + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # --- per-channel scales (weight_scale_2 from checkpoint) --- + w13_weight_scale_2 = torch.nn.Parameter( + torch.ones( + num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) + + w2_weight_scale_2 = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) + + # No input scales – activations are quantised dynamically at runtime. + layer.w13_input_scale = None + layer.w2_input_scale = None + + # ------------------------------------------------------------------ # + # INT4 unpacking helper # + # ------------------------------------------------------------------ # + @staticmethod + def _unpack_int4_to_bf16(packed: torch.Tensor) -> torch.Tensor: + """Unpack 8 signed INT4 values packed per INT32 element. + + Args: + packed: ``[..., K // 8]`` dtype ``int32`` + Returns: + ``[..., K]`` dtype ``bfloat16`` + """ + # View as uint8: each int32 → 4 bytes, each byte → 2 nibbles + u8 = packed.contiguous().view(torch.uint8) # [..., K//8 * 4] + low = (u8 & 0x0F).to(torch.int8) + high = ((u8 >> 4) & 0x0F).to(torch.int8) + # Sign-extend 4-bit → 8-bit + low = torch.where(low >= 8, low - 16, low) + high = torch.where(high >= 8, high - 16, high) + # Interleave low/high nibbles: byte order is little-endian + interleaved = torch.stack([low, high], dim=-1) # [..., K//8*4, 2] + out_shape = list(packed.shape[:-1]) + [packed.shape[-1] * 8] + return interleaved.reshape(out_shape).to(torch.bfloat16) + + # ------------------------------------------------------------------ # + # Post-load conversion: INT4 → BF16 → FP8 # + # ------------------------------------------------------------------ # + def process_weights_after_loading(self, layer: nn.Module) -> None: + num_experts = self.num_experts + intermediate = self.intermediate_size + hidden = self.hidden_size + + # Allocate FP8 destination buffers + device = layer.w13_weight.device + fp8_dtype = torch.float8_e4m3fn + fp8_w13 = torch.empty( + num_experts, 2 * intermediate, hidden, dtype=fp8_dtype, device=device + ) + fp8_w2 = torch.empty( + num_experts, hidden, intermediate, dtype=fp8_dtype, device=device + ) + fp8_w13_scale = torch.empty(num_experts, 2, dtype=torch.float32, device=device) + fp8_w2_scale = torch.empty(num_experts, dtype=torch.float32, device=device) + + quant_func = get_hip_quant(self.quant_type) + + for expert_id in range(num_experts): + # ---- w13 (gate + up) ---- + for shard_id in range(2): + start = shard_id * intermediate + end = start + intermediate + + # 1) Unpack INT4 → BF16 + bf16_w = self._unpack_int4_to_bf16( + layer.w13_weight.data[expert_id, start:end, :] + ) # [intermediate, hidden] + + # 2) Apply per-channel scale (weight_scale_2) + per_ch = layer.w13_weight_scale_2.data[expert_id, start:end] + bf16_w = bf16_w * per_ch.unsqueeze(1) + + # 3) Apply per-tensor scale (weight_scale) + per_tensor = layer.w13_weight_scale.data[expert_id, shard_id] + bf16_w = (bf16_w * per_tensor).to(torch.bfloat16) + + # 4) Quantise to FP8 + fp8_w13[expert_id, start:end, :], fp8_w13_scale[expert_id, shard_id] = ( + quant_func(bf16_w) + ) + + # ---- w2 (down) ---- + bf16_w = self._unpack_int4_to_bf16( + layer.w2_weight.data[expert_id] + ) # [hidden, intermediate] + + per_ch = layer.w2_weight_scale_2.data[expert_id] # [hidden] + bf16_w = bf16_w * per_ch.unsqueeze(1) + + per_tensor = layer.w2_weight_scale.data[expert_id] + bf16_w = (bf16_w * per_tensor).to(torch.bfloat16) + + fp8_w2[expert_id], fp8_w2_scale[expert_id] = quant_func(bf16_w) + + # Merge w13 per-shard scales to a single per-expert scale + # (required by the FP8 MoE kernel). + max_w13_scales = fp8_w13_scale.max(dim=1).values + shard_size = intermediate + for expert_id in range(num_experts): + for shard_id in range(2): + start = shard_id * shard_size + end = start + shard_size + dq_weight = per_tensor_dequantize( + fp8_w13[expert_id][start:end, :], + fp8_w13_scale[expert_id][shard_id], + ) + fp8_w13[expert_id][start:end, :], _ = quant_func( + dq_weight, max_w13_scales[expert_id] + ) + + # Replace layer parameters with FP8 versions + layer.w13_weight = nn.Parameter(fp8_w13, requires_grad=False) + layer.w2_weight = nn.Parameter(fp8_w2, requires_grad=False) + layer.w13_weight_scale = nn.Parameter(max_w13_scales, requires_grad=False) + layer.w2_weight_scale = nn.Parameter(fp8_w2_scale, requires_grad=False) + + # Clean up staging scale buffers + if hasattr(layer, "w13_weight_scale_2"): + del layer.w13_weight_scale_2 + if hasattr(layer, "w2_weight_scale_2"): + del layer.w2_weight_scale_2 + + # Shuffle weights for AITER kernel layout + shuffle_weights(layer.w13_weight, layer.w2_weight) + + # ------------------------------------------------------------------ # + # Inference – identical to Fp8MoEMethod (per-tensor, dynamic act) # + # ------------------------------------------------------------------ # + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=None, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: ActivationType = ActivationType.Silu, + ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + num_fused_shared_experts=layer.num_fused_shared_experts, + routed_scaling_factor=layer.routed_scaling_factor, + ) + # per_Tensor doesn't support num_local_tokens → no mori + if self.fused_experts is None: + return torch.ops.aiter.rocm_aiter_fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + expert_mask=expert_map, + activation=activation, + quant_type=self.quant_type, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation=activation, + quant_type=self.quant_type, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + + def determine_expert_map( ep_size: int, ep_rank: int, global_num_experts: int ) -> Tuple[int, Optional[torch.Tensor]]: @@ -1931,6 +2222,9 @@ def __init__( self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( moe ) + elif quant_config.get("quark_w4a8", False): + # Quark W4A8: INT4 checkpoint weights converted to FP8 at load time + self.quant_method = QuarkW4A8MoEMethod(quant_config, moe) elif ( quant_method_str == "compressed-tensors" and quant_config["quant_dtype"] == dtypes.fp8 @@ -2299,6 +2593,18 @@ def weight_loader( # Case weight scales, zero_points and offset if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name: + # Quark W4A8 checkpoints have both weight_scale (per-tensor) and + # weight_scale_2 (per-channel). Route weight_scale_2 through the + # per-channel loader regardless of the runtime quant_type. + if "weight_scale_2" in weight_name: + self._load_per_channel_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank, + ) + return # load the weight scales and zp based on the quantization scheme # supported weight scales/zp can be found in # FusedMoeWeightScaleSupported diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 4a88fb863..80426329f 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1269,7 +1269,12 @@ def __init__( base_quant_config = None else: source_quant_dtype = None - base_quant_config = quant_config + # Check exclude patterns (e.g. W4A8 checkpoints exclude attention) + if should_ignore_layer(quant_config, prefix): + quant_config = None + base_quant_config = None + else: + base_quant_config = quant_config if self.q_lora_rank is not None: # self.q_a_proj = ReplicatedLinear(self.hidden_size, diff --git a/atom/models/kimi_k25.py b/atom/models/kimi_k25.py new file mode 100644 index 000000000..eb4aaafb6 --- /dev/null +++ b/atom/models/kimi_k25.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Inference-only Kimi-K2.5 model (text-only backbone). + +Kimi-K2.5 is a multimodal model whose language backbone is a DeepseekV3-style +MoE transformer with MLA attention. For text-only serving we load only the +``language_model.*`` weights and delegate to the existing +:class:`DeepseekV2ForCausalLM` implementation. + +Vision encoder and multimodal projector weights are skipped during loading +via :pyattr:`skip_weight_prefixes`. +""" + +from typing import Optional, Union + +import torch +from torch import nn + +from atom.config import Config +from atom.models.deepseek_v2 import DeepseekV2ForCausalLM +from atom.models.utils import IntermediateTensors + + +class KimiK25ForCausalLM(nn.Module): + """Kimi-K2.5 text-only wrapper around :class:`DeepseekV2ForCausalLM`. + + The HuggingFace checkpoint stores the LLM weights under the + ``language_model.*`` prefix. By placing the underlying model as + ``self.language_model``, PyTorch's parameter naming automatically + matches the checkpoint layout so no explicit prefix stripping is needed. + + Vision tower and multimodal projector weights are excluded via + :pyattr:`skip_weight_prefixes` which the model loader respects. + """ + + # Weight prefixes that should be silently skipped during loading + # (these belong to the vision encoder / MM projector that we don't use). + skip_weight_prefixes = [ + "vision_tower.", + "mm_projector.", + ] + + def __init__( + self, + atom_config: Config, + prefix: str = "", + ): + super().__init__() + self.config = atom_config.hf_config + + # The underlying LLM – named ``language_model`` so that its parameter + # names match the ``language_model.*`` keys in the checkpoint. + self.language_model = DeepseekV2ForCausalLM( + atom_config=atom_config, + prefix="", + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + # ---- properties forwarded to the inner model ---- + + @property + def packed_modules_mapping(self): + return self.language_model.packed_modules_mapping + + # ---- forward / inference API ---- + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.language_model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + return self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.language_model.get_expert_mapping() diff --git a/recipes/Kimi-K2.5.md b/recipes/Kimi-K2.5.md new file mode 100644 index 000000000..4b05e1021 --- /dev/null +++ b/recipes/Kimi-K2.5.md @@ -0,0 +1,67 @@ +# Kimi-K2.5 Usage Guide + +[Kimi-K2.5](https://huggingface.co/moonshotai/Kimi-K2.5) is a native multimodal agentic model developed by Moonshot AI, built through continual pretraining on approximately 15 trillion mixed visual and text tokens atop Kimi-K2-Base. + +ATOM currently supports the **text-only** backbone of Kimi-K2.5 (i.e. the DeepseekV3-style MoE language model with MLA attention). The model uses native INT4 quantization (`compressed-tensors`, group_size=32) for the routed MoE expert weights. + +## Preparing environment +Pull the nightly docker from https://hub.docker.com/r/rocm/atom/. +All the operations below will be executed inside the container. + +## Launching server +ATOM supports running the model with different parallelism, e.g., tensor parallel, expert parallel, data parallel. +Here we consider the parallelism of TP4 as an example. + +### Serving on 4xMI355 GPUs + +```bash +#!/bin/bash +export HIP_VISIBLE_DEVICES=0,1,2,3 + +python -m atom.entrypoints.openai_server \ + --model moonshotai/Kimi-K2.5 \ + --trust-remote-code \ + -tp 4 \ + --kv_cache_dtype fp8 +``` + +**Notes**: +- The `--trust-remote-code` flag is required for loading the model's custom tokenizer. +- Kimi-K2.5 uses a DeepseekV3-style architecture with MLA attention, so it leverages the same optimized kernels (MLA, FP8 KV cache, etc.) as DeepSeek models. +- The model uses native INT4 quantization for routed MoE expert weights via `compressed-tensors`. + +## Performance baseline + +The following script can be used to benchmark the performance: + +```bash +python -m atom.benchmarks.benchmark_serving \ + --model=moonshotai/Kimi-K2.5 --backend=vllm --base-url=http://localhost:$PORT \ + --trust-remote-code --dataset-name=random \ + --random-input-len=${ISL} --random-output-len=${OSL} \ + --random-range-ratio 0.8 \ + --num-prompts=$(( $CONC * 10 )) \ + --max-concurrency=$CONC \ + --request-rate=inf --ignore-eos \ + --save-result --result-dir=${result_dir} --result-filename=$RESULT_FILENAME.json \ + --percentile-metrics="ttft,tpot,itl,e2el" +``` + +### Accuracy test +You can verify accuracy using the lm_eval framework: +```bash +lm_eval \ +--model local-completions \ +--model_args model=moonshotai/Kimi-K2.5,base_url=http://localhost:8000/v1/completions,num_concurrent=64,max_retries=3,tokenized_requests=False,trust_remote_code=True \ +--tasks gsm8k \ +--num_fewshot 3 +``` + +## Architecture Details + +Kimi-K2.5 is a multimodal model (`KimiK25ForConditionalGeneration`) that wraps: +- **Language model**: A DeepseekV3-style MoE transformer with MLA attention (61 layers, 7168 hidden size, 64 attention heads, 384 routed experts, 8 experts per token) +- **Vision encoder**: MoonViT3d (not loaded in text-only mode) +- **MM Projector**: PatchMerger (not loaded in text-only mode) + +ATOM loads only the language model backbone, skipping vision and projector weights for efficient text-only inference. From 6cade5a97b63d3d4566903742251230ceab39daa Mon Sep 17 00:00:00 2001 From: Thiago Rocha Date: Mon, 23 Feb 2026 22:14:05 +0000 Subject: [PATCH 2/2] Add MXFP4 handling on top of INT4 --- atom/model_ops/moe.py | 22 ++++++++++++++++------ recipes/Kimi-K2.5.md | 43 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 3c0fa5be9..645beeaf0 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -2381,10 +2381,16 @@ def _load_w13( else: assert shard_id == "w3" expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) - if expert_data.dtype != dtypes.fp4x2: - expert_data.copy_(loaded_weight) - else: + if expert_data.dtype == dtypes.fp4x2: expert_data.view(torch.uint8).copy_(loaded_weight.view(torch.uint8)) + elif loaded_weight.dtype == torch.uint8 and expert_data.element_size() > 1: + # MXFP4 per-expert: packed FP4 bytes stored in a wider-dtype + # parameter (e.g. bf16). Raw byte copy into the leading bytes + # of each element so that shuffle_weight operates correctly. + dst = expert_data.view(torch.uint8) + dst[..., :loaded_weight.shape[-1]].copy_(loaded_weight) + else: + expert_data.copy_(loaded_weight) def _load_w2( self, @@ -2404,10 +2410,14 @@ def _load_w2( shard_dim, shard_size * tp_rank, shard_size ) # w2, down_proj: Load into only logical weight of w2. - if expert_data.dtype != dtypes.fp4x2: - expert_data.copy_(loaded_weight) - else: + if expert_data.dtype == dtypes.fp4x2: expert_data.view(torch.uint8).copy_(loaded_weight.view(torch.uint8)) + elif loaded_weight.dtype == torch.uint8 and expert_data.element_size() > 1: + # MXFP4 per-expert: same raw byte copy as _load_w13. + dst = expert_data.view(torch.uint8) + dst[..., :loaded_weight.shape[-1]].copy_(loaded_weight) + else: + expert_data.copy_(loaded_weight) def _load_single_value( self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int diff --git a/recipes/Kimi-K2.5.md b/recipes/Kimi-K2.5.md index 4b05e1021..4f011a216 100644 --- a/recipes/Kimi-K2.5.md +++ b/recipes/Kimi-K2.5.md @@ -2,7 +2,12 @@ [Kimi-K2.5](https://huggingface.co/moonshotai/Kimi-K2.5) is a native multimodal agentic model developed by Moonshot AI, built through continual pretraining on approximately 15 trillion mixed visual and text tokens atop Kimi-K2-Base. -ATOM currently supports the **text-only** backbone of Kimi-K2.5 (i.e. the DeepseekV3-style MoE language model with MLA attention). The model uses native INT4 quantization (`compressed-tensors`, group_size=32) for the routed MoE expert weights. +ATOM currently supports the **text-only** backbone of Kimi-K2.5 (i.e. the DeepseekV3-style MoE language model with MLA attention). Two quantized variants are available: + +| Variant | Quantization | Description | +|---------|-------------|-------------| +| **MXFP4** | Quark MXFP4 (w4a4, e8m0 scales, group_size=32) | Routed MoE expert weights in microscale FP4 format. Activations quantised dynamically at runtime. | +| **INT4→FP8** | Quark W4A8 (INT4 weights + FP8 activations) | Routed MoE expert weights stored as INT4, dequantised and re-quantised to FP8 during model loading. | ## Preparing environment Pull the nightly docker from https://hub.docker.com/r/rocm/atom/. @@ -10,9 +15,34 @@ All the operations below will be executed inside the container. ## Launching server ATOM supports running the model with different parallelism, e.g., tensor parallel, expert parallel, data parallel. -Here we consider the parallelism of TP4 as an example. -### Serving on 4xMI355 GPUs +### MXFP4 variant on 8×MI355 GPUs (TP8) + +```bash +#!/bin/bash +export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +HSA_NO_SCRATCH_RECLAIM=1 python -m atom.entrypoints.openai_server \ + --model \ + --trust-remote-code \ + -tp 8 \ + --kv_cache_dtype fp8 +``` + +### INT4→FP8 variant on 4×MI355 GPUs (TP4) + +```bash +#!/bin/bash +export HIP_VISIBLE_DEVICES=0,1,2,3 + +python -m atom.entrypoints.openai_server \ + --model \ + --trust-remote-code \ + -tp 4 \ + --kv_cache_dtype fp8 +``` + +### BF16 (unquantized) on 4×MI355 GPUs (TP4) ```bash #!/bin/bash @@ -28,7 +58,8 @@ python -m atom.entrypoints.openai_server \ **Notes**: - The `--trust-remote-code` flag is required for loading the model's custom tokenizer. - Kimi-K2.5 uses a DeepseekV3-style architecture with MLA attention, so it leverages the same optimized kernels (MLA, FP8 KV cache, etc.) as DeepSeek models. -- The model uses native INT4 quantization for routed MoE expert weights via `compressed-tensors`. +- For the MXFP4 variant, `HSA_NO_SCRATCH_RECLAIM=1` is recommended for stability. +- Non-MoE layers (attention, shared experts, dense MLPs) remain in BF16 for all quantized variants. ## Performance baseline @@ -36,7 +67,7 @@ The following script can be used to benchmark the performance: ```bash python -m atom.benchmarks.benchmark_serving \ - --model=moonshotai/Kimi-K2.5 --backend=vllm --base-url=http://localhost:$PORT \ + --model= --backend=vllm --base-url=http://localhost:$PORT \ --trust-remote-code --dataset-name=random \ --random-input-len=${ISL} --random-output-len=${OSL} \ --random-range-ratio 0.8 \ @@ -52,7 +83,7 @@ You can verify accuracy using the lm_eval framework: ```bash lm_eval \ --model local-completions \ ---model_args model=moonshotai/Kimi-K2.5,base_url=http://localhost:8000/v1/completions,num_concurrent=64,max_retries=3,tokenized_requests=False,trust_remote_code=True \ +--model_args model=,base_url=http://localhost:8000/v1/completions,num_concurrent=64,max_retries=3,tokenized_requests=False,trust_remote_code=True \ --tasks gsm8k \ --num_fewshot 3 ```