Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 42 additions & 31 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,13 @@ def set_splitting_ops_for_v1(self):
]


class QuantizationConfig(dict):
class QuantizationConfig:
"""Model-wide quantization configuration.

Still inherits from dict for backward compatibility with existing code
that accesses ``quant_config["quant_type"]``, etc.

New code should prefer the :pyattr:`parsed` attribute and
:pymeth:`resolve` method.
The primary API is :pymeth:`resolve` and the :pyattr:`parsed` /
:pyattr:`global_spec` attributes. Scalar convenience properties
(``quant_type``, ``quant_dtype``, ``is_dynamic``, ``quant_method``)
delegate to ``global_spec``.
"""

def __init__(
Expand All @@ -276,28 +275,24 @@ def __init__(
*,
parsed: Optional[ParsedQuantConfig] = None,
):
super().__init__()
self["quant_type"] = quant_type if quant_type is not None else QuantType.No
self["quant_dtype"] = quant_dtype if quant_dtype is not None else torch.bfloat16
self["quant_name"] = quant_name
self["is_dynamic"] = is_dynamic
self["quant_method"] = quant_method
self["exclude_layers"] = exclude_layers if exclude_layers is not None else []

# --- New: structured parsed config ---
self._quant_name = quant_name

# --- Structured parsed config ---
if parsed is not None:
self._parsed = parsed
else:
# Build a ParsedQuantConfig from the scalar fields so that
# manually-constructed QuantizationConfigs still work.
self._parsed = ParsedQuantConfig(
global_spec=LayerQuantSpec(
quant_type=self["quant_type"],
quant_dtype=self["quant_dtype"],
is_dynamic=self["is_dynamic"],
quant_method=self["quant_method"],
quant_type=quant_type if quant_type is not None else QuantType.No,
quant_dtype=(
quant_dtype if quant_dtype is not None else torch.bfloat16
),
is_dynamic=is_dynamic,
quant_method=quant_method,
),
exclude_layers=self["exclude_layers"],
exclude_layers=exclude_layers if exclude_layers is not None else [],
)

# -- public API --------------------------------------------------------
Expand Down Expand Up @@ -342,18 +337,35 @@ def resolve(self, prefix: str) -> LayerQuantSpec:
# 4. Global default
return self._parsed.global_spec

# -- backward compat ---------------------------------------------------
# -- scalar convenience properties ------------------------------------

@property
def quant_type(self) -> "QuantType":
return self._parsed.global_spec.quant_type

@property
def quant_dtype(self) -> torch.dtype:
return self._parsed.global_spec.quant_dtype

@property
def is_dynamic(self) -> bool:
return self._parsed.global_spec.is_dynamic

@property
def quant_method(self) -> Optional[str]:
return self._parsed.global_spec.quant_method

# -- named accessor ---------------------------------------------------

def get_name(self):
return self["quant_name"]
def get_name(self) -> str:
return self._quant_name

# -- internals ---------------------------------------------------------

def _is_excluded(self, prefix: str) -> bool:
"""Check whether *prefix* matches the exclude list.

Uses the same logic as the original ``should_ignore_layer``
in ``atom.models.utils`` so behaviour is identical.
Supports bare suffix, substring, and ``re:`` prefix for regex patterns.
"""
exclude_layers: list[str] = self._parsed.exclude_layers
if not exclude_layers:
Expand Down Expand Up @@ -384,12 +396,11 @@ def compute_hash(self) -> str:
the final hidden states.
"""
factors: list[Any] = []
factors.append(self["quant_type"])
factors.append(self["quant_dtype"])
factors.append(self["quant_name"])
factors.append(self["is_dynamic"])
factors.append(self["quant_method"])
# assert_hashable(str_factors)
factors.append(self.quant_type)
factors.append(self.quant_dtype)
factors.append(self._quant_name)
factors.append(self.is_dynamic)
factors.append(self.quant_method)
return hashlib.sha256(str(factors).encode()).hexdigest()


Expand Down
6 changes: 2 additions & 4 deletions atom/model_ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@ def __init__(
if quant_config is None:
quant_config = QuantizationConfig()

quant_type = quant_config["quant_type"]
params_dtype = quant_config["quant_dtype"]
self.quant_type = quant_type
self.params_dtype = params_dtype
self.quant_type = quant_config.global_spec.quant_type
self.params_dtype = quant_config.global_spec.quant_dtype

def forward_native(
self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None
Expand Down
6 changes: 2 additions & 4 deletions atom/model_ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,8 @@ def __init__(

if quant_config is None:
quant_config = QuantizationConfig()
quant_type = quant_config["quant_type"]
params_dtype = quant_config["quant_dtype"]
self.quant_type = quant_type
self.params_dtype = params_dtype
self.quant_type = quant_config.global_spec.quant_type
self.params_dtype = quant_config.global_spec.quant_dtype

def forward(
self,
Expand Down
12 changes: 6 additions & 6 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,16 @@ def __init__(
if quant_config is None:
quant_config = QuantizationConfig()

# --- New: prefer LayerQuantSpec if provided ---
# layer_spec is always provided by the linear subclasses via resolve()
if layer_spec is not None:
self._layer_spec = layer_spec
else:
# Build a LayerQuantSpec from old-style dict fields for compat
# Fallback: build from global_spec when no prefix was supplied
self._layer_spec = LayerQuantSpec(
quant_type=quant_config["quant_type"],
quant_dtype=quant_config["quant_dtype"],
is_dynamic=quant_config.get("is_dynamic", True),
quant_method=quant_config.get("quant_method", None),
quant_type=quant_config.global_spec.quant_type,
quant_dtype=quant_config.global_spec.quant_dtype,
is_dynamic=quant_config.global_spec.is_dynamic,
quant_method=quant_config.global_spec.quant_method,
checkpoint_dtype=source_quant_dtype,
)

Expand Down
87 changes: 42 additions & 45 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from aiter.ops.shuffle import shuffle_scale_a16w4, shuffle_weight_a16w4
from aiter.utility import fp4_utils
from atom.config import Config, QuantizationConfig, get_current_atom_config
from atom.models.utils import get_quant_config_for_layer
from atom.model_loader.weight_utils import set_weight_attrs
from atom.model_ops.base_config import QuantizeMethodBase
from atom.model_ops.fused_moe.config import (
Expand Down Expand Up @@ -633,9 +632,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig):
super().__init__(moe)
self.quant_config = quant_config
self.quant_type = self.quant_config["quant_type"]
self.quant_dtype = self.quant_config["quant_dtype"]
self.quant_method = self.quant_config["quant_method"]
self.quant_type = self.quant_config.global_spec.quant_type
self.quant_dtype = self.quant_config.global_spec.quant_dtype
self.quant_method = self.quant_config.global_spec.quant_method
self.block_quant = (
self.quant_type == QuantType.per_1x128
or self.quant_type == QuantType.per_1x32
Expand Down Expand Up @@ -967,8 +966,8 @@ class CompressedTensorsFp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig):
super().__init__(moe)
self.quant_config = quant_config
self.quant_type = quant_config["quant_type"]
self.quant_dtype = quant_config["quant_dtype"]
self.quant_type = quant_config.global_spec.quant_type
self.quant_dtype = quant_config.global_spec.quant_dtype

# Check if we need to normalize e4m3fn to e4m3fnuz (AMD GPUs)
self.need_normalize_e4m3fn_to_e4m3fnuz = (
Expand All @@ -985,7 +984,7 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig):
self.per_channel = self.quant_type == QuantType.per_Token

# Check if static input scales (activation quantization)
self.static_input_scales = not quant_config.get("is_dynamic", True)
self.static_input_scales = not quant_config.global_spec.is_dynamic

# Block sizes for block quantization
if self.block_quant:
Expand Down Expand Up @@ -1372,8 +1371,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig):
super().__init__(moe)
self.quant_config = quant_config
self.quant_type = self.quant_config["quant_type"]
self.quant_dtype = self.quant_config["quant_dtype"]
self.quant_type = self.quant_config.global_spec.quant_type
self.quant_dtype = self.quant_config.global_spec.quant_dtype
self.block_quant = (
self.quant_type == QuantType.per_1x128
or self.quant_type == QuantType.per_1x32
Expand Down Expand Up @@ -1479,7 +1478,7 @@ def create_weights(
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
assert self.quant_config["is_dynamic"]
assert self.quant_config.global_spec.is_dynamic

# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
Expand All @@ -1494,7 +1493,7 @@ def create_weights(
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

# INPUT_SCALES
if not self.quant_config["is_dynamic"]:
if not self.quant_config.global_spec.is_dynamic:
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
Expand All @@ -1520,7 +1519,7 @@ def process_weights_after_loading(self, layer: nn.Module) -> None:

# TODO (rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config["is_dynamic"]
assert self.quant_config.global_spec.is_dynamic
if self.need_normalize_e4m3fn_to_e4m3fnuz:
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
Expand Down Expand Up @@ -1550,7 +1549,7 @@ def process_weights_after_loading(self, layer: nn.Module) -> None:
else:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if not self.quant_config["is_dynamic"]:
if not self.quant_config.global_spec.is_dynamic:
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
Expand Down Expand Up @@ -1829,7 +1828,9 @@ def __init__(
super().__init__()
self.prefix = prefix
self.params_dtype = (
quant_config["quant_dtype"] if quant_config else torch.get_default_dtype()
quant_config.global_spec.quant_dtype
if quant_config
else torch.get_default_dtype()
)
self.quant_config = quant_config
self.has_bias = has_bias
Expand Down Expand Up @@ -1947,56 +1948,49 @@ def __init__(
self.moe_config = moe

if quant_config is not None and prefix:
quant_config = get_quant_config_for_layer(quant_config, prefix)
if quant_config is None:
quant_config = QuantizationConfig()

# Resolve per-layer quant spec so the dispatch below sees the
# correct dtype/type when per-layer overrides differ from the
# global config (e.g., MXFP4 globally but FP8 for MTP layers).
if hasattr(quant_config, "resolve") and prefix:
_spec = quant_config.resolve(prefix)
if _spec.is_quantized and (
_spec.quant_dtype != quant_config["quant_dtype"]
or _spec.quant_type != quant_config["quant_type"]
_resolved = quant_config.resolve(prefix)
if not _resolved.is_quantized:
quant_config = None
elif (
_resolved.quant_dtype != quant_config.global_spec.quant_dtype
or _resolved.quant_type != quant_config.global_spec.quant_type
):
# Per-layer override differs from global config (e.g., MXFP4
# globally but FP8 for MTP layers). Build a layer-specific
# QuantizationConfig so the dispatch below sees the correct
# dtype/type.
quant_config = QuantizationConfig(
quant_type=_spec.quant_type,
quant_dtype=_spec.quant_dtype,
quant_type=_resolved.quant_type,
quant_dtype=_resolved.quant_dtype,
is_dynamic=quant_config.get("is_dynamic", True),
quant_name=quant_config.get("quant_name", ""),
quant_method=quant_config.get("quant_method", None),
)
# Update instance attrs to match the (possibly resolved) config
self.quant_config = quant_config
self.params_dtype = quant_config["quant_dtype"]
if quant_config is not None:
self.quant_config = quant_config
self.params_dtype = quant_config.global_spec.quant_dtype

# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method_str = quant_config.get("quant_method", None)
if quant_config["quant_type"] == QuantType.No:
_gs = quant_config.global_spec if quant_config is not None else None
if _gs is None or _gs.quant_type == QuantType.No:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
moe
)
elif (
quant_method_str == "compressed-tensors"
and quant_config["quant_dtype"] == dtypes.fp8
):
elif _gs.quant_method == "compressed-tensors" and _gs.quant_dtype == dtypes.fp8:
# Use CompressedTensorsFp8MoEMethod for compressed-tensors format
self.quant_method = CompressedTensorsFp8MoEMethod(quant_config, moe)
elif (
quant_config["quant_dtype"] == dtypes.fp8
and quant_config["quant_type"] == QuantType.per_Token
):
elif _gs.quant_dtype == dtypes.fp8 and _gs.quant_type == QuantType.per_Token:
# Per-channel FP8 (e.g., Quark per_Token override for MTP layers)
# needs CompressedTensors-style weight scale handling
self.quant_method = CompressedTensorsFp8MoEMethod(quant_config, moe)
elif quant_config["quant_dtype"] == dtypes.fp8:
elif _gs.quant_dtype == dtypes.fp8:
self.quant_method = Fp8MoEMethod(quant_config, moe)
elif quant_config["quant_dtype"] == dtypes.fp4x2:
elif _gs.quant_dtype == dtypes.fp4x2:
self.quant_method = Mxfp4MoEMethod(quant_config, moe)
else:
raise ValueError(f"Unsupported quant dtype: {quant_config['quant_dtype']}")
raise ValueError(f"Unsupported quant dtype: {_gs.quant_dtype}")

assert self.quant_method is not None

Expand Down Expand Up @@ -2284,7 +2278,10 @@ def weight_loader(
shard_id: str = "",
expert_id: int = 0,
) -> None:
if self.quant_config["quant_dtype"] == dtypes.fp4x2 and weight_name == "":
if (
self.quant_config.global_spec.quant_dtype == dtypes.fp4x2
and weight_name == ""
):
self.mxf4_merged_weight_loader(param, loaded_weight)
return

Expand Down Expand Up @@ -2362,7 +2359,7 @@ def weight_loader(
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = self.quant_config["quant_type"]
quant_method = self.quant_config.global_spec.quant_type
if quant_method == QuantType.per_Token:
self._load_per_channel_weight_scale(
shard_id=shard_id,
Expand Down
2 changes: 1 addition & 1 deletion atom/model_ops/topK.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
quant_config = config.quant_config
is_shared_experts_excluded = False
is_experts_excluded = False
exclude_layers = quant_config["exclude_layers"]
exclude_layers = quant_config.parsed.exclude_layers
for layer in exclude_layers:
if "shared_experts" in layer:
is_shared_experts_excluded = True
Expand Down
2 changes: 1 addition & 1 deletion atom/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, atom_config: Config, prefix: str, layer_idx: int) -> None:
)

quant_config = atom_config.quant_config
if quant_config["quant_dtype"] == dtypes.fp4x2:
if quant_config.global_spec.quant_dtype == dtypes.fp4x2:
quant_config = QuantizationConfig()

self.mtp_block = DeepseekV2DecoderLayer(
Expand Down
Loading