diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index 6edff441..74b04b56 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -18,7 +18,12 @@ import sys from copy import deepcopy from enum import Enum -from functools import partial + +try: + from enum import member +except ImportError: + # Python 3.10 compat: partial prevents Enum from treating functions as methods + from functools import partial as member from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union @@ -31,7 +36,11 @@ from pruna import SmashConfig from pruna.engine.load_artifacts import load_artifacts -from pruna.engine.utils import load_json_config, move_to_device, set_to_best_available_device +from pruna.engine.utils import ( + load_json_config, + move_to_device, + set_to_best_available_device, +) from pruna.logging.logger import pruna_logger PICKLED_FILE_NAME = "optimized_model.pt" @@ -244,7 +253,9 @@ def resmash(model: Any, smash_config: SmashConfig) -> Any: return smash(model=model, smash_config=smash_config_subset) -def load_transformers_model(path: str | Path, smash_config: SmashConfig | None = None, **kwargs) -> Any: +def load_transformers_model( + path: str | Path, smash_config: SmashConfig | None = None, **kwargs +) -> Any: """ Load a transformers model or pipeline from the given model path. @@ -286,11 +297,17 @@ def load_transformers_model(path: str | Path, smash_config: SmashConfig | None = device_map = "auto" else: device = smash_config.device if smash_config.device != "cuda" else "cuda:0" - device_map = smash_config.device_map if smash_config.device == "accelerate" else device + device_map = ( + smash_config.device_map + if smash_config.device == "accelerate" + else device + ) return cls.from_pretrained(path, device_map=device_map, **kwargs) -def load_diffusers_model(path: str | Path, smash_config: SmashConfig | None = None, **kwargs) -> Any: +def load_diffusers_model( + path: str | Path, smash_config: SmashConfig | None = None, **kwargs +) -> Any: """ Load a diffusers model from the given model path. @@ -366,7 +383,9 @@ def load_pickled(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: The loaded pickled model. """ # torch load has a target device but no interface to reproduce an accelerate-distributed model, we first map to cpu - target_device = "cpu" if smash_config.device == "accelerate" else smash_config.device + target_device = ( + "cpu" if smash_config.device == "accelerate" else smash_config.device + ) model = torch.load( Path(path) / PICKLED_FILE_NAME, weights_only=False, @@ -406,11 +425,17 @@ def load_hqq(model_path: str | Path, smash_config: SmashConfig, **kwargs) -> Any else: saved_smash_config = SmashConfig() saved_smash_config.load_from_json(model_path) - compute_dtype = torch.float16 if saved_smash_config["hqq_compute_dtype"] == "torch.float16" else torch.bfloat16 + compute_dtype = ( + torch.float16 + if saved_smash_config["hqq_compute_dtype"] == "torch.float16" + else torch.bfloat16 + ) has_config = (model_path / "config.json").exists() is_janus_model = ( - has_config and load_json_config(model_path, "config.json")["architectures"][0] == "JanusForConditionalGeneration" + has_config + and load_json_config(model_path, "config.json")["architectures"][0] + == "JanusForConditionalGeneration" ) def load_quantized_model(quantized_path: str | Path) -> Any: @@ -418,15 +443,21 @@ def load_quantized_model(quantized_path: str | Path) -> Any: quantized_model = algorithm_packages["HQQModelForCausalLM"].from_quantized( str(quantized_path), device=smash_config.device, - **filter_load_kwargs(algorithm_packages["HQQModelForCausalLM"].from_quantized, kwargs), + **filter_load_kwargs( + algorithm_packages["HQQModelForCausalLM"].from_quantized, kwargs + ), ) except Exception: # Default to generic HQQ pipeline if it fails - pruna_logger.info("Could not load HQQ model using HQQModelForCausalLM, trying generic AutoHQQHFModel...") + pruna_logger.info( + "Could not load HQQ model using HQQModelForCausalLM, trying generic AutoHQQHFModel..." + ) quantized_model = algorithm_packages["AutoHQQHFModel"].from_quantized( str(quantized_path), device=smash_config.device, compute_dtype=compute_dtype, - **filter_load_kwargs(algorithm_packages["AutoHQQHFModel"].from_quantized, kwargs), + **filter_load_kwargs( + algorithm_packages["AutoHQQHFModel"].from_quantized, kwargs + ), ) return quantized_model @@ -434,7 +465,9 @@ def load_quantized_model(quantized_path: str | Path) -> Any: # load the pipeline with a fake model on meta device with (model_path / PIPELINE_INFO_FILE_NAME).open("r") as f: task = json.load(f)["task"] - pipe = pipeline(task=task, model=str(model_path), model_kwargs={"device_map": "meta"}) + pipe = pipeline( + task=task, model=str(model_path), model_kwargs={"device_map": "meta"} + ) # load the quantized model pipe.model = load_quantized_model(model_path) move_to_device(pipe, smash_config.device) @@ -450,7 +483,9 @@ def load_quantized_model(quantized_path: str | Path) -> Any: # Janus language model must be patched to match HQQ's causal LM assumption quantized_path = str(hqq_model_dir / "qmodel.pt") weights = torch.load(quantized_path, map_location="cpu", weights_only=True) - is_already_patched = all(k == "lm_head" or k.startswith("model.") for k in weights) + is_already_patched = all( + k == "lm_head" or k.startswith("model.") for k in weights + ) if not is_already_patched: # load the weight on cpu to rename attr -> model.attr, weights = {f"model.{k}": v for k, v in weights.items()} @@ -459,7 +494,9 @@ def load_quantized_model(quantized_path: str | Path) -> Any: torch.save(weights, quantized_path) # patch weights quantized_causal_lm = load_quantized_model(hqq_model_dir) - model.model.language_model = quantized_causal_lm.model # drop the lm_head and causal_lm wrapper + model.model.language_model = ( + quantized_causal_lm.model + ) # drop the lm_head and causal_lm wrapper # some weights of the language_model are not on the correct device, so we move it afterwards. move_to_device(model, smash_config.device) return model @@ -498,8 +535,12 @@ def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) -> ) hf_quantizer = HQQDiffusers() - auto_hqq_hf_diffusers_model = construct_base_class(hf_quantizer.import_algorithm_packages(), []) - quantized_load_kwargs = filter_load_kwargs(auto_hqq_hf_diffusers_model.from_quantized, kwargs) + auto_hqq_hf_diffusers_model = construct_base_class( + hf_quantizer.import_algorithm_packages(), [] + ) + quantized_load_kwargs = filter_load_kwargs( + auto_hqq_hf_diffusers_model.from_quantized, kwargs + ) path = Path(path) if "compute_dtype" not in kwargs and (path / "dtype_info.json").exists(): @@ -508,7 +549,9 @@ def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) -> kwargs.setdefault("torch_dtype", dtype) qmodel_path = path / "qmodel.pt" - if qmodel_path.exists(): # the whole model was quantized, not a pipeline, load it directly + if ( + qmodel_path.exists() + ): # the whole model was quantized, not a pipeline, load it directly model = auto_hqq_hf_diffusers_model.from_quantized(path, **kwargs) # force dtype if specified, from_quantized does not set it properly if "torch_dtype" in kwargs: @@ -520,7 +563,9 @@ def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) -> # that can be quantized and saved separately. # by convention, each component has been saved in a directory f"{attr_name}_quantized". quantized_components: dict[str, Any] = {} - for quantized_path in [qpath for qpath in path.iterdir() if qpath.name.endswith("_quantized")]: + for quantized_path in [ + qpath for qpath in path.iterdir() if qpath.name.endswith("_quantized") + ]: attr_name = quantized_path.name.replace("_quantized", "") # legacy behavior: backbone_quantized -> target the transformer or unet attribute @@ -587,11 +632,11 @@ class LOAD_FUNCTIONS(Enum): # noqa: N801 """ - transformers = partial(load_transformers_model) - diffusers = partial(load_diffusers_model) - pickled = partial(load_pickled) - hqq = partial(load_hqq) - hqq_diffusers = partial(load_hqq_diffusers) + transformers = member(load_transformers_model) + diffusers = member(load_diffusers_model) + pickled = member(load_pickled) + hqq = member(load_hqq) + hqq_diffusers = member(load_hqq_diffusers) def __call__(self, *args, **kwargs) -> Any: """ @@ -636,7 +681,10 @@ def filter_load_kwargs(func: Callable, kwargs: dict) -> dict: signature = inspect.signature(func) # Check if function accepts arbitrary kwargs - has_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()) + has_kwargs = any( + param.kind == inspect.Parameter.VAR_KEYWORD + for param in signature.parameters.values() + ) if has_kwargs: return kwargs @@ -648,6 +696,8 @@ def filter_load_kwargs(func: Callable, kwargs: dict) -> dict: # Log the discarded kwargs if invalid_kwargs: - pruna_logger.info(f"Discarded unused loading kwargs: {list(invalid_kwargs.keys())}") + pruna_logger.info( + f"Discarded unused loading kwargs: {list(invalid_kwargs.keys())}" + ) return valid_kwargs diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py index 607fe641..ec454069 100644 --- a/src/pruna/engine/load_artifacts.py +++ b/src/pruna/engine/load_artifacts.py @@ -15,7 +15,12 @@ import json from enum import Enum -from functools import partial + +try: + from enum import member +except ImportError: + # Python 3.10 compat: partial prevents Enum from treating functions as methods + from functools import partial as member from pathlib import Path from typing import Any @@ -181,8 +186,8 @@ class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 # Torch artifacts loaded into the current runtime """ - torch_artifacts = partial(load_torch_artifacts) - moe_kernel_tuner_artifacts = partial(load_moe_kernel_tuner_artifacts) + torch_artifacts = member(load_torch_artifacts) + moe_kernel_tuner_artifacts = member(load_moe_kernel_tuner_artifacts) def __call__(self, *args, **kwargs) -> None: """Call the underlying load function.""" diff --git a/src/pruna/engine/save.py b/src/pruna/engine/save.py index 50f9ca78..27101b31 100644 --- a/src/pruna/engine/save.py +++ b/src/pruna/engine/save.py @@ -18,7 +18,12 @@ import shutil import tempfile from enum import Enum -from functools import partial + +try: + from enum import member +except ImportError: + # Python 3.10 compat: partial prevents Enum from treating functions as methods + from functools import partial as member from pathlib import Path from typing import TYPE_CHECKING, Any, List, cast @@ -513,11 +518,11 @@ class SAVE_FUNCTIONS(Enum): # noqa: N801 # Model saved to disk in pickled format """ - pickled = partial(save_pickled) - hqq = partial(save_model_hqq) - hqq_diffusers = partial(save_model_hqq_diffusers) - save_before_apply = partial(save_before_apply) - reapply = partial(reapply) + pickled = member(save_pickled) + hqq = member(save_model_hqq) + hqq_diffusers = member(save_model_hqq_diffusers) + save_before_apply = member(save_before_apply) + reapply = member(reapply) def __call__(self, *args, **kwargs) -> None: """ diff --git a/src/pruna/engine/save_artifacts.py b/src/pruna/engine/save_artifacts.py index 5acf58a5..c1090085 100644 --- a/src/pruna/engine/save_artifacts.py +++ b/src/pruna/engine/save_artifacts.py @@ -13,7 +13,12 @@ # limitations under the License. import shutil from enum import Enum -from functools import partial + +try: + from enum import member +except ImportError: + # Python 3.10 compat: partial prevents Enum from treating functions as methods + from functools import partial as member from pathlib import Path from typing import Any @@ -147,8 +152,8 @@ class SAVE_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 # Torch artifacts saved alongside the main model """ - torch_artifacts = partial(save_torch_artifacts) - moe_kernel_tuner_artifacts = partial(save_moe_kernel_tuner_artifacts) + torch_artifacts = member(save_torch_artifacts) + moe_kernel_tuner_artifacts = member(save_moe_kernel_tuner_artifacts) def __call__(self, *args, **kwargs) -> None: """ diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index cd96da78..4d329d86 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -150,44 +150,48 @@ def ssim_update( # Available metrics class TorchMetrics(Enum): """ - Enum for available torchmetrics. + Enumeration of torchmetrics metrics for evaluation. - The enum contains triplets of the metric class, the update function and the call type. + This enum provides a tuple per member (metric_factory, update_fn, call_type): + metric_factory builds the metric (typically a torchmetrics class, or + functools.partial when some constructor arguments are fixed); update_fn is + an optional custom update handler; call_type describes how inputs are paired + for the metric. Parameters ---------- - value : Callable - The function or class constructor for the metric. - names : List[str] - The available metric names. + value : tuple + Tuple holding metric_factory, update_fn, and call_type as described above. + names : str + The name of the enum member. module : str - The module in which the metric is defined. + The module where the enum is defined. qualname : str - Qualified name of the metric. - type : Type - The type of the enum value. + The qualified name of the enum. + type : type + The type of the enum. start : int - The starting value for the enum. + The start index for auto-numbering enum values. boundary : enum.FlagBoundary or None Boundary handling mode used by the Enum functional API for Flag and IntFlag enums. """ - fid = (partial(FrechetInceptionDistance), fid_update, "gt_y") - accuracy = (partial(Accuracy), None, "y_gt") - perplexity = (partial(Perplexity), None, "y_gt") - clip_score = (partial(CLIPScore), None, "y_x") - precision = (partial(Precision), None, "y_gt") - recall = (partial(Recall), None, "y_gt") + fid = (FrechetInceptionDistance, fid_update, "gt_y") + accuracy = (Accuracy, None, "y_gt") + perplexity = (Perplexity, None, "y_gt") + clip_score = (CLIPScore, None, "y_x") + precision = (Precision, None, "y_gt") + recall = (Recall, None, "y_gt") psnr = (partial(PeakSignalNoiseRatio, data_range=255.0), None, "pairwise_y_gt") - ssim = (partial(StructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt") - msssim = (partial(MultiScaleStructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt") - lpips = (partial(LearnedPerceptualImagePatchSimilarity), lpips_update, "pairwise_y_gt") - arniqa = (partial(ARNIQA), arniqa_update, "y") - clipiqa = (partial(CLIPImageQualityAssessment), None, "y") + ssim = (StructuralSimilarityIndexMeasure, ssim_update, "pairwise_y_gt") + msssim = (MultiScaleStructuralSimilarityIndexMeasure, ssim_update, "pairwise_y_gt") + lpips = (LearnedPerceptualImagePatchSimilarity, lpips_update, "pairwise_y_gt") + arniqa = (ARNIQA, arniqa_update, "y") + clipiqa = (CLIPImageQualityAssessment, None, "y") def __init__(self, *args, **kwargs) -> None: - self.tm = self.value[0] + self.tm: Callable[..., Metric] = self.value[0] self.update_fn = self.value[1] or default_update self.call_type = self.value[2]