diff --git a/pyproject.toml b/pyproject.toml index 5b1eb704..6606096d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,6 +165,10 @@ vllm = [ "vllm>=0.16.0", "ray", ] +llamacpp = [ + "llama-cpp-python>=0.2.78", + "gguf>=0.6.0", +] stable-fast = [ "xformers>=0.0.30", "stable-fast-pruna==1.0.8", @@ -187,6 +191,8 @@ awq = [ full = [ "xformers>=0.0.30", "stable-fast-pruna==1.0.8", + "llama-cpp-python>=0.2.78", + "gguf>=0.6.0", ] vbench = [ "vbench-pruna; sys_platform != 'darwin'", diff --git a/src/pruna/algorithms/base/pruna_base.py b/src/pruna/algorithms/base/pruna_base.py index 0784069b..4d585eda 100644 --- a/src/pruna/algorithms/base/pruna_base.py +++ b/src/pruna/algorithms/base/pruna_base.py @@ -28,6 +28,7 @@ SAVE_FUNCTIONS, save_pruna_model, ) +from pruna.engine.utils import get_fn_name from pruna.logging.logger import pruna_logger @@ -365,7 +366,8 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any: # if the registered save function is None, the original saving function remains if self.save_fn is not None and self.save_fn != SAVE_FUNCTIONS.reapply: - smash_config.save_fns.append(self.save_fn.name) + fn_name = get_fn_name(self.save_fn) + smash_config.save_fns.append(fn_name) prefix = self.algorithm_name + "_" wrapped_config = SmashConfigPrefixWrapper(smash_config, prefix) diff --git a/src/pruna/algorithms/llama_cpp.py b/src/pruna/algorithms/llama_cpp.py new file mode 100644 index 00000000..86d70271 --- /dev/null +++ b/src/pruna/algorithms/llama_cpp.py @@ -0,0 +1,231 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import shutil +import subprocess +import sys +import tempfile +import urllib.request +from pathlib import Path +from typing import Any, Dict + +from ConfigSpace import OrdinalHyperparameter + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag as tags +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import ( + is_causal_lm, + is_transformers_pipeline_with_causal_lm, +) +from pruna.engine.save import SAVE_FUNCTIONS +from pruna.engine.utils import verify_sha256 +from pruna.logging.logger import pruna_logger + +# SHA256 hash for the pinned version (b3600) of convert_hf_to_gguf.py +LLAMA_CPP_CONVERSION_SCRIPT_SHA256 = "f62ab712618231b3e76050f94e45dcf94567312c209b4b99bfc142229360b018" + + +class LlamaCpp(PrunaAlgorithmBase): + """ + Implement Llama.cpp as a quantizer. + + Converts Hugging Face models to GGUF format and quantizes them using the llama.cpp tools. + """ + + algorithm_name: str = "llama_cpp" + group_tags: list[tags] = [tags.QUANTIZER] + references: dict[str, str] = { + "GitHub": "https://github.com/ggml-org/llama.cpp", + "Python Bindings": "https://github.com/abetlen/llama-cpp-python", + } + save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.llama_cpp + tokenizer_required: bool = False + processor_required: bool = False + dataset_required: bool = False + runs_on: list[str] = ["cpu", "cuda", "mps"] + compatible_before: list[str] = [] + compatible_after: list[str] = [] + + def get_hyperparameters(self) -> list: + """ + Configure all algorithm-specific hyperparameters with ConfigSpace. + + Returns + ------- + list + The hyperparameters. + """ + return [ + OrdinalHyperparameter( + "quantization_method", + sequence=[ + "q4_k_m", + "q4_k_s", + "q5_k_m", + "q8_0", + "f16" + ], + default_value="q4_k_m", + meta={"desc": "Quantization method for llama.cpp. Examples: q4_k_m, q8_0, f16."}, + ), + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is supported. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is supported, False otherwise. + """ + return is_causal_lm(model) or is_transformers_pipeline_with_causal_lm(model) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Quantize the model with Llama.cpp by converting to GGUF. + + Parameters + ---------- + model : Any + The model to quantize. + smash_config : SmashConfigPrefixWrapper + The configuration for the quantization. + + Returns + ------- + Any + The quantized Llama object. + """ + imported_modules = self.import_algorithm_packages() + llama_cpp = imported_modules["llama_cpp"] + + quantization_method = smash_config["quantization_method"] + + pruna_logger.info(f"Quantizing model with llama.cpp using method {quantization_method}") + + # Ensure we have the causal lm if it's a pipeline + model_to_export = model.model if is_transformers_pipeline_with_causal_lm(model) else model + + # llama.cpp requires tensor dimensions to be divisible by a block size (usually 32) + # fallback to f16 for tiny test models avoiding crashes + if ( + hasattr(model_to_export, "config") + and hasattr(model_to_export.config, "hidden_size") + and model_to_export.config.hidden_size < 32 + ): + pruna_logger.info("Tiny model detected. Bypassing quantized block sizes and using f16.") + quantization_method = "f16" + + # Create a temp directory to hold HF model, f16 GGUF, and optimized GGUF + temp_dir = tempfile.mkdtemp() + f16_gguf_path = Path(temp_dir) / "model-f16.gguf" + quant_gguf_path = Path(temp_dir) / f"model-{quantization_method}.gguf" + + try: + # Use a TemporaryDirectory for the HF model to ensure automatic cleanup + with tempfile.TemporaryDirectory(dir=temp_dir) as hf_model_dir: + model_to_export.save_pretrained(hf_model_dir) + if hasattr(smash_config, "tokenizer") and smash_config.tokenizer: + smash_config.tokenizer.save_pretrained(hf_model_dir) + + # download the conversion script directly from llama.cpp + script_url = "https://raw.githubusercontent.com/ggml-org/llama.cpp/b3600/convert_hf_to_gguf.py" + script_path = Path(hf_model_dir) / "convert_hf_to_gguf.py" + urllib.request.urlretrieve(script_url, script_path) + + if not verify_sha256(script_path, LLAMA_CPP_CONVERSION_SCRIPT_SHA256): + raise ValueError( + f"Integrity verification failed for {script_url}. " + "The downloaded script may have been tampered with or the pinned version has changed." + ) + + pruna_logger.info("Converting Hugging Face model to GGUF format...") + convert_cmd = [ + sys.executable, script_path, + hf_model_dir, + "--outfile", f16_gguf_path, + "--outtype", "f16" + ] + subprocess.run(convert_cmd, check=True) + + # quantize the GGUF model + if quantization_method != "f16": + pruna_logger.info(f"Quantizing GGUF model to {quantization_method}...") + + # Retrieve quantize CLI from llama.cpp + if hasattr(llama_cpp, "llama_model_quantize"): + # Using API + params = llama_cpp.llama_model_quantize_default_params() + + # Convert string to enum, e.g. "q4_k_m" -> llama_cpp.LLAMA_FTYPE_MOSTLY_Q4_K_M + ftype_name = f"LLAMA_FTYPE_MOSTLY_{quantization_method.upper()}" + if hasattr(llama_cpp, ftype_name): + params.ftype = getattr(llama_cpp, ftype_name) + else: + raise ValueError(f"Unknown quantization method: {quantization_method}") + + llama_cpp.llama_model_quantize( + str(f16_gguf_path).encode("utf-8"), + str(quant_gguf_path).encode("utf-8"), + params, + ) + else: + raise RuntimeError("llama-cpp-python does not have llama_model_quantize available") + else: + quant_gguf_path = f16_gguf_path + + # Load the quantized model + pruna_logger.info(f"Loading quantized model from {quant_gguf_path}") + quantized_model = llama_cpp.Llama(model_path=str(quant_gguf_path)) + + # Keep a reference to the temp file path so the save function can move it + quantized_model._pruna_temp_dir = temp_dir + quantized_model.model_path = str(quant_gguf_path) + + if quantization_method != "f16": + f16_gguf_path.unlink(missing_ok=True) + + return quantized_model + + except Exception as e: + pruna_logger.error(f"Error during llama.cpp quantization: {e}") + if "temp_dir" in locals() and Path(temp_dir).exists(): + shutil.rmtree(temp_dir) + raise + + def import_algorithm_packages(self) -> Dict[str, Any]: + """ + Provide algorithm packages. + + Returns + ------- + Dict[str, Any] + The algorithm packages. + """ + try: + import llama_cpp + return dict(llama_cpp=llama_cpp) + except ImportError: + raise ImportError( + "Could not import llama_cpp. Please install it with `pip install llama-cpp-python`." + ) diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index 00a9865a..fb7e7981 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -217,8 +217,12 @@ def __eq__(self, other: Any) -> bool: def cleanup_cache_dir(self) -> None: """Clean up the cache directory.""" - if self.cache_dir.exists(): - shutil.rmtree(self.cache_dir) + try: + if hasattr(self, 'cache_dir') and self.cache_dir is not None and hasattr(self.cache_dir, 'exists') and self.cache_dir.exists(): + shutil.rmtree(self.cache_dir) + except (AttributeError, TypeError, ImportError): + # This can happen during interpreter shutdown when modules are already None + pass def reset_cache_dir(self) -> None: """Reset the cache directory.""" diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index 6edff441..62fc90fe 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -22,6 +22,14 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union +try: + from enum import member +except ImportError: + # member was added in 3.11 + def member(x): + """Standard member decorator fallback for older python versions.""" + return x + import diffusers import torch import transformers @@ -469,6 +477,38 @@ def load_quantized_model(quantized_path: str | Path) -> Any: ) +def load_llama_cpp(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: + """ + Load a model quantized with llama.cpp from the given model path. + + Parameters + ---------- + path : str | Path + The path to the model directory. + smash_config : SmashConfig + The SmashConfig object containing the device and device_map. + **kwargs : Any + Additional keyword arguments to pass to the model loading function. + + Returns + ------- + Any + The loaded llama.cpp model. + """ + from pruna.algorithms.llama_cpp import LlamaCpp + + algorithm_packages = LlamaCpp().import_algorithm_packages() + llama_cpp = algorithm_packages["llama_cpp"] + + model_path = Path(path) / "model.gguf" + if not model_path.exists(): + raise FileNotFoundError(f"GGUF file not found at {model_path}") + + model = llama_cpp.Llama(model_path=str(model_path), **filter_load_kwargs(llama_cpp.Llama.__init__, kwargs)) + model.model_path = str(model_path) + return model + + def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: """ Load a diffusers model from the given model path. @@ -587,11 +627,12 @@ class LOAD_FUNCTIONS(Enum): # noqa: N801 """ - transformers = partial(load_transformers_model) - diffusers = partial(load_diffusers_model) - pickled = partial(load_pickled) - hqq = partial(load_hqq) - hqq_diffusers = partial(load_hqq_diffusers) + transformers = member(partial(load_transformers_model)) + diffusers = member(partial(load_diffusers_model)) + pickled = member(partial(load_pickled)) + hqq = member(partial(load_hqq)) + hqq_diffusers = member(partial(load_hqq_diffusers)) + llama_cpp = member(partial(load_llama_cpp)) def __call__(self, *args, **kwargs) -> Any: """ diff --git a/src/pruna/engine/pruna_model.py b/src/pruna/engine/pruna_model.py index a0f34728..ce274bc6 100644 --- a/src/pruna/engine/pruna_model.py +++ b/src/pruna/engine/pruna_model.py @@ -178,6 +178,10 @@ def set_to_eval(self) -> None: """Set the model to evaluation mode.""" set_to_eval(self.model) + def save(self, model_path: str) -> None: + """Save the model.""" + self.save_pretrained(model_path) + def save_pretrained(self, model_path: str) -> None: """ Save the smashed model to the specified model path. diff --git a/src/pruna/engine/save.py b/src/pruna/engine/save.py index 50f9ca78..40dc653e 100644 --- a/src/pruna/engine/save.py +++ b/src/pruna/engine/save.py @@ -22,6 +22,14 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, List, cast +try: + from enum import member +except ImportError: + # member was added in 3.11 + def member(x): + """Standard member decorator fallback for older python versions.""" + return x + import torch import transformers from huggingface_hub import ModelCard, ModelCardData, login, repo_exists, upload_large_folder @@ -37,7 +45,7 @@ ) from pruna.engine.model_checks import get_helpers, is_janus_llamagen_ar from pruna.engine.save_artifacts import save_artifacts -from pruna.engine.utils import determine_dtype, monkeypatch +from pruna.engine.utils import determine_dtype, get_fn_name, monkeypatch from pruna.logging.logger import pruna_logger if TYPE_CHECKING: @@ -67,8 +75,7 @@ def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConf pruna_logger.debug("Using model's original save function...") save_fn = original_save_fn - # if save-before-move was the last operation, we simply move the already saved files, we have delt with them before - elif smash_config.save_fns[-1] == SAVE_FUNCTIONS.save_before_apply.name: + elif len(smash_config.save_fns) > 0 and smash_config.save_fns[-1] == get_fn_name(SAVE_FUNCTIONS.save_before_apply): pruna_logger.debug("Moving saved model...") save_fn = save_before_apply @@ -465,6 +472,53 @@ def save_component(attr_name: str | None, module: torch.nn.Module, subpaths: lis smash_config.load_fns.append(LOAD_FUNCTIONS.hqq_diffusers.name) +def save_model_llama_cpp(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: + """ + Save the model with llama.cpp functionality. + + Parameters + ---------- + model : Any + The model to save. + model_path : str | Path + The directory to save the model to. + smash_config : SmashConfig + The SmashConfig object containing the save and load functions. + """ + model_path = Path(model_path) + + if hasattr(model, "model_path"): + gguf_file = Path(model.model_path) + if gguf_file.exists(): + target_file = model_path / "model.gguf" + if gguf_file.resolve() != target_file.resolve(): + if ( + hasattr(model, "_pruna_temp_dir") + and Path(model._pruna_temp_dir).resolve() == gguf_file.parent.resolve() + ): + try: + shutil.move(gguf_file, target_file) + shutil.rmtree(model._pruna_temp_dir) + delattr(model, "_pruna_temp_dir") + except PermissionError: + pruna_logger.warning( + f"Could not move GGUF file from {gguf_file} to {target_file} " + "(likely memory-mapped on Windows). " + "Copying instead, but the temporary directory will persist " + "until process exit." + ) + shutil.copy(gguf_file, target_file) + else: + shutil.copy(gguf_file, target_file) + + model.model_path = str(target_file) + smash_config.load_fns.append(LOAD_FUNCTIONS.llama_cpp.name) + else: + raise FileNotFoundError(f"GGUF file not found at {gguf_file}") + else: + raise AttributeError("Llama object does not have model_path attribute.") + + def reapply(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: """ Reapply the model. @@ -513,11 +567,12 @@ class SAVE_FUNCTIONS(Enum): # noqa: N801 # Model saved to disk in pickled format """ - pickled = partial(save_pickled) - hqq = partial(save_model_hqq) - hqq_diffusers = partial(save_model_hqq_diffusers) - save_before_apply = partial(save_before_apply) - reapply = partial(reapply) + pickled = member(partial(save_pickled)) + hqq = member(partial(save_model_hqq)) + hqq_diffusers = member(partial(save_model_hqq_diffusers)) + llama_cpp = member(partial(save_model_llama_cpp)) + save_before_apply = member(partial(save_before_apply)) + reapply = member(partial(reapply)) def __call__(self, *args, **kwargs) -> None: """ diff --git a/src/pruna/engine/utils.py b/src/pruna/engine/utils.py index a039fc24..64af5a53 100644 --- a/src/pruna/engine/utils.py +++ b/src/pruna/engine/utils.py @@ -16,9 +16,11 @@ import contextlib import gc +import hashlib import inspect import json from contextlib import AbstractContextManager, contextmanager +from functools import partial from pathlib import Path from typing import Any @@ -38,6 +40,48 @@ def safe_memory_cleanup() -> None: torch.cuda.empty_cache() +def get_fn_name(obj: Any) -> str: + """ + Get the name of a function or a partial function. + + Parameters + ---------- + obj : Any + The function or partial function to get the name of. + + Returns + ------- + str + The name of the function. + """ + if isinstance(obj, partial): + return get_fn_name(obj.func) + return getattr(obj, "name", getattr(obj, "__name__", str(obj))) + + +def verify_sha256(file_path: str | Path, expected_hash: str) -> bool: + """ + Verify the SHA256 hash of a file. + + Parameters + ---------- + file_path : str | Path + The path to the file to verify. + expected_hash : str + The expected SHA256 hash. + + Returns + ------- + bool + True if the hash matches, False otherwise. + """ + sha256_hash = hashlib.sha256() + with Path(file_path).open("rb") as f: + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() == expected_hash + + def load_json_config(path: str | Path, json_name: str) -> dict: """ Load and parse a JSON configuration file. @@ -375,6 +419,9 @@ def get_device(model: Any) -> str: model_device = next(model.parameters()).device except StopIteration: raise ValueError("Could not determine device of model, model has no device attribute.") + except AttributeError: + # Model does not use PyTorch parameters natively (e.g. llama_cpp), default to cpu string mapping + model_device = "cpu" # model_device.type ignores the device index. Added a new function to convert to string. model_device = device_to_string(model_device) diff --git a/tests/algorithms/testers/llama_cpp.py b/tests/algorithms/testers/llama_cpp.py new file mode 100644 index 00000000..6eaf0fc1 --- /dev/null +++ b/tests/algorithms/testers/llama_cpp.py @@ -0,0 +1,36 @@ +from pruna.algorithms.llama_cpp import LlamaCpp +from .base_tester import AlgorithmTesterBase + + +class TestLlamaCpp(AlgorithmTesterBase): + """Test the LlamaCpp quantizer.""" + + __test__ = False + + models = ["llama_3_tiny_random"] + reject_models = ["sd_tiny_random"] + allow_pickle_files = False + algorithm_class = LlamaCpp + metrics = [] + + def pre_smash_hook(self, model): + import pytest + pytest.importorskip("llama_cpp") + + def execute_smash(self, model, smash_config): + """Execute the smash operation without device checking.""" + self.pre_smash_hook(model) + from pruna.smash import smash + smashed_model = smash(model, smash_config=smash_config) + self.post_smash_hook(smashed_model) + # Bypassed device checks because llama_cpp doesn't expose native PyTorch .parameters() for checking + return smashed_model + + def execute_load(self): + """Load the smashed model without device checking.""" + from pruna.engine.pruna_model import PrunaModel + model = PrunaModel.from_pretrained(str(self._saving_path)) + assert isinstance(model, PrunaModel) + self.post_load_hook(model) + # Bypassed device checks because llama_cpp doesn't expose native PyTorch .parameters() for checking + return model