From 66e940bac189fe01404b1d1dd58300fd97281b9a Mon Sep 17 00:00:00 2001 From: Thiago Rocha Date: Thu, 5 Mar 2026 19:11:15 +0000 Subject: [PATCH] Phase 2: replace dict-style quant_config access with LayerQuantSpec across all models and ops --- atom/config.py | 73 ++++++++++++++++------------- atom/model_ops/activation.py | 6 +-- atom/model_ops/layernorm.py | 6 +-- atom/model_ops/linear.py | 12 ++--- atom/model_ops/moe.py | 87 +++++++++++++++++------------------ atom/model_ops/topK.py | 2 +- atom/models/deepseek_mtp.py | 2 +- atom/models/deepseek_v2.py | 27 +++++++---- atom/models/llama.py | 4 +- atom/models/utils.py | 51 -------------------- tests/test_per_layer_quant.py | 39 ++++++++-------- 11 files changed, 135 insertions(+), 174 deletions(-) diff --git a/atom/config.py b/atom/config.py index 094813d6c..81bbc390d 100644 --- a/atom/config.py +++ b/atom/config.py @@ -255,14 +255,13 @@ def set_splitting_ops_for_v1(self): ] -class QuantizationConfig(dict): +class QuantizationConfig: """Model-wide quantization configuration. - Still inherits from dict for backward compatibility with existing code - that accesses ``quant_config["quant_type"]``, etc. - - New code should prefer the :pyattr:`parsed` attribute and - :pymeth:`resolve` method. + The primary API is :pymeth:`resolve` and the :pyattr:`parsed` / + :pyattr:`global_spec` attributes. Scalar convenience properties + (``quant_type``, ``quant_dtype``, ``is_dynamic``, ``quant_method``) + delegate to ``global_spec``. """ def __init__( @@ -276,15 +275,9 @@ def __init__( *, parsed: Optional[ParsedQuantConfig] = None, ): - super().__init__() - self["quant_type"] = quant_type if quant_type is not None else QuantType.No - self["quant_dtype"] = quant_dtype if quant_dtype is not None else torch.bfloat16 - self["quant_name"] = quant_name - self["is_dynamic"] = is_dynamic - self["quant_method"] = quant_method - self["exclude_layers"] = exclude_layers if exclude_layers is not None else [] - - # --- New: structured parsed config --- + self._quant_name = quant_name + + # --- Structured parsed config --- if parsed is not None: self._parsed = parsed else: @@ -292,12 +285,14 @@ def __init__( # manually-constructed QuantizationConfigs still work. self._parsed = ParsedQuantConfig( global_spec=LayerQuantSpec( - quant_type=self["quant_type"], - quant_dtype=self["quant_dtype"], - is_dynamic=self["is_dynamic"], - quant_method=self["quant_method"], + quant_type=quant_type if quant_type is not None else QuantType.No, + quant_dtype=( + quant_dtype if quant_dtype is not None else torch.bfloat16 + ), + is_dynamic=is_dynamic, + quant_method=quant_method, ), - exclude_layers=self["exclude_layers"], + exclude_layers=exclude_layers if exclude_layers is not None else [], ) # -- public API -------------------------------------------------------- @@ -342,18 +337,35 @@ def resolve(self, prefix: str) -> LayerQuantSpec: # 4. Global default return self._parsed.global_spec - # -- backward compat --------------------------------------------------- + # -- scalar convenience properties ------------------------------------ + + @property + def quant_type(self) -> "QuantType": + return self._parsed.global_spec.quant_type + + @property + def quant_dtype(self) -> torch.dtype: + return self._parsed.global_spec.quant_dtype + + @property + def is_dynamic(self) -> bool: + return self._parsed.global_spec.is_dynamic + + @property + def quant_method(self) -> Optional[str]: + return self._parsed.global_spec.quant_method + + # -- named accessor --------------------------------------------------- - def get_name(self): - return self["quant_name"] + def get_name(self) -> str: + return self._quant_name # -- internals --------------------------------------------------------- def _is_excluded(self, prefix: str) -> bool: """Check whether *prefix* matches the exclude list. - Uses the same logic as the original ``should_ignore_layer`` - in ``atom.models.utils`` so behaviour is identical. + Supports bare suffix, substring, and ``re:`` prefix for regex patterns. """ exclude_layers: list[str] = self._parsed.exclude_layers if not exclude_layers: @@ -384,12 +396,11 @@ def compute_hash(self) -> str: the final hidden states. """ factors: list[Any] = [] - factors.append(self["quant_type"]) - factors.append(self["quant_dtype"]) - factors.append(self["quant_name"]) - factors.append(self["is_dynamic"]) - factors.append(self["quant_method"]) - # assert_hashable(str_factors) + factors.append(self.quant_type) + factors.append(self.quant_dtype) + factors.append(self._quant_name) + factors.append(self.is_dynamic) + factors.append(self.quant_method) return hashlib.sha256(str(factors).encode()).hexdigest() diff --git a/atom/model_ops/activation.py b/atom/model_ops/activation.py index db5d8033d..b8715f1f7 100644 --- a/atom/model_ops/activation.py +++ b/atom/model_ops/activation.py @@ -65,10 +65,8 @@ def __init__( if quant_config is None: quant_config = QuantizationConfig() - quant_type = quant_config["quant_type"] - params_dtype = quant_config["quant_dtype"] - self.quant_type = quant_type - self.params_dtype = params_dtype + self.quant_type = quant_config.global_spec.quant_type + self.params_dtype = quant_config.global_spec.quant_dtype def forward_native( self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index ce9cfe1ec..22fbda91d 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -189,10 +189,8 @@ def __init__( if quant_config is None: quant_config = QuantizationConfig() - quant_type = quant_config["quant_type"] - params_dtype = quant_config["quant_dtype"] - self.quant_type = quant_type - self.params_dtype = params_dtype + self.quant_type = quant_config.global_spec.quant_type + self.params_dtype = quant_config.global_spec.quant_dtype def forward( self, diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 7f3de11b5..d4aa0e117 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -205,16 +205,16 @@ def __init__( if quant_config is None: quant_config = QuantizationConfig() - # --- New: prefer LayerQuantSpec if provided --- + # layer_spec is always provided by the linear subclasses via resolve() if layer_spec is not None: self._layer_spec = layer_spec else: - # Build a LayerQuantSpec from old-style dict fields for compat + # Fallback: build from global_spec when no prefix was supplied self._layer_spec = LayerQuantSpec( - quant_type=quant_config["quant_type"], - quant_dtype=quant_config["quant_dtype"], - is_dynamic=quant_config.get("is_dynamic", True), - quant_method=quant_config.get("quant_method", None), + quant_type=quant_config.global_spec.quant_type, + quant_dtype=quant_config.global_spec.quant_dtype, + is_dynamic=quant_config.global_spec.is_dynamic, + quant_method=quant_config.global_spec.quant_method, checkpoint_dtype=source_quant_dtype, ) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index ca02116fc..4d4e0d2e7 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -15,7 +15,6 @@ from aiter.ops.shuffle import shuffle_scale_a16w4, shuffle_weight_a16w4 from aiter.utility import fp4_utils from atom.config import Config, QuantizationConfig, get_current_atom_config -from atom.models.utils import get_quant_config_for_layer from atom.model_loader.weight_utils import set_weight_attrs from atom.model_ops.base_config import QuantizeMethodBase from atom.model_ops.fused_moe.config import ( @@ -633,9 +632,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): super().__init__(moe) self.quant_config = quant_config - self.quant_type = self.quant_config["quant_type"] - self.quant_dtype = self.quant_config["quant_dtype"] - self.quant_method = self.quant_config["quant_method"] + self.quant_type = self.quant_config.global_spec.quant_type + self.quant_dtype = self.quant_config.global_spec.quant_dtype + self.quant_method = self.quant_config.global_spec.quant_method self.block_quant = ( self.quant_type == QuantType.per_1x128 or self.quant_type == QuantType.per_1x32 @@ -967,8 +966,8 @@ class CompressedTensorsFp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): super().__init__(moe) self.quant_config = quant_config - self.quant_type = quant_config["quant_type"] - self.quant_dtype = quant_config["quant_dtype"] + self.quant_type = quant_config.global_spec.quant_type + self.quant_dtype = quant_config.global_spec.quant_dtype # Check if we need to normalize e4m3fn to e4m3fnuz (AMD GPUs) self.need_normalize_e4m3fn_to_e4m3fnuz = ( @@ -985,7 +984,7 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): self.per_channel = self.quant_type == QuantType.per_Token # Check if static input scales (activation quantization) - self.static_input_scales = not quant_config.get("is_dynamic", True) + self.static_input_scales = not quant_config.global_spec.is_dynamic # Block sizes for block quantization if self.block_quant: @@ -1372,8 +1371,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): super().__init__(moe) self.quant_config = quant_config - self.quant_type = self.quant_config["quant_type"] - self.quant_dtype = self.quant_config["quant_dtype"] + self.quant_type = self.quant_config.global_spec.quant_type + self.quant_dtype = self.quant_config.global_spec.quant_dtype self.block_quant = ( self.quant_type == QuantType.per_1x128 or self.quant_type == QuantType.per_1x32 @@ -1479,7 +1478,7 @@ def create_weights( ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) - assert self.quant_config["is_dynamic"] + assert self.quant_config.global_spec.is_dynamic # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly @@ -1494,7 +1493,7 @@ def create_weights( set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES - if not self.quant_config["is_dynamic"]: + if not self.quant_config.global_spec.is_dynamic: w13_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) @@ -1520,7 +1519,7 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: # TODO (rob): refactor block quant into separate class. if self.block_quant: - assert self.quant_config["is_dynamic"] + assert self.quant_config.global_spec.is_dynamic if self.need_normalize_e4m3fn_to_e4m3fnuz: w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( @@ -1550,7 +1549,7 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: else: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. - if not self.quant_config["is_dynamic"]: + if not self.quant_config.global_spec.is_dynamic: if layer.w13_input_scale is None or layer.w2_input_scale is None: raise ValueError( "QuantConfig has static quantization, but found " @@ -1829,7 +1828,9 @@ def __init__( super().__init__() self.prefix = prefix self.params_dtype = ( - quant_config["quant_dtype"] if quant_config else torch.get_default_dtype() + quant_config.global_spec.quant_dtype + if quant_config + else torch.get_default_dtype() ) self.quant_config = quant_config self.has_bias = has_bias @@ -1947,56 +1948,49 @@ def __init__( self.moe_config = moe if quant_config is not None and prefix: - quant_config = get_quant_config_for_layer(quant_config, prefix) - if quant_config is None: - quant_config = QuantizationConfig() - - # Resolve per-layer quant spec so the dispatch below sees the - # correct dtype/type when per-layer overrides differ from the - # global config (e.g., MXFP4 globally but FP8 for MTP layers). - if hasattr(quant_config, "resolve") and prefix: - _spec = quant_config.resolve(prefix) - if _spec.is_quantized and ( - _spec.quant_dtype != quant_config["quant_dtype"] - or _spec.quant_type != quant_config["quant_type"] + _resolved = quant_config.resolve(prefix) + if not _resolved.is_quantized: + quant_config = None + elif ( + _resolved.quant_dtype != quant_config.global_spec.quant_dtype + or _resolved.quant_type != quant_config.global_spec.quant_type ): + # Per-layer override differs from global config (e.g., MXFP4 + # globally but FP8 for MTP layers). Build a layer-specific + # QuantizationConfig so the dispatch below sees the correct + # dtype/type. quant_config = QuantizationConfig( - quant_type=_spec.quant_type, - quant_dtype=_spec.quant_dtype, + quant_type=_resolved.quant_type, + quant_dtype=_resolved.quant_dtype, is_dynamic=quant_config.get("is_dynamic", True), quant_name=quant_config.get("quant_name", ""), quant_method=quant_config.get("quant_method", None), ) # Update instance attrs to match the (possibly resolved) config - self.quant_config = quant_config - self.params_dtype = quant_config["quant_dtype"] + if quant_config is not None: + self.quant_config = quant_config + self.params_dtype = quant_config.global_spec.quant_dtype # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. - quant_method_str = quant_config.get("quant_method", None) - if quant_config["quant_type"] == QuantType.No: + _gs = quant_config.global_spec if quant_config is not None else None + if _gs is None or _gs.quant_type == QuantType.No: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( moe ) - elif ( - quant_method_str == "compressed-tensors" - and quant_config["quant_dtype"] == dtypes.fp8 - ): + elif _gs.quant_method == "compressed-tensors" and _gs.quant_dtype == dtypes.fp8: # Use CompressedTensorsFp8MoEMethod for compressed-tensors format self.quant_method = CompressedTensorsFp8MoEMethod(quant_config, moe) - elif ( - quant_config["quant_dtype"] == dtypes.fp8 - and quant_config["quant_type"] == QuantType.per_Token - ): + elif _gs.quant_dtype == dtypes.fp8 and _gs.quant_type == QuantType.per_Token: # Per-channel FP8 (e.g., Quark per_Token override for MTP layers) # needs CompressedTensors-style weight scale handling self.quant_method = CompressedTensorsFp8MoEMethod(quant_config, moe) - elif quant_config["quant_dtype"] == dtypes.fp8: + elif _gs.quant_dtype == dtypes.fp8: self.quant_method = Fp8MoEMethod(quant_config, moe) - elif quant_config["quant_dtype"] == dtypes.fp4x2: + elif _gs.quant_dtype == dtypes.fp4x2: self.quant_method = Mxfp4MoEMethod(quant_config, moe) else: - raise ValueError(f"Unsupported quant dtype: {quant_config['quant_dtype']}") + raise ValueError(f"Unsupported quant dtype: {_gs.quant_dtype}") assert self.quant_method is not None @@ -2284,7 +2278,10 @@ def weight_loader( shard_id: str = "", expert_id: int = 0, ) -> None: - if self.quant_config["quant_dtype"] == dtypes.fp4x2 and weight_name == "": + if ( + self.quant_config.global_spec.quant_dtype == dtypes.fp4x2 + and weight_name == "" + ): self.mxf4_merged_weight_loader(param, loaded_weight) return @@ -2362,7 +2359,7 @@ def weight_loader( # FusedMoeWeightScaleSupported # TODO @dsikka: once hardened, refactor to use vLLM Parameters # specific to each case - quant_method = self.quant_config["quant_type"] + quant_method = self.quant_config.global_spec.quant_type if quant_method == QuantType.per_Token: self._load_per_channel_weight_scale( shard_id=shard_id, diff --git a/atom/model_ops/topK.py b/atom/model_ops/topK.py index 5103f9256..50d8eb0c3 100644 --- a/atom/model_ops/topK.py +++ b/atom/model_ops/topK.py @@ -19,7 +19,7 @@ def is_rocm_aiter_fusion_shared_expert_enabled() -> bool: quant_config = config.quant_config is_shared_experts_excluded = False is_experts_excluded = False - exclude_layers = quant_config["exclude_layers"] + exclude_layers = quant_config.parsed.exclude_layers for layer in exclude_layers: if "shared_experts" in layer: is_shared_experts_excluded = True diff --git a/atom/models/deepseek_mtp.py b/atom/models/deepseek_mtp.py index 5394d49e5..e03a9c344 100644 --- a/atom/models/deepseek_mtp.py +++ b/atom/models/deepseek_mtp.py @@ -56,7 +56,7 @@ def __init__(self, atom_config: Config, prefix: str, layer_idx: int) -> None: ) quant_config = atom_config.quant_config - if quant_config["quant_dtype"] == dtypes.fp4x2: + if quant_config.global_spec.quant_dtype == dtypes.fp4x2: quant_config = QuantizationConfig() self.mtp_block = DeepseekV2DecoderLayer( diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index f0342dce1..b99adecc2 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -82,7 +82,6 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, - should_ignore_layer, ) from atom.utils import envs from atom.utils.custom_register import direct_register_custom_op @@ -1247,9 +1246,14 @@ def __init__( # For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs, # For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs - if quant_config["quant_dtype"] == dtypes.fp4x2: + _attn_quant_dtype = ( + quant_config.resolve(prefix).quant_dtype + if quant_config is not None + else None + ) + if _attn_quant_dtype == dtypes.fp4x2: # normally linear layers in attn share the same quant config - if should_ignore_layer(quant_config, prefix): + if not quant_config.resolve(prefix).is_quantized: source_quant_dtype = None quant_config = None base_quant_config = None @@ -1407,10 +1411,11 @@ def __init__( self.quant_dtype = None self.fuse_qknorm_quant = False if quant_config is not None and ENABLE_DS_QKNORM_QUANT_FUSION: - if quant_config["quant_dtype"] == dtypes.fp8 or ( - quant_config["quant_dtype"] == dtypes.fp4x2 and use_triton_gemm() + _qkn_dtype = quant_config.resolve(prefix).quant_dtype + if _qkn_dtype == dtypes.fp8 or ( + _qkn_dtype == dtypes.fp4x2 and use_triton_gemm() ): - self.quant_dtype = quant_config["quant_dtype"] + self.quant_dtype = _qkn_dtype self.fuse_qknorm_quant = True def forward( @@ -1553,11 +1558,11 @@ def __init__( self.fuse_input_norm_quant = False self.fuse_ar_input_norm = ENABLE_ALLREDUCE_RMSNORM_FUSION if quant_config is not None and ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION: + _input_norm_dtype = quant_config.global_spec.quant_dtype if ( - quant_config["quant_dtype"] == dtypes.fp8 - or quant_config["quant_dtype"] == dtypes.fp4x2 + _input_norm_dtype == dtypes.fp8 or _input_norm_dtype == dtypes.fp4x2 ) and use_triton_gemm(): - self.quant_dtype = quant_config["quant_dtype"] + self.quant_dtype = _input_norm_dtype self.fuse_input_norm_quant = True if self.fuse_ar_input_norm: self.fuse_ar_input_norm = False @@ -1604,7 +1609,9 @@ def __init__( fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, ) self.routed_scaling_factor = config.routed_scaling_factor - self.quant_dtype = quant_config["quant_dtype"] if quant_config else None + self.quant_dtype = ( + quant_config.global_spec.quant_dtype if quant_config else None + ) self.fuse_rmsnorm_quant = ( ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION and self.quant_dtype is not None ) diff --git a/atom/models/llama.py b/atom/models/llama.py index 4f485de69..df77c6a34 100644 --- a/atom/models/llama.py +++ b/atom/models/llama.py @@ -99,7 +99,7 @@ def __init__( self.act_fn = SiluAndMul( fused_quant=self.fused_act_quant, quant_config=quant_config ) - self.quant_type = quant_config["quant_type"] + self.quant_type = quant_config.global_spec.quant_type def forward(self, x, x_scale: Optional[torch.Tensor] = None): x = self.gate_up_proj(x, x_scale=x_scale) @@ -271,7 +271,7 @@ def __init__( ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT ) - self.quant_type = quant_config["quant_type"] + self.quant_type = quant_config.global_spec.quant_type self.self_attn = LlamaAttention( config=config, diff --git a/atom/models/utils.py b/atom/models/utils.py index 8b61a39e7..8002cbe74 100644 --- a/atom/models/utils.py +++ b/atom/models/utils.py @@ -7,14 +7,10 @@ Protocol, Tuple, Union, - Optional, ) import torch import os -import re - -from atom.config import QuantizationConfig import logging @@ -237,53 +233,6 @@ def fast_topk(values, topk, dim): return torch.topk(values, topk, dim=dim) -def should_ignore_layer( - quantization_config: Optional[QuantizationConfig], prefix: str -) -> bool: - """Check whether *prefix* should skip quantization. - - Delegates to ``QuantizationConfig.resolve()`` when available (the new - ``LayerQuantSpec``-based path). Falls back to the legacy exclude-list - scan for plain-dict configs. - """ - if quantization_config is None: - return True - # New path: use resolve() if available - if hasattr(quantization_config, "resolve"): - spec = quantization_config.resolve(prefix) - return not spec.is_quantized - # Legacy fallback - exclude_layers: List[str] = quantization_config.get("exclude_layers", []) - if not exclude_layers: - return False - for exclude_layer in exclude_layers: - if exclude_layer.startswith("re"): - regex_pattern = exclude_layer[3:] - if re.search(regex_pattern, prefix): - return True - elif prefix in exclude_layer: - return True - else: - if prefix.split(".")[-1] == exclude_layer: - return True - return False - - -def get_quant_config_for_layer( - quantization_config: Optional[QuantizationConfig], prefix: str -) -> Optional[QuantizationConfig]: - """Return *quantization_config* if *prefix* should be quantized, else None. - - This is the legacy helper — new code should prefer - ``quant_config.resolve(prefix)`` directly. - """ - return ( - None - if should_ignore_layer(quantization_config, prefix) - else quantization_config - ) - - def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int: """ Extract the layer index from the module name. diff --git a/tests/test_per_layer_quant.py b/tests/test_per_layer_quant.py index 577033e4e..42c6c4b4b 100644 --- a/tests/test_per_layer_quant.py +++ b/tests/test_per_layer_quant.py @@ -9,7 +9,7 @@ - Parser registry: registration, dispatch, fallback - QuarkParser / CompressedTensorsParser / GenericParser: parsing logic - QuantizationConfig.resolve(): exclude-list resolution, per-layer overrides -- Backward compatibility: dict access, should_ignore_layer, get_quant_config_for_layer +- QuantizationConfig scalar properties: quant_type, quant_dtype, is_dynamic, quant_method - LinearBase: layer_spec parameter plumbing """ @@ -292,10 +292,11 @@ def test_resolve_no_quant_config(self): # ==================================================================== -# 7. Backward compatibility +# 7. Scalar properties and resolve API # ==================================================================== class TestBackwardCompat(unittest.TestCase): - def test_dict_access(self): + def test_scalar_properties(self): + """QuantizationConfig exposes scalar properties that delegate to global_spec.""" from atom.config import QuantizationConfig qc = QuantizationConfig( @@ -304,36 +305,36 @@ def test_dict_access(self): quant_method="quark", exclude_layers=["lm_head"], ) - self.assertEqual(qc["quant_type"], QuantType.per_1x32) - self.assertEqual(qc["quant_dtype"], torch.float4_e2m1fn_x2) - self.assertEqual(qc["quant_method"], "quark") - self.assertEqual(qc["exclude_layers"], ["lm_head"]) + self.assertEqual(qc.quant_type, QuantType.per_1x32) + self.assertEqual(qc.quant_dtype, torch.float4_e2m1fn_x2) + self.assertEqual(qc.quant_method, "quark") + self.assertEqual(qc.parsed.exclude_layers, ["lm_head"]) - def test_should_ignore_layer_uses_resolve(self): + def test_resolve_excluded_layer(self): + """resolve() returns no_quant spec for excluded layers.""" from atom.config import QuantizationConfig - from atom.models.utils import should_ignore_layer qc = QuantizationConfig( quant_type=QuantType.per_1x32, quant_dtype=torch.float4_e2m1fn_x2, exclude_layers=["lm_head"], ) - self.assertTrue(should_ignore_layer(qc, "lm_head")) - self.assertFalse(should_ignore_layer(qc, "model.layers.0.mlp.down_proj")) + self.assertFalse(qc.resolve("lm_head").is_quantized) + self.assertTrue(qc.resolve("model.layers.0.mlp.down_proj").is_quantized) - def test_get_quant_config_for_layer(self): + def test_resolve_replaces_legacy_helpers(self): + """resolve() subsumes what get_quant_config_for_layer used to return.""" from atom.config import QuantizationConfig - from atom.models.utils import get_quant_config_for_layer qc = QuantizationConfig( quant_type=QuantType.per_1x32, quant_dtype=torch.float4_e2m1fn_x2, exclude_layers=["lm_head"], ) - self.assertIsNone(get_quant_config_for_layer(qc, "lm_head")) - self.assertIs( - get_quant_config_for_layer(qc, "model.layers.0.mlp.down_proj"), qc - ) + excluded_spec = qc.resolve("lm_head") + included_spec = qc.resolve("model.layers.0.mlp.down_proj") + self.assertFalse(excluded_spec.is_quantized) + self.assertTrue(included_spec.is_quantized) def test_parsed_property(self): from atom.config import QuantizationConfig @@ -416,8 +417,8 @@ def test_layer_spec_with_checkpoint_dtype(self, mock_tp): self.assertTrue(lb._layer_spec.needs_online_quant) @patch("atom.model_ops.linear.get_tp_group") - def test_no_layer_spec_builds_from_dict(self, mock_tp): - """When layer_spec is not provided, LinearBase builds one from the dict.""" + def test_no_layer_spec_builds_from_global_spec(self, mock_tp): + """When layer_spec is not provided, LinearBase builds one from global_spec.""" mock_group = MagicMock() mock_group.rank_in_group = 0 mock_group.world_size = 1