Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 75 additions & 25 deletions src/pruna/engine/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
import sys
from copy import deepcopy
from enum import Enum
from functools import partial

try:
from enum import member
except ImportError:
# Python 3.10 compat: partial prevents Enum from treating functions as methods
from functools import partial as member
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union

Expand All @@ -31,7 +36,11 @@

from pruna import SmashConfig
from pruna.engine.load_artifacts import load_artifacts
from pruna.engine.utils import load_json_config, move_to_device, set_to_best_available_device
from pruna.engine.utils import (
load_json_config,
move_to_device,
set_to_best_available_device,
)
from pruna.logging.logger import pruna_logger

PICKLED_FILE_NAME = "optimized_model.pt"
Expand Down Expand Up @@ -244,7 +253,9 @@ def resmash(model: Any, smash_config: SmashConfig) -> Any:
return smash(model=model, smash_config=smash_config_subset)


def load_transformers_model(path: str | Path, smash_config: SmashConfig | None = None, **kwargs) -> Any:
def load_transformers_model(
path: str | Path, smash_config: SmashConfig | None = None, **kwargs
) -> Any:
"""
Load a transformers model or pipeline from the given model path.

Expand Down Expand Up @@ -286,11 +297,17 @@ def load_transformers_model(path: str | Path, smash_config: SmashConfig | None =
device_map = "auto"
else:
device = smash_config.device if smash_config.device != "cuda" else "cuda:0"
device_map = smash_config.device_map if smash_config.device == "accelerate" else device
device_map = (
smash_config.device_map
if smash_config.device == "accelerate"
else device
)
return cls.from_pretrained(path, device_map=device_map, **kwargs)


def load_diffusers_model(path: str | Path, smash_config: SmashConfig | None = None, **kwargs) -> Any:
def load_diffusers_model(
path: str | Path, smash_config: SmashConfig | None = None, **kwargs
) -> Any:
"""
Load a diffusers model from the given model path.

