From b28ef12c0e6601bb44e47b07d149ccdff3ddb007 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Thu, 2 Apr 2026 17:05:40 +0000 Subject: [PATCH 1/2] ci: add guards for lm eval imports to make importing the metric okay --- .../evaluation/metrics/metric_evalharness.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/pruna/evaluation/metrics/metric_evalharness.py b/src/pruna/evaluation/metrics/metric_evalharness.py index 95e083ef..5bf72680 100644 --- a/src/pruna/evaluation/metrics/metric_evalharness.py +++ b/src/pruna/evaluation/metrics/metric_evalharness.py @@ -17,8 +17,13 @@ from typing import Any, List, Tuple import torch -from lm_eval.api import metrics # noqa: F401 # needed to register lm-eval metrics -from lm_eval.api import registry as lm_registry + +try: + from lm_eval.api import metrics as _lm_metrics # noqa: F401 + from lm_eval.api import registry as lm_registry + _LM_EVAL_AVAILABLE = True +except ImportError: + _LM_EVAL_AVAILABLE = False from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry @@ -31,8 +36,8 @@ # So requires us to handle turning the vector logits into a single log prob. LM_EVAL_METRICS = [ m for m in lm_registry.METRIC_REGISTRY - if m != "perplexity" # We use the torchmetrics implementation for perplexity -] + if m != "perplexity" +] if _LM_EVAL_AVAILABLE else [] @MetricRegistry.register_wrapper(available_metrics=LM_EVAL_METRICS) @@ -54,6 +59,11 @@ class LMEvalMetric(StatefulMetric): pairs: List[Tuple[Any, Any]] # dynamically added by add_state() def __init__(self, metric_name: str, call_type: str = "y_gt") -> None: + if not _LM_EVAL_AVAILABLE: + raise ImportError( + "lm-eval is required for LMEvalMetric. " + "Install it with: pip install 'pruna[lmharness]'" + ) super().__init__() self.metric_name = metric_name self.call_type = call_type From c3a01a3b70ddda5d6c6f1398f392aea152ed6adc Mon Sep 17 00:00:00 2001 From: Gaspar Rochette Date: Fri, 3 Apr 2026 10:38:28 +0000 Subject: [PATCH 2/2] style: add back comment explaining lm_eval.api.metrics import and slight reformating --- src/pruna/evaluation/metrics/metric_evalharness.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/pruna/evaluation/metrics/metric_evalharness.py b/src/pruna/evaluation/metrics/metric_evalharness.py index 5bf72680..c3771abe 100644 --- a/src/pruna/evaluation/metrics/metric_evalharness.py +++ b/src/pruna/evaluation/metrics/metric_evalharness.py @@ -19,8 +19,10 @@ import torch try: + # lm_eval.api.metrics import is needed to register lm-eval metrics from lm_eval.api import metrics as _lm_metrics # noqa: F401 from lm_eval.api import registry as lm_registry + _LM_EVAL_AVAILABLE = True except ImportError: _LM_EVAL_AVAILABLE = False @@ -34,10 +36,7 @@ METRIC_EVALHARNESS = "lm_eval_metric" # Perplexity from lm_eval requires the call type to be "y" # So requires us to handle turning the vector logits into a single log prob. -LM_EVAL_METRICS = [ - m for m in lm_registry.METRIC_REGISTRY - if m != "perplexity" -] if _LM_EVAL_AVAILABLE else [] +LM_EVAL_METRICS = [m for m in lm_registry.METRIC_REGISTRY if m != "perplexity"] if _LM_EVAL_AVAILABLE else [] @MetricRegistry.register_wrapper(available_metrics=LM_EVAL_METRICS) @@ -60,10 +59,7 @@ class LMEvalMetric(StatefulMetric): def __init__(self, metric_name: str, call_type: str = "y_gt") -> None: if not _LM_EVAL_AVAILABLE: - raise ImportError( - "lm-eval is required for LMEvalMetric. " - "Install it with: pip install 'pruna[lmharness]'" - ) + raise ImportError("lm-eval is required for LMEvalMetric. Install it with: pip install 'pruna[lmharness]'") super().__init__() self.metric_name = metric_name self.call_type = call_type