diff --git a/xarray_beam/__init__.py b/xarray_beam/__init__.py index f2bad3f..24dbeb7 100644 --- a/xarray_beam/__init__.py +++ b/xarray_beam/__init__.py @@ -55,4 +55,4 @@ DatasetToZarr as DatasetToZarr, ) -__version__ = '0.11.3' # automatically synchronized to pyproject.toml +__version__ = '0.11.4' # automatically synchronized to pyproject.toml diff --git a/xarray_beam/_src/zarr.py b/xarray_beam/_src/zarr.py index 0500f2a..60775c2 100644 --- a/xarray_beam/_src/zarr.py +++ b/xarray_beam/_src/zarr.py @@ -151,12 +151,13 @@ def replace_template_dims( template: xarray.Dataset, **dim_replacements: int | np.ndarray | pd.Index | xarray.DataArray, ) -> xarray.Dataset: + # pyformat: disable """Replaces dimension(s) in a template with updates coordinates and/or sizes. This is convenient for creating templates from evaluated results for a single chunk. - Example usage: + Example usage:: import numpy as np import pandas as pd @@ -178,27 +179,21 @@ def replace_template_dims( # Dimensions: (time: 1, longitude: 1440, latitude: 721) # Coordinates: # * time (time) datetime64[ns] 8B 1940-01-01 - # * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 - 359.8 - # * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 - 90.0 + # * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8 + # * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 90.0 # Data variables: - # foo (time, longitude, latitude) float64 8MB - dask.array + # foo (time, longitude, latitude) float64 8MB dask.array template = xbeam.replace_template_dims(template, time=times) print(template) # Size: 6TB # Dimensions: (time: 747769, longitude: 1440, latitude: 721) # Coordinates: - # * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 - 359.8 - # * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 - 90.0 + # * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8 + # * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 90.0 # * time (time) datetime64[ns] 6MB 1940-01-01 ... 2025-04-21 # Data variables: - # foo (time, longitude, latitude) float64 6TB - dask.array + # foo (time, longitude, latitude) float64 6TB dask.array Args: template: The template to replace dimensions in. @@ -209,7 +204,8 @@ def replace_template_dims( Returns: Template with the replaced dimensions. """ - expansions = {} + # pyformat: enable + expansions_with_axes = {} for name, variable in template.items(): if variable.chunks is None: raise ValueError( @@ -217,14 +213,16 @@ def replace_template_dims( ' xarray_beam.make_template() to create a valid template before ' f' calling replace_template_dims(): {template}' ) - expansions[name] = { - dim: replacement - for dim, replacement in dim_replacements.items() - if dim in variable.dims - } + # identify which dimensions of this variable need to be replaced, in order + dims_to_replace = [dim for dim in variable.dims if dim in dim_replacements] + if dims_to_replace: + expansions = {dim: dim_replacements[dim] for dim in dims_to_replace} + axes = [variable.dims.index(dim) for dim in dims_to_replace] + expansions_with_axes[name] = (expansions, axes) + template = template.isel({dim: 0 for dim in dim_replacements}, drop=True) - for name, variable in template.items(): - template[name] = variable.expand_dims(expansions[name]) + for name, (expansions, axes) in expansions_with_axes.items(): + template[name] = template[name].expand_dims(expansions, axis=axes) return template diff --git a/xarray_beam/_src/zarr_test.py b/xarray_beam/_src/zarr_test.py index 05c0852..a7d9cf9 100644 --- a/xarray_beam/_src/zarr_test.py +++ b/xarray_beam/_src/zarr_test.py @@ -182,6 +182,17 @@ def test_replace_template_dims_multiple_vars(self): self.assertIsInstance(new_template.bar.data, da.Array) self.assertIsInstance(new_template.baz.data, da.Array) + def test_replace_template_dims_multiple_dims_unordered(self): + source = xarray.Dataset( + {'foo': (('x', 'y', 'z'), np.zeros((1, 2, 3)))}, + coords={'x': [0], 'y': [10, 20], 'z': [1, 2, 3]}, + ) + template = xbeam.make_template(source) + new_template = xbeam.replace_template_dims(template, z=4, x=5) + + self.assertEqual(new_template.sizes, {'x': 5, 'y': 2, 'z': 4}) + self.assertEqual(new_template.foo.dims, ('x', 'y', 'z')) + def test_replace_template_dims_error_on_non_template(self): source = xarray.Dataset({'foo': ('x', np.zeros(1))}) # Not a template with self.assertRaisesRegex(ValueError, 'is not chunked with Dask'):