From 8ef34a053a26a7ee29bf67517c40149021758fbf Mon Sep 17 00:00:00 2001 From: Rodrigo Almeida Date: Tue, 13 Jan 2026 13:35:40 +0100 Subject: [PATCH 1/2] Add ROC skill metrics and tests --- src/extremeweatherbench/metrics.py | 51 ++++++++++++++++ tests/test_metrics.py | 95 ++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+) diff --git a/src/extremeweatherbench/metrics.py b/src/extremeweatherbench/metrics.py index 21bed2b5..c93a87c1 100644 --- a/src/extremeweatherbench/metrics.py +++ b/src/extremeweatherbench/metrics.py @@ -658,6 +658,57 @@ def _compute_metric( return transformed.accuracy() +class ReceiverOperatingCharacteristic(ThresholdMetric): + """Receiver Operating Characteristic metric.""" + + def __init__(self, name: str = "ReceiverOperatingCharacteristic", *args, **kwargs): + super().__init__(name, *args, **kwargs) + + def _compute_metric( + self, + forecast: xr.DataArray, + target: xr.DataArray, + **kwargs: Any, + ) -> Any: + preserve_dims = kwargs.get("preserve_dims", self.preserve_dims) + op_func = utils.maybe_get_operator(kwargs.get("op_func", operator.ge)) + + # Binarize forecast and target using thresholds + binary_forecast = utils.maybe_densify_dataarray( + op_func(forecast, kwargs.get("forecast_threshold", self.forecast_threshold)) + ).astype(float) + binary_target = utils.maybe_densify_dataarray( + op_func(target, kwargs.get("target_threshold", self.target_threshold)) + ).astype(float) + + return scores.probability.roc_curve_data( + binary_forecast, + binary_target, + thresholds="auto", + preserve_dims=preserve_dims, + weights=None, + ) + + +class ReceiverOperatingCharacteristicSkillScore(ReceiverOperatingCharacteristic): + """Receiver Operating Characteristic Skill Score metric.""" + + def __init__( + self, name: str = "ReceiverOperatingCharacteristicSkillScore", *args, **kwargs + ): + super().__init__(name, *args, **kwargs) + + def _compute_metric( + self, + forecast: xr.DataArray, + target: xr.DataArray, + auc_reference: float = 0.5, + **kwargs: Any, + ) -> Any: + roc_curve_data = super()._compute_metric(forecast, target, **kwargs) + return (roc_curve_data["AUC"] - auc_reference) / (1 - auc_reference) + + class MeanSquaredError(BaseMetric): """Mean Squared Error metric. diff --git a/tests/test_metrics.py b/tests/test_metrics.py index a7056806..70c4d756 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -250,6 +250,15 @@ def test_accuracy_threshold_metric(self): assert acc_metric.forecast_threshold == 15000 assert acc_metric.target_threshold == 0.3 + def test_roc_threshold_metric(self): + """Test ROC threshold metric instantiation and properties.""" + roc_metric = metrics.ReceiverOperatingCharacteristic( + forecast_threshold=15000, target_threshold=0.3 + ) + assert isinstance(roc_metric, metrics.ThresholdMetric) + assert roc_metric.forecast_threshold == 15000 + assert roc_metric.target_threshold == 0.3 + def test_threshold_metric_instance_interface(self): """Test that instance callable interface works.""" # Create test data @@ -1979,6 +1988,92 @@ def test_compute_metric(self): assert isinstance(result, xr.DataArray) +class TestROCSS: + """Tests for the ROCSS metric.""" + + def test_instantiation(self): + """Test that ROCSS can be instantiated.""" + metric = metrics.ReceiverOperatingCharacteristicSkillScore() + assert isinstance(metric, metrics.ReceiverOperatingCharacteristic) + + def test_compute_metric(self): + """Test ROCSS computation.""" + metric = metrics.ReceiverOperatingCharacteristicSkillScore( + forecast_threshold=0.5, target_threshold=0.5 + ) + + forecast = xr.DataArray( + data=[0.8, 0.3, 0.7, 0.2], + dims=["lead_time"], + coords={"lead_time": [0, 1, 2, 3]}, + ) + target = xr.DataArray( + data=[0.9, 0.1, 0.8, 0.6], + dims=["lead_time"], + coords={"lead_time": [0, 1, 2, 3]}, + ) + + result = metric._compute_metric(forecast, target) + assert isinstance(result, xr.DataArray) + + def test_skill_score_zero_when_auc_matches_reference(self): + """ROCSS should be zero when AUC equals the reference value.""" + metric = metrics.ReceiverOperatingCharacteristicSkillScore( + forecast_threshold=0.5, target_threshold=0.5, preserve_dims=None + ) + + forecast = xr.DataArray( + data=[0.8, 0.3, 0.7, 0.2], + dims=["sample"], + coords={"sample": [0, 1, 2, 3]}, + ) + target = xr.DataArray( + data=[0.9, 0.1, 0.8, 0.6], + dims=["sample"], + coords={"sample": [0, 1, 2, 3]}, + ) + + roc_metric = metrics.ReceiverOperatingCharacteristic( + forecast_threshold=0.5, target_threshold=0.5, preserve_dims=None + ) + roc_curve_data = roc_metric._compute_metric(forecast, target) + auc = roc_curve_data["AUC"] + + auc_reference = float(auc) + result = metric._compute_metric( + forecast, target, auc_reference=auc_reference + ) + + xr.testing.assert_allclose(result, xr.zeros_like(auc)) + + def test_skill_score_scales_auc_above_reference(self): + """ROCSS scales the AUC improvement over the reference.""" + forecast = xr.DataArray( + data=[0.9, 0.7, 0.6, 0.2], + dims=["sample"], + coords={"sample": [0, 1, 2, 3]}, + ) + target = xr.DataArray( + data=[0.8, 0.4, 0.9, 0.3], + dims=["sample"], + coords={"sample": [0, 1, 2, 3]}, + ) + + roc_metric = metrics.ReceiverOperatingCharacteristic( + forecast_threshold=0.6, target_threshold=0.5, preserve_dims=None + ) + roc_curve_data = roc_metric._compute_metric(forecast, target) + auc = roc_curve_data["AUC"] + + metric = metrics.ReceiverOperatingCharacteristicSkillScore( + forecast_threshold=0.6, target_threshold=0.5, preserve_dims=None + ) + result = metric._compute_metric(forecast, target, auc_reference=0.5) + + expected = (auc - 0.5) / (1 - 0.5) + xr.testing.assert_allclose(result, expected) + + class TestMetricIntegration: """Integration tests for metric classes.""" From e3c27975e081af0fe64ace1e0532c06f28047d01 Mon Sep 17 00:00:00 2001 From: Rodrigo Almeida Date: Wed, 14 Jan 2026 12:23:59 +0100 Subject: [PATCH 2/2] format --- tests/test_metrics.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 70c4d756..ca1b63f8 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -2040,9 +2040,7 @@ def test_skill_score_zero_when_auc_matches_reference(self): auc = roc_curve_data["AUC"] auc_reference = float(auc) - result = metric._compute_metric( - forecast, target, auc_reference=auc_reference - ) + result = metric._compute_metric(forecast, target, auc_reference=auc_reference) xr.testing.assert_allclose(result, xr.zeros_like(auc))