diff --git a/atom/models/deepseek_mtp.py b/atom/models/deepseek_mtp.py index 5394d49e..a5ccc2cf 100644 --- a/atom/models/deepseek_mtp.py +++ b/atom/models/deepseek_mtp.py @@ -56,7 +56,12 @@ def __init__(self, atom_config: Config, prefix: str, layer_idx: int) -> None: ) quant_config = atom_config.quant_config - if quant_config["quant_dtype"] == dtypes.fp4x2: + if quant_config is not None and hasattr(quant_config, "resolve"): + _mtp_spec = quant_config.resolve(prefix) + if _mtp_spec.quant_dtype == dtypes.fp4x2: + # MTP layers don't support FP4 — fall back to unquantized + quant_config = QuantizationConfig() + elif quant_config is not None and quant_config["quant_dtype"] == dtypes.fp4x2: quant_config = QuantizationConfig() self.mtp_block = DeepseekV2DecoderLayer( diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index f0342dce..f12df2cd 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1245,9 +1245,21 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.layer_num = layer_num + # Resolve the per-layer quant spec for this attention block. + # For mixed-precision models the attention layers may use a + # different dtype (e.g. FP8) than the global config (e.g. MXFP4). + _attn_spec = ( + quant_config.resolve(prefix) + if quant_config is not None + else None + ) + _attn_quant_dtype = ( + _attn_spec.quant_dtype if _attn_spec is not None else None + ) + # For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs, # For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs - if quant_config["quant_dtype"] == dtypes.fp4x2: + if _attn_quant_dtype == dtypes.fp4x2: # normally linear layers in attn share the same quant config if should_ignore_layer(quant_config, prefix): source_quant_dtype = None @@ -1276,6 +1288,7 @@ def __init__( bias=False, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=f"{prefix}.fused_qkv_a_proj", ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( @@ -1407,10 +1420,10 @@ def __init__( self.quant_dtype = None self.fuse_qknorm_quant = False if quant_config is not None and ENABLE_DS_QKNORM_QUANT_FUSION: - if quant_config["quant_dtype"] == dtypes.fp8 or ( - quant_config["quant_dtype"] == dtypes.fp4x2 and use_triton_gemm() + if _attn_quant_dtype == dtypes.fp8 or ( + _attn_quant_dtype == dtypes.fp4x2 and use_triton_gemm() ): - self.quant_dtype = quant_config["quant_dtype"] + self.quant_dtype = _attn_quant_dtype self.fuse_qknorm_quant = True def forward(