From f9d7aac0fa4ef8270fe6e7c1c7172e7d9156296b Mon Sep 17 00:00:00 2001 From: Thiago Rocha Date: Tue, 24 Feb 2026 15:16:24 +0000 Subject: [PATCH 1/3] feat: per-layer quantization config with parser registry Introduce a structured per-layer quantization system that replaces ad-hoc dict-based config with: - LayerQuantSpec: frozen dataclass capturing quant_type, quant_dtype, is_dynamic, quant_method, checkpoint_dtype, and extensible flags. Supports two-phase online quant modeling (init as no_quant, post-load quantize to target spec via needs_online_quant predicate). - ParsedQuantConfig: holds global_spec, per-layer overrides (layer_specs), and exclude_layers list. - Parser registry: QuantConfigParser ABC with register_quant_parser() decorator. Built-in parsers: QuarkParser, CompressedTensorsParser, GenericParser (fallback). Extensible via decorator registration. - QuantizationConfig.resolve(prefix): single resolution point for per-layer config. Priority: layer_specs override > exclude-list > global spec. Supports exact match, prefix match, and re: regex patterns in exclude lists. - LinearBase: accepts optional LayerQuantSpec via layer_spec parameter, deriving quant_type, params_dtype, and source_quant_dtype from it. Falls back to building LayerQuantSpec from dict fields for backward compatibility. - QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear: use quant_config.resolve(prefix) to get per-layer specs, eliminating the need for get_quant_config_for_layer() at construction time. - qwen3_moe.py: wires prefix through attention and decoder layers. - models/utils.py: should_ignore_layer() delegates to resolve() when available, maintaining backward compatibility. All 34 unit tests pass covering: LayerQuantSpec properties, parser registry, QuarkParser/CompressedTensorsParser parsing, resolve() with exclusions/overrides, backward compat (dict access, should_ignore_layer, get_quant_config_for_layer), LinearBase layer_spec plumbing, and QKV/Row parallel linear resolve integration. --- atom/config.py | 112 ++++++++ atom/model_ops/linear.py | 130 ++++++++- atom/models/qwen3_moe.py | 2 + atom/models/utils.py | 20 +- atom/quant_spec.py | 396 +++++++++++++++++++++++++ tests/test_per_layer_quant.py | 525 ++++++++++++++++++++++++++++++++++ 6 files changed, 1166 insertions(+), 19 deletions(-) create mode 100644 atom/quant_spec.py create mode 100644 tests/test_per_layer_quant.py diff --git a/atom/config.py b/atom/config.py index e273fb73c..094813d6c 100644 --- a/atom/config.py +++ b/atom/config.py @@ -13,6 +13,11 @@ from aiter import QuantType from aiter.utility.dtypes import d_dtypes from atom.utils import envs, get_open_port +from atom.quant_spec import ( + LayerQuantSpec, + ParsedQuantConfig, + get_quant_parser, +) from atom.utils.distributed.utils import stateless_init_torch_distributed_process_group from torch.distributed import ProcessGroup, ReduceOp from transformers import AutoConfig, GenerationConfig, PretrainedConfig @@ -251,6 +256,15 @@ def set_splitting_ops_for_v1(self): class QuantizationConfig(dict): + """Model-wide quantization configuration. + + Still inherits from dict for backward compatibility with existing code + that accesses ``quant_config["quant_type"]``, etc. + + New code should prefer the :pyattr:`parsed` attribute and + :pymeth:`resolve` method. + """ + def __init__( self, quant_type=QuantType.No, @@ -259,6 +273,8 @@ def __init__( quant_name="", quant_method=None, exclude_layers: Optional[list[str]] = None, + *, + parsed: Optional[ParsedQuantConfig] = None, ): super().__init__() self["quant_type"] = quant_type if quant_type is not None else QuantType.No @@ -268,9 +284,93 @@ def __init__( self["quant_method"] = quant_method self["exclude_layers"] = exclude_layers if exclude_layers is not None else [] + # --- New: structured parsed config --- + if parsed is not None: + self._parsed = parsed + else: + # Build a ParsedQuantConfig from the scalar fields so that + # manually-constructed QuantizationConfigs still work. + self._parsed = ParsedQuantConfig( + global_spec=LayerQuantSpec( + quant_type=self["quant_type"], + quant_dtype=self["quant_dtype"], + is_dynamic=self["is_dynamic"], + quant_method=self["quant_method"], + ), + exclude_layers=self["exclude_layers"], + ) + + # -- public API -------------------------------------------------------- + + @property + def parsed(self) -> ParsedQuantConfig: + """Access the structured :class:`ParsedQuantConfig`.""" + return self._parsed + + @property + def global_spec(self) -> LayerQuantSpec: + """Shortcut for ``self.parsed.global_spec``.""" + return self._parsed.global_spec + + def resolve(self, prefix: str) -> LayerQuantSpec: + """Return the :class:`LayerQuantSpec` for layer *prefix*. + + Resolution order: + 1. Explicit per-layer override in ``parsed.layer_specs[prefix]``. + 2. Check the exclude list -- if the layer is excluded, return + ``LayerQuantSpec.no_quant()``. + 3. fnmatch-style pattern match in ``parsed.layer_pattern_specs`` + (first matching pattern wins). + 4. Fall back to ``parsed.global_spec``. + """ + from fnmatch import fnmatch + + # 1. Explicit per-layer override + layer_specs = self._parsed.layer_specs + if prefix in layer_specs: + return layer_specs[prefix] + + # 2. Check exclude list + if self._is_excluded(prefix): + return LayerQuantSpec.no_quant() + + # 3. fnmatch-style pattern matching + for pattern, spec in self._parsed.layer_pattern_specs: + if fnmatch(prefix, pattern): + return spec + + # 4. Global default + return self._parsed.global_spec + + # -- backward compat --------------------------------------------------- + def get_name(self): return self["quant_name"] + # -- internals --------------------------------------------------------- + + def _is_excluded(self, prefix: str) -> bool: + """Check whether *prefix* matches the exclude list. + + Uses the same logic as the original ``should_ignore_layer`` + in ``atom.models.utils`` so behaviour is identical. + """ + exclude_layers: list[str] = self._parsed.exclude_layers + if not exclude_layers: + return False + for exclude_layer in exclude_layers: + if exclude_layer.startswith("re"): + # case "re:model.layers.*self_attn.*" + regex_pattern = exclude_layer[3:] + if re.search(regex_pattern, prefix): + return True + elif prefix in exclude_layer: + return True + else: + if prefix.split(".")[-1] == exclude_layer: + return True + return False + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -363,12 +463,24 @@ def get_quant_config(config: PretrainedConfig) -> QuantizationConfig: ) exclude_layers_key = "ignore" exclude_layers = orig_quant_config.get(exclude_layers_key, None) + + # Use the structured parser to build a ParsedQuantConfig that includes + # per-layer pattern overrides (layer_pattern_specs) from the HF config. + # This is needed for models like DeepSeek-R1 MXFP4 which have different + # quantization for attention vs MoE layers. + if quant_method is not None: + parser = get_quant_parser(quant_method) + parsed = parser.parse(orig_quant_config) + else: + parsed = None + return QuantizationConfig( quant_type, quant_dtype, is_dynamic, quant_method=quant_method, exclude_layers=exclude_layers, + parsed=parsed, ) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index a3d7b4ef7..7f3de11b5 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -18,8 +18,8 @@ from torch import nn from atom.config import QuantizationConfig, get_current_atom_config +from atom.quant_spec import LayerQuantSpec from atom.model_ops.utils import normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale -from atom.models.utils import get_quant_config_for_layer # import torch.distributed as dist from aiter.dist.parallel_state import get_tp_group @@ -200,12 +200,34 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = False, source_quant_dtype: torch.dtype | None = None, + layer_spec: Optional[LayerQuantSpec] = None, ): if quant_config is None: quant_config = QuantizationConfig() - self.source_quant_dtype = source_quant_dtype - quant_type = quant_config["quant_type"] - params_dtype = quant_config["quant_dtype"] + + # --- New: prefer LayerQuantSpec if provided --- + if layer_spec is not None: + self._layer_spec = layer_spec + else: + # Build a LayerQuantSpec from old-style dict fields for compat + self._layer_spec = LayerQuantSpec( + quant_type=quant_config["quant_type"], + quant_dtype=quant_config["quant_dtype"], + is_dynamic=quant_config.get("is_dynamic", True), + quant_method=quant_config.get("quant_method", None), + checkpoint_dtype=source_quant_dtype, + ) + + # Backward compat: source_quant_dtype can come from layer_spec + self.source_quant_dtype = ( + source_quant_dtype + if source_quant_dtype is not None + else self._layer_spec.checkpoint_dtype + ) + + # Effective quant params for this layer + quant_type = self._layer_spec.quant_type + params_dtype = self._layer_spec.quant_dtype super().__init__() self.reduce_results = reduce_results self.input_size = input_size @@ -259,7 +281,7 @@ def __init__( torch.empty(len(self.output_partition_sizes), 1, dtype=dtypes.fp32), requires_grad=False, ) - if not quant_config["is_dynamic"]: + if not self._layer_spec.is_dynamic: self.input_scale = nn.Parameter( torch.empty( len(self.output_partition_sizes), 1, dtype=dtypes.fp32 @@ -453,6 +475,13 @@ def __init__( source_quant_dtype: torch.dtype = None, **kwargs, ): + # Extract per-layer info from kwargs + prefix = kwargs.pop("prefix", "") + layer_spec = kwargs.pop("layer_spec", None) + if layer_spec is None and quant_config is not None and prefix: + layer_spec = quant_config.resolve(prefix) + if not layer_spec.is_quantized: + quant_config = None # backward compat: pass None to LinearBase super().__init__( input_size, output_size, @@ -460,10 +489,17 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + layer_spec=layer_spec, ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data + # Checkpoint scales are often 1-D; reshape to match param shape (e.g. [N] -> [N, 1]) + if ( + loaded_weight.shape != param_data.shape + and loaded_weight.ndim < param_data.ndim + ): + loaded_weight = loaded_weight.view(param_data.shape) param.weight_loader_process(param_data, loaded_weight) @@ -475,9 +511,16 @@ def __init__( bias: bool = False, quant_config: Optional[QuantizationConfig] = None, source_quant_dtype: torch.dtype = None, + prefix: str = "", **kwargs, ): self.tp_dim = 0 + # Resolve per-layer spec via prefix + layer_spec = None + if quant_config is not None and prefix: + layer_spec = quant_config.resolve(prefix) + if not layer_spec.is_quantized: + quant_config = None # backward compat: pass None to LinearBase super().__init__( input_size, output_size, @@ -485,6 +528,7 @@ def __init__( bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + layer_spec=layer_spec, ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): @@ -492,6 +536,12 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) + # Checkpoint scales are often 1-D; reshape to match param shape (e.g. [N] -> [N, 1]) + if ( + loaded_weight.shape != param_data.shape + and loaded_weight.ndim < param_data.ndim + ): + loaded_weight = loaded_weight.view(param_data.shape) param.weight_loader_process(param_data, loaded_weight) @@ -507,8 +557,12 @@ def __init__( **kwargs, ): self.output_sizes = output_sizes + # Resolve per-layer spec via prefix + layer_spec = None if quant_config is not None and prefix: - quant_config = get_quant_config_for_layer(quant_config, prefix) + layer_spec = quant_config.resolve(prefix) + if not layer_spec.is_quantized: + quant_config = None # backward compat: pass None to LinearBase super().__init__( input_size, output_sizes, @@ -516,6 +570,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + layer_spec=layer_spec, ) def weight_loader( @@ -619,6 +674,7 @@ def __init__( bias: bool = False, quant_config: Optional[QuantizationConfig] = None, source_quant_dtype: torch.dtype = None, + prefix: str = "", **kwargs, ): self.head_size = head_size @@ -644,12 +700,20 @@ def __init__( self.num_kv_heads * self.head_size * tp_size, ] + # Resolve per-layer spec via prefix + layer_spec = None + if quant_config is not None and prefix: + layer_spec = quant_config.resolve(prefix) + if not layer_spec.is_quantized: + quant_config = None # backward compat: pass None to LinearBase + super().__init__( input_size, output_sizes, bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + layer_spec=layer_spec, ) def weight_loader( @@ -701,8 +765,12 @@ def __init__( **kwargs, ): self.tp_rank = get_tp_group().rank_in_group + # Resolve per-layer spec via prefix + layer_spec = None if quant_config is not None and prefix: - quant_config = get_quant_config_for_layer(quant_config, prefix) + layer_spec = quant_config.resolve(prefix) + if not layer_spec.is_quantized: + quant_config = None # backward compat: pass None to LinearBase super().__init__( input_size, output_size, @@ -711,18 +779,33 @@ def __init__( quant_config=quant_config, reduce_results=reduce_results, source_quant_dtype=source_quant_dtype, + layer_spec=layer_spec, ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data + is_scale = param is getattr(self, "weight_scale", None) or param is getattr( + self, "input_scale", None + ) if param is not getattr(self, "bias", None): - shard_size = param_data.size(self.tp_dim) - if len(loaded_weight.shape) == 0: - loaded_weight = loaded_weight.view(1, 1) - if loaded_weight.size(self.tp_dim) == 1 and self.tp_size > 1: - loaded_weight = loaded_weight.repeat(1, self.tp_size) - start_idx = self.tp_rank * shard_size - loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) + # For per-token / per-channel scales, the scale is NOT sharded + # along the input (tp_dim=1) dimension -- each rank holds the + # full scale. Only the weight itself (2-D) is narrowed. + if is_scale and loaded_weight.ndim <= 1: + # Checkpoint scale is 1-D [output_size]; reshape to [output_size, 1] + if ( + loaded_weight.shape != param_data.shape + and loaded_weight.ndim < param_data.ndim + ): + loaded_weight = loaded_weight.view(param_data.shape) + else: + shard_size = param_data.size(self.tp_dim) + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.view(1, 1) + if loaded_weight.size(self.tp_dim) == 1 and self.tp_size > 1: + loaded_weight = loaded_weight.repeat(1, self.tp_size) + start_idx = self.tp_rank * shard_size + loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) else: if self.tp_size > 0 and self.tp_rank != 0: loaded_weight.zero_() @@ -742,10 +825,11 @@ def __init__( self.output_sizes = output_size super().__init__( input_size, - sum(output_size), # ? + sum(output_size), bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + **kwargs, # forward prefix/layer_spec to ReplicatedLinear ) def weight_loader( @@ -768,8 +852,24 @@ def weight_loader( elif self.quant_type == QuantType.per_Tensor: shard_offset = loaded_shard_id shard_size = 1 + elif self.quant_type in (QuantType.per_Token, QuantType.per_1x32): + # per_Token: scale shape is (output_size, 1) + # per_1x32: scale shape is (output_size, ceil(input_size/32)) + # Both are sharded along dim-0 (output_size) like the weight. + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] + else: + # Fallback: treat scale shard dims the same as weight dims + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] else: shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_size = self.output_sizes[loaded_shard_id] param_data = param_data.narrow(0, shard_offset, shard_size) + # Checkpoint scales are often 1-D; reshape to match param shape (e.g. [N] -> [N, 1]) + if ( + loaded_weight.shape != param_data.shape + and loaded_weight.ndim < param_data.ndim + ): + loaded_weight = loaded_weight.view(param_data.shape) param.weight_loader_process(param_data, loaded_weight) diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 9a0e1eba1..952aec522 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -183,6 +183,7 @@ def __init__( self.total_num_kv_heads, bias=qkv_bias, quant_config=atom_config.quant_config, + prefix=prefix, ) self.o_proj = RowParallelLinear( @@ -191,6 +192,7 @@ def __init__( bias=False, quant_config=atom_config.quant_config, reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( diff --git a/atom/models/utils.py b/atom/models/utils.py index 60334d78c..8b61a39e7 100644 --- a/atom/models/utils.py +++ b/atom/models/utils.py @@ -240,23 +240,30 @@ def fast_topk(values, topk, dim): def should_ignore_layer( quantization_config: Optional[QuantizationConfig], prefix: str ) -> bool: + """Check whether *prefix* should skip quantization. + + Delegates to ``QuantizationConfig.resolve()`` when available (the new + ``LayerQuantSpec``-based path). Falls back to the legacy exclude-list + scan for plain-dict configs. + """ if quantization_config is None: return True + # New path: use resolve() if available + if hasattr(quantization_config, "resolve"): + spec = quantization_config.resolve(prefix) + return not spec.is_quantized + # Legacy fallback exclude_layers: List[str] = quantization_config.get("exclude_layers", []) if not exclude_layers: return False for exclude_layer in exclude_layers: if exclude_layer.startswith("re"): - # case "re:model.layers.*self_attn.*", remove the 're:' prefix regex_pattern = exclude_layer[3:] if re.search(regex_pattern, prefix): return True elif prefix in exclude_layer: - # case exclude_layer like "model.layers.0.self_attn.q_a_proj" - # a common prefix for linear layers in attn like "model.layers.0.self_attn" return True else: - # case "lm_head". Common practice won't quant lm_head, however. if prefix.split(".")[-1] == exclude_layer: return True return False @@ -265,6 +272,11 @@ def should_ignore_layer( def get_quant_config_for_layer( quantization_config: Optional[QuantizationConfig], prefix: str ) -> Optional[QuantizationConfig]: + """Return *quantization_config* if *prefix* should be quantized, else None. + + This is the legacy helper — new code should prefer + ``quant_config.resolve(prefix)`` directly. + """ return ( None if should_ignore_layer(quantization_config, prefix) diff --git a/atom/quant_spec.py b/atom/quant_spec.py new file mode 100644 index 000000000..6f6e28d06 --- /dev/null +++ b/atom/quant_spec.py @@ -0,0 +1,396 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +""" +Per-layer quantization specification and parser registry. + +This module introduces: +- ``LayerQuantSpec``: a frozen dataclass describing how a single layer + should be quantized at init-time and (optionally) post-load-time. +- ``ParsedQuantConfig``: the output of a quant-method parser, containing + a global spec, per-layer overrides, and the raw exclude list. +- ``QuantConfigParser`` ABC + ``_QUANT_PARSERS`` registry so new + quant methods can be added without touching the core class. +""" + +from __future__ import annotations + +import logging +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch +from aiter import QuantType +from aiter.utility.dtypes import d_dtypes + +logger = logging.getLogger("atom") + +# --------------------------------------------------------------------------- +# LayerQuantSpec +# --------------------------------------------------------------------------- + +# Sentinel for "no quant" — used as the default +_NO_QUANT_SPEC: Optional["LayerQuantSpec"] = None # set after class def + + +@dataclass(frozen=True) +class LayerQuantSpec: + """Immutable description of how a single linear layer is quantized. + + By making this a frozen dataclass we get: + - Safe sharing between layers (immutable) + - Easy equality checks and hashing + - A single place to extend when new quantization knobs are needed + """ + + quant_type: QuantType = QuantType.No + quant_dtype: torch.dtype = torch.bfloat16 + is_dynamic: bool = True + quant_method: Optional[str] = None + + # Checkpoint format may differ from runtime quant_dtype. + # e.g. MXFP4 checkpoints store bf16 source weights that are + # quantized in-place during process_weights_after_loading. + checkpoint_dtype: Optional[torch.dtype] = None + + # Extra flags for specialised paths (e.g. triton-gemm selection, + # pre-shuffle hints, etc.). Using a dict keeps LayerQuantSpec + # extensible without adding a new field for every niche flag. + flags: dict[str, Any] = field(default_factory=dict) + + # -- convenience predicates ------------------------------------------- + + @property + def is_quantized(self) -> bool: + return self.quant_type != QuantType.No + + @property + def needs_online_quant(self) -> bool: + """True when the checkpoint stores BF16 weights that must be + quantized in ``process_weights_after_loading``.""" + return ( + self.checkpoint_dtype is not None + and self.checkpoint_dtype != self.quant_dtype + ) + + @classmethod + def no_quant(cls) -> "LayerQuantSpec": + """Canonical "no quantization" spec — BF16 weights, no scales.""" + global _NO_QUANT_SPEC + if _NO_QUANT_SPEC is None: + _NO_QUANT_SPEC = cls() + return _NO_QUANT_SPEC + + +# --------------------------------------------------------------------------- +# ParsedQuantConfig — output of a parser +# --------------------------------------------------------------------------- + + +@dataclass +class ParsedQuantConfig: + """Result returned by every ``QuantConfigParser.parse()`` call. + + Attributes: + global_spec: The default spec applied to every layer unless + overridden. + layer_specs: Optional per-layer overrides keyed by exact layer + prefix. Checked first (highest priority). + layer_pattern_specs: Ordered list of ``(pattern, spec)`` pairs + where *pattern* is an fnmatch-style glob (e.g. + ``*self_attn*``). Checked after ``layer_specs`` but before + the global default. First matching pattern wins. + exclude_layers: The raw exclude-layer list from the HF config, + kept for backward-compat logging / debugging. + """ + + global_spec: LayerQuantSpec + layer_specs: dict[str, LayerQuantSpec] = field(default_factory=dict) + layer_pattern_specs: list[tuple[str, LayerQuantSpec]] = field(default_factory=list) + exclude_layers: list[str] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Parser registry +# --------------------------------------------------------------------------- + + +class QuantConfigParser(ABC): + """Abstract base for quant-method parsers. + + To register a new parser: + + @register_quant_parser("my_method") + class MyParser(QuantConfigParser): + def parse(self, raw_config: dict) -> ParsedQuantConfig: + ... + """ + + @abstractmethod + def parse(self, raw_config: dict) -> ParsedQuantConfig: + """Parse a raw HF ``quantization_config`` dict into a + ``ParsedQuantConfig``.""" + ... + + +_QUANT_PARSERS: dict[str, type[QuantConfigParser]] = {} + + +def register_quant_parser(method_name: str): + """Decorator to register a ``QuantConfigParser`` subclass for a given + ``quant_method`` string.""" + + def wrapper(cls: type[QuantConfigParser]): + if method_name in _QUANT_PARSERS: + logger.warning( + "Overwriting quant parser for method %r with %s", + method_name, + cls.__name__, + ) + _QUANT_PARSERS[method_name] = cls + return cls + + return wrapper + + +def get_quant_parser(method_name: str) -> QuantConfigParser: + """Instantiate and return the parser for *method_name*. + + Falls back to ``GenericParser`` when no specific parser is registered. + """ + cls = _QUANT_PARSERS.get(method_name) + if cls is None: + logger.warning( + "No dedicated quant parser for method %r — falling back to GenericParser", + method_name, + ) + cls = _QUANT_PARSERS.get("__generic__", GenericParser) + return cls() + + +# --------------------------------------------------------------------------- +# Shared parsing helpers +# --------------------------------------------------------------------------- + +RE_QUANT_BLOCKSIZE = r"\'(?:group_size|weight_block_size)\'\:\s*(?:\[\n*)\s*(\d+)," +RE_QUANT_DTYPE = r"\'(?:d?type|weight_dtype|quant_method)\'\:\s*\'(\w+)\'" +RE_STATIC_QUANT = r"\'(?:activation_scheme)\'\:\s*\'(static)\'" + + +def _parse_quant_type(raw_str: str) -> QuantType: + """Infer QuantType from the stringified HF config.""" + if "channel'," in raw_str: + return QuantType.per_Token + if group_size := re.search(RE_QUANT_BLOCKSIZE, raw_str): + gs = int(group_size.group(1)) + assert gs in (32, 128), f"Unsupported group size {gs}" + return QuantType.per_1x128 if gs == 128 else QuantType.per_1x32 + return QuantType.per_Tensor + + +def _parse_quant_dtype(raw_str: str) -> torch.dtype: + """Infer torch dtype from stringified HF config.""" + m = re.search(RE_QUANT_DTYPE, raw_str) + if m and m.group(1).lower() in [ + "fp8", + "fp4", + "int8", + "int4", + "fp8_e4m3", + "mxfp4", + ]: + dtype = m.group(1).lower().split("_")[0] + if dtype == "mxfp4": + dtype = "fp4" + if dtype.endswith("4"): + dtype += "x2" + return d_dtypes[dtype] + + bit_match = re.search(r"\'(?:num_)?bits\'\:\s*(\d+)", raw_str) + if bit_match: + bit = int(bit_match.group(1)) + dtype_match = re.search(RE_QUANT_DTYPE, raw_str) + if dtype_match: + dtype = dtype_match.group(1).lower() + dtype_prefix = "i" if dtype.startswith("int") else "fp" + else: + dtype_prefix = "i" + quant_dtype_str = ( + f"{dtype_prefix}{bit}" if bit != 4 else f"{dtype_prefix}{bit}x2" + ) + result = d_dtypes.get(quant_dtype_str, None) + if result is not None: + return result + + raise ValueError(f"Cannot parse quant dtype from {raw_str}") + + +def _parse_is_dynamic(raw_str: str) -> bool: + return not bool(re.search(RE_STATIC_QUANT, raw_str)) + + +# --------------------------------------------------------------------------- +# Built-in parsers +# --------------------------------------------------------------------------- + + +def _build_quark_layer_spec(layer_cfg: dict) -> LayerQuantSpec: + """Build a :class:`LayerQuantSpec` from a single Quark + ``layer_quant_config`` entry (dict with ``weight``, ``input_tensors``, + etc.).""" + weight_cfg = layer_cfg.get("weight", {}) + weight_str = str(weight_cfg) if weight_cfg else str(layer_cfg) + raw_str = str(layer_cfg) + + quant_dtype = _parse_quant_dtype(weight_str) + is_dynamic = _parse_is_dynamic(raw_str) + + # Prefer structured group_size from the dict over regex + group_size = weight_cfg.get("group_size", None) if weight_cfg else None + if isinstance(group_size, list): + group_size = group_size[0] + if group_size in (32, 128): + quant_type = QuantType.per_1x128 if group_size == 128 else QuantType.per_1x32 + else: + quant_type = _parse_quant_type(weight_str) + + # For FP4 dtype, force per_1x32 + if quant_dtype == d_dtypes["fp4x2"]: + quant_type = QuantType.per_1x32 + + return LayerQuantSpec( + quant_type=quant_type, + quant_dtype=quant_dtype, + is_dynamic=is_dynamic, + quant_method="quark", + ) + + +@register_quant_parser("quark") +class QuarkParser(QuantConfigParser): + """Parser for the ``quark`` quantization method (used by AMD MXFP4 + checkpoints, among others).""" + + def parse(self, raw_config: dict) -> ParsedQuantConfig: + # Extract weight sub-config for dtype/quant_type parsing + weight_cfg = raw_config.get("global_quant_config", {}).get("weight", {}) + weight_str = str(weight_cfg) if weight_cfg else str(raw_config) + raw_str = str(raw_config) + + quant_dtype = _parse_quant_dtype(weight_str) + is_dynamic = _parse_is_dynamic(raw_str) + + # Prefer structured group_size from the dict over regex + group_size = weight_cfg.get("group_size", None) if weight_cfg else None + if isinstance(group_size, list): + group_size = group_size[0] + if group_size in (32, 128): + quant_type = ( + QuantType.per_1x128 if group_size == 128 else QuantType.per_1x32 + ) + else: + quant_type = _parse_quant_type(weight_str) + + # quark uses "exclude" for the exclude-layers key + exclude_layers = raw_config.get("exclude", []) or [] + + # For FP4 dtype, force per_1x32 + if quant_dtype == d_dtypes["fp4x2"]: + quant_type = QuantType.per_1x32 + + global_spec = LayerQuantSpec( + quant_type=quant_type, + quant_dtype=quant_dtype, + is_dynamic=is_dynamic, + quant_method="quark", + ) + + # --- layer_quant_config: per-layer pattern overrides --- + layer_pattern_specs: list[tuple[str, LayerQuantSpec]] = [] + layer_quant_cfg = raw_config.get("layer_quant_config", {}) + for layer_key, layer_cfg in layer_quant_cfg.items(): + if not isinstance(layer_cfg, dict): + continue + spec = _build_quark_layer_spec(layer_cfg) + layer_pattern_specs.append((layer_key, spec)) + + if layer_pattern_specs: + logger.info( + "QuarkParser: parsed %d layer-pattern override(s).", + len(layer_pattern_specs), + ) + + return ParsedQuantConfig( + global_spec=global_spec, + layer_pattern_specs=layer_pattern_specs, + exclude_layers=exclude_layers, + ) + + +@register_quant_parser("compressed-tensors") +class CompressedTensorsParser(QuantConfigParser): + """Parser for the ``compressed-tensors`` quantization method.""" + + def parse(self, raw_config: dict) -> ParsedQuantConfig: + raw_str = str(raw_config) + quant_type = QuantType.per_Token # compressed-tensors → per_Token + quant_dtype = _parse_quant_dtype(raw_str) + is_dynamic = _parse_is_dynamic(raw_str) + + # compressed-tensors uses "ignore" for exclude layers + exclude_layers = raw_config.get("ignore", []) or [] + + if quant_dtype == d_dtypes["fp4x2"]: + quant_type = QuantType.per_1x32 + + global_spec = LayerQuantSpec( + quant_type=quant_type, + quant_dtype=quant_dtype, + is_dynamic=is_dynamic, + quant_method="compressed-tensors", + ) + + return ParsedQuantConfig( + global_spec=global_spec, + exclude_layers=exclude_layers, + ) + + +@register_quant_parser("__generic__") +class GenericParser(QuantConfigParser): + """Fallback parser used when no method-specific parser is registered.""" + + def parse(self, raw_config: dict) -> ParsedQuantConfig: + raw_str = str(raw_config) + quant_method = raw_config.get("quant_method", None) + quant_type = _parse_quant_type(raw_str) + quant_dtype = _parse_quant_dtype(raw_str) + is_dynamic = _parse_is_dynamic(raw_str) + + if quant_dtype == d_dtypes["fp4x2"]: + quant_type = QuantType.per_1x32 + + # Best-effort: try "ignore" first, then "exclude" + exclude_layers = ( + raw_config.get("ignore", None) or raw_config.get("exclude", []) or [] + ) + + logger.warning( + "Using generic quant parser for method %r — " + "please verify exclude_layers key is correct.", + quant_method, + ) + + global_spec = LayerQuantSpec( + quant_type=quant_type, + quant_dtype=quant_dtype, + is_dynamic=is_dynamic, + quant_method=quant_method, + ) + + return ParsedQuantConfig( + global_spec=global_spec, + exclude_layers=exclude_layers, + ) diff --git a/tests/test_per_layer_quant.py b/tests/test_per_layer_quant.py new file mode 100644 index 000000000..577033e4e --- /dev/null +++ b/tests/test_per_layer_quant.py @@ -0,0 +1,525 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Unit tests for the per-layer quantization config refactor. + +Tests cover: +- LayerQuantSpec: frozen dataclass, no_quant, predicates +- ParsedQuantConfig: dataclass construction +- Parser registry: registration, dispatch, fallback +- QuarkParser / CompressedTensorsParser / GenericParser: parsing logic +- QuantizationConfig.resolve(): exclude-list resolution, per-layer overrides +- Backward compatibility: dict access, should_ignore_layer, get_quant_config_for_layer +- LinearBase: layer_spec parameter plumbing +""" + +import sys +import unittest +from unittest.mock import MagicMock, patch + +# ── conftest.py stubs atom.config → clean up before we import ── +for mod_name in list(sys.modules): + if mod_name.startswith("atom"): + del sys.modules[mod_name] + +import torch # noqa: E402 +from aiter import QuantType # noqa: E402 +from aiter.utility.dtypes import d_dtypes # noqa: E402 + + +# ==================================================================== +# 1. LayerQuantSpec +# ==================================================================== +class TestLayerQuantSpec(unittest.TestCase): + def test_default_is_no_quant(self): + from atom.quant_spec import LayerQuantSpec + + spec = LayerQuantSpec() + self.assertEqual(spec.quant_type, QuantType.No) + self.assertEqual(spec.quant_dtype, torch.bfloat16) + self.assertFalse(spec.is_quantized) + self.assertFalse(spec.needs_online_quant) + + def test_no_quant_singleton(self): + from atom.quant_spec import LayerQuantSpec + + a = LayerQuantSpec.no_quant() + b = LayerQuantSpec.no_quant() + self.assertIs(a, b) + + def test_quantized_spec(self): + from atom.quant_spec import LayerQuantSpec + + spec = LayerQuantSpec( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + quant_method="quark", + ) + self.assertTrue(spec.is_quantized) + self.assertFalse(spec.needs_online_quant) + + def test_online_quant(self): + from atom.quant_spec import LayerQuantSpec + + spec = LayerQuantSpec( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + checkpoint_dtype=torch.bfloat16, + ) + self.assertTrue(spec.needs_online_quant) + + def test_frozen(self): + from atom.quant_spec import LayerQuantSpec + + spec = LayerQuantSpec() + with self.assertRaises(AttributeError): + spec.quant_type = QuantType.per_Token # type: ignore + + def test_equality(self): + from atom.quant_spec import LayerQuantSpec + + a = LayerQuantSpec(quant_type=QuantType.per_Token, quant_dtype=d_dtypes["fp8"]) + b = LayerQuantSpec(quant_type=QuantType.per_Token, quant_dtype=d_dtypes["fp8"]) + self.assertEqual(a, b) + + def test_flags(self): + from atom.quant_spec import LayerQuantSpec + + spec = LayerQuantSpec(flags={"use_triton": True}) + self.assertTrue(spec.flags["use_triton"]) + + +# ==================================================================== +# 2. ParsedQuantConfig +# ==================================================================== +class TestParsedQuantConfig(unittest.TestCase): + def test_default_fields(self): + from atom.quant_spec import LayerQuantSpec, ParsedQuantConfig + + pqc = ParsedQuantConfig(global_spec=LayerQuantSpec()) + self.assertEqual(pqc.layer_specs, {}) + self.assertEqual(pqc.exclude_layers, []) + + def test_with_overrides(self): + from atom.quant_spec import LayerQuantSpec, ParsedQuantConfig + + global_spec = LayerQuantSpec( + quant_type=QuantType.per_1x32, quant_dtype=torch.float4_e2m1fn_x2 + ) + layer_specs = {"model.layers.0.self_attn.qkv_proj": LayerQuantSpec.no_quant()} + pqc = ParsedQuantConfig( + global_spec=global_spec, + layer_specs=layer_specs, + exclude_layers=["lm_head"], + ) + self.assertEqual( + pqc.layer_specs["model.layers.0.self_attn.qkv_proj"].quant_type, + QuantType.No, + ) + self.assertEqual(pqc.exclude_layers, ["lm_head"]) + + +# ==================================================================== +# 3. Parser registry +# ==================================================================== +class TestParserRegistry(unittest.TestCase): + def test_builtin_parsers_registered(self): + from atom.quant_spec import _QUANT_PARSERS + + self.assertIn("quark", _QUANT_PARSERS) + self.assertIn("compressed-tensors", _QUANT_PARSERS) + self.assertIn("__generic__", _QUANT_PARSERS) + + def test_get_quant_parser_known(self): + from atom.quant_spec import get_quant_parser, QuarkParser + + parser = get_quant_parser("quark") + self.assertIsInstance(parser, QuarkParser) + + def test_get_quant_parser_unknown_falls_back(self): + from atom.quant_spec import get_quant_parser, GenericParser + + parser = get_quant_parser("some_unknown_method") + self.assertIsInstance(parser, GenericParser) + + def test_register_custom_parser(self): + from atom.quant_spec import ( + QuantConfigParser, + ParsedQuantConfig, + LayerQuantSpec, + register_quant_parser, + get_quant_parser, + _QUANT_PARSERS, + ) + + @register_quant_parser("test_method") + class TestParser(QuantConfigParser): + def parse(self, raw_config: dict) -> ParsedQuantConfig: + return ParsedQuantConfig(global_spec=LayerQuantSpec()) + + parser = get_quant_parser("test_method") + self.assertIsInstance(parser, TestParser) + + # Clean up + del _QUANT_PARSERS["test_method"] + + +# ==================================================================== +# 4. QuarkParser +# ==================================================================== +class TestQuarkParser(unittest.TestCase): + def test_parse_mxfp4(self): + from atom.quant_spec import QuarkParser + + raw = { + "quant_method": "quark", + "global_quant_config": { + "weight": {"dtype": "fp4", "group_size": 32}, + }, + "exclude": ["lm_head", "model.layers.0.self_attn.q_proj"], + } + parsed = QuarkParser().parse(raw) + self.assertEqual(parsed.global_spec.quant_type, QuantType.per_1x32) + self.assertEqual(parsed.global_spec.quant_dtype, torch.float4_e2m1fn_x2) + self.assertEqual(parsed.global_spec.quant_method, "quark") + self.assertIn("lm_head", parsed.exclude_layers) + + def test_parse_fp8(self): + from atom.quant_spec import QuarkParser + + raw = { + "quant_method": "quark", + "global_quant_config": { + "weight": {"dtype": "fp8", "group_size": 128}, + }, + "exclude": [], + } + parsed = QuarkParser().parse(raw) + self.assertEqual(parsed.global_spec.quant_type, QuantType.per_1x128) + self.assertEqual(parsed.global_spec.quant_dtype, d_dtypes["fp8"]) + + +# ==================================================================== +# 5. CompressedTensorsParser +# ==================================================================== +class TestCompressedTensorsParser(unittest.TestCase): + def test_parse_ct(self): + from atom.quant_spec import CompressedTensorsParser + + raw = { + "quant_method": "compressed-tensors", + "weight_dtype": "fp8", + "ignore": ["lm_head"], + } + parsed = CompressedTensorsParser().parse(raw) + self.assertEqual(parsed.global_spec.quant_type, QuantType.per_Token) + self.assertIn("lm_head", parsed.exclude_layers) + + +# ==================================================================== +# 6. QuantizationConfig.resolve() +# ==================================================================== +class TestQuantizationConfigResolve(unittest.TestCase): + def _make_qc(self, exclude_layers=None): + from atom.config import QuantizationConfig + + return QuantizationConfig( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + quant_method="quark", + exclude_layers=exclude_layers or [], + ) + + def test_resolve_normal_layer(self): + qc = self._make_qc() + spec = qc.resolve("model.layers.0.mlp.down_proj") + self.assertTrue(spec.is_quantized) + self.assertEqual(spec.quant_type, QuantType.per_1x32) + + def test_resolve_excluded_exact(self): + qc = self._make_qc(exclude_layers=["model.layers.0.self_attn.q_proj"]) + spec = qc.resolve("model.layers.0.self_attn") # prefix match + self.assertFalse(spec.is_quantized) + + def test_resolve_excluded_suffix(self): + qc = self._make_qc(exclude_layers=["lm_head"]) + spec = qc.resolve("lm_head") + self.assertFalse(spec.is_quantized) + + def test_resolve_excluded_regex(self): + qc = self._make_qc(exclude_layers=["re:model.layers.*self_attn.*"]) + spec = qc.resolve("model.layers.5.self_attn.qkv_proj") + self.assertFalse(spec.is_quantized) + + def test_resolve_per_layer_override(self): + from atom.config import QuantizationConfig + from atom.quant_spec import LayerQuantSpec, ParsedQuantConfig + + global_spec = LayerQuantSpec( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + quant_method="quark", + ) + custom_spec = LayerQuantSpec( + quant_type=QuantType.per_Token, + quant_dtype=d_dtypes["fp8"], + quant_method="custom", + ) + parsed = ParsedQuantConfig( + global_spec=global_spec, + layer_specs={"model.layers.0.mlp.down_proj": custom_spec}, + ) + qc = QuantizationConfig( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + parsed=parsed, + ) + # Per-layer override takes priority + spec = qc.resolve("model.layers.0.mlp.down_proj") + self.assertEqual(spec.quant_type, QuantType.per_Token) + self.assertEqual(spec.quant_method, "custom") + + # Non-overridden layer gets global spec + spec2 = qc.resolve("model.layers.1.mlp.down_proj") + self.assertEqual(spec2.quant_type, QuantType.per_1x32) + + def test_resolve_no_quant_config(self): + from atom.config import QuantizationConfig + + qc = QuantizationConfig() + spec = qc.resolve("model.layers.0.self_attn") + self.assertFalse(spec.is_quantized) + + +# ==================================================================== +# 7. Backward compatibility +# ==================================================================== +class TestBackwardCompat(unittest.TestCase): + def test_dict_access(self): + from atom.config import QuantizationConfig + + qc = QuantizationConfig( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + quant_method="quark", + exclude_layers=["lm_head"], + ) + self.assertEqual(qc["quant_type"], QuantType.per_1x32) + self.assertEqual(qc["quant_dtype"], torch.float4_e2m1fn_x2) + self.assertEqual(qc["quant_method"], "quark") + self.assertEqual(qc["exclude_layers"], ["lm_head"]) + + def test_should_ignore_layer_uses_resolve(self): + from atom.config import QuantizationConfig + from atom.models.utils import should_ignore_layer + + qc = QuantizationConfig( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + exclude_layers=["lm_head"], + ) + self.assertTrue(should_ignore_layer(qc, "lm_head")) + self.assertFalse(should_ignore_layer(qc, "model.layers.0.mlp.down_proj")) + + def test_get_quant_config_for_layer(self): + from atom.config import QuantizationConfig + from atom.models.utils import get_quant_config_for_layer + + qc = QuantizationConfig( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + exclude_layers=["lm_head"], + ) + self.assertIsNone(get_quant_config_for_layer(qc, "lm_head")) + self.assertIs( + get_quant_config_for_layer(qc, "model.layers.0.mlp.down_proj"), qc + ) + + def test_parsed_property(self): + from atom.config import QuantizationConfig + from atom.quant_spec import ParsedQuantConfig + + qc = QuantizationConfig(quant_type=QuantType.per_Token) + self.assertIsInstance(qc.parsed, ParsedQuantConfig) + + def test_global_spec_property(self): + from atom.config import QuantizationConfig + from atom.quant_spec import LayerQuantSpec + + qc = QuantizationConfig( + quant_type=QuantType.per_Token, quant_dtype=d_dtypes["fp8"] + ) + self.assertIsInstance(qc.global_spec, LayerQuantSpec) + self.assertEqual(qc.global_spec.quant_type, QuantType.per_Token) + + def test_compute_hash_unchanged(self): + from atom.config import QuantizationConfig + + qc1 = QuantizationConfig( + quant_type=QuantType.per_1x32, quant_dtype=torch.float4_e2m1fn_x2 + ) + qc2 = QuantizationConfig( + quant_type=QuantType.per_1x32, quant_dtype=torch.float4_e2m1fn_x2 + ) + self.assertEqual(qc1.compute_hash(), qc2.compute_hash()) + + +# ==================================================================== +# 8. LinearBase with layer_spec +# ==================================================================== +class TestLinearBaseLayerSpec(unittest.TestCase): + @patch("atom.model_ops.linear.get_tp_group") + def test_layer_spec_overrides_quant_config(self, mock_tp): + """When layer_spec is provided, LinearBase uses it instead of + dict-based quant_config fields.""" + from atom.quant_spec import LayerQuantSpec + + mock_group = MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 1 + mock_tp.return_value = mock_group + + from atom.model_ops.linear import LinearBase + from atom.config import QuantizationConfig + + spec = LayerQuantSpec( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + quant_method="quark", + ) + # Pass a no-quant QuantizationConfig but override with layer_spec + qc = QuantizationConfig() + lb = LinearBase(1024, 512, layer_spec=spec, quant_config=qc) + self.assertEqual(lb._layer_spec, spec) + self.assertEqual(lb.quant_type, QuantType.per_1x32) + self.assertEqual(lb.params_dtype, torch.float4_e2m1fn_x2) + + @patch("atom.model_ops.linear.get_tp_group") + def test_layer_spec_with_checkpoint_dtype(self, mock_tp): + """LayerQuantSpec.checkpoint_dtype flows into source_quant_dtype.""" + from atom.quant_spec import LayerQuantSpec + + mock_group = MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 1 + mock_tp.return_value = mock_group + + from atom.model_ops.linear import LinearBase + + spec = LayerQuantSpec( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + checkpoint_dtype=torch.bfloat16, + ) + lb = LinearBase(1024, 512, layer_spec=spec) + self.assertEqual(lb.source_quant_dtype, torch.bfloat16) + self.assertTrue(lb._layer_spec.needs_online_quant) + + @patch("atom.model_ops.linear.get_tp_group") + def test_no_layer_spec_builds_from_dict(self, mock_tp): + """When layer_spec is not provided, LinearBase builds one from the dict.""" + mock_group = MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 1 + mock_tp.return_value = mock_group + + from atom.model_ops.linear import LinearBase + from atom.config import QuantizationConfig + + qc = QuantizationConfig( + quant_type=QuantType.per_Token, + quant_dtype=d_dtypes["fp8"], + ) + lb = LinearBase(1024, 512, quant_config=qc) + self.assertEqual(lb._layer_spec.quant_type, QuantType.per_Token) + self.assertEqual(lb._layer_spec.quant_dtype, d_dtypes["fp8"]) + + +# ==================================================================== +# 9. QKVParallelLinear resolve integration +# ==================================================================== +class TestQKVResolve(unittest.TestCase): + @patch("atom.model_ops.linear.get_tp_group") + def test_qkv_excluded_layer_gets_no_quant(self, mock_tp): + mock_group = MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 1 + mock_tp.return_value = mock_group + + from atom.model_ops.linear import QKVParallelLinear + from atom.config import QuantizationConfig + + qc = QuantizationConfig( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + quant_method="quark", + exclude_layers=["re:.*self_attn.*"], + ) + qkv = QKVParallelLinear( + hidden_size=1024, + head_size=128, + total_num_heads=8, + total_num_kv_heads=2, + quant_config=qc, + prefix="model.layers.0.self_attn", + ) + # Should have QuantType.No since self_attn is excluded + self.assertEqual(qkv.quant_type, QuantType.No) + + @patch("atom.model_ops.linear.get_tp_group") + def test_qkv_non_excluded_layer_gets_quant(self, mock_tp): + mock_group = MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 1 + mock_tp.return_value = mock_group + + from atom.model_ops.linear import QKVParallelLinear + from atom.config import QuantizationConfig + + qc = QuantizationConfig( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + quant_method="quark", + exclude_layers=["lm_head"], + ) + qkv = QKVParallelLinear( + hidden_size=1024, + head_size=128, + total_num_heads=8, + total_num_kv_heads=2, + quant_config=qc, + prefix="model.layers.0.self_attn", + ) + self.assertEqual(qkv.quant_type, QuantType.per_1x32) + + +# ==================================================================== +# 10. RowParallelLinear resolve integration +# ==================================================================== +class TestRowResolve(unittest.TestCase): + @patch("atom.model_ops.linear.get_tp_group") + def test_row_excluded_gets_no_quant(self, mock_tp): + mock_group = MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 1 + mock_tp.return_value = mock_group + + from atom.model_ops.linear import RowParallelLinear + from atom.config import QuantizationConfig + + qc = QuantizationConfig( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + exclude_layers=["lm_head"], + ) + row = RowParallelLinear( + input_size=1024, + output_size=512, + quant_config=qc, + prefix="lm_head", + ) + self.assertEqual(row.quant_type, QuantType.No) + + +if __name__ == "__main__": + unittest.main() From 097b7a8e016e113d2f512907a89e77c374c3c1aa Mon Sep 17 00:00:00 2001 From: Thiago Rocha Date: Fri, 6 Mar 2026 15:54:35 +0000 Subject: [PATCH 2/3] fix: resolve per-layer quant config in FusedMoE init When a model has per-layer quantization overrides (e.g., MXFP4 globally but FP8 for MTP layers), the FusedMoE dispatch logic needs to see the resolved per-layer dtype/type rather than the global config defaults. - Add None guard after get_quant_config_for_layer (can return None for unquantized layers via should_ignore_layer + resolve path) - Use quant_config.resolve(prefix) to detect per-layer overrides and construct a layer-specific QuantizationConfig when they differ from the global config - Update self.quant_config and self.params_dtype to match the resolved config so weight_loader also uses the correct dtype --- atom/model_ops/moe.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index b6820cdda..ec9417119 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1948,6 +1948,28 @@ def __init__( if quant_config is not None and prefix: quant_config = get_quant_config_for_layer(quant_config, prefix) + if quant_config is None: + quant_config = QuantizationConfig() + + # Resolve per-layer quant spec so the dispatch below sees the + # correct dtype/type when per-layer overrides differ from the + # global config (e.g., MXFP4 globally but FP8 for MTP layers). + if hasattr(quant_config, "resolve") and prefix: + _spec = quant_config.resolve(prefix) + if _spec.is_quantized and ( + _spec.quant_dtype != quant_config["quant_dtype"] + or _spec.quant_type != quant_config["quant_type"] + ): + quant_config = QuantizationConfig( + quant_type=_spec.quant_type, + quant_dtype=_spec.quant_dtype, + is_dynamic=quant_config.get("is_dynamic", True), + quant_name=quant_config.get("quant_name", ""), + quant_method=quant_config.get("quant_method", None), + ) + # Update instance attrs to match the (possibly resolved) config + self.quant_config = quant_config + self.params_dtype = quant_config["quant_dtype"] # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. From 7392aec58e304aafe00a69bf4d80db9fcad89a9a Mon Sep 17 00:00:00 2001 From: Thiago Rocha Date: Fri, 6 Mar 2026 16:03:08 +0000 Subject: [PATCH 3/3] fix: route per-channel FP8 MoE to CompressedTensorsFp8MoEMethod Per-channel (per_Token) FP8 quantization needs per-channel weight scale allocation [E, N, 1] which CompressedTensorsFp8MoEMethod provides. Fp8MoEMethod only allocates scalar-per-expert scales [E, 2]/[E]. - Add dispatch case for quant_dtype==fp8 + quant_type==per_Token to use CompressedTensorsFp8MoEMethod - Fix _load_per_channel_weight_scale to unsqueeze 1D checkpoint scales to match 2D [N, 1] buffer shape --- atom/model_ops/moe.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index ec9417119..ca02116fc 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1984,6 +1984,13 @@ def __init__( ): # Use CompressedTensorsFp8MoEMethod for compressed-tensors format self.quant_method = CompressedTensorsFp8MoEMethod(quant_config, moe) + elif ( + quant_config["quant_dtype"] == dtypes.fp8 + and quant_config["quant_type"] == QuantType.per_Token + ): + # Per-channel FP8 (e.g., Quark per_Token override for MTP layers) + # needs CompressedTensors-style weight scale handling + self.quant_method = CompressedTensorsFp8MoEMethod(quant_config, moe) elif quant_config["quant_dtype"] == dtypes.fp8: self.quant_method = Fp8MoEMethod(quant_config, moe) elif quant_config["quant_dtype"] == dtypes.fp4x2: @@ -2100,6 +2107,10 @@ def _load_per_channel_weight_scale( tp_rank: int, ): # for per channel weight quantization + # When scales are stored as [N,1] (CompressedTensors per-channel) + # but loaded from checkpoint as [N], reshape to match. + if loaded_weight.dim() == 1 and expert_data.dim() == 2: + loaded_weight = loaded_weight.unsqueeze(-1) if shard_id == "w2": expert_data.copy_(loaded_weight) elif shard_id in ("w1", "w3"):