Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion improver/cli/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def process(
or with structure "THRESHOLD_VALUE": "None" (no fuzzy bounds).
Repeated thresholds with different bounds are ignored; only the
last duplicate will be used.
threshold_values and and threshold_config are mutually exclusive
threshold_values and threshold_config are mutually exclusive
arguments, defining both will lead to an exception.
threshold_units (str):
Units of the threshold values. If not provided the units are
Expand Down
26 changes: 22 additions & 4 deletions improver/cli/threshold_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
def process(
forecast_at_thresholds: cli.inputcube,
*,
thresholds: cli.comma_separated_list,
threshold_values: cli.comma_separated_list = None,
threshold_config: cli.inputjson = None,
threshold_units: str = None,
):
"""
Use this CLI to modify the probability thresholds in an existing probability
Expand All @@ -22,15 +24,31 @@ def process(
Args:
forecast_at_thresholds:
Cube expected to contain a threshold coordinate.
thresholds:
List of the desired output thresholds.
threshold_values:
The desired output thresholds, either as a list of float values or a
single float value.
threshold_config:
Threshold configuration containing threshold values. It should contain
either a list of float values or a dictionary of strings that can be
interpreted as floats with the structure: "THRESHOLD_VALUE": "None".
The latter follows the format expected in improver/cli/threshold.py,
however fuzzy bounds will be ignored here.
Repeated thresholds with different bounds are ignored; only the
last duplicate will be used.
Threshold_values and threshold_config are mutually exclusive
arguments, defining both will lead to an exception.
threshold_units:
Units of the threshold values. If not provided the units are
assumed to be the same as those of the input cube.

Returns:
Cube with forecast values at the desired set of thresholds.
The threshold coordinate is always the zeroth dimension.
"""
from improver.utilities.threshold_interpolation import ThresholdInterpolation

result = ThresholdInterpolation(thresholds)(forecast_at_thresholds)
result = ThresholdInterpolation(
threshold_values, threshold_config, threshold_units
)(forecast_at_thresholds)

return result
138 changes: 118 additions & 20 deletions improver/utilities/threshold_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
# See LICENSE in the root of the repository for full licensing details.
"""Script to linearly interpolate thresholds"""

from typing import List, Optional
import numbers
from typing import Dict, List, Optional, Union

import iris
import numpy as np
from cf_units import Unit
from iris.cube import Cube
from numpy import ndarray

Expand All @@ -27,44 +29,115 @@


class ThresholdInterpolation(PostProcessingPlugin):
def __init__(self, thresholds: List[float]):
def __init__(
self,
threshold_values: Optional[Union[List[float], float]] = None,
threshold_config: Optional[Union[List[float], Dict[str, str]]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
threshold_config: Optional[Union[List[float], Dict[str, str]]] = None,
threshold_config: Optional[Union[List[float]], Dict[str, Union[List[float], str]]] = None,

I can't work out why this was changed, so if I've missed something then please let me know. I think the typing only needed the addition of Union[List[float]], ... though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Original annotation
The original annotation comes from the threshold.py plugin and suggests dictionary values can be either a list of fuzzy bound floats or a "None" string.

Annotation I suggested
My annotation suggests that if a dictionary is provided, the dictionary's values can only be a "None" string. We don't use fuzzy bounds for threshold interpolation.

I have kept my original change for now. Please let me know if I have misunderstood something

threshold_units: Optional[str] = None,
):
"""
Args:
thresholds:
List of the desired output thresholds.
threshold_values:
The desired output thresholds, either as a list of float values or a
single float value.
threshold_config:
Threshold configuration containing threshold values. It should contain
either a list of float values or a dictionary of strings that can be
interpreted as floats with the structure: "THRESHOLD_VALUE": "None".
Repeated thresholds with different bounds are ignored; only the
last duplicate will be used.
Threshold_values and threshold_config are mutually exclusive
arguments, defining both will lead to an exception.
threshold_units:
Units of the threshold values. If not provided the units are
assumed to be the same as those of the input cube.

Raises:
ValueError:
If the thresholds list is empty.
ValueError: If threshold_config and threshold_values are both set
ValueError: If neither threshold_config or threshold_values are set
"""
if not thresholds:
raise ValueError("The thresholds list cannot be empty.")
self.thresholds = thresholds
if threshold_config and threshold_values:
raise ValueError(
"Threshold_config and threshold_values are mutually exclusive "
"arguments - please provide one or the other, not both"
)
if threshold_config is None and threshold_values is None:
raise ValueError(
"One of threshold_config or threshold_values must be provided."
)
self.threshold_values = threshold_values
self.threshold_coord = None
self.threshold_config = threshold_config

thresholds = self._set_thresholds(threshold_values, threshold_config)
self.thresholds = [thresholds] if np.isscalar(thresholds) else thresholds
self.threshold_units = (
None if threshold_units is None else Unit(threshold_units)
)

self.original_units = None

@staticmethod
def _set_thresholds(
threshold_values: Optional[Union[List[float], float]] = None,
threshold_config: Optional[Union[List[float], Dict[str, str]]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
threshold_config: Optional[Union[List[float], Dict[str, str]]] = None,
threshold_config: Optional[Union[List[float]], Dict[str, Union[List[float], str]]] = None,

Same suggestion as above - sorry if I've missed something.

) -> List[float]:
"""
Interprets a threshold_config dictionary if provided, or ensures that
a list of thresholds has suitable precision.

Args:
threshold_values:
A list of threshold values or a single threshold value.
threshold_config:
Either a list of float values or a dictionary of strings that can be
interpreted as floats with the structure: "THRESHOLD_VALUE": "None".

Returns:
thresholds:
A list of input thresholds as float64 type.
"""
if threshold_config and isinstance(threshold_config, dict):
thresholds = []
for key in threshold_config.keys():
# Ensure thresholds are float64 to avoid rounding errors during
# possible unit conversion.
thresholds.append(float(key))
elif threshold_config and isinstance(threshold_config, list):
thresholds = [float(x) for x in threshold_config]
else:
# Convert threshold_values to a list if it is a single value.
if isinstance(threshold_values, numbers.Number):
threshold_values = [threshold_values]
thresholds = [float(x) for x in threshold_values]
return thresholds

def mask_checking(self, forecast_at_thresholds: Cube) -> Optional[np.ndarray]:
"""
Check if the mask is consistent across different slices of the threshold coordinate.
Check if the mask is consistent across different slices of the threshold
coordinate.

Args:
forecast_at_thresholds:
The input cube containing forecast data with a threshold coordinate.

Returns:
original_mask:
The original mask if the data is masked and the mask is consistent across
different slices of the threshold coordinate, otherwise None.
The original mask if the data is masked and the mask is consistent
across different slices of the threshold coordinate, otherwise None.

Raises:
ValueError: If the mask varies across different slices of the threshold coordinate.
ValueError: If the mask varies across different slices of the threshold
coordinate.
"""
original_mask = None
if np.ma.is_masked(forecast_at_thresholds.data):
(crd_dim,) = forecast_at_thresholds.coord_dims(self.threshold_coord.name())
if np.diff(forecast_at_thresholds.data.mask, axis=crd_dim).any():
raise ValueError(
f"The mask is expected to be constant across different slices of the {self.threshold_coord.name()}"
f" dimension, however, in the dataset provided, the mask varies across the {self.threshold_coord.name()}"
"The mask is expected to be constant across different slices of"
f"the {self.threshold_coord.name()} dimension, however, in the dataset"
f"provided, the mask varies across the {self.threshold_coord.name()}"
f" dimension. This is not currently supported."
)
else:
Expand Down Expand Up @@ -149,6 +222,9 @@ def create_cube_with_thresholds(
)
template_cube.remove_coord(self.threshold_coord)

if self.threshold_units is not None:
template_cube.units = self.threshold_units

# create cube with new threshold dimension
cubes = iris.cube.CubeList([])
for point in self.thresholds:
Expand All @@ -175,11 +251,16 @@ def process(

This method performs the following steps:
1. Identifies the threshold coordinate in the input cube.
2. Checks if the mask is consistent across different slices of the threshold coordinate.
3. Collapses the realizations if present.
4. Interpolates the forecast data to the new set of thresholds.
5. Creates a new cube with the interpolated threshold data.
6. Applies the original mask to the new cube if it exists.
2. Checks if the mask is consistent across different slices of the threshold
coordinate.
3. Converts the threshold coordinate to the specified units if provided.
4. Collapses the realizations if present.
5. Interpolates the data to the new set of thresholds.
6. Creates a new cube with the interpolated threshold data.
7. Applies the original mask to the new cube if it exists.
8. Converts the threshold coordinate units back to the original units.
9. Converts the interpolated cube data units to the original units, in case
these have been changed by the processing.

Args:
forecast_at_thresholds:
Expand All @@ -191,9 +272,17 @@ def process(
The threshold coordinate is always the zeroth dimension.
"""
self.threshold_coord = find_threshold_coordinate(forecast_at_thresholds)
self.threshold_coord_name = self.threshold_coord.name()
self.original_units = forecast_at_thresholds.units
self.original_threshold_units = self.threshold_coord.units

original_mask = self.mask_checking(forecast_at_thresholds)

if self.threshold_units is not None:
forecast_at_thresholds.coord(self.threshold_coord_name).convert_units(
self.threshold_units
)

if forecast_at_thresholds.coords("realization"):
forecast_at_thresholds = collapse_realizations(forecast_at_thresholds)

Expand All @@ -204,10 +293,19 @@ def process(
forecast_at_thresholds,
forecast_at_thresholds_data,
)

if original_mask is not None:
original_mask = np.broadcast_to(original_mask, threshold_cube.shape)
threshold_cube.data = np.ma.MaskedArray(
threshold_cube.data, mask=original_mask
)

# Revert the threshold coordinate's units
threshold_cube.coord(self.threshold_coord_name).convert_units(
self.original_threshold_units
)

# Ensure the cube's overall units are restored
threshold_cube.units = self.original_units

return threshold_cube
6 changes: 4 additions & 2 deletions improver_tests/acceptance/SHA256SUMS
Original file line number Diff line number Diff line change
Expand Up @@ -936,12 +936,14 @@ eb6f7c3f646c4c51a0964b9a19367f43d6e3762ff5523b982cfaf7bf2610f091 ./temporal-int
e3b8f51a0be52c4fead55f95c0e3da29ee3d93f92deed26314e60ad43e8fd5ef ./temporal-interpolate/uv/20181220T1200Z-PT0024H00M-uv_index.nc
b3fde693b3a8e144cb8f9ee9ff23c51ef92701858667cff850b2a49986bacaab ./temporal-interpolate/uv/kgo_t1.nc
1065ae1f25e6bc6df8d02e61c1f8ef92ab3dae679595d5165bd94d9c740adb2c ./temporal-interpolate/uv/kgo_t1_daynight.nc
3335761a3c15c0fd4336cb852970376abd6f6dac99907fe9b081e6a7672e530c ./threshold-interpolation/extra_thresholds_kgo.nc
022657626d7ae4608781c390ca9c30d9cbb949d71bedf74a2687228f5964b3e9 ./threshold-interpolation/input.nc
29eb38825b3a5e20b73128d3d2e65b5fc7e9a7670bc47bee0b45811a1139fd9c ./threshold-interpolation/extra_thresholds_kgo.nc
8e9c724ebc9b275777f15f92b16451cebe0a272a2caad0ef386b64d199e299fc ./threshold-interpolation/input.nc
12acca08e123437e07ad4e3aab81cc2fc0a3cfb72b5cb2fd06343bd5beb13f00 ./threshold-interpolation/input_realization.nc
7b172ce0d98c0f7fbfea1cde23a126d7116871bb62a221348c7ddddc35c29a0a ./threshold-interpolation/masked_cube_kgo.nc
ec73679ff5e308a2bb4d21283262118f8d9fbb6a425309b76d5865a97a773c40 ./threshold-interpolation/masked_input.nc
6058009963941b539117ea44792277253d87c7a1c81318e4836406b5c0b88525 ./threshold-interpolation/realization_collapse_kgo.nc
a8663a8344931ad655645facb59bf73cbd4d83b687f441c988bf64af32fea9ed ./threshold-interpolation/threshold_config_dict.json
837187b7282b1a1ea9ab15ed1f9ff2b21ccf340dda6c9d10b3c127fc818f4815 ./threshold-interpolation/threshold_config_list.json
ac93ed67c9947547e5879af6faaa329fede18afd822c720ac3afcb18fa41077a ./threshold/basic/input.nc
eb3fdc9400401ec47d95961553aed452abcbd91891d0fbca106b3a05131adaa9 ./threshold/basic/kgo.nc
6b50fa16b663869b3e3fbff36197603886ff7383b2df2a8ba92579bcc9461a16 ./threshold/below_threshold/kgo.nc
Expand Down
66 changes: 60 additions & 6 deletions improver_tests/acceptance/test_threshold_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,92 @@

def test_basic(tmp_path):
"""Test basic invocation with threshold argument"""
thresholds = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0"
threshold_values = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0"
kgo_dir = acc.kgo_root() / "threshold-interpolation"
kgo_path = kgo_dir / "extra_thresholds_kgo.nc"
input_path = kgo_dir / "input.nc"
output_path = tmp_path / "output.nc"
args = [input_path, "--thresholds", thresholds, "--output", f"{output_path}"]
args = [
input_path,
"--threshold-values",
threshold_values,
"--output",
f"{output_path}",
]

run_cli(args)
acc.compare(output_path, kgo_path)


def test_realization_collapse(tmp_path):
"""Test realization coordinate is collapsed"""
thresholds = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0"
threshold_values = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0"
kgo_dir = acc.kgo_root() / "threshold-interpolation"
kgo_path = kgo_dir / "realization_collapse_kgo.nc"
input_path = kgo_dir / "input_realization.nc"
output_path = tmp_path / "output.nc"
args = [input_path, "--thresholds", thresholds, "--output", f"{output_path}"]
args = [
input_path,
"--threshold-values",
threshold_values,
"--output",
f"{output_path}",
]

run_cli(args)
acc.compare(output_path, kgo_path)


def test_masked_cube(tmp_path):
"""Test masked cube"""
thresholds = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0"
threshold_values = "50.0,200.0,400.0,600.0,1000.0,2000.0,10000.0,25000.0,40000.0"
kgo_dir = acc.kgo_root() / "threshold-interpolation"
kgo_path = kgo_dir / "masked_cube_kgo.nc"
input_path = kgo_dir / "masked_input.nc"
output_path = tmp_path / "output.nc"
args = [input_path, "--thresholds", thresholds, "--output", f"{output_path}"]
args = [
input_path,
"--threshold-values",
threshold_values,
"--output",
f"{output_path}",
]

run_cli(args)
acc.compare(output_path, kgo_path)


def test_json_dict_input(tmp_path):
"""Test JSON dictionary input"""
kgo_dir = acc.kgo_root() / "threshold-interpolation"
threshold_config = kgo_dir / "threshold_config_dict.json"
kgo_path = kgo_dir / "realization_collapse_kgo.nc"
input_path = kgo_dir / "input_realization.nc"
output_path = tmp_path / "output.nc"
args = [
input_path,
"--threshold-config",
threshold_config,
"--output",
f"{output_path}",
]
run_cli(args)
acc.compare(output_path, kgo_path)


def test_json_list_input(tmp_path):
"""Test JSON list input"""
kgo_dir = acc.kgo_root() / "threshold-interpolation"
threshold_config = kgo_dir / "threshold_config_list.json"
kgo_path = kgo_dir / "realization_collapse_kgo.nc"
input_path = kgo_dir / "input_realization.nc"
output_path = tmp_path / "output.nc"
args = [
input_path,
"--threshold-config",
threshold_config,
"--output",
f"{output_path}",
]
run_cli(args)
acc.compare(output_path, kgo_path)
Loading