From 7ea8f6937146204c5ccd6128fa597371a9684aae Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 2 Nov 2025 21:49:52 -0800 Subject: [PATCH] Round-up zarr shards when they are equal to the full dimension size Otherwise, it is extremely hard to use shards for dimensions with irregular sizes (e.g., "time" in a reanalysis dataset). With this change, a dataset with `sizes={'x': 19}` and `chunks={'x': 10}` can be sharded with `shards={'x': 20}`. PiperOrigin-RevId: 827316686 --- xarray_beam/_src/dataset.py | 32 ++++++++++++++++++++++---------- xarray_beam/_src/dataset_test.py | 22 ++++++++++++++++++++++ 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/xarray_beam/_src/dataset.py b/xarray_beam/_src/dataset.py index 19aede1..fe10c91 100644 --- a/xarray_beam/_src/dataset.py +++ b/xarray_beam/_src/dataset.py @@ -392,7 +392,6 @@ def method(self: Dataset, *args, **kwargs) -> Dataset: return method - class _CountNamer: def __init__(self): @@ -408,6 +407,7 @@ def apply(self, name: str) -> str: @dataclasses.dataclass(frozen=True) class _LazyPCollection: """Pipeline and PTransform not yet been combined into a PCollection.""" + # Beam does not provide a public API for manipulating Pipeline objects, so # instead of applying pipelines eagerly, we store them in this wrapper. This # allows for performance optimizations specialized to Xarray-Beam PTransforms, @@ -715,12 +715,16 @@ def _check_shards_or_chunks( zarr_chunks: Mapping[str, int], chunks_name: Literal['shards', 'chunks'], ) -> None: - if any(self.chunks[k] % zarr_chunks[k] for k in self.chunks): - raise ValueError( - f'cannot write a dataset with chunks {self.chunks} to Zarr with ' - f'{chunks_name} {zarr_chunks}, which do not divide evenly into ' - f'{chunks_name}' - ) + for k in self.chunks: + if ( + self.chunks[k] % zarr_chunks[k] + and self.chunks[k] != self.template.sizes[k] + ): + raise ValueError( + f'cannot write a dataset with chunks {self.chunks} to Zarr with ' + f'{chunks_name} {zarr_chunks}, which do not divide evenly into ' + f'{chunks_name}' + ) def to_zarr( self, @@ -804,6 +808,16 @@ def to_zarr( previous_chunks=self.chunks, ) if zarr_shards is not None: + # Zarr shards are currently constrained to be an integer multiple of + # chunk sizes, which means shard sizes must be rounded up to be larger + # than the full dimension size. This will likely be relaxed in the future: + # https://github.com/zarr-developers/zarr-extensions/issues/34 + zarr_shards = dict(zarr_shards) + for k in zarr_shards: + if zarr_shards[k] == self.sizes[k]: + zarr_shards[k] = ( + math.ceil(zarr_shards[k] / zarr_chunks[k]) * zarr_chunks[k] + ) self._check_shards_or_chunks(zarr_shards, 'shards') else: self._check_shards_or_chunks(zarr_chunks, 'chunks') @@ -956,9 +970,7 @@ def rechunk( ): # Rechunking can be performed by re-reading the source dataset with new # chunks, rather than using a separate rechunking transform. - ptransform = core.DatasetToChunks( - ptransform.dataset, chunks, split_vars - ) + ptransform = core.DatasetToChunks(ptransform.dataset, chunks, split_vars) ptransform.label = _concat_labels(ptransform.label, label) if pipeline is not None: ptransform = _LazyPCollection(pipeline, ptransform) diff --git a/xarray_beam/_src/dataset_test.py b/xarray_beam/_src/dataset_test.py index a2b813a..5d60685 100644 --- a/xarray_beam/_src/dataset_test.py +++ b/xarray_beam/_src/dataset_test.py @@ -715,6 +715,28 @@ def test_to_zarr_shards(self): zarr_shards={'x': 9}, ) + @parameterized.named_parameters( + dict(testcase_name='empty', zarr_shards={}), + dict(testcase_name='minus_one', zarr_shards=-1), + dict(testcase_name='explicit_19', zarr_shards={'x': 19}), + dict(testcase_name='explicit_20', zarr_shards={'x': 20}), + ) + def test_to_zarr_shards_round_up(self, zarr_shards): + temp_dir = self.create_tempdir().full_path + ds = xarray.Dataset({'foo': ('x', np.arange(19, dtype='int64'))}) + beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 19}) + + with beam.Pipeline() as p: + p |= beam_ds.to_zarr( + temp_dir, + zarr_chunks={'x': 10}, + zarr_shards=zarr_shards, + ) + opened, chunks = xbeam.open_zarr(temp_dir) + xarray.testing.assert_identical(ds, opened) + self.assertEqual(chunks, {'x': 10}) + self.assertEqual(opened['foo'].encoding['shards'], (20,)) + def test_to_zarr_chunks_per_shard(self): temp_dir = self.create_tempdir().full_path ds = xarray.Dataset({'foo': ('x', np.arange(12))})