Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 222 additions & 103 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import re
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from typing import Any, cast, Optional, Union

import torch
from aiter import QuantType
Expand Down Expand Up @@ -250,126 +250,246 @@ def set_splitting_ops_for_v1(self):
]


class QuantizationConfig(dict):
class LayerQuantConfig(dict):
def __init__(
self,
quant_type=QuantType.No,
quant_dtype=torch.bfloat16,
is_dynamic=True,
quant_name="",
quant_method=None,
exclude_layers: Optional[list[str]] = None,
quant_name="",
):
"""
Core components of layer_quant
"""
super().__init__()
self["quant_type"] = quant_type if quant_type is not None else QuantType.No
self["quant_dtype"] = quant_dtype if quant_dtype is not None else torch.bfloat16
self["quant_name"] = quant_name
self["is_dynamic"] = is_dynamic
self["quant_method"] = quant_method
self["exclude_layers"] = exclude_layers if exclude_layers is not None else []
self["quant_name"] = quant_name

def get_name(self):
return self["quant_name"]

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
class QuantizationConfig:
def __init__(self, config: PretrainedConfig = None):
if config is None:
self.torch_dtype = torch.bfloat16
self.hf_quant_config = None
self.global_quant_config = LayerQuantConfig()
self.layer_quant_config = {}
self.exclude_layers = []
self.quant_method = ""
return

self.torch_dtype = getattr(config, "torch_dtype", "bf16")
self.hf_quant_config = getattr(config, "quantization_config", None)
self.global_quant_config = None
self.layer_quant_config = {}
self.exclude_layers = []

if self.hf_quant_config is None:
self.global_quant_config = LayerQuantConfig(
quant_type=QuantType.No, quant_dtype=self.torch_dtype
)
self.quant_method = None
return

Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = []
factors.append(self["quant_type"])
factors.append(self["quant_dtype"])
factors.append(self["quant_name"])
factors.append(self["is_dynamic"])
factors.append(self["quant_method"])
# assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest()
self.quant_method = self.hf_quant_config.get("quant_method", "")
if self.quant_method == "quark":
layer_quant_config_dict = cast(
dict[str, Any], self.hf_quant_config.get("layer_quant_config")
)
for layer_name, layer_cfg in layer_quant_config_dict.items():
self.layer_quant_config[layer_name] = self.parse_quark_config_dict(
layer_cfg
)

global_quant_config_dict = cast(
dict[str, Any], self.hf_quant_config.get("global_quant_config")
)
self.global_quant_config = self.parse_quark_config_dict(
global_quant_config_dict
)

def get_quant_config(config: PretrainedConfig) -> QuantizationConfig:
torch_dtype = getattr(config, "dtype", "bf16")
orig_quant_config = getattr(config, "quantization_config", None)
if orig_quant_config is None:
return QuantizationConfig(
quant_type=QuantType.No,
quant_dtype=torch_dtype,
)
self.exclude_layers = cast(list[str], self.hf_quant_config.get("exclude"))
else:
self.parse_other_config()

quant_method = orig_quant_config.get("quant_method", None)
RE_QUANT_BLOCKSIZE = r"\'(?:group_size|weight_block_size)\'\:\s*(?:\[\n*)\s*(\d+),"
orig_quant_config_str = str(orig_quant_config)
if quant_method == "compressed-tensors" or "channel'," in orig_quant_config_str:
quant_type = QuantType.per_Token
elif group_size := re.search(RE_QUANT_BLOCKSIZE, orig_quant_config_str):
group_size = int(group_size.group(1))
assert group_size in (32, 128), f"Unsupported group size {group_size}"
if group_size == 128:
quant_type = QuantType.per_1x128
elif group_size == 32:
def get_name(self):
"""
from original quant_config func
"""
return self.quant_method

def parse_quark_config_dict(self, config: dict) -> LayerQuantConfig:
quant_type = None
quant_dtype = None
is_dynamic = False
# parse quark config dict
weight_config = cast(dict[str, Any], config.get("weight"))
input_config = cast(dict[str, Any], config.get("input_tensors"))
weight_qscheme = cast(str, weight_config.get("qscheme"))
weight_dtype = weight_config.get("dtype")

# quant_type
if weight_qscheme == "per_channel":
quant_type = QuantType.per_Token
elif weight_qscheme == "per_tensor":
quant_type = QuantType.per_Tensor
elif weight_qscheme == "per_group":
# Currently, quark only supports group_size=32
quant_type = QuantType.per_1x32
else:
quant_type = QuantType.per_Tensor

