|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | 17 | import os |
18 | | -import tempfile |
19 | 18 | import subprocess |
| 19 | +import tempfile |
| 20 | +import shutil |
| 21 | +import urllib.request |
| 22 | +import sys |
20 | 23 | from typing import Any, Dict |
21 | 24 |
|
22 | 25 | from ConfigSpace import Constant, OrdinalHyperparameter |
23 | 26 |
|
24 | 27 | from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase |
25 | 28 | from pruna.algorithms.base.tags import AlgorithmTag as tags |
26 | | -from pruna.config.smash_config import SmashConfigPrefixWrapper |
| 29 | +from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper |
27 | 30 | from pruna.engine.save import SAVE_FUNCTIONS |
28 | 31 | from pruna.engine.model_checks import is_causal_lm, is_transformers_pipeline_with_causal_lm |
| 32 | +from pruna.engine.utils import verify_sha256 |
29 | 33 | from pruna.logging.logger import pruna_logger |
30 | 34 |
|
31 | 35 |
|
| 36 | +# SHA256 hash for the pinned version (b3600) of convert_hf_to_gguf.py |
| 37 | +LLAMA_CPP_CONVERSION_SCRIPT_SHA256 = "f62ab712618231b3e76050f94e45dcf94567312c209b4b99bfc142229360b018" |
| 38 | + |
| 39 | + |
32 | 40 | class LlamaCpp(PrunaAlgorithmBase): |
33 | 41 | """ |
34 | 42 | Implement Llama.cpp as a quantizer. |
@@ -128,31 +136,35 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: |
128 | 136 |
|
129 | 137 | # Create a temp directory to hold HF model, f16 GGUF, and optimized GGUF |
130 | 138 | temp_dir = tempfile.mkdtemp() |
131 | | - hf_model_dir = os.path.join(temp_dir, "hf_model") |
132 | 139 | f16_gguf_path = os.path.join(temp_dir, "model-f16.gguf") |
133 | 140 | quant_gguf_path = os.path.join(temp_dir, f"model-{quantization_method}.gguf") |
134 | 141 |
|
135 | 142 | try: |
136 | | - # save HF model |
137 | | - model_to_export.save_pretrained(hf_model_dir) |
138 | | - if hasattr(smash_config, "tokenizer") and smash_config.tokenizer: |
139 | | - smash_config.tokenizer.save_pretrained(hf_model_dir) |
140 | | - |
141 | | - # download the conversion script directly from llama.cpp |
142 | | - import urllib.request |
143 | | - import sys |
144 | | - script_url = "https://raw.githubusercontent.com/ggml-org/llama.cpp/b3600/convert_hf_to_gguf.py" |
145 | | - script_path = os.path.join(temp_dir, "convert_hf_to_gguf.py") |
146 | | - urllib.request.urlretrieve(script_url, script_path) |
147 | | - |
148 | | - pruna_logger.info("Converting Hugging Face model to GGUF format...") |
149 | | - convert_cmd = [ |
150 | | - sys.executable, script_path, |
151 | | - hf_model_dir, |
152 | | - "--outfile", f16_gguf_path, |
153 | | - "--outtype", "f16" |
154 | | - ] |
155 | | - subprocess.run(convert_cmd, check=True) |
| 143 | + # Use a TemporaryDirectory for the HF model to ensure automatic cleanup |
| 144 | + with tempfile.TemporaryDirectory(dir=temp_dir) as hf_model_dir: |
| 145 | + model_to_export.save_pretrained(hf_model_dir) |
| 146 | + if hasattr(smash_config, "tokenizer") and smash_config.tokenizer: |
| 147 | + smash_config.tokenizer.save_pretrained(hf_model_dir) |
| 148 | + |
| 149 | + # download the conversion script directly from llama.cpp |
| 150 | + script_url = "https://raw.githubusercontent.com/ggml-org/llama.cpp/b3600/convert_hf_to_gguf.py" |
| 151 | + script_path = os.path.join(hf_model_dir, "convert_hf_to_gguf.py") |
| 152 | + urllib.request.urlretrieve(script_url, script_path) |
| 153 | + |
| 154 | + if not verify_sha256(script_path, LLAMA_CPP_CONVERSION_SCRIPT_SHA256): |
| 155 | + raise ValueError( |
| 156 | + f"Integrity verification failed for {script_url}. " |
| 157 | + "The downloaded script may have been tampered with or the pinned version has changed." |
| 158 | + ) |
| 159 | + |
| 160 | + pruna_logger.info("Converting Hugging Face model to GGUF format...") |
| 161 | + convert_cmd = [ |
| 162 | + sys.executable, script_path, |
| 163 | + hf_model_dir, |
| 164 | + "--outfile", f16_gguf_path, |
| 165 | + "--outtype", "f16" |
| 166 | + ] |
| 167 | + subprocess.run(convert_cmd, check=True) |
156 | 168 |
|
157 | 169 | # quantize the GGUF model |
158 | 170 | if quantization_method != "f16": |
@@ -194,6 +206,8 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: |
194 | 206 |
|
195 | 207 | except Exception as e: |
196 | 208 | pruna_logger.error(f"Error during llama.cpp quantization: {e}") |
| 209 | + if 'temp_dir' in locals() and os.path.exists(temp_dir): |
| 210 | + shutil.rmtree(temp_dir) |
197 | 211 | raise |
198 | 212 |
|
199 | 213 | def import_algorithm_packages(self) -> Dict[str, Any]: |
|
0 commit comments