Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xarray_beam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@
DatasetToZarr as DatasetToZarr,
)

__version__ = '0.11.3' # automatically synchronized to pyproject.toml
__version__ = '0.11.4' # automatically synchronized to pyproject.toml
40 changes: 19 additions & 21 deletions xarray_beam/_src/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,13 @@ def replace_template_dims(
template: xarray.Dataset,
**dim_replacements: int | np.ndarray | pd.Index | xarray.DataArray,
) -> xarray.Dataset:
# pyformat: disable
"""Replaces dimension(s) in a template with updates coordinates and/or sizes.

This is convenient for creating templates from evaluated results for a
single chunk.

Example usage:
Example usage::

import numpy as np
import pandas as pd
Expand All @@ -178,27 +179,21 @@ def replace_template_dims(
# Dimensions: (time: 1, longitude: 1440, latitude: 721)
# Coordinates:
# * time (time) datetime64[ns] 8B 1940-01-01
# * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5
359.8
# * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75
90.0
# * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
# * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 90.0
# Data variables:
# foo (time, longitude, latitude) float64 8MB
dask.array<chunksize=(1, 1440, 721), meta=np.ndarray>
# foo (time, longitude, latitude) float64 8MB dask.array<chunksize=(1, 1440, 721), meta=np.ndarray>

template = xbeam.replace_template_dims(template, time=times)
print(template)
# <xarray.Dataset> Size: 6TB
# Dimensions: (time: 747769, longitude: 1440, latitude: 721)
# Coordinates:
# * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5
359.8
# * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75
90.0
# * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
# * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 90.0
# * time (time) datetime64[ns] 6MB 1940-01-01 ... 2025-04-21
# Data variables:
# foo (time, longitude, latitude) float64 6TB
dask.array<chunksize=(747769, 1440, 721), meta=np.ndarray>
# foo (time, longitude, latitude) float64 6TB dask.array<chunksize=(747769, 1440, 721), meta=np.ndarray>

Args:
template: The template to replace dimensions in.
Expand All @@ -209,22 +204,25 @@ def replace_template_dims(
Returns:
Template with the replaced dimensions.
"""
expansions = {}
# pyformat: enable
expansions_with_axes = {}
for name, variable in template.items():
if variable.chunks is None:
raise ValueError(
f'Data variable {name} is not chunked with Dask. Please call'
' xarray_beam.make_template() to create a valid template before '
f' calling replace_template_dims(): {template}'
)
expansions[name] = {
dim: replacement
for dim, replacement in dim_replacements.items()
if dim in variable.dims
}
# identify which dimensions of this variable need to be replaced, in order
dims_to_replace = [dim for dim in variable.dims if dim in dim_replacements]
if dims_to_replace:
expansions = {dim: dim_replacements[dim] for dim in dims_to_replace}
axes = [variable.dims.index(dim) for dim in dims_to_replace]
expansions_with_axes[name] = (expansions, axes)

template = template.isel({dim: 0 for dim in dim_replacements}, drop=True)
for name, variable in template.items():
template[name] = variable.expand_dims(expansions[name])
for name, (expansions, axes) in expansions_with_axes.items():
template[name] = template[name].expand_dims(expansions, axis=axes)
return template


Expand Down
11 changes: 11 additions & 0 deletions xarray_beam/_src/zarr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,17 @@ def test_replace_template_dims_multiple_vars(self):
self.assertIsInstance(new_template.bar.data, da.Array)
self.assertIsInstance(new_template.baz.data, da.Array)

def test_replace_template_dims_multiple_dims_unordered(self):
source = xarray.Dataset(
{'foo': (('x', 'y', 'z'), np.zeros((1, 2, 3)))},
coords={'x': [0], 'y': [10, 20], 'z': [1, 2, 3]},
)
template = xbeam.make_template(source)
new_template = xbeam.replace_template_dims(template, z=4, x=5)

self.assertEqual(new_template.sizes, {'x': 5, 'y': 2, 'z': 4})
self.assertEqual(new_template.foo.dims, ('x', 'y', 'z'))

def test_replace_template_dims_error_on_non_template(self):
source = xarray.Dataset({'foo': ('x', np.zeros(1))}) # Not a template
with self.assertRaisesRegex(ValueError, 'is not chunked with Dask'):
Expand Down