-
Notifications
You must be signed in to change notification settings - Fork 88
feat: integrate Llama.cpp and enhance engine stability for cross-platform usage #584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1858c9c
0178d2f
b4ffcff
656f917
61766d1
f3c6733
32fdf0a
dd64356
6600fda
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,231 @@ | ||||
| # Copyright 2025 - Pruna AI GmbH. All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|
|
||||
| from __future__ import annotations | ||||
|
|
||||
| import shutil | ||||
| import subprocess | ||||
| import sys | ||||
| import tempfile | ||||
| import urllib.request | ||||
| from pathlib import Path | ||||
| from typing import Any, Dict | ||||
|
|
||||
| from ConfigSpace import OrdinalHyperparameter | ||||
|
|
||||
| from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase | ||||
| from pruna.algorithms.base.tags import AlgorithmTag as tags | ||||
| from pruna.config.smash_config import SmashConfigPrefixWrapper | ||||
| from pruna.engine.model_checks import ( | ||||
| is_causal_lm, | ||||
| is_transformers_pipeline_with_causal_lm, | ||||
| ) | ||||
| from pruna.engine.save import SAVE_FUNCTIONS | ||||
| from pruna.engine.utils import verify_sha256 | ||||
| from pruna.logging.logger import pruna_logger | ||||
|
|
||||
| # SHA256 hash for the pinned version (b3600) of convert_hf_to_gguf.py | ||||
| LLAMA_CPP_CONVERSION_SCRIPT_SHA256 = "f62ab712618231b3e76050f94e45dcf94567312c209b4b99bfc142229360b018" | ||||
|
|
||||
|
|
||||
| class LlamaCpp(PrunaAlgorithmBase): | ||||
| """ | ||||
| Implement Llama.cpp as a quantizer. | ||||
|
|
||||
| Converts Hugging Face models to GGUF format and quantizes them using the llama.cpp tools. | ||||
| """ | ||||
|
|
||||
| algorithm_name: str = "llama_cpp" | ||||
| group_tags: list[tags] = [tags.QUANTIZER] | ||||
| references: dict[str, str] = { | ||||
| "GitHub": "https://github.com/ggml-org/llama.cpp", | ||||
| "Python Bindings": "https://github.com/abetlen/llama-cpp-python", | ||||
| } | ||||
| save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.llama_cpp | ||||
| tokenizer_required: bool = False | ||||
| processor_required: bool = False | ||||
| dataset_required: bool = False | ||||
| runs_on: list[str] = ["cpu", "cuda", "mps"] | ||||
| compatible_before: list[str] = [] | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that |
||||
| compatible_after: list[str] = [] | ||||
|
|
||||
| def get_hyperparameters(self) -> list: | ||||
| """ | ||||
| Configure all algorithm-specific hyperparameters with ConfigSpace. | ||||
|
|
||||
| Returns | ||||
| ------- | ||||
| list | ||||
| The hyperparameters. | ||||
| """ | ||||
| return [ | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In llama.cpp can we choose which layer to quantize or not (eg. skip a given block) ? If yes, could you add the target modules hyperparameters (eg. as in pruna/src/pruna/algorithms/hqq.py Line 109 in 7c566d7
|
||||
| OrdinalHyperparameter( | ||||
| "quantization_method", | ||||
| sequence=[ | ||||
| "q4_k_m", | ||||
| "q4_k_s", | ||||
| "q5_k_m", | ||||
| "q8_0", | ||||
| "f16" | ||||
| ], | ||||
| default_value="q4_k_m", | ||||
| meta={"desc": "Quantization method for llama.cpp. Examples: q4_k_m, q8_0, f16."}, | ||||
| ), | ||||
| ] | ||||
|
|
||||
| def model_check_fn(self, model: Any) -> bool: | ||||
| """ | ||||
| Check if the model is supported. | ||||
|
|
||||
| Parameters | ||||
| ---------- | ||||
| model : Any | ||||
| The model to check. | ||||
|
|
||||
| Returns | ||||
| ------- | ||||
| bool | ||||
| True if the model is supported, False otherwise. | ||||
| """ | ||||
| return is_causal_lm(model) or is_transformers_pipeline_with_causal_lm(model) | ||||
|
|
||||
| def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: | ||||
| """ | ||||
| Quantize the model with Llama.cpp by converting to GGUF. | ||||
|
|
||||
| Parameters | ||||
| ---------- | ||||
| model : Any | ||||
| The model to quantize. | ||||
| smash_config : SmashConfigPrefixWrapper | ||||
| The configuration for the quantization. | ||||
|
|
||||
| Returns | ||||
| ------- | ||||
| Any | ||||
| The quantized Llama object. | ||||
| """ | ||||
| imported_modules = self.import_algorithm_packages() | ||||
| llama_cpp = imported_modules["llama_cpp"] | ||||
|
|
||||
| quantization_method = smash_config["quantization_method"] | ||||
|
|
||||
| pruna_logger.info(f"Quantizing model with llama.cpp using method {quantization_method}") | ||||
|
|
||||
| # Ensure we have the causal lm if it's a pipeline | ||||
| model_to_export = model.model if is_transformers_pipeline_with_causal_lm(model) else model | ||||
|
|
||||
| # llama.cpp requires tensor dimensions to be divisible by a block size (usually 32) | ||||
| # fallback to f16 for tiny test models avoiding crashes | ||||
| if ( | ||||
| hasattr(model_to_export, "config") | ||||
| and hasattr(model_to_export.config, "hidden_size") | ||||
| and model_to_export.config.hidden_size < 32 | ||||
| ): | ||||
| pruna_logger.info("Tiny model detected. Bypassing quantized block sizes and using f16.") | ||||
| quantization_method = "f16" | ||||
|
|
||||
| # Create a temp directory to hold HF model, f16 GGUF, and optimized GGUF | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The hf can stay saved in the tempdir, but what do you think about using the pruna cache for f16 GGUF, and quantized GGUF ? (eg. as we do for Lora weights in
|
||||
| temp_dir = tempfile.mkdtemp() | ||||
| f16_gguf_path = Path(temp_dir) / "model-f16.gguf" | ||||
| quant_gguf_path = Path(temp_dir) / f"model-{quantization_method}.gguf" | ||||
|
|
||||
| try: | ||||
| # Use a TemporaryDirectory for the HF model to ensure automatic cleanup | ||||
| with tempfile.TemporaryDirectory(dir=temp_dir) as hf_model_dir: | ||||
| model_to_export.save_pretrained(hf_model_dir) | ||||
| if hasattr(smash_config, "tokenizer") and smash_config.tokenizer: | ||||
| smash_config.tokenizer.save_pretrained(hf_model_dir) | ||||
|
|
||||
| # download the conversion script directly from llama.cpp | ||||
| script_url = "https://raw.githubusercontent.com/ggml-org/llama.cpp/b3600/convert_hf_to_gguf.py" | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you move the script url into a global variable at the top of the file please? This is for maintenance, so we can easily update both
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @gsprochette. I'll get to the local caching and refactor shortly. Thanks for the review! |
||||
| script_path = Path(hf_model_dir) / "convert_hf_to_gguf.py" | ||||
| urllib.request.urlretrieve(script_url, script_path) | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As said previously, I am not a big fan of having to download the script at every smash call.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also tagging @gsprochette if you have some better idea to handle this ! |
||||
|
|
||||
| if not verify_sha256(script_path, LLAMA_CPP_CONVERSION_SCRIPT_SHA256): | ||||
| raise ValueError( | ||||
| f"Integrity verification failed for {script_url}. " | ||||
| "The downloaded script may have been tampered with or the pinned version has changed." | ||||
| ) | ||||
|
|
||||
| pruna_logger.info("Converting Hugging Face model to GGUF format...") | ||||
| convert_cmd = [ | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the conversion fails, the stderr output isn't captured :( |
||||
| sys.executable, script_path, | ||||
| hf_model_dir, | ||||
| "--outfile", f16_gguf_path, | ||||
| "--outtype", "f16" | ||||
| ] | ||||
| subprocess.run(convert_cmd, check=True) | ||||
|
|
||||
| # quantize the GGUF model | ||||
| if quantization_method != "f16": | ||||
| pruna_logger.info(f"Quantizing GGUF model to {quantization_method}...") | ||||
|
|
||||
| # Retrieve quantize CLI from llama.cpp | ||||
| if hasattr(llama_cpp, "llama_model_quantize"): | ||||
| # Using API | ||||
| params = llama_cpp.llama_model_quantize_default_params() | ||||
|
|
||||
| # Convert string to enum, e.g. "q4_k_m" -> llama_cpp.LLAMA_FTYPE_MOSTLY_Q4_K_M | ||||
| ftype_name = f"LLAMA_FTYPE_MOSTLY_{quantization_method.upper()}" | ||||
| if hasattr(llama_cpp, ftype_name): | ||||
| params.ftype = getattr(llama_cpp, ftype_name) | ||||
| else: | ||||
| raise ValueError(f"Unknown quantization method: {quantization_method}") | ||||
|
|
||||
| llama_cpp.llama_model_quantize( | ||||
| str(f16_gguf_path).encode("utf-8"), | ||||
| str(quant_gguf_path).encode("utf-8"), | ||||
| params, | ||||
| ) | ||||
| else: | ||||
| raise RuntimeError("llama-cpp-python does not have llama_model_quantize available") | ||||
| else: | ||||
| quant_gguf_path = f16_gguf_path | ||||
|
|
||||
| # Load the quantized model | ||||
| pruna_logger.info(f"Loading quantized model from {quant_gguf_path}") | ||||
| quantized_model = llama_cpp.Llama(model_path=str(quant_gguf_path)) | ||||
|
|
||||
| # Keep a reference to the temp file path so the save function can move it | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you use the pruna cache, as suggested above, you will be able to discard these 2 ref attr ;) |
||||
| quantized_model._pruna_temp_dir = temp_dir | ||||
| quantized_model.model_path = str(quant_gguf_path) | ||||
|
|
||||
| if quantization_method != "f16": | ||||
| f16_gguf_path.unlink(missing_ok=True) | ||||
|
|
||||
| return quantized_model | ||||
|
|
||||
| except Exception as e: | ||||
| pruna_logger.error(f"Error during llama.cpp quantization: {e}") | ||||
| if "temp_dir" in locals() and Path(temp_dir).exists(): | ||||
| shutil.rmtree(temp_dir) | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If save is never called, the temp directory persists until process exit, no ? |
||||
| raise | ||||
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
krishjp marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
|
||||
| def import_algorithm_packages(self) -> Dict[str, Any]: | ||||
| """ | ||||
| Provide algorithm packages. | ||||
|
|
||||
| Returns | ||||
| ------- | ||||
| Dict[str, Any] | ||||
| The algorithm packages. | ||||
| """ | ||||
| try: | ||||
| import llama_cpp | ||||
| return dict(llama_cpp=llama_cpp) | ||||
| except ImportError: | ||||
| raise ImportError( | ||||
| "Could not import llama_cpp. Please install it with `pip install llama-cpp-python`." | ||||
| ) | ||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -217,8 +217,12 @@ def __eq__(self, other: Any) -> bool: | |
|
|
||
| def cleanup_cache_dir(self) -> None: | ||
| """Clean up the cache directory.""" | ||
| if self.cache_dir.exists(): | ||
| shutil.rmtree(self.cache_dir) | ||
| try: | ||
| if hasattr(self, 'cache_dir') and self.cache_dir is not None and hasattr(self.cache_dir, 'exists') and self.cache_dir.exists(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need the hasattr(self.cache_dir, 'exists')? |
||
| shutil.rmtree(self.cache_dir) | ||
| except (AttributeError, TypeError, ImportError): | ||
| # This can happen during interpreter shutdown when modules are already None | ||
| pass | ||
|
|
||
| def reset_cache_dir(self) -> None: | ||
| """Reset the cache directory.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,14 @@ | |
| from pathlib import Path | ||
| from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union | ||
|
|
||
| try: | ||
| from enum import member | ||
| except ImportError: | ||
| # member was added in 3.11 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as in save.py ! |
||
| def member(x): | ||
| """Standard member decorator fallback for older python versions.""" | ||
| return x | ||
|
|
||
| import diffusers | ||
| import torch | ||
| import transformers | ||
|
|
@@ -469,6 +477,38 @@ def load_quantized_model(quantized_path: str | Path) -> Any: | |
| ) | ||
|
|
||
|
|
||
| def load_llama_cpp(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: | ||
| """ | ||
| Load a model quantized with llama.cpp from the given model path. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| path : str | Path | ||
| The path to the model directory. | ||
| smash_config : SmashConfig | ||
| The SmashConfig object containing the device and device_map. | ||
| **kwargs : Any | ||
| Additional keyword arguments to pass to the model loading function. | ||
|
|
||
| Returns | ||
| ------- | ||
| Any | ||
| The loaded llama.cpp model. | ||
| """ | ||
| from pruna.algorithms.llama_cpp import LlamaCpp | ||
|
|
||
| algorithm_packages = LlamaCpp().import_algorithm_packages() | ||
| llama_cpp = algorithm_packages["llama_cpp"] | ||
|
|
||
| model_path = Path(path) / "model.gguf" | ||
| if not model_path.exists(): | ||
| raise FileNotFoundError(f"GGUF file not found at {model_path}") | ||
|
|
||
| model = llama_cpp.Llama(model_path=str(model_path), **filter_load_kwargs(llama_cpp.Llama.__init__, kwargs)) | ||
| model.model_path = str(model_path) | ||
| return model | ||
|
|
||
|
|
||
| def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: | ||
| """ | ||
| Load a diffusers model from the given model path. | ||
|
|
@@ -587,11 +627,12 @@ class LOAD_FUNCTIONS(Enum): # noqa: N801 | |
| <Loaded transformer model> | ||
| """ | ||
|
|
||
| transformers = partial(load_transformers_model) | ||
| diffusers = partial(load_diffusers_model) | ||
| pickled = partial(load_pickled) | ||
| hqq = partial(load_hqq) | ||
| hqq_diffusers = partial(load_hqq_diffusers) | ||
| transformers = member(partial(load_transformers_model)) | ||
| diffusers = member(partial(load_diffusers_model)) | ||
| pickled = member(partial(load_pickled)) | ||
| hqq = member(partial(load_hqq)) | ||
| hqq_diffusers = member(partial(load_hqq_diffusers)) | ||
| llama_cpp = member(partial(load_llama_cpp)) | ||
|
|
||
| def __call__(self, *args, **kwargs) -> Any: | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -178,6 +178,10 @@ def set_to_eval(self) -> None: | |
| """Set the model to evaluation mode.""" | ||
| set_to_eval(self.model) | ||
|
|
||
| def save(self, model_path: str) -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need this alias ? |
||
| """Save the model.""" | ||
| self.save_pretrained(model_path) | ||
|
|
||
| def save_pretrained(self, model_path: str) -> None: | ||
| """ | ||
| Save the smashed model to the specified model path. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No code in this PR directly uses the
ggufpackage. I guess because you need it for the conversion script you download? If yes, please add a comment to make this choice transparent ;) (and also check if we need other lib for running this script)