Skip to content

Commit b468bae

Browse files
committed
fix: integrity verification of remote scripts
1 parent f3c6733 commit b468bae

4 files changed

Lines changed: 90 additions & 37 deletions

File tree

src/pruna/algorithms/base/pruna_base.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
SAVE_FUNCTIONS,
2929
save_pruna_model,
3030
)
31+
from pruna.engine.utils import get_fn_name
3132
from pruna.logging.logger import pruna_logger
3233

3334

@@ -365,11 +366,7 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any:
365366

366367
# if the registered save function is None, the original saving function remains
367368
if self.save_fn is not None and self.save_fn != SAVE_FUNCTIONS.reapply:
368-
if isinstance(self.save_fn, functools.partial):
369-
fn_name = getattr(self.save_fn.func, 'name', getattr(self.save_fn.func, '__name__', str(self.save_fn.func)))
370-
else:
371-
fn_name = getattr(self.save_fn, 'name', getattr(self.save_fn, '__name__', str(self.save_fn)))
372-
369+
fn_name = get_fn_name(self.save_fn)
373370
smash_config.save_fns.append(fn_name)
374371

375372
prefix = self.algorithm_name + "_"

src/pruna/algorithms/llama_cpp.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,28 @@
1515
from __future__ import annotations
1616

1717
import os
18-
import tempfile
1918
import subprocess
19+
import tempfile
20+
import shutil
21+
import urllib.request
22+
import sys
2023
from typing import Any, Dict
2124

2225
from ConfigSpace import Constant, OrdinalHyperparameter
2326

2427
from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
2528
from pruna.algorithms.base.tags import AlgorithmTag as tags
26-
from pruna.config.smash_config import SmashConfigPrefixWrapper
29+
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
2730
from pruna.engine.save import SAVE_FUNCTIONS
2831
from pruna.engine.model_checks import is_causal_lm, is_transformers_pipeline_with_causal_lm
32+
from pruna.engine.utils import verify_sha256
2933
from pruna.logging.logger import pruna_logger
3034

3135

36+
# SHA256 hash for the pinned version (b3600) of convert_hf_to_gguf.py
37+
LLAMA_CPP_CONVERSION_SCRIPT_SHA256 = "f62ab712618231b3e76050f94e45dcf94567312c209b4b99bfc142229360b018"
38+
39+
3240
class LlamaCpp(PrunaAlgorithmBase):
3341
"""
3442
Implement Llama.cpp as a quantizer.
@@ -128,31 +136,35 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
128136

129137
# Create a temp directory to hold HF model, f16 GGUF, and optimized GGUF
130138
temp_dir = tempfile.mkdtemp()
131-
hf_model_dir = os.path.join(temp_dir, "hf_model")
132139
f16_gguf_path = os.path.join(temp_dir, "model-f16.gguf")
133140
quant_gguf_path = os.path.join(temp_dir, f"model-{quantization_method}.gguf")
134141

135142
try:
136-
# save HF model
137-
model_to_export.save_pretrained(hf_model_dir)
138-
if hasattr(smash_config, "tokenizer") and smash_config.tokenizer:
139-
smash_config.tokenizer.save_pretrained(hf_model_dir)
140-
141-
# download the conversion script directly from llama.cpp
142-
import urllib.request
143-
import sys
144-
script_url = "https://raw.githubusercontent.com/ggml-org/llama.cpp/b3600/convert_hf_to_gguf.py"
145-
script_path = os.path.join(temp_dir, "convert_hf_to_gguf.py")
146-
urllib.request.urlretrieve(script_url, script_path)
147-
148-
pruna_logger.info("Converting Hugging Face model to GGUF format...")
149-
convert_cmd = [
150-
sys.executable, script_path,
151-
hf_model_dir,
152-
"--outfile", f16_gguf_path,
153-
"--outtype", "f16"
154-
]
155-
subprocess.run(convert_cmd, check=True)
143+
# Use a TemporaryDirectory for the HF model to ensure automatic cleanup
144+
with tempfile.TemporaryDirectory(dir=temp_dir) as hf_model_dir:
145+
model_to_export.save_pretrained(hf_model_dir)
146+
if hasattr(smash_config, "tokenizer") and smash_config.tokenizer:
147+
smash_config.tokenizer.save_pretrained(hf_model_dir)
148+
149+
# download the conversion script directly from llama.cpp
150+
script_url = "https://raw.githubusercontent.com/ggml-org/llama.cpp/b3600/convert_hf_to_gguf.py"
151+
script_path = os.path.join(hf_model_dir, "convert_hf_to_gguf.py")
152+
urllib.request.urlretrieve(script_url, script_path)
153+
154+
if not verify_sha256(script_path, LLAMA_CPP_CONVERSION_SCRIPT_SHA256):
155+
raise ValueError(
156+
f"Integrity verification failed for {script_url}. "
157+
"The downloaded script may have been tampered with or the pinned version has changed."
158+
)
159+
160+
pruna_logger.info("Converting Hugging Face model to GGUF format...")
161+
convert_cmd = [
162+
sys.executable, script_path,
163+
hf_model_dir,
164+
"--outfile", f16_gguf_path,
165+
"--outtype", "f16"
166+
]
167+
subprocess.run(convert_cmd, check=True)
156168

157169
# quantize the GGUF model
158170
if quantization_method != "f16":
@@ -194,6 +206,8 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
194206

195207
except Exception as e:
196208
pruna_logger.error(f"Error during llama.cpp quantization: {e}")
209+
if 'temp_dir' in locals() and os.path.exists(temp_dir):
210+
shutil.rmtree(temp_dir)
197211
raise
198212

199213
def import_algorithm_packages(self) -> Dict[str, Any]:

src/pruna/engine/save.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
)
4444
from pruna.engine.model_checks import get_helpers, is_janus_llamagen_ar
4545
from pruna.engine.save_artifacts import save_artifacts
46-
from pruna.engine.utils import determine_dtype, monkeypatch
46+
from pruna.engine.utils import determine_dtype, get_fn_name, monkeypatch
4747
from pruna.logging.logger import pruna_logger
4848

4949
if TYPE_CHECKING:
@@ -65,11 +65,6 @@ def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConf
6565
The SmashConfig object containing the save and load functions.
6666
"""
6767

