Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions xarray_beam/_src/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,26 @@
# limitations under the License.
"""IO with Zarr via Xarray."""
from __future__ import annotations

import collections
import concurrent.futures
import dataclasses
import logging
from typing import (
Any,
AbstractSet,
Any,
Dict,
Optional,
Mapping,
MutableMapping,
Optional,
Set,
Tuple,
Union,
MutableMapping,
)

import apache_beam as beam
import dask
import dask.array
import xarray

from xarray_beam._src import core
from xarray_beam._src import rechunk
from xarray_beam._src import threadmap
Expand Down Expand Up @@ -284,9 +283,6 @@ def validate_zarr_chunk(
key: the Key corresponding to the position of the chunk to write in the
template.
chunk: the chunk to write.
by `xarray_beam.make_template`). One or more variables are expected to be
"chunked" with Dask, and will only have their metadata written to Zarr
without array values.
template: a lazy xarray.Dataset already chunked using Dask (e.g., as created
by `xarray_beam.make_template`). One or more variables are expected to be
"chunked" with Dask, and will only have their metadata written to Zarr
Expand Down Expand Up @@ -383,8 +379,9 @@ def __init__(
store: WritableStore,
template: Union[xarray.Dataset, beam.pvalue.AsSingleton, None] = None,
zarr_chunks: Optional[Mapping[str, int]] = None,
*,
num_threads: Optional[int] = None,
setup_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None,
needs_setup: bool = True,
):
# pyformat: disable
"""Initialize ChunksToZarr.
Expand Down Expand Up @@ -415,24 +412,26 @@ def __init__(
and makes it harder for Beam runners to shard work. Note that each
variable in a Dataset is already written in parallel, so this is most
useful for Datasets with a small number of variables.
setup_executor: an optional thread pool executor to use for setting up
the Zarr store when creating ChunksToZarr() objects in a non-blocking
fashion. Only used if template is provided as an xarray.Dataset. If not
provided, setup is performed eagerly.
needs_setup: if False, then the Zarr store is already setup and does not
need to be set up as part of this PTransform.
"""
# pyformat: enable
if isinstance(template, xarray.Dataset):
if setup_executor is not None:
setup_executor.submit(setup_zarr, template, store, zarr_chunks)
else:
if needs_setup:
setup_zarr(template, store, zarr_chunks)
if zarr_chunks is None:
zarr_chunks = _infer_zarr_chunks(template)
template = _make_template_from_chunked(template)
elif isinstance(template, beam.pvalue.AsSingleton):
pass
if not needs_setup:
raise ValueError(
'setup required if template is a beam.pvalue.AsSingleton object'
)
# Setup happens later, in expand().
elif template is None:
pass
if not needs_setup:
raise ValueError('setup required if template is not supplied')
# Setup happens later, in expand().
else:
raise TypeError(
'template must be an None, an xarray.Dataset, or a '
Expand All @@ -458,7 +457,7 @@ def _write_chunk_to_zarr(self, key, chunk, template=None):
def expand(self, pcoll):
if isinstance(self.template, xarray.Dataset):
template = self.template
setup_result = None # already setup in __init__
setup_result = None # already setup
else:
if isinstance(self.template, beam.pvalue.AsSingleton):
template = self.template
Expand Down
9 changes: 3 additions & 6 deletions xarray_beam/_src/zarr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for xarray_beam._src.core."""
import re

from absl.testing import absltest
from absl.testing import parameterized
from concurrent import futures
import dask.array as da
import numpy as np
import xarray
Expand Down Expand Up @@ -124,11 +122,10 @@ def test_chunks_to_zarr(self):
inputs | xbeam.ChunksToZarr(temp_dir, chunked)
result = xarray.open_zarr(temp_dir, consolidated=True)
xarray.testing.assert_identical(dataset, result)
with self.subTest('with template and setup_executor'):
with self.subTest('with template and needs_setup=False'):
temp_dir = self.create_tempdir().full_path
with futures.ThreadPoolExecutor() as executor:
to_zarr = xbeam.ChunksToZarr(temp_dir, chunked, setup_executor=executor)
inputs | to_zarr
xbeam.setup_zarr(chunked, temp_dir)
inputs | xbeam.ChunksToZarr(temp_dir, chunked, needs_setup=False)
result = xarray.open_zarr(temp_dir, consolidated=True)
xarray.testing.assert_identical(dataset, result)
with self.subTest('with zarr_chunks and with template'):
Expand Down
Loading