Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/setup-uv-project/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ runs:
github-token: ${{ github.token }}

- shell: bash
run: uv sync --extra dev --extra lmharness --extra vllm
run: uv sync --extra dev --extra vllm --extra evaluation
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ full = [
vbench = [
"vbench-pruna; sys_platform != 'darwin'",
]
rapidata = [
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, so we already start with the extra seperation, nice!
@begumcig I know that we can also do something like.

evaluation = [
    rapidata,
    vbench
]
could be nice to already start structuring like this, right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't think vbench and rapidata have a lot of shared dependencies, so doesn't really make sense to me to group them together

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed it could be a shared dependency group for all evaluation metrics. You can add extras to extras, but perhaps you'd like to keep them seperate?

"rapidata>=3.0.0"
]
dev = [
"wget",
"python-dotenv",
Expand Down Expand Up @@ -222,6 +225,10 @@ cpu = []
lmharness = [
"lm-eval>=0.4.0"
]
evaluation = [
"pruna[rapidata]",
"pruna[lmharness]"
]
intel = [
"intel-extension-for-pytorch>=2.7.0",
]
Expand Down
31 changes: 21 additions & 10 deletions src/pruna/evaluation/evaluation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
from pruna.data.utils import move_batch_to_device
from pruna.engine.pruna_model import PrunaModel
from pruna.engine.utils import get_device, move_to_device, safe_memory_cleanup, set_to_best_available_device
from pruna.evaluation.metrics.context_mixin import EvaluationContextMixin
from pruna.evaluation.metrics.metric_base import BaseMetric
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
from pruna.evaluation.metrics.result import MetricResult
from pruna.evaluation.metrics.result import MetricResult, MetricResultProtocol
from pruna.evaluation.metrics.utils import ensure_device_consistency, get_device_map, group_metrics_by_inheritance
from pruna.evaluation.task import Task
from pruna.logging.logger import pruna_logger
Expand Down Expand Up @@ -71,8 +72,8 @@ def __init__(
raise ValueError("When not using 'task' parameter, both 'request' and 'datamodule' must be provided.")
self.task = Task(request=request, datamodule=datamodule, device=device)

self.first_model_results: List[MetricResult] = []
self.subsequent_model_results: List[MetricResult] = []
self.first_model_results: List[MetricResultProtocol] = []
self.subsequent_model_results: List[MetricResultProtocol] = []
self.device = set_to_best_available_device(self.task.device)
self.cache: List[Tensor] = []
self.evaluation_for_first_model: bool = True
Expand Down Expand Up @@ -124,18 +125,20 @@ def from_benchmark(
)
return cls(task=task)

def evaluate(self, model: Any) -> List[MetricResult]:
def evaluate(self, model: Any, model_name: str | None = None) -> List[MetricResultProtocol]:
"""
Evaluate models using different metric types.

Parameters
----------
model : PrunaModel
model : Any
The model to evaluate.
model_name : str | None, optional
The name of the model to evaluate. Required for rapidata benchmark submission.

Returns
-------
List[MetricResult]
List[MetricResultProtocol]
The results of the model.
"""
results = []
Expand All @@ -146,6 +149,10 @@ def evaluate(self, model: Any) -> List[MetricResult]:
pairwise_metrics = self.task.get_pairwise_stateful_metrics()
stateless_metrics = self.task.get_stateless_metrics()

for metric in single_stateful_metrics:
if isinstance(metric, EvaluationContextMixin):
metric.current_context = model_name

# Update and compute stateful metrics.
pruna_logger.info("Evaluating stateful metrics.")
with torch.no_grad():
Expand Down Expand Up @@ -278,7 +285,7 @@ def update_stateful_metrics(

def compute_stateful_metrics(
self, single_stateful_metrics: List[StatefulMetric], pairwise_metrics: List[StatefulMetric]
) -> List[MetricResult]:
) -> List[MetricResultProtocol]:
"""
Compute stateful metrics.

Expand All @@ -296,16 +303,20 @@ def compute_stateful_metrics(
"""
results = []
for stateful_metric in single_stateful_metrics:
results.append(stateful_metric.compute())
result = stateful_metric.compute()
if result is not None:
results.append(result)
stateful_metric.reset()

if not self.evaluation_for_first_model and self.task.is_pairwise_evaluation():
for pairwise_metric in pairwise_metrics:
results.append(pairwise_metric.compute())
result = pairwise_metric.compute()
if result is not None:
results.append(result)
pairwise_metric.reset()
return results

def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[Any]) -> List[MetricResult]:
def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[Any]) -> List[MetricResultProtocol]:
"""
Compute stateless metrics.

Expand Down
9 changes: 9 additions & 0 deletions src/pruna/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric
from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper

try:
from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric as RapidataMetric
except ModuleNotFoundError as e:
if e.name != "rapidata":
raise

__all__ = [
"MetricRegistry",
"TorchMetricWrapper",
Expand All @@ -46,3 +52,6 @@
"AestheticLAION",
"LMEvalMetric",
]

if "RapidataMetric" in globals():
__all__.append("RapidataMetric")
53 changes: 53 additions & 0 deletions src/pruna/evaluation/metrics/async_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from abc import ABC, abstractmethod
from typing import Any


class AsyncEvaluationMixin(ABC):
"""
Mixin for metrics that submit to external evaluation services and retrieve results asynchronously.

Subclasses implement create_request() to set up an evaluation
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Subclasses implement create_request() to set up an evaluation
Subclasses implement create_async_request() to set up an evaluation

(e.g., create a leaderboard) and retrieve_results() to retrieve
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(e.g., create a leaderboard) and retrieve_results() to retrieve
(e.g., create a leaderboard) and retrieve_async_results() to retrieve

outcomes (e.g., standings from human evaluators).
"""

@abstractmethod
def create_async_request(self, *args, **kwargs) -> Any:
"""
Create/configure an evaluation request on the external service.

Parameters
----------
*args :
Variable length argument list.
**kwargs :
Arbitrary keyword arguments.
"""

@abstractmethod
def retrieve_async_results(self, *args, **kwargs) -> Any:
"""
Retrieve results from the external service.

Parameters
----------
*args :
Variable length argument list.
**kwargs :
Arbitrary keyword arguments.
"""
62 changes: 62 additions & 0 deletions src/pruna/evaluation/metrics/context_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from abc import ABC


class EvaluationContextMixin(ABC):
"""
Mixin for metrics that evaluate multiple models sequentially.

Provides a current_context property that tracks which model is being
evaluated. Setting a new context triggers on_context_change(), which
subclasses can override to reset state between models.
"""

_current_context: str | None = None

@property
def current_context(self) -> str | None:
"""
Return the current context.

Returns
-------
str | None
The current context.
"""
return self._current_context

@current_context.setter
def current_context(self, value: str | None) -> None:
"""
Set the current context.

Parameters
----------
value : str
The new context.
"""
self._current_context = value
self.on_context_change()
Comment on lines +52 to +53
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about checking if the current_context value actually changed an only triggering then?
back in my frontend days, it was always done like this, but completely up to you

Suggested change
self._current_context = value
self.on_context_change()
if self._current_context != value:
self._current_context = value
self.on_context_change()


def on_context_change(self) -> None:
"""Hook called when the context changes. Override to reset state."""
pass

def _require_context(self) -> None:
"""Raise if no context has been set."""
if self._current_context is None:
raise ValueError("No context set. Set current_context first.")
Loading
Loading