From 3b350bad5894f12c0c2da42cec0232a975dc188b Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Thu, 19 Mar 2026 23:37:28 +0000 Subject: [PATCH 1/7] feat: initial implementation for rapidata --- src/pruna/evaluation/evaluation_agent.py | 29 +- src/pruna/evaluation/metrics/__init__.py | 2 + src/pruna/evaluation/metrics/async_mixin.py | 53 ++ .../evaluation/metrics/metric_rapiddata.py | 459 ++++++++++++++++++ .../evaluation/metrics/metric_stateful.py | 32 +- src/pruna/evaluation/metrics/result.py | 101 +++- tests/evaluation/test_rapidata.py | 267 ++++++++++ 7 files changed, 927 insertions(+), 16 deletions(-) create mode 100644 src/pruna/evaluation/metrics/async_mixin.py create mode 100644 src/pruna/evaluation/metrics/metric_rapiddata.py create mode 100644 tests/evaluation/test_rapidata.py diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 5b713dea..ba83c863 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -28,7 +28,7 @@ from pruna.engine.utils import get_device, move_to_device, safe_memory_cleanup, set_to_best_available_device from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_stateful import StatefulMetric -from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.result import MetricResult, MetricResultProtocol from pruna.evaluation.metrics.utils import ensure_device_consistency, get_device_map, group_metrics_by_inheritance from pruna.evaluation.task import Task from pruna.logging.logger import pruna_logger @@ -71,8 +71,8 @@ def __init__( raise ValueError("When not using 'task' parameter, both 'request' and 'datamodule' must be provided.") self.task = Task(request=request, datamodule=datamodule, device=device) - self.first_model_results: List[MetricResult] = [] - self.subsequent_model_results: List[MetricResult] = [] + self.first_model_results: List[MetricResultProtocol] = [] + self.subsequent_model_results: List[MetricResultProtocol] = [] self.device = set_to_best_available_device(self.task.device) self.cache: List[Tensor] = [] self.evaluation_for_first_model: bool = True @@ -124,18 +124,20 @@ def from_benchmark( ) return cls(task=task) - def evaluate(self, model: Any) -> List[MetricResult]: + def evaluate(self, model: Any, model_name: str | None = None) -> List[MetricResultProtocol]: """ Evaluate models using different metric types. Parameters ---------- - model : PrunaModel + model : Any The model to evaluate. + model_name : str | None, optional + The name of the model to evaluate. Required for rapidata benchmark submission. Returns ------- - List[MetricResult] + List[MetricResultProtocol] The results of the model. """ results = [] @@ -146,6 +148,9 @@ def evaluate(self, model: Any) -> List[MetricResult]: pairwise_metrics = self.task.get_pairwise_stateful_metrics() stateless_metrics = self.task.get_stateless_metrics() + for metric in single_stateful_metrics + pairwise_metrics: + metric.set_current_context(model_name=model_name) + # Update and compute stateful metrics. pruna_logger.info("Evaluating stateful metrics.") with torch.no_grad(): @@ -278,7 +283,7 @@ def update_stateful_metrics( def compute_stateful_metrics( self, single_stateful_metrics: List[StatefulMetric], pairwise_metrics: List[StatefulMetric] - ) -> List[MetricResult]: + ) -> List[MetricResultProtocol]: """ Compute stateful metrics. @@ -296,16 +301,20 @@ def compute_stateful_metrics( """ results = [] for stateful_metric in single_stateful_metrics: - results.append(stateful_metric.compute()) + result = stateful_metric.compute() + if result is not None: + results.append(result) stateful_metric.reset() if not self.evaluation_for_first_model and self.task.is_pairwise_evaluation(): for pairwise_metric in pairwise_metrics: - results.append(pairwise_metric.compute()) + result = pairwise_metric.compute() + if result is not None: + results.append(result) pairwise_metric.reset() return results - def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[Any]) -> List[MetricResult]: + def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[Any]) -> List[MetricResultProtocol]: """ Compute stateless metrics. diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index 1a12f623..9653cb3a 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -23,6 +23,7 @@ from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore +from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper @@ -45,4 +46,5 @@ "SharpnessMetric", "AestheticLAION", "LMEvalMetric", + "RapidataMetric", ] diff --git a/src/pruna/evaluation/metrics/async_mixin.py b/src/pruna/evaluation/metrics/async_mixin.py new file mode 100644 index 00000000..8ff3d274 --- /dev/null +++ b/src/pruna/evaluation/metrics/async_mixin.py @@ -0,0 +1,53 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +from typing import Any + + +class AsyncEvaluationMixin(ABC): + """ + Mixin for metrics that submit to external evaluation services and retrieve results asynchronously. + + Subclasses implement create_request() to set up an evaluation + (e.g., create a leaderboard) and retrieve_results() to retrieve + outcomes (e.g., standings from human evaluators). + """ + + @abstractmethod + def create_request(self, *args, **kwargs) -> Any: + """ + Create/configure an evaluation request on the external service. + + Parameters + ---------- + *args : + Variable length argument list. + **kwargs : + Arbitrary keyword arguments. + """ + + @abstractmethod + def retrieve_results(self, *args, **kwargs) -> Any: + """ + Retrieve results from the external service. + + Parameters + ---------- + *args : + Variable length argument list. + **kwargs : + Arbitrary keyword arguments. + """ diff --git a/src/pruna/evaluation/metrics/metric_rapiddata.py b/src/pruna/evaluation/metrics/metric_rapiddata.py new file mode 100644 index 00000000..0d79833d --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_rapiddata.py @@ -0,0 +1,459 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import shutil +import tempfile +from pathlib import Path +from typing import Any, List, Literal + +import PIL.Image +import torch +from rapidata import RapidataClient +from rapidata.rapidata_client.benchmark.rapidata_benchmark import RapidataBenchmark +from torch import Tensor +from torchvision.utils import save_image + +from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.evaluation.metrics.async_mixin import AsyncEvaluationMixin +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.result import CompositeMetricResult +from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor +from pruna.logging.logger import pruna_logger + +METRIC_RAPIDATA = "rapidata" + + +class RapidataMetric(StatefulMetric, AsyncEvaluationMixin): + """ + Evaluate models with human feedback via the Rapidata platform. + + Parameters + ---------- + call_type : str + How to extract inputs from (x, gt, outputs). Default is "single". + client : RapidataClient | None + The Rapidata client to use. If None, a new one is created. + rapidata_client_id : str | None + The client ID of the Rapidata client. + If none, the credentials are read from the environment variable RAPIDATA_CLIENT_ID. + If credentials are not found in the environment variable, you will be prompted to login via browser. + rapidata_client_secret : str | None + The client secret of the Rapidata client. + If none, the credentials are read from the environment variable RAPIDATA_CLIENT_SECRET. + If credentials are not found in the environment variable, you will be prompted to login via browser. + *args : + Additional arguments passed to StatefulMetric. + **kwargs : Any + Additional keyword arguments passed to StatefulMetric. + + Examples + -------- + Standalone usage:: + metric = RapidataMetric() + # OR metric = RapidataMetric.from_benchmark_id("69bc528fa858d3fbc1ea1475") + + metric.create_benchmark("my_bench", prompts) + metric.create_request("Quality", instruction="Which image looks better?") + + metric.set_current_context("model_a") + metric.update(prompts, ground_truths, outputs_a) + metric.compute() + + metric.set_current_context("model_b") + metric.update(prompts, ground_truths, outputs_b) + metric.compute() + + # wait for human votes + overall = metric.retrieve_results() + """ + + media_cache: List[torch.Tensor | PIL.Image.Image | str] + prompt_cache: List[str] + default_call_type: str = "x_y" + higher_is_better: bool = True + metric_name: str = METRIC_RAPIDATA + runs_on: List[str] = ["cpu", "cuda"] + + def __init__( + self, + call_type: str = SINGLE, + client: RapidataClient | None = None, + rapidata_client_id: str | None = None, + rapidata_client_secret: str | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.client = client or RapidataClient( + client_id=rapidata_client_id, + client_secret=rapidata_client_secret, + ) + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("media_cache", default=[]) + self.add_state("prompt_cache", default=[]) + self.benchmark: RapidataBenchmark | None = None + self.current_benchmarked_model: str | None = None + + @classmethod + def from_benchmark( + cls, + benchmark: RapidataBenchmark, + rapidata_client_id: str | None = None, + rapidata_client_secret: str | None = None + ) -> RapidataMetric: + """ + Create a RapidataMetric from an existing RapidataBenchmark. + + Parameters + ---------- + benchmark : RapidataBenchmark + The benchmark to attach to. + rapidata_client_id : str | None + The client ID of the Rapidata client. + rapidata_client_secret : str | None + The client secret of the Rapidata client. + + Returns + ------- + RapidataMetric + The created metric. + """ + metric = cls( + rapidata_client_id=rapidata_client_id, + rapidata_client_secret=rapidata_client_secret, + ) + metric.benchmark = benchmark + return metric + + @classmethod + def from_benchmark_id( + cls, + benchmark_id: str, + rapidata_client_id: str | None = None, + rapidata_client_secret: str | None = None, + ) -> RapidataMetric: + """ + Create a RapidataMetric from an existing benchmark ID. + + Parameters + ---------- + benchmark_id : str + The ID of the benchmark on the Rapidata platform. + rapidata_client_id : str | None + The client ID of the Rapidata client. + rapidata_client_secret : str | None + The client secret of the Rapidata client. + + Returns + ------- + RapidataMetric + The created metric. + """ + metric = cls( + rapidata_client_id=rapidata_client_id, + rapidata_client_secret=rapidata_client_secret, + ) + metric.benchmark = metric.client.mri.get_benchmark_by_id(benchmark_id) + return metric + + def create_benchmark( + self, + name: str, + data: list[str] | PrunaDataModule, + split: Literal["test", "val", "train"] = "test", + **kwargs, + ) -> None: + """ + Register a new benchmark on the Rapidata platform. + + The benchmark defines the prompt pool. Any data submitted to + leaderboards later must be drawn from this pool. + + Parameters + ---------- + name : str + The name of the benchmark. + data : list[str] | PrunaDataModule + The prompts or dataset to benchmark against. + split : str, optional + Which split to use when data is a PrunaDataModule. Default is "test". + **kwargs : Any + Additional keyword arguments passed to the Rapidata API. + """ + if self.benchmark is not None: + raise ValueError("Benchmark already created. Use from_benchmark() to attach to an existing one.") + + if isinstance(data, PrunaDataModule): + split_map = {"test": data.test_dataset, "val": data.val_dataset, "train": data.train_dataset} + dataset = split_map[split] + # PrunaDataModule dataset loaders always renames the prompts column to "text" + if hasattr(dataset, "column_names") and "text" in dataset.column_names: + data = list(dataset["text"]) + else: + raise ValueError( + "Could not extract prompts from dataset.\n " + "Expected a 'text' column. Please use a suitable dataset from Pruna \ + or pass a list[str] directly instead." + ) + + self.benchmark = self.client.mri.create_new_benchmark(name, prompts=data, **kwargs) + + def create_request( + self, + name: str, + instruction: str, + show_prompt: bool = False, + **kwargs, + ) -> None: + """ + Add a leaderboard (evaluation criterion) to the benchmark. + + Each leaderboard defines a single instruction that human raters see + when comparing model outputs (e.g. "Which image has higher quality?" + or "Which image is more aligned with the prompt?"). + + You can create multiple leaderboards to evaluate different quality dimensions. + Must be called after :meth:`create_benchmark` (or after attaching a + benchmark via :meth:`from_benchmark` / :meth:`from_benchmark_id`). + + Parameters + ---------- + name : str + The name of the leaderboard. + instruction : str + The evaluation instruction shown to human raters. + show_prompt : bool, optional + Whether to show the prompt to raters. Default is False. + **kwargs : Any + Additional keyword arguments passed to the Rapidata API. + """ + self._require_benchmark() + self.benchmark.create_leaderboard(name, instruction, show_prompt, **kwargs) + + def set_current_context(self, model_name: str, **kwargs) -> None: + """ + Set which model is currently being evaluated. + + Call this before the :meth:`update` / :meth:`compute` cycle for each + model. At least two models must be submitted before meaningful + human comparison can begin. + + Parameters + ---------- + model_name : str + The name of the model to evaluate. + **kwargs : Any + Additional keyword arguments. + """ + self.current_benchmarked_model = model_name + self.reset() + + def update(self, x: List[Any] | Tensor, gt: List[Any] | Tensor, outputs: Any) -> None: + """ + Accumulate model outputs for the current model. + + Parameters + ---------- + x : List[Any] | Tensor + The input data (prompts). + gt : List[Any] | Tensor + The ground truth data. + outputs : Any + The model outputs (generated media). + """ + self._require_benchmark() + self._require_model() + inputs = metric_data_processor(x, gt, outputs, self.call_type) + self.prompt_cache.extend(inputs[0]) + self.media_cache.extend(inputs[1]) + + def compute(self) -> None: + """ + Submit the accumulated outputs for the current model to Rapidata. + + Converts cached media to uploadable file paths if necessary (saving tensors and + PIL images to a temporary directory), submits them to the benchmark, + and cleans up temporary files. + + This method does **not** return a result — human evaluation is + asynchronous. Use :meth:`retrieve_results` or + :meth:`retrieve_granular_results` once enough votes have been + collected. + """ + self._require_model() + if not self.media_cache: + raise ValueError("No data accumulated. Call update() before compute().") + + media = self._prepare_media_for_upload() + + # Ignoring the type error because _require_model() has already been called, but ty can't see it. + self.benchmark.evaluate_model( + self.current_benchmarked_model, # type: ignore[arg-type] + media=media, + prompts=self.prompt_cache, + ) + + self._cleanup_temp_media() + + pruna_logger.info( + "Sent evaluation request for model '%s' to Rapidata.\n " + "It may take a while to collect votes from human raters.\n " + "Use retrieve_results() to check scores later, " + "or monitor progress at: " + "https://app.rapidata.ai/mri/benchmarks/%s", + self.current_benchmarked_model, + self.benchmark.id, + ) + + def retrieve_results(self, *args, **kwargs) -> CompositeMetricResult | None: + """ + Retrieve aggregated standings across all leaderboards. + + Parameters + ---------- + *args : Any + Additional arguments passed to the Rapidata API. + **kwargs : Any + Additional keyword arguments passed to the Rapidata API. + + Returns + ------- + CompositeMetricResult | None + The overall standings, or None if not enough votes yet. + """ + self._require_benchmark() + + try: + standings = self.benchmark.get_overall_standings(*args, **kwargs) + except Exception as e: + if "ValidationError" in type(e).__name__: + pruna_logger.warning( + "The benchmark hasn't finished yet.\n " + "Please wait for more votes and try again." + "Skipping." + ) + return None + raise + + scores = dict(zip(standings["name"], standings["score"])) + return CompositeMetricResult( + name=self.metric_name, + params={}, + result=scores, + higher_is_better=self.higher_is_better, + ) + + def retrieve_granular_results(self, **kwargs) -> List[CompositeMetricResult]: + """ + Retrieve per-leaderboard results. + + Each leaderboard produces a separate CompositeMetricResult containing + scores for all evaluated models. + + Parameters + ---------- + **kwargs : Any + Additional keyword arguments passed to the Rapidata API. + + Returns + ------- + List[CompositeMetricResult] + A list of results, one per leaderboard. + """ + self._require_benchmark() + + results = [] + for leaderboard in self.benchmark.leaderboards: + try: + standings = leaderboard.get_standings(**kwargs) + except Exception as e: + if "ValidationError" in type(e).__name__: + pruna_logger.warning( + "Leaderboard '%s' does not have results yet.\n " + "Not enough votes have been collected. Skipping.", + leaderboard.name, + ) + continue + raise + + scores = dict(zip(standings["name"], standings["score"])) + result = CompositeMetricResult( + name=leaderboard.name, + params={"instruction": leaderboard.instruction}, + result=scores, + higher_is_better=not leaderboard.inverse_ranking, + ) + results.append(result) + return results + + def _require_benchmark(self) -> None: + """Raise if no benchmark has been created or attached.""" + if self.benchmark is None: + raise ValueError( + "No benchmark configured. " + "Call create_benchmark(), or use from_benchmark() / from_benchmark_id()." + ) + + def _require_model(self) -> None: + """Raise if no model context has been set.""" + if self.current_benchmarked_model is None: + raise ValueError( + "No model set. Call set_current_context() first." + ) + + def _prepare_media_for_upload(self) -> list[str]: + """ + Convert cached media to file paths that Rapidata can upload. + + Handles three cases: + - str: assumed to be a URL or file path, passed through as-is + - PIL.Image: saved to a temporary file + - torch.Tensor: saved to a temporary file + + Returns + ------- + list[str] + A list of URLs or file paths. + """ + self._temp_dir = Path(tempfile.mkdtemp(prefix="rapidata_")) + media_paths = [] + + for i, item in enumerate(self.media_cache): + if isinstance(item, str): + media_paths.append(item) + elif isinstance(item, PIL.Image.Image): + path = self._temp_dir / f"{i}.png" + item.save(path) + media_paths.append(str(path)) + elif isinstance(item, torch.Tensor): + path = self._temp_dir / f"{i}.png" + tensor = item.float() + if tensor.max() > 1.0: + tensor = tensor / 255.0 + save_image(tensor, path) + media_paths.append(str(path)) + else: + raise TypeError( + f"Unsupported media type: {type(item)}. " + "Expected str (URL/path), PIL.Image, or torch.Tensor." + ) + + return media_paths + + def _cleanup_temp_media(self) -> None: + """Remove temporary files created for upload.""" + if hasattr(self, "_temp_dir") and self._temp_dir.exists(): + shutil.rmtree(self._temp_dir) diff --git a/src/pruna/evaluation/metrics/metric_stateful.py b/src/pruna/evaluation/metrics/metric_stateful.py index 39fddcf6..f9cf7ebd 100644 --- a/src/pruna/evaluation/metrics/metric_stateful.py +++ b/src/pruna/evaluation/metrics/metric_stateful.py @@ -121,9 +121,37 @@ def update(self, *args, **kwargs) -> None: The keyword arguments to pass to the metric. """ + def set_current_context(self, *args, **kwargs) -> None: + """ + Set the current benchmarked model for the metric. + + Override this in subclasses that need to track which model or + configuration is being evaluated, such as async metrics that + submit results to external services. + + By default, this is a no-op. + + Parameters + ---------- + *args : Any + The arguments to pass to the metric. + **kwargs : Any + The keyword arguments to pass to the metric. + """ + pass + @abstractmethod - def compute(self) -> Any: - """Override this method to compute the final metric value.""" + def compute(self, *args, **kwargs) -> Any: + """ + Override this method to compute the final metric value. + + Parameters + ---------- + *args : Any + The arguments to pass to the metric. + **kwargs : Any + The keyword arguments to pass to the metric. + """ def is_pairwise(self) -> bool: """ diff --git a/src/pruna/evaluation/metrics/result.py b/src/pruna/evaluation/metrics/result.py index f1e13ca8..d362474b 100644 --- a/src/pruna/evaluation/metrics/result.py +++ b/src/pruna/evaluation/metrics/result.py @@ -14,13 +14,52 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Protocol, runtime_checkable + + +@runtime_checkable +class MetricResultProtocol(Protocol): + """ + Protocol defining the shared interface for all metric results. + + Any metric result class should implement these attributes and methods + to be compatible with the evaluation pipeline. + + # Have to include this to prevent ty errors. + + Parameters + ---------- + *args : + Additional arguments passed to the MetricResultProtocol. + **kwargs : + Additional keyword arguments passed to the MetricResultProtocol. + + Attributes + ---------- + name : str + The name of the metric. + params : Dict[str, Any] + The parameters of the metric. + higher_is_better : Optional[bool] + Whether larger values mean better performance. + metric_units : Optional[str] + The units of the metric. + """ + + name: str + params: Dict[str, Any] + higher_is_better: Optional[bool] + metric_units: Optional[str] + + def __str__(self) -> str: + """Return a human-readable representation of the metric result.""" + ... @dataclass class MetricResult: """ - A class to store the results of a metric. + A class to store the result of a single-value metric. Parameters ---------- @@ -42,7 +81,7 @@ class MetricResult: higher_is_better: Optional[bool] = None metric_units: Optional[str] = None - def __post_init__(self): + def __post_init__(self) -> None: """Checker that metric_units and higher_is_better are consistent with the result.""" if self.metric_units is None: object.__setattr__(self, "metric_units", self.params.get("metric_units")) @@ -67,7 +106,7 @@ def from_results_dict( metric_name: str, metric_params: Dict[str, Any], results_dict: Dict[str, Any], - ) -> "MetricResult": + ) -> MetricResultProtocol: """ Create a MetricResult from a raw results dictionary. @@ -89,3 +128,57 @@ def from_results_dict( result = results_dict[metric_name] assert isinstance(result, (float, int)), f"Result for metric {metric_name} is not a float or int" return cls(metric_name, metric_params, result) + + +@dataclass +class CompositeMetricResult: + """ + A class to store the result of a metric that returns multiple labeled scores. + + This is used for metrics where a single evaluation request produces + scores for multiple entries, such as asynchronous metrics that + return labeled scores for different settings / models. + + Parameters + ---------- + name : str + The name of the metric. + params : Dict[str, Any] + The parameters of the metric. + result : Dict[str, float | int] + A mapping of labels to scores. + higher_is_better : Optional[bool] + Whether larger values mean better performance. + metric_units : Optional[str] + The units of the metric. + """ + + name: str + params: Dict[str, Any] + result: Dict[str, float | int] + higher_is_better: Optional[bool] = None + metric_units: Optional[str] = None + + def __post_init__(self) -> None: + """Resolve metric_units and higher_is_better from params if not explicitly provided.""" + if self.metric_units is None: + object.__setattr__(self, "metric_units", self.params.get("metric_units")) + if self.higher_is_better is None: + object.__setattr__(self, "higher_is_better", self.params.get("higher_is_better")) + + def __str__(self) -> str: + """ + Return a string representation of the CompositeMetricResult. + + Each labeled score is displayed on its own line. + + Returns + ------- + str + A string representation of the CompositeMetricResult. + """ + lines = [f"{self.name}:"] + for key, score in self.result.items(): + units = f" {self.metric_units}" if self.metric_units else "" + lines.append(f" {key}: {score}{units}") + return "\n".join(lines) diff --git a/tests/evaluation/test_rapidata.py b/tests/evaluation/test_rapidata.py new file mode 100644 index 00000000..98ba4e44 --- /dev/null +++ b/tests/evaluation/test_rapidata.py @@ -0,0 +1,267 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch + +import PIL.Image +import pytest +import torch +from datasets import Dataset + +from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.evaluation.metrics.metric_rapiddata import METRIC_RAPIDATA, RapidataMetric +from pruna.evaluation.metrics.result import CompositeMetricResult + + +@pytest.fixture +def mock_client(): + return MagicMock() + + +@pytest.fixture +def metric(mock_client): + return RapidataMetric(client=mock_client) + + +@pytest.fixture +def metric_with_benchmark(metric): + benchmark = MagicMock() + benchmark.id = "bench-123" + benchmark.leaderboards = [] + metric.benchmark = benchmark + return metric + + +@pytest.fixture +def metric_ready(metric_with_benchmark): + metric_with_benchmark.set_current_context("test-model") + return metric_with_benchmark + + +# Initialization with / without a client +def test_default_client_created_when_none_provided(): + """Test that a RapidataClient is created when none is provided.""" + with patch("pruna.evaluation.metrics.metric_rapiddata.RapidataClient") as mock_cls: + mock_cls.return_value = MagicMock() + m = RapidataMetric() + mock_cls.assert_called_once() + + +def test_custom_client_used(mock_client): + """Test that a custom client is used when provided.""" + m = RapidataMetric(client=mock_client) + assert m.client is mock_client + + +# Creation from existing benchmark +def test_from_benchmark(): + """Test creating a metric from an existing benchmark.""" + benchmark = MagicMock() + with patch("pruna.evaluation.metrics.metric_rapiddata.RapidataClient"): + m = RapidataMetric.from_benchmark(benchmark) + assert m.benchmark is benchmark + + +# Creation from benchmark ID +def test_from_benchmark_id(): + """Test creating a metric from a benchmark ID.""" + with patch("pruna.evaluation.metrics.metric_rapiddata.RapidataClient") as mock_cls: + mock_instance = MagicMock() + mock_cls.return_value = mock_instance + mock_instance.mri.get_benchmark_by_id.return_value = MagicMock(id="abc") + m = RapidataMetric.from_benchmark_id("abc") + mock_instance.mri.get_benchmark_by_id.assert_called_once_with("abc") + assert m.benchmark is not None + + +def test_create_benchmark_with_prompt_list(metric, mock_client): + """Test creating a benchmark with a list of prompts.""" + prompts = ["a cat", "a dog"] + metric.create_benchmark("my-bench", data=prompts) + mock_client.mri.create_new_benchmark.assert_called_once_with("my-bench", prompts=prompts) + assert metric.benchmark is not None + + +def test_create_benchmark_from_datamodule(metric, mock_client): + """Test creating a benchmark from a PrunaDataModule.""" + ds = Dataset.from_dict({"text": ["prompt1", "prompt2"]}) + dm = PrunaDataModule(train_ds=ds, val_ds=ds, test_ds=ds, collate_fn=lambda x: x, dataloader_args={}) + + metric.create_benchmark("my-bench", data=dm, split="test") + mock_client.mri.create_new_benchmark.assert_called_once_with("my-bench", prompts=["prompt1", "prompt2"]) + + +def test_create_benchmark_raises_if_already_exists(metric_with_benchmark): + """Test that creating a benchmark twice raises.""" + with pytest.raises(ValueError, match="Benchmark already created"): + metric_with_benchmark.create_benchmark("dup", data=["x"]) + +def test_create_request_raises_without_benchmark(metric): + """Test that create_request raises without a benchmark.""" + with pytest.raises(ValueError, match="No benchmark configured"): + metric.create_request("quality", "Rate image quality") + + +def test_create_request_delegates_to_leaderboard(metric_with_benchmark): + """Test that create_request delegates to the benchmark.""" + metric_with_benchmark.create_request("quality", "Rate image quality") + metric_with_benchmark.benchmark.create_leaderboard.assert_called_once_with( + "quality", "Rate image quality", False + ) + + +def test_set_current_context_resets_caches(metric_ready): + """Test that set_current_context resets the caches.""" + metric_ready.prompt_cache.append("leftover") + metric_ready.media_cache.append("leftover") + metric_ready.set_current_context("model-b") + assert metric_ready.prompt_cache == [] + assert metric_ready.media_cache == [] + + +def test_update_accumulates_prompts_and_media(metric_ready): + """Test that update accumulates prompts and media.""" + x = ["a cat on a sofa", "a dog in rain"] + gt = [None, None] + outputs = [torch.rand(3, 64, 64), torch.rand(3, 64, 64)] + metric_ready.update(x, gt, outputs) + + assert metric_ready.prompt_cache == x + assert len(metric_ready.media_cache) == 2 + + +def test_update_raises_without_benchmark(metric): + """Test that update raises without a benchmark.""" + metric.current_benchmarked_model = "m" + with pytest.raises(ValueError, match="No benchmark configured"): + metric.update(["p"], [None], [torch.rand(3, 32, 32)]) + + +def test_update_raises_without_model(metric_with_benchmark): + """Test that update raises without a model context.""" + with pytest.raises(ValueError, match="No model set"): + metric_with_benchmark.update(["p"], [None], [torch.rand(3, 32, 32)]) + + +def test_prepare_media_string_passthrough(metric_ready): + """Test that string URLs/paths are passed through as-is.""" + metric_ready.media_cache = ["https://example.com/img.png", "/tmp/local.png"] + paths = metric_ready._prepare_media_for_upload() + assert paths == ["https://example.com/img.png", "/tmp/local.png"] + metric_ready._cleanup_temp_media() + + +def test_prepare_media_pil_image(metric_ready): + """Test that PIL images are saved to temp files.""" + img = PIL.Image.new("RGB", (64, 64), color="red") + metric_ready.media_cache = [img] + paths = metric_ready._prepare_media_for_upload() + assert len(paths) == 1 + assert Path(paths[0]).exists() + metric_ready._cleanup_temp_media() + + +def test_prepare_media_tensor(metric_ready): + """Test that tensors are saved to temp files.""" + tensor = torch.rand(3, 64, 64) + metric_ready.media_cache = [tensor] + paths = metric_ready._prepare_media_for_upload() + assert len(paths) == 1 + assert Path(paths[0]).exists() + metric_ready._cleanup_temp_media() + + +def test_prepare_media_tensor_uint8_range(metric_ready): + """Test that tensors in 0-255 range are normalised before saving.""" + tensor = torch.randint(0, 256, (3, 32, 32)).float() + assert tensor.max() > 1.0 + metric_ready.media_cache = [tensor] + paths = metric_ready._prepare_media_for_upload() + assert len(paths) == 1 + assert Path(paths[0]).exists() + metric_ready._cleanup_temp_media() + + +def test_prepare_media_unsupported_type_raises(metric_ready): + """Test that unsupported media types raise.""" + metric_ready.media_cache = [12345] + with pytest.raises(TypeError, match="Unsupported media type"): + metric_ready._prepare_media_for_upload() + + +def test_compute_submits_to_rapidata(metric_ready): + """Test that compute submits the accumulated data.""" + img = PIL.Image.new("RGB", (32, 32)) + metric_ready.media_cache = [img] + metric_ready.prompt_cache = ["a cat"] + metric_ready.compute() + metric_ready.benchmark.evaluate_model.assert_called_once() + call_kwargs = metric_ready.benchmark.evaluate_model.call_args + assert call_kwargs[0][0] == "test-model" + + +def test_compute_raises_when_cache_empty(metric_ready): + """Test that compute raises when no data has been accumulated.""" + with pytest.raises(ValueError, match="No data accumulated"): + metric_ready.compute() + + +def test_compute_raises_without_model_context(metric_with_benchmark): + """Test that compute raises without a model context.""" + with pytest.raises(ValueError, match="No model set"): + metric_with_benchmark.compute() + + +def test_compute_cleans_up_temp_dir(metric_ready): + """Test that compute removes the temp directory after submission.""" + metric_ready.media_cache = [torch.rand(3, 32, 32)] + metric_ready.prompt_cache = ["test"] + metric_ready.compute() + assert not hasattr(metric_ready, "_temp_dir") or not metric_ready._temp_dir.exists() + + +def test_retrieve_results_returns_composite_result(metric_with_benchmark): + """Test that retrieve_results returns a CompositeMetricResult.""" + metric_with_benchmark.benchmark.get_overall_standings.return_value = { + "name": ["model-a", "model-b"], + "score": [0.85, 0.72], + } + result = metric_with_benchmark.retrieve_results() + assert isinstance(result, CompositeMetricResult) + assert result.name == METRIC_RAPIDATA + assert result.result == {"model-a": 0.85, "model-b": 0.72} + assert result.higher_is_better is True + + +def test_retrieve_results_raises_without_benchmark(metric): + """Test that retrieve_results raises without a benchmark.""" + with pytest.raises(ValueError, match="No benchmark configured"): + metric.retrieve_results() + + +def test_retrieve_results_reraises_non_validation_error(metric_with_benchmark): + """Test that non-validation errors are re-raised.""" + metric_with_benchmark.benchmark.get_overall_standings.side_effect = RuntimeError("boom") + with pytest.raises(RuntimeError, match="boom"): + metric_with_benchmark.retrieve_results() + + +def test_retrieve_granular_results_per_leaderboard(metric_with_benchmark): + """Test that granular results returns one result per leaderboard.""" + lb = MagicMock() + lb.name = "quality" + lb.instruction = "Rate quality" + lb.get_standings.return_value = { + "name": ["model-a"], + "score": [0.9], + } + metric_with_benchmark.benchmark.leaderboards = [lb] + results = metric_with_benchmark.retrieve_granular_results() + assert len(results) == 1 + assert results[0].name == "quality" + assert results[0].params == {"instruction": "Rate quality"} + assert results[0].result == {"model-a": 0.9} + + +def test_retrieve_granular_results_raises_without_benchmark(metric): + """Test that granular results raises without a benchmark.""" + with pytest.raises(ValueError, match="No benchmark configured"): + metric.retrieve_granular_results() \ No newline at end of file From f3134b801f3ac4cd7712a8890e6cb7d681a2bc92 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Fri, 20 Mar 2026 10:34:57 +0000 Subject: [PATCH 2/7] ci: add rapidata dependency and some cleanup --- .github/actions/setup-uv-project/action.yml | 2 +- pyproject.toml | 3 +++ src/pruna/evaluation/metrics/metric_rapiddata.py | 10 ++++++++-- src/pruna/evaluation/metrics/metric_stateful.py | 13 ++----------- src/pruna/evaluation/metrics/result.py | 7 ------- tests/evaluation/test_rapidata.py | 2 +- 6 files changed, 15 insertions(+), 22 deletions(-) diff --git a/.github/actions/setup-uv-project/action.yml b/.github/actions/setup-uv-project/action.yml index 74f1ac9b..10d95d33 100644 --- a/.github/actions/setup-uv-project/action.yml +++ b/.github/actions/setup-uv-project/action.yml @@ -12,4 +12,4 @@ runs: github-token: ${{ github.token }} - shell: bash - run: uv sync --extra dev --extra lmharness --extra vllm + run: uv sync --extra dev --extra lmharness --extra vllm --extra rapidata diff --git a/pyproject.toml b/pyproject.toml index 5b1eb704..33c47332 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -191,6 +191,9 @@ full = [ vbench = [ "vbench-pruna; sys_platform != 'darwin'", ] +rapidata = [ + "rapidata>=3.0.0" +] dev = [ "wget", "python-dotenv", diff --git a/src/pruna/evaluation/metrics/metric_rapiddata.py b/src/pruna/evaluation/metrics/metric_rapiddata.py index 0d79833d..5be686c0 100644 --- a/src/pruna/evaluation/metrics/metric_rapiddata.py +++ b/src/pruna/evaluation/metrics/metric_rapiddata.py @@ -30,12 +30,14 @@ from pruna.evaluation.metrics.async_mixin import AsyncEvaluationMixin from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.result import CompositeMetricResult -from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor +from pruna.evaluation.metrics.utils import PAIRWISE, SINGLE, get_call_type_for_single_metric, metric_data_processor from pruna.logging.logger import pruna_logger METRIC_RAPIDATA = "rapidata" +# We don't use the MetricRegistry here +# because we need to instantiate the Metric directly with benchmark and leaderboards. class RapidataMetric(StatefulMetric, AsyncEvaluationMixin): """ Evaluate models with human feedback via the Rapidata platform. @@ -101,6 +103,8 @@ def __init__( client_id=rapidata_client_id, client_secret=rapidata_client_secret, ) + if call_type.startswith(PAIRWISE): + raise ValueError("RapidataMetric does not support pairwise metrics. Use a single metric instead.") self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) self.add_state("media_cache", default=[]) self.add_state("prompt_cache", default=[]) @@ -196,6 +200,8 @@ def create_benchmark( if self.benchmark is not None: raise ValueError("Benchmark already created. Use from_benchmark() to attach to an existing one.") + # Rapidata benchmarks only accept a list of string, + # so we need to convert the PrunaDataModule to a list of strings. if isinstance(data, PrunaDataModule): split_map = {"test": data.test_dataset, "val": data.val_dataset, "train": data.train_dataset} dataset = split_map[split] @@ -259,7 +265,7 @@ def set_current_context(self, model_name: str, **kwargs) -> None: Additional keyword arguments. """ self.current_benchmarked_model = model_name - self.reset() + self.reset() # Clear the cache for the new model. def update(self, x: List[Any] | Tensor, gt: List[Any] | Tensor, outputs: Any) -> None: """ diff --git a/src/pruna/evaluation/metrics/metric_stateful.py b/src/pruna/evaluation/metrics/metric_stateful.py index f9cf7ebd..3464aaef 100644 --- a/src/pruna/evaluation/metrics/metric_stateful.py +++ b/src/pruna/evaluation/metrics/metric_stateful.py @@ -141,17 +141,8 @@ def set_current_context(self, *args, **kwargs) -> None: pass @abstractmethod - def compute(self, *args, **kwargs) -> Any: - """ - Override this method to compute the final metric value. - - Parameters - ---------- - *args : Any - The arguments to pass to the metric. - **kwargs : Any - The keyword arguments to pass to the metric. - """ + def compute(self,) -> Any: + """Override this method to compute the final metric value.""" def is_pairwise(self) -> bool: """ diff --git a/src/pruna/evaluation/metrics/result.py b/src/pruna/evaluation/metrics/result.py index d362474b..93a9cd0e 100644 --- a/src/pruna/evaluation/metrics/result.py +++ b/src/pruna/evaluation/metrics/result.py @@ -159,13 +159,6 @@ class CompositeMetricResult: higher_is_better: Optional[bool] = None metric_units: Optional[str] = None - def __post_init__(self) -> None: - """Resolve metric_units and higher_is_better from params if not explicitly provided.""" - if self.metric_units is None: - object.__setattr__(self, "metric_units", self.params.get("metric_units")) - if self.higher_is_better is None: - object.__setattr__(self, "higher_is_better", self.params.get("higher_is_better")) - def __str__(self) -> str: """ Return a string representation of the CompositeMetricResult. diff --git a/tests/evaluation/test_rapidata.py b/tests/evaluation/test_rapidata.py index 98ba4e44..1f04fe87 100644 --- a/tests/evaluation/test_rapidata.py +++ b/tests/evaluation/test_rapidata.py @@ -41,7 +41,7 @@ def test_default_client_created_when_none_provided(): """Test that a RapidataClient is created when none is provided.""" with patch("pruna.evaluation.metrics.metric_rapiddata.RapidataClient") as mock_cls: mock_cls.return_value = MagicMock() - m = RapidataMetric() + _ = RapidataMetric() mock_cls.assert_called_once() From 8a081cc9a2355a9d02a0cb741426a1f408bd0416 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 20 Mar 2026 13:06:54 +0000 Subject: [PATCH 3/7] Guard optional rapidata metric import and tighten validation Applied via @cursor push command --- src/pruna/evaluation/metrics/__init__.py | 11 +++++++++-- src/pruna/evaluation/metrics/metric_rapiddata.py | 3 ++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index 9653cb3a..e4a0aa31 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -23,10 +23,15 @@ from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore -from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper +try: + from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric +except ModuleNotFoundError as e: + if e.name != "rapidata": + raise + __all__ = [ "MetricRegistry", "TorchMetricWrapper", @@ -46,5 +51,7 @@ "SharpnessMetric", "AestheticLAION", "LMEvalMetric", - "RapidataMetric", ] + +if "RapidataMetric" in globals(): + __all__.append("RapidataMetric") diff --git a/src/pruna/evaluation/metrics/metric_rapiddata.py b/src/pruna/evaluation/metrics/metric_rapiddata.py index 5be686c0..537981e1 100644 --- a/src/pruna/evaluation/metrics/metric_rapiddata.py +++ b/src/pruna/evaluation/metrics/metric_rapiddata.py @@ -299,6 +299,7 @@ def compute(self) -> None: :meth:`retrieve_granular_results` once enough votes have been collected. """ + self._require_benchmark() self._require_model() if not self.media_cache: raise ValueError("No data accumulated. Call update() before compute().") @@ -348,7 +349,7 @@ def retrieve_results(self, *args, **kwargs) -> CompositeMetricResult | None: if "ValidationError" in type(e).__name__: pruna_logger.warning( "The benchmark hasn't finished yet.\n " - "Please wait for more votes and try again." + "Please wait for more votes and try again.\n " "Skipping." ) return None From a3ba5b1a3a84b00162db7e91710d3b6e7d02cac5 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Fri, 20 Mar 2026 14:30:44 +0000 Subject: [PATCH 4/7] refactor: address PR comments --- src/pruna/evaluation/metrics/__init__.py | 2 +- src/pruna/evaluation/metrics/async_mixin.py | 4 +- .../evaluation/metrics/metric_rapiddata.py | 118 ++++++++---------- tests/algorithms/testers/moe_kernel_tuner.py | 1 - tests/evaluation/test_rapidata.py | 21 ++-- 5 files changed, 66 insertions(+), 80 deletions(-) diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index e4a0aa31..357b2d48 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -27,7 +27,7 @@ from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper try: - from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric + from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric as RapidataMetric except ModuleNotFoundError as e: if e.name != "rapidata": raise diff --git a/src/pruna/evaluation/metrics/async_mixin.py b/src/pruna/evaluation/metrics/async_mixin.py index 8ff3d274..6ba137f3 100644 --- a/src/pruna/evaluation/metrics/async_mixin.py +++ b/src/pruna/evaluation/metrics/async_mixin.py @@ -27,7 +27,7 @@ class AsyncEvaluationMixin(ABC): """ @abstractmethod - def create_request(self, *args, **kwargs) -> Any: + def create_async_request(self, *args, **kwargs) -> Any: """ Create/configure an evaluation request on the external service. @@ -40,7 +40,7 @@ def create_request(self, *args, **kwargs) -> Any: """ @abstractmethod - def retrieve_results(self, *args, **kwargs) -> Any: + def retrieve_async_results(self, *args, **kwargs) -> Any: """ Retrieve results from the external service. diff --git a/src/pruna/evaluation/metrics/metric_rapiddata.py b/src/pruna/evaluation/metrics/metric_rapiddata.py index 537981e1..34528471 100644 --- a/src/pruna/evaluation/metrics/metric_rapiddata.py +++ b/src/pruna/evaluation/metrics/metric_rapiddata.py @@ -40,7 +40,7 @@ # because we need to instantiate the Metric directly with benchmark and leaderboards. class RapidataMetric(StatefulMetric, AsyncEvaluationMixin): """ - Evaluate models with human feedback via the Rapidata platform. + Evaluate models with human feedback via the Rapidata platform https://www.rapidata.ai/. Parameters ---------- @@ -84,10 +84,9 @@ class RapidataMetric(StatefulMetric, AsyncEvaluationMixin): media_cache: List[torch.Tensor | PIL.Image.Image | str] prompt_cache: List[str] + higher_is_better: bool default_call_type: str = "x_y" - higher_is_better: bool = True metric_name: str = METRIC_RAPIDATA - runs_on: List[str] = ["cpu", "cuda"] def __init__( self, @@ -112,9 +111,9 @@ def __init__( self.current_benchmarked_model: str | None = None @classmethod - def from_benchmark( + def from_rapidata_benchmark( cls, - benchmark: RapidataBenchmark, + benchmark: RapidataBenchmark | str, rapidata_client_id: str | None = None, rapidata_client_secret: str | None = None ) -> RapidataMetric: @@ -123,8 +122,8 @@ def from_benchmark( Parameters ---------- - benchmark : RapidataBenchmark - The benchmark to attach to. + benchmark : RapidataBenchmark | str + The benchmark to attach to. Can be a RapidataBenchmark object or a string (benchmark ID). rapidata_client_id : str | None The client ID of the Rapidata client. rapidata_client_secret : str | None @@ -139,38 +138,12 @@ def from_benchmark( rapidata_client_id=rapidata_client_id, rapidata_client_secret=rapidata_client_secret, ) - metric.benchmark = benchmark - return metric - - @classmethod - def from_benchmark_id( - cls, - benchmark_id: str, - rapidata_client_id: str | None = None, - rapidata_client_secret: str | None = None, - ) -> RapidataMetric: - """ - Create a RapidataMetric from an existing benchmark ID. - - Parameters - ---------- - benchmark_id : str - The ID of the benchmark on the Rapidata platform. - rapidata_client_id : str | None - The client ID of the Rapidata client. - rapidata_client_secret : str | None - The client secret of the Rapidata client. - - Returns - ------- - RapidataMetric - The created metric. - """ - metric = cls( - rapidata_client_id=rapidata_client_id, - rapidata_client_secret=rapidata_client_secret, - ) - metric.benchmark = metric.client.mri.get_benchmark_by_id(benchmark_id) + if isinstance(benchmark, RapidataBenchmark): + metric.benchmark = benchmark + elif isinstance(benchmark, str): + metric.benchmark = metric.client.mri.get_benchmark_by_id(benchmark) + else: + raise ValueError(f"Invalid benchmark: {benchmark}. Expected a RapidataBenchmark or a string.") return metric def create_benchmark( @@ -217,7 +190,7 @@ def create_benchmark( self.benchmark = self.client.mri.create_new_benchmark(name, prompts=data, **kwargs) - def create_request( + def create_async_request( self, name: str, instruction: str, @@ -315,7 +288,7 @@ def compute(self) -> None: self._cleanup_temp_media() - pruna_logger.info( + pruna_logger.warning( "Sent evaluation request for model '%s' to Rapidata.\n " "It may take a while to collect votes from human raters.\n " "Use retrieve_results() to check scores later, " @@ -325,12 +298,22 @@ def compute(self) -> None: self.benchmark.id, ) - def retrieve_results(self, *args, **kwargs) -> CompositeMetricResult | None: + def retrieve_async_results( + self, + granular: bool = False, + *args, + **kwargs, + ) -> List[CompositeMetricResult] | CompositeMetricResult | None: """ - Retrieve aggregated standings across all leaderboards. + Retrieve standings for all leaderboards. + + If granular is True, retrieve standings for each leaderboard separately. + Otherwise, retrieve aggregated standings across all leaderboards. Parameters ---------- + granular: bool, optional + Whether to retrieve granular results. Default is False. *args : Any Additional arguments passed to the Rapidata API. **kwargs : Any @@ -338,32 +321,35 @@ def retrieve_results(self, *args, **kwargs) -> CompositeMetricResult | None: Returns ------- - CompositeMetricResult | None - The overall standings, or None if not enough votes yet. + List[CompositeMetricResult] | CompositeMetricResult | None + If granular is True, a list of results, one per leaderboard. + If granular is False, the overall standings, or None if not enough votes yet. """ self._require_benchmark() - try: - standings = self.benchmark.get_overall_standings(*args, **kwargs) - except Exception as e: - if "ValidationError" in type(e).__name__: - pruna_logger.warning( - "The benchmark hasn't finished yet.\n " - "Please wait for more votes and try again.\n " - "Skipping." - ) - return None - raise - - scores = dict(zip(standings["name"], standings["score"])) - return CompositeMetricResult( - name=self.metric_name, - params={}, - result=scores, - higher_is_better=self.higher_is_better, - ) + if not granular: + try: + standings = self.benchmark.get_overall_standings(*args, **kwargs) + except Exception as e: + if "ValidationError" in type(e).__name__: + pruna_logger.warning( + "The benchmark hasn't finished yet.\n " + "Please wait for more votes and try again.\n " + "Skipping." + ) + return None + raise - def retrieve_granular_results(self, **kwargs) -> List[CompositeMetricResult]: + return CompositeMetricResult( + name=self.metric_name, + params={}, + result=dict(zip(standings["name"], standings["score"])), + higher_is_better=self.higher_is_better, + ) + + return self._retrieve_granular_results(**kwargs) + + def _retrieve_granular_results(self, **kwargs) -> List[CompositeMetricResult]: """ Retrieve per-leaderboard results. @@ -380,8 +366,6 @@ def retrieve_granular_results(self, **kwargs) -> List[CompositeMetricResult]: List[CompositeMetricResult] A list of results, one per leaderboard. """ - self._require_benchmark() - results = [] for leaderboard in self.benchmark.leaderboards: try: diff --git a/tests/algorithms/testers/moe_kernel_tuner.py b/tests/algorithms/testers/moe_kernel_tuner.py index 9a754cf3..85661a83 100644 --- a/tests/algorithms/testers/moe_kernel_tuner.py +++ b/tests/algorithms/testers/moe_kernel_tuner.py @@ -34,7 +34,6 @@ def post_smash_hook(self, model: PrunaModel) -> None: def _resolve_hf_cache_config_path(self) -> Path: """Read the saved artifact and compute the expected HF cache config path.""" - imported_packages = MoeKernelTuner().import_algorithm_packages() smash_cfg = SmashConfig() diff --git a/tests/evaluation/test_rapidata.py b/tests/evaluation/test_rapidata.py index 1f04fe87..85caa673 100644 --- a/tests/evaluation/test_rapidata.py +++ b/tests/evaluation/test_rapidata.py @@ -35,6 +35,12 @@ def metric_ready(metric_with_benchmark): metric_with_benchmark.set_current_context("test-model") return metric_with_benchmark +@pytest.fixture +def metric_ready_with_cleanup(metric_ready): + """metric_ready that auto-cleans temp media after the test.""" + yield metric_ready + metric_ready._cleanup_temp_media() + # Initialization with / without a client def test_default_client_created_when_none_provided(): @@ -140,23 +146,20 @@ def test_update_raises_without_model(metric_with_benchmark): with pytest.raises(ValueError, match="No model set"): metric_with_benchmark.update(["p"], [None], [torch.rand(3, 32, 32)]) - -def test_prepare_media_string_passthrough(metric_ready): +def test_prepare_media_string_passthrough(metric_ready_with_cleanup): """Test that string URLs/paths are passed through as-is.""" - metric_ready.media_cache = ["https://example.com/img.png", "/tmp/local.png"] - paths = metric_ready._prepare_media_for_upload() + metric_ready_with_cleanup.media_cache = ["https://example.com/img.png", "/tmp/local.png"] + paths = metric_ready_with_cleanup._prepare_media_for_upload() assert paths == ["https://example.com/img.png", "/tmp/local.png"] - metric_ready._cleanup_temp_media() -def test_prepare_media_pil_image(metric_ready): +def test_prepare_media_pil_image(metric_ready_with_cleanup): """Test that PIL images are saved to temp files.""" img = PIL.Image.new("RGB", (64, 64), color="red") - metric_ready.media_cache = [img] - paths = metric_ready._prepare_media_for_upload() + metric_ready_with_cleanup.media_cache = [img] + paths = metric_ready_with_cleanup._prepare_media_for_upload() assert len(paths) == 1 assert Path(paths[0]).exists() - metric_ready._cleanup_temp_media() def test_prepare_media_tensor(metric_ready): From f2b4a987cf8c46795da9caf9dd71398bd244c08b Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Mon, 23 Mar 2026 17:58:52 +0000 Subject: [PATCH 5/7] feat: add polling and address further PR comments --- .../evaluation/metrics/metric_rapiddata.py | 254 +++++++++++++----- tests/evaluation/test_rapidata.py | 165 +++++++++--- 2 files changed, 310 insertions(+), 109 deletions(-) diff --git a/src/pruna/evaluation/metrics/metric_rapiddata.py b/src/pruna/evaluation/metrics/metric_rapiddata.py index 34528471..76b20552 100644 --- a/src/pruna/evaluation/metrics/metric_rapiddata.py +++ b/src/pruna/evaluation/metrics/metric_rapiddata.py @@ -16,8 +16,9 @@ import shutil import tempfile +import time from pathlib import Path -from typing import Any, List, Literal +from typing import Any, Callable, List, Literal import PIL.Image import torch @@ -84,7 +85,10 @@ class RapidataMetric(StatefulMetric, AsyncEvaluationMixin): media_cache: List[torch.Tensor | PIL.Image.Image | str] prompt_cache: List[str] - higher_is_better: bool + # With every metric higher is actually better, + # Because for negative questions like "Which image has more errors?" + # We create the leaderboard with inverse_ranking=True, which reverses the ranking. + higher_is_better: bool = True default_call_type: str = "x_y" metric_name: str = METRIC_RAPIDATA @@ -150,6 +154,7 @@ def create_benchmark( self, name: str, data: list[str] | PrunaDataModule, + data_assets: list[str] | None = None, split: Literal["test", "val", "train"] = "test", **kwargs, ) -> None: @@ -159,12 +164,20 @@ def create_benchmark( The benchmark defines the prompt pool. Any data submitted to leaderboards later must be drawn from this pool. + Prompts can be provided as a list of strings or as a PrunaDataModule. + When using a list of strings, you can optionally pass data_assets as a list of file paths or URLs. + When using a PrunaDataModule, data assets are extracted automatically from the datamodule, if available. + Parameters ---------- name : str The name of the benchmark. data : list[str] | PrunaDataModule The prompts or dataset to benchmark against. + data_assets : list[str] | None + The assets to attach to the prompts. + For instance, if you wish to benchmark an image editing model, + you can pass the original images as data_assets. split : str, optional Which split to use when data is a PrunaDataModule. Default is "test". **kwargs : Any @@ -181,6 +194,11 @@ def create_benchmark( # PrunaDataModule dataset loaders always renames the prompts column to "text" if hasattr(dataset, "column_names") and "text" in dataset.column_names: data = list(dataset["text"]) + data_assets = None # When using a PrunaDataModule, we need to get the data assets from the datamodule. + if "image" in dataset.column_names: + images = list(dataset["image"]) # Pruna text to image datasets always have an "image" column. + # Rapidata only accepts file paths or URLs, so we need to convert the images to file paths. + data_assets = self._prepare_media_for_upload(images) else: raise ValueError( "Could not extract prompts from dataset.\n " @@ -188,13 +206,14 @@ def create_benchmark( or pass a list[str] directly instead." ) - self.benchmark = self.client.mri.create_new_benchmark(name, prompts=data, **kwargs) + self.benchmark = self.client.mri.create_new_benchmark(name, prompts=data, prompt_assets=data_assets, **kwargs) def create_async_request( self, name: str, instruction: str, show_prompt: bool = False, + show_prompt_assets: bool = False, **kwargs, ) -> None: """ @@ -216,11 +235,13 @@ def create_async_request( The evaluation instruction shown to human raters. show_prompt : bool, optional Whether to show the prompt to raters. Default is False. + show_prompt_assets : bool, optional + Whether to show the prompt assets to raters. Default is False. **kwargs : Any Additional keyword arguments passed to the Rapidata API. """ self._require_benchmark() - self.benchmark.create_leaderboard(name, instruction, show_prompt, **kwargs) + self.benchmark.create_leaderboard(name, instruction, show_prompt, show_prompt_assets, **kwargs) def set_current_context(self, model_name: str, **kwargs) -> None: """ @@ -298,22 +319,49 @@ def compute(self) -> None: self.benchmark.id, ) - def retrieve_async_results( - self, - granular: bool = False, - *args, - **kwargs, - ) -> List[CompositeMetricResult] | CompositeMetricResult | None: + @staticmethod + def _is_not_ready_error(exc: Exception) -> bool: """ - Retrieve standings for all leaderboards. + Search for a ValidationError in the exception chain. + + When the benchmark is not finished yet, the API throws a pydantic ValidationError + we are catching it and returning None to indicate that the benchmark is not ready yet, + rather than straight up failing with an exception. + """ + return "ValidationError" in type(exc).__name__ + + def _fetch_standings(self, api_call, *args, **kwargs): + """ + Barebones API call wrapper that catches ValidationError and returns None if the benchmark is not ready yet. + + Since the core logic between the overall and granular standings is the same, + we can use a single function to fetch the standings. + + Parameters + ---------- + api_call : callable + The API call to make. + *args : Any + Additional arguments passed to the API call. + **kwargs : Any + Additional keyword arguments passed to the API call. + """ + try: + return api_call(*args, **kwargs) + except Exception as e: + if not self._is_not_ready_error(e): + raise + return None + + def _fetch_overall_standings(self, *args, **kwargs) -> tuple[CompositeMetricResult | None, bool]: + """ + Retrieve overall standings for the benchmark. - If granular is True, retrieve standings for each leaderboard separately. - Otherwise, retrieve aggregated standings across all leaderboards. + Returns a tuple where the first element is the composite score of all leaderboards in the benchmark, + and the second element is a boolean indicating whether the benchmark is finished yet. Parameters ---------- - granular: bool, optional - Whether to retrieve granular results. Default is False. *args : Any Additional arguments passed to the Rapidata API. **kwargs : Any @@ -321,74 +369,146 @@ def retrieve_async_results( Returns ------- - List[CompositeMetricResult] | CompositeMetricResult | None - If granular is True, a list of results, one per leaderboard. - If granular is False, the overall standings, or None if not enough votes yet. + CompositeMetricResult | None + The overall standings or None if the benchmark is not finished yet. """ - self._require_benchmark() + standings = self._fetch_standings(self.benchmark.get_overall_standings, *args, **kwargs) + if standings is None: + return None, False + return CompositeMetricResult( + name=self.metric_name, + params={}, + result=dict(zip(standings["name"], standings["score"])), + higher_is_better=self.higher_is_better, + ), True + + def _fetch_granular_standings(self, *args, **kwargs) -> tuple[List[CompositeMetricResult] | None, bool]: + """ + Retrieve standings for all leaderboards. - if not granular: - try: - standings = self.benchmark.get_overall_standings(*args, **kwargs) - except Exception as e: - if "ValidationError" in type(e).__name__: - pruna_logger.warning( - "The benchmark hasn't finished yet.\n " - "Please wait for more votes and try again.\n " - "Skipping." - ) - return None - raise + Returns a tuple where the first element is a list of results, one per leaderboard, + and the second element is a boolean indicating whether all of the leaderboards (the benchmark) is finished yet. + + Parameters + ---------- + *args : Any + Additional arguments passed to the Rapidata API. + **kwargs : Any + Additional keyword arguments passed to the Rapidata API. - return CompositeMetricResult( - name=self.metric_name, - params={}, + Returns + ------- + List[CompositeMetricResult] | None + A list of results, one per leaderboard, or None if the benchmark is not finished yet. + """ + results = [] + all_finished = True + for leaderboard in self.benchmark.leaderboards: + standings = self._fetch_standings(leaderboard.get_standings, *args, **kwargs) + if standings is None: + all_finished = False + continue + results.append(CompositeMetricResult( + name=leaderboard.name, + params={"instruction": leaderboard.instruction}, result=dict(zip(standings["name"], standings["score"])), higher_is_better=self.higher_is_better, - ) + )) + return results, all_finished + + def _fetch_with_retry_option( + self, + fetch_fn: Callable, + is_blocking: bool, + timeout: float, + poll_interval: float, + *args, + **kwargs, + ) -> CompositeMetricResult | List[CompositeMetricResult] | None: + """ + Wait for the results or return whatever we have as is from the benchmark. - return self._retrieve_granular_results(**kwargs) + If is_blocking is True, it will poll until the results are ready or the timeout is reached. + If is_blocking is False, it will return the results immediately if they are ready, + otherwise it will return None and log a warning. - def _retrieve_granular_results(self, **kwargs) -> List[CompositeMetricResult]: + Parameters + ---------- + fetch_fn : callable + The function to fetch the standings from the benchmark. + is_blocking : bool + Whether to block and wait for the results to be ready. + timeout : float + The maximum time to wait for the results to be ready. + poll_interval : float + The interval in seconds to poll for the results. + *args : Any + Additional arguments passed to the Rapidata API. + **kwargs : Any + Additional keyword arguments passed to the Rapidata API. """ - Retrieve per-leaderboard results. + deadline = time.monotonic() + timeout + while True: + result, is_finished = fetch_fn(*args, **kwargs) + if is_finished: # The benchmark is finished, we don't need to check anything else, just return the result. + return result + if not is_blocking: # The benchmark is not finished yet, but the user doesn't want to keep on polling. + pruna_logger.warning( + "The benchmark hasn't finished yet. " + "Please wait for more votes and try again." + ) + return result # Return whatever we have as is. + if time.monotonic() + poll_interval > deadline: # The timeout is reached, we raise an exception. + raise TimeoutError( + f"Benchmark results not ready after {timeout:.0f}s. " + f"Monitor at: https://app.rapidata.ai/mri/benchmarks/{self.benchmark.id}" + ) + pruna_logger.info("Results not ready yet, retrying in %ds...", poll_interval) + time.sleep(poll_interval) - Each leaderboard produces a separate CompositeMetricResult containing - scores for all evaluated models. + def retrieve_async_results( + self, + is_granular: bool = False, + is_blocking: bool = False, + timeout: float = 3600, + poll_interval: float = 30, + *args, + **kwargs, + ) -> List[CompositeMetricResult] | CompositeMetricResult | None: + """ + Retrieve standings from the benchmark. Parameters ---------- + is_granular : bool, optional + If True, return per-leaderboard results (partial results + are returned for any leaderboard that is ready). + If False, return overall aggregated standings. + is_blocking : bool, optional + If True, poll until results are ready or *timeout* is reached. + timeout : float, optional + Maximum seconds to wait when blocking. Default is 3600. + poll_interval : float, optional + Seconds between polling attempts when blocking. Default is 30. + *args : Any + Additional arguments passed to the Rapidata API. **kwargs : Any Additional keyword arguments passed to the Rapidata API. Returns ------- - List[CompositeMetricResult] - A list of results, one per leaderboard. - """ - results = [] - for leaderboard in self.benchmark.leaderboards: - try: - standings = leaderboard.get_standings(**kwargs) - except Exception as e: - if "ValidationError" in type(e).__name__: - pruna_logger.warning( - "Leaderboard '%s' does not have results yet.\n " - "Not enough votes have been collected. Skipping.", - leaderboard.name, - ) - continue - raise + List[CompositeMetricResult] | CompositeMetricResult | None + Granular returns a list (possibly partial), overall returns + a single result or None if not ready. - scores = dict(zip(standings["name"], standings["score"])) - result = CompositeMetricResult( - name=leaderboard.name, - params={"instruction": leaderboard.instruction}, - result=scores, - higher_is_better=not leaderboard.inverse_ranking, - ) - results.append(result) - return results + Raises + ------ + TimeoutError + If *is_blocking* is True and results are not ready within *timeout*. + """ + self._require_benchmark() + fetch_fn = self._fetch_granular_standings if is_granular else self._fetch_overall_standings + return self._fetch_with_retry_option(fetch_fn, is_blocking, timeout, poll_interval, **kwargs) def _require_benchmark(self) -> None: """Raise if no benchmark has been created or attached.""" @@ -405,7 +525,7 @@ def _require_model(self) -> None: "No model set. Call set_current_context() first." ) - def _prepare_media_for_upload(self) -> list[str]: + def _prepare_media_for_upload(self, media: list[torch.Tensor | PIL.Image.Image | str] | None = None) -> list[str]: """ Convert cached media to file paths that Rapidata can upload. @@ -422,7 +542,7 @@ def _prepare_media_for_upload(self) -> list[str]: self._temp_dir = Path(tempfile.mkdtemp(prefix="rapidata_")) media_paths = [] - for i, item in enumerate(self.media_cache): + for i, item in enumerate(media or self.media_cache): if isinstance(item, str): media_paths.append(item) elif isinstance(item, PIL.Image.Image): diff --git a/tests/evaluation/test_rapidata.py b/tests/evaluation/test_rapidata.py index 85caa673..0985fbf9 100644 --- a/tests/evaluation/test_rapidata.py +++ b/tests/evaluation/test_rapidata.py @@ -7,7 +7,8 @@ from datasets import Dataset from pruna.data.pruna_datamodule import PrunaDataModule -from pruna.evaluation.metrics.metric_rapiddata import METRIC_RAPIDATA, RapidataMetric +from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric +from rapidata.rapidata_client.benchmark.rapidata_benchmark import RapidataBenchmark from pruna.evaluation.metrics.result import CompositeMetricResult @@ -27,6 +28,7 @@ def metric_with_benchmark(metric): benchmark.id = "bench-123" benchmark.leaderboards = [] metric.benchmark = benchmark + metric.higher_is_better = True return metric @@ -60,9 +62,9 @@ def test_custom_client_used(mock_client): # Creation from existing benchmark def test_from_benchmark(): """Test creating a metric from an existing benchmark.""" - benchmark = MagicMock() + benchmark = MagicMock(spec=RapidataBenchmark) with patch("pruna.evaluation.metrics.metric_rapiddata.RapidataClient"): - m = RapidataMetric.from_benchmark(benchmark) + m = RapidataMetric.from_rapidata_benchmark(benchmark) assert m.benchmark is benchmark @@ -73,7 +75,7 @@ def test_from_benchmark_id(): mock_instance = MagicMock() mock_cls.return_value = mock_instance mock_instance.mri.get_benchmark_by_id.return_value = MagicMock(id="abc") - m = RapidataMetric.from_benchmark_id("abc") + m = RapidataMetric.from_rapidata_benchmark("abc") mock_instance.mri.get_benchmark_by_id.assert_called_once_with("abc") assert m.benchmark is not None @@ -82,7 +84,7 @@ def test_create_benchmark_with_prompt_list(metric, mock_client): """Test creating a benchmark with a list of prompts.""" prompts = ["a cat", "a dog"] metric.create_benchmark("my-bench", data=prompts) - mock_client.mri.create_new_benchmark.assert_called_once_with("my-bench", prompts=prompts) + mock_client.mri.create_new_benchmark.assert_called_once_with("my-bench", prompts=prompts, prompt_assets=None) assert metric.benchmark is not None @@ -92,7 +94,7 @@ def test_create_benchmark_from_datamodule(metric, mock_client): dm = PrunaDataModule(train_ds=ds, val_ds=ds, test_ds=ds, collate_fn=lambda x: x, dataloader_args={}) metric.create_benchmark("my-bench", data=dm, split="test") - mock_client.mri.create_new_benchmark.assert_called_once_with("my-bench", prompts=["prompt1", "prompt2"]) + mock_client.mri.create_new_benchmark.assert_called_once_with("my-bench", prompts=["prompt1", "prompt2"], prompt_assets=None) def test_create_benchmark_raises_if_already_exists(metric_with_benchmark): @@ -103,14 +105,14 @@ def test_create_benchmark_raises_if_already_exists(metric_with_benchmark): def test_create_request_raises_without_benchmark(metric): """Test that create_request raises without a benchmark.""" with pytest.raises(ValueError, match="No benchmark configured"): - metric.create_request("quality", "Rate image quality") + metric.create_async_request("quality", "Rate image quality") def test_create_request_delegates_to_leaderboard(metric_with_benchmark): """Test that create_request delegates to the benchmark.""" - metric_with_benchmark.create_request("quality", "Rate image quality") + metric_with_benchmark.create_async_request("quality", "Rate image quality") metric_with_benchmark.benchmark.create_leaderboard.assert_called_once_with( - "quality", "Rate image quality", False + "quality", "Rate image quality", False, False ) @@ -221,50 +223,129 @@ def test_compute_cleans_up_temp_dir(metric_ready): assert not hasattr(metric_ready, "_temp_dir") or not metric_ready._temp_dir.exists() -def test_retrieve_results_returns_composite_result(metric_with_benchmark): - """Test that retrieve_results returns a CompositeMetricResult.""" +class _FakeValidationError(Exception): + pass + + +def test_is_not_ready_error_recognises_validation_error(): + assert RapidataMetric._is_not_ready_error(_FakeValidationError()) is True + assert RapidataMetric._is_not_ready_error(RuntimeError()) is False + + +def test_retrieve_non_blocking_returns_result_when_ready(metric_with_benchmark): metric_with_benchmark.benchmark.get_overall_standings.return_value = { - "name": ["model-a", "model-b"], - "score": [0.85, 0.72], + "name": ["model-a", "model-b"], "score": [0.85, 0.72], } - result = metric_with_benchmark.retrieve_results() + result = metric_with_benchmark.retrieve_async_results() assert isinstance(result, CompositeMetricResult) - assert result.name == METRIC_RAPIDATA assert result.result == {"model-a": 0.85, "model-b": 0.72} - assert result.higher_is_better is True -def test_retrieve_results_raises_without_benchmark(metric): - """Test that retrieve_results raises without a benchmark.""" - with pytest.raises(ValueError, match="No benchmark configured"): - metric.retrieve_results() +def test_retrieve_non_blocking_returns_none_when_not_ready(metric_with_benchmark): + metric_with_benchmark.benchmark.get_overall_standings.side_effect = _FakeValidationError() + assert metric_with_benchmark.retrieve_async_results() is None + + +def test_retrieve_non_blocking_granular_returns_partial(metric_with_benchmark): + lb_ready = MagicMock(name="quality", instruction="Rate quality", inverse_ranking=False) + lb_ready.get_standings.return_value = {"name": ["m-a"], "score": [0.9]} + lb_pending = MagicMock(name="alignment") + lb_pending.get_standings.side_effect = _FakeValidationError() + metric_with_benchmark.benchmark.leaderboards = [lb_ready, lb_pending] + + results = metric_with_benchmark.retrieve_async_results(is_granular=True) + assert len(results) == 1 + assert results[0].result == {"m-a": 0.9} -def test_retrieve_results_reraises_non_validation_error(metric_with_benchmark): - """Test that non-validation errors are re-raised.""" +def test_retrieve_reraises_non_validation_error(metric_with_benchmark): metric_with_benchmark.benchmark.get_overall_standings.side_effect = RuntimeError("boom") with pytest.raises(RuntimeError, match="boom"): - metric_with_benchmark.retrieve_results() + metric_with_benchmark.retrieve_async_results() -def test_retrieve_granular_results_per_leaderboard(metric_with_benchmark): - """Test that granular results returns one result per leaderboard.""" - lb = MagicMock() - lb.name = "quality" - lb.instruction = "Rate quality" - lb.get_standings.return_value = { - "name": ["model-a"], - "score": [0.9], - } - metric_with_benchmark.benchmark.leaderboards = [lb] - results = metric_with_benchmark.retrieve_granular_results() - assert len(results) == 1 - assert results[0].name == "quality" - assert results[0].params == {"instruction": "Rate quality"} - assert results[0].result == {"model-a": 0.9} +@patch("pruna.evaluation.metrics.metric_rapiddata.time") +def test_retrieve_blocking_polls_until_ready(mock_time, metric_with_benchmark): + _clock = iter(range(0, 1000, 10)) + mock_time.monotonic.side_effect = lambda: next(_clock) + standings = {"name": ["m-a"], "score": [0.9]} + metric_with_benchmark.benchmark.get_overall_standings.side_effect = [ + _FakeValidationError(), _FakeValidationError(), standings, + ] + result = metric_with_benchmark.retrieve_async_results(is_blocking=True, timeout=60, poll_interval=5) + assert isinstance(result, CompositeMetricResult) + assert result.result == {"m-a": 0.9} + assert mock_time.sleep.call_count == 2 -def test_retrieve_granular_results_raises_without_benchmark(metric): - """Test that granular results raises without a benchmark.""" - with pytest.raises(ValueError, match="No benchmark configured"): - metric.retrieve_granular_results() \ No newline at end of file + +@patch("pruna.evaluation.metrics.metric_rapiddata.time") +def test_retrieve_blocking_raises_timeout(mock_time, metric_with_benchmark): + _clock = iter(range(0, 1000, 30)) + mock_time.monotonic.side_effect = lambda: next(_clock) + metric_with_benchmark.benchmark.get_overall_standings.side_effect = _FakeValidationError() + + with pytest.raises(TimeoutError, match="not ready after 60s"): + metric_with_benchmark.retrieve_async_results(is_blocking=True, timeout=60, poll_interval=5) + + +def test_create_benchmark_forwards_explicit_data_assets(metric, mock_client): + """Explicit data_assets are forwarded as prompt_assets.""" + prompts = ["edit this", "fix that"] + assets = ["/imgs/a.png", "/imgs/b.png"] + metric.create_benchmark("bench", data=prompts, data_assets=assets) + mock_client.mri.create_new_benchmark.assert_called_once_with( + "bench", prompts=prompts, prompt_assets=assets, + ) + + +def test_create_benchmark_datamodule_extracts_images(metric, mock_client): + """PrunaDataModule with an 'image' column extracts and converts images to prompt_assets.""" + from datasets import Features, Image as HFImage, Value + img1 = PIL.Image.new("RGB", (32, 32), "red") + img2 = PIL.Image.new("RGB", (32, 32), "blue") + ds = Dataset.from_dict( + {"text": ["prompt1", "prompt2"], "image": [img1, img2]}, + features=Features({"text": Value("string"), "image": HFImage()}), + ) + dm = PrunaDataModule(train_ds=ds, val_ds=ds, test_ds=ds, collate_fn=lambda x: x, dataloader_args={}) + with patch.object(metric, "_prepare_media_for_upload", return_value=["/tmp/0.png", "/tmp/1.png"]) as mock_prep: + metric.create_benchmark("my-bench", data=dm, split="test") + mock_prep.assert_called_once() + images_arg = mock_prep.call_args[0][0] + assert len(images_arg) == 2 + assert all(isinstance(img, PIL.Image.Image) for img in images_arg) + mock_client.mri.create_new_benchmark.assert_called_once_with( + "my-bench", prompts=["prompt1", "prompt2"], prompt_assets=["/tmp/0.png", "/tmp/1.png"], + ) + + +def test_create_benchmark_datamodule_ignores_explicit_data_assets(metric, mock_client): + """When using a PrunaDataModule, explicit data_assets are overridden.""" + ds = Dataset.from_dict({"text": ["p1"]}) + dm = PrunaDataModule(train_ds=ds, val_ds=ds, test_ds=ds, collate_fn=lambda x: x, dataloader_args={}) + metric.create_benchmark("bench", data=dm, data_assets=["/should/be/ignored.png"]) + mock_client.mri.create_new_benchmark.assert_called_once_with( + "bench", prompts=["p1"], prompt_assets=None, + ) + + +def test_create_request_forwards_show_prompt_assets_true(metric_with_benchmark): + """show_prompt_assets=True is forwarded to create_leaderboard.""" + metric_with_benchmark.create_async_request("quality", "Rate quality", show_prompt_assets=True) + metric_with_benchmark.benchmark.create_leaderboard.assert_called_once_with( + "quality", "Rate quality", False, True, + ) + + +def test_prepare_media_uses_explicit_list_over_cache(metric_ready): + """Passing an explicit media list uses it instead of media_cache.""" + metric_ready.media_cache = [torch.rand(3, 32, 32)] # should be ignored + explicit = [PIL.Image.new("RGB", (16, 16))] + + paths = metric_ready._prepare_media_for_upload(explicit) + assert len(paths) == 1 + assert Path(paths[0]).exists() + loaded = PIL.Image.open(paths[0]) + assert loaded.size == (16, 16) + metric_ready._cleanup_temp_media() \ No newline at end of file From a772e7c1362726785e8009bb938b3dee32620ae6 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Mon, 23 Mar 2026 18:31:26 +0000 Subject: [PATCH 6/7] refactor: add mixin for setting context --- src/pruna/evaluation/evaluation_agent.py | 6 +- src/pruna/evaluation/metrics/context_mixin.py | 62 +++++++++++++++++++ .../evaluation/metrics/metric_rapiddata.py | 37 ++++++----- .../evaluation/metrics/metric_stateful.py | 19 ------ tests/evaluation/test_rapidata.py | 12 ++-- 5 files changed, 92 insertions(+), 44 deletions(-) create mode 100644 src/pruna/evaluation/metrics/context_mixin.py diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index ba83c863..674bf962 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -26,6 +26,7 @@ from pruna.data.utils import move_batch_to_device from pruna.engine.pruna_model import PrunaModel from pruna.engine.utils import get_device, move_to_device, safe_memory_cleanup, set_to_best_available_device +from pruna.evaluation.metrics.context_mixin import EvaluationContextMixin from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.result import MetricResult, MetricResultProtocol @@ -148,8 +149,9 @@ def evaluate(self, model: Any, model_name: str | None = None) -> List[MetricResu pairwise_metrics = self.task.get_pairwise_stateful_metrics() stateless_metrics = self.task.get_stateless_metrics() - for metric in single_stateful_metrics + pairwise_metrics: - metric.set_current_context(model_name=model_name) + for metric in single_stateful_metrics: + if isinstance(metric, EvaluationContextMixin): + metric.current_context = model_name # Update and compute stateful metrics. pruna_logger.info("Evaluating stateful metrics.") diff --git a/src/pruna/evaluation/metrics/context_mixin.py b/src/pruna/evaluation/metrics/context_mixin.py new file mode 100644 index 00000000..ea4f5b2c --- /dev/null +++ b/src/pruna/evaluation/metrics/context_mixin.py @@ -0,0 +1,62 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC + + +class EvaluationContextMixin(ABC): + """ + Mixin for metrics that evaluate multiple models sequentially. + + Provides a current_context property that tracks which model is being + evaluated. Setting a new context triggers on_context_change(), which + subclasses can override to reset state between models. + """ + + _current_context: str | None = None + + @property + def current_context(self) -> str | None: + """ + Return the current context. + + Returns + ------- + str | None + The current context. + """ + return self._current_context + + @current_context.setter + def current_context(self, value: str | None) -> None: + """ + Set the current context. + + Parameters + ---------- + value : str + The new context. + """ + self._current_context = value + self.on_context_change() + + def on_context_change(self) -> None: + """Hook called when the context changes. Override to reset state.""" + pass + + def _require_context(self) -> None: + """Raise if no context has been set.""" + if self._current_context is None: + raise ValueError("No context set. Set current_context first.") diff --git a/src/pruna/evaluation/metrics/metric_rapiddata.py b/src/pruna/evaluation/metrics/metric_rapiddata.py index 76b20552..d715373a 100644 --- a/src/pruna/evaluation/metrics/metric_rapiddata.py +++ b/src/pruna/evaluation/metrics/metric_rapiddata.py @@ -29,6 +29,7 @@ from pruna.data.pruna_datamodule import PrunaDataModule from pruna.evaluation.metrics.async_mixin import AsyncEvaluationMixin +from pruna.evaluation.metrics.context_mixin import EvaluationContextMixin from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.result import CompositeMetricResult from pruna.evaluation.metrics.utils import PAIRWISE, SINGLE, get_call_type_for_single_metric, metric_data_processor @@ -39,14 +40,14 @@ # We don't use the MetricRegistry here # because we need to instantiate the Metric directly with benchmark and leaderboards. -class RapidataMetric(StatefulMetric, AsyncEvaluationMixin): +class RapidataMetric(StatefulMetric, AsyncEvaluationMixin, EvaluationContextMixin): """ Evaluate models with human feedback via the Rapidata platform https://www.rapidata.ai/. Parameters ---------- call_type : str - How to extract inputs from (x, gt, outputs). Default is "single". + How to extract inputs from (x, gt, outputs). Only "single" is supported. client : RapidataClient | None The Rapidata client to use. If None, a new one is created. rapidata_client_id : str | None @@ -112,7 +113,7 @@ def __init__( self.add_state("media_cache", default=[]) self.add_state("prompt_cache", default=[]) self.benchmark: RapidataBenchmark | None = None - self.current_benchmarked_model: str | None = None + self.current_context: str | None = None @classmethod def from_rapidata_benchmark( @@ -225,7 +226,7 @@ def create_async_request( You can create multiple leaderboards to evaluate different quality dimensions. Must be called after :meth:`create_benchmark` (or after attaching a - benchmark via :meth:`from_benchmark` / :meth:`from_benchmark_id`). + benchmark via :meth:`from_rapidata_benchmark`). Parameters ---------- @@ -258,7 +259,7 @@ def set_current_context(self, model_name: str, **kwargs) -> None: **kwargs : Any Additional keyword arguments. """ - self.current_benchmarked_model = model_name + self.current_context = model_name self.reset() # Clear the cache for the new model. def update(self, x: List[Any] | Tensor, gt: List[Any] | Tensor, outputs: Any) -> None: @@ -275,7 +276,7 @@ def update(self, x: List[Any] | Tensor, gt: List[Any] | Tensor, outputs: Any) -> The model outputs (generated media). """ self._require_benchmark() - self._require_model() + self._require_context() inputs = metric_data_processor(x, gt, outputs, self.call_type) self.prompt_cache.extend(inputs[0]) self.media_cache.extend(inputs[1]) @@ -294,15 +295,15 @@ def compute(self) -> None: collected. """ self._require_benchmark() - self._require_model() + self._require_context() if not self.media_cache: raise ValueError("No data accumulated. Call update() before compute().") media = self._prepare_media_for_upload() - # Ignoring the type error because _require_model() has already been called, but ty can't see it. + # Ignoring the type error because _require_context() has already been called, but ty can't see it. self.benchmark.evaluate_model( - self.current_benchmarked_model, # type: ignore[arg-type] + self.current_context, # type: ignore[arg-type] media=media, prompts=self.prompt_cache, ) @@ -315,10 +316,14 @@ def compute(self) -> None: "Use retrieve_results() to check scores later, " "or monitor progress at: " "https://app.rapidata.ai/mri/benchmarks/%s", - self.current_benchmarked_model, + self.current_context, self.benchmark.id, ) + def on_context_change(self) -> None: + """Reset the cache when the context changes.""" + self.reset() + @staticmethod def _is_not_ready_error(exc: Exception) -> bool: """ @@ -518,13 +523,6 @@ def _require_benchmark(self) -> None: "Call create_benchmark(), or use from_benchmark() / from_benchmark_id()." ) - def _require_model(self) -> None: - """Raise if no model context has been set.""" - if self.current_benchmarked_model is None: - raise ValueError( - "No model set. Call set_current_context() first." - ) - def _prepare_media_for_upload(self, media: list[torch.Tensor | PIL.Image.Image | str] | None = None) -> list[str]: """ Convert cached media to file paths that Rapidata can upload. @@ -534,6 +532,11 @@ def _prepare_media_for_upload(self, media: list[torch.Tensor | PIL.Image.Image | - PIL.Image: saved to a temporary file - torch.Tensor: saved to a temporary file + Parameters + ---------- + media : list[torch.Tensor | PIL.Image.Image | str] | None + The media to prepare for upload. If None, the media cache is used. + Returns ------- list[str] diff --git a/src/pruna/evaluation/metrics/metric_stateful.py b/src/pruna/evaluation/metrics/metric_stateful.py index 3464aaef..02982980 100644 --- a/src/pruna/evaluation/metrics/metric_stateful.py +++ b/src/pruna/evaluation/metrics/metric_stateful.py @@ -121,25 +121,6 @@ def update(self, *args, **kwargs) -> None: The keyword arguments to pass to the metric. """ - def set_current_context(self, *args, **kwargs) -> None: - """ - Set the current benchmarked model for the metric. - - Override this in subclasses that need to track which model or - configuration is being evaluated, such as async metrics that - submit results to external services. - - By default, this is a no-op. - - Parameters - ---------- - *args : Any - The arguments to pass to the metric. - **kwargs : Any - The keyword arguments to pass to the metric. - """ - pass - @abstractmethod def compute(self,) -> Any: """Override this method to compute the final metric value.""" diff --git a/tests/evaluation/test_rapidata.py b/tests/evaluation/test_rapidata.py index 0985fbf9..ba736d3a 100644 --- a/tests/evaluation/test_rapidata.py +++ b/tests/evaluation/test_rapidata.py @@ -34,7 +34,7 @@ def metric_with_benchmark(metric): @pytest.fixture def metric_ready(metric_with_benchmark): - metric_with_benchmark.set_current_context("test-model") + metric_with_benchmark.current_context = "test-model" return metric_with_benchmark @pytest.fixture @@ -120,7 +120,7 @@ def test_set_current_context_resets_caches(metric_ready): """Test that set_current_context resets the caches.""" metric_ready.prompt_cache.append("leftover") metric_ready.media_cache.append("leftover") - metric_ready.set_current_context("model-b") + metric_ready.current_context = "model-b" assert metric_ready.prompt_cache == [] assert metric_ready.media_cache == [] @@ -138,14 +138,14 @@ def test_update_accumulates_prompts_and_media(metric_ready): def test_update_raises_without_benchmark(metric): """Test that update raises without a benchmark.""" - metric.current_benchmarked_model = "m" + metric.current_context = "m" with pytest.raises(ValueError, match="No benchmark configured"): metric.update(["p"], [None], [torch.rand(3, 32, 32)]) -def test_update_raises_without_model(metric_with_benchmark): +def test_update_raises_without_context(metric_with_benchmark): """Test that update raises without a model context.""" - with pytest.raises(ValueError, match="No model set"): + with pytest.raises(ValueError, match="No context set. Set current_context first."): metric_with_benchmark.update(["p"], [None], [torch.rand(3, 32, 32)]) def test_prepare_media_string_passthrough(metric_ready_with_cleanup): @@ -211,7 +211,7 @@ def test_compute_raises_when_cache_empty(metric_ready): def test_compute_raises_without_model_context(metric_with_benchmark): """Test that compute raises without a model context.""" - with pytest.raises(ValueError, match="No model set"): + with pytest.raises(ValueError, match="No context set. Set current_context first."): metric_with_benchmark.compute() From fdae70dcafaa718f5dba364665347fa002ea7f15 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Mon, 23 Mar 2026 19:14:55 +0000 Subject: [PATCH 7/7] ci: add evaluation as an umbrella dep --- .github/actions/setup-uv-project/action.yml | 2 +- pyproject.toml | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/actions/setup-uv-project/action.yml b/.github/actions/setup-uv-project/action.yml index 10d95d33..a5f9a7fa 100644 --- a/.github/actions/setup-uv-project/action.yml +++ b/.github/actions/setup-uv-project/action.yml @@ -12,4 +12,4 @@ runs: github-token: ${{ github.token }} - shell: bash - run: uv sync --extra dev --extra lmharness --extra vllm --extra rapidata + run: uv sync --extra dev --extra vllm --extra evaluation diff --git a/pyproject.toml b/pyproject.toml index 33c47332..b22eebb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -225,6 +225,10 @@ cpu = [] lmharness = [ "lm-eval>=0.4.0" ] +evaluation = [ + "pruna[rapidata]", + "pruna[lmharness]" +] intel = [ "intel-extension-for-pytorch>=2.7.0", ]