From 097389cdb0d983c9df776e3efa21e537959c8131 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Mon, 10 Nov 2025 14:26:19 +1100 Subject: [PATCH 1/2] add optional undo to dataset aligner --- .../pipeline/operations/xarray/_sort.py | 45 ++++++++++-- .../tests/operations/xarray/test_sort.py | 71 +++++++++---------- 2 files changed, 73 insertions(+), 43 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py index 1eb0c59a..553ad71f 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import TypeVar, Optional +from typing import TypeVar, Optional, Literal, Union, Type import xarray as xr @@ -34,7 +34,40 @@ class AlignDataVariableDimensionsToDatasetCoords(Operation): This operator will align all of the data variables to the same ordering as present in the Dataset. """ + def __init__( + self, + undo: bool = False, + *, + split_tuples: Literal["apply", "undo", True, False] = False, + recursively_split_tuples: bool = False, + operation: Literal["apply", "undo", "both"] = "both", + recognised_types: Optional[ + Union[ + tuple[Type, ...], + Type, + dict[Literal["apply", "undo"], Union[tuple[Type, ...], Type]], + ] + ] = None, + response_on_type: Literal["warn", "exception", "ignore", "filter"] = "exception", + **kwargs, + ): + super().__init__( + split_tuples=split_tuples, + recursively_split_tuples=recursively_split_tuples, + operation=operation, + recognised_types=recognised_types, + response_on_type=response_on_type, + **kwargs, + ) + self.undo = undo + self.recorded_ordering = {} + def apply_func(self, data: xr.Dataset) -> xr.Dataset: + + if self.undo: + # record the original dimensions of the array + self.recorded_ordering = {array_name: data[array_name].dims for array_name in data} + # use coords.dim for when coordinates don't have the same name as dimensions dataset_ordering = list(data.coords.dims) @@ -42,9 +75,13 @@ def apply_func(self, data: xr.Dataset) -> xr.Dataset: return data def undo_func(self, data: xr.Dataset) -> xr.Dataset: - # TODO: Record all the original orderings and transpose them back, I guess - - raise NotImplementedError("Don't yet know how to undo data variable alignment.") + if self.undo: + # new dataset so input data isn't augmented. + return xr.Dataset( + {array_name: data[array_name].transpose(*self.recorded_ordering[array_name]) for array_name in data} + ) + else: + return data class Sort(Operation): diff --git a/packages/pipeline/tests/operations/xarray/test_sort.py b/packages/pipeline/tests/operations/xarray/test_sort.py index cf5f72f9..18534aec 100644 --- a/packages/pipeline/tests/operations/xarray/test_sort.py +++ b/packages/pipeline/tests/operations/xarray/test_sort.py @@ -35,55 +35,48 @@ ) SIMPLE_DS1 = xr.Dataset({"Temperature": SIMPLE_DA1}) SIMPLE_DS2 = xr.Dataset({"Humidity": SIMPLE_DA1, "Temperature": SIMPLE_DA1, "WombatsPerKm2": SIMPLE_DA1}) +SIMPLE_DA_WITH_NAMED_COORDS = xr.DataArray( + SIMPLE_DA1.data, + coords={"h": ("height", [10, 20]), "x": ("lat", [0, 1, 2]), "y": ("lon", [5, 6, 7])}, + dims=["height", "lat", "lon"], +) -def test_align(): +@pytest.mark.parametrize("undo,unaligned_dataarray", [(True, SIMPLE_DA1), (False, SIMPLE_DA_WITH_NAMED_COORDS)]) +def test_align(undo, unaligned_dataarray): """Tests that the dataset dimension alignment operation works.""" - align_op = AlignDataVariableDimensionsToDatasetCoords() - # create dataset with arrays that are not consistently ordered - ds = xr.Dataset( - { - "Temperature": SIMPLE_DA1.transpose("lat", "height", "lon"), - "Humidity": SIMPLE_DA1, - "WombatsPerKm2": SIMPLE_DA1.transpose("lon", "height", "lat"), - } - ) + # instantiate align op + align_op = AlignDataVariableDimensionsToDatasetCoords(undo=undo) - # check that dataset dims are indeed unaligned - assert ds["Temperature"].dims != ds["Humidity"].dims - assert ds["Temperature"].dims != ds["WombatsPerKm2"].dims - - # apply aligner to dataset and check that dataset dims now align - ds_aligned = align_op.apply_func(ds) - assert ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims - assert ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims - - ## Test that alignment works even when coordinate names don't match dims - da_with_named_coords = xr.DataArray( - SIMPLE_DA1.data, - coords={"h": ("height", [10, 20]), "x": ("lat", [0, 1, 2]), "y": ("lon", [5, 6, 7])}, - dims=["height", "lat", "lon"], - ) - ds = xr.Dataset( + # create dataset with dims that are unaligned + unaligned_dataset = xr.Dataset( { - "Temperature": da_with_named_coords.transpose("lat", "height", "lon"), - "Humidity": da_with_named_coords, - "WombatsPerKm2": da_with_named_coords.transpose("lon", "height", "lat"), + "Temperature": unaligned_dataarray.transpose("lat", "lon", "height"), + "Humidity": unaligned_dataarray, + "WombatsPerKm2": unaligned_dataarray.transpose("lon", "height", "lat"), } ) - # check that dataset dims are indeed unaligned - assert ds["Temperature"].dims != ds["Humidity"].dims - assert ds["Temperature"].dims != ds["WombatsPerKm2"].dims # apply aligner to dataset and check that dataset dims now align - ds_aligned = align_op.apply_func(ds) - assert ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims - assert ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims - - # placeholder test for undo method - with pytest.raises(NotImplementedError): - align_op.undo_func(ds) + ds_aligned = align_op.apply_func(unaligned_dataset) + assert ( + ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims + ), "Humidity DataArray dims not aligned to Temperature DataArray dims." + assert ( + ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims + ), "WombatsPerKm2 DataArray dims not aligned to Temperature DataArray dims." + + # test undo alignment + ds_aligned_undone = align_op.undo_func(ds_aligned) + assert (ds_aligned_undone == ds_aligned).all(), "Underlying data was modified in undoing the dimension alignment." + for array_name in unaligned_dataset: + if undo: + assert ( + ds_aligned_undone[array_name].dims == unaligned_dataset[array_name].dims + ), "Undoing dimensions failed." + else: + assert ds_aligned_undone[array_name].dims == ds_aligned[array_name].dims, "Dimensions changed." def test_Sort(): From 16853fbfcb8ef916ef2cc11e73022d0a8b8d1fa5 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Tue, 11 Nov 2025 10:29:50 +1100 Subject: [PATCH 2/2] add doc string to init method --- .../pipeline/operations/xarray/_sort.py | 25 ++++++++++++++++--- .../tests/operations/xarray/test_sort.py | 2 +- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py index 553ad71f..00d9eac8 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py @@ -36,7 +36,7 @@ class AlignDataVariableDimensionsToDatasetCoords(Operation): def __init__( self, - undo: bool = False, + restore_dim_order_on_undo: bool = False, *, split_tuples: Literal["apply", "undo", True, False] = False, recursively_split_tuples: bool = False, @@ -51,6 +51,23 @@ def __init__( response_on_type: Literal["warn", "exception", "ignore", "filter"] = "exception", **kwargs, ): + """ + Constructs the AlignDataVariableDimensionsToDatasetCoords operation. + + Args: + restore_dim_order_on_undo (bool): Whether to undo the dimension ordering in the class' undo operation. Defaults to False. + split_tuples (Literal['apply', 'undo', True, False], optional): + Split tuples on associated actions, if bool, apply to all functions. Defaults to False. + recursively_split_tuples (bool, optional): + Recursively split tuples. Defaults to False. + operation (Literal['apply', 'undo', 'both'], optional): + Which functions to apply operation to. + If not 'apply' apply does nothing, same for `undo`. Defaults to "both". + recognised_types (Optional[Union[tuple[Type, ...], Type, dict[str, Union[tuple[Type, ...], Type]]] ], optional): + Types recognised, can be dictionary to reference different types per function Defaults to None. + response_on_type (Literal['warn', 'exception', 'ignore', 'filter'], optional): + Response when invalid type found. Defaults to "exception". + """ super().__init__( split_tuples=split_tuples, recursively_split_tuples=recursively_split_tuples, @@ -59,12 +76,12 @@ def __init__( response_on_type=response_on_type, **kwargs, ) - self.undo = undo + self.restore_dim_order_on_undo = restore_dim_order_on_undo self.recorded_ordering = {} def apply_func(self, data: xr.Dataset) -> xr.Dataset: - if self.undo: + if self.restore_dim_order_on_undo: # record the original dimensions of the array self.recorded_ordering = {array_name: data[array_name].dims for array_name in data} @@ -75,7 +92,7 @@ def apply_func(self, data: xr.Dataset) -> xr.Dataset: return data def undo_func(self, data: xr.Dataset) -> xr.Dataset: - if self.undo: + if self.restore_dim_order_on_undo: # new dataset so input data isn't augmented. return xr.Dataset( {array_name: data[array_name].transpose(*self.recorded_ordering[array_name]) for array_name in data} diff --git a/packages/pipeline/tests/operations/xarray/test_sort.py b/packages/pipeline/tests/operations/xarray/test_sort.py index 18534aec..18cf517a 100644 --- a/packages/pipeline/tests/operations/xarray/test_sort.py +++ b/packages/pipeline/tests/operations/xarray/test_sort.py @@ -47,7 +47,7 @@ def test_align(undo, unaligned_dataarray): """Tests that the dataset dimension alignment operation works.""" # instantiate align op - align_op = AlignDataVariableDimensionsToDatasetCoords(undo=undo) + align_op = AlignDataVariableDimensionsToDatasetCoords(restore_dim_order_on_undo=undo) # create dataset with dims that are unaligned unaligned_dataset = xr.Dataset(