From afd3f80f8e379e1f437bc665d7cb479c96ac383b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 24 Oct 2025 14:53:32 -0700 Subject: [PATCH] Add more validation to xbeam.Dataset.map_blocks PiperOrigin-RevId: 823674644 --- xarray_beam/__init__.py | 2 +- xarray_beam/_src/dataset.py | 16 ++++++++++++++++ xarray_beam/_src/dataset_test.py | 28 ++++++++++++++++++++++++---- 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/xarray_beam/__init__.py b/xarray_beam/__init__.py index 39d9870..61aacc8 100644 --- a/xarray_beam/__init__.py +++ b/xarray_beam/__init__.py @@ -55,4 +55,4 @@ DatasetToZarr as DatasetToZarr, ) -__version__ = '0.11.1' # automatically synchronized to pyproject.toml +__version__ = '0.11.2' # automatically synchronized to pyproject.toml diff --git a/xarray_beam/_src/dataset.py b/xarray_beam/_src/dataset.py index af947e4..0a0d1b9 100644 --- a/xarray_beam/_src/dataset.py +++ b/xarray_beam/_src/dataset.py @@ -828,6 +828,22 @@ def map_blocks( new_sizes=template.sizes, ) # pytype: disable=wrong-arg-types + for dim, old_chunks in self.chunks.items(): + if old_chunks < self.sizes[dim]: + if dim not in template.dims: + raise ValueError( + f'dimension {dim!r} has multiple chunks on the source dataset, ' + 'and therefore must be included in the result of map_blocks, but ' + f'is not in the new template: {template}' + ) + old_chunk_count = math.ceil(self.sizes[dim] / old_chunks) + new_chunk_count = math.ceil(template.sizes[dim] / chunks[dim]) + if old_chunk_count != new_chunk_count: + raise ValueError( + f'dimension {dim!r} has {old_chunk_count} chunks on the source ' + f'dataset and {new_chunk_count} in the result of map_blocks' + ) + label = _get_label('map_blocks') func_name = getattr(func, '__name__', None) name = f'map-blocks-{func_name}' if func_name else 'map-blocks' diff --git a/xarray_beam/_src/dataset_test.py b/xarray_beam/_src/dataset_test.py index 80bfe85..0a1aefb 100644 --- a/xarray_beam/_src/dataset_test.py +++ b/xarray_beam/_src/dataset_test.py @@ -736,9 +736,7 @@ def test_to_zarr_chunks_per_shard(self): ds2 = xarray.Dataset({'foo': (('x', 'y'), np.zeros((12, 10)))}) beam_ds2 = xbeam.Dataset.from_xarray(ds2, {'x': 6, 'y': 5}) with beam.Pipeline() as p: - p |= beam_ds2.to_zarr( - temp_dir, zarr_chunks_per_shard={'x': 3, ...: 1} - ) + p |= beam_ds2.to_zarr(temp_dir, zarr_chunks_per_shard={'x': 3, ...: 1}) opened, chunks = xbeam.open_zarr(temp_dir) xarray.testing.assert_identical(ds2, opened) self.assertEqual(chunks, {'x': 2, 'y': 5}) @@ -786,7 +784,8 @@ def test_to_zarr_chunks_per_shard(self): beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 6}) with self.assertRaisesRegex( ValueError, - r'cannot write a dataset with chunks .*zarr_chunks_per_shard=.* which do not evenly divide', + r'cannot write a dataset with chunks .*zarr_chunks_per_shard=.* which' + r' do not evenly divide', ): beam_ds.to_zarr(temp_dir, zarr_chunks_per_shard={'x': 5}) @@ -1002,6 +1001,27 @@ def test_map_blocks_new_split_vars_fails(self): ): source_ds.map_blocks(func) + def test_map_blocks_non_unique(self): + source = xarray.Dataset({'foo': ('x', np.arange(8))}) + source_ds = xbeam.Dataset.from_xarray(source, {'x': 4}) + with self.assertRaisesRegex( + ValueError, + "dimension 'x' has multiple chunks on the source dataset, and " + 'therefore must be included in the result of map_blocks, but is not ' + 'in the new template:', + ): + source_ds.map_blocks(lambda ds: ds.mean('x')) + + def test_map_blocks_inconsistent_chunks_error(self): + source = xarray.Dataset({'foo': ('x', np.arange(8))}) + source_ds = xbeam.Dataset.from_xarray(source, {'x': 4}) + with self.assertRaisesWithLiteralMatch( + ValueError, + "dimension 'x' has 2 chunks on the source dataset and 8 in the result " + 'of map_blocks', + ): + source_ds.map_blocks(lambda ds: ds, chunks={'x': 1}) + class RechunkingTest(test_util.TestCase):