diff --git a/.github/actions/setup-uv-project/action.yml b/.github/actions/setup-uv-project/action.yml index 74f1ac9b..a5f9a7fa 100644 --- a/.github/actions/setup-uv-project/action.yml +++ b/.github/actions/setup-uv-project/action.yml @@ -12,4 +12,4 @@ runs: github-token: ${{ github.token }} - shell: bash - run: uv sync --extra dev --extra lmharness --extra vllm + run: uv sync --extra dev --extra vllm --extra evaluation diff --git a/pyproject.toml b/pyproject.toml index 5b1eb704..b22eebb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -191,6 +191,9 @@ full = [ vbench = [ "vbench-pruna; sys_platform != 'darwin'", ] +rapidata = [ + "rapidata>=3.0.0" +] dev = [ "wget", "python-dotenv", @@ -222,6 +225,10 @@ cpu = [] lmharness = [ "lm-eval>=0.4.0" ] +evaluation = [ + "pruna[rapidata]", + "pruna[lmharness]" +] intel = [ "intel-extension-for-pytorch>=2.7.0", ] diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 5b713dea..674bf962 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -26,9 +26,10 @@ from pruna.data.utils import move_batch_to_device from pruna.engine.pruna_model import PrunaModel from pruna.engine.utils import get_device, move_to_device, safe_memory_cleanup, set_to_best_available_device +from pruna.evaluation.metrics.context_mixin import EvaluationContextMixin from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_stateful import StatefulMetric -from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.result import MetricResult, MetricResultProtocol from pruna.evaluation.metrics.utils import ensure_device_consistency, get_device_map, group_metrics_by_inheritance from pruna.evaluation.task import Task from pruna.logging.logger import pruna_logger @@ -71,8 +72,8 @@ def __init__( raise ValueError("When not using 'task' parameter, both 'request' and 'datamodule' must be provided.") self.task = Task(request=request, datamodule=datamodule, device=device) - self.first_model_results: List[MetricResult] = [] - self.subsequent_model_results: List[MetricResult] = [] + self.first_model_results: List[MetricResultProtocol] = [] + self.subsequent_model_results: List[MetricResultProtocol] = [] self.device = set_to_best_available_device(self.task.device) self.cache: List[Tensor] = [] self.evaluation_for_first_model: bool = True @@ -124,18 +125,20 @@ def from_benchmark( ) return cls(task=task) - def evaluate(self, model: Any) -> List[MetricResult]: + def evaluate(self, model: Any, model_name: str | None = None) -> List[MetricResultProtocol]: """ Evaluate models using different metric types. Parameters ---------- - model : PrunaModel + model : Any The model to evaluate. + model_name : str | None, optional + The name of the model to evaluate. Required for rapidata benchmark submission. Returns ------- - List[MetricResult] + List[MetricResultProtocol] The results of the model. """ results = [] @@ -146,6 +149,10 @@ def evaluate(self, model: Any) -> List[MetricResult]: pairwise_metrics = self.task.get_pairwise_stateful_metrics() stateless_metrics = self.task.get_stateless_metrics() + for metric in single_stateful_metrics: + if isinstance(metric, EvaluationContextMixin): + metric.current_context = model_name + # Update and compute stateful metrics. pruna_logger.info("Evaluating stateful metrics.") with torch.no_grad(): @@ -278,7 +285,7 @@ def update_stateful_metrics( def compute_stateful_metrics( self, single_stateful_metrics: List[StatefulMetric], pairwise_metrics: List[StatefulMetric] - ) -> List[MetricResult]: + ) -> List[MetricResultProtocol]: """ Compute stateful metrics. @@ -296,16 +303,20 @@ def compute_stateful_metrics( """ results = [] for stateful_metric in single_stateful_metrics: - results.append(stateful_metric.compute()) + result = stateful_metric.compute() + if result is not None: + results.append(result) stateful_metric.reset() if not self.evaluation_for_first_model and self.task.is_pairwise_evaluation(): for pairwise_metric in pairwise_metrics: - results.append(pairwise_metric.compute()) + result = pairwise_metric.compute() + if result is not None: + results.append(result) pairwise_metric.reset() return results - def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[Any]) -> List[MetricResult]: + def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[Any]) -> List[MetricResultProtocol]: """ Compute stateless metrics. diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index 1a12f623..357b2d48 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -26,6 +26,12 @@ from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper +try: + from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric as RapidataMetric +except ModuleNotFoundError as e: + if e.name != "rapidata": + raise + __all__ = [ "MetricRegistry", "TorchMetricWrapper", @@ -46,3 +52,6 @@ "AestheticLAION", "LMEvalMetric", ] + +if "RapidataMetric" in globals(): + __all__.append("RapidataMetric") diff --git a/src/pruna/evaluation/metrics/async_mixin.py b/src/pruna/evaluation/metrics/async_mixin.py new file mode 100644 index 00000000..6ba137f3 --- /dev/null +++ b/src/pruna/evaluation/metrics/async_mixin.py @@ -0,0 +1,53 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +from typing import Any + + +class AsyncEvaluationMixin(ABC): + """ + Mixin for metrics that submit to external evaluation services and retrieve results asynchronously. + + Subclasses implement create_request() to set up an evaluation + (e.g., create a leaderboard) and retrieve_results() to retrieve + outcomes (e.g., standings from human evaluators). + """ + + @abstractmethod + def create_async_request(self, *args, **kwargs) -> Any: + """ + Create/configure an evaluation request on the external service. + + Parameters + ---------- + *args : + Variable length argument list. + **kwargs : + Arbitrary keyword arguments. + """ + + @abstractmethod + def retrieve_async_results(self, *args, **kwargs) -> Any: + """ + Retrieve results from the external service. + + Parameters + ---------- + *args : + Variable length argument list. + **kwargs : + Arbitrary keyword arguments. + """ diff --git a/src/pruna/evaluation/metrics/context_mixin.py b/src/pruna/evaluation/metrics/context_mixin.py new file mode 100644 index 00000000..ea4f5b2c --- /dev/null +++ b/src/pruna/evaluation/metrics/context_mixin.py @@ -0,0 +1,62 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC + + +class EvaluationContextMixin(ABC): + """ + Mixin for metrics that evaluate multiple models sequentially. + + Provides a current_context property that tracks which model is being + evaluated. Setting a new context triggers on_context_change(), which + subclasses can override to reset state between models. + """ + + _current_context: str | None = None + + @property + def current_context(self) -> str | None: + """ + Return the current context. + + Returns + ------- + str | None + The current context. + """ + return self._current_context + + @current_context.setter + def current_context(self, value: str | None) -> None: + """ + Set the current context. + + Parameters + ---------- + value : str + The new context. + """ + self._current_context = value + self.on_context_change() + + def on_context_change(self) -> None: + """Hook called when the context changes. Override to reset state.""" + pass + + def _require_context(self) -> None: + """Raise if no context has been set.""" + if self._current_context is None: + raise ValueError("No context set. Set current_context first.") diff --git a/src/pruna/evaluation/metrics/metric_rapiddata.py b/src/pruna/evaluation/metrics/metric_rapiddata.py new file mode 100644 index 00000000..d715373a --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_rapiddata.py @@ -0,0 +1,573 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import shutil +import tempfile +import time +from pathlib import Path +from typing import Any, Callable, List, Literal + +import PIL.Image +import torch +from rapidata import RapidataClient +from rapidata.rapidata_client.benchmark.rapidata_benchmark import RapidataBenchmark +from torch import Tensor +from torchvision.utils import save_image + +from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.evaluation.metrics.async_mixin import AsyncEvaluationMixin +from pruna.evaluation.metrics.context_mixin import EvaluationContextMixin +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.result import CompositeMetricResult +from pruna.evaluation.metrics.utils import PAIRWISE, SINGLE, get_call_type_for_single_metric, metric_data_processor +from pruna.logging.logger import pruna_logger + +METRIC_RAPIDATA = "rapidata" + + +# We don't use the MetricRegistry here +# because we need to instantiate the Metric directly with benchmark and leaderboards. +class RapidataMetric(StatefulMetric, AsyncEvaluationMixin, EvaluationContextMixin): + """ + Evaluate models with human feedback via the Rapidata platform https://www.rapidata.ai/. + + Parameters + ---------- + call_type : str + How to extract inputs from (x, gt, outputs). Only "single" is supported. + client : RapidataClient | None + The Rapidata client to use. If None, a new one is created. + rapidata_client_id : str | None + The client ID of the Rapidata client. + If none, the credentials are read from the environment variable RAPIDATA_CLIENT_ID. + If credentials are not found in the environment variable, you will be prompted to login via browser. + rapidata_client_secret : str | None + The client secret of the Rapidata client. + If none, the credentials are read from the environment variable RAPIDATA_CLIENT_SECRET. + If credentials are not found in the environment variable, you will be prompted to login via browser. + *args : + Additional arguments passed to StatefulMetric. + **kwargs : Any + Additional keyword arguments passed to StatefulMetric. + + Examples + -------- + Standalone usage:: + metric = RapidataMetric() + # OR metric = RapidataMetric.from_benchmark_id("69bc528fa858d3fbc1ea1475") + + metric.create_benchmark("my_bench", prompts) + metric.create_request("Quality", instruction="Which image looks better?") + + metric.set_current_context("model_a") + metric.update(prompts, ground_truths, outputs_a) + metric.compute() + + metric.set_current_context("model_b") + metric.update(prompts, ground_truths, outputs_b) + metric.compute() + + # wait for human votes + overall = metric.retrieve_results() + """ + + media_cache: List[torch.Tensor | PIL.Image.Image | str] + prompt_cache: List[str] + # With every metric higher is actually better, + # Because for negative questions like "Which image has more errors?" + # We create the leaderboard with inverse_ranking=True, which reverses the ranking. + higher_is_better: bool = True + default_call_type: str = "x_y" + metric_name: str = METRIC_RAPIDATA + + def __init__( + self, + call_type: str = SINGLE, + client: RapidataClient | None = None, + rapidata_client_id: str | None = None, + rapidata_client_secret: str | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.client = client or RapidataClient( + client_id=rapidata_client_id, + client_secret=rapidata_client_secret, + ) + if call_type.startswith(PAIRWISE): + raise ValueError("RapidataMetric does not support pairwise metrics. Use a single metric instead.") + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("media_cache", default=[]) + self.add_state("prompt_cache", default=[]) + self.benchmark: RapidataBenchmark | None = None + self.current_context: str | None = None + + @classmethod + def from_rapidata_benchmark( + cls, + benchmark: RapidataBenchmark | str, + rapidata_client_id: str | None = None, + rapidata_client_secret: str | None = None + ) -> RapidataMetric: + """ + Create a RapidataMetric from an existing RapidataBenchmark. + + Parameters + ---------- + benchmark : RapidataBenchmark | str + The benchmark to attach to. Can be a RapidataBenchmark object or a string (benchmark ID). + rapidata_client_id : str | None + The client ID of the Rapidata client. + rapidata_client_secret : str | None + The client secret of the Rapidata client. + + Returns + ------- + RapidataMetric + The created metric. + """ + metric = cls( + rapidata_client_id=rapidata_client_id, + rapidata_client_secret=rapidata_client_secret, + ) + if isinstance(benchmark, RapidataBenchmark): + metric.benchmark = benchmark + elif isinstance(benchmark, str): + metric.benchmark = metric.client.mri.get_benchmark_by_id(benchmark) + else: + raise ValueError(f"Invalid benchmark: {benchmark}. Expected a RapidataBenchmark or a string.") + return metric + + def create_benchmark( + self, + name: str, + data: list[str] | PrunaDataModule, + data_assets: list[str] | None = None, + split: Literal["test", "val", "train"] = "test", + **kwargs, + ) -> None: + """ + Register a new benchmark on the Rapidata platform. + + The benchmark defines the prompt pool. Any data submitted to + leaderboards later must be drawn from this pool. + + Prompts can be provided as a list of strings or as a PrunaDataModule. + When using a list of strings, you can optionally pass data_assets as a list of file paths or URLs. + When using a PrunaDataModule, data assets are extracted automatically from the datamodule, if available. + + Parameters + ---------- + name : str + The name of the benchmark. + data : list[str] | PrunaDataModule + The prompts or dataset to benchmark against. + data_assets : list[str] | None + The assets to attach to the prompts. + For instance, if you wish to benchmark an image editing model, + you can pass the original images as data_assets. + split : str, optional + Which split to use when data is a PrunaDataModule. Default is "test". + **kwargs : Any + Additional keyword arguments passed to the Rapidata API. + """ + if self.benchmark is not None: + raise ValueError("Benchmark already created. Use from_benchmark() to attach to an existing one.") + + # Rapidata benchmarks only accept a list of string, + # so we need to convert the PrunaDataModule to a list of strings. + if isinstance(data, PrunaDataModule): + split_map = {"test": data.test_dataset, "val": data.val_dataset, "train": data.train_dataset} + dataset = split_map[split] + # PrunaDataModule dataset loaders always renames the prompts column to "text" + if hasattr(dataset, "column_names") and "text" in dataset.column_names: + data = list(dataset["text"]) + data_assets = None # When using a PrunaDataModule, we need to get the data assets from the datamodule. + if "image" in dataset.column_names: + images = list(dataset["image"]) # Pruna text to image datasets always have an "image" column. + # Rapidata only accepts file paths or URLs, so we need to convert the images to file paths. + data_assets = self._prepare_media_for_upload(images) + else: + raise ValueError( + "Could not extract prompts from dataset.\n " + "Expected a 'text' column. Please use a suitable dataset from Pruna \ + or pass a list[str] directly instead." + ) + + self.benchmark = self.client.mri.create_new_benchmark(name, prompts=data, prompt_assets=data_assets, **kwargs) + + def create_async_request( + self, + name: str, + instruction: str, + show_prompt: bool = False, + show_prompt_assets: bool = False, + **kwargs, + ) -> None: + """ + Add a leaderboard (evaluation criterion) to the benchmark. + + Each leaderboard defines a single instruction that human raters see + when comparing model outputs (e.g. "Which image has higher quality?" + or "Which image is more aligned with the prompt?"). + + You can create multiple leaderboards to evaluate different quality dimensions. + Must be called after :meth:`create_benchmark` (or after attaching a + benchmark via :meth:`from_rapidata_benchmark`). + + Parameters + ---------- + name : str + The name of the leaderboard. + instruction : str + The evaluation instruction shown to human raters. + show_prompt : bool, optional + Whether to show the prompt to raters. Default is False. + show_prompt_assets : bool, optional + Whether to show the prompt assets to raters. Default is False. + **kwargs : Any + Additional keyword arguments passed to the Rapidata API. + """ + self._require_benchmark() + self.benchmark.create_leaderboard(name, instruction, show_prompt, show_prompt_assets, **kwargs) + + def set_current_context(self, model_name: str, **kwargs) -> None: + """ + Set which model is currently being evaluated. + + Call this before the :meth:`update` / :meth:`compute` cycle for each + model. At least two models must be submitted before meaningful + human comparison can begin. + + Parameters + ---------- + model_name : str + The name of the model to evaluate. + **kwargs : Any + Additional keyword arguments. + """ + self.current_context = model_name + self.reset() # Clear the cache for the new model. + + def update(self, x: List[Any] | Tensor, gt: List[Any] | Tensor, outputs: Any) -> None: + """ + Accumulate model outputs for the current model. + + Parameters + ---------- + x : List[Any] | Tensor + The input data (prompts). + gt : List[Any] | Tensor + The ground truth data. + outputs : Any + The model outputs (generated media). + """ + self._require_benchmark() + self._require_context() + inputs = metric_data_processor(x, gt, outputs, self.call_type) + self.prompt_cache.extend(inputs[0]) + self.media_cache.extend(inputs[1]) + + def compute(self) -> None: + """ + Submit the accumulated outputs for the current model to Rapidata. + + Converts cached media to uploadable file paths if necessary (saving tensors and + PIL images to a temporary directory), submits them to the benchmark, + and cleans up temporary files. + + This method does **not** return a result — human evaluation is + asynchronous. Use :meth:`retrieve_results` or + :meth:`retrieve_granular_results` once enough votes have been + collected. + """ + self._require_benchmark() + self._require_context() + if not self.media_cache: + raise ValueError("No data accumulated. Call update() before compute().") + + media = self._prepare_media_for_upload() + + # Ignoring the type error because _require_context() has already been called, but ty can't see it. + self.benchmark.evaluate_model( + self.current_context, # type: ignore[arg-type] + media=media, + prompts=self.prompt_cache, + ) + + self._cleanup_temp_media() + + pruna_logger.warning( + "Sent evaluation request for model '%s' to Rapidata.\n " + "It may take a while to collect votes from human raters.\n " + "Use retrieve_results() to check scores later, " + "or monitor progress at: " + "https://app.rapidata.ai/mri/benchmarks/%s", + self.current_context, + self.benchmark.id, + ) + + def on_context_change(self) -> None: + """Reset the cache when the context changes.""" + self.reset() + + @staticmethod + def _is_not_ready_error(exc: Exception) -> bool: + """ + Search for a ValidationError in the exception chain. + + When the benchmark is not finished yet, the API throws a pydantic ValidationError + we are catching it and returning None to indicate that the benchmark is not ready yet, + rather than straight up failing with an exception. + """ + return "ValidationError" in type(exc).__name__ + + def _fetch_standings(self, api_call, *args, **kwargs): + """ + Barebones API call wrapper that catches ValidationError and returns None if the benchmark is not ready yet. + + Since the core logic between the overall and granular standings is the same, + we can use a single function to fetch the standings. + + Parameters + ---------- + api_call : callable + The API call to make. + *args : Any + Additional arguments passed to the API call. + **kwargs : Any + Additional keyword arguments passed to the API call. + """ + try: + return api_call(*args, **kwargs) + except Exception as e: + if not self._is_not_ready_error(e): + raise + return None + + def _fetch_overall_standings(self, *args, **kwargs) -> tuple[CompositeMetricResult | None, bool]: + """ + Retrieve overall standings for the benchmark. + + Returns a tuple where the first element is the composite score of all leaderboards in the benchmark, + and the second element is a boolean indicating whether the benchmark is finished yet. + + Parameters + ---------- + *args : Any + Additional arguments passed to the Rapidata API. + **kwargs : Any + Additional keyword arguments passed to the Rapidata API. + + Returns + ------- + CompositeMetricResult | None + The overall standings or None if the benchmark is not finished yet. + """ + standings = self._fetch_standings(self.benchmark.get_overall_standings, *args, **kwargs) + if standings is None: + return None, False + return CompositeMetricResult( + name=self.metric_name, + params={}, + result=dict(zip(standings["name"], standings["score"])), + higher_is_better=self.higher_is_better, + ), True + + def _fetch_granular_standings(self, *args, **kwargs) -> tuple[List[CompositeMetricResult] | None, bool]: + """ + Retrieve standings for all leaderboards. + + Returns a tuple where the first element is a list of results, one per leaderboard, + and the second element is a boolean indicating whether all of the leaderboards (the benchmark) is finished yet. + + Parameters + ---------- + *args : Any + Additional arguments passed to the Rapidata API. + **kwargs : Any + Additional keyword arguments passed to the Rapidata API. + + Returns + ------- + List[CompositeMetricResult] | None + A list of results, one per leaderboard, or None if the benchmark is not finished yet. + """ + results = [] + all_finished = True + for leaderboard in self.benchmark.leaderboards: + standings = self._fetch_standings(leaderboard.get_standings, *args, **kwargs) + if standings is None: + all_finished = False + continue + results.append(CompositeMetricResult( + name=leaderboard.name, + params={"instruction": leaderboard.instruction}, + result=dict(zip(standings["name"], standings["score"])), + higher_is_better=self.higher_is_better, + )) + return results, all_finished + + def _fetch_with_retry_option( + self, + fetch_fn: Callable, + is_blocking: bool, + timeout: float, + poll_interval: float, + *args, + **kwargs, + ) -> CompositeMetricResult | List[CompositeMetricResult] | None: + """ + Wait for the results or return whatever we have as is from the benchmark. + + If is_blocking is True, it will poll until the results are ready or the timeout is reached. + If is_blocking is False, it will return the results immediately if they are ready, + otherwise it will return None and log a warning. + + Parameters + ---------- + fetch_fn : callable + The function to fetch the standings from the benchmark. + is_blocking : bool + Whether to block and wait for the results to be ready. + timeout : float + The maximum time to wait for the results to be ready. + poll_interval : float + The interval in seconds to poll for the results. + *args : Any + Additional arguments passed to the Rapidata API. + **kwargs : Any + Additional keyword arguments passed to the Rapidata API. + """ + deadline = time.monotonic() + timeout + while True: + result, is_finished = fetch_fn(*args, **kwargs) + if is_finished: # The benchmark is finished, we don't need to check anything else, just return the result. + return result + if not is_blocking: # The benchmark is not finished yet, but the user doesn't want to keep on polling. + pruna_logger.warning( + "The benchmark hasn't finished yet. " + "Please wait for more votes and try again." + ) + return result # Return whatever we have as is. + if time.monotonic() + poll_interval > deadline: # The timeout is reached, we raise an exception. + raise TimeoutError( + f"Benchmark results not ready after {timeout:.0f}s. " + f"Monitor at: https://app.rapidata.ai/mri/benchmarks/{self.benchmark.id}" + ) + pruna_logger.info("Results not ready yet, retrying in %ds...", poll_interval) + time.sleep(poll_interval) + + def retrieve_async_results( + self, + is_granular: bool = False, + is_blocking: bool = False, + timeout: float = 3600, + poll_interval: float = 30, + *args, + **kwargs, + ) -> List[CompositeMetricResult] | CompositeMetricResult | None: + """ + Retrieve standings from the benchmark. + + Parameters + ---------- + is_granular : bool, optional + If True, return per-leaderboard results (partial results + are returned for any leaderboard that is ready). + If False, return overall aggregated standings. + is_blocking : bool, optional + If True, poll until results are ready or *timeout* is reached. + timeout : float, optional + Maximum seconds to wait when blocking. Default is 3600. + poll_interval : float, optional + Seconds between polling attempts when blocking. Default is 30. + *args : Any + Additional arguments passed to the Rapidata API. + **kwargs : Any + Additional keyword arguments passed to the Rapidata API. + + Returns + ------- + List[CompositeMetricResult] | CompositeMetricResult | None + Granular returns a list (possibly partial), overall returns + a single result or None if not ready. + + Raises + ------ + TimeoutError + If *is_blocking* is True and results are not ready within *timeout*. + """ + self._require_benchmark() + fetch_fn = self._fetch_granular_standings if is_granular else self._fetch_overall_standings + return self._fetch_with_retry_option(fetch_fn, is_blocking, timeout, poll_interval, **kwargs) + + def _require_benchmark(self) -> None: + """Raise if no benchmark has been created or attached.""" + if self.benchmark is None: + raise ValueError( + "No benchmark configured. " + "Call create_benchmark(), or use from_benchmark() / from_benchmark_id()." + ) + + def _prepare_media_for_upload(self, media: list[torch.Tensor | PIL.Image.Image | str] | None = None) -> list[str]: + """ + Convert cached media to file paths that Rapidata can upload. + + Handles three cases: + - str: assumed to be a URL or file path, passed through as-is + - PIL.Image: saved to a temporary file + - torch.Tensor: saved to a temporary file + + Parameters + ---------- + media : list[torch.Tensor | PIL.Image.Image | str] | None + The media to prepare for upload. If None, the media cache is used. + + Returns + ------- + list[str] + A list of URLs or file paths. + """ + self._temp_dir = Path(tempfile.mkdtemp(prefix="rapidata_")) + media_paths = [] + + for i, item in enumerate(media or self.media_cache): + if isinstance(item, str): + media_paths.append(item) + elif isinstance(item, PIL.Image.Image): + path = self._temp_dir / f"{i}.png" + item.save(path) + media_paths.append(str(path)) + elif isinstance(item, torch.Tensor): + path = self._temp_dir / f"{i}.png" + tensor = item.float() + if tensor.max() > 1.0: + tensor = tensor / 255.0 + save_image(tensor, path) + media_paths.append(str(path)) + else: + raise TypeError( + f"Unsupported media type: {type(item)}. " + "Expected str (URL/path), PIL.Image, or torch.Tensor." + ) + + return media_paths + + def _cleanup_temp_media(self) -> None: + """Remove temporary files created for upload.""" + if hasattr(self, "_temp_dir") and self._temp_dir.exists(): + shutil.rmtree(self._temp_dir) diff --git a/src/pruna/evaluation/metrics/metric_stateful.py b/src/pruna/evaluation/metrics/metric_stateful.py index 39fddcf6..02982980 100644 --- a/src/pruna/evaluation/metrics/metric_stateful.py +++ b/src/pruna/evaluation/metrics/metric_stateful.py @@ -122,7 +122,7 @@ def update(self, *args, **kwargs) -> None: """ @abstractmethod - def compute(self) -> Any: + def compute(self,) -> Any: """Override this method to compute the final metric value.""" def is_pairwise(self) -> bool: diff --git a/src/pruna/evaluation/metrics/result.py b/src/pruna/evaluation/metrics/result.py index f1e13ca8..93a9cd0e 100644 --- a/src/pruna/evaluation/metrics/result.py +++ b/src/pruna/evaluation/metrics/result.py @@ -14,13 +14,52 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Protocol, runtime_checkable + + +@runtime_checkable +class MetricResultProtocol(Protocol): + """ + Protocol defining the shared interface for all metric results. + + Any metric result class should implement these attributes and methods + to be compatible with the evaluation pipeline. + + # Have to include this to prevent ty errors. + + Parameters + ---------- + *args : + Additional arguments passed to the MetricResultProtocol. + **kwargs : + Additional keyword arguments passed to the MetricResultProtocol. + + Attributes + ---------- + name : str + The name of the metric. + params : Dict[str, Any] + The parameters of the metric. + higher_is_better : Optional[bool] + Whether larger values mean better performance. + metric_units : Optional[str] + The units of the metric. + """ + + name: str + params: Dict[str, Any] + higher_is_better: Optional[bool] + metric_units: Optional[str] + + def __str__(self) -> str: + """Return a human-readable representation of the metric result.""" + ... @dataclass class MetricResult: """ - A class to store the results of a metric. + A class to store the result of a single-value metric. Parameters ---------- @@ -42,7 +81,7 @@ class MetricResult: higher_is_better: Optional[bool] = None metric_units: Optional[str] = None - def __post_init__(self): + def __post_init__(self) -> None: """Checker that metric_units and higher_is_better are consistent with the result.""" if self.metric_units is None: object.__setattr__(self, "metric_units", self.params.get("metric_units")) @@ -67,7 +106,7 @@ def from_results_dict( metric_name: str, metric_params: Dict[str, Any], results_dict: Dict[str, Any], - ) -> "MetricResult": + ) -> MetricResultProtocol: """ Create a MetricResult from a raw results dictionary. @@ -89,3 +128,50 @@ def from_results_dict( result = results_dict[metric_name] assert isinstance(result, (float, int)), f"Result for metric {metric_name} is not a float or int" return cls(metric_name, metric_params, result) + + +@dataclass +class CompositeMetricResult: + """ + A class to store the result of a metric that returns multiple labeled scores. + + This is used for metrics where a single evaluation request produces + scores for multiple entries, such as asynchronous metrics that + return labeled scores for different settings / models. + + Parameters + ---------- + name : str + The name of the metric. + params : Dict[str, Any] + The parameters of the metric. + result : Dict[str, float | int] + A mapping of labels to scores. + higher_is_better : Optional[bool] + Whether larger values mean better performance. + metric_units : Optional[str] + The units of the metric. + """ + + name: str + params: Dict[str, Any] + result: Dict[str, float | int] + higher_is_better: Optional[bool] = None + metric_units: Optional[str] = None + + def __str__(self) -> str: + """ + Return a string representation of the CompositeMetricResult. + + Each labeled score is displayed on its own line. + + Returns + ------- + str + A string representation of the CompositeMetricResult. + """ + lines = [f"{self.name}:"] + for key, score in self.result.items(): + units = f" {self.metric_units}" if self.metric_units else "" + lines.append(f" {key}: {score}{units}") + return "\n".join(lines) diff --git a/tests/algorithms/testers/moe_kernel_tuner.py b/tests/algorithms/testers/moe_kernel_tuner.py index 9a754cf3..85661a83 100644 --- a/tests/algorithms/testers/moe_kernel_tuner.py +++ b/tests/algorithms/testers/moe_kernel_tuner.py @@ -34,7 +34,6 @@ def post_smash_hook(self, model: PrunaModel) -> None: def _resolve_hf_cache_config_path(self) -> Path: """Read the saved artifact and compute the expected HF cache config path.""" - imported_packages = MoeKernelTuner().import_algorithm_packages() smash_cfg = SmashConfig() diff --git a/tests/evaluation/test_rapidata.py b/tests/evaluation/test_rapidata.py new file mode 100644 index 00000000..ba736d3a --- /dev/null +++ b/tests/evaluation/test_rapidata.py @@ -0,0 +1,351 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch + +import PIL.Image +import pytest +import torch +from datasets import Dataset + +from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric +from rapidata.rapidata_client.benchmark.rapidata_benchmark import RapidataBenchmark +from pruna.evaluation.metrics.result import CompositeMetricResult + + +@pytest.fixture +def mock_client(): + return MagicMock() + + +@pytest.fixture +def metric(mock_client): + return RapidataMetric(client=mock_client) + + +@pytest.fixture +def metric_with_benchmark(metric): + benchmark = MagicMock() + benchmark.id = "bench-123" + benchmark.leaderboards = [] + metric.benchmark = benchmark + metric.higher_is_better = True + return metric + + +@pytest.fixture +def metric_ready(metric_with_benchmark): + metric_with_benchmark.current_context = "test-model" + return metric_with_benchmark + +@pytest.fixture +def metric_ready_with_cleanup(metric_ready): + """metric_ready that auto-cleans temp media after the test.""" + yield metric_ready + metric_ready._cleanup_temp_media() + + +# Initialization with / without a client +def test_default_client_created_when_none_provided(): + """Test that a RapidataClient is created when none is provided.""" + with patch("pruna.evaluation.metrics.metric_rapiddata.RapidataClient") as mock_cls: + mock_cls.return_value = MagicMock() + _ = RapidataMetric() + mock_cls.assert_called_once() + + +def test_custom_client_used(mock_client): + """Test that a custom client is used when provided.""" + m = RapidataMetric(client=mock_client) + assert m.client is mock_client + + +# Creation from existing benchmark +def test_from_benchmark(): + """Test creating a metric from an existing benchmark.""" + benchmark = MagicMock(spec=RapidataBenchmark) + with patch("pruna.evaluation.metrics.metric_rapiddata.RapidataClient"): + m = RapidataMetric.from_rapidata_benchmark(benchmark) + assert m.benchmark is benchmark + + +# Creation from benchmark ID +def test_from_benchmark_id(): + """Test creating a metric from a benchmark ID.""" + with patch("pruna.evaluation.metrics.metric_rapiddata.RapidataClient") as mock_cls: + mock_instance = MagicMock() + mock_cls.return_value = mock_instance + mock_instance.mri.get_benchmark_by_id.return_value = MagicMock(id="abc") + m = RapidataMetric.from_rapidata_benchmark("abc") + mock_instance.mri.get_benchmark_by_id.assert_called_once_with("abc") + assert m.benchmark is not None + + +def test_create_benchmark_with_prompt_list(metric, mock_client): + """Test creating a benchmark with a list of prompts.""" + prompts = ["a cat", "a dog"] + metric.create_benchmark("my-bench", data=prompts) + mock_client.mri.create_new_benchmark.assert_called_once_with("my-bench", prompts=prompts, prompt_assets=None) + assert metric.benchmark is not None + + +def test_create_benchmark_from_datamodule(metric, mock_client): + """Test creating a benchmark from a PrunaDataModule.""" + ds = Dataset.from_dict({"text": ["prompt1", "prompt2"]}) + dm = PrunaDataModule(train_ds=ds, val_ds=ds, test_ds=ds, collate_fn=lambda x: x, dataloader_args={}) + + metric.create_benchmark("my-bench", data=dm, split="test") + mock_client.mri.create_new_benchmark.assert_called_once_with("my-bench", prompts=["prompt1", "prompt2"], prompt_assets=None) + + +def test_create_benchmark_raises_if_already_exists(metric_with_benchmark): + """Test that creating a benchmark twice raises.""" + with pytest.raises(ValueError, match="Benchmark already created"): + metric_with_benchmark.create_benchmark("dup", data=["x"]) + +def test_create_request_raises_without_benchmark(metric): + """Test that create_request raises without a benchmark.""" + with pytest.raises(ValueError, match="No benchmark configured"): + metric.create_async_request("quality", "Rate image quality") + + +def test_create_request_delegates_to_leaderboard(metric_with_benchmark): + """Test that create_request delegates to the benchmark.""" + metric_with_benchmark.create_async_request("quality", "Rate image quality") + metric_with_benchmark.benchmark.create_leaderboard.assert_called_once_with( + "quality", "Rate image quality", False, False + ) + + +def test_set_current_context_resets_caches(metric_ready): + """Test that set_current_context resets the caches.""" + metric_ready.prompt_cache.append("leftover") + metric_ready.media_cache.append("leftover") + metric_ready.current_context = "model-b" + assert metric_ready.prompt_cache == [] + assert metric_ready.media_cache == [] + + +def test_update_accumulates_prompts_and_media(metric_ready): + """Test that update accumulates prompts and media.""" + x = ["a cat on a sofa", "a dog in rain"] + gt = [None, None] + outputs = [torch.rand(3, 64, 64), torch.rand(3, 64, 64)] + metric_ready.update(x, gt, outputs) + + assert metric_ready.prompt_cache == x + assert len(metric_ready.media_cache) == 2 + + +def test_update_raises_without_benchmark(metric): + """Test that update raises without a benchmark.""" + metric.current_context = "m" + with pytest.raises(ValueError, match="No benchmark configured"): + metric.update(["p"], [None], [torch.rand(3, 32, 32)]) + + +def test_update_raises_without_context(metric_with_benchmark): + """Test that update raises without a model context.""" + with pytest.raises(ValueError, match="No context set. Set current_context first."): + metric_with_benchmark.update(["p"], [None], [torch.rand(3, 32, 32)]) + +def test_prepare_media_string_passthrough(metric_ready_with_cleanup): + """Test that string URLs/paths are passed through as-is.""" + metric_ready_with_cleanup.media_cache = ["https://example.com/img.png", "/tmp/local.png"] + paths = metric_ready_with_cleanup._prepare_media_for_upload() + assert paths == ["https://example.com/img.png", "/tmp/local.png"] + + +def test_prepare_media_pil_image(metric_ready_with_cleanup): + """Test that PIL images are saved to temp files.""" + img = PIL.Image.new("RGB", (64, 64), color="red") + metric_ready_with_cleanup.media_cache = [img] + paths = metric_ready_with_cleanup._prepare_media_for_upload() + assert len(paths) == 1 + assert Path(paths[0]).exists() + + +def test_prepare_media_tensor(metric_ready): + """Test that tensors are saved to temp files.""" + tensor = torch.rand(3, 64, 64) + metric_ready.media_cache = [tensor] + paths = metric_ready._prepare_media_for_upload() + assert len(paths) == 1 + assert Path(paths[0]).exists() + metric_ready._cleanup_temp_media() + + +def test_prepare_media_tensor_uint8_range(metric_ready): + """Test that tensors in 0-255 range are normalised before saving.""" + tensor = torch.randint(0, 256, (3, 32, 32)).float() + assert tensor.max() > 1.0 + metric_ready.media_cache = [tensor] + paths = metric_ready._prepare_media_for_upload() + assert len(paths) == 1 + assert Path(paths[0]).exists() + metric_ready._cleanup_temp_media() + + +def test_prepare_media_unsupported_type_raises(metric_ready): + """Test that unsupported media types raise.""" + metric_ready.media_cache = [12345] + with pytest.raises(TypeError, match="Unsupported media type"): + metric_ready._prepare_media_for_upload() + + +def test_compute_submits_to_rapidata(metric_ready): + """Test that compute submits the accumulated data.""" + img = PIL.Image.new("RGB", (32, 32)) + metric_ready.media_cache = [img] + metric_ready.prompt_cache = ["a cat"] + metric_ready.compute() + metric_ready.benchmark.evaluate_model.assert_called_once() + call_kwargs = metric_ready.benchmark.evaluate_model.call_args + assert call_kwargs[0][0] == "test-model" + + +def test_compute_raises_when_cache_empty(metric_ready): + """Test that compute raises when no data has been accumulated.""" + with pytest.raises(ValueError, match="No data accumulated"): + metric_ready.compute() + + +def test_compute_raises_without_model_context(metric_with_benchmark): + """Test that compute raises without a model context.""" + with pytest.raises(ValueError, match="No context set. Set current_context first."): + metric_with_benchmark.compute() + + +def test_compute_cleans_up_temp_dir(metric_ready): + """Test that compute removes the temp directory after submission.""" + metric_ready.media_cache = [torch.rand(3, 32, 32)] + metric_ready.prompt_cache = ["test"] + metric_ready.compute() + assert not hasattr(metric_ready, "_temp_dir") or not metric_ready._temp_dir.exists() + + +class _FakeValidationError(Exception): + pass + + +def test_is_not_ready_error_recognises_validation_error(): + assert RapidataMetric._is_not_ready_error(_FakeValidationError()) is True + assert RapidataMetric._is_not_ready_error(RuntimeError()) is False + + +def test_retrieve_non_blocking_returns_result_when_ready(metric_with_benchmark): + metric_with_benchmark.benchmark.get_overall_standings.return_value = { + "name": ["model-a", "model-b"], "score": [0.85, 0.72], + } + result = metric_with_benchmark.retrieve_async_results() + assert isinstance(result, CompositeMetricResult) + assert result.result == {"model-a": 0.85, "model-b": 0.72} + + +def test_retrieve_non_blocking_returns_none_when_not_ready(metric_with_benchmark): + metric_with_benchmark.benchmark.get_overall_standings.side_effect = _FakeValidationError() + assert metric_with_benchmark.retrieve_async_results() is None + + +def test_retrieve_non_blocking_granular_returns_partial(metric_with_benchmark): + lb_ready = MagicMock(name="quality", instruction="Rate quality", inverse_ranking=False) + lb_ready.get_standings.return_value = {"name": ["m-a"], "score": [0.9]} + lb_pending = MagicMock(name="alignment") + lb_pending.get_standings.side_effect = _FakeValidationError() + metric_with_benchmark.benchmark.leaderboards = [lb_ready, lb_pending] + + results = metric_with_benchmark.retrieve_async_results(is_granular=True) + assert len(results) == 1 + assert results[0].result == {"m-a": 0.9} + + +def test_retrieve_reraises_non_validation_error(metric_with_benchmark): + metric_with_benchmark.benchmark.get_overall_standings.side_effect = RuntimeError("boom") + with pytest.raises(RuntimeError, match="boom"): + metric_with_benchmark.retrieve_async_results() + + +@patch("pruna.evaluation.metrics.metric_rapiddata.time") +def test_retrieve_blocking_polls_until_ready(mock_time, metric_with_benchmark): + _clock = iter(range(0, 1000, 10)) + mock_time.monotonic.side_effect = lambda: next(_clock) + + standings = {"name": ["m-a"], "score": [0.9]} + metric_with_benchmark.benchmark.get_overall_standings.side_effect = [ + _FakeValidationError(), _FakeValidationError(), standings, + ] + result = metric_with_benchmark.retrieve_async_results(is_blocking=True, timeout=60, poll_interval=5) + assert isinstance(result, CompositeMetricResult) + assert result.result == {"m-a": 0.9} + assert mock_time.sleep.call_count == 2 + + +@patch("pruna.evaluation.metrics.metric_rapiddata.time") +def test_retrieve_blocking_raises_timeout(mock_time, metric_with_benchmark): + _clock = iter(range(0, 1000, 30)) + mock_time.monotonic.side_effect = lambda: next(_clock) + metric_with_benchmark.benchmark.get_overall_standings.side_effect = _FakeValidationError() + + with pytest.raises(TimeoutError, match="not ready after 60s"): + metric_with_benchmark.retrieve_async_results(is_blocking=True, timeout=60, poll_interval=5) + + +def test_create_benchmark_forwards_explicit_data_assets(metric, mock_client): + """Explicit data_assets are forwarded as prompt_assets.""" + prompts = ["edit this", "fix that"] + assets = ["/imgs/a.png", "/imgs/b.png"] + metric.create_benchmark("bench", data=prompts, data_assets=assets) + mock_client.mri.create_new_benchmark.assert_called_once_with( + "bench", prompts=prompts, prompt_assets=assets, + ) + + +def test_create_benchmark_datamodule_extracts_images(metric, mock_client): + """PrunaDataModule with an 'image' column extracts and converts images to prompt_assets.""" + from datasets import Features, Image as HFImage, Value + img1 = PIL.Image.new("RGB", (32, 32), "red") + img2 = PIL.Image.new("RGB", (32, 32), "blue") + ds = Dataset.from_dict( + {"text": ["prompt1", "prompt2"], "image": [img1, img2]}, + features=Features({"text": Value("string"), "image": HFImage()}), + ) + dm = PrunaDataModule(train_ds=ds, val_ds=ds, test_ds=ds, collate_fn=lambda x: x, dataloader_args={}) + with patch.object(metric, "_prepare_media_for_upload", return_value=["/tmp/0.png", "/tmp/1.png"]) as mock_prep: + metric.create_benchmark("my-bench", data=dm, split="test") + mock_prep.assert_called_once() + images_arg = mock_prep.call_args[0][0] + assert len(images_arg) == 2 + assert all(isinstance(img, PIL.Image.Image) for img in images_arg) + mock_client.mri.create_new_benchmark.assert_called_once_with( + "my-bench", prompts=["prompt1", "prompt2"], prompt_assets=["/tmp/0.png", "/tmp/1.png"], + ) + + +def test_create_benchmark_datamodule_ignores_explicit_data_assets(metric, mock_client): + """When using a PrunaDataModule, explicit data_assets are overridden.""" + ds = Dataset.from_dict({"text": ["p1"]}) + dm = PrunaDataModule(train_ds=ds, val_ds=ds, test_ds=ds, collate_fn=lambda x: x, dataloader_args={}) + metric.create_benchmark("bench", data=dm, data_assets=["/should/be/ignored.png"]) + mock_client.mri.create_new_benchmark.assert_called_once_with( + "bench", prompts=["p1"], prompt_assets=None, + ) + + +def test_create_request_forwards_show_prompt_assets_true(metric_with_benchmark): + """show_prompt_assets=True is forwarded to create_leaderboard.""" + metric_with_benchmark.create_async_request("quality", "Rate quality", show_prompt_assets=True) + metric_with_benchmark.benchmark.create_leaderboard.assert_called_once_with( + "quality", "Rate quality", False, True, + ) + + +def test_prepare_media_uses_explicit_list_over_cache(metric_ready): + """Passing an explicit media list uses it instead of media_cache.""" + metric_ready.media_cache = [torch.rand(3, 32, 32)] # should be ignored + explicit = [PIL.Image.new("RGB", (16, 16))] + + paths = metric_ready._prepare_media_for_upload(explicit) + assert len(paths) == 1 + assert Path(paths[0]).exists() + loaded = PIL.Image.open(paths[0]) + assert loaded.size == (16, 16) + metric_ready._cleanup_temp_media() \ No newline at end of file