Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions xarray_beam/_src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,6 @@ def method(self: Dataset, *args, **kwargs) -> Dataset:
return method



class _CountNamer:

def __init__(self):
Expand All @@ -408,6 +407,7 @@ def apply(self, name: str) -> str:
@dataclasses.dataclass(frozen=True)
class _LazyPCollection:
"""Pipeline and PTransform not yet been combined into a PCollection."""

# Beam does not provide a public API for manipulating Pipeline objects, so
# instead of applying pipelines eagerly, we store them in this wrapper. This
# allows for performance optimizations specialized to Xarray-Beam PTransforms,
Expand Down Expand Up @@ -715,12 +715,16 @@ def _check_shards_or_chunks(
zarr_chunks: Mapping[str, int],
chunks_name: Literal['shards', 'chunks'],
) -> None:
if any(self.chunks[k] % zarr_chunks[k] for k in self.chunks):
raise ValueError(
f'cannot write a dataset with chunks {self.chunks} to Zarr with '
f'{chunks_name} {zarr_chunks}, which do not divide evenly into '
f'{chunks_name}'
)
for k in self.chunks:
if (
self.chunks[k] % zarr_chunks[k]
and self.chunks[k] != self.template.sizes[k]
):
raise ValueError(
f'cannot write a dataset with chunks {self.chunks} to Zarr with '
f'{chunks_name} {zarr_chunks}, which do not divide evenly into '
f'{chunks_name}'
)

def to_zarr(
self,
Expand Down Expand Up @@ -804,6 +808,16 @@ def to_zarr(
previous_chunks=self.chunks,
)
if zarr_shards is not None:
# Zarr shards are currently constrained to be an integer multiple of
# chunk sizes, which means shard sizes must be rounded up to be larger
# than the full dimension size. This will likely be relaxed in the future:
# https://github.com/zarr-developers/zarr-extensions/issues/34
zarr_shards = dict(zarr_shards)
for k in zarr_shards:
if zarr_shards[k] == self.sizes[k]:
zarr_shards[k] = (
math.ceil(zarr_shards[k] / zarr_chunks[k]) * zarr_chunks[k]
)
self._check_shards_or_chunks(zarr_shards, 'shards')
else:
self._check_shards_or_chunks(zarr_chunks, 'chunks')
Expand Down Expand Up @@ -956,9 +970,7 @@ def rechunk(
):
# Rechunking can be performed by re-reading the source dataset with new
# chunks, rather than using a separate rechunking transform.
ptransform = core.DatasetToChunks(
ptransform.dataset, chunks, split_vars
)
ptransform = core.DatasetToChunks(ptransform.dataset, chunks, split_vars)
ptransform.label = _concat_labels(ptransform.label, label)
if pipeline is not None:
ptransform = _LazyPCollection(pipeline, ptransform)
Expand Down
22 changes: 22 additions & 0 deletions xarray_beam/_src/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,28 @@ def test_to_zarr_shards(self):
zarr_shards={'x': 9},
)

@parameterized.named_parameters(
dict(testcase_name='empty', zarr_shards={}),
dict(testcase_name='minus_one', zarr_shards=-1),
dict(testcase_name='explicit_19', zarr_shards={'x': 19}),
dict(testcase_name='explicit_20', zarr_shards={'x': 20}),
)
def test_to_zarr_shards_round_up(self, zarr_shards):
temp_dir = self.create_tempdir().full_path
ds = xarray.Dataset({'foo': ('x', np.arange(19, dtype='int64'))})
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 19})

with beam.Pipeline() as p:
p |= beam_ds.to_zarr(
temp_dir,
zarr_chunks={'x': 10},
zarr_shards=zarr_shards,
)
opened, chunks = xbeam.open_zarr(temp_dir)
xarray.testing.assert_identical(ds, opened)
self.assertEqual(chunks, {'x': 10})
self.assertEqual(opened['foo'].encoding['shards'], (20,))

def test_to_zarr_chunks_per_shard(self):
temp_dir = self.create_tempdir().full_path
ds = xarray.Dataset({'foo': ('x', np.arange(12))})
Expand Down