diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 99929f81..f4a9f44b 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -130,6 +130,18 @@ def get_hyperparameters(self) -> list: default_value=8, meta={"desc": "Maximum (log) block size for tiling through intermediate dimension."}, ), + OrdinalHyperparameter( + "block_quant_shape_n", + sequence=[32, 64, 128, 256, 512, 1024, 2048, 4096, None], + default_value=None, + meta={"desc": "Block size for quantization through input dimension."}, + ), + OrdinalHyperparameter( + "block_quant_shape_k", + sequence=[32, 64, 128, 256, 512, 1024, 2048, 4096, None], + default_value=None, + meta={"desc": "Block size for quantization through intermediate dimension."}, + ), ] def model_check_fn(self, model: Any) -> bool: @@ -178,6 +190,11 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: model_config = getattr(model, "config", None) if model_config is None: raise ValueError(f"Model {model.__class__.__name__} has no config.") + # Multimodal MoE (e.g. Qwen3_5MoeForConditionalGeneration): MoE parameters live on text_config. + if getattr(model_config, "num_experts", None) is None: + text_cfg = getattr(model_config, "text_config", None) + if text_cfg is not None and getattr(text_cfg, "num_experts", None) is not None: + model_config = text_cfg tensor_parallel_size = int(smash_config["tensor_parallel_size"]) if model.__class__.__name__ == "HunyuanImage3ForCausalMM": @@ -194,6 +211,15 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 use_fp8_w8a8 = smash_config["weight_dtype"] == "fp8_w8a8" use_int8_w8a16 = smash_config["weight_dtype"] == "int8_w8a16" + block_quant_shape = None + if ( + smash_config["block_quant_shape_n"] is not None + and smash_config["block_quant_shape_k"] is not None + ): + block_quant_shape = [ + smash_config["block_quant_shape_n"], + smash_config["block_quant_shape_k"], + ] # (iii) Tune the kernel over a range of batch sizes (single GPU per Ray worker). batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] @@ -206,6 +232,22 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: ray.init(ignore_reinit_error=True) search_space = get_configs_compute_bound(smash_config) + + # Remove configs incompatible with block quantisation constraints: + # - BLOCK_SIZE_K must be divisible by block_quant_shape_k + # - BLOCK_SIZE_N must be divisible by block_quant_shape_n + if ( + smash_config["block_quant_shape_n"] is not None + and smash_config["block_quant_shape_k"] is not None + and use_fp8_w8a8 + ): + search_space = [ + cfg + for cfg in search_space + if cfg["BLOCK_SIZE_K"] % smash_config["block_quant_shape_k"] == 0 + and cfg["BLOCK_SIZE_N"] % smash_config["block_quant_shape_n"] == 0 + ] + pruna_logger.info(f"Start tuning over {len(search_space)} configurations...") start = time.time() @@ -226,7 +268,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: use_fp8_w8a8, use_int8_w8a16, search_space, - None, + block_quant_shape, False, imported_packages, 0, # fixed seed for reproducibility @@ -266,7 +308,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: dtype, use_fp8_w8a8, use_int8_w8a16, - None, + block_quant_shape, smash_config["path_to_huggingface_hub_cache"], smash_config["path_to_vllm_cache"], imported_packages, diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index fa5fb763..2f15a4eb 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -109,7 +109,9 @@ def is_moe_lm(model: Any) -> bool: """ Check if the model is a MoE LM. - Currently all MoE LMs are based on Mixtral in transformers. + Detects MoE via ``config.num_experts`` (e.g. Mixtral, Qwen-MoE text-only) + or via nested ``config.text_config.num_experts`` (e.g. multimodal + ``*ForConditionalGeneration`` wrappers). Parameters ---------- @@ -121,7 +123,13 @@ def is_moe_lm(model: Any) -> bool: bool True if the model is a MoE LM, False otherwise. """ - return hasattr(getattr(model, "config", None), "num_experts") + config = getattr(model, "config", None) + if config is None: + return False + if getattr(config, "num_experts", None) is not None: + return True + text_cfg = getattr(config, "text_config", None) + return text_cfg is not None and getattr(text_cfg, "num_experts", None) is not None def is_transformers_pipeline_with_causal_lm(model: Any) -> bool: