diff --git a/improver/calibration/quantile_mapping.py b/improver/calibration/quantile_mapping.py new file mode 100644 index 0000000000..6421b4eba9 --- /dev/null +++ b/improver/calibration/quantile_mapping.py @@ -0,0 +1,325 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. +"""Module containing quantile mapping bias correction. + +Quantile mapping is a statistical calibration technique that adjusts forecast +values to match the distribution of reference (observed) data. It works by: +1. Finding each forecast value's position (quantile) in the forecast distribution +2. Mapping that quantile to the corresponding value in the reference distribution + +This corrects systematic biases while preserving spatial patterns. +""" + +from typing import Optional + +import numpy as np +from iris.cube import Cube + +from improver import PostProcessingPlugin + + +class QuantileMapping(PostProcessingPlugin): + """Apply quantile mapping bias correction to forecast data.""" + + def __init__(self, preservation_threshold: Optional[float] = None) -> None: + """Initialize the quantile mapping plugin. + + Args: + preservation_threshold: + Optional threshold value below which (exclusive) the forecast + values are not adjusted to be like the reference. Useful for variables + such as precipitation, where a user may be wary of mapping 0mm/hr + precipitation values to non-zero values. + """ + self.preservation_threshold = preservation_threshold + + @staticmethod + def _build_empirical_cdf(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Build empirical cumulative distribution function (CDF). + + Args: + data: 1D array of input data values. + + Returns: + Tuple of (sorted_values, quantiles) representing the empirical CDF. + + """ + sorted_values = np.sort(data) + num_points = sorted_values.shape[0] + quantiles = np.arange(1, num_points + 1) / num_points + return sorted_values, quantiles + + @staticmethod + def _inverted_cdf(data: np.ndarray, quantiles: np.ndarray) -> np.ndarray: + """Get distribution values at specified quantiles (discrete step method). + + Uses floored index lookup, rounding each quantile down to the nearest + available data point. This creates a step-function mapping that's faster + but less smooth than interpolation. + + Taken from: + https://github.com/ecmwf-projects/ibicus/blob/main/ibicus/utils/_math_utils.py + + Args: + data: + 1D array of data values defining the distribution. + quantiles: + Quantiles to evaluate (values between 0 and 1). + + Returns: + Values from the data corresponding to the requested quantiles. + """ + sorted_values = np.sort(data) + num_points = sorted_values.shape[0] + floored_indices = np.array( + np.floor((num_points - 1) * quantiles), dtype=np.int32 + ) + return sorted_values[floored_indices] + + def _map_quantiles( + self, + reference_data: np.ndarray, + forecast_data: np.ndarray, + ) -> np.ndarray: + """Transform forecast values to match the reference distribution. + + For each forecast value: + 1. Find its quantile position in the forecast distribution + 2. Map that quantile to the corresponding value in the reference distribution + using discrete (floor) method + + Example: + - reference_data: [10, 20, 30, 40, 50] + - forecast_data: [5, 15, 25, 35, 45] + + The forecast systematically underestimates by 5 units. + Corrected values: [10, 20, 30, 40, 50] (mapped to reference distribution) + + Args: + reference_data: + Target distribution (observed/historical data). + forecast_data: + Source distribution (biased forecasts to correct). + + Returns: + Bias-corrected forecast values matching the reference distribution. + """ + # Build empirical CDF for the forecast distribution + sorted_forecast_values, forecast_empirical_quantiles = ( + self._build_empirical_cdf(forecast_data) + ) + + # Find where each forecast value sits in the forecast distribution + # (i.e., determine its quantile, clipped to [0, 1]) + forecast_quantiles = np.interp( + forecast_data, sorted_forecast_values, forecast_empirical_quantiles + ) + + # Map the quantiles to values in the reference distribution + corrected_values = self._inverted_cdf(reference_data, forecast_quantiles) + + return corrected_values + + @staticmethod + def _convert_reference_cube_to_forecast_units( + reference_cube: Cube, + forecast_cube: Cube, + ) -> tuple[Cube, Cube]: + """Ensure reference cube uses the same units as forecast cube. + + Args: + reference_cube: + The reference data cube. + forecast_cube: + The forecast data cube. + + Returns: + Tuple of (reference_cube, forecast_cube) with matching units. + + Raises: + ValueError: If units are incompatible and cannot be converted. + """ + target_units = forecast_cube.units + + # Convert reference_cube to target_units if needed + if reference_cube.units != target_units: + try: + reference_cube = reference_cube.copy() + reference_cube.convert_units(target_units) + except ValueError: + raise ValueError( + f"Cannot convert cube with units {reference_cube.units} " + f"to target units {target_units}" + ) + + return (reference_cube, forecast_cube) + + def _process_masked_data( + self, + reference_cube: Cube, + forecast_cube: Cube, + ) -> tuple[np.ndarray, Optional[np.ndarray]]: + """Apply quantile mapping while properly handling masked data. + + Masked values are excluded from the calibration CDFs to avoid + contaminating the statistics. They are preserved in their original + (masked) state in the output. + + Args: + reference_cube: + The reference cube (with units already converted). + forecast_cube: + The forecast cube to calibrate. + + Returns: + Tuple of: + - corrected_data_flat: 1D array with corrected values. + - output_mask: The mask to apply, or None if data is not masked. + """ + # Determine if either cube has masked data + forecast_is_masked = np.ma.is_masked(forecast_cube.data) + reference_is_masked = np.ma.is_masked(reference_cube.data) + + if forecast_is_masked or reference_is_masked: + # Create combined mask using getmaskarray (returns False array if not masked) + combined_mask = np.ma.getmaskarray(forecast_cube.data) | np.ma.getmaskarray( + reference_cube.data + ) + + # Flatten and get valid (non-masked) indices + combined_mask_flat = combined_mask.flatten() + valid_mask = ~combined_mask_flat + + # Extract underlying data arrays (ignoring masks temporarily) + # We need the full arrays to reconstruct later, but will only + # use valid_mask indices for quantile mapping calculations + reference_data_flat = np.ma.getdata(reference_cube.data).flatten() + forecast_data_flat = np.ma.getdata(forecast_cube.data).flatten() + + # Extract ONLY valid (non-masked) values for CDF calculations + # Masked values are not included in these arrays + reference_valid = reference_data_flat[valid_mask] + forecast_valid = forecast_data_flat[valid_mask] + + # Apply quantile mapping using only valid values + corrected_valid = self._map_quantiles(reference_valid, forecast_valid) + + # Reconstruct full array with corrected values at valid positions + corrected_values_flat = forecast_data_flat.copy() + corrected_values_flat[valid_mask] = corrected_valid + + output_mask = combined_mask + else: + # No masking needed + output_mask = None + corrected_values_flat = self._map_quantiles( + reference_cube.data.flatten(), + forecast_cube.data.flatten(), + ) + + return corrected_values_flat, output_mask + + def _apply_preservation_threshold( + self, output_cube: Cube, forecast_cube: Cube + ) -> None: + """Preserve original values below preservation threshold. + + Modifies output_cube.data in-place. + + Args: + output_cube: + The cube with calibrated data to modify. + forecast_cube: + The original forecast cube with values to preserve. + """ + if self.preservation_threshold is None: + return + + mask_below_threshold = np.ma.less( + forecast_cube.data, self.preservation_threshold + ) + # np.ma.where works for both masked and non-masked arrays + output_cube.data = np.ma.where( + mask_below_threshold, forecast_cube.data, output_cube.data + ) + + def _finalise_output_cube( + self, + corrected_values_flat: np.ndarray, + forecast_cube: Cube, + output_cube: Cube, + output_mask, + ) -> None: + """Make final adjustments to output cube metadata and data type. + Args: + output_cube: + The cube to finalize. + """ + # Reshape corrected data to match original shape and set data type to float32 + if corrected_values_flat.dtype != np.float32: + corrected_values_flat = corrected_values_flat.astype(np.float32) + + corrected_data_reshaped = np.reshape(corrected_values_flat, forecast_cube.shape) + + # Reinstate original mask if applicable + if output_mask is not None: + output_cube.data = np.ma.masked_array( + corrected_data_reshaped, mask=output_mask + ) + else: + output_cube.data = corrected_data_reshaped + + # Preserve low values if threshold is set, modifying in-place + self._apply_preservation_threshold(output_cube, forecast_cube) + + def process( + self, + reference_cube: Cube, + forecast_cube: Cube, + ) -> Cube: + """Adjust forecast values to match the statistical distribution of reference + data. + + This calibration method corrects biases in forecast data by transforming its + values to follow the same distribution as a reference dataset. + Unlike grid-point methods that match values at each location, this approach uses + all data across the spatial domain to build the statistical distributions. + + This is particularly useful when forecasts have been smoothed and you want to + restore realistic variation in the values while preserving the spatial patterns. + + Uses the discrete (floor) method for quantile lookup, which maps each quantile + to the nearest available reference value, creating a step-function mapping. + + Args: + reference_cube: + The reference data that define what the "correct" distribution + should look like. + forecast_cube: + The forecast data you want to correct (e.g. smoothed model output). + + Returns: + Calibrated forecast cube with quantiles mapped to the reference + distribution. + """ + + # Ensure both cubes use the same units + reference_cube, forecast_cube = self._convert_reference_cube_to_forecast_units( + reference_cube, forecast_cube + ) + + # Create output cube to preserve metadata + output_cube = forecast_cube.copy() + + # Apply quantile mapping (handles masked data automatically) + corrected_values_flat, output_mask = self._process_masked_data( + reference_cube, forecast_cube + ) + + self._finalise_output_cube( + corrected_values_flat, forecast_cube, output_cube, output_mask + ) + + return output_cube diff --git a/improver/cli/quantile_mapping.py b/improver/cli/quantile_mapping.py new file mode 100644 index 0000000000..8e6337f761 --- /dev/null +++ b/improver/cli/quantile_mapping.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. +"""CLI to apply quantile mapping""" + +from improver import cli + + +@cli.clizefy +@cli.with_output +def process( + *cubes: cli.inputcube, + truth_attribute: str, + preservation_threshold: float = None, +): + """Adjust forecast values to match the statistical distribution of reference + data. + + Unlike grid-point methods that match values at each location, this approach uses + all data across the spatial domain to build the statistical distributions. This is + particularly useful when forecasts have been smoothed and you want to restore + realistic variation in the values while preserving the spatial patterns. + + Args: + cubes: + A list of cubes containing the forecasts and corresponding truth (reference) + used for calibration. They must have the same cube name and will be + separated based on the truth attribute. + truth_attribute: + An attribute and its value in the format of "attribute=value", + which must be present on historical truth cubes. + reference_cube: + The reference data that define what the "correct" distribution + should look like. + forecast_cube: + The forecast data you want to correct (e.g. smoothed model output). + preservation_threshold: + Optional threshold value below which (exclusive) the forecast values + are not adjusted. Useful for variables like precipitation where you + may want to preserve small/zero values. + + Returns: + Calibrated forecast cube with quantiles mapped to the reference + distribution. + + Raises: + ValueError: If reference and forecast cubes have incompatible units. + """ + from improver.calibration import split_forecasts_and_truth + from improver.calibration.quantile_mapping import QuantileMapping + + forecast_cube, reference_cube, _ = split_forecasts_and_truth(cubes, truth_attribute) + plugin = QuantileMapping(preservation_threshold=preservation_threshold) + return plugin.process( + reference_cube, + forecast_cube, + ) diff --git a/improver_tests/acceptance/SHA256SUMS b/improver_tests/acceptance/SHA256SUMS index 992e92034a..efe4fb3c32 100644 --- a/improver_tests/acceptance/SHA256SUMS +++ b/improver_tests/acceptance/SHA256SUMS @@ -878,6 +878,10 @@ a89ba9668fd878ed5c5cc017e46a25ab1f9d205b1a6913457a8f3af770cc49e1 ./precipitatio f69103cececd76e27bbff5a96e9c74c0e708dcb7f18459ade3eb448639992b34 ./precipitation_duration/standard_names/kgo_acc_1.00_rate_4.0.nc 39730b1c6f60d0ffc1a79629b29c84ee063e465f1110fd179338478277c69b03 ./precipitation_duration/standard_names/kgo_multi_threshold.nc 6a6394f52409d218e7e8d87c95a71c1f844d904bc3cbef3421f03e8d3afe98ac ./precipitation_duration/standard_names/kgo_short_period.nc +61d60afd98d8cafd2f010565e967258aaae928380b72aaebda6d4b2632d08f0b ./quantile-mapping/basic/kgo.nc +31983e9237750163a0e717d323d36c54946be5fec1b6ba0e84a1a76c6aa26f25 ./quantile-mapping/forecast.nc +5adad3bdb97e79d5a497a71c52998d641733871f86643e7703b0c7fa128e0f06 ./quantile-mapping/reference.nc +30908114a1347e9bd321aac6321f99e08b0e4236f1e255ac4f1087afa10f3f6f ./quantile-mapping/with_preservation_threshold/kgo.nc ae048c636992e80b79c6cbb44b36339b30ea8d0ef1db72cd3f4de8766346fa1d ./recursive-filter/input.nc b6cdb8bf877bb0b3b78ad224b50b9272b65732bf9e39a88df704209e228bf4c0 ./recursive-filter/input_masked.nc 11c428f6fb0202ab0f975e58e52d17342c50f607aee4fd0e387a2a62c188790e ./recursive-filter/input_variable_masked.nc diff --git a/improver_tests/acceptance/test_quantile_mapping.py b/improver_tests/acceptance/test_quantile_mapping.py new file mode 100644 index 0000000000..edfd9dff32 --- /dev/null +++ b/improver_tests/acceptance/test_quantile_mapping.py @@ -0,0 +1,55 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. +"""Tests for the quantile-mapping CLI""" + +import pytest + +from . import acceptance as acc + +pytestmark = [pytest.mark.acc, acc.skip_if_kgo_missing] +CLI = acc.cli_name_with_dashes(__file__) +run_cli = acc.run_cli(CLI) + + +def test_floor_no_threshold(tmp_path): + """Test quantile mapping with floor method and no preservation threshold.""" + kgo_dir = acc.kgo_root() / "quantile-mapping/basic/" + kgo_path = kgo_dir / "kgo.nc" + reference_path = acc.kgo_root() / "quantile-mapping/reference.nc" + forecast_path = acc.kgo_root() / "quantile-mapping/forecast.nc" + output_path = tmp_path / "output.nc" + + args = [ + reference_path, + forecast_path, + "--truth-attribute", + "mosg__model_configuration=uk_det", + "--output", + output_path, + ] + run_cli(args) + acc.compare(output_path, kgo_path) + + +def test_floor_with_threshold(tmp_path): + """Test quantile mapping with floor method and preservation threshold.""" + kgo_dir = acc.kgo_root() / "quantile-mapping/with_preservation_threshold/" + kgo_path = kgo_dir / "kgo.nc" + reference_path = acc.kgo_root() / "quantile-mapping/reference.nc" + forecast_path = acc.kgo_root() / "quantile-mapping/forecast.nc" + output_path = tmp_path / "output.nc" + + args = [ + reference_path, + forecast_path, + "--preservation-threshold", + "2.0", + "--truth-attribute", + "mosg__model_configuration=uk_det", + "--output", + output_path, + ] + run_cli(args) + acc.compare(output_path, kgo_path) diff --git a/improver_tests/calibration/quantile_mapping/__init__.py b/improver_tests/calibration/quantile_mapping/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/improver_tests/calibration/quantile_mapping/test_QuantileMapping.py b/improver_tests/calibration/quantile_mapping/test_QuantileMapping.py new file mode 100644 index 0000000000..0b7a8c9e3b --- /dev/null +++ b/improver_tests/calibration/quantile_mapping/test_QuantileMapping.py @@ -0,0 +1,246 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. + +import numpy as np +import pytest +from iris.cube import Cube + +from improver.calibration.quantile_mapping import ( + QuantileMapping, +) +from improver.synthetic_data.set_up_test_cubes import set_up_variable_cube + + +@pytest.fixture +def simple_reference_array(): + """Fixture for creating a simple reference array.""" + return np.array([10, 20, 30, 40, 50]) + + +@pytest.fixture +def simple_forecast_array(): + """Fixture for creating a simple forecast array""" + return np.array([5, 15, 25, 35, 45]) + + +def test__build_empirical_cdf(simple_reference_array): + """Test _build_empirical_cdf returns the correct empirical CDF.""" + sorted_values, quantiles = QuantileMapping()._build_empirical_cdf( + simple_reference_array + ) + + np.testing.assert_array_equal(sorted_values, np.array([10, 20, 30, 40, 50])) + np.testing.assert_array_equal(quantiles, np.array([0.2, 0.4, 0.6, 0.8, 1.0])) + + +def test__inverted_cdf(simple_reference_array): + """Test _inverted_cdf returns the correct values. Values output should be the + same as values input in this case.""" + _, quantiles = QuantileMapping()._build_empirical_cdf(simple_reference_array) + result = QuantileMapping()._inverted_cdf(simple_reference_array, quantiles) + np.testing.assert_array_equal(result, np.array([10, 20, 30, 40, 50])) + + +def test__map_quantiles( + simple_reference_array, + simple_forecast_array, +): + expected = np.array([10, 20, 30, 40, 50]) + result = QuantileMapping()._map_quantiles( + simple_reference_array, + simple_forecast_array, + ) + np.testing.assert_array_equal(result, expected) + + +@pytest.fixture +def reference_cube(): + """Fixture for creating a reference precipitation rate (mm/h) cube.""" + data = np.array( + [ + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ], + [ + [0.7, 1.8, 2.8], + [3.8, 4.9, 5.8], + [ + 6.8, + 7.7, + 8.7, + ], + ], + ], + dtype=np.float32, + ) + + return set_up_variable_cube(data, name="lwe_precipitation_rate", units="mm h-1") + + +@pytest.fixture +def forecast_cube(): + """Fixture for creating a forecast precipitation rate (mm/h) cube.""" + data = np.array( + [ + [ + [0.6, 1.7, 2.7], + [3.7, 4.8, 5.7], + [6.7, 7.6, 8.6], + ], + [ + [0.5, 1.6, 2.6], + [3.6, 4.7, 5.6], + [6.6, 7.5, 8.5], + ], + ], + dtype=np.float32, + ) + return set_up_variable_cube(data, name="lwe_precipitation_rate", units="mm h-1") + + +@pytest.fixture +def expected_result_no_threshold(): + """Expected result for quantile mapping without a preservation threshold.""" + return np.array( + [ + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[0.7, 1.8, 2.8], [3.8, 4.9, 5.8], [6.8, 7.7, 8.7]], + ], + dtype=np.float32, + ) + + +@pytest.mark.parametrize( + "test_case", + [ + "same_units", + "different_units", + "incompatible_units", + ], +) +def test__convert_reference_cube_to_forecast( + reference_cube, + forecast_cube, + test_case, +): + """Test handling of cubes with same, different, and incompatible units.""" + plugin = QuantileMapping() + + if test_case == "same_units": + # Both cubes already in mm h-1, should work normally + result = plugin.process(reference_cube, forecast_cube) + assert result.units == forecast_cube.units + + elif test_case == "different_units": + # Convert forecast to different (but compatible) units + forecast_cube_copy = forecast_cube.copy() + forecast_cube_copy.convert_units("m s-1") + result = plugin.process(reference_cube, forecast_cube_copy) + # Result should be in forecast units (m s-1) + assert result.units == forecast_cube_copy.units + + elif test_case == "incompatible_units": + # Set incompatible units and expect error + forecast_cube_copy = forecast_cube.copy() + forecast_cube_copy.units = "Celsius" + with pytest.raises(ValueError, match="Cannot convert cube with units"): + plugin.process(reference_cube, forecast_cube_copy) + + +def test_quantile_mapping_process_no_threshold( + reference_cube, forecast_cube, expected_result_no_threshold +): + """Test quantile mapping with no preservation threshold.""" + plugin = QuantileMapping() + result = plugin.process(reference_cube, forecast_cube) + + assert isinstance(result, Cube) + assert result.shape == forecast_cube.shape + assert result.data.dtype == np.float32 + assert not np.ma.is_masked(result.data) + np.testing.assert_array_equal(result.data, expected_result_no_threshold) + + +def test_quantile_mapping_process_with_threshold(reference_cube, forecast_cube): + """Test quantile mapping with preservation threshold. + Index [1,0,0] should remain 0.5, despite the reference normally transforming + it to the reference value of 0.7. + """ + plugin = QuantileMapping(preservation_threshold=0.51) + result = plugin.process(reference_cube, forecast_cube) + + expected_result = np.array( + [ + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[0.5, 1.8, 2.8], [3.8, 4.9, 5.8], [6.8, 7.7, 8.7]], + ], + dtype=np.float32, + ) + + assert isinstance(result, Cube) + assert result.shape == forecast_cube.shape + assert result.data.dtype == np.float32 + assert result.data.mask is not False + np.testing.assert_array_equal(result.data, expected_result) + + +@pytest.mark.parametrize( + "test_case", + [ + "one_input_masked", + "both_inputs_masked", + ], +) +def test_masked_input(reference_cube, forecast_cube, test_case): + """Test behaviour when one or both inputs have masked values. + In both cases, the mask should be a union of cube masks.""" + + # Make copies to avoid fixture mutation + reference_cube = reference_cube.copy() + forecast_cube = forecast_cube.copy() + + # Mask reference at position [0, 0, 0] + reference_cube.data = np.ma.masked_array( + reference_cube.data, mask=np.zeros_like(reference_cube.data, dtype=bool) + ) + reference_cube.data[0, 0, 0] = np.ma.masked + + if test_case == "one_input_masked": + expected_mask_count = 1 + + elif test_case == "both_inputs_masked": + # Also mask forecast at position [0, 0, 1] + forecast_cube.data = np.ma.masked_array( + forecast_cube.data, mask=np.zeros_like(forecast_cube.data, dtype=bool) + ) + forecast_cube.data[0, 0, 1] = np.ma.masked + expected_mask_count = 2 + + plugin = QuantileMapping() + result = plugin.process(reference_cube, forecast_cube) + + # Check that result is masked + assert np.ma.is_masked(result.data) + # Check mask count matches expected (union of input masks) + assert expected_mask_count == np.ma.count_masked(result.data) + # Check that the correct positions are masked + if test_case == "one_input_masked": + assert result.data.mask[0, 0, 0] + assert not result.data.mask[0, 0, 1] + elif test_case == "both_inputs_masked": + assert result.data.mask[0, 0, 0] + assert result.data.mask[0, 0, 1] + + +def test_metadata_preservation(reference_cube, forecast_cube): + """Test that metadata from forecast cube is preserved.""" + plugin = QuantileMapping() + reference_cube.long_name = "kittens" + result = plugin.process(reference_cube, forecast_cube) + + # Check key metadata is preserved + assert result.long_name == forecast_cube.long_name