From 060e9cbfc64710a10ee171e2f6ee8157f03d3254 Mon Sep 17 00:00:00 2001 From: Thiago Rocha Date: Wed, 4 Mar 2026 21:50:58 +0000 Subject: [PATCH 1/2] feat: resolve per-layer quant dtype in DeepSeek attention init For mixed-precision models (e.g. MXFP4 MoE + FP8 attention), the attention block must resolve its own per-layer quant spec rather than using the global quant_config['quant_dtype']. - Add _attn_spec / _attn_quant_dtype via quant_config.resolve(prefix) - Use resolved dtype for FP4/FP8 decision in attention init - Pass prefix to MergedReplicatedLinear for fused_qkv_a_proj - Use resolved dtype for fuse_qknorm_quant decision Tested with DeepSeek-R1-0528-moe-mxfp4-other-ptpc on TP=4. --- atom/models/deepseek_v2.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index f0342dce..f12df2cd 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1245,9 +1245,21 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.layer_num = layer_num + # Resolve the per-layer quant spec for this attention block. + # For mixed-precision models the attention layers may use a + # different dtype (e.g. FP8) than the global config (e.g. MXFP4). + _attn_spec = ( + quant_config.resolve(prefix) + if quant_config is not None + else None + ) + _attn_quant_dtype = ( + _attn_spec.quant_dtype if _attn_spec is not None else None + ) + # For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs, # For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs - if quant_config["quant_dtype"] == dtypes.fp4x2: + if _attn_quant_dtype == dtypes.fp4x2: # normally linear layers in attn share the same quant config if should_ignore_layer(quant_config, prefix): source_quant_dtype = None @@ -1276,6 +1288,7 @@ def __init__( bias=False, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=f"{prefix}.fused_qkv_a_proj", ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( @@ -1407,10 +1420,10 @@ def __init__( self.quant_dtype = None self.fuse_qknorm_quant = False if quant_config is not None and ENABLE_DS_QKNORM_QUANT_FUSION: - if quant_config["quant_dtype"] == dtypes.fp8 or ( - quant_config["quant_dtype"] == dtypes.fp4x2 and use_triton_gemm() + if _attn_quant_dtype == dtypes.fp8 or ( + _attn_quant_dtype == dtypes.fp4x2 and use_triton_gemm() ): - self.quant_dtype = quant_config["quant_dtype"] + self.quant_dtype = _attn_quant_dtype self.fuse_qknorm_quant = True def forward( From 6045cef75664dff915c02610ac529977af1b2675 Mon Sep 17 00:00:00 2001 From: Thiago Rocha Date: Fri, 6 Mar 2026 15:55:18 +0000 Subject: [PATCH 2/2] fix: use per-layer resolve for MTP FP4 bypass in deepseek_mtp Instead of checking the global quant_dtype to decide whether to bypass FP4 quantization for MTP layers, use quant_config.resolve(prefix) to check the per-layer spec. This correctly preserves FP8 quantization for MTP layer 61 when the global config is MXFP4 but the layer has an FP8 per_Token override (as in the PTPC model format). --- atom/models/deepseek_mtp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/atom/models/deepseek_mtp.py b/atom/models/deepseek_mtp.py index 5394d49e..a5ccc2cf 100644 --- a/atom/models/deepseek_mtp.py +++ b/atom/models/deepseek_mtp.py @@ -56,7 +56,12 @@ def __init__(self, atom_config: Config, prefix: str, layer_idx: int) -> None: ) quant_config = atom_config.quant_config - if quant_config["quant_dtype"] == dtypes.fp4x2: + if quant_config is not None and hasattr(quant_config, "resolve"): + _mtp_spec = quant_config.resolve(prefix) + if _mtp_spec.quant_dtype == dtypes.fp4x2: + # MTP layers don't support FP4 — fall back to unquantized + quant_config = QuantizationConfig() + elif quant_config is not None and quant_config["quant_dtype"] == dtypes.fp4x2: quant_config = QuantizationConfig() self.mtp_block = DeepseekV2DecoderLayer(