Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from aiter import QuantType
from aiter.utility.dtypes import d_dtypes
from atom.utils import envs, get_open_port
from atom.quant_spec import (
LayerQuantSpec,
ParsedQuantConfig,
get_quant_parser,
)
from atom.utils.distributed.utils import stateless_init_torch_distributed_process_group
from torch.distributed import ProcessGroup, ReduceOp
from transformers import AutoConfig, GenerationConfig, PretrainedConfig
Expand Down Expand Up @@ -251,6 +256,15 @@ def set_splitting_ops_for_v1(self):


class QuantizationConfig(dict):
"""Model-wide quantization configuration.

Still inherits from dict for backward compatibility with existing code
that accesses ``quant_config["quant_type"]``, etc.

New code should prefer the :pyattr:`parsed` attribute and
:pymeth:`resolve` method.
"""

def __init__(
self,
quant_type=QuantType.No,
Expand All @@ -259,6 +273,8 @@ def __init__(
quant_name="",
quant_method=None,
exclude_layers: Optional[list[str]] = None,
*,
parsed: Optional[ParsedQuantConfig] = None,
):
super().__init__()
self["quant_type"] = quant_type if quant_type is not None else QuantType.No
Expand All @@ -268,9 +284,93 @@ def __init__(
self["quant_method"] = quant_method
self["exclude_layers"] = exclude_layers if exclude_layers is not None else []

# --- New: structured parsed config ---
if parsed is not None:
self._parsed = parsed
else:
# Build a ParsedQuantConfig from the scalar fields so that
# manually-constructed QuantizationConfigs still work.
self._parsed = ParsedQuantConfig(
global_spec=LayerQuantSpec(
quant_type=self["quant_type"],
quant_dtype=self["quant_dtype"],
is_dynamic=self["is_dynamic"],
quant_method=self["quant_method"],
),
exclude_layers=self["exclude_layers"],
)

# -- public API --------------------------------------------------------

@property
def parsed(self) -> ParsedQuantConfig:
"""Access the structured :class:`ParsedQuantConfig`."""
return self._parsed

@property
def global_spec(self) -> LayerQuantSpec:
"""Shortcut for ``self.parsed.global_spec``."""
return self._parsed.global_spec

def resolve(self, prefix: str) -> LayerQuantSpec:
"""Return the :class:`LayerQuantSpec` for layer *prefix*.

Resolution order:
1. Explicit per-layer override in ``parsed.layer_specs[prefix]``.
2. Check the exclude list -- if the layer is excluded, return
``LayerQuantSpec.no_quant()``.
3. fnmatch-style pattern match in ``parsed.layer_pattern_specs``
(first matching pattern wins).
4. Fall back to ``parsed.global_spec``.
"""
from fnmatch import fnmatch

# 1. Explicit per-layer override
layer_specs = self._parsed.layer_specs
if prefix in layer_specs:
return layer_specs[prefix]

# 2. Check exclude list
if self._is_excluded(prefix):
return LayerQuantSpec.no_quant()

# 3. fnmatch-style pattern matching
for pattern, spec in self._parsed.layer_pattern_specs:
if fnmatch(prefix, pattern):
return spec

# 4. Global default
return self._parsed.global_spec

# -- backward compat ---------------------------------------------------

def get_name(self):
return self["quant_name"]

# -- internals ---------------------------------------------------------

def _is_excluded(self, prefix: str) -> bool:
"""Check whether *prefix* matches the exclude list.

Uses the same logic as the original ``should_ignore_layer``
in ``atom.models.utils`` so behaviour is identical.
"""
exclude_layers: list[str] = self._parsed.exclude_layers
if not exclude_layers:
return False
for exclude_layer in exclude_layers:
if exclude_layer.startswith("re"):
# case "re:model.layers.*self_attn.*"
regex_pattern = exclude_layer[3:]
if re.search(regex_pattern, prefix):
return True
elif prefix in exclude_layer:
return True
else:
if prefix.split(".")[-1] == exclude_layer:
return True
return False

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -363,12 +463,24 @@ def get_quant_config(config: PretrainedConfig) -> QuantizationConfig:
)
exclude_layers_key = "ignore"
exclude_layers = orig_quant_config.get(exclude_layers_key, None)

# Use the structured parser to build a ParsedQuantConfig that includes
# per-layer pattern overrides (layer_pattern_specs) from the HF config.
# This is needed for models like DeepSeek-R1 MXFP4 which have different
# quantization for attention vs MoE layers.
if quant_method is not None:
parser = get_quant_parser(quant_method)
parsed = parser.parse(orig_quant_config)
else:
parsed = None

return QuantizationConfig(
quant_type,
quant_dtype,
is_dynamic,
quant_method=quant_method,
exclude_layers=exclude_layers,
parsed=parsed,
)


Expand Down
Loading
Loading