diff --git a/improver/cli/threshold.py b/improver/cli/threshold.py index d1ad30a807..eaf6eeb7bf 100755 --- a/improver/cli/threshold.py +++ b/improver/cli/threshold.py @@ -56,7 +56,7 @@ def process( or with structure "THRESHOLD_VALUE": "None" (no fuzzy bounds). Repeated thresholds with different bounds are ignored; only the last duplicate will be used. - threshold_values and and threshold_config are mutually exclusive + threshold_values and threshold_config are mutually exclusive arguments, defining both will lead to an exception. threshold_units (str): Units of the threshold values. If not provided the units are diff --git a/improver/cli/threshold_interpolation.py b/improver/cli/threshold_interpolation.py index 5d012feb04..9711f6dc95 100755 --- a/improver/cli/threshold_interpolation.py +++ b/improver/cli/threshold_interpolation.py @@ -13,7 +13,9 @@ def process( forecast_at_thresholds: cli.inputcube, *, - thresholds: cli.comma_separated_list, + threshold_values: cli.comma_separated_list = None, + threshold_config: cli.inputjson = None, + threshold_units: str = None, ): """ Use this CLI to modify the probability thresholds in an existing probability @@ -22,8 +24,22 @@ def process( Args: forecast_at_thresholds: Cube expected to contain a threshold coordinate. - thresholds: - List of the desired output thresholds. + threshold_values: + The desired output thresholds, either as a list of float values or a + single float value. + threshold_config: + Threshold configuration containing threshold values. It should contain + either a list of float values or a dictionary of strings that can be + interpreted as floats with the structure: "THRESHOLD_VALUE": "None". + The latter follows the format expected in improver/cli/threshold.py, + however fuzzy bounds will be ignored here. + Repeated thresholds with different bounds are ignored; only the + last duplicate will be used. + Threshold_values and threshold_config are mutually exclusive + arguments, defining both will lead to an exception. + threshold_units: + Units of the threshold values. If not provided the units are + assumed to be the same as those of the input cube. Returns: Cube with forecast values at the desired set of thresholds. @@ -31,6 +47,8 @@ def process( """ from improver.utilities.threshold_interpolation import ThresholdInterpolation - result = ThresholdInterpolation(thresholds)(forecast_at_thresholds) + result = ThresholdInterpolation( + threshold_values, threshold_config, threshold_units + )(forecast_at_thresholds) return result diff --git a/improver/utilities/threshold_interpolation.py b/improver/utilities/threshold_interpolation.py index 3ffc6a1f34..06372fa8cd 100644 --- a/improver/utilities/threshold_interpolation.py +++ b/improver/utilities/threshold_interpolation.py @@ -4,10 +4,12 @@ # See LICENSE in the root of the repository for full licensing details. """Script to linearly interpolate thresholds""" -from typing import List, Optional +import numbers +from typing import Dict, List, Optional, Union import iris import numpy as np +from cf_units import Unit from iris.cube import Cube from numpy import ndarray @@ -27,24 +29,93 @@ class ThresholdInterpolation(PostProcessingPlugin): - def __init__(self, thresholds: List[float]): + def __init__( + self, + threshold_values: Optional[Union[List[float], float]] = None, + threshold_config: Optional[Union[List[float], Dict[str, str]]] = None, + threshold_units: Optional[str] = None, + ): """ Args: - thresholds: - List of the desired output thresholds. + threshold_values: + The desired output thresholds, either as a list of float values or a + single float value. + threshold_config: + Threshold configuration containing threshold values. It should contain + either a list of float values or a dictionary of strings that can be + interpreted as floats with the structure: "THRESHOLD_VALUE": "None". + Repeated thresholds with different bounds are ignored; only the + last duplicate will be used. + Threshold_values and threshold_config are mutually exclusive + arguments, defining both will lead to an exception. + threshold_units: + Units of the threshold values. If not provided the units are + assumed to be the same as those of the input cube. Raises: - ValueError: - If the thresholds list is empty. + ValueError: If threshold_config and threshold_values are both set + ValueError: If neither threshold_config or threshold_values are set """ - if not thresholds: - raise ValueError("The thresholds list cannot be empty.") - self.thresholds = thresholds + if threshold_config and threshold_values: + raise ValueError( + "Threshold_config and threshold_values are mutually exclusive " + "arguments - please provide one or the other, not both" + ) + if threshold_config is None and threshold_values is None: + raise ValueError( + "One of threshold_config or threshold_values must be provided." + ) + self.threshold_values = threshold_values self.threshold_coord = None + self.threshold_config = threshold_config + + thresholds = self._set_thresholds(threshold_values, threshold_config) + self.thresholds = [thresholds] if np.isscalar(thresholds) else thresholds + self.threshold_units = ( + None if threshold_units is None else Unit(threshold_units) + ) + + self.original_units = None + + @staticmethod + def _set_thresholds( + threshold_values: Optional[Union[List[float], float]] = None, + threshold_config: Optional[Union[List[float], Dict[str, str]]] = None, + ) -> List[float]: + """ + Interprets a threshold_config dictionary if provided, or ensures that + a list of thresholds has suitable precision. + + Args: + threshold_values: + A list of threshold values or a single threshold value. + threshold_config: + Either a list of float values or a dictionary of strings that can be + interpreted as floats with the structure: "THRESHOLD_VALUE": "None". + + Returns: + thresholds: + A list of input thresholds as float64 type. + """ + if threshold_config and isinstance(threshold_config, dict): + thresholds = [] + for key in threshold_config.keys(): + # Ensure thresholds are float64 to avoid rounding errors during + # possible unit conversion. + thresholds.append(float(key)) + elif threshold_config and isinstance(threshold_config, list): + thresholds = [float(x) for x in threshold_config] + else: + # Convert threshold_values to a list if it is a single value. + if isinstance(threshold_values, numbers.Number): + threshold_values = [threshold_values] + thresholds = [float(x) for x in threshold_values] + return thresholds def mask_checking(self, forecast_at_thresholds: Cube) -> Optional[np.ndarray]: """ - Check if the mask is consistent across different slices of the threshold coordinate. + Check if the mask is consistent across different slices of the threshold + coordinate. Args: forecast_at_thresholds: @@ -52,19 +123,21 @@ def mask_checking(self, forecast_at_thresholds: Cube) -> Optional[np.ndarray]: Returns: original_mask: - The original mask if the data is masked and the mask is consistent across - different slices of the threshold coordinate, otherwise None. + The original mask if the data is masked and the mask is consistent + across different slices of the threshold coordinate, otherwise None. Raises: - ValueError: If the mask varies across different slices of the threshold coordinate. + ValueError: If the mask varies across different slices of the threshold + coordinate. """ original_mask = None if np.ma.is_masked(forecast_at_thresholds.data): (crd_dim,) = forecast_at_thresholds.coord_dims(self.threshold_coord.name()) if np.diff(forecast_at_thresholds.data.mask, axis=crd_dim).any(): raise ValueError( - f"The mask is expected to be constant across different slices of the {self.threshold_coord.name()}" - f" dimension, however, in the dataset provided, the mask varies across the {self.threshold_coord.name()}" + "The mask is expected to be constant across different slices of" + f"the {self.threshold_coord.name()} dimension, however, in the dataset" + f"provided, the mask varies across the {self.threshold_coord.name()}" f" dimension. This is not currently supported." ) else: @@ -149,6 +222,9 @@ def create_cube_with_thresholds( ) template_cube.remove_coord(self.threshold_coord) + if self.threshold_units is not None: + template_cube.units = self.threshold_units + # create cube with new threshold dimension cubes = iris.cube.CubeList([]) for point in self.thresholds: @@ -175,11 +251,16 @@ def process( This method performs the following steps: 1. Identifies the threshold coordinate in the input cube. - 2. Checks if the mask is consistent across different slices of the threshold coordinate. - 3. Collapses the realizations if present. - 4. Interpolates the forecast data to the new set of thresholds. - 5. Creates a new cube with the interpolated threshold data. - 6. Applies the original mask to the new cube if it exists. + 2. Checks if the mask is consistent across different slices of the threshold + coordinate. + 3. Converts the threshold coordinate to the specified units if provided. + 4. Collapses the realizations if present. + 5. Interpolates the data to the new set of thresholds. + 6. Creates a new cube with the interpolated threshold data. + 7. Applies the original mask to the new cube if it exists. + 8. Converts the threshold coordinate units back to the original units. + 9. Converts the interpolated cube data units to the original units, in case + these have been changed by the processing. Args: forecast_at_thresholds: @@ -191,9 +272,17 @@ def process( The threshold coordinate is always the zeroth dimension. """ self.threshold_coord = find_threshold_coordinate(forecast_at_thresholds) + self.threshold_coord_name = self.threshold_coord.name() + self.original_units = forecast_at_thresholds.units + self.original_threshold_units = self.threshold_coord.units original_mask = self.mask_checking(forecast_at_thresholds) + if self.threshold_units is not None: + forecast_at_thresholds.coord(self.threshold_coord_name).convert_units( + self.threshold_units + ) + if forecast_at_thresholds.coords("realization"): forecast_at_thresholds = collapse_realizations(forecast_at_thresholds) @@ -204,10 +293,19 @@ def process( forecast_at_thresholds, forecast_at_thresholds_data, ) + if original_mask is not None: original_mask = np.broadcast_to(original_mask, threshold_cube.shape) threshold_cube.data = np.ma.MaskedArray( threshold_cube.data, mask=original_mask ) + # Revert the threshold coordinate's units + threshold_cube.coord(self.threshold_coord_name).convert_units( + self.original_threshold_units + ) + + # Ensure the cube's overall units are restored + threshold_cube.units = self.original_units + return threshold_cube diff --git a/improver_tests/acceptance/SHA256SUMS b/improver_tests/acceptance/SHA256SUMS index 2f6b32572d..ffbee924f3 100644 --- a/improver_tests/acceptance/SHA256SUMS +++ b/improver_tests/acceptance/SHA256SUMS @@ -936,12 +936,14 @@ eb6f7c3f646c4c51a0964b9a19367f43d6e3762ff5523b982cfaf7bf2610f091 ./temporal-int e3b8f51a0be52c4fead55f95c0e3da29ee3d93f92deed26314e60ad43e8fd5ef ./temporal-interpolate/uv/20181220T1200Z-PT0024H00M-uv_index.nc b3fde693b3a8e144cb8f9ee9ff23c51ef92701858667cff850b2a49986bacaab ./temporal-interpolate/uv/kgo_t1.nc 1065ae1f25e6bc6df8d02e61c1f8ef92ab3dae679595d5165bd94d9c740adb2c ./temporal-interpolate/uv/kgo_t1_daynight.nc -3335761a3c15c0fd4336cb852970376abd6f6dac99907fe9b081e6a7672e530c ./threshold-interpolation/extra_thresholds_kgo.nc -022657626d7ae4608781c390ca9c30d9cbb949d71bedf74a2687228f5964b3e9 ./threshold-interpolation/input.nc +29eb38825b3a5e20b73128d3d2e65b5fc7e9a7670bc47bee0b45811a1139fd9c ./threshold-interpolation/extra_thresholds_kgo.nc +8e9c724ebc9b275777f15f92b16451cebe0a272a2caad0ef386b64d199e299fc ./threshold-interpolation/input.nc 12acca08e123437e07ad4e3aab81cc2fc0a3cfb72b5cb2fd06343bd5beb13f00 ./threshold-interpolation/input_realization.nc 7b172ce0d98c0f7fbfea1cde23a126d7116871bb62a221348c7ddddc35c29a0a ./threshold-interpolation/masked_cube_kgo.nc ec73679ff5e308a2bb4d21283262118f8d9fbb6a425309b76d5865a97a773c40 ./threshold-interpolation/masked_input.nc 6058009963941b539117ea44792277253d87c7a1c81318e4836406b5c0b88525 ./threshold-interpolation/realization_collapse_kgo.nc +a8663a8344931ad655645facb59bf73cbd4d83b687f441c988bf64af32fea9ed ./threshold-interpolation/threshold_config_dict.json +837187b7282b1a1ea9ab15ed1f9ff2b21ccf340dda6c9d10b3c127fc818f4815 ./threshold-interpolation/threshold_config_list.json ac93ed67c9947547e5879af6faaa329fede18afd822c720ac3afcb18fa41077a ./threshold/basic/input.nc eb3fdc9400401ec47d95961553aed452abcbd91891d0fbca106b3a05131adaa9 ./threshold/basic/kgo.nc 6b50fa16b663869b3e3fbff36197603886ff7383b2df2a8ba92579bcc9461a16 ./threshold/below_threshold/kgo.nc diff --git a/improver_tests/acceptance/test_threshold_interpolation.py b/improver_tests/acceptance/test_threshold_interpolation.py index 704dae9219..d934647276 100644 --- a/improver_tests/acceptance/test_threshold_interpolation.py +++ b/improver_tests/acceptance/test_threshold_interpolation.py @@ -15,12 +15,18 @@ def test_basic(tmp_path): """Test basic invocation with threshold argument""" - thresholds = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0" + threshold_values = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0" kgo_dir = acc.kgo_root() / "threshold-interpolation" kgo_path = kgo_dir / "extra_thresholds_kgo.nc" input_path = kgo_dir / "input.nc" output_path = tmp_path / "output.nc" - args = [input_path, "--thresholds", thresholds, "--output", f"{output_path}"] + args = [ + input_path, + "--threshold-values", + threshold_values, + "--output", + f"{output_path}", + ] run_cli(args) acc.compare(output_path, kgo_path) @@ -28,12 +34,18 @@ def test_basic(tmp_path): def test_realization_collapse(tmp_path): """Test realization coordinate is collapsed""" - thresholds = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0" + threshold_values = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0" kgo_dir = acc.kgo_root() / "threshold-interpolation" kgo_path = kgo_dir / "realization_collapse_kgo.nc" input_path = kgo_dir / "input_realization.nc" output_path = tmp_path / "output.nc" - args = [input_path, "--thresholds", thresholds, "--output", f"{output_path}"] + args = [ + input_path, + "--threshold-values", + threshold_values, + "--output", + f"{output_path}", + ] run_cli(args) acc.compare(output_path, kgo_path) @@ -41,12 +53,54 @@ def test_realization_collapse(tmp_path): def test_masked_cube(tmp_path): """Test masked cube""" - thresholds = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0" + threshold_values = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0" kgo_dir = acc.kgo_root() / "threshold-interpolation" kgo_path = kgo_dir / "masked_cube_kgo.nc" input_path = kgo_dir / "masked_input.nc" output_path = tmp_path / "output.nc" - args = [input_path, "--thresholds", thresholds, "--output", f"{output_path}"] + args = [ + input_path, + "--threshold-values", + threshold_values, + "--output", + f"{output_path}", + ] run_cli(args) acc.compare(output_path, kgo_path) + + +def test_json_dict_input(tmp_path): + """Test JSON dictionary input""" + kgo_dir = acc.kgo_root() / "threshold-interpolation" + threshold_config = kgo_dir / "threshold_config_dict.json" + kgo_path = kgo_dir / "realization_collapse_kgo.nc" + input_path = kgo_dir / "input_realization.nc" + output_path = tmp_path / "output.nc" + args = [ + input_path, + "--threshold-config", + threshold_config, + "--output", + f"{output_path}", + ] + run_cli(args) + acc.compare(output_path, kgo_path) + + +def test_json_list_input(tmp_path): + """Test JSON list input""" + kgo_dir = acc.kgo_root() / "threshold-interpolation" + threshold_config = kgo_dir / "threshold_config_list.json" + kgo_path = kgo_dir / "realization_collapse_kgo.nc" + input_path = kgo_dir / "input_realization.nc" + output_path = tmp_path / "output.nc" + args = [ + input_path, + "--threshold-config", + threshold_config, + "--output", + f"{output_path}", + ] + run_cli(args) + acc.compare(output_path, kgo_path) diff --git a/improver_tests/utilities/test_ThresholdInterpolation.py b/improver_tests/utilities/test_ThresholdInterpolation.py index a4edfef0bc..dd90604d84 100644 --- a/improver_tests/utilities/test_ThresholdInterpolation.py +++ b/improver_tests/utilities/test_ThresholdInterpolation.py @@ -62,26 +62,61 @@ def masked_cube() -> Cube: return masked_cube +@pytest.fixture +def threshold_values(): + """Set up a list of threshold values.""" + return [100, 150, 200, 250, 300] + + +@pytest.fixture +def threshold_config_dict(): + """Set up a threshold_config dictionary.""" + return { + "100.0": "None", + "150.0": "None", + "200.0": "None", + "250.0": "None", + "300.0": "None", + } + + +@pytest.fixture +def threshold_config_list(): + """Set up a threshold_config list.""" + return [100, 150, 200, 250, 300] + + @pytest.mark.parametrize("input", ["input_cube", "masked_cube"]) -def test_cube_returned(request, input): +def test_cube_returned(request, input, threshold_values): """ - Test that the plugin returns an Iris.cube.Cube with suitable units. + Test that the plugin returns an Iris.cube.Cube with suitable units and thresholds. """ cube = request.getfixturevalue(input) - thresholds = [100, 150, 200, 250, 300] - result = ThresholdInterpolation(thresholds)(cube) + result = ThresholdInterpolation(threshold_values)(cube) assert isinstance(result, Cube) assert result.units == cube.units + np.testing.assert_array_equal( + result.coord("visibility_in_air").points, threshold_values + ) + + +def test_nthreshold_units(input_cube, threshold_values): + """ + Test that the plugin can handle different threshold units. + """ + original_units = input_cube.coord("visibility_in_air").units + result = ThresholdInterpolation(threshold_values, threshold_units="km")(input_cube) + # Check that units are converted back to the original input cube's units + assert result.coord("visibility_in_air").units == original_units @pytest.mark.parametrize("input", ["input_cube", "masked_cube"]) -def test_interpolated_values(request, input): +def test_interpolated_values(request, input, threshold_values): """ Test that the interpolated values are as expected. """ cube = request.getfixturevalue(input) - thresholds = [100, 150, 200, 250, 300] - result = ThresholdInterpolation(thresholds)(cube) + result = ThresholdInterpolation(threshold_values)(cube) expected_interpolated_values = np.array( [ [[1.0, 0.9, 1.0], [0.8, 0.9, 0.5], [0.5, 0.2, 0.0]], @@ -95,48 +130,98 @@ def test_interpolated_values(request, input): np.testing.assert_array_equal(result.data, expected_interpolated_values) -def test_empty_threshold_list(): +@pytest.mark.parametrize( + "input, threshold_config", + [("input_cube", "threshold_config_dict"), ("masked_cube", "threshold_config_list")], +) +def test_threshold_config_provided(request, input, threshold_config, threshold_values): """ - Test that a ValueError is raised if the threshold list is empty. + Test that the plugin can handle threshold_config (and so JSON files) being provided + as a list of float values or a dictionary of strings that can be interpreted as + floats with the structure: "THRESHOLD_VALUE": "None". """ - with pytest.raises(ValueError, match="The thresholds list cannot be empty."): - ThresholdInterpolation([]) + cube = request.getfixturevalue(input) + threshold_config = request.getfixturevalue(threshold_config) + result = ThresholdInterpolation(threshold_config=threshold_config)(cube) + expected_interpolated_values = np.array( + [ + [[1.0, 0.9, 1.0], [0.8, 0.9, 0.5], [0.5, 0.2, 0.0]], + [[1.0, 0.7, 1.0], [0.65, 0.7, 0.4], [0.35, 0.1, 0.0]], + [[1.0, 0.5, 1.0], [0.5, 0.5, 0.3], [0.2, 0.0, 0.0]], + [[1.0, 0.35, 0.75], [0.35, 0.25, 0.2], [0.1, 0.0, 0.0]], + [[1.0, 0.2, 0.5], [0.2, 0.0, 0.1], [0.0, 0.0, 0.0]], + ], + dtype=np.float32, + ) + np.testing.assert_array_equal(result.data, expected_interpolated_values) + assert ( + result.coord("visibility_in_air").units == cube.coord("visibility_in_air").units + ) + np.testing.assert_array_equal( + result.coord("visibility_in_air").points, threshold_values + ) -def test_metadata_copy(input_cube): +def test_single_threshold_value_provided(input_cube): + """ + Test that the plugin can handle a single numeric input for threshold_values. + """ + result = ThresholdInterpolation(threshold_values=250)(input_cube) + expected_interpolated_values = np.array( + [ + [1.0, 0.35, 0.75], + [0.35, 0.25, 0.2], + [0.1, 0.0, 0.0], + ], + dtype=np.float32, + ) + np.testing.assert_array_equal(result.data, expected_interpolated_values) + + +def test_no_thresholds_provided(): + """ + Test that a ValueError is raised if neither threshold_config or threshold_values + are set. + """ + with pytest.raises( + ValueError, + match="One of threshold_config or threshold_values must be provided.", + ): + ThresholdInterpolation() + + +def test_metadata_copy(input_cube, threshold_values): """ Test that the metadata dictionaries within the input cube are also present on the output cube. """ input_cube.attributes = {"source": "ukv"} - thresholds = [100, 150, 200, 250, 300] - result = ThresholdInterpolation(thresholds)(input_cube) + result = ThresholdInterpolation(threshold_values)(input_cube) assert input_cube.metadata._asdict() == result.metadata._asdict() -def test_thresholds_different_mask(masked_cube): +def test_thresholds_different_mask(masked_cube, threshold_values): """ - Testing that a value error message is raised if masks are different across thresholds. + Testing that a value error message is raised if masks are different across + thresholds. """ masked_cube.data.mask[0, 0, 0] = True - thresholds = [100, 150, 200, 250, 300] - error_msg = "The mask is expected to be constant across different slices of the" + error_msg = "The mask is expected to be constant across different slices of" with pytest.raises(ValueError, match=error_msg): - ThresholdInterpolation(thresholds)(masked_cube) + ThresholdInterpolation(threshold_values)(masked_cube) -def test_mask_consistency(masked_cube): +def test_mask_consistency(masked_cube, threshold_values): """ Test that the mask is the same before and after ThresholdInterpolation. """ - thresholds = [100, 150, 200, 250, 300] original_mask = masked_cube.data.mask.copy() - result = ThresholdInterpolation(thresholds)(masked_cube).data.mask + result = ThresholdInterpolation(threshold_values)(masked_cube).data.mask np.testing.assert_array_equal(original_mask[0], result[0]) -def test_collapse_realizations(input_cube): +def test_collapse_realizations(input_cube, threshold_values): """ Test that the realizations are collapsed if present in the input cube. """ @@ -154,6 +239,23 @@ def test_collapse_realizations(input_cube): cube = cubes.merge_cube() - thresholds = [100, 150, 200, 250, 300] - result = ThresholdInterpolation(thresholds)(cube.copy())[0::2] + result = ThresholdInterpolation(threshold_values)(cube.copy())[0::2] np.testing.assert_array_equal(result.data, 0.5 * (cube[0].data + cube[1].data)) + + +def test_error_set_thresholds_with_config_and_values( + input_cube, threshold_config_dict, threshold_values +): + """ + Test that a ValueError is raised if both threshold_values and threshold_config + are provided. + """ + with pytest.raises( + ValueError, + match="Threshold_config and threshold_values are mutually " + "exclusive arguments - please provide one or the other, " + "not both", + ): + ThresholdInterpolation( + threshold_values, threshold_config=threshold_config_dict + )(input_cube)