diff --git a/src/pruna/evaluation/metrics/metric_evalharness.py b/src/pruna/evaluation/metrics/metric_evalharness.py index 95e083ef..c3771abe 100644 --- a/src/pruna/evaluation/metrics/metric_evalharness.py +++ b/src/pruna/evaluation/metrics/metric_evalharness.py @@ -17,8 +17,15 @@ from typing import Any, List, Tuple import torch -from lm_eval.api import metrics # noqa: F401 # needed to register lm-eval metrics -from lm_eval.api import registry as lm_registry + +try: + # lm_eval.api.metrics import is needed to register lm-eval metrics + from lm_eval.api import metrics as _lm_metrics # noqa: F401 + from lm_eval.api import registry as lm_registry + + _LM_EVAL_AVAILABLE = True +except ImportError: + _LM_EVAL_AVAILABLE = False from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry @@ -29,10 +36,7 @@ METRIC_EVALHARNESS = "lm_eval_metric" # Perplexity from lm_eval requires the call type to be "y" # So requires us to handle turning the vector logits into a single log prob. -LM_EVAL_METRICS = [ - m for m in lm_registry.METRIC_REGISTRY - if m != "perplexity" # We use the torchmetrics implementation for perplexity -] +LM_EVAL_METRICS = [m for m in lm_registry.METRIC_REGISTRY if m != "perplexity"] if _LM_EVAL_AVAILABLE else [] @MetricRegistry.register_wrapper(available_metrics=LM_EVAL_METRICS) @@ -54,6 +58,8 @@ class LMEvalMetric(StatefulMetric): pairs: List[Tuple[Any, Any]] # dynamically added by add_state() def __init__(self, metric_name: str, call_type: str = "y_gt") -> None: + if not _LM_EVAL_AVAILABLE: + raise ImportError("lm-eval is required for LMEvalMetric. Install it with: pip install 'pruna[lmharness]'") super().__init__() self.metric_name = metric_name self.call_type = call_type