Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from typing import TypeVar, Optional
from typing import TypeVar, Optional, Literal, Union, Type

import xarray as xr

Expand All @@ -34,17 +34,71 @@ class AlignDataVariableDimensionsToDatasetCoords(Operation):
This operator will align all of the data variables to the same ordering as present in the Dataset.
"""

def __init__(
self,
restore_dim_order_on_undo: bool = False,
*,
split_tuples: Literal["apply", "undo", True, False] = False,
recursively_split_tuples: bool = False,
operation: Literal["apply", "undo", "both"] = "both",
recognised_types: Optional[
Union[
tuple[Type, ...],
Type,
dict[Literal["apply", "undo"], Union[tuple[Type, ...], Type]],
]
] = None,
response_on_type: Literal["warn", "exception", "ignore", "filter"] = "exception",
**kwargs,
):
"""
Constructs the AlignDataVariableDimensionsToDatasetCoords operation.

Args:
restore_dim_order_on_undo (bool): Whether to undo the dimension ordering in the class' undo operation. Defaults to False.
split_tuples (Literal['apply', 'undo', True, False], optional):
Split tuples on associated actions, if bool, apply to all functions. Defaults to False.
recursively_split_tuples (bool, optional):
Recursively split tuples. Defaults to False.
operation (Literal['apply', 'undo', 'both'], optional):
Which functions to apply operation to.
If not 'apply' apply does nothing, same for `undo`. Defaults to "both".
recognised_types (Optional[Union[tuple[Type, ...], Type, dict[str, Union[tuple[Type, ...], Type]]] ], optional):
Types recognised, can be dictionary to reference different types per function Defaults to None.
response_on_type (Literal['warn', 'exception', 'ignore', 'filter'], optional):
Response when invalid type found. Defaults to "exception".
"""
super().__init__(
split_tuples=split_tuples,
recursively_split_tuples=recursively_split_tuples,
operation=operation,
recognised_types=recognised_types,
response_on_type=response_on_type,
**kwargs,
)
self.restore_dim_order_on_undo = restore_dim_order_on_undo
self.recorded_ordering = {}

def apply_func(self, data: xr.Dataset) -> xr.Dataset:

if self.restore_dim_order_on_undo:
# record the original dimensions of the array
self.recorded_ordering = {array_name: data[array_name].dims for array_name in data}

# use coords.dim for when coordinates don't have the same name as dimensions
dataset_ordering = list(data.coords.dims)

data = data.transpose(*dataset_ordering)
return data

def undo_func(self, data: xr.Dataset) -> xr.Dataset:
# TODO: Record all the original orderings and transpose them back, I guess

raise NotImplementedError("Don't yet know how to undo data variable alignment.")
if self.restore_dim_order_on_undo:
# new dataset so input data isn't augmented.
return xr.Dataset(
{array_name: data[array_name].transpose(*self.recorded_ordering[array_name]) for array_name in data}
)
else:
return data


class Sort(Operation):
Expand Down
71 changes: 32 additions & 39 deletions packages/pipeline/tests/operations/xarray/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,55 +35,48 @@
)
SIMPLE_DS1 = xr.Dataset({"Temperature": SIMPLE_DA1})
SIMPLE_DS2 = xr.Dataset({"Humidity": SIMPLE_DA1, "Temperature": SIMPLE_DA1, "WombatsPerKm2": SIMPLE_DA1})
SIMPLE_DA_WITH_NAMED_COORDS = xr.DataArray(
SIMPLE_DA1.data,
coords={"h": ("height", [10, 20]), "x": ("lat", [0, 1, 2]), "y": ("lon", [5, 6, 7])},
dims=["height", "lat", "lon"],
)


def test_align():
@pytest.mark.parametrize("undo,unaligned_dataarray", [(True, SIMPLE_DA1), (False, SIMPLE_DA_WITH_NAMED_COORDS)])
def test_align(undo, unaligned_dataarray):
"""Tests that the dataset dimension alignment operation works."""
align_op = AlignDataVariableDimensionsToDatasetCoords()

# create dataset with arrays that are not consistently ordered
ds = xr.Dataset(
{
"Temperature": SIMPLE_DA1.transpose("lat", "height", "lon"),
"Humidity": SIMPLE_DA1,
"WombatsPerKm2": SIMPLE_DA1.transpose("lon", "height", "lat"),
}
)
# instantiate align op
align_op = AlignDataVariableDimensionsToDatasetCoords(restore_dim_order_on_undo=undo)

# check that dataset dims are indeed unaligned
assert ds["Temperature"].dims != ds["Humidity"].dims
assert ds["Temperature"].dims != ds["WombatsPerKm2"].dims

# apply aligner to dataset and check that dataset dims now align
ds_aligned = align_op.apply_func(ds)
assert ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims
assert ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims

## Test that alignment works even when coordinate names don't match dims
da_with_named_coords = xr.DataArray(
SIMPLE_DA1.data,
coords={"h": ("height", [10, 20]), "x": ("lat", [0, 1, 2]), "y": ("lon", [5, 6, 7])},
dims=["height", "lat", "lon"],
)
ds = xr.Dataset(
# create dataset with dims that are unaligned
unaligned_dataset = xr.Dataset(
{
"Temperature": da_with_named_coords.transpose("lat", "height", "lon"),
"Humidity": da_with_named_coords,
"WombatsPerKm2": da_with_named_coords.transpose("lon", "height", "lat"),
"Temperature": unaligned_dataarray.transpose("lat", "lon", "height"),
"Humidity": unaligned_dataarray,
"WombatsPerKm2": unaligned_dataarray.transpose("lon", "height", "lat"),
}
)
# check that dataset dims are indeed unaligned
assert ds["Temperature"].dims != ds["Humidity"].dims
assert ds["Temperature"].dims != ds["WombatsPerKm2"].dims

# apply aligner to dataset and check that dataset dims now align
ds_aligned = align_op.apply_func(ds)
assert ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims
assert ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims

# placeholder test for undo method
with pytest.raises(NotImplementedError):
align_op.undo_func(ds)
ds_aligned = align_op.apply_func(unaligned_dataset)
assert (
ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims
), "Humidity DataArray dims not aligned to Temperature DataArray dims."
assert (
ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims
), "WombatsPerKm2 DataArray dims not aligned to Temperature DataArray dims."

# test undo alignment
ds_aligned_undone = align_op.undo_func(ds_aligned)
assert (ds_aligned_undone == ds_aligned).all(), "Underlying data was modified in undoing the dimension alignment."
for array_name in unaligned_dataset:
if undo:
assert (
ds_aligned_undone[array_name].dims == unaligned_dataset[array_name].dims
), "Undoing dimensions failed."
else:
assert ds_aligned_undone[array_name].dims == ds_aligned[array_name].dims, "Dimensions changed."


def test_Sort():
Expand Down