Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/pruna/evaluation/metrics/metric_evalharness.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@
from typing import Any, List, Tuple

import torch
from lm_eval.api import metrics # noqa: F401 # needed to register lm-eval metrics
from lm_eval.api import registry as lm_registry

try:
# lm_eval.api.metrics import is needed to register lm-eval metrics
from lm_eval.api import metrics as _lm_metrics # noqa: F401
from lm_eval.api import registry as lm_registry

_LM_EVAL_AVAILABLE = True
except ImportError:
_LM_EVAL_AVAILABLE = False

from pruna.evaluation.metrics.metric_stateful import StatefulMetric
from pruna.evaluation.metrics.registry import MetricRegistry
Expand All @@ -29,10 +36,7 @@
METRIC_EVALHARNESS = "lm_eval_metric"
# Perplexity from lm_eval requires the call type to be "y"
# So requires us to handle turning the vector logits into a single log prob.
LM_EVAL_METRICS = [
m for m in lm_registry.METRIC_REGISTRY
if m != "perplexity" # We use the torchmetrics implementation for perplexity
]
LM_EVAL_METRICS = [m for m in lm_registry.METRIC_REGISTRY if m != "perplexity"] if _LM_EVAL_AVAILABLE else []


@MetricRegistry.register_wrapper(available_metrics=LM_EVAL_METRICS)
Expand All @@ -54,6 +58,8 @@ class LMEvalMetric(StatefulMetric):
pairs: List[Tuple[Any, Any]] # dynamically added by add_state()

def __init__(self, metric_name: str, call_type: str = "y_gt") -> None:
if not _LM_EVAL_AVAILABLE:
raise ImportError("lm-eval is required for LMEvalMetric. Install it with: pip install 'pruna[lmharness]'")
super().__init__()
self.metric_name = metric_name
self.call_type = call_type
Expand Down
Loading