68-
def get_fn_name(obj):
69-
if isinstance(obj, partial):
70-
return get_fn_name(obj.func)
71-
return getattr(obj, 'name', getattr(obj, '__name__', str(obj)))
72-
7368
model_path = Path(model_path)
7469
if not model_path.exists():
7570
model_path.mkdir(parents=True, exist_ok=True)
@@ -495,12 +490,15 @@ def save_model_llama_cpp(model: Any, model_path: str | Path, smash_config: Smash
495490
gguf_file = Path(model.model_path)
496491
if gguf_file.exists():
497492
target_file = model_path / "model.gguf"
498-
shutil.copy(gguf_file, target_file)
493+
shutil.move(gguf_file, target_file)
494+
# Cleanup the temporary directory
495+
if gguf_file.parent.exists():
496+
shutil.rmtree(gguf_file.parent)
499497
smash_config.load_fns.append(LOAD_FUNCTIONS.llama_cpp.name)
500498
else:
501-
pruna_logger.error(f"GGUF file not found at {gguf_file}")
499+
raise FileNotFoundError(f"GGUF file not found at {gguf_file}")
502500
else:
503-
pruna_logger.error("Llama object does not have model_path attribute.")
501+
raise AttributeError("Llama object does not have model_path attribute.")
504502

505503

506504
def reapply(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None:

src/pruna/engine/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616

1717
import contextlib
1818
import gc
19+
import hashlib
1920
import inspect
2021
import json
2122
from contextlib import AbstractContextManager, contextmanager
23+
from functools import partial
2224
from pathlib import Path
2325
from typing import Any
2426

@@ -38,6 +40,48 @@ def safe_memory_cleanup() -> None:
3840
torch.cuda.empty_cache()
3941

4042

43+
def get_fn_name(obj: Any) -> str:
44+
"""
45+
Get the name of a function or a partial function.
46+
47+
Parameters
48+
----------
49+
obj : Any
50+
The function or partial function to get the name of.
51+
52+
Returns
53+
-------
54+
str
55+
The name of the function.
56+
"""
57+
if isinstance(obj, partial):
58+
return get_fn_name(obj.func)
59+
return getattr(obj, "name", getattr(obj, "__name__", str(obj)))
60+
61+
62+
def verify_sha256(file_path: str | Path, expected_hash: str) -> bool:
63+
"""
64+
Verify the SHA256 hash of a file.
65+
66+
Parameters
67+
----------
68+
file_path : str | Path
69+
The path to the file to verify.
70+
expected_hash : str
71+
The expected SHA256 hash.
72+
73+
Returns
74+
-------
75+
bool
76+
True if the hash matches, False otherwise.
77+
"""
78+
sha256_hash = hashlib.sha256()
79+
with Path(file_path).open("rb") as f:
80+
for byte_block in iter(lambda: f.read(4096), b""):
81+
sha256_hash.update(byte_block)
82+
return sha256_hash.hexdigest() == expected_hash
83+
84+
4185
def load_json_config(path: str | Path, json_name: str) -> dict:
4286
"""
4387
Load and parse a JSON configuration file.

0 commit comments

Comments
 (0)