RE_QUANT_DTYPE = r"\'(?:d?type|weight_dtype|quant_method)\'\:\s*\'(\w+)\'"
quant_dtype = None
m = re.search(RE_QUANT_DTYPE, orig_quant_config_str)
if m and m.group(1).lower() in ["fp8", "fp4", "int8", "int4", "fp8_e4m3", "mxfp4"]:
dtype = m.group(1).lower().split("_")[0]
if dtype == "mxfp4":
dtype = "fp4"
# quant_dtype
dtype = weight_dtype.split("_")[0]
if dtype.endswith("4"):
dtype += "x2"
quant_dtype = d_dtypes[dtype]
else:
bit_match = re.search(r"\'(?:num_)?bits\'\:\s*(\d+)", orig_quant_config_str)
if bit_match:
bit = int(bit_match.group(1))
dtype_match = re.search(RE_QUANT_DTYPE, orig_quant_config_str)
if dtype_match:
dtype = dtype_match.group(1).lower()
dtype_prefix = "i" if dtype.startswith("int") else "fp"
else:
dtype_prefix = "i"
quant_dtype_str = (
f"{dtype_prefix}{bit}" if bit != 4 else f"{dtype_prefix}{bit}x2"

# is_dynamic
if input_config is not None:
# input_dtype = input_config.get("dtype")
# input_qscheme = cast(str, input_config.get("qscheme"))
is_dynamic = not cast(bool, input_config.get("is_dynamic"))
return LayerQuantConfig(
quant_type=quant_type,
quant_dtype=quant_dtype,
is_dynamic=is_dynamic,
quant_method="quark",
)

# TODO: For now, it's just a temporary migration.
# We should subsequently refine them in a targeted manner.
def parse_other_config(self):
RE_QUANT_BLOCKSIZE = (
r"\'(?:group_size|weight_block_size)\'\:\s*(?:\[\n*)\s*(\d+),"
)
orig_quant_config = self.hf_quant_config
quant_method = self.quant_method
orig_quant_config_str = str(orig_quant_config)
if quant_method == "compressed-tensors" or "channel'," in orig_quant_config_str:
quant_type = QuantType.per_Token
elif group_size := re.search(RE_QUANT_BLOCKSIZE, orig_quant_config_str):
group_size = int(group_size.group(1))
assert group_size in (32, 128), f"Unsupported group size {group_size}"
if group_size == 128:
quant_type = QuantType.per_1x128
elif group_size == 32:
quant_type = QuantType.per_1x32
else:
quant_type = QuantType.per_Tensor

RE_QUANT_DTYPE = r"\'(?:d?type|weight_dtype|quant_method)\'\:\s*\'(\w+)\'"
quant_dtype = None
m = re.search(RE_QUANT_DTYPE, orig_quant_config_str)
if m and m.group(1).lower() in [
"fp8",
"fp4",
"int8",
"int4",
"fp8_e4m3",
"mxfp4",
]:
dtype = m.group(1).lower().split("_")[0]
if dtype == "mxfp4":
dtype = "fp4"
if dtype.endswith("4"):
dtype += "x2"
quant_dtype = d_dtypes[dtype]
else:
bit_match = re.search(r"\'(?:num_)?bits\'\:\s*(\d+)", orig_quant_config_str)
if bit_match:
bit = int(bit_match.group(1))
dtype_match = re.search(RE_QUANT_DTYPE, orig_quant_config_str)
if dtype_match:
dtype = dtype_match.group(1).lower()
dtype_prefix = "i" if dtype.startswith("int") else "fp"
else:
dtype_prefix = "i"
quant_dtype_str = (
f"{dtype_prefix}{bit}" if bit != 4 else f"{dtype_prefix}{bit}x2"
)
quant_dtype = d_dtypes.get(quant_dtype_str, None)
assert (
quant_dtype is not None
), f"Cannot parse quant dtype from {orig_quant_config_str}"
if quant_dtype == d_dtypes["fp4x2"]:
quant_type = QuantType.per_1x32

RE_STATIC_QUANT = r"\'(?:activation_scheme)\'\:\s*\'(static)\'"
if re.search(RE_STATIC_QUANT, orig_quant_config_str):
is_dynamic = False
else:
is_dynamic = True
if quant_method == "compressed-tensors":
exclude_layers_key = "ignore"
elif quant_method == "quark":
exclude_layers_key = "exclude"
else:
logger.warning(
f"Using 'ignore' as key for exclude layers with quant_method {quant_method}, \
please double check the quantization config."
)
quant_dtype = d_dtypes.get(quant_dtype_str, None)
assert (
quant_dtype is not None
), f"Cannot parse quant dtype from {orig_quant_config_str}"
if quant_dtype == d_dtypes["fp4x2"]:
quant_type = QuantType.per_1x32

RE_STATIC_QUANT = r"\'(?:activation_scheme)\'\:\s*\'(static)\'"
if re.search(RE_STATIC_QUANT, orig_quant_config_str):
is_dynamic = False
else:
is_dynamic = True
if quant_method == "compressed-tensors":
exclude_layers_key = "ignore"
elif quant_method == "quark":
exclude_layers_key = "exclude"
else:
logger.warning(
f"Using 'ignore' as key for exclude layers with quant_method {quant_method}, \
please double check the quantization config."
exclude_layers_key = "ignore"
exclude_layers = orig_quant_config.get(exclude_layers_key, [])

self.global_quant_config = LayerQuantConfig(
quant_type=quant_type, quant_dtype=quant_dtype, is_dynamic=is_dynamic
)
exclude_layers_key = "ignore"
exclude_layers = orig_quant_config.get(exclude_layers_key, None)
return QuantizationConfig(
quant_type,
quant_dtype,
is_dynamic,
quant_method=quant_method,
exclude_layers=exclude_layers,
)
# self.layer_quant_config = None
self.exclude_layers = exclude_layers

def should_ignore_layer_quant(self, layer_name: str) -> bool:
# TODO: solve fused_mapping case
if layer_name is None or not self.exclude_layers:
return False
return any(
self.is_equal_or_regex_match(layer_name, ignore_str)
for ignore_str in self.exclude_layers
)

def is_equal_or_regex_match(
self, layer_name: str, ignore_str: str, check_contains: bool = False
) -> bool:
"""Match the target string or regular expression"""
if ignore_str.startswith("re:"):
pattern = ignore_str[3:]
if re.match(pattern, layer_name):
return True
elif check_contains:
if ignore_str.lower() in layer_name.lower():
return True
elif ignore_str == layer_name:
return True
return False

def get_layer_quant_config(self, layer_name: str) -> LayerQuantConfig:
if self.should_ignore_layer_quant(layer_name=layer_name):
# return unquantized config
return LayerQuantConfig(quant_dtype=self.torch_dtype)
# layer quant config
layer_quant_config = None
if self.layer_quant_config:
import fnmatch

def _matches_pattern(layer_name, pattern):
if "*" not in pattern:
return layer_name in pattern
return fnmatch.fnmatch(layer_name, pattern)

for name_pattern, config in self.layer_quant_config.items():
if _matches_pattern(layer_name, name_pattern):
layer_quant_config = config

layer_quant_config = (
self.global_quant_config
if layer_quant_config is None
else layer_quant_config
)
# TODO: if use_aiter, we can customize the quantization format here, such as dpsk
# For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs,
# For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs

return layer_quant_config


_CONFIG_REGISTRY: dict[str, str] = {
Expand Down Expand Up @@ -590,9 +710,7 @@ class Config:
port: int = 8006
torch_profiler_dir: str | None = os.getenv("ATOM_TORCH_PROFILER_DIR", None)
compilation_config: CompilationConfig = field(default_factory=CompilationConfig)
quant_config: QuantizationConfig = field(
default_factory=lambda: QuantizationConfig()
)
quant_config: QuantizationConfig = field(init=False)
asyncio_mode: bool = False
load_dummy: bool = False
enable_expert_parallel: bool = False
Expand Down Expand Up @@ -637,7 +755,7 @@ def __post_init__(self):
eos_ids := getattr(self.generation_config, "eos_token_id", None)
) is not None:
self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids
self.quant_config = get_quant_config(self.hf_config)
self.quant_config = QuantizationConfig(self.hf_config)
hf_config_max_position_embeddings = getattr(
self.hf_config, "max_position_embeddings", 8192
)
Expand Down Expand Up @@ -695,8 +813,9 @@ def compute_hash(self) -> str:

# summarize vllm config
vllm_factors: list[Any] = []
if self.quant_config:
vllm_factors.append(self.quant_config.compute_hash())
# TODO: fix here
# if self.quant_config:
# vllm_factors.append(self.quant_config.compute_hash())

if self.compilation_config:
vllm_factors.append(self.compilation_config.compute_hash())
Expand Down
Loading
Loading