diff --git a/atom/config.py b/atom/config.py index e273fb73c..ce7e1e4f0 100644 --- a/atom/config.py +++ b/atom/config.py @@ -7,7 +7,7 @@ import os import re from dataclasses import dataclass, field -from typing import Any, Optional, Union +from typing import Any, cast, Optional, Union import torch from aiter import QuantType @@ -250,126 +250,246 @@ def set_splitting_ops_for_v1(self): ] -class QuantizationConfig(dict): +class LayerQuantConfig(dict): def __init__( self, quant_type=QuantType.No, quant_dtype=torch.bfloat16, is_dynamic=True, - quant_name="", quant_method=None, - exclude_layers: Optional[list[str]] = None, + quant_name="", ): + """ + Core components of layer_quant + """ super().__init__() self["quant_type"] = quant_type if quant_type is not None else QuantType.No self["quant_dtype"] = quant_dtype if quant_dtype is not None else torch.bfloat16 - self["quant_name"] = quant_name self["is_dynamic"] = is_dynamic self["quant_method"] = quant_method - self["exclude_layers"] = exclude_layers if exclude_layers is not None else [] + self["quant_name"] = quant_name - def get_name(self): - return self["quant_name"] - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. +class QuantizationConfig: + def __init__(self, config: PretrainedConfig = None): + if config is None: + self.torch_dtype = torch.bfloat16 + self.hf_quant_config = None + self.global_quant_config = LayerQuantConfig() + self.layer_quant_config = {} + self.exclude_layers = [] + self.quant_method = "" + return + + self.torch_dtype = getattr(config, "torch_dtype", "bf16") + self.hf_quant_config = getattr(config, "quantization_config", None) + self.global_quant_config = None + self.layer_quant_config = {} + self.exclude_layers = [] + + if self.hf_quant_config is None: + self.global_quant_config = LayerQuantConfig( + quant_type=QuantType.No, quant_dtype=self.torch_dtype + ) + self.quant_method = None + return - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - factors: list[Any] = [] - factors.append(self["quant_type"]) - factors.append(self["quant_dtype"]) - factors.append(self["quant_name"]) - factors.append(self["is_dynamic"]) - factors.append(self["quant_method"]) - # assert_hashable(str_factors) - return hashlib.sha256(str(factors).encode()).hexdigest() + self.quant_method = self.hf_quant_config.get("quant_method", "") + if self.quant_method == "quark": + layer_quant_config_dict = cast( + dict[str, Any], self.hf_quant_config.get("layer_quant_config") + ) + for layer_name, layer_cfg in layer_quant_config_dict.items(): + self.layer_quant_config[layer_name] = self.parse_quark_config_dict( + layer_cfg + ) + global_quant_config_dict = cast( + dict[str, Any], self.hf_quant_config.get("global_quant_config") + ) + self.global_quant_config = self.parse_quark_config_dict( + global_quant_config_dict + ) -def get_quant_config(config: PretrainedConfig) -> QuantizationConfig: - torch_dtype = getattr(config, "dtype", "bf16") - orig_quant_config = getattr(config, "quantization_config", None) - if orig_quant_config is None: - return QuantizationConfig( - quant_type=QuantType.No, - quant_dtype=torch_dtype, - ) + self.exclude_layers = cast(list[str], self.hf_quant_config.get("exclude")) + else: + self.parse_other_config() - quant_method = orig_quant_config.get("quant_method", None) - RE_QUANT_BLOCKSIZE = r"\'(?:group_size|weight_block_size)\'\:\s*(?:\[\n*)\s*(\d+)," - orig_quant_config_str = str(orig_quant_config) - if quant_method == "compressed-tensors" or "channel'," in orig_quant_config_str: - quant_type = QuantType.per_Token - elif group_size := re.search(RE_QUANT_BLOCKSIZE, orig_quant_config_str): - group_size = int(group_size.group(1)) - assert group_size in (32, 128), f"Unsupported group size {group_size}" - if group_size == 128: - quant_type = QuantType.per_1x128 - elif group_size == 32: + def get_name(self): + """ + from original quant_config func + """ + return self.quant_method + + def parse_quark_config_dict(self, config: dict) -> LayerQuantConfig: + quant_type = None + quant_dtype = None + is_dynamic = False + # parse quark config dict + weight_config = cast(dict[str, Any], config.get("weight")) + input_config = cast(dict[str, Any], config.get("input_tensors")) + weight_qscheme = cast(str, weight_config.get("qscheme")) + weight_dtype = weight_config.get("dtype") + + # quant_type + if weight_qscheme == "per_channel": + quant_type = QuantType.per_Token + elif weight_qscheme == "per_tensor": + quant_type = QuantType.per_Tensor + elif weight_qscheme == "per_group": + # Currently, quark only supports group_size=32 quant_type = QuantType.per_1x32 - else: - quant_type = QuantType.per_Tensor - - RE_QUANT_DTYPE = r"\'(?:d?type|weight_dtype|quant_method)\'\:\s*\'(\w+)\'" - quant_dtype = None - m = re.search(RE_QUANT_DTYPE, orig_quant_config_str) - if m and m.group(1).lower() in ["fp8", "fp4", "int8", "int4", "fp8_e4m3", "mxfp4"]: - dtype = m.group(1).lower().split("_")[0] - if dtype == "mxfp4": - dtype = "fp4" + # quant_dtype + dtype = weight_dtype.split("_")[0] if dtype.endswith("4"): dtype += "x2" quant_dtype = d_dtypes[dtype] - else: - bit_match = re.search(r"\'(?:num_)?bits\'\:\s*(\d+)", orig_quant_config_str) - if bit_match: - bit = int(bit_match.group(1)) - dtype_match = re.search(RE_QUANT_DTYPE, orig_quant_config_str) - if dtype_match: - dtype = dtype_match.group(1).lower() - dtype_prefix = "i" if dtype.startswith("int") else "fp" - else: - dtype_prefix = "i" - quant_dtype_str = ( - f"{dtype_prefix}{bit}" if bit != 4 else f"{dtype_prefix}{bit}x2" + + # is_dynamic + if input_config is not None: + # input_dtype = input_config.get("dtype") + # input_qscheme = cast(str, input_config.get("qscheme")) + is_dynamic = not cast(bool, input_config.get("is_dynamic")) + return LayerQuantConfig( + quant_type=quant_type, + quant_dtype=quant_dtype, + is_dynamic=is_dynamic, + quant_method="quark", + ) + + # TODO: For now, it's just a temporary migration. + # We should subsequently refine them in a targeted manner. + def parse_other_config(self): + RE_QUANT_BLOCKSIZE = ( + r"\'(?:group_size|weight_block_size)\'\:\s*(?:\[\n*)\s*(\d+)," + ) + orig_quant_config = self.hf_quant_config + quant_method = self.quant_method + orig_quant_config_str = str(orig_quant_config) + if quant_method == "compressed-tensors" or "channel'," in orig_quant_config_str: + quant_type = QuantType.per_Token + elif group_size := re.search(RE_QUANT_BLOCKSIZE, orig_quant_config_str): + group_size = int(group_size.group(1)) + assert group_size in (32, 128), f"Unsupported group size {group_size}" + if group_size == 128: + quant_type = QuantType.per_1x128 + elif group_size == 32: + quant_type = QuantType.per_1x32 + else: + quant_type = QuantType.per_Tensor + + RE_QUANT_DTYPE = r"\'(?:d?type|weight_dtype|quant_method)\'\:\s*\'(\w+)\'" + quant_dtype = None + m = re.search(RE_QUANT_DTYPE, orig_quant_config_str) + if m and m.group(1).lower() in [ + "fp8", + "fp4", + "int8", + "int4", + "fp8_e4m3", + "mxfp4", + ]: + dtype = m.group(1).lower().split("_")[0] + if dtype == "mxfp4": + dtype = "fp4" + if dtype.endswith("4"): + dtype += "x2" + quant_dtype = d_dtypes[dtype] + else: + bit_match = re.search(r"\'(?:num_)?bits\'\:\s*(\d+)", orig_quant_config_str) + if bit_match: + bit = int(bit_match.group(1)) + dtype_match = re.search(RE_QUANT_DTYPE, orig_quant_config_str) + if dtype_match: + dtype = dtype_match.group(1).lower() + dtype_prefix = "i" if dtype.startswith("int") else "fp" + else: + dtype_prefix = "i" + quant_dtype_str = ( + f"{dtype_prefix}{bit}" if bit != 4 else f"{dtype_prefix}{bit}x2" + ) + quant_dtype = d_dtypes.get(quant_dtype_str, None) + assert ( + quant_dtype is not None + ), f"Cannot parse quant dtype from {orig_quant_config_str}" + if quant_dtype == d_dtypes["fp4x2"]: + quant_type = QuantType.per_1x32 + + RE_STATIC_QUANT = r"\'(?:activation_scheme)\'\:\s*\'(static)\'" + if re.search(RE_STATIC_QUANT, orig_quant_config_str): + is_dynamic = False + else: + is_dynamic = True + if quant_method == "compressed-tensors": + exclude_layers_key = "ignore" + elif quant_method == "quark": + exclude_layers_key = "exclude" + else: + logger.warning( + f"Using 'ignore' as key for exclude layers with quant_method {quant_method}, \ + please double check the quantization config." ) - quant_dtype = d_dtypes.get(quant_dtype_str, None) - assert ( - quant_dtype is not None - ), f"Cannot parse quant dtype from {orig_quant_config_str}" - if quant_dtype == d_dtypes["fp4x2"]: - quant_type = QuantType.per_1x32 - - RE_STATIC_QUANT = r"\'(?:activation_scheme)\'\:\s*\'(static)\'" - if re.search(RE_STATIC_QUANT, orig_quant_config_str): - is_dynamic = False - else: - is_dynamic = True - if quant_method == "compressed-tensors": - exclude_layers_key = "ignore" - elif quant_method == "quark": - exclude_layers_key = "exclude" - else: - logger.warning( - f"Using 'ignore' as key for exclude layers with quant_method {quant_method}, \ - please double check the quantization config." + exclude_layers_key = "ignore" + exclude_layers = orig_quant_config.get(exclude_layers_key, []) + + self.global_quant_config = LayerQuantConfig( + quant_type=quant_type, quant_dtype=quant_dtype, is_dynamic=is_dynamic ) - exclude_layers_key = "ignore" - exclude_layers = orig_quant_config.get(exclude_layers_key, None) - return QuantizationConfig( - quant_type, - quant_dtype, - is_dynamic, - quant_method=quant_method, - exclude_layers=exclude_layers, - ) + # self.layer_quant_config = None + self.exclude_layers = exclude_layers + + def should_ignore_layer_quant(self, layer_name: str) -> bool: + # TODO: solve fused_mapping case + if layer_name is None or not self.exclude_layers: + return False + return any( + self.is_equal_or_regex_match(layer_name, ignore_str) + for ignore_str in self.exclude_layers + ) + + def is_equal_or_regex_match( + self, layer_name: str, ignore_str: str, check_contains: bool = False + ) -> bool: + """Match the target string or regular expression""" + if ignore_str.startswith("re:"): + pattern = ignore_str[3:] + if re.match(pattern, layer_name): + return True + elif check_contains: + if ignore_str.lower() in layer_name.lower(): + return True + elif ignore_str == layer_name: + return True + return False + + def get_layer_quant_config(self, layer_name: str) -> LayerQuantConfig: + if self.should_ignore_layer_quant(layer_name=layer_name): + # return unquantized config + return LayerQuantConfig(quant_dtype=self.torch_dtype) + # layer quant config + layer_quant_config = None + if self.layer_quant_config: + import fnmatch + + def _matches_pattern(layer_name, pattern): + if "*" not in pattern: + return layer_name in pattern + return fnmatch.fnmatch(layer_name, pattern) + + for name_pattern, config in self.layer_quant_config.items(): + if _matches_pattern(layer_name, name_pattern): + layer_quant_config = config + + layer_quant_config = ( + self.global_quant_config + if layer_quant_config is None + else layer_quant_config + ) + # TODO: if use_aiter, we can customize the quantization format here, such as dpsk + # For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs, + # For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs + + return layer_quant_config _CONFIG_REGISTRY: dict[str, str] = { @@ -590,9 +710,7 @@ class Config: port: int = 8006 torch_profiler_dir: str | None = os.getenv("ATOM_TORCH_PROFILER_DIR", None) compilation_config: CompilationConfig = field(default_factory=CompilationConfig) - quant_config: QuantizationConfig = field( - default_factory=lambda: QuantizationConfig() - ) + quant_config: QuantizationConfig = field(init=False) asyncio_mode: bool = False load_dummy: bool = False enable_expert_parallel: bool = False @@ -637,7 +755,7 @@ def __post_init__(self): eos_ids := getattr(self.generation_config, "eos_token_id", None) ) is not None: self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids - self.quant_config = get_quant_config(self.hf_config) + self.quant_config = QuantizationConfig(self.hf_config) hf_config_max_position_embeddings = getattr( self.hf_config, "max_position_embeddings", 8192 ) @@ -695,8 +813,9 @@ def compute_hash(self) -> str: # summarize vllm config vllm_factors: list[Any] = [] - if self.quant_config: - vllm_factors.append(self.quant_config.compute_hash()) + # TODO: fix here + # if self.quant_config: + # vllm_factors.append(self.quant_config.compute_hash()) if self.compilation_config: vllm_factors.append(self.compilation_config.compute_hash()) diff --git a/atom/model_ops/activation.py b/atom/model_ops/activation.py index db5d8033d..99d491163 100644 --- a/atom/model_ops/activation.py +++ b/atom/model_ops/activation.py @@ -6,7 +6,7 @@ from torch import nn import torch.nn.functional as F from aiter import silu_and_mul -from atom.config import QuantizationConfig +from atom.config import QuantizationConfig, LayerQuantConfig from aiter.jit.utils.torch_guard import torch_compile_guard from aiter import ( @@ -63,10 +63,12 @@ def __init__( super().__init__() self.fused_quant = fused_quant if quant_config is None: - quant_config = QuantizationConfig() + layer_quant_config = LayerQuantConfig() + else: + layer_quant_config = quant_config.global_quant_config - quant_type = quant_config["quant_type"] - params_dtype = quant_config["quant_dtype"] + quant_type = layer_quant_config["quant_type"] + params_dtype = layer_quant_config["quant_dtype"] self.quant_type = quant_type self.params_dtype = params_dtype diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index 68a8567d5..f91c73e27 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -import itertools from typing import Type import aiter diff --git a/atom/model_ops/fused_moe/config.py b/atom/model_ops/fused_moe/config.py index ff3a25269..5ab04b544 100644 --- a/atom/model_ops/fused_moe/config.py +++ b/atom/model_ops/fused_moe/config.py @@ -198,7 +198,9 @@ def make( if weight_dtype is None: weight_dtype = quant_dtype - a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype, False, block_shape) + a_shape, w_shape = _quant_flags_to_group_shape( + quant_dtype, per_act_token_quant, block_shape + ) quant_config = FusedMoEQuantConfig( _a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale), _a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale), diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index ce9cfe1ec..a42811aad 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -8,7 +8,7 @@ has_torch_function_unary, handle_torch_function, ) -from atom.config import QuantizationConfig +from atom.config import QuantizationConfig, LayerQuantConfig from torch import nn from aiter import ( rmsnorm2d_fwd, @@ -188,9 +188,12 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() if quant_config is None: - quant_config = QuantizationConfig() - quant_type = quant_config["quant_type"] - params_dtype = quant_config["quant_dtype"] + layer_quant_config = LayerQuantConfig() + else: + layer_quant_config = quant_config.global_quant_config + + quant_type = layer_quant_config["quant_type"] + params_dtype = layer_quant_config["quant_dtype"] self.quant_type = quant_type self.params_dtype = params_dtype diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index a3d7b4ef7..f5849cec1 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -17,9 +17,8 @@ ) from torch import nn -from atom.config import QuantizationConfig, get_current_atom_config +from atom.config import QuantizationConfig, get_current_atom_config, LayerQuantConfig from atom.model_ops.utils import normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale -from atom.models.utils import get_quant_config_for_layer # import torch.distributed as dist from aiter.dist.parallel_state import get_tp_group @@ -200,12 +199,18 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = False, source_quant_dtype: torch.dtype | None = None, + prefix: str = "", ): - if quant_config is None: - quant_config = QuantizationConfig() + self.prefix = prefix + layer_quant_config = ( + quant_config.get_layer_quant_config(prefix) + if quant_config is not None + else LayerQuantConfig() + ) + quant_type = layer_quant_config["quant_type"] + params_dtype = layer_quant_config["quant_dtype"] self.source_quant_dtype = source_quant_dtype - quant_type = quant_config["quant_type"] - params_dtype = quant_config["quant_dtype"] + self.layer_quant_config = layer_quant_config super().__init__() self.reduce_results = reduce_results self.input_size = input_size @@ -259,7 +264,7 @@ def __init__( torch.empty(len(self.output_partition_sizes), 1, dtype=dtypes.fp32), requires_grad=False, ) - if not quant_config["is_dynamic"]: + if not layer_quant_config["is_dynamic"]: self.input_scale = nn.Parameter( torch.empty( len(self.output_partition_sizes), 1, dtype=dtypes.fp32 @@ -314,7 +319,13 @@ def weight_loader_process( and param.data.element_size() == loaded_weight.element_size() ): param.data = param.data.view(loaded_weight.dtype) - param.data.copy_(post_process_func(loaded_weight)) + loaded_weight = post_process_func(loaded_weight) + if ( + loaded_weight.shape != param.data.shape + and loaded_weight.numel() == param.data.numel() + ): + loaded_weight = loaded_weight.reshape(param.data.shape) + param.data.copy_(loaded_weight) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data @@ -451,6 +462,7 @@ def __init__( bias: bool = False, quant_config: Optional[QuantizationConfig] = None, source_quant_dtype: torch.dtype = None, + prefix: str = "", **kwargs, ): super().__init__( @@ -460,6 +472,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=prefix, ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): @@ -475,6 +488,7 @@ def __init__( bias: bool = False, quant_config: Optional[QuantizationConfig] = None, source_quant_dtype: torch.dtype = None, + prefix: str = "", **kwargs, ): self.tp_dim = 0 @@ -485,6 +499,7 @@ def __init__( bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=prefix, ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): @@ -507,8 +522,6 @@ def __init__( **kwargs, ): self.output_sizes = output_sizes - if quant_config is not None and prefix: - quant_config = get_quant_config_for_layer(quant_config, prefix) super().__init__( input_size, output_sizes, @@ -516,6 +529,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=prefix, ) def weight_loader( @@ -551,6 +565,7 @@ def __init__( bias: bool = False, quant_config: Optional[QuantizationConfig] = None, source_quant_dtype: torch.dtype = None, + prefix: str = "", **kwargs, ): self.head_k_dim = head_k_dim @@ -571,6 +586,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=prefix, ) def weight_loader( @@ -619,6 +635,7 @@ def __init__( bias: bool = False, quant_config: Optional[QuantizationConfig] = None, source_quant_dtype: torch.dtype = None, + prefix: str = "", **kwargs, ): self.head_size = head_size @@ -650,6 +667,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=prefix, ) def weight_loader( @@ -701,8 +719,6 @@ def __init__( **kwargs, ): self.tp_rank = get_tp_group().rank_in_group - if quant_config is not None and prefix: - quant_config = get_quant_config_for_layer(quant_config, prefix) super().__init__( input_size, output_size, @@ -711,14 +727,20 @@ def __init__( quant_config=quant_config, reduce_results=reduce_results, source_quant_dtype=source_quant_dtype, + prefix=prefix, ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data if param is not getattr(self, "bias", None): - shard_size = param_data.size(self.tp_dim) if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.view(1, 1) + if loaded_weight.ndim <= self.tp_dim: + # dims < tp_dim (1D per-channel scale with + # tp_dim=1) + param.weight_loader_process(param_data, loaded_weight) + return + shard_size = param_data.size(self.tp_dim) if loaded_weight.size(self.tp_dim) == 1 and self.tp_size > 1: loaded_weight = loaded_weight.repeat(1, self.tp_size) start_idx = self.tp_rank * shard_size @@ -737,6 +759,7 @@ def __init__( bias: bool = False, quant_config: Optional[QuantizationConfig] = None, source_quant_dtype: torch.dtype = None, + prefix: str = "", **kwargs, ): self.output_sizes = output_size @@ -746,6 +769,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=prefix, ) def weight_loader( @@ -768,6 +792,10 @@ def weight_loader( elif self.quant_type == QuantType.per_Tensor: shard_offset = loaded_shard_id shard_size = 1 + else: + # Per-channel same layout as weights + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] else: shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_size = self.output_sizes[loaded_shard_id] diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index b6820cdda..27ad9ddaa 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -14,8 +14,12 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.shuffle import shuffle_scale_a16w4, shuffle_weight_a16w4 from aiter.utility import fp4_utils -from atom.config import Config, QuantizationConfig, get_current_atom_config -from atom.models.utils import get_quant_config_for_layer +from atom.config import ( + Config, + QuantizationConfig, + get_current_atom_config, + LayerQuantConfig, +) from atom.model_loader.weight_utils import set_weight_attrs from atom.model_ops.base_config import QuantizeMethodBase from atom.model_ops.fused_moe.config import ( @@ -630,7 +634,7 @@ def rocm_aiter_fused_moe_fake( class Mxfp4MoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): + def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): super().__init__(moe) self.quant_config = quant_config self.quant_type = self.quant_config["quant_type"] @@ -964,7 +968,7 @@ def apply( # Refer to CompressedTensorsW8A8Fp8MoEMethod in vllm class CompressedTensorsFp8MoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): + def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): super().__init__(moe) self.quant_config = quant_config self.quant_type = quant_config["quant_type"] @@ -1358,8 +1362,10 @@ def apply( class Fp8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. + Supports three quantization strategies: + - per_Tensor: per-tensor weight scale, static/dynamic activation scale + - per_Token (PTPTC): per-channel weight scale, dynamic per-token activation + - per_1x128 / per_1x32 (block): block-wise weight scale, dynamic activation Also supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after @@ -1369,7 +1375,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): + def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): super().__init__(moe) self.quant_config = quant_config self.quant_type = self.quant_config["quant_type"] @@ -1378,6 +1384,7 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): self.quant_type == QuantType.per_1x128 or self.quant_type == QuantType.per_1x32 ) + self.channel_quant = self.quant_type == QuantType.per_Token self.need_normalize_e4m3fn_to_e4m3fnuz = ( self.quant_dtype == torch.float8_e4m3fnuz ) @@ -1447,18 +1454,24 @@ def create_weights( set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - if not self.block_quant: - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. + if self.channel_quant: + # Per-channel (PTPTC): one scale per output channel per expert. + # w13: [E, 2*N], w2: [E, hidden_size] w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, ) w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False + torch.ones(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) - else: + elif self.block_quant: w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, @@ -1480,21 +1493,26 @@ def create_weights( layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) assert self.quant_config["is_dynamic"] + else: + # Per-tensor + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - # extra_weight_attrs.update( - # {"quant_method": FusedMoeWeightScaleSupported.BLOCK. - # value} if self.block_quant else - # {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES - if not self.quant_config["is_dynamic"]: + # Per-channel uses dynamic per-token activation, no static input scales. + if self.channel_quant or self.quant_config["is_dynamic"]: + layer.w13_input_scale = None + layer.w2_input_scale = None + else: w13_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) @@ -1505,134 +1523,125 @@ def create_weights( ) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) - else: - layer.w13_input_scale = None - layer.w2_input_scale = None + + def _normalize_weights_and_scales(self, layer: nn.Module): + if not self.need_normalize_e4m3fn_to_e4m3fnuz: + return + w13_weight, w13_weight_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) + layer.w13_weight = nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = nn.Parameter(w13_weight_scale, requires_grad=False) + layer.w2_weight = nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = nn.Parameter(w2_weight_scale, requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = nn.Parameter(w13_input_scale, requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = nn.Parameter(w2_input_scale, requires_grad=False) def process_weights_after_loading(self, layer: nn.Module) -> None: - # Lazy import to avoid importing triton too early. - # from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - # is_rocm_aiter_moe_enabled, shuffle_weights) + if self.block_quant: + self._process_block_quant(layer) + elif self.channel_quant: + self._process_channel_quant(layer) + else: + self._process_tensor_quant(layer) - # self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() - # self.rocm_aiter_use_asm = (self.rocm_aiter_moe_enabled - # and envs.VLLM_ROCM_USE_AITER_ASMMOE) + def _process_block_quant(self, layer: nn.Module) -> None: + assert self.quant_config["is_dynamic"] + self._normalize_weights_and_scales(layer) - # TODO (rob): refactor block quant into separate class. - if self.block_quant: - assert self.quant_config["is_dynamic"] - if self.need_normalize_e4m3fn_to_e4m3fnuz: - w13_weight, w13_weight_scale, w13_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale - ) - ) - w2_weight, w2_weight_scale, w2_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale - ) - ) - else: - w13_weight = layer.w13_weight.data - w13_weight_scale = layer.w13_weight_scale.data - w2_weight = layer.w2_weight - w2_weight_scale = layer.w2_weight_scale + if not self.need_normalize_e4m3fn_to_e4m3fnuz: + layer.w13_weight = nn.Parameter(layer.w13_weight.data, requires_grad=False) + layer.w13_weight_scale = nn.Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) + layer.w2_weight = nn.Parameter(layer.w2_weight.data, requires_grad=False) + layer.w2_weight_scale = nn.Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) - # torch.compile() cannot use Parameter subclasses. - layer.w13_weight = nn.Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale = nn.Parameter(w13_weight_scale, requires_grad=False) - layer.w2_weight = nn.Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale = nn.Parameter(w2_weight_scale, requires_grad=False) + shuffle_weights(layer.w13_weight, layer.w2_weight) - shuffle_weights(layer.w13_weight, layer.w2_weight) + def _process_channel_quant(self, layer: nn.Module) -> None: + """PTPTC""" + self._normalize_weights_and_scales(layer) - return - else: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. - if not self.quant_config["is_dynamic"]: - if layer.w13_input_scale is None or layer.w2_input_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - # if (not all_close_1d(layer.w13_input_scale) - # or not all_close_1d(layer.w2_input_scale)): - # print( - # "Found input_scales that are not equal for " - # "fp8 MoE layer. Using the maximum across experts " - # "for each layer.") - layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False - ) - layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False + if layer.w13_weight.data.dtype in (torch.bfloat16, torch.float16): + quant_func = get_hip_quant(QuantType.per_Token) + for expert_id in range(layer.local_num_experts): + w13_q, w13_s = quant_func( + layer.w13_weight.data[expert_id], quant_dtype=dtypes.fp8 ) - if self.need_normalize_e4m3fn_to_e4m3fnuz: - # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale - ) + layer.w13_weight.data[expert_id] = w13_q + layer.w13_weight_scale.data[expert_id] = w13_s.squeeze(-1) + + w2_q, w2_s = quant_func( + layer.w2_weight.data[expert_id], quant_dtype=dtypes.fp8 ) - w2_weight, w2_weight_scale, w2_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale - ) + layer.w2_weight.data[expert_id] = w2_q + layer.w2_weight_scale.data[expert_id] = w2_s.squeeze(-1) + + shuffle_weights(layer.w13_weight, layer.w2_weight) + + def _process_tensor_quant(self, layer: nn.Module) -> None: + if not self.quant_config["is_dynamic"]: + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." ) - # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter( - w13_weight_scale, requires_grad=False + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) + + self._normalize_weights_and_scales(layer) + + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], ) - if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter( - w13_input_scale, requires_grad=False - ) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter( - w2_weight_scale, requires_grad=False + quant_func = get_hip_quant(self.quant_type) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + quant_func(dq_weight, max_w13_scales[expert_id]) ) - if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter( - w2_input_scale, requires_grad=False - ) - - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start : start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id], - ) - quant_func = get_hip_quant(self.quant_type) - layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( - quant_func(dq_weight, max_w13_scales[expert_id]) - ) - start += shard_size + start += shard_size - shuffle_weights(layer.w13_weight, layer.w2_weight) + shuffle_weights(layer.w13_weight, layer.w2_weight) - layer.w13_weight_scale = torch.nn.Parameter( - max_w13_scales, requires_grad=False - ) - return + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - return fp8_w8a8_moe_quant_config( - w1_scale=(layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=None, - ) + if self.channel_quant: + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=True, + ) + else: + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=None, + ) def apply( self, @@ -1669,7 +1678,8 @@ def apply( num_fused_shared_experts=layer.num_fused_shared_experts, routed_scaling_factor=layer.routed_scaling_factor, ) - # per_Tensor not support num_local_tokens so not use mori + # per_Tensor doesn't support num_local_tokens, so fallback to + # rocm_aiter_fused_moe when using per-tensor or no modular kernel. if self.quant_type == QuantType.per_Tensor or self.fused_experts is None: return torch.ops.aiter.rocm_aiter_fused_moe( x, @@ -1828,10 +1838,15 @@ def __init__( ): super().__init__() self.prefix = prefix + layer_quant_config = ( + quant_config.get_layer_quant_config(prefix) if quant_config else None + ) self.params_dtype = ( - quant_config["quant_dtype"] if quant_config else torch.get_default_dtype() + layer_quant_config["quant_dtype"] + if layer_quant_config + else torch.get_default_dtype() ) - self.quant_config = quant_config + self.layer_quant_config = layer_quant_config self.has_bias = has_bias # Note: here we guard against accessing the TP and DP groups when # uninitialized (this happens when testing) @@ -1946,28 +1961,28 @@ def __init__( ) self.moe_config = moe - if quant_config is not None and prefix: - quant_config = get_quant_config_for_layer(quant_config, prefix) - # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. - quant_method_str = quant_config.get("quant_method", None) - if quant_config["quant_type"] == QuantType.No: + + quant_method_str = layer_quant_config.get("quant_method", None) + if layer_quant_config["quant_type"] == QuantType.No: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( moe ) elif ( quant_method_str == "compressed-tensors" - and quant_config["quant_dtype"] == dtypes.fp8 + and layer_quant_config["quant_dtype"] == dtypes.fp8 ): # Use CompressedTensorsFp8MoEMethod for compressed-tensors format - self.quant_method = CompressedTensorsFp8MoEMethod(quant_config, moe) - elif quant_config["quant_dtype"] == dtypes.fp8: - self.quant_method = Fp8MoEMethod(quant_config, moe) - elif quant_config["quant_dtype"] == dtypes.fp4x2: - self.quant_method = Mxfp4MoEMethod(quant_config, moe) + self.quant_method = CompressedTensorsFp8MoEMethod(layer_quant_config, moe) + elif layer_quant_config["quant_dtype"] == dtypes.fp8: + self.quant_method = Fp8MoEMethod(layer_quant_config, moe) + elif layer_quant_config["quant_dtype"] == dtypes.fp4x2: + self.quant_method = Mxfp4MoEMethod(layer_quant_config, moe) else: - raise ValueError(f"Unsupported quant dtype: {quant_config['quant_dtype']}") + raise ValueError( + f"Unsupported quant dtype: {layer_quant_config['quant_dtype']}" + ) assert self.quant_method is not None @@ -2251,7 +2266,7 @@ def weight_loader( shard_id: str = "", expert_id: int = 0, ) -> None: - if self.quant_config["quant_dtype"] == dtypes.fp4x2 and weight_name == "": + if self.layer_quant_config["quant_dtype"] == dtypes.fp4x2 and weight_name == "": self.mxf4_merged_weight_loader(param, loaded_weight) return @@ -2329,7 +2344,7 @@ def weight_loader( # FusedMoeWeightScaleSupported # TODO @dsikka: once hardened, refactor to use vLLM Parameters # specific to each case - quant_method = self.quant_config["quant_type"] + quant_method = self.layer_quant_config["quant_type"] if quant_method == QuantType.per_Token: self._load_per_channel_weight_scale( shard_id=shard_id, diff --git a/atom/model_ops/topK.py b/atom/model_ops/topK.py index 5103f9256..5f187383a 100644 --- a/atom/model_ops/topK.py +++ b/atom/model_ops/topK.py @@ -19,7 +19,7 @@ def is_rocm_aiter_fusion_shared_expert_enabled() -> bool: quant_config = config.quant_config is_shared_experts_excluded = False is_experts_excluded = False - exclude_layers = quant_config["exclude_layers"] + exclude_layers = quant_config.exclude_layers for layer in exclude_layers: if "shared_experts" in layer: is_shared_experts_excluded = True diff --git a/atom/models/deepseek_mtp.py b/atom/models/deepseek_mtp.py index 5394d49e5..4ea300835 100644 --- a/atom/models/deepseek_mtp.py +++ b/atom/models/deepseek_mtp.py @@ -56,7 +56,8 @@ def __init__(self, atom_config: Config, prefix: str, layer_idx: int) -> None: ) quant_config = atom_config.quant_config - if quant_config["quant_dtype"] == dtypes.fp4x2: + layer_quant_config = quant_config.get_layer_quant_config(prefix) + if layer_quant_config["quant_dtype"] == dtypes.fp4x2: quant_config = QuantizationConfig() self.mtp_block = DeepseekV2DecoderLayer( diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index f0342dce1..d084c5fd3 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1247,7 +1247,10 @@ def __init__( # For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs, # For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs - if quant_config["quant_dtype"] == dtypes.fp4x2: + if ( + quant_config.get_layer_quant_config(f"{prefix}.fused_qkv_a_proj") + == dtypes.fp4x2 + ): # normally linear layers in attn share the same quant config if should_ignore_layer(quant_config, prefix): source_quant_dtype = None @@ -1276,6 +1279,7 @@ def __init__( bias=False, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=f"{prefix}.fused_qkv_a_proj", ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( @@ -1407,10 +1411,16 @@ def __init__( self.quant_dtype = None self.fuse_qknorm_quant = False if quant_config is not None and ENABLE_DS_QKNORM_QUANT_FUSION: - if quant_config["quant_dtype"] == dtypes.fp8 or ( - quant_config["quant_dtype"] == dtypes.fp4x2 and use_triton_gemm() + if quant_config.get_layer_quant_config( + f"{prefix}.fused_qkv_a_proj" + ) == dtypes.fp8 or ( + quant_config.get_layer_quant_config(f"{prefix}.fused_qkv_a_proj") + == dtypes.fp4x2 + and use_triton_gemm() ): - self.quant_dtype = quant_config["quant_dtype"] + self.quant_dtype = quant_config.get_layer_quant_config( + f"{prefix}.fused_qkv_a_proj" + )["quant_dtype"] self.fuse_qknorm_quant = True def forward( @@ -1554,10 +1564,10 @@ def __init__( self.fuse_ar_input_norm = ENABLE_ALLREDUCE_RMSNORM_FUSION if quant_config is not None and ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION: if ( - quant_config["quant_dtype"] == dtypes.fp8 - or quant_config["quant_dtype"] == dtypes.fp4x2 + quant_config.global_quant_config["quant_dtype"] == dtypes.fp8 + or quant_config.global_quant_config["quant_dtype"] == dtypes.fp4x2 ) and use_triton_gemm(): - self.quant_dtype = quant_config["quant_dtype"] + self.quant_dtype = quant_config.global_quant_config["quant_dtype"] self.fuse_input_norm_quant = True if self.fuse_ar_input_norm: self.fuse_ar_input_norm = False @@ -1604,7 +1614,9 @@ def __init__( fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, ) self.routed_scaling_factor = config.routed_scaling_factor - self.quant_dtype = quant_config["quant_dtype"] if quant_config else None + self.quant_dtype = ( + quant_config.global_quant_config["quant_dtype"] if quant_config else None + ) self.fuse_rmsnorm_quant = ( ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION and self.quant_dtype is not None ) diff --git a/atom/models/llama.py b/atom/models/llama.py index 4f485de69..cf70ca619 100644 --- a/atom/models/llama.py +++ b/atom/models/llama.py @@ -99,7 +99,7 @@ def __init__( self.act_fn = SiluAndMul( fused_quant=self.fused_act_quant, quant_config=quant_config ) - self.quant_type = quant_config["quant_type"] + self.quant_type = quant_config.global_quant_config["quant_type"] def forward(self, x, x_scale: Optional[torch.Tensor] = None): x = self.gate_up_proj(x, x_scale=x_scale) @@ -271,7 +271,7 @@ def __init__( ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT ) - self.quant_type = quant_config["quant_type"] + self.quant_type = quant_config.global_quant_config["quant_type"] self.self_attn = LlamaAttention( config=config, diff --git a/atom/models/utils.py b/atom/models/utils.py index 60334d78c..58fece010 100644 --- a/atom/models/utils.py +++ b/atom/models/utils.py @@ -262,16 +262,6 @@ def should_ignore_layer( return False -def get_quant_config_for_layer( - quantization_config: Optional[QuantizationConfig], prefix: str -) -> Optional[QuantizationConfig]: - return ( - None - if should_ignore_layer(quantization_config, prefix) - else quantization_config - ) - - def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int: """ Extract the layer index from the module name.