diff --git a/docs/user_manual/configure.rst b/docs/user_manual/configure.rst index 4bfb8a67..f1cbf9cd 100644 --- a/docs/user_manual/configure.rst +++ b/docs/user_manual/configure.rst @@ -253,7 +253,7 @@ Underneath you can find the list of all the available datasets. - ``text: str`` * - Image Generation - `LAION256 `_, `OpenImage `_, `COCO `_, `DrawBench `_, `PartiPrompts `_, `GenAIBench `_ - - ``image_generation_collate``, ``prompt_collate`` + - ``image_generation_collate``, ``prompt_with_auxiliaries_collate`` - ``text: str``, ``image: Optional[PIL.Image.Image]`` * - Image Classification - `ImageNet `_, `MNIST `_, `CIFAR10 `_ diff --git a/pyproject.toml b/pyproject.toml index 5b1eb704..86b69558 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,6 +165,10 @@ vllm = [ "vllm>=0.16.0", "ray", ] +evaluation = [ + "outlines>1.2.0,<2.0.0", + "litellm>=1.0.0", +] stable-fast = [ "xformers>=0.0.30", "stable-fast-pruna==1.0.8", @@ -217,6 +221,7 @@ dev = [ "types-PyYAML", "logbar", "pytest-xdist>=3.8.0", + "pruna[evaluation]", ] cpu = [] lmharness = [ diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index fd14a496..1a733662 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -103,13 +103,13 @@ "image_classification_collate", {"img_size": 224}, ), - "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), + "DrawBench": (setup_drawbench_dataset, "prompt_with_auxiliaries_collate", {}), "PartiPrompts": ( setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}, ), - "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), + "GenAIBench": (setup_genai_bench_dataset, "prompt_with_auxiliaries_collate", {}), "GenEval": (setup_geneval_dataset, "prompt_with_auxiliaries_collate", {}), "HPS": (setup_hps_dataset, "prompt_with_auxiliaries_collate", {}), "ImgEdit": (setup_imgedit_dataset, "prompt_with_auxiliaries_collate", {}), diff --git a/src/pruna/data/collate.py b/src/pruna/data/collate.py index 2803ee50..8a5df457 100644 --- a/src/pruna/data/collate.py +++ b/src/pruna/data/collate.py @@ -321,6 +321,5 @@ def question_answering_collate( "image_classification_collate": image_classification_collate, "text_generation_collate": text_generation_collate, "question_answering_collate": question_answering_collate, - "prompt_collate": prompt_collate, "prompt_with_auxiliaries_collate": prompt_with_auxiliaries_collate, } diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 7764d23b..c656aa61 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -123,6 +123,14 @@ DPGCategory = Literal["entity", "attribute", "relation", "global", "other"] +def _warn_ignored_benchmark_seed(seed: int | None, *, dataset: str) -> None: + if seed is not None: + pruna_logger.warning( + "%s: `seed` is ignored for this test-only benchmark; sampling does not shuffle the test split.", + dataset, + ) + + def _to_oneig_record(row: dict, questions_by_key: dict[str, dict]) -> dict: """Convert OneIG row to unified record format.""" row_category = row.get("category", "") @@ -159,7 +167,7 @@ def setup_drawbench_dataset() -> Tuple[Dataset, Dataset, Dataset]: def setup_parti_prompts_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -172,8 +180,8 @@ def setup_parti_prompts_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -188,6 +196,7 @@ def setup_parti_prompts_dataset( Tuple[Dataset, Dataset, Dataset] The Parti Prompts dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="PartiPrompts") ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index] if category is not None: @@ -226,7 +235,7 @@ def _generate_geneval_question(entry: dict) -> list[str]: def setup_geneval_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -239,8 +248,8 @@ def setup_geneval_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -255,6 +264,7 @@ def setup_geneval_dataset( Tuple[Dataset, Dataset, Dataset] The GenEval dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="GenEval") import json import requests @@ -286,7 +296,7 @@ def setup_geneval_dataset( def setup_hps_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -299,8 +309,8 @@ def setup_hps_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -315,6 +325,7 @@ def setup_hps_dataset( Tuple[Dataset, Dataset, Dataset] The HPD dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="HPS") import json from huggingface_hub import hf_hub_download @@ -338,7 +349,7 @@ def setup_hps_dataset( def setup_long_text_bench_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -350,8 +361,8 @@ def setup_long_text_bench_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -364,6 +375,7 @@ def setup_long_text_bench_dataset( Tuple[Dataset, Dataset, Dataset] The Long Text Bench dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="LongTextBench") ds = load_dataset("X-Omni/LongText-Bench")["train"] # type: ignore[index] ds = ds.rename_column("text", "text_content") ds = ds.rename_column("prompt", "text") @@ -390,7 +402,7 @@ def setup_genai_bench_dataset() -> Tuple[Dataset, Dataset, Dataset]: def setup_imgedit_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -403,8 +415,8 @@ def setup_imgedit_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -420,6 +432,7 @@ def setup_imgedit_dataset( Tuple[Dataset, Dataset, Dataset] The ImgEdit dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="ImgEdit") import json import requests @@ -493,7 +506,7 @@ def _fetch_oneig_alignment() -> dict[str, dict]: def setup_oneig_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -506,8 +519,8 @@ def setup_oneig_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -523,6 +536,7 @@ def setup_oneig_dataset( Tuple[Dataset, Dataset, Dataset] The OneIG dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="OneIG") questions_by_key = _fetch_oneig_alignment() ds_raw = load_dataset("OneIG-Bench/OneIG-Bench", "OneIG-Bench")["train"] # type: ignore[index] @@ -545,7 +559,7 @@ def setup_oneig_dataset( def setup_gedit_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -558,8 +572,8 @@ def setup_gedit_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -576,6 +590,7 @@ def setup_gedit_dataset( Tuple[Dataset, Dataset, Dataset] The GEditBench dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="GEditBench") task_type_map = { "subject_add": "subject-add", "subject_remove": "subject-remove", @@ -613,7 +628,7 @@ def setup_gedit_dataset( def setup_dpg_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -626,8 +641,8 @@ def setup_dpg_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -642,6 +657,7 @@ def setup_dpg_dataset( Tuple[Dataset, Dataset, Dataset] The DPG dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="DPG") import csv import io from collections import defaultdict diff --git a/src/pruna/data/pruna_datamodule.py b/src/pruna/data/pruna_datamodule.py index 6d1eaadd..03003127 100644 --- a/src/pruna/data/pruna_datamodule.py +++ b/src/pruna/data/pruna_datamodule.py @@ -135,7 +135,7 @@ def from_string( tokenizer: AutoTokenizer | None = None, collate_fn_args: dict = dict(), dataloader_args: dict = dict(), - seed: int = 42, + seed: int | None = None, category: str | list[str] | None = None, fraction: float = 1.0, train_sample_size: int | None = None, @@ -154,8 +154,10 @@ def from_string( Any additional arguments for the collate function. dataloader_args : dict Any additional arguments for the dataloader. - seed : int - The seed to use. + seed : int | None, optional + Passed to dataset setup when the loader uses shuffled sampling. + If None, setups that require a seed default to 42; test-only benchmarks + omit seed so ordering stays deterministic without warnings. category : str | list[str] | None The category of the dataset. fraction : float @@ -177,7 +179,12 @@ def from_string( collate_fn_args = default_collate_fn_args if "seed" in inspect.signature(setup_fn).parameters: - setup_fn = partial(setup_fn, seed=seed) + seed_param = inspect.signature(setup_fn).parameters["seed"] + has_default = seed_param.default is not inspect.Parameter.empty + if seed is not None: + setup_fn = partial(setup_fn, seed=seed) + elif not has_default: + setup_fn = partial(setup_fn, seed=42) if "category" in inspect.signature(setup_fn).parameters: setup_fn = partial(setup_fn, category=category) diff --git a/src/pruna/data/utils.py b/src/pruna/data/utils.py index 2096f9e6..b021130d 100644 --- a/src/pruna/data/utils.py +++ b/src/pruna/data/utils.py @@ -34,18 +34,34 @@ from pruna.logging.logger import pruna_logger -class TokenizerMissingError(Exception): - """ - Custom exception raised when a tokenizer is required but not provided. +def _extract_literal_values(annotation: Any) -> list[str] | None: + if _is_none_annotation(annotation): + return None + literal_values = _literal_string_values(annotation) + if literal_values is not None: + return literal_values + for arg in _annotation_args(annotation): + found = _extract_literal_values(arg) + if found is not None: + return found + return None - Parameters - ---------- - message : str, optional - The message to display when the exception is raised. - """ - def __init__(self, message: str = "Tokenizer is missing. Please provide a valid tokenizer.") -> None: - super().__init__(message) +def _is_none_annotation(annotation: Any) -> bool: + return annotation is None or annotation is type(None) + + +def _literal_string_values(annotation: Any) -> list[str] | None: + if get_origin(annotation) is not Literal: + return None + args = get_args(annotation) + if args and all(isinstance(arg, str) for arg in args): + return list(args) + return None + + +def _annotation_args(annotation: Any) -> tuple[Any, ...]: + return get_args(annotation) or () def get_literal_values_from_param(func: Callable[..., Any], param_name: str) -> list[str] | None: @@ -78,18 +94,21 @@ def get_literal_values_from_param(func: Callable[..., Any], param_name: str) -> except Exception: return None - def extract(ann: Any) -> list[str] | None: - if ann is None or ann is type(None): - return None - if get_origin(ann) is Literal: - args = get_args(ann) - return list(args) if args and all(isinstance(a, str) for a in args) else None - for arg in get_args(ann) or (): - if (r := extract(arg)) is not None: - return r - return None + return _extract_literal_values(ann) + + +class TokenizerMissingError(Exception): + """ + Custom exception raised when a tokenizer is required but not provided. - return extract(ann) + Parameters + ---------- + message : str, optional + The message to display when the exception is raised. + """ + + def __init__(self, message: str = "Tokenizer is missing. Please provide a valid tokenizer.") -> None: + super().__init__(message) def split_train_into_train_val_test(dataset: Dataset | IterableDataset, seed: int) -> Tuple[Dataset, Dataset, Dataset]: diff --git a/src/pruna/evaluation/benchmarks.py b/src/pruna/evaluation/benchmarks.py index e52ae463..f4ad5f12 100644 --- a/src/pruna/evaluation/benchmarks.py +++ b/src/pruna/evaluation/benchmarks.py @@ -174,7 +174,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "Covers basic skills (scene, attributes, spatial relationships) to advanced reasoning " "(counting, comparison, logic/negation) with over 24k human ratings." ), - metrics=[], # Paper uses VQAScore only; not in Pruna + metrics=["vqa"], task_type="text_to_image", reference="https://arxiv.org/abs/2406.13743", ), @@ -226,7 +226,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "counting, colors, position, color attributes. Evaluates fine-grained alignment " "between prompts and generated images via VQA-style questions." ), - metrics=["clip_score"], # §3.2: Mask2Former; not in Pruna + metrics=["qa_accuracy"], task_type="text_to_image", reference="https://arxiv.org/abs/2310.11513", ), @@ -246,7 +246,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "Image editing benchmark with 8 edit types: replace, add, remove, adjust, extract, " "style, background, compose. Evaluates instruction-following for inpainting and editing." ), - metrics=[], # Paper uses GPT-4o/ImgEdit-Judge; not in Pruna + metrics=["img_edit_score"], task_type="text_to_image", reference="https://arxiv.org/abs/2505.20275", ), @@ -256,7 +256,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "Text-to-image benchmark for long, detailed prompts. Evaluates model ability to " "handle complex multi-clause descriptions and maintain coherence across long instructions." ), - metrics=[], # Paper uses text_score/TIT-Score; not in Pruna + metrics=["text_score"], task_type="text_to_image", reference="https://arxiv.org/abs/2507.22058", ), @@ -267,7 +267,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "material alter, motion change, style change, subject add/remove/replace, text change, " "tone transfer, and human retouching." ), - metrics=[], # Paper uses VIEScore; not in Pruna + metrics=["viescore"], task_type="text_to_image", reference="https://arxiv.org/abs/2504.17761", ), @@ -278,7 +278,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "(Anime_Stylization, General_Object, Knowledge_Reasoning, Multilingualism, Portrait, " "Text_Rendering) plus fine-grained style classes. Includes alignment questions." ), - metrics=[], # Paper uses dimension-specific metrics; not in Pruna + metrics=["qa_accuracy"], task_type="text_to_image", reference="https://arxiv.org/abs/2506.07977", ), @@ -288,7 +288,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "Dense Prompt Graph benchmark. Evaluates entity, attribute, relation, " "global, and other descriptive aspects with natural-language questions for alignment." ), - metrics=[], # Paper uses custom evaluation; not in Pruna + metrics=["qa_accuracy"], task_type="text_to_image", reference="https://arxiv.org/abs/2403.05135", ), diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 5b713dea..3e20e4a5 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -112,8 +112,8 @@ def from_benchmark( Examples -------- - >>> agent = EvaluationAgent.from_benchmark("Parti Prompts", model) - >>> agent = EvaluationAgent.from_benchmark("HPS", model, category="anime", fraction=0.1) + >>> agent = EvaluationAgent.from_benchmark("Parti Prompts") + >>> agent = EvaluationAgent.from_benchmark("HPS", category="anime", fraction=0.1) """ task = Task.from_benchmark( benchmark_name, diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index 1a12f623..82d6aff5 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -15,16 +15,28 @@ from pruna.evaluation.metrics.registry import MetricRegistry # isort:skip from pruna.evaluation.metrics.aesthetic_laion import AestheticLAION +from pruna.evaluation.metrics.metric_alignment_score import AlignmentScoreMetric from pruna.evaluation.metrics.metric_cmmd import CMMD from pruna.evaluation.metrics.metric_dino_score import DinoScore from pruna.evaluation.metrics.metric_elapsed_time import LatencyMetric, ThroughputMetric, TotalTimeMetric from pruna.evaluation.metrics.metric_energy import CO2EmissionsMetric, EnergyConsumedMetric from pruna.evaluation.metrics.metric_evalharness import LMEvalMetric +from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric +from pruna.evaluation.metrics.metric_text_score import TextScoreMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper +from pruna.evaluation.metrics.metric_viescore import VieScoreMetric +from pruna.evaluation.metrics.metric_vqa import VQAMetric +from pruna.evaluation.metrics.vlm_base import ( + BaseVLM, + LitellmVLM, + TransformersVLM, + get_vlm, +) __all__ = [ "MetricRegistry", @@ -44,5 +56,15 @@ "DinoScore", "SharpnessMetric", "AestheticLAION", + "VQAMetric", + "AlignmentScoreMetric", + "ImageEditScoreMetric", + "QAAccuracyMetric", + "TextScoreMetric", + "VieScoreMetric", + "BaseVLM", + "LitellmVLM", + "TransformersVLM", + "get_vlm", "LMEvalMetric", ] diff --git a/src/pruna/evaluation/metrics/metric_alignment_score.py b/src/pruna/evaluation/metrics/metric_alignment_score.py new file mode 100644 index 00000000..ae0cfe6d --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_alignment_score.py @@ -0,0 +1,138 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Alignment Score metric using VLM for image-text alignment evaluation.""" + +from __future__ import annotations + +from typing import Any, List, Literal, Optional + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.metric_vlm_utils import VQAnswer, _process_images +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm + + +@MetricRegistry.register("alignment_score") +class AlignmentScoreMetric(StatefulMetric): + """ + Alignment Score metric using VLM. + + Assesses how well generated images match text prompts through structured questioning. + Higher scores indicate better alignment. + + Parameters + ---------- + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, vlm_type and model_name are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str, optional + Model name. Default is "gpt-4o". + vlm_kwargs : dict, optional + Extra kwargs for VLM init (e.g. model_load_kwargs for transformers). + structured_output : bool, optional + Use structured generation. Default is True. + use_outlines : bool, optional + Use outlines for transformers. Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. Default is "cpu". + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments forwarded to the VLM backend constructor. + """ + + scores: List[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "alignment_score" + runs_on: List[str] = ["cpu"] + + def __init__( + self, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + vlm_kwargs: Optional[dict] = None, + structured_output: bool = True, + use_outlines: bool = True, + device: str | torch.device = "cpu", + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + use_outlines=use_outlines, + **(vlm_kwargs or {}), + ) + self.response_format = VQAnswer if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (prompts). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], list) else [""] * len(images) + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + question = f'Does this image show "{prompt}"?' + score = self.vlm.score([image], [question], ["Yes"], response_format=self.response_format)[0] + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the alignment score. + + Returns + ------- + MetricResult + The mean alignment score across all updates. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/metric_img_edit_score.py b/src/pruna/evaluation/metrics/metric_img_edit_score.py new file mode 100644 index 00000000..8204a9ea --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_img_edit_score.py @@ -0,0 +1,157 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Image Edit Score metric. + +VLM-based instruction-following score for image editing. Evaluates how well an edited image +follows the given editing instruction on a 0-10 scale. Related work: EditScore (arXiv:2509.23909), +ADIEE (ICCV 2025). +""" + +from __future__ import annotations + +import re +from typing import Any, List, Literal, Optional + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.metric_vlm_utils import FloatOutput, _process_images +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm + + +@MetricRegistry.register("img_edit_score") +class ImageEditScoreMetric(StatefulMetric): + """ + Image Edit Score metric. + + VLM-based instruction-following score for image editing. Evaluates how well an edited image + follows the given editing instruction. Higher scores indicate better editing quality. + + Related work: EditScore (arXiv:2509.23909), ADIEE (ICCV 2025). + + Parameters + ---------- + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, vlm_type and model_name are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str, optional + Model name. Default is "gpt-4o". + vlm_kwargs : dict, optional + Extra kwargs for VLM init (e.g. model_load_kwargs for transformers). + structured_output : bool, optional + Use structured generation. Default is True. + use_outlines : bool, optional + Use outlines for transformers. Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. Default is "cpu". + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments forwarded to the VLM backend constructor. + """ + + scores: List[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "img_edit_score" + runs_on: List[str] = ["cpu"] + + def __init__( + self, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + vlm_kwargs: Optional[dict] = None, + structured_output: bool = True, + use_outlines: bool = True, + device: str | torch.device = "cpu", + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + use_outlines=use_outlines, + **(vlm_kwargs or {}), + ) + self.response_format = FloatOutput if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (editing instructions). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output (edited) images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], list) else [""] * len(images) + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + question = ( + f'On a scale of 0 to 10, how well does this edited image follow the instruction "{prompt}"? ' + "0 = instruction not followed at all, 10 = perfectly executed. Reply with a single number." + ) + responses = self.vlm.generate([image], [question], response_format=self.response_format) + score = self._parse_score(responses[0]) + self.scores.append(score) + + def _parse_score(self, response: str) -> float: + if isinstance(response, str): + numbers = re.findall(r"\d+", response) + return min(float(numbers[0]), 10.0) / 10.0 if numbers else 0.0 + return 0.0 + + def compute(self) -> MetricResult: + """ + Compute the image edit score. + + Returns + ------- + MetricResult + The mean image edit score across all updates. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/metric_qa_accuracy.py b/src/pruna/evaluation/metrics/metric_qa_accuracy.py new file mode 100644 index 00000000..647586d2 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_qa_accuracy.py @@ -0,0 +1,164 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""QA Accuracy metric using VLM for image understanding evaluation.""" + +from __future__ import annotations + +from typing import Any, List, Literal, Optional + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.metric_vlm_utils import VQAnswer, _process_images +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm + + +@MetricRegistry.register("qa_accuracy") +class QAAccuracyMetric(StatefulMetric): + """ + QA Accuracy metric. + + Uses VLM to answer questions about images. + Higher scores indicate better image understanding. + + Parameters + ---------- + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, vlm_type and model_name are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str, optional + Model name. Default is "gpt-4o". + vlm_kwargs : dict, optional + Extra kwargs for VLM init (e.g. model_load_kwargs for transformers). + structured_output : bool, optional + Use structured generation. Default is True. + use_outlines : bool, optional + Use outlines for transformers. Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. Default is "cpu". + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments forwarded to the VLM backend constructor. + """ + + scores: List[float] + default_call_type: str = "y_gt" + higher_is_better: bool = True + metric_name: str = "qa_accuracy" + runs_on: List[str] = ["cpu"] + + def __init__( + self, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + vlm_kwargs: Optional[dict] = None, + structured_output: bool = True, + use_outlines: bool = True, + device: str | torch.device = "cpu", + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + use_outlines=use_outlines, + **(vlm_kwargs or {}), + ) + self.response_format = VQAnswer if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def _extract_questions(self, gt: Any, n: int) -> List[List[str]]: + if isinstance(gt, (list, tuple)) and len(gt) >= n: + out = [] + for i in range(n): + v = gt[i] + if isinstance(v, dict) and "questions" in v: + qs = v["questions"] + out.append(list(qs.values()) if isinstance(qs, dict) else list(qs)) + else: + out.append([]) + return out + return [[]] * n + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data. + gt : torch.Tensor + The ground truth (questions per image). + outputs : torch.Tensor + The output images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + auxiliaries = inputs[1] if len(inputs) > 1 else [] + questions_per_image = self._extract_questions(auxiliaries, len(images)) + for i, image in enumerate(images): + questions = questions_per_image[i] if i < len(questions_per_image) else [] + if not questions: + aux = auxiliaries[i] if i < len(auxiliaries) else {} + raise ValueError( + "qa_accuracy requires 'questions' in auxiliaries. " + "Use a benchmark that provides it (e.g. GenEval, DPG, OneIG). " + f"Got aux keys: {list(aux.keys()) if isinstance(aux, dict) else 'not a dict'}." + ) + scores = self.vlm.score( + [image] * len(questions), + questions, + ["Yes"] * len(questions), + response_format=self.response_format, + ) + score = float(np.mean(scores)) + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the QA accuracy score. + + Returns + ------- + MetricResult + The mean QA accuracy across all updates. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/metric_stateful.py b/src/pruna/evaluation/metrics/metric_stateful.py index 39fddcf6..5aa33ad9 100644 --- a/src/pruna/evaluation/metrics/metric_stateful.py +++ b/src/pruna/evaluation/metrics/metric_stateful.py @@ -91,7 +91,7 @@ def forward(self, *args, **kwargs) -> None: **kwargs : Any The keyword arguments to pass to the metric. """ - pass + ... def reset(self) -> None: """ @@ -109,16 +109,18 @@ def reset(self) -> None: getattr(self, attr).clear() @abstractmethod - def update(self, *args, **kwargs) -> None: + def update(self, x: Any, gt: Any, outputs: Any) -> None: """ Override this method to update the state variables of your metric. Parameters ---------- - *args : Any - The arguments to pass to the metric. - **kwargs : Any - The keyword arguments to pass to the metric. + x : Any + Input/prompt data for the metric. + gt : Any + Ground-truth or auxiliary data for the metric. + outputs : Any + Model outputs to evaluate. """ @abstractmethod diff --git a/src/pruna/evaluation/metrics/metric_text_score.py b/src/pruna/evaluation/metrics/metric_text_score.py new file mode 100644 index 00000000..6e618391 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_text_score.py @@ -0,0 +1,174 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text Score metric for evaluating text rendering in images using VLM OCR.""" + +from __future__ import annotations + +import re +from typing import Any, List, Literal, Optional + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.metric_vlm_utils import TextOutput, _process_images, get_text_from_response +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm + +OCR_PROMPT = ( + "Extract all text visible in this image. Include logos, stylized fonts, handwritten text, " + "and non-standard typography. Return only the extracted text, exactly as it appears—no preamble, " + "explanation, or markdown. Preserve words, numbers, punctuation, and spacing. " + "If no text is recognized, reply with exactly: No text recognized" +) + + +@MetricRegistry.register("text_score") +class TextScoreMetric(StatefulMetric): + """ + Text Score metric for evaluating text rendering in images. + + Uses VLM for OCR to extract text and compare with ground truth. + Lower scores (edit distance) are better. + + Parameters + ---------- + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, vlm_type and model_name are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str, optional + Model name. Default is "gpt-4o". + vlm_kwargs : dict, optional + Extra kwargs for VLM init (e.g. model_load_kwargs for transformers). + structured_output : bool, optional + Use structured generation. Default is True. + use_outlines : bool, optional + Use outlines for transformers. Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. Default is "cpu". + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments forwarded to the VLM backend constructor. + """ + + scores: List[float] + default_call_type: str = "y_gt" + higher_is_better: bool = False + metric_name: str = "text_score" + runs_on: List[str] = ["cuda", "cpu"] + + def __init__( + self, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + vlm_kwargs: Optional[dict] = None, + structured_output: bool = True, + use_outlines: bool = True, + device: str | torch.device = "cpu", + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + use_outlines=use_outlines, + **(vlm_kwargs or {}), + ) + self.response_format = TextOutput if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + @staticmethod + def _normalize_text(s: str) -> str: + cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z0-9\sàâäéèêëîïôöùûüçÀÂÄÉÈÊËÎÏÔÖÙÛÜÇ]", "", s or "") + return re.sub(r"\s+", " ", cleaned).strip() + + @staticmethod + def _levenshtein(s1: str, s2: str) -> float: + if len(s1) < len(s2): + return TextScoreMetric._levenshtein(s2, s1) + prev = list(range(len(s2) + 1)) + for i, c1 in enumerate(s1): + curr = [i + 1] + for j, c2 in enumerate(s2): + curr.append(min(prev[j] + (c1 != c2), prev[j + 1] + 1, curr[-1] + 1)) + prev = curr + return float(prev[-1]) + + def update(self, x: List[Any] | torch.Tensor, gt: List[str], outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (prompts). + gt : List[dict] | List[str] + Ground truth auxiliaries. Each item must have 'text_content' key (e.g. from + LongTextBench, OneIG). Or a list of strings for backward compatibility. + outputs : torch.Tensor + The output images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + auxiliaries = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], (list, tuple)) else [{}] * len(images) + for i, image in enumerate(images): + responses = self.vlm.generate([image], [OCR_PROMPT], response_format=self.response_format) + raw = responses[0] if responses else "" + ocr_text = get_text_from_response(raw) + aux = auxiliaries[i] if i < len(auxiliaries) else {} + text_gt = aux.get("text_content") if isinstance(aux, dict) else (aux if isinstance(aux, str) else None) + if text_gt is None: + raise ValueError( + "text_score requires 'text_content' in auxiliaries. " + "Use a benchmark that provides it (e.g. LongTextBench, OneIG)." + ) + norm_gt = self._normalize_text(text_gt) + norm_ocr = self._normalize_text(ocr_text) + score = self._levenshtein(norm_ocr, norm_gt) + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the text score. + + Returns + ------- + MetricResult + The mean text score (edit distance) across all updates. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/metric_viescore.py b/src/pruna/evaluation/metrics/metric_viescore.py new file mode 100644 index 00000000..2bc0c044 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_viescore.py @@ -0,0 +1,176 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VIEScore metric for evaluating conditional image synthesis (semantic + quality). + +Reference: VIEScore: Towards Explainable Metrics for Conditional Image Synthesis Evaluation +(ACL 2024) - https://arxiv.org/abs/2312.14867, https://github.com/TIGER-AI-Lab/VIEScore +""" + +from __future__ import annotations + +import math +import re +from typing import Any, List, Literal, Optional + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.metric_vlm_utils import FloatOutput, _process_images +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm + + +@MetricRegistry.register("viescore") +class VieScoreMetric(StatefulMetric): + """ + VIEScore metric for evaluating conditional image synthesis (semantic + quality). + + Uses VLM to assess both semantic alignment and visual quality. + Higher scores indicate better overall quality. + + Computes: + - Semantic score: How well image follows prompt + - Quality score: Naturalness and artifacts + - Overall: Geometric mean of semantic and quality + + Parameters + ---------- + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, vlm_type and model_name are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str, optional + Model name. Default is "gpt-4o". + vlm_kwargs : dict, optional + Extra kwargs for VLM init (e.g. model_load_kwargs for transformers). + structured_output : bool, optional + Use structured generation. Default is True. + use_outlines : bool, optional + Use outlines for transformers. Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. Default is "cpu". + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments forwarded to the VLM backend constructor. + + References + ---------- + VIEScore: Towards Explainable Metrics for Conditional Image Synthesis Evaluation (ACL 2024) + https://arxiv.org/abs/2312.14867 + https://github.com/TIGER-AI-Lab/VIEScore + """ + + scores: List[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "viescore" + runs_on: List[str] = ["cpu"] + + def __init__( + self, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + vlm_kwargs: Optional[dict] = None, + structured_output: bool = True, + use_outlines: bool = True, + device: str | torch.device = "cpu", + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + use_outlines=use_outlines, + **(vlm_kwargs or {}), + ) + self.response_format = FloatOutput if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (prompts). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], list) else [""] * len(images) + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + + sem_prompt = ( + f'On a scale of 0 to 10, how well does this image match the prompt "{prompt}"? ' + "0 = no match, 10 = perfect match. Reply with a single number." + ) + sem_resp = self.vlm.generate([image], [sem_prompt], response_format=self.response_format)[0] + sem_score = self._parse_score(sem_resp) + + qual_prompt = ( + "On a scale of 0 to 10, rate this image's naturalness and absence of artifacts. " + "0 = unnatural, heavy artifacts; 10 = natural, no artifacts. Reply with a single number." + ) + qual_resp = self.vlm.generate([image], [qual_prompt], response_format=self.response_format)[0] + qual_score = self._parse_score(qual_resp) + + score = math.sqrt(sem_score * qual_score) / 10.0 + self.scores.append(score) + + def _parse_score(self, response: str) -> float: + if isinstance(response, str): + numbers = re.findall(r"\d+", response) + return min(float(numbers[0]), 10.0) if numbers else 0.0 + return 0.0 + + def compute(self) -> MetricResult: + """ + Compute the VIEScore metric. + + Returns + ------- + MetricResult + The mean VIEScore across all updates. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/metric_vlm_utils.py b/src/pruna/evaluation/metrics/metric_vlm_utils.py new file mode 100644 index 00000000..04b088a8 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_vlm_utils.py @@ -0,0 +1,176 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities and Pydantic models for VLM metrics.""" + +from __future__ import annotations + +import json +import re +from typing import Any, List + +import torch +from PIL import Image +from pydantic import BaseModel, Field + + +def _tensor_to_pil(tensor: torch.Tensor) -> Image.Image: + if tensor.ndim == 4: + tensor = tensor[0] + if tensor.max() > 1: + tensor = tensor / 255.0 + np_img = (tensor.cpu().numpy() * 255).astype("uint8") + return Image.fromarray(np_img.transpose(1, 2, 0)) + + +def _process_images(images: torch.Tensor) -> List[Any]: + return [_tensor_to_pil(img) if isinstance(img, torch.Tensor) else img for img in images] + + +class VQAnswer(BaseModel): + """ + Structured output for VQA questions (Yes/No or open-ended). + + Parameters + ---------- + answer : str + Answer to the question. Typically "Yes" or "No" for alignment metrics, + but can be any string for open-ended questions. + """ + + answer: str = Field(description="Answer to the question") + + +class FloatOutput(BaseModel): + """ + Structured output for numeric scoring (img_edit_score, viescore). + + Parameters + ---------- + score : float + Score from 0 to 10. + """ + + score: float = Field(ge=0, le=10, description="Score from 0 to 10") + + +class TextOutput(BaseModel): + """ + Structured output for text extraction (text_score). + + Parameters + ---------- + text : str + Extracted text from the image, or 'No text recognized' if empty. + """ + + text: str = Field(description="Extracted text from the image, or 'No text recognized' if empty") + + +def get_answer_from_response(response: str | BaseModel | dict) -> str: + """ + Extract answer string from a VLM score() response (VQAnswer, dict, or raw string). + + Parameters + ---------- + response : str | BaseModel | dict + Raw response from vlm.generate() or vlm.score(). + + Returns + ------- + str + Extracted answer string, or empty string. + """ + if response is None: + return "" + if isinstance(response, VQAnswer): + return response.answer + if isinstance(response, dict): + return response.get("answer", "") + raw = str(response).strip() + if raw.startswith("{"): + try: + return json.loads(raw).get("answer", raw) + except (json.JSONDecodeError, TypeError): + pass + return raw + + +def get_text_from_response(response: str | BaseModel | dict) -> str: + """ + Extract text from a VLM generate() response (str, pydantic, or dict). + + Parameters + ---------- + response : str | BaseModel | dict + Raw response from vlm.generate(). + + Returns + ------- + str + Extracted text, or empty string. + """ + text = _extract_text_payload(response) + return _strip_no_text_markers(text) + + +def _extract_text_payload(response: str | BaseModel | dict) -> str: + if response is None: + return "" + if isinstance(response, TextOutput): + return response.text + if isinstance(response, dict): + return str(response.get("text", "") or "") + return _parse_json_text(str(response or "").strip()) + + +def _parse_json_text(text: str) -> str: + if not text.startswith("{"): + return text + try: + data = json.loads(text) + return str(data.get("text", text)) + except (json.JSONDecodeError, TypeError): + return text + + +def _strip_no_text_markers(text: str) -> str: + cleaned = text or "" + for phrase in ("No text recognized", "no text recognized", "No text"): + cleaned = cleaned.replace(phrase, "").strip() + return cleaned.strip() + + +def get_score_from_response(response: str | BaseModel | dict) -> float: + """ + Extract numeric score (0-10) from a VLM generate() response. + + Parameters + ---------- + response : str | BaseModel | dict + Raw response from vlm.generate(). + + Returns + ------- + float + Score in [0, 1] (normalized from 0-10). + """ + if response is None: + return 0.0 + if isinstance(response, FloatOutput): + return min(response.score, 10.0) / 10.0 + if isinstance(response, dict): + return min(float(response.get("score", 0)), 10.0) / 10.0 + numbers = re.findall(r"\d+", str(response or "")) + return min(float(numbers[0]), 10.0) / 10.0 if numbers else 0.0 diff --git a/src/pruna/evaluation/metrics/metric_vqa.py b/src/pruna/evaluation/metrics/metric_vqa.py new file mode 100644 index 00000000..e2fe6a0b --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_vqa.py @@ -0,0 +1,163 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VQA (Visual Question Answering) metric. + +Reference: VQAScore - Evaluating Text-to-Visual Generation with Image-to-Text Generation +https://arxiv.org/abs/2404.01291 + +Note: VQAScore uses P(Yes) (probability of "Yes" answer) for ranking. With litellm, +use_probability=True (default) requests logprobs for soft scores when the provider supports it. +Set use_probability=False for binary 0/1. TransformersVLM always uses binary. +""" + +from __future__ import annotations + +from typing import Any, List, Literal, Optional + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.metric_vlm_utils import VQAnswer, _process_images +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm + + +@MetricRegistry.register("vqa") +class VQAMetric(StatefulMetric): + """ + VQA (Visual Question Answering) metric. + + Uses VLM to answer "Does this image show '{prompt}'?" and scores alignment. + Higher scores indicate better image-text alignment. + + VQAScore (arXiv:2404.01291) uses P(Yes) for ranking. Default use_probability=True + with litellm requests logprobs for soft scores when supported. + + Parameters + ---------- + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, vlm_type and model_name are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend to use. Default is "litellm". + model_name : str, optional + Model name (gpt-4o for litellm, model path for transformers). + vlm_kwargs : dict, optional + Extra kwargs for VLM init (e.g. model_load_kwargs for transformers). + structured_output : bool, optional + Use structured generation for stable outputs. Default is True. + use_outlines : bool, optional + Use outlines for transformers. Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. Default is "cpu". + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + use_probability : bool, optional + If True, use P(Yes) when backend supports logprobs (litellm). Otherwise binary 0/1. + Default is True for paper alignment. + **kwargs : Any + Additional arguments forwarded to the VLM backend constructor. + """ + + scores: List[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "vqa" + runs_on: List[str] = ["cpu"] + + def __init__( + self, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + vlm_kwargs: Optional[dict] = None, + structured_output: bool = True, + use_outlines: bool = True, + device: str | torch.device = "cpu", + api_key: Optional[str] = None, + call_type: str = SINGLE, + use_probability: bool = True, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + self.structured_output = structured_output + self.use_probability = use_probability + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + use_outlines=use_outlines, + **(vlm_kwargs or {}), + ) + self.response_format = VQAnswer if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (prompts). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], list) else [""] * len(images) + + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + question = f'Does this image show "{prompt}"?' + score = self.vlm.score( + [image], + [question], + ["Yes"], + response_format=self.response_format, + use_probability=self.use_probability, + )[0] + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the VQA score. + + Returns + ------- + MetricResult + The mean VQA score across all updates. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/registry.py b/src/pruna/evaluation/metrics/registry.py index 5efd721a..a3201525 100644 --- a/src/pruna/evaluation/metrics/registry.py +++ b/src/pruna/evaluation/metrics/registry.py @@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, Iterable, List from pruna.engine.load import filter_load_kwargs +from pruna.engine.utils import device_to_string, split_device from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.logging.logger import pruna_logger @@ -32,6 +33,14 @@ class MetricRegistry: """ _registry: Dict[str, Callable[..., Any]] = {} + _cpu_default_stateful_metrics = { + "vqa", + "alignment_score", + "img_edit_score", + "qa_accuracy", + "text_score", + "viescore", + } @classmethod def register(cls, name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: @@ -135,7 +144,14 @@ def get_metric(cls, name: str, **kwargs) -> BaseMetric | StatefulMetric: return metric_cls(**kwargs) elif isclass(metric_cls): if issubclass(metric_cls, StatefulMetric): - kwargs["device"] = stateful_metric_device if stateful_metric_device else device + metric_device = stateful_metric_device if stateful_metric_device else device + if metric_device is None and name in cls._cpu_default_stateful_metrics: + metric_device = "cpu" + elif metric_device is not None: + requested_device, _ = split_device(device_to_string(metric_device), strict=False) + if requested_device not in metric_cls.runs_on and name in cls._cpu_default_stateful_metrics: + metric_device = "cpu" + kwargs["device"] = metric_device elif issubclass(metric_cls, BaseMetric): kwargs["device"] = inference_device if inference_device else device return metric_cls(**filter_load_kwargs(metric_cls, kwargs)) diff --git a/src/pruna/evaluation/metrics/utils.py b/src/pruna/evaluation/metrics/utils.py index 29342701..c6813872 100644 --- a/src/pruna/evaluation/metrics/utils.py +++ b/src/pruna/evaluation/metrics/utils.py @@ -56,13 +56,17 @@ def metric_data_processor( This function determines the order and selection of inputs to be passed to various metrics. The function supports different input arrangements through the 'call_type' configuration: - - 'x_y': Uses input data (x) and model outputs - - 'gt_y': Uses ground truth (gt) and model outputs - - 'y_x': Uses model outputs and input data (x) - - 'y_gt': Uses model outputs and ground truth (gt) - - 'pairwise_gt_y': Uses cached base model outputs (gt) and smashed model outputs (y). - - 'pairwise_y_gt': Uses smashed model outputs (y) and cached base model outputs (gt). - The evaluation agent is expected to pass the cached base model outputs as gt. + + - 'y_gt': Model's output first, then ground truth. Returns [outputs, gt]. + - 'gt_y': Ground truth first, then model's output. Returns [gt, outputs]. + - 'y_x': Model's output first, then input data. Returns [outputs, x]. + Used by CLIPScore, AlignmentScore, VQA, ImageEditScore, VIEScore. + - 'x_y': Input data first, then model's output. Returns [x, outputs]. + - 'x_gt': Input data first, then ground truth. Returns [x, gt]. + - 'gt_x': Ground truth first, then input data. Returns [gt, x]. + - 'pairwise_y_gt': Base model's output first, then subsequent model's output. + - 'pairwise_gt_y': Subsequent model's output first, then base model's output. + - 'y': Only the output is used; the metric has an internal dataset. Returns [outputs]. Parameters ---------- @@ -85,7 +89,8 @@ def metric_data_processor( Raises ------ ValueError - If the specified call_type is not one of: 'x_y', 'gt_y', 'y_x', 'y_gt', 'pairwise'. + If the specified call_type is not one of: 'y_gt', 'gt_y', 'y_x', 'x_y', + 'x_gt', 'gt_x', 'pairwise_y_gt', 'pairwise_gt_y', 'y'. Examples -------- @@ -106,11 +111,15 @@ def metric_data_processor( return [outputs, x] elif call_type == "y_gt": return [outputs, gt] + elif call_type == "x_gt": + return [x, gt] + elif call_type == "gt_x": + return [gt, x] elif call_type == "pairwise_gt_y": return [gt, outputs] elif call_type == "pairwise_y_gt": return [outputs, gt] - elif call_type == "y": # IQA metrics that have an internal dataset + elif call_type == "y": return [outputs] else: raise ValueError(f"Invalid call type: {call_type}") diff --git a/src/pruna/evaluation/metrics/vlm_base.py b/src/pruna/evaluation/metrics/vlm_base.py new file mode 100644 index 00000000..3139028a --- /dev/null +++ b/src/pruna/evaluation/metrics/vlm_base.py @@ -0,0 +1,691 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +VLM (Vision-Language Model) base classes for metrics. + +This module provides two VLM implementations: +1. LitellmVLM - Uses litellm for API-based VLM calls (supports 100+ providers) +2. TransformersVLM - Uses local VLM models from HuggingFace Transformers + +Both support structured generation for stable outputs: +- LitellmVLM: Uses pydantic models with response_format +- TransformersVLM: Uses outlines for constrained decoding. +""" + +from __future__ import annotations + +import base64 +import io +import json +import math +import os +from abc import ABC, abstractmethod +from typing import Any, List, Literal, Optional, Type, TypeVar, Union + +import torch +from PIL import Image +from pydantic import BaseModel + +from pruna.logging.logger import pruna_logger + +T = TypeVar("T", bound=BaseModel) + + +def get_vlm( + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str = "gpt-4o", + device: Optional[str | torch.device] = None, + api_key: Optional[str] = None, + use_outlines: bool = False, + **vlm_kwargs: Any, +) -> BaseVLM: + """ + Create or return a VLM instance. + + Parameters + ---------- + vlm : BaseVLM | None + If provided, returned as-is. Otherwise a VLM is created. + vlm_type : {"litellm", "transformers"} + Backend when creating a VLM. + model_name : str + Model name for litellm or HuggingFace. + device : str | torch.device | None + Device for transformers VLM. + api_key : str | None + API key for litellm. + use_outlines : bool + Use outlines for transformers. + **vlm_kwargs : Any + Extra kwargs passed to LitellmVLM or TransformersVLM. + For TransformersVLM, use model_load_kwargs={"torch_dtype": torch.bfloat16} + to pass options to from_pretrained. + + Returns + ------- + BaseVLM + The VLM instance. + """ + if vlm is not None: + return vlm + if vlm_type == "litellm": + return LitellmVLM(model_name=model_name, api_key=api_key, **vlm_kwargs) + model_load_kwargs = vlm_kwargs.pop("model_load_kwargs", {}) + return TransformersVLM( + model_name=model_name, + device=device, + use_outlines=use_outlines, + model_load_kwargs=model_load_kwargs, + **vlm_kwargs, + ) + + +class BaseVLM(ABC): + """Base class for Vision-Language Models.""" + + @abstractmethod + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses for images and prompts. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | str | None + Optional pydantic model (litellm) or format string: "integer", "yes_no", "json" (transformers/outlines). + **kwargs : Any + Additional arguments passed to the implementation. + + Returns + ------- + List[str] + Generated responses. + """ + ... + + @abstractmethod + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + use_probability: bool = False, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + use_probability : bool, optional + If True and supported, return P(expected answer) instead of binary 0/1. + response_format : Type[BaseModel] | str | None, optional + Structured output format. When set, uses generate() with this format and + extracts the answer field for comparison instead of raw string matching. + **kwargs : Any + Additional arguments passed to the implementation. + + Returns + ------- + List[float] + Scores for each image-question pair (0-1, or probability when use_probability). + """ + ... + + +class LitellmVLM(BaseVLM): + """ + VLM using litellm for API-based inference. + + Supports 100+ LLM providers (OpenAI, Anthropic, Azure, etc.) + Default model is gpt-4o. + + Parameters + ---------- + model_name : str, optional + Model name (e.g., gpt-4o). Default is "gpt-4o". + api_key : str | None, optional + API key for the provider. Uses LITELLM_API_KEY or OPENAI_API_KEY env if None. + **kwargs : Any + Additional arguments passed to litellm. + """ + + def __init__( + self, + model_name: str = "gpt-4o", + api_key: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.model_name = model_name + self.api_key = api_key or os.getenv("LITELLM_API_KEY") or os.getenv("OPENAI_API_KEY") + self.extra_kwargs = kwargs + try: + import litellm + + litellm.drop_params = True + self._litellm = litellm + except ImportError: + pruna_logger.error("litellm not installed. Install with: pip install litellm") + raise + + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses for images and prompts. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | str | None + Optional pydantic model for structured output (litellm uses BaseModel). + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + List[str] + Generated responses. + """ + results = [] + for image, prompt in zip(images, prompts): + try: + content = self._build_litellm_content(image, prompt) + completion_kwargs = self._build_completion_kwargs(content, kwargs, response_format) + response = self._litellm.completion(**completion_kwargs) + results.append(self._extract_content_result(response, response_format)) + except Exception as e: + pruna_logger.error(f"Litellm generation failed: {e}") + results.append("") + return results + + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + use_probability: bool = False, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + When use_probability=True, requests logprobs from the API and returns P(expected). + When response_format is set, uses structured generation and extracts the answer field. + Falls back to binary 0/1 if logprobs not available. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + use_probability : bool, optional + If True, return P(expected) from logprobs when available. Default is False. + response_format : Type[BaseModel] | str | None, optional + Structured output format for answer extraction. + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + List[float] + Scores for each image-question pair (0-1, or probability when use_probability). + """ + scores = [] + for image, question, answer in zip(images, questions, answers): + prompt = f"{question} Please answer yes or no." + if use_probability: + score = self._score_with_logprobs(image, prompt, answer, **kwargs) + elif response_format is not None: + score = self._score_structured_response(image, prompt, answer, response_format, **kwargs) + else: + raw = self.generate([image], [prompt], **kwargs)[0] + score = self._normalize_binary_match(raw, answer) + scores.append(score) + return scores + + def _score_with_logprobs(self, image: Image.Image, prompt: str, expected: str, **kwargs: Any) -> float: + """ + Get P(expected) from logprobs when available. + + Parameters + ---------- + image : Image.Image + PIL Image to score. + prompt : str + Question prompt. + expected : str + Expected answer (e.g., "Yes"). + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + float + Probability of expected answer (0-1), or binary 0/1 on fallback. + """ + content = self._build_litellm_content(image, prompt) + completion_kwargs = { + "model": self.model_name, + "messages": [{"role": "user", "content": content}], + "api_key": self.api_key, + "logprobs": True, + "top_logprobs": 5, + **self.extra_kwargs, + **kwargs, + } + try: + response = self._litellm.completion(**completion_kwargs) + choice = response.choices[0] + logprobs = self._extract_logprobs(choice) + prob = self._prob_from_top_logprobs(logprobs, expected) + if prob is not None: + return prob + return self._binary_fallback_from_choice(choice, expected) + except Exception: + response = self.generate([image], [prompt], **kwargs)[0].lower() + return 1.0 if expected.lower() in response else 0.0 + + def _build_litellm_content(self, image: Image.Image, prompt: str) -> list[dict[str, Any]]: + return [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": self._image_to_data_url(image)}}, + ] + + def _build_completion_kwargs( + self, + content: list[dict[str, Any]], + kwargs: dict[str, Any], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]], + ) -> dict[str, Any]: + completion_kwargs: dict[str, Any] = { + "model": self.model_name, + "messages": [{"role": "user", "content": content}], + "api_key": self.api_key, + **self.extra_kwargs, + **kwargs, + } + if response_format is not None and isinstance(response_format, type): + completion_kwargs["response_format"] = response_format + return completion_kwargs + + def _extract_content_result( + self, + response: Any, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]], + ) -> str: + content_result = response.choices[0].message.content + use_pydantic = response_format is not None and isinstance(response_format, type) and isinstance( + content_result, response_format + ) + if use_pydantic: + return content_result.model_dump_json() + return content_result + + @staticmethod + def _normalize_binary_match(response_text: str, expected: str) -> float: + return 1.0 if expected.lower() in response_text.lower() else 0.0 + + def _score_structured_response( + self, + image: Image.Image, + prompt: str, + expected: str, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]], + **kwargs: Any, + ) -> float: + from pruna.evaluation.metrics.metric_vlm_utils import get_answer_from_response + + raw = self.generate([image], [prompt], response_format=response_format, **kwargs)[0] + response_answer = get_answer_from_response(raw) + return self._normalize_binary_match(response_answer, expected) + + @staticmethod + def _extract_logprobs(choice: Any) -> Any: + return getattr(choice, "logprobs", None) or getattr(choice.message, "logprobs", None) + + @staticmethod + def _prob_from_top_logprobs(logprobs: Any, expected: str) -> Optional[float]: + token_logprobs = LitellmVLM._iter_top_logprobs(logprobs) + if token_logprobs is None: + return None + expected_lower = expected.lower() + for token_logprob in token_logprobs: + if LitellmVLM._token_matches_expected(token_logprob, expected_lower): + return LitellmVLM._logprob_to_probability(token_logprob) + return None + + @staticmethod + def _iter_top_logprobs(logprobs: Any) -> Optional[list[Any]]: + if not (logprobs and hasattr(logprobs, "content")): + return None + flattened: list[Any] = [] + for tok in logprobs.content or []: + flattened.extend(getattr(tok, "top_logprobs", None) or []) + return flattened + + @staticmethod + def _token_matches_expected(token_logprob: Any, expected_lower: str) -> bool: + token_str = getattr(token_logprob, "token", "") or str(token_logprob) + return bool(token_str and expected_lower in token_str.lower()) + + @staticmethod + def _logprob_to_probability(token_logprob: Any) -> float: + logprob = float(getattr(token_logprob, "logprob", -1e9) or -1e9) + return min(1.0, max(0.0, math.exp(logprob))) + + def _binary_fallback_from_choice(self, choice: Any, expected: str) -> float: + content_str = (choice.message.content or "") + return self._normalize_binary_match(content_str, expected) + + def _image_to_data_url(self, image: Image.Image) -> str: + buffer = io.BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) + b64 = base64.b64encode(buffer.read()).decode("utf-8") + return f"data:image/png;base64,{b64}" + + +class TransformersVLM(BaseVLM): + """ + VLM using HuggingFace Transformers for local inference. + + Supports models like BLIP, LLaVA, SmolVLM, etc. + + Parameters + ---------- + model_name : str, optional + HuggingFace model name. Default is "Salesforce/blip2-opt-2.7b". + device : str | torch.device | None, optional + Device for inference. Auto-detected if None. + use_outlines : bool, optional + Use outlines for constrained decoding. Default is False. + model_load_kwargs : dict, optional + Kwargs passed to from_pretrained (e.g. torch_dtype, attn_implementation). + **kwargs : Any + Additional arguments passed to model.generate. + """ + + def __init__( + self, + model_name: str = "Salesforce/blip2-opt-2.7b", + device: Optional[str | torch.device] = None, + use_outlines: bool = False, + model_load_kwargs: Optional[dict] = None, + **kwargs: Any, + ) -> None: + self.model_name = model_name + self.use_outlines = use_outlines + self.model_load_kwargs = model_load_kwargs or {} + if device is None: + if torch.cuda.is_available(): + self.device = torch.device("cuda") + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + else: + self.device = torch.device(device) + self.extra_kwargs = kwargs + self._model = None + self._processor = None + self._outlines_model = None + + def _load_model(self) -> None: + if self._model is not None: + return + try: + from transformers import AutoModelForImageTextToText, AutoProcessor + except ImportError: + pruna_logger.error("transformers not installed. Install with: pip install transformers") + raise + pruna_logger.info(f"Loading VLM model: {self.model_name}") + self._processor = AutoProcessor.from_pretrained(self.model_name) + self._model = AutoModelForImageTextToText.from_pretrained(self.model_name, **self.model_load_kwargs) + device = self.device + self._model.to(device) # type: ignore[invalid-argument-type] + self._model.eval() + + def _load_outlines_model(self) -> None: + """Lazily wrap the loaded multimodal model for Outlines structured generation.""" + if self._outlines_model is not None: + return + try: + import outlines + except ImportError: + pruna_logger.warning("outlines not installed, using standard generation") + return + self._load_model() + if self._model is None or self._processor is None: + pruna_logger.warning("VLM model or processor failed to load, using standard generation") + return + self._outlines_model = outlines.from_transformers(self._model, self._processor) + + def _get_outlines_output_type( + self, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]], + ) -> Any: + """Map current response formats to an Outlines-compatible output type.""" + if response_format is None: + return None + if isinstance(response_format, type) and issubclass(response_format, BaseModel): + return response_format + if response_format == "integer": + return int + if response_format == "yes_no": + return Literal["Yes", "No"] + if response_format == "json": + return dict + return None + + @staticmethod + def _serialize_outlines_result(result: Any) -> str: + """Normalize Outlines results so the existing response parsers still work.""" + if isinstance(result, BaseModel): + return result.model_dump_json() + if isinstance(result, (dict, list)): + return json.dumps(result) + return str(result) + + @staticmethod + def _to_outlines_input(image: Image.Image, prompt: str) -> list[Any]: + """Build a minimal multimodal input payload for Outlines.""" + from outlines.inputs import Image as OutlinesImage + + return [prompt, OutlinesImage(image)] + + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses using local VLM. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | str | None + Format constraint for outlines ("integer", "yes_no") or None. + **kwargs : Any + Additional arguments passed to model generate. + + Returns + ------- + List[str] + Generated responses. + """ + self._load_model() + max_new_tokens, gen_kwargs = self._prepare_transformers_generation_args(kwargs) + return self._run_structured_or_standard_generation(images, prompts, response_format, max_new_tokens, gen_kwargs) + + @staticmethod + def _prepare_transformers_generation_args(kwargs: dict[str, Any]) -> tuple[int, dict[str, Any]]: + max_new_tokens = kwargs.get("max_new_tokens", 128) + gen_kwargs = {k: v for k, v in kwargs.items() if k != "max_new_tokens"} + return max_new_tokens, gen_kwargs + + def _run_structured_or_standard_generation( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]], + max_new_tokens: int, + gen_kwargs: dict[str, Any], + ) -> List[str]: + if self.use_outlines and response_format is not None: + return self._generate_with_outlines(images, prompts, response_format, max_new_tokens) + return self._generate_standard(images, prompts, max_new_tokens, **gen_kwargs) + + def _generate_with_outlines( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]], + max_new_tokens: int, + ) -> List[str]: + """Generate using outlines for constrained decoding.""" + self._load_outlines_model() + if self._outlines_model is None: + return self._generate_standard(images, prompts, max_new_tokens) + output_type = self._get_outlines_output_type(response_format) + if output_type is None: + return self._generate_standard(images, prompts, max_new_tokens) + results = [] + for image, prompt in zip(images, prompts): + try: + model_input = self._to_outlines_input(image, prompt) + output = self._outlines_model(model_input, output_type, max_new_tokens=max_new_tokens) + results.append(self._serialize_outlines_result(output)) + except Exception as e: + pruna_logger.warning(f"Outlines generation failed: {e}, using standard") + results.extend(self._generate_standard([image], [prompt], max_new_tokens)) + return results + + def _prepare_inputs(self, image: Image.Image, prompt: str) -> dict: + """Prepare model inputs, supporting both BLIP-style and chat-template processors.""" + try: + inputs = self._processor(images=[image], text=prompt, return_tensors="pt") + except (ValueError, TypeError): + conversation = [ + {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]} + ] + inputs = self._processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + return {k: v.to(self.device) for k, v in inputs.items()} + + def _decode_output(self, output_ids: torch.Tensor) -> str: + """Decode model output to text.""" + if hasattr(self._processor, "batch_decode"): + return self._processor.batch_decode([output_ids], skip_special_tokens=True)[0] + return self._processor.decode(output_ids, skip_special_tokens=True) + + def _generate_standard( + self, + images: List[Image.Image], + prompts: List[str], + max_new_tokens: int, + **kwargs: Any, + ) -> List[str]: + """Standard generation without outlines.""" + results = [] + with torch.inference_mode(): + for image, prompt in zip(images, prompts): + inputs = self._prepare_inputs(image, prompt) + output = self._model.generate(**inputs, max_new_tokens=max_new_tokens, **self.extra_kwargs, **kwargs) + response = self._decode_output(output[0]) + results.append(response) + return results + + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + use_probability: bool = False, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + use_probability is not supported for TransformersVLM; uses binary 0/1. + When response_format is set, uses structured generation and extracts the answer field. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + use_probability : bool, optional + Ignored; TransformersVLM always uses binary 0/1. + response_format : Type[BaseModel] | str | None, optional + Structured output format for answer extraction. + **kwargs : Any + Additional arguments passed to generate. + + Returns + ------- + List[float] + Scores for each image-question pair (0 or 1). + """ + from pruna.evaluation.metrics.metric_vlm_utils import get_answer_from_response + + scores = [] + for image, question, answer in zip(images, questions, answers): + prompt = f"{question} Please answer yes or no." + responses = self.generate([image], [prompt], response_format=response_format, **kwargs) + raw = responses[0] if responses else "" + response_answer = get_answer_from_response(raw) if response_format is not None else raw.lower() + score = 1.0 if answer.lower() in response_answer.lower() else 0.0 + scores.append(score) + return scores diff --git a/tests/evaluation/test_task.py b/tests/evaluation/test_task.py index 281c6b7e..855a4ef5 100644 --- a/tests/evaluation/test_task.py +++ b/tests/evaluation/test_task.py @@ -1,21 +1,29 @@ -import pytest +from functools import partial from unittest.mock import patch + +import pytest +from torchmetrics.classification import Accuracy, Precision, Recall from transformers import AutoTokenizer -from pruna.evaluation.task import Task + from pruna.data.pruna_datamodule import PrunaDataModule -from pruna.evaluation.metrics.registry import MetricRegistry -from pruna.data import base_datasets -from pruna.evaluation.metrics.metric_torch import TorchMetrics -from torchmetrics.classification import Accuracy, Precision, Recall -from functools import partial +from pruna.engine.utils import device_to_string, split_device from pruna.evaluation.metrics.metric_base import BaseMetric -from pruna.evaluation.metrics.metric_stateful import StatefulMetric -from pruna.engine.utils import split_device, device_to_string -from ..common import device_parametrized -from pruna.evaluation.metrics.metric_elapsed_time import LatencyMetric from pruna.evaluation.metrics.metric_cmmd import CMMD +from pruna.evaluation.metrics.metric_elapsed_time import LatencyMetric from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore -from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.metric_torch import TorchMetrics, TorchMetricWrapper +from pruna.evaluation.metrics.metric_vqa import VQAMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.task import Task + +from ..common import device_parametrized + + +def _require(condition: bool, message: str = "test condition failed") -> None: + if not condition: + pytest.fail(message) + @pytest.fixture(autouse=True) def _mock_torch_metrics(): @@ -36,23 +44,53 @@ def make_mock_metric(metric_class): with patch.object(TorchMetrics, '_member_map_', {**TorchMetrics._member_map_, **mock_metrics}): yield + @pytest.mark.parametrize("metric_name", MetricRegistry()._registry) def test_metric_initialization_from_metric_name(metric_name): + """All registered metric names should instantiate through Task.""" datamodule = PrunaDataModule.from_string("LAION256") - Task(request=[metric_name], datamodule=datamodule) + Task(request=[metric_name], datamodule=datamodule, device="cpu") + + +@patch("pruna.evaluation.task.set_to_best_available_device") +def test_vlm_metrics_fallback_to_cpu_on_auto_device(mock_set_to_best_available_device): + """VLM metrics should stay on CPU when task auto-selects CUDA.""" + def fake_best_device(device=None, *args, **kwargs): + if device is None: + return "cuda" + return device + + mock_set_to_best_available_device.side_effect = fake_best_device + + task = Task(request=["vqa"], datamodule=PrunaDataModule.from_string("PartiPrompts")) + + _require(split_device(device_to_string(task.device))[0] == "cuda") + _require(isinstance(task.metrics[0], VQAMetric)) + _require(split_device(device_to_string(task.metrics[0].device))[0] == "cpu") @device_parametrized -def test_device_is_set_correctly_for_metrics(device:str): - task = Task(request=['latency', 'cmmd', 'pairwise_clip_score'], datamodule=PrunaDataModule.from_string("LAION256"), device = device) - assert split_device(device_to_string(task.device)) == split_device(device_to_string(device)) +def test_device_is_set_correctly_for_metrics(device: str): + """Task and metric devices should align with the requested device.""" + task = Task( + request=["latency", "cmmd", "pairwise_clip_score"], + datamodule=PrunaDataModule.from_string("LAION256"), + device=device, + ) + _require(split_device(device_to_string(task.device)) == split_device(device_to_string(device))) for metric in task.metrics: if isinstance(metric, BaseMetric): - assert split_device(device_to_string(metric.device)) == split_device(device_to_string(device)) + _require(split_device(device_to_string(metric.device)) == split_device(device_to_string(device))) elif isinstance(metric, StatefulMetric): - if hasattr(metric, 'metric'): - assert split_device(device_to_string(metric.metric.device)) == split_device(device_to_string(task.stateful_metric_device)) - assert split_device(device_to_string(metric.device)) == split_device(device_to_string(task.stateful_metric_device)) + if hasattr(metric, "metric"): + _require( + split_device(device_to_string(metric.metric.device)) + == split_device(device_to_string(task.stateful_metric_device)) + ) + _require( + split_device(device_to_string(metric.device)) + == split_device(device_to_string(task.stateful_metric_device)) + ) @pytest.mark.cuda @@ -82,42 +120,58 @@ def test_device_is_set_correctly_for_metrics(device:str): ], ) def test_metric_device_adapts_to_task_device(inference_device: str, stateful_metric_device: str, task_device: str): - """ Test that the metrics in the task are moved to the task device if they are on a different device.""" + """Test that the metrics in the task are moved to the task device if they are on a different device.""" latency = LatencyMetric(device=inference_device) cmmd = CMMD(device=stateful_metric_device) pairwise_clip_score = PairwiseClipScore(device=stateful_metric_device) - psnr = TorchMetricWrapper('psnr', device=stateful_metric_device) + psnr = TorchMetricWrapper("psnr", device=stateful_metric_device) - task = Task(request=[latency, cmmd, pairwise_clip_score, psnr], datamodule=PrunaDataModule.from_string("LAION256"), device = task_device) - assert split_device(device_to_string(task.device)) == split_device(device_to_string(task_device)) + task = Task( + request=[latency, cmmd, pairwise_clip_score, psnr], + datamodule=PrunaDataModule.from_string("LAION256"), + device=task_device, + ) + _require(split_device(device_to_string(task.device)) == split_device(device_to_string(task_device))) for metric in task.metrics: if isinstance(metric, BaseMetric): - assert split_device(device_to_string(metric.device)) == split_device(device_to_string(task.device)) + _require(split_device(device_to_string(metric.device)) == split_device(device_to_string(task.device))) elif isinstance(metric, StatefulMetric): if hasattr(metric, "device"): - assert split_device(device_to_string(metric.device)) == split_device(device_to_string(task.stateful_metric_device)) - if hasattr(metric, "metric") and hasattr(metric.metric, "device"): # Wrapper metric - assert split_device(device_to_string(metric.metric.device)) == split_device(device_to_string(task.stateful_metric_device)) + _require( + split_device(device_to_string(metric.device)) + == split_device(device_to_string(task.stateful_metric_device)) + ) + if hasattr(metric, "metric") and hasattr(metric.metric, "device"): # Wrapper metric + _require( + split_device(device_to_string(metric.metric.device)) + == split_device(device_to_string(task.stateful_metric_device)) + ) if not hasattr(metric, "device") and not hasattr(metric.metric, "device"): raise ValueError("Could not find device for metric.") + @pytest.mark.cpu def test_task_from_string_request(): + """Task should instantiate requested metric wrappers by name.""" request = ["cmmd", "pairwise_clip_score", "psnr"] - task = Task(request=request, datamodule=PrunaDataModule.from_string("LAION256"), device = "cpu") - assert isinstance(task.metrics[0], CMMD) - assert isinstance(task.metrics[1], PairwiseClipScore) - assert isinstance(task.metrics[2], TorchMetricWrapper) + task = Task(request=request, datamodule=PrunaDataModule.from_string("LAION256"), device="cpu") + _require(isinstance(task.metrics[0], CMMD)) + _require(isinstance(task.metrics[1], PairwiseClipScore)) + _require(isinstance(task.metrics[2], TorchMetricWrapper)) @pytest.mark.cpu def test_task_text_generation_quality_request(): """Test that 'text_generation_quality' named request creates perplexity metric.""" tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - task = Task(request="text_generation_quality", datamodule=PrunaDataModule.from_string("TinyWikiText", tokenizer=tokenizer), device="cpu") - assert len(task.metrics) == 1 - assert isinstance(task.metrics[0], TorchMetricWrapper) - assert task.metrics[0].metric_name == "perplexity" + task = Task( + request="text_generation_quality", + datamodule=PrunaDataModule.from_string("TinyWikiText", tokenizer=tokenizer), + device="cpu", + ) + _require(len(task.metrics) == 1) + _require(isinstance(task.metrics[0], TorchMetricWrapper)) + _require(task.metrics[0].metric_name == "perplexity") @pytest.mark.cpu diff --git a/tests/evaluation/test_vlm_metrics.py b/tests/evaluation/test_vlm_metrics.py new file mode 100644 index 00000000..3f4f3d2b --- /dev/null +++ b/tests/evaluation/test_vlm_metrics.py @@ -0,0 +1,402 @@ +"""Tests for VLM metrics (VQA, AlignmentScore, ImageEditScore, QAAccuracy, TextScore, VieScore).""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch +from datasets import Dataset +from pydantic import BaseModel + +from pruna.data import base_datasets +from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.evaluation.evaluation_agent import EvaluationAgent +from pruna.evaluation.metrics.metric_alignment_score import AlignmentScoreMetric +from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric +from pruna.evaluation.metrics.metric_text_score import TextScoreMetric +from pruna.evaluation.metrics.metric_viescore import VieScoreMetric +from pruna.evaluation.metrics.metric_vlm_utils import VQAnswer, get_answer_from_response +from pruna.evaluation.metrics.metric_vqa import VQAMetric +from pruna.evaluation.metrics.vlm_base import BaseVLM, TransformersVLM, get_vlm +from pruna.evaluation.task import Task + +SMOL_VLM = "HuggingFaceTB/SmolVLM-256M-Instruct" + + +def _require(condition: bool, message: str = "test condition failed") -> None: + if not condition: + pytest.fail(message) + + +def _dummy_image(batch: int = 1, size: int = 224) -> torch.Tensor: + return torch.rand(batch, 3, size, size) + + +def _prompt_benchmark_datamodule(records: list[dict]) -> PrunaDataModule: + dataset = Dataset.from_list(records) + return PrunaDataModule.from_datasets((dataset, dataset, dataset), "prompt_with_auxiliaries_collate") + + +def _update_metric(metric: object, prompts: list, images: torch.Tensor) -> None: + """Update metric with appropriate gt type per metric contract.""" + if isinstance(metric, QAAccuracyMetric): + metric.update(prompts, [{"questions": ["Is there a cat?"]}], images) + elif isinstance(metric, TextScoreMetric): + metric.update(prompts, ["cat"], images) + else: + metric.update(prompts, images, images) + + +@pytest.mark.cpu +@pytest.mark.slow +@pytest.mark.parametrize( + "metric_cls", + [ + VQAMetric, + AlignmentScoreMetric, + ImageEditScoreMetric, + QAAccuracyMetric, + TextScoreMetric, + VieScoreMetric, + ], +) +@pytest.mark.parametrize("structured_output", [False, True]) +def test_vlm_metrics_transformers_smolvlm(metric_cls: type, structured_output: bool) -> None: + """Test each VLM metric with local SmolVLM-256M-Instruct.""" + metric = metric_cls( + vlm_type="transformers", + model_name=SMOL_VLM, + device="cpu", + structured_output=structured_output, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + _update_metric(metric, prompts, images) + result = metric.compute() + _require(result.name == metric.metric_name) + _require(isinstance(result.result, float)) + if metric.higher_is_better: + _require(0.0 <= result.result <= 1.0) + else: + _require(result.result >= 0.0) + + +@pytest.mark.cpu +@pytest.mark.parametrize( + "metric_cls", + [ + VQAMetric, + AlignmentScoreMetric, + ImageEditScoreMetric, + QAAccuracyMetric, + TextScoreMetric, + VieScoreMetric, + ], +) +@pytest.mark.parametrize("structured_output", [False, True]) +def test_vlm_metrics_litellm_mocked(metric_cls: type, structured_output: bool) -> None: + """Test each VLM metric with mocked litellm API (requires litellm installed).""" + pytest.importorskip("litellm") + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + if metric_cls in (AlignmentScoreMetric, VQAMetric, QAAccuracyMetric): + mock_response.choices[0].message.content = ( + '{"answer": "Yes"}' if structured_output else "Yes" + ) + else: + mock_response.choices[0].message.content = ( + '{"score": 8}' if structured_output else "8" + ) + + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = mock_response + + metric = metric_cls( + vlm_type="litellm", + model_name="gpt-4o", + device="cpu", + structured_output=structured_output, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + _update_metric(metric, prompts, images) + result = metric.compute() + + _require(result.name == metric.metric_name) + _require(isinstance(result.result, float)) + _require(mock_completion.called) + + +@pytest.mark.cpu +@pytest.mark.parametrize( + "metric_cls", + [ + VQAMetric, + AlignmentScoreMetric, + ImageEditScoreMetric, + QAAccuracyMetric, + TextScoreMetric, + VieScoreMetric, + ], +) +@pytest.mark.parametrize("structured_output", [False, True]) +def test_vlm_metrics_empty_score(metric_cls: type, structured_output: bool) -> None: + """Test that empty compute returns 0.0.""" + metric = metric_cls( + vlm_type="transformers", + model_name=SMOL_VLM, + device="cpu", + structured_output=structured_output, + ) + result = metric.compute() + _require(result.result == 0.0) + + +@pytest.mark.cpu +@pytest.mark.parametrize("structured_output", [False, True]) +def test_vlm_metrics_custom_vlm(structured_output: bool) -> None: + """Test metrics with a custom VLM instance.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["Yes"] + mock_vlm.score.return_value = [1.0] + + metric = VQAMetric( + vlm=mock_vlm, vlm_type="litellm", device="cpu", structured_output=structured_output + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + metric.update(prompts, images, images) + result = metric.compute() + + _require(result.result == 1.0) + mock_vlm.score.assert_called() + + +@pytest.mark.cpu +def test_get_vlm_returns_custom() -> None: + """Test get_vlm returns provided vlm as-is.""" + custom = MagicMock(spec=BaseVLM) + out = get_vlm(vlm=custom, vlm_type="litellm", model_name="gpt-4o") + _require(out is custom) + + +@pytest.mark.cpu +def test_vlm_metric_defaults_enable_structured_local_generation() -> None: + """Transformers-backed VLM metrics should default to outlines-based structured generation.""" + metric = VQAMetric(vlm_type="transformers", model_name=SMOL_VLM) + _require(metric.vlm.use_outlines is True) + _require(metric.device == "cpu") + + +@pytest.mark.cpu +def test_transformers_generate_routes_pydantic_response_format_to_outlines() -> None: + """Structured Pydantic responses should use the outlines path for transformers backends.""" + vlm = TransformersVLM(model_name=SMOL_VLM, device="cpu", use_outlines=True) + + with ( + patch.object(vlm, "_load_model") as mock_load_model, + patch.object(vlm, "_generate_with_outlines", return_value=['{"answer":"Yes"}']) as mock_outlines, + patch.object(vlm, "_generate_standard", return_value=["fallback"]) as mock_standard, + ): + result = vlm.generate([MagicMock()], ["question"], response_format=VQAnswer) + + mock_load_model.assert_called_once() + mock_outlines.assert_called_once() + mock_standard.assert_not_called() + _require(result == ['{"answer":"Yes"}']) + + +@pytest.mark.cpu +def test_transformers_outlines_result_serialization() -> None: + """Outlines outputs should be normalized into strings parseable by existing helpers.""" + + class DummySchema(BaseModel): + answer: str + + schema_result = TransformersVLM._serialize_outlines_result(DummySchema(answer="Yes")) + dict_result = TransformersVLM._serialize_outlines_result({"answer": "No"}) + + _require(get_answer_from_response(schema_result) == "Yes") + _require(get_answer_from_response(dict_result) == "No") + + +@pytest.mark.cpu +def test_evaluation_agent_update_stateful_metrics_with_stub_vlm() -> None: + """Smoke-test the real agent stateful update path with a stub VLM-backed metric.""" + stub_vlm = MagicMock(spec=BaseVLM) + stub_vlm.score.return_value = [1.0] + metric = VQAMetric(vlm=stub_vlm, vlm_type="litellm", device="cpu") + task = Task(request=[metric], datamodule=PrunaDataModule.from_string("LAION256"), device="cpu") + agent = EvaluationAgent(task=task) + agent.task.dataloader = [(["a cat"], torch.empty(0))] + agent.device = "cpu" + agent.device_map = None + + class FakeModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.dummy = torch.nn.Parameter(torch.zeros(1)) + + def run_inference(self, batch): + return _dummy_image(batch=1) + + agent.update_stateful_metrics(FakeModel(), agent.task.get_single_stateful_metrics(), []) + results = agent.compute_stateful_metrics(agent.task.get_single_stateful_metrics(), []) + + _require(len(results) == 1) + _require(results[0].name == "vqa") + _require(results[0].result == 1.0) + stub_vlm.score.assert_called_once() + + +@pytest.mark.cpu +@pytest.mark.parametrize( + ("benchmark_name", "dataset_key", "records", "module_patch", "score_value", "expected_name"), + [ + ( + "GenAI Bench", + "GenAIBench", + [{"text": "a cat"}], + "pruna.evaluation.metrics.metric_vqa.get_vlm", + [1.0], + "vqa", + ), + ( + "GenEval", + "GenEval", + [{"text": "a cat", "questions": ["Is there a cat?"]}], + "pruna.evaluation.metrics.metric_qa_accuracy.get_vlm", + [1.0], + "qa_accuracy", + ), + ( + "Long Text Bench", + "LongTextBench", + [{"text": "draw text", "text_content": "HELLO"}], + "pruna.evaluation.metrics.metric_text_score.get_vlm", + ["HELLO"], + "text_score", + ), + ( + "GEditBench", + "GEditBench", + [{"text": "add a hat", "category": "subject_add"}], + "pruna.evaluation.metrics.metric_viescore.get_vlm", + ['{"score": 8}'], + "viescore", + ), + ( + "OneIG", + "OneIG", + [{"text": "a cat", "questions": {"q1": "Is there a cat?"}, "category": "General_Object"}], + "pruna.evaluation.metrics.metric_qa_accuracy.get_vlm", + [1.0], + "qa_accuracy", + ), + ( + "DPG", + "DPG", + [{"text": "a cat", "questions": ["Is there a cat?"], "category": "entity"}], + "pruna.evaluation.metrics.metric_qa_accuracy.get_vlm", + [1.0], + "qa_accuracy", + ), + ( + "ImgEdit", + "ImgEdit", + [{"text": "make it blue", "category": "adjust", "judge_prompt": "score the edit"}], + "pruna.evaluation.metrics.metric_img_edit_score.get_vlm", + ['{"score": 8}'], + "img_edit_score", + ), + ], +) +def test_benchmark_vlm_metrics_end_to_end( + monkeypatch, + benchmark_name: str, + dataset_key: str, + records: list[dict], + module_patch: str, + score_value, + expected_name: str, +) -> None: + """Benchmark wiring should exercise VLM metrics end to end with benchmark auxiliaries.""" + datamodule = _prompt_benchmark_datamodule(records) + monkeypatch.setitem( + base_datasets, + dataset_key, + ( + lambda dm=datamodule: (dm.train_dataset, dm.val_dataset, dm.test_dataset), + "prompt_with_auxiliaries_collate", + {}, + ), + ) + + stub_vlm = MagicMock(spec=BaseVLM) + if expected_name in {"vqa", "qa_accuracy"}: + stub_vlm.score.return_value = score_value + else: + stub_vlm.generate.return_value = score_value + + with patch(module_patch, return_value=stub_vlm): + agent = EvaluationAgent.from_benchmark(benchmark_name, device="cpu") + + class FakeModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.dummy = torch.nn.Parameter(torch.zeros(1)) + + def run_inference(self, batch): + return _dummy_image(batch=1) + + agent.device = "cpu" + agent.device_map = None + agent.update_stateful_metrics(FakeModel(), agent.task.get_single_stateful_metrics(), []) + results = agent.compute_stateful_metrics(agent.task.get_single_stateful_metrics(), []) + + _require(len(results) == 1) + _require(results[0].name == expected_name) + _require(isinstance(results[0].result, float)) + if expected_name == "text_score": + _require(results[0].result == 0.0) + else: + _require(results[0].result > 0.0) + + +@pytest.mark.cpu +def test_text_score_with_list_str_gt() -> None: + """Test TextScoreMetric accepts List[str] ground truth from text_score_collate.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["hello world"] + + metric = TextScoreMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu") + images = _dummy_image(batch=1) + metric.update(["a prompt"], ["hello world"], images) + result = metric.compute() + + _require(result.result == 0.0) + mock_vlm.generate.assert_called_once() + + +@pytest.mark.cpu +@pytest.mark.integration +@pytest.mark.skip(reason="Requires OPENAI_API_KEY; run manually with: pytest -m integration") +@pytest.mark.parametrize("structured_output", [False, True]) +def test_vlm_metrics_litellm_api(structured_output: bool) -> None: + """Integration test with real litellm API (requires OPENAI_API_KEY).""" + import os + + if not os.getenv("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not set") + + metric = VQAMetric( + vlm_type="litellm", + model_name="gpt-4o", + device="cpu", + structured_output=structured_output, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + metric.update(prompts, images, images) + result = metric.compute() + _require(0.0 <= result.result <= 1.0)