From 88037523a4814b4f53b173803d196f9adc8363c8 Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 31 Mar 2026 07:47:12 +0000 Subject: [PATCH 1/5] feat: extend moe model check to multimodal ones --- src/pruna/algorithms/moe_kernel_tuner.py | 5 +++++ src/pruna/engine/model_checks.py | 14 +++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 99929f81..d08b72d5 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -178,6 +178,11 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: model_config = getattr(model, "config", None) if model_config is None: raise ValueError(f"Model {model.__class__.__name__} has no config.") + # Multimodal MoE (e.g. Qwen3_5MoeForConditionalGeneration): MoE parameters live on text_config. + if getattr(model_config, "num_experts", None) is None: + text_cfg = getattr(model_config, "text_config", None) + if text_cfg is not None and getattr(text_cfg, "num_experts", None) is not None: + model_config = text_cfg tensor_parallel_size = int(smash_config["tensor_parallel_size"]) if model.__class__.__name__ == "HunyuanImage3ForCausalMM": diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index fa5fb763..ca7e7287 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -25,7 +25,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, ) from transformers.pipelines.automatic_speech_recognition import AutomaticSpeechRecognitionPipeline -from transformers.pipelines.text2text_generation import Text2TextGenerationPipeline +#from transformers.pipelines.text2text_generation import Text2TextGenerationPipeline from transformers.pipelines.text_generation import TextGenerationPipeline from pruna.engine.utils import ModelContext @@ -109,7 +109,9 @@ def is_moe_lm(model: Any) -> bool: """ Check if the model is a MoE LM. - Currently all MoE LMs are based on Mixtral in transformers. + Detects MoE via ``config.num_experts`` (e.g. Mixtral, Qwen-MoE text-only) + or via nested ``config.text_config.num_experts`` (e.g. multimodal + ``*ForConditionalGeneration`` wrappers). Parameters ---------- @@ -121,7 +123,13 @@ def is_moe_lm(model: Any) -> bool: bool True if the model is a MoE LM, False otherwise. """ - return hasattr(getattr(model, "config", None), "num_experts") + config = getattr(model, "config", None) + if config is None: + return False + if getattr(config, "num_experts", None) is not None: + return True + text_cfg = getattr(config, "text_config", None) + return text_cfg is not None and getattr(text_cfg, "num_experts", None) is not None def is_transformers_pipeline_with_causal_lm(model: Any) -> bool: From 9b0c6f1db5e10fb9d85191900d94b74df8ffc63a Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 31 Mar 2026 07:59:10 +0000 Subject: [PATCH 2/5] fix: remove the commented out import (should be fixed with later transformers versions) --- src/pruna/engine/model_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index ca7e7287..2f15a4eb 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -25,7 +25,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, ) from transformers.pipelines.automatic_speech_recognition import AutomaticSpeechRecognitionPipeline -#from transformers.pipelines.text2text_generation import Text2TextGenerationPipeline +from transformers.pipelines.text2text_generation import Text2TextGenerationPipeline from transformers.pipelines.text_generation import TextGenerationPipeline from pruna.engine.utils import ModelContext From 4fd38f4e35c0f6897a71848548c5babf3490e865 Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 31 Mar 2026 14:14:31 +0000 Subject: [PATCH 3/5] feat: add support for block quantization --- src/pruna/algorithms/moe_kernel_tuner.py | 59 +++++++++++++++++++++++- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index d08b72d5..3d3cfa7f 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -200,6 +200,13 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: use_fp8_w8a8 = smash_config["weight_dtype"] == "fp8_w8a8" use_int8_w8a16 = smash_config["weight_dtype"] == "int8_w8a16" + # (ii.b) Extract block quantization shape so the tuned config file + # name matches what vLLM looks for at serving time (e.g. + # ``E=256,N=512,...,block_shape=[128,128].json``). + block_quant_shape = _get_block_quant_shape(getattr(model, "config", None)) + if block_quant_shape is None: + block_quant_shape = _get_block_quant_shape(model_config) + # (iii) Tune the kernel over a range of batch sizes (single GPU per Ray worker). batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] ray = imported_packages["ray"] @@ -211,6 +218,16 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: ray.init(ignore_reinit_error=True) search_space = get_configs_compute_bound(smash_config) + + # Remove configs incompatible with block quantisation constraints + # (BLOCK_SIZE_K must be divisible by block_k, BLOCK_SIZE_N by block_n). + if block_quant_shape is not None and use_fp8_w8a8: + bq_n, bq_k = block_quant_shape[0], block_quant_shape[1] + search_space = [ + cfg for cfg in search_space + if cfg["BLOCK_SIZE_K"] % bq_k == 0 and cfg["BLOCK_SIZE_N"] % bq_n == 0 + ] + pruna_logger.info(f"Start tuning over {len(search_space)} configurations...") start = time.time() @@ -231,7 +248,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: use_fp8_w8a8, use_int8_w8a16, search_space, - None, + block_quant_shape, False, imported_packages, 0, # fixed seed for reproducibility @@ -271,7 +288,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: dtype, use_fp8_w8a8, use_int8_w8a16, - None, + block_quant_shape, smash_config["path_to_huggingface_hub_cache"], smash_config["path_to_vllm_cache"], imported_packages, @@ -345,6 +362,44 @@ def import_algorithm_packages(self) -> dict[str, Any]: ) +def _get_block_quant_shape(config: Any) -> list[int] | None: + """ + Extract block quantisation shape from a HuggingFace model config. + + Mirrors vLLM's ``get_weight_block_size_safety`` in + ``benchmarks/kernels/benchmark_moe.py``. Checks ``weight_block_size`` + (AWQ / GPTQ / FP8 style) and ``config_groups.*.weights.block_structure`` + (compressed-tensors style). + + Parameters + ---------- + config : Any + A HuggingFace ``PretrainedConfig`` (or its ``text_config``). + + Returns + ------- + list[int] | None + ``[block_n, block_k]`` if found, else ``None``. + """ + if config is None: + return None + quantization_config = getattr(config, "quantization_config", None) + if quantization_config is None: + return None + if isinstance(quantization_config, dict): + wbs = quantization_config.get("weight_block_size") + if wbs is not None: + return list(wbs) + config_groups = quantization_config.get("config_groups", {}) + for group_cfg in config_groups.values(): + if isinstance(group_cfg, dict): + weights = group_cfg.get("weights", {}) + bs = weights.get("block_structure") + if bs is not None: + return list(bs) + return None + + def extract_hunyuan_dimensions( model: Any, model_config: Any, From 21f015eaa425744a868838546171b76a8e191055 Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 31 Mar 2026 14:39:28 +0000 Subject: [PATCH 4/5] feat: make block quant specified by the user --- src/pruna/algorithms/moe_kernel_tuner.py | 79 +++++++++--------------- 1 file changed, 28 insertions(+), 51 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 3d3cfa7f..aaf8f4a3 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -130,6 +130,18 @@ def get_hyperparameters(self) -> list: default_value=8, meta={"desc": "Maximum (log) block size for tiling through intermediate dimension."}, ), + OrdinalHyperparameter( + "block_quant_shape_n", + sequence=[32, 64, 128, 256, 512, 1024, 2048, 4096, None], + default_value=None, + meta={"desc": "Block size for quantization through input dimension."}, + ), + OrdinalHyperparameter( + "block_quant_shape_k", + sequence=[32, 64, 128, 256, 512, 1024, 2048, 4096, None], + default_value=None, + meta={"desc": "Block size for quantization through intermediate dimension."}, + ), ] def model_check_fn(self, model: Any) -> bool: @@ -199,13 +211,10 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 use_fp8_w8a8 = smash_config["weight_dtype"] == "fp8_w8a8" use_int8_w8a16 = smash_config["weight_dtype"] == "int8_w8a16" - - # (ii.b) Extract block quantization shape so the tuned config file - # name matches what vLLM looks for at serving time (e.g. - # ``E=256,N=512,...,block_shape=[128,128].json``). - block_quant_shape = _get_block_quant_shape(getattr(model, "config", None)) - if block_quant_shape is None: - block_quant_shape = _get_block_quant_shape(model_config) + block_quant_shape = [ + smash_config["block_quant_shape_n"], + smash_config["block_quant_shape_k"], + ] # (iii) Tune the kernel over a range of batch sizes (single GPU per Ray worker). batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] @@ -219,13 +228,19 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: ray.init(ignore_reinit_error=True) search_space = get_configs_compute_bound(smash_config) - # Remove configs incompatible with block quantisation constraints - # (BLOCK_SIZE_K must be divisible by block_k, BLOCK_SIZE_N by block_n). - if block_quant_shape is not None and use_fp8_w8a8: - bq_n, bq_k = block_quant_shape[0], block_quant_shape[1] + # Remove configs incompatible with block quantisation constraints: + # - BLOCK_SIZE_K must be divisible by block_quant_shape_k + # - BLOCK_SIZE_N must be divisible by block_quant_shape_n + if ( + smash_config["block_quant_shape_n"] is not None + and smash_config["block_quant_shape_k"] is not None + and use_fp8_w8a8 + ): search_space = [ - cfg for cfg in search_space - if cfg["BLOCK_SIZE_K"] % bq_k == 0 and cfg["BLOCK_SIZE_N"] % bq_n == 0 + cfg + for cfg in search_space + if cfg["BLOCK_SIZE_K"] % smash_config["block_quant_shape_k"] == 0 + and cfg["BLOCK_SIZE_N"] % smash_config["block_quant_shape_n"] == 0 ] pruna_logger.info(f"Start tuning over {len(search_space)} configurations...") @@ -362,44 +377,6 @@ def import_algorithm_packages(self) -> dict[str, Any]: ) -def _get_block_quant_shape(config: Any) -> list[int] | None: - """ - Extract block quantisation shape from a HuggingFace model config. - - Mirrors vLLM's ``get_weight_block_size_safety`` in - ``benchmarks/kernels/benchmark_moe.py``. Checks ``weight_block_size`` - (AWQ / GPTQ / FP8 style) and ``config_groups.*.weights.block_structure`` - (compressed-tensors style). - - Parameters - ---------- - config : Any - A HuggingFace ``PretrainedConfig`` (or its ``text_config``). - - Returns - ------- - list[int] | None - ``[block_n, block_k]`` if found, else ``None``. - """ - if config is None: - return None - quantization_config = getattr(config, "quantization_config", None) - if quantization_config is None: - return None - if isinstance(quantization_config, dict): - wbs = quantization_config.get("weight_block_size") - if wbs is not None: - return list(wbs) - config_groups = quantization_config.get("config_groups", {}) - for group_cfg in config_groups.values(): - if isinstance(group_cfg, dict): - weights = group_cfg.get("weights", {}) - bs = weights.get("block_structure") - if bs is not None: - return list(bs) - return None - - def extract_hunyuan_dimensions( model: Any, model_config: Any, From ef9e4858b68b0f1bf8dbf110a3d1402dec2a0ded Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 31 Mar 2026 15:07:44 +0000 Subject: [PATCH 5/5] feat: make block quant specified by the user --- src/pruna/algorithms/moe_kernel_tuner.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index aaf8f4a3..f4a9f44b 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -211,10 +211,15 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 use_fp8_w8a8 = smash_config["weight_dtype"] == "fp8_w8a8" use_int8_w8a16 = smash_config["weight_dtype"] == "int8_w8a16" - block_quant_shape = [ - smash_config["block_quant_shape_n"], - smash_config["block_quant_shape_k"], - ] + block_quant_shape = None + if ( + smash_config["block_quant_shape_n"] is not None + and smash_config["block_quant_shape_k"] is not None + ): + block_quant_shape = [ + smash_config["block_quant_shape_n"], + smash_config["block_quant_shape_k"], + ] # (iii) Tune the kernel over a range of batch sizes (single GPU per Ray worker). batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]