Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion atom/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ def __init__(self, atom_config: Config, prefix: str, layer_idx: int) -> None:
)

quant_config = atom_config.quant_config
if quant_config["quant_dtype"] == dtypes.fp4x2:
if quant_config is not None and hasattr(quant_config, "resolve"):
_mtp_spec = quant_config.resolve(prefix)
if _mtp_spec.quant_dtype == dtypes.fp4x2:
# MTP layers don't support FP4 — fall back to unquantized
quant_config = QuantizationConfig()
elif quant_config is not None and quant_config["quant_dtype"] == dtypes.fp4x2:
quant_config = QuantizationConfig()

self.mtp_block = DeepseekV2DecoderLayer(
Expand Down
21 changes: 17 additions & 4 deletions atom/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,9 +1245,21 @@ def __init__(
self.max_position_embeddings = max_position_embeddings
self.layer_num = layer_num

# Resolve the per-layer quant spec for this attention block.
# For mixed-precision models the attention layers may use a
# different dtype (e.g. FP8) than the global config (e.g. MXFP4).
_attn_spec = (
quant_config.resolve(prefix)
if quant_config is not None
else None
)
_attn_quant_dtype = (
_attn_spec.quant_dtype if _attn_spec is not None else None
)

# For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs,
# For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs
if quant_config["quant_dtype"] == dtypes.fp4x2:
if _attn_quant_dtype == dtypes.fp4x2:
# normally linear layers in attn share the same quant config
if should_ignore_layer(quant_config, prefix):
source_quant_dtype = None
Expand Down Expand Up @@ -1276,6 +1288,7 @@ def __init__(
bias=False,
quant_config=quant_config,
source_quant_dtype=source_quant_dtype,
prefix=f"{prefix}.fused_qkv_a_proj",
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
Expand Down Expand Up @@ -1407,10 +1420,10 @@ def __init__(
self.quant_dtype = None
self.fuse_qknorm_quant = False
if quant_config is not None and ENABLE_DS_QKNORM_QUANT_FUSION:
if quant_config["quant_dtype"] == dtypes.fp8 or (
quant_config["quant_dtype"] == dtypes.fp4x2 and use_triton_gemm()
if _attn_quant_dtype == dtypes.fp8 or (
_attn_quant_dtype == dtypes.fp4x2 and use_triton_gemm()
):
self.quant_dtype = quant_config["quant_dtype"]
self.quant_dtype = _attn_quant_dtype
self.fuse_qknorm_quant = True

def forward(
Expand Down