Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions docs/high-level.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@
},
"cell_type": "code",
"source": [
"with beam.Pipeline() as p:\n",
" p | (\n",
" xbeam.Dataset.from_zarr('example_data.zarr')\n",
"with beam.Pipeline() as pipeline:\n",
" (\n",
" xbeam.Dataset.from_zarr('example_data.zarr', pipeline=pipeline)\n",
" .rechunk({'time': -1, ...: '100 MB'})\n",
" .map_blocks(lambda ds: ds.groupby('time.month').mean())\n",
" .rechunk('10 MB') # ensure a reasonable min chunk-size for Zarr\n",
Expand Down Expand Up @@ -206,9 +206,9 @@
},
"cell_type": "code",
"source": [
"with beam.Pipeline() as p:\n",
" p | (\n",
" xbeam.Dataset.from_zarr('example_data.zarr')\n",
"with beam.Pipeline() as pipeline:\n",
" (\n",
" xbeam.Dataset.from_zarr('example_data.zarr', pipeline=pipeline)\n",
" .rechunk({'time': '30MB', 'latitude': -1, 'longitude': -1})\n",
" .map_blocks(lambda ds: ds.coarsen(latitude=2, longitude=2).mean())\n",
" .to_zarr('example_regrid.zarr')\n",
Expand Down Expand Up @@ -254,7 +254,7 @@
},
"cell_type": "markdown",
"source": [
"You can avoid these errors by explicitly supplying a template, either from {py:attr}`Dataset.template \u003cxarray_beam.Dataset.template\u003e` or produced by {py:func}`~xarray_beam.make_template`:"
"You can avoid these errors by explicitly supplying a template, either from {py:attr}`Dataset.template <xarray_beam.Dataset.template>` or produced by {py:func}`~xarray_beam.make_template`:"
]
},
{
Expand All @@ -270,7 +270,18 @@
")"
],
"outputs": [],
"execution_count": 6
"execution_count": 5
},
{
"metadata": {
"id": "7l8Cw8xTURea"
},
"cell_type": "markdown",
"source": [
"```{tip}\n",
"Notice that supplying `pipeline` to {py:func}`~xarray_beam.Dataset.from_zarr` is _optional_. You'll need to eventually apply a Beam pipeline to a `PTransform` produced by Xarray-Beam to compute it, but it can be convenient to omit when building pipelines interactively.\n",
"```"
]
},
{
"metadata": {
Expand Down Expand Up @@ -301,7 +312,7 @@
"all_times = pd.date_range('2025-01-01', freq='1D', periods=365)\n",
"source_dataset = xarray.open_zarr('example_data.zarr', chunks=None)\n",
"\n",
"def load_chunk(time: pd.Timestamp) -\u003e tuple[xbeam.Key, xarray.Dataset]:\n",
"def load_chunk(time: pd.Timestamp) -> tuple[xbeam.Key, xarray.Dataset]:\n",
" key = xbeam.Key({'time': (time - all_times[0]).days})\n",
" dataset = source_dataset.sel(time=[time])\n",
" return key, dataset\n",
Expand Down
2 changes: 1 addition & 1 deletion xarray_beam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@
DatasetToZarr as DatasetToZarr,
)

__version__ = '0.11.2' # automatically synchronized to pyproject.toml
__version__ = '0.11.3' # automatically synchronized to pyproject.toml
122 changes: 88 additions & 34 deletions xarray_beam/_src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

import xarray_beam as xbeam

transform = (
xbeam.Dataset.from_zarr(input_path)
.rechunk({'time': -1, 'latitude': 10, 'longitude': 10})
.map_blocks(lambda x: x.median('time'))
.to_zarr(output_path)
)
with beam.Pipeline() as p:
p | transform
(
xbeam.Dataset.from_zarr(input_path, pipeline=p)
.rechunk({'time': -1, 'latitude': 10, 'longitude': 10})
.map_blocks(lambda x: x.median('time'))
.to_zarr(output_path)
)
"""
from __future__ import annotations

Expand Down Expand Up @@ -371,21 +370,21 @@ def method(self: Dataset, *args, **kwargs) -> Dataset:
chunks = {k: v for k, v in self.chunks.items() if k in template.dims}

label = _get_label(method_name)
if isinstance(self.ptransform, core.DatasetToChunks):

pipeline, ptransform = _split_lazy_pcollection(self._ptransform)
if isinstance(ptransform, core.DatasetToChunks):
# Some transformations (e.g., indexing) can be applied much less
# expensively to xarray.Dataset objects rather than via Xarray-Beam. Try
# to preserve this option for downstream transformations if possible.
dataset = func(self.ptransform.dataset)
dataset = func(ptransform.dataset)
ptransform = core.DatasetToChunks(dataset, chunks, self.split_vars)
ptransform.label = _concat_labels(self.ptransform.label, label)
ptransform.label = _concat_labels(ptransform.label, label)
if pipeline is not None:
ptransform = _LazyPCollection(pipeline, ptransform)
else:
ptransform = self.ptransform | label >> beam.MapTuple(
functools.partial(
_apply_to_each_chunk,
func,
method_name,
self.chunks,
chunks
_apply_to_each_chunk, func, method_name, self.chunks, chunks
)
)
return Dataset(template, chunks, self.split_vars, ptransform)
Expand All @@ -405,6 +404,43 @@ def apply(self, name: str) -> str:
_get_label = _CountNamer().apply


@dataclasses.dataclass(frozen=True)
class _LazyPCollection:
"""Pipeline and PTransform not yet been combined into a PCollection."""
# Beam does not provide a public API for manipulating Pipeline objects, so
# instead of applying pipelines eagerly, we store them in this wrapper. This
# allows for performance optimizations specialized to Xarray-Beam PTransforms,
# (in particular for DatasetToChunks) even in the case where a pipeline is
# supplied.
pipeline: beam.Pipeline
ptransform: beam.PTransform

# Cache the evaluated PCollection, so we apply the transform to the pipeline
# at most once. Otherwise, reapplying the same transform results in reused
# labels in the same pipeline, which is an error in Beam.
@functools.cached_property
def evaluated(self) -> beam.PCollection:
return self.pipeline | self.ptransform


def _split_lazy_pcollection(
value: beam.PTransform | beam.PCollection | _LazyPCollection,
) -> tuple[beam.Pipeline | None, beam.PTransform | beam.PCollection]:
if isinstance(value, _LazyPCollection):
return value.pipeline, value.ptransform
else:
return None, value


def _as_eager_pcollection_or_ptransform(
value: beam.PTransform | beam.PCollection | _LazyPCollection,
) -> beam.PTransform | beam.PCollection:
if isinstance(value, _LazyPCollection):
return value.evaluated
else:
return value


@core.export
@dataclasses.dataclass
class Dataset:
Expand All @@ -415,7 +451,7 @@ def __init__(
template: xarray.Dataset,
chunks: Mapping[str, int],
split_vars: bool,
ptransform: beam.PTransform,
ptransform: beam.PTransform | beam.PCollection | _LazyPCollection,
):
"""Low level interface for creating a new Dataset, without validation.

Expand All @@ -429,7 +465,7 @@ def __init__(
use :py:func:`xarray_beam.normalize_chunks`.
split_vars: whether variables are split between separate elements in the
ptransform, or all stored in the same element.
ptransform: Beam PTransform of ``(xbeam.Key, xarray.Dataset)`` tuples with
ptransform: Beam collection of ``(xbeam.Key, xarray.Dataset)`` tuples with
this dataset's data.
"""
self._template = template
Expand All @@ -453,9 +489,9 @@ def split_vars(self) -> bool:
return self._split_vars

@property
def ptransform(self) -> beam.PTransform:
def ptransform(self) -> beam.PTransform | beam.PCollection:
"""Beam PTransform of (xbeam.Key, xarray.Dataset) with this dataset's data."""
return self._ptransform
return _as_eager_pcollection_or_ptransform(self._ptransform)

@property
def sizes(self) -> Mapping[str, int]:
Expand Down Expand Up @@ -507,7 +543,7 @@ def __repr__(self):
plural = 's' if chunk_count != 1 else ''
return (
'<xarray_beam.Dataset>\n'
f'PTransform: {self.ptransform}\n'
f'PTransform: {self._ptransform}\n'
f'Chunks: {chunk_size} ({chunks_str})\n'
f'Template: {total_size} ({chunk_count} chunk{plural})\n'
+ textwrap.indent('\n'.join(base.split('\n')[1:]), ' ' * 4)
Expand All @@ -516,7 +552,7 @@ def __repr__(self):
@classmethod
def from_ptransform(
cls,
ptransform: beam.PTransform,
ptransform: beam.PTransform | beam.PCollection,
*,
template: xarray.Dataset,
chunks: Mapping[str | types.EllipsisType, int],
Expand All @@ -535,9 +571,9 @@ def from_ptransform(
outputs are valid.

Args:
ptransform: A Beam PTransform that yields ``(Key, xarray.Dataset)`` pairs.
You only need to set ``offsets`` on these keys, ``vars`` will be
automatically set based on the dataset if ``split_vars`` is True.
ptransform: A Beam collection of ``(Key, xarray.Dataset)`` pairs. You only
need to set ``offsets`` on these keys, ``vars`` will be automatically
set based on the dataset if ``split_vars`` is True.
template: An ``xarray.Dataset`` object representing the schema
(coordinates, dimensions, data variables, and attributes) of the full
dataset, as produced by :py:func:`xarray_beam.make_template`, with data
Expand Down Expand Up @@ -577,6 +613,7 @@ def from_xarray(
*,
split_vars: bool = False,
previous_chunks: Mapping[str, int] | None = None,
pipeline: beam.Pipeline | None = None,
) -> Dataset:
"""Create an xarray_beam.Dataset from an xarray.Dataset.

Expand All @@ -588,13 +625,17 @@ def from_xarray(
ptransform, or all stored in the same element.
previous_chunks: chunks hint used for parsing string values in ``chunks``
with ``normalize_chunks()``.
pipeline: Beam pipeline to use for this dataset. If not provided, you will
need apply a pipeline later to compute this dataset.
"""
template = zarr.make_template(source)
if previous_chunks is None:
previous_chunks = source.sizes
chunks = normalize_chunks(chunks, template, split_vars, previous_chunks)
ptransform = core.DatasetToChunks(source, chunks, split_vars)
ptransform.label = _get_label('from_xarray')
if pipeline is not None:
ptransform = _LazyPCollection(pipeline, ptransform)
return cls(template, dict(chunks), split_vars, ptransform)

@classmethod
Expand All @@ -604,6 +645,7 @@ def from_zarr(
*,
chunks: UnnormalizedChunks | None = None,
split_vars: bool = False,
pipeline: beam.Pipeline | None = None,
) -> Dataset:
"""Create an xarray_beam.Dataset from a Zarr store.

Expand All @@ -614,6 +656,8 @@ def from_zarr(
provided, the chunk sizes will be inferred from the Zarr file.
split_vars: whether variables are split between separate elements in the
ptransform, or all stored in the same element.
pipeline: Beam pipeline to use for this dataset. If not provided, you will
need apply a pipeline later to compute this dataset.

Returns:
New Dataset created from the Zarr store.
Expand All @@ -622,9 +666,14 @@ def from_zarr(
if chunks is None:
chunks = previous_chunks
result = cls.from_xarray(
source, chunks, split_vars=split_vars, previous_chunks=previous_chunks
source,
chunks,
split_vars=split_vars,
previous_chunks=previous_chunks,
)
result.ptransform.label = _get_label('from_zarr')
if pipeline is not None:
result._ptransform = _LazyPCollection(pipeline, result.ptransform)
return result

def _check_shards_or_chunks(
Expand All @@ -650,7 +699,7 @@ def to_zarr(
zarr_shards: UnnormalizedChunks | None = None,
zarr_format: int | None = None,
stage_locally: bool | None = None,
) -> beam.PTransform:
) -> beam.PTransform | beam.PCollection:
"""Write this dataset to a Zarr file.

The extensive options for controlling chunking and sharding are intended for
Expand Down Expand Up @@ -688,7 +737,7 @@ def to_zarr(
path.

Returns:
Beam PTransform that writes the dataset to a Zarr file.
Beam transform that writes the dataset to a Zarr file.
"""
if zarr_shards is not None:
zarr_shards = normalize_chunks(
Expand Down Expand Up @@ -889,15 +938,18 @@ def rechunk(
)
label = _get_label('rechunk')

if isinstance(self.ptransform, core.DatasetToChunks) and all(
pipeline, ptransform = _split_lazy_pcollection(self._ptransform)
if isinstance(ptransform, core.DatasetToChunks) and all(
chunks[k] % self.chunks[k] == 0 for k in chunks
):
# Rechunking can be performed by re-reading the source dataset with new
# chunks, rather than using a separate rechunking transform.
ptransform = core.DatasetToChunks(
self.ptransform.dataset, chunks, split_vars
ptransform.dataset, chunks, split_vars
)
ptransform.label = _concat_labels(self.ptransform.label, label)
ptransform.label = _concat_labels(ptransform.label, label)
if pipeline is not None:
ptransform = _LazyPCollection(pipeline, ptransform)
return type(self)(self.template, chunks, split_vars, ptransform)

# Need to do a full rechunking.
Expand Down Expand Up @@ -982,23 +1034,25 @@ def mean(

def head(self, **indexers_kwargs: int) -> Dataset:
"""Return a Dataset with the first N elements of each dimension."""
if not isinstance(self.ptransform, core.DatasetToChunks):
_, ptransform = _split_lazy_pcollection(self._ptransform)
if not isinstance(ptransform, core.DatasetToChunks):
raise ValueError(
'head() is only supported on untransformed datasets, with '
'ptransform=DatasetToChunks. This dataset has '
f'ptransform={self.ptransform}'
f'ptransform={ptransform}'
)
return self._head(**indexers_kwargs)

_tail = _whole_dataset_method('tail')

def tail(self, **indexers_kwargs: int) -> Dataset:
"""Return a Dataset with the last N elements of each dimension."""
if not isinstance(self.ptransform, core.DatasetToChunks):
_, ptransform = _split_lazy_pcollection(self._ptransform)
if not isinstance(ptransform, core.DatasetToChunks):
raise ValueError(
'tail() is only supported on untransformed datasets, with '
'ptransform=DatasetToChunks. This dataset has '
f'ptransform={self.ptransform}'
f'ptransform={ptransform}'
)
return self._tail(**indexers_kwargs)

Expand Down
Loading