Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def __init__(
quant_name="",
quant_method=None,
exclude_layers: Optional[list[str]] = None,
quark_w4a8: bool = False,
):
super().__init__()
self["quant_type"] = quant_type if quant_type is not None else QuantType.No
Expand All @@ -263,6 +264,9 @@ def __init__(
self["is_dynamic"] = is_dynamic
self["quant_method"] = quant_method
self["exclude_layers"] = exclude_layers if exclude_layers is not None else []
# When True the checkpoint stores INT4-packed expert weights that must
# be dequantised and re-quantised to FP8 during model loading.
self["quark_w4a8"] = quark_w4a8

def get_name(self):
return self["quant_name"]
Expand Down Expand Up @@ -290,6 +294,27 @@ def compute_hash(self) -> str:
return hashlib.sha256(str(factors).encode()).hexdigest()


def _is_quark_w4a8(orig_quant_config: dict) -> bool:
"""Detect Quark W4A8 checkpoint format.

The signature is ``quant_method == "quark"`` with a
``global_quant_config.weight`` list that contains an INT4 entry and
FP8 activations.
"""
if orig_quant_config.get("quant_method") != "quark":
return False
global_cfg = orig_quant_config.get("global_quant_config", {})
weight_specs = global_cfg.get("weight", [])
if not isinstance(weight_specs, list):
return False
has_int4_weight = any(
w.get("dtype", "").startswith("int4") for w in weight_specs if isinstance(w, dict)
)
input_cfg = global_cfg.get("input_tensors") or {}
has_fp8_act = isinstance(input_cfg, dict) and "fp8" in input_cfg.get("dtype", "")
return has_int4_weight and has_fp8_act


def get_quant_config(config: PretrainedConfig) -> QuantizationConfig:
torch_dtype = getattr(config, "torch_dtype", "bf16")
orig_quant_config = getattr(config, "quantization_config", None)
Expand All @@ -300,6 +325,25 @@ def get_quant_config(config: PretrainedConfig) -> QuantizationConfig:
)

quant_method = orig_quant_config.get("quant_method", None)

# ------------------------------------------------------------------
# Quark W4A8: INT4 weights + FP8 activations. At runtime the INT4
# expert weights are dequantised and re-quantised to FP8, so we
# advertise FP8 per-tensor as the runtime format and set a flag so
# that the MoE layer knows to do the conversion.
# ------------------------------------------------------------------
if _is_quark_w4a8(orig_quant_config):
logger.info("Detected Quark W4A8 checkpoint – will convert INT4 experts to FP8")
exclude_layers = orig_quant_config.get("exclude", None)
return QuantizationConfig(
quant_type=QuantType.per_Tensor,
quant_dtype=d_dtypes["fp8"],
is_dynamic=True,
quant_method=quant_method,
exclude_layers=exclude_layers,
quark_w4a8=True,
)

RE_QUANT_BLOCKSIZE = r"\'(?:group_size|weight_block_size)\'\:\s*(?:\[\n*)\s*(\d+),"
orig_quant_config_str = str(orig_quant_config)
if quant_method == "compressed-tensors" or "channel'," in orig_quant_config_str:
Expand Down Expand Up @@ -371,6 +415,13 @@ def get_quant_config(config: PretrainedConfig) -> QuantizationConfig:

_CONFIG_REGISTRY: dict[str, str] = {
"deepseek_v32": "deepseek_v3",
"kimi_k2": "deepseek_v3",
}


_MULTIMODAL_MODEL_TYPES: dict[str, str] = {
# Maps multimodal model_type -> key in config_dict for the text sub-config
"kimi_k25": "text_config",
}


Expand All @@ -386,6 +437,32 @@ def _get_hf_token() -> str | None:
return token
return None

# For multimodal models, extract the text sub-config so the rest of ATOM
# (which is text-only today) works transparently.
if model_type in _MULTIMODAL_MODEL_TYPES:
text_config_key = _MULTIMODAL_MODEL_TYPES[model_type]
text_config_dict = config_dict.get(text_config_key, {}).copy()
# Remove auto_map to avoid trust_remote_code issues
text_config_dict.pop("auto_map", None)
# Propagate quantization_config from root level into text config
# (quantization_config lives alongside text_config, not inside it).
if "quantization_config" not in text_config_dict and "quantization_config" in config_dict:
text_config_dict["quantization_config"] = config_dict["quantization_config"]
text_model_type = text_config_dict.get("model_type", "deepseek_v3")
mapped_type = _CONFIG_REGISTRY.get(text_model_type, text_model_type)
config_class = AutoConfig.for_model(mapped_type)
hf_config = config_class.from_dict(text_config_dict)
# Override architectures so that ATOM selects the correct model class
# which can handle the multimodal weight prefix during loading.
original_arch = config_dict.get("architectures", [])
if original_arch:
hf_config.architectures = original_arch
# Propagate top-level token IDs if missing in text config
for field in ("bos_token_id", "eos_token_id", "pad_token_id"):
if getattr(hf_config, field, None) is None and field in config_dict:
setattr(hf_config, field, config_dict[field])
return hf_config

if model_type in _CONFIG_REGISTRY:
config_class = AutoConfig.for_model(_CONFIG_REGISTRY[model_type])
return config_class.from_pretrained(
Expand Down
2 changes: 2 additions & 0 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"DeepseekV32ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM",
"GptOssForCausalLM": "atom.models.gpt_oss.GptOssForCausalLM",
"Glm4MoeForCausalLM": "atom.models.glm4_moe.Glm4MoeForCausalLM",
"KimiK25ForConditionalGeneration": "atom.models.kimi_k25.KimiK25ForCausalLM",
}
# seed = 34567
# np.random.seed(seed)
Expand Down Expand Up @@ -628,6 +629,7 @@ def is_deepseek_mla(self) -> bool:
"deepseek_v3",
"deepseek_v32",
"deepseek_mtp",
"kimi_k2",
):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == "eagle":
Expand Down
8 changes: 8 additions & 0 deletions atom/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def load_model(
):
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
weights_mapping = getattr(model, "weights_mapping", {})
skip_weight_prefixes = getattr(model, "skip_weight_prefixes", [])
params_dict = dict(model.named_parameters())
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
Expand All @@ -100,6 +101,13 @@ def load_model(
continue
if name.endswith("kv_scale"):
continue
# Skip weights matching model-defined prefixes (e.g. vision encoder
# weights in multimodal checkpoints that are not needed for text-only
# inference).
if skip_weight_prefixes and any(
name.startswith(p) for p in skip_weight_prefixes
):
continue
if spec_decode:
spec_layer = get_spec_layer_idx_from_weight_name(hf_config, name)
if spec_layer is None:
Expand Down
Loading