Expand Down Expand Up @@ -366,7 +383,9 @@ def load_pickled(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any:
The loaded pickled model.
"""
# torch load has a target device but no interface to reproduce an accelerate-distributed model, we first map to cpu
target_device = "cpu" if smash_config.device == "accelerate" else smash_config.device
target_device = (
"cpu" if smash_config.device == "accelerate" else smash_config.device
)
model = torch.load(
Path(path) / PICKLED_FILE_NAME,
weights_only=False,
Expand Down Expand Up @@ -406,35 +425,49 @@ def load_hqq(model_path: str | Path, smash_config: SmashConfig, **kwargs) -> Any
else:
saved_smash_config = SmashConfig()
saved_smash_config.load_from_json(model_path)
compute_dtype = torch.float16 if saved_smash_config["hqq_compute_dtype"] == "torch.float16" else torch.bfloat16
compute_dtype = (
torch.float16
if saved_smash_config["hqq_compute_dtype"] == "torch.float16"
else torch.bfloat16
)

has_config = (model_path / "config.json").exists()
is_janus_model = (
has_config and load_json_config(model_path, "config.json")["architectures"][0] == "JanusForConditionalGeneration"
has_config
and load_json_config(model_path, "config.json")["architectures"][0]
== "JanusForConditionalGeneration"
)

def load_quantized_model(quantized_path: str | Path) -> Any:
try: # Try to use pipeline for HF specific HQQ quantization
quantized_model = algorithm_packages["HQQModelForCausalLM"].from_quantized(
str(quantized_path),
device=smash_config.device,
**filter_load_kwargs(algorithm_packages["HQQModelForCausalLM"].from_quantized, kwargs),
**filter_load_kwargs(
algorithm_packages["HQQModelForCausalLM"].from_quantized, kwargs
),
)
except Exception: # Default to generic HQQ pipeline if it fails
pruna_logger.info("Could not load HQQ model using HQQModelForCausalLM, trying generic AutoHQQHFModel...")
pruna_logger.info(
"Could not load HQQ model using HQQModelForCausalLM, trying generic AutoHQQHFModel..."
)
quantized_model = algorithm_packages["AutoHQQHFModel"].from_quantized(
str(quantized_path),
device=smash_config.device,
compute_dtype=compute_dtype,
**filter_load_kwargs(algorithm_packages["AutoHQQHFModel"].from_quantized, kwargs),
**filter_load_kwargs(
algorithm_packages["AutoHQQHFModel"].from_quantized, kwargs
),
)
return quantized_model

if (model_path / PIPELINE_INFO_FILE_NAME).exists(): # pipeline
# load the pipeline with a fake model on meta device
with (model_path / PIPELINE_INFO_FILE_NAME).open("r") as f:
task = json.load(f)["task"]
pipe = pipeline(task=task, model=str(model_path), model_kwargs={"device_map": "meta"})
pipe = pipeline(
task=task, model=str(model_path), model_kwargs={"device_map": "meta"}
)
# load the quantized model
pipe.model = load_quantized_model(model_path)
move_to_device(pipe, smash_config.device)
Expand All @@ -450,7 +483,9 @@ def load_quantized_model(quantized_path: str | Path) -> Any:
# Janus language model must be patched to match HQQ's causal LM assumption
quantized_path = str(hqq_model_dir / "qmodel.pt")
weights = torch.load(quantized_path, map_location="cpu", weights_only=True)
is_already_patched = all(k == "lm_head" or k.startswith("model.") for k in weights)
is_already_patched = all(
k == "lm_head" or k.startswith("model.") for k in weights
)
if not is_already_patched:
# load the weight on cpu to rename attr -> model.attr,
weights = {f"model.{k}": v for k, v in weights.items()}
Expand All @@ -459,7 +494,9 @@ def load_quantized_model(quantized_path: str | Path) -> Any:
torch.save(weights, quantized_path) # patch weights

quantized_causal_lm = load_quantized_model(hqq_model_dir)
model.model.language_model = quantized_causal_lm.model # drop the lm_head and causal_lm wrapper
model.model.language_model = (
quantized_causal_lm.model
) # drop the lm_head and causal_lm wrapper
# some weights of the language_model are not on the correct device, so we move it afterwards.
move_to_device(model, smash_config.device)
return model
Expand Down Expand Up @@ -498,8 +535,12 @@ def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) ->
)

hf_quantizer = HQQDiffusers()
auto_hqq_hf_diffusers_model = construct_base_class(hf_quantizer.import_algorithm_packages(), [])
quantized_load_kwargs = filter_load_kwargs(auto_hqq_hf_diffusers_model.from_quantized, kwargs)
auto_hqq_hf_diffusers_model = construct_base_class(
hf_quantizer.import_algorithm_packages(), []
)
quantized_load_kwargs = filter_load_kwargs(
auto_hqq_hf_diffusers_model.from_quantized, kwargs
)

path = Path(path)
if "compute_dtype" not in kwargs and (path / "dtype_info.json").exists():
Expand All @@ -508,7 +549,9 @@ def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) ->
kwargs.setdefault("torch_dtype", dtype)

qmodel_path = path / "qmodel.pt"
if qmodel_path.exists(): # the whole model was quantized, not a pipeline, load it directly
if (
qmodel_path.exists()
): # the whole model was quantized, not a pipeline, load it directly
model = auto_hqq_hf_diffusers_model.from_quantized(path, **kwargs)
# force dtype if specified, from_quantized does not set it properly
if "torch_dtype" in kwargs:
Expand All @@ -520,7 +563,9 @@ def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) ->
# that can be quantized and saved separately.
# by convention, each component has been saved in a directory f"{attr_name}_quantized".
quantized_components: dict[str, Any] = {}
for quantized_path in [qpath for qpath in path.iterdir() if qpath.name.endswith("_quantized")]:
for quantized_path in [
qpath for qpath in path.iterdir() if qpath.name.endswith("_quantized")
]:
attr_name = quantized_path.name.replace("_quantized", "")

# legacy behavior: backbone_quantized -> target the transformer or unet attribute
Expand Down Expand Up @@ -587,11 +632,11 @@ class LOAD_FUNCTIONS(Enum): # noqa: N801
<Loaded transformer model>
"""

transformers = partial(load_transformers_model)
diffusers = partial(load_diffusers_model)
pickled = partial(load_pickled)
hqq = partial(load_hqq)
hqq_diffusers = partial(load_hqq_diffusers)
transformers = member(load_transformers_model)
diffusers = member(load_diffusers_model)
pickled = member(load_pickled)
hqq = member(load_hqq)
hqq_diffusers = member(load_hqq_diffusers)

def __call__(self, *args, **kwargs) -> Any:
"""
Expand Down Expand Up @@ -636,7 +681,10 @@ def filter_load_kwargs(func: Callable, kwargs: dict) -> dict:
signature = inspect.signature(func)

# Check if function accepts arbitrary kwargs
has_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values())
has_kwargs = any(
param.kind == inspect.Parameter.VAR_KEYWORD
for param in signature.parameters.values()
)

if has_kwargs:
return kwargs
Expand All @@ -648,6 +696,8 @@ def filter_load_kwargs(func: Callable, kwargs: dict) -> dict:

# Log the discarded kwargs
if invalid_kwargs:
pruna_logger.info(f"Discarded unused loading kwargs: {list(invalid_kwargs.keys())}")
pruna_logger.info(
f"Discarded unused loading kwargs: {list(invalid_kwargs.keys())}"
)

return valid_kwargs
11 changes: 8 additions & 3 deletions src/pruna/engine/load_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@

import json
from enum import Enum
from functools import partial

try:
from enum import member
except ImportError:
# Python 3.10 compat: partial prevents Enum from treating functions as methods
from functools import partial as member
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -181,8 +186,8 @@ class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801
# Torch artifacts loaded into the current runtime
"""

torch_artifacts = partial(load_torch_artifacts)
moe_kernel_tuner_artifacts = partial(load_moe_kernel_tuner_artifacts)
torch_artifacts = member(load_torch_artifacts)
moe_kernel_tuner_artifacts = member(load_moe_kernel_tuner_artifacts)

def __call__(self, *args, **kwargs) -> None:
"""Call the underlying load function."""
Expand Down
17 changes: 11 additions & 6 deletions src/pruna/engine/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
import shutil
import tempfile
from enum import Enum
from functools import partial

try:
from enum import member
except ImportError:
# Python 3.10 compat: partial prevents Enum from treating functions as methods
from functools import partial as member
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, cast

Expand Down Expand Up @@ -513,11 +518,11 @@ class SAVE_FUNCTIONS(Enum): # noqa: N801
# Model saved to disk in pickled format
"""

pickled = partial(save_pickled)
hqq = partial(save_model_hqq)
hqq_diffusers = partial(save_model_hqq_diffusers)
save_before_apply = partial(save_before_apply)
reapply = partial(reapply)
pickled = member(save_pickled)
hqq = member(save_model_hqq)
hqq_diffusers = member(save_model_hqq_diffusers)
save_before_apply = member(save_before_apply)
reapply = member(reapply)

def __call__(self, *args, **kwargs) -> None:
"""
Expand Down
11 changes: 8 additions & 3 deletions src/pruna/engine/save_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
# limitations under the License.
import shutil
from enum import Enum
from functools import partial

try:
from enum import member
except ImportError:
# Python 3.10 compat: partial prevents Enum from treating functions as methods
from functools import partial as member
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -147,8 +152,8 @@ class SAVE_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801
# Torch artifacts saved alongside the main model
"""

torch_artifacts = partial(save_torch_artifacts)
moe_kernel_tuner_artifacts = partial(save_moe_kernel_tuner_artifacts)
torch_artifacts = member(save_torch_artifacts)
moe_kernel_tuner_artifacts = member(save_moe_kernel_tuner_artifacts)

def __call__(self, *args, **kwargs) -> None:
"""
Expand Down
50 changes: 27 additions & 23 deletions src/pruna/evaluation/metrics/metric_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,44 +150,48 @@ def ssim_update(
# Available metrics
class TorchMetrics(Enum):
"""
Enum for available torchmetrics.
Enumeration of torchmetrics metrics for evaluation.

The enum contains triplets of the metric class, the update function and the call type.
This enum provides a tuple per member (metric_factory, update_fn, call_type):
metric_factory builds the metric (typically a torchmetrics class, or
functools.partial when some constructor arguments are fixed); update_fn is
an optional custom update handler; call_type describes how inputs are paired
for the metric.

Parameters
----------
value : Callable
The function or class constructor for the metric.
names : List[str]
The available metric names.
value : tuple
Tuple holding metric_factory, update_fn, and call_type as described above.
names : str
The name of the enum member.
module : str
The module in which the metric is defined.
The module where the enum is defined.
qualname : str
Qualified name of the metric.
type : Type
The type of the enum value.
The qualified name of the enum.
type : type
The type of the enum.
start : int
The starting value for the enum.
The start index for auto-numbering enum values.
boundary : enum.FlagBoundary or None
Boundary handling mode used by the Enum functional API for Flag and
IntFlag enums.
"""

fid = (partial(FrechetInceptionDistance), fid_update, "gt_y")
accuracy = (partial(Accuracy), None, "y_gt")
perplexity = (partial(Perplexity), None, "y_gt")
clip_score = (partial(CLIPScore), None, "y_x")
precision = (partial(Precision), None, "y_gt")
recall = (partial(Recall), None, "y_gt")
fid = (FrechetInceptionDistance, fid_update, "gt_y")
accuracy = (Accuracy, None, "y_gt")
perplexity = (Perplexity, None, "y_gt")
clip_score = (CLIPScore, None, "y_x")
precision = (Precision, None, "y_gt")
recall = (Recall, None, "y_gt")
psnr = (partial(PeakSignalNoiseRatio, data_range=255.0), None, "pairwise_y_gt")
ssim = (partial(StructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt")
msssim = (partial(MultiScaleStructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt")
lpips = (partial(LearnedPerceptualImagePatchSimilarity), lpips_update, "pairwise_y_gt")
arniqa = (partial(ARNIQA), arniqa_update, "y")
clipiqa = (partial(CLIPImageQualityAssessment), None, "y")
ssim = (StructuralSimilarityIndexMeasure, ssim_update, "pairwise_y_gt")
msssim = (MultiScaleStructuralSimilarityIndexMeasure, ssim_update, "pairwise_y_gt")
lpips = (LearnedPerceptualImagePatchSimilarity, lpips_update, "pairwise_y_gt")
arniqa = (ARNIQA, arniqa_update, "y")
clipiqa = (CLIPImageQualityAssessment, None, "y")

def __init__(self, *args, **kwargs) -> None:
self.tm = self.value[0]
self.tm: Callable[..., Metric] = self.value[0]
self.update_fn = self.value[1] or default_update
self.call_type = self.value[2]

Expand Down
Loading