Skip to content
Open
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ gptqmodel = [
{ index = "pypi", marker = "sys_platform == 'darwin' and platform_machine == 'arm64'"},
]

[tool.uv]
extra-build-dependencies = { image-reward = ["setuptools"] }

[project]
name = "pruna"
version = "0.3.1"
Expand Down Expand Up @@ -141,6 +144,9 @@ dependencies = [
]

[project.optional-dependencies]
evaluation = [
"image-reward"
]
stable-fast = [
"xformers>=0.0.30",
"stable-fast-pruna==1.0.8",
Expand Down
2 changes: 2 additions & 0 deletions src/pruna/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pruna.evaluation.metrics.metric_dino_score import DinoScore
from pruna.evaluation.metrics.metric_elapsed_time import LatencyMetric, ThroughputMetric, TotalTimeMetric
from pruna.evaluation.metrics.metric_energy import CO2EmissionsMetric, EnergyConsumedMetric
from pruna.evaluation.metrics.metric_image_reward import ImageRewardMetric
from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric
from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric
from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore
Expand All @@ -43,4 +44,5 @@
"DinoScore",
"SharpnessMetric",
"AestheticLAION",
"ImageRewardMetric",
]
141 changes: 141 additions & 0 deletions src/pruna/evaluation/metrics/metric_image_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Image Reward Metric for Pruna.

This metric computes image reward scores using the ImageReward library.

Based on the InferBench implementation:
https://github.com/PrunaAI/InferBench
"""

from __future__ import annotations

from typing import Any, List

import torch
from PIL import Image

from pruna.engine.utils import set_to_best_available_device
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
from pruna.evaluation.metrics.registry import MetricRegistry
from pruna.evaluation.metrics.result import MetricResult
from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor

import ImageReward as ImageRewardModule

METRIC_IMAGE_REWARD = "image_reward"


@MetricRegistry.register(METRIC_IMAGE_REWARD)
class ImageRewardMetric(StatefulMetric):
"""
Image Reward metric for evaluating image-text alignment.

This metric uses the ImageReward model to compute how well generated images
match their text prompts based on learned human preferences.
Higher scores indicate better alignment.

Reference
----------
ImageReward: https://github.com/thaosu/ImageReward

Parameters
----------
*args : Any
Additional arguments.
device : str | torch.device | None, optional
The device to be used, e.g., 'cuda' or 'cpu'. Default is None.
If None, the best available device will be used.
model_name : str, optional
The ImageReward model to use. Default is "ImageReward-v1.0".
call_type : str, optional
The type of call to use for the metric.
**kwargs : Any
Additional keyword arguments.
"""

scores: List[float]
default_call_type: str = "y"
higher_is_better: bool = True
metric_name: str = METRIC_IMAGE_REWARD
runs_on: List[str] = ["cpu", "cuda", "mps"]

def __init__(
self,
*args,
device: str | torch.device | None = None,
model_name: str = "ImageReward-v1.0",
call_type: str = SINGLE,
**kwargs,
) -> None:
super().__init__(device=device)
self.device = set_to_best_available_device(device)
self.model_name = model_name
self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type)
self.model = ImageRewardModule.load(self.model_name, device=str(self.device))
self.add_state("scores", [])

@torch.no_grad()
def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None:
"""
Update the metric with new batch data.

Parameters
----------
x : List[Any] | torch.Tensor
The input data (text prompts).
gt : torch.Tensor
The ground truth / cached images.
outputs : torch.Tensor
The output images to score.
"""
inputs = metric_data_processor(x, gt, outputs, self.call_type)
images = inputs[0]
prompts = x if isinstance(x, list) else [""] * len(images)

for i, image in enumerate(images):
if isinstance(image, torch.Tensor):
image = self._tensor_to_pil(image)

prompt = prompts[i] if i < len(prompts) else ""
score = self.model.score(prompt, image)
self.scores.append(score)

def compute(self) -> MetricResult:
"""
Compute the mean ImageReward metric.

Returns
-------
MetricResult
The mean ImageReward metric.
"""
if not self.scores:
return MetricResult(self.metric_name, self.__dict__, 0.0)

import numpy as np

mean_score = float(np.mean(self.scores))
return MetricResult(self.metric_name, self.__dict__, mean_score)

def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image:
"""Convert tensor to PIL Image."""
if tensor.ndim == 4:
tensor = tensor[0]
if tensor.max() > 1:
tensor = tensor / 255.0
np_img = (tensor.cpu().numpy() * 255).astype("uint8")
return Image.fromarray(np_img.transpose(1, 2, 0))
Loading