diff --git a/docs/high-level.ipynb b/docs/high-level.ipynb index f5be474..965eab2 100644 --- a/docs/high-level.ipynb +++ b/docs/high-level.ipynb @@ -176,9 +176,9 @@ }, "cell_type": "code", "source": [ - "with beam.Pipeline() as p:\n", - " p | (\n", - " xbeam.Dataset.from_zarr('example_data.zarr')\n", + "with beam.Pipeline() as pipeline:\n", + " (\n", + " xbeam.Dataset.from_zarr('example_data.zarr', pipeline=pipeline)\n", " .rechunk({'time': -1, ...: '100 MB'})\n", " .map_blocks(lambda ds: ds.groupby('time.month').mean())\n", " .rechunk('10 MB') # ensure a reasonable min chunk-size for Zarr\n", @@ -206,9 +206,9 @@ }, "cell_type": "code", "source": [ - "with beam.Pipeline() as p:\n", - " p | (\n", - " xbeam.Dataset.from_zarr('example_data.zarr')\n", + "with beam.Pipeline() as pipeline:\n", + " (\n", + " xbeam.Dataset.from_zarr('example_data.zarr', pipeline=pipeline)\n", " .rechunk({'time': '30MB', 'latitude': -1, 'longitude': -1})\n", " .map_blocks(lambda ds: ds.coarsen(latitude=2, longitude=2).mean())\n", " .to_zarr('example_regrid.zarr')\n", @@ -254,7 +254,7 @@ }, "cell_type": "markdown", "source": [ - "You can avoid these errors by explicitly supplying a template, either from {py:attr}`Dataset.template \u003cxarray_beam.Dataset.template\u003e` or produced by {py:func}`~xarray_beam.make_template`:" + "You can avoid these errors by explicitly supplying a template, either from {py:attr}`Dataset.template ` or produced by {py:func}`~xarray_beam.make_template`:" ] }, { @@ -270,7 +270,18 @@ ")" ], "outputs": [], - "execution_count": 6 + "execution_count": 5 + }, + { + "metadata": { + "id": "7l8Cw8xTURea" + }, + "cell_type": "markdown", + "source": [ + "```{tip}\n", + "Notice that supplying `pipeline` to {py:func}`~xarray_beam.Dataset.from_zarr` is _optional_. You'll need to eventually apply a Beam pipeline to a `PTransform` produced by Xarray-Beam to compute it, but it can be convenient to omit when building pipelines interactively.\n", + "```" + ] }, { "metadata": { @@ -301,7 +312,7 @@ "all_times = pd.date_range('2025-01-01', freq='1D', periods=365)\n", "source_dataset = xarray.open_zarr('example_data.zarr', chunks=None)\n", "\n", - "def load_chunk(time: pd.Timestamp) -\u003e tuple[xbeam.Key, xarray.Dataset]:\n", + "def load_chunk(time: pd.Timestamp) -> tuple[xbeam.Key, xarray.Dataset]:\n", " key = xbeam.Key({'time': (time - all_times[0]).days})\n", " dataset = source_dataset.sel(time=[time])\n", " return key, dataset\n", diff --git a/xarray_beam/__init__.py b/xarray_beam/__init__.py index 61aacc8..f2bad3f 100644 --- a/xarray_beam/__init__.py +++ b/xarray_beam/__init__.py @@ -55,4 +55,4 @@ DatasetToZarr as DatasetToZarr, ) -__version__ = '0.11.2' # automatically synchronized to pyproject.toml +__version__ = '0.11.3' # automatically synchronized to pyproject.toml diff --git a/xarray_beam/_src/dataset.py b/xarray_beam/_src/dataset.py index 0a0d1b9..6dedcea 100644 --- a/xarray_beam/_src/dataset.py +++ b/xarray_beam/_src/dataset.py @@ -17,14 +17,13 @@ import xarray_beam as xbeam - transform = ( - xbeam.Dataset.from_zarr(input_path) - .rechunk({'time': -1, 'latitude': 10, 'longitude': 10}) - .map_blocks(lambda x: x.median('time')) - .to_zarr(output_path) - ) with beam.Pipeline() as p: - p | transform + ( + xbeam.Dataset.from_zarr(input_path, pipeline=p) + .rechunk({'time': -1, 'latitude': 10, 'longitude': 10}) + .map_blocks(lambda x: x.median('time')) + .to_zarr(output_path) + ) """ from __future__ import annotations @@ -371,21 +370,21 @@ def method(self: Dataset, *args, **kwargs) -> Dataset: chunks = {k: v for k, v in self.chunks.items() if k in template.dims} label = _get_label(method_name) - if isinstance(self.ptransform, core.DatasetToChunks): + + pipeline, ptransform = _split_lazy_pcollection(self._ptransform) + if isinstance(ptransform, core.DatasetToChunks): # Some transformations (e.g., indexing) can be applied much less # expensively to xarray.Dataset objects rather than via Xarray-Beam. Try # to preserve this option for downstream transformations if possible. - dataset = func(self.ptransform.dataset) + dataset = func(ptransform.dataset) ptransform = core.DatasetToChunks(dataset, chunks, self.split_vars) - ptransform.label = _concat_labels(self.ptransform.label, label) + ptransform.label = _concat_labels(ptransform.label, label) + if pipeline is not None: + ptransform = _LazyPCollection(pipeline, ptransform) else: ptransform = self.ptransform | label >> beam.MapTuple( functools.partial( - _apply_to_each_chunk, - func, - method_name, - self.chunks, - chunks + _apply_to_each_chunk, func, method_name, self.chunks, chunks ) ) return Dataset(template, chunks, self.split_vars, ptransform) @@ -405,6 +404,43 @@ def apply(self, name: str) -> str: _get_label = _CountNamer().apply +@dataclasses.dataclass(frozen=True) +class _LazyPCollection: + """Pipeline and PTransform not yet been combined into a PCollection.""" + # Beam does not provide a public API for manipulating Pipeline objects, so + # instead of applying pipelines eagerly, we store them in this wrapper. This + # allows for performance optimizations specialized to Xarray-Beam PTransforms, + # (in particular for DatasetToChunks) even in the case where a pipeline is + # supplied. + pipeline: beam.Pipeline + ptransform: beam.PTransform + + # Cache the evaluated PCollection, so we apply the transform to the pipeline + # at most once. Otherwise, reapplying the same transform results in reused + # labels in the same pipeline, which is an error in Beam. + @functools.cached_property + def evaluated(self) -> beam.PCollection: + return self.pipeline | self.ptransform + + +def _split_lazy_pcollection( + value: beam.PTransform | beam.PCollection | _LazyPCollection, +) -> tuple[beam.Pipeline | None, beam.PTransform | beam.PCollection]: + if isinstance(value, _LazyPCollection): + return value.pipeline, value.ptransform + else: + return None, value + + +def _as_eager_pcollection_or_ptransform( + value: beam.PTransform | beam.PCollection | _LazyPCollection, +) -> beam.PTransform | beam.PCollection: + if isinstance(value, _LazyPCollection): + return value.evaluated + else: + return value + + @core.export @dataclasses.dataclass class Dataset: @@ -415,7 +451,7 @@ def __init__( template: xarray.Dataset, chunks: Mapping[str, int], split_vars: bool, - ptransform: beam.PTransform, + ptransform: beam.PTransform | beam.PCollection | _LazyPCollection, ): """Low level interface for creating a new Dataset, without validation. @@ -429,7 +465,7 @@ def __init__( use :py:func:`xarray_beam.normalize_chunks`. split_vars: whether variables are split between separate elements in the ptransform, or all stored in the same element. - ptransform: Beam PTransform of ``(xbeam.Key, xarray.Dataset)`` tuples with + ptransform: Beam collection of ``(xbeam.Key, xarray.Dataset)`` tuples with this dataset's data. """ self._template = template @@ -453,9 +489,9 @@ def split_vars(self) -> bool: return self._split_vars @property - def ptransform(self) -> beam.PTransform: + def ptransform(self) -> beam.PTransform | beam.PCollection: """Beam PTransform of (xbeam.Key, xarray.Dataset) with this dataset's data.""" - return self._ptransform + return _as_eager_pcollection_or_ptransform(self._ptransform) @property def sizes(self) -> Mapping[str, int]: @@ -507,7 +543,7 @@ def __repr__(self): plural = 's' if chunk_count != 1 else '' return ( '\n' - f'PTransform: {self.ptransform}\n' + f'PTransform: {self._ptransform}\n' f'Chunks: {chunk_size} ({chunks_str})\n' f'Template: {total_size} ({chunk_count} chunk{plural})\n' + textwrap.indent('\n'.join(base.split('\n')[1:]), ' ' * 4) @@ -516,7 +552,7 @@ def __repr__(self): @classmethod def from_ptransform( cls, - ptransform: beam.PTransform, + ptransform: beam.PTransform | beam.PCollection, *, template: xarray.Dataset, chunks: Mapping[str | types.EllipsisType, int], @@ -535,9 +571,9 @@ def from_ptransform( outputs are valid. Args: - ptransform: A Beam PTransform that yields ``(Key, xarray.Dataset)`` pairs. - You only need to set ``offsets`` on these keys, ``vars`` will be - automatically set based on the dataset if ``split_vars`` is True. + ptransform: A Beam collection of ``(Key, xarray.Dataset)`` pairs. You only + need to set ``offsets`` on these keys, ``vars`` will be automatically + set based on the dataset if ``split_vars`` is True. template: An ``xarray.Dataset`` object representing the schema (coordinates, dimensions, data variables, and attributes) of the full dataset, as produced by :py:func:`xarray_beam.make_template`, with data @@ -577,6 +613,7 @@ def from_xarray( *, split_vars: bool = False, previous_chunks: Mapping[str, int] | None = None, + pipeline: beam.Pipeline | None = None, ) -> Dataset: """Create an xarray_beam.Dataset from an xarray.Dataset. @@ -588,6 +625,8 @@ def from_xarray( ptransform, or all stored in the same element. previous_chunks: chunks hint used for parsing string values in ``chunks`` with ``normalize_chunks()``. + pipeline: Beam pipeline to use for this dataset. If not provided, you will + need apply a pipeline later to compute this dataset. """ template = zarr.make_template(source) if previous_chunks is None: @@ -595,6 +634,8 @@ def from_xarray( chunks = normalize_chunks(chunks, template, split_vars, previous_chunks) ptransform = core.DatasetToChunks(source, chunks, split_vars) ptransform.label = _get_label('from_xarray') + if pipeline is not None: + ptransform = _LazyPCollection(pipeline, ptransform) return cls(template, dict(chunks), split_vars, ptransform) @classmethod @@ -604,6 +645,7 @@ def from_zarr( *, chunks: UnnormalizedChunks | None = None, split_vars: bool = False, + pipeline: beam.Pipeline | None = None, ) -> Dataset: """Create an xarray_beam.Dataset from a Zarr store. @@ -614,6 +656,8 @@ def from_zarr( provided, the chunk sizes will be inferred from the Zarr file. split_vars: whether variables are split between separate elements in the ptransform, or all stored in the same element. + pipeline: Beam pipeline to use for this dataset. If not provided, you will + need apply a pipeline later to compute this dataset. Returns: New Dataset created from the Zarr store. @@ -622,9 +666,14 @@ def from_zarr( if chunks is None: chunks = previous_chunks result = cls.from_xarray( - source, chunks, split_vars=split_vars, previous_chunks=previous_chunks + source, + chunks, + split_vars=split_vars, + previous_chunks=previous_chunks, ) result.ptransform.label = _get_label('from_zarr') + if pipeline is not None: + result._ptransform = _LazyPCollection(pipeline, result.ptransform) return result def _check_shards_or_chunks( @@ -650,7 +699,7 @@ def to_zarr( zarr_shards: UnnormalizedChunks | None = None, zarr_format: int | None = None, stage_locally: bool | None = None, - ) -> beam.PTransform: + ) -> beam.PTransform | beam.PCollection: """Write this dataset to a Zarr file. The extensive options for controlling chunking and sharding are intended for @@ -688,7 +737,7 @@ def to_zarr( path. Returns: - Beam PTransform that writes the dataset to a Zarr file. + Beam transform that writes the dataset to a Zarr file. """ if zarr_shards is not None: zarr_shards = normalize_chunks( @@ -889,15 +938,18 @@ def rechunk( ) label = _get_label('rechunk') - if isinstance(self.ptransform, core.DatasetToChunks) and all( + pipeline, ptransform = _split_lazy_pcollection(self._ptransform) + if isinstance(ptransform, core.DatasetToChunks) and all( chunks[k] % self.chunks[k] == 0 for k in chunks ): # Rechunking can be performed by re-reading the source dataset with new # chunks, rather than using a separate rechunking transform. ptransform = core.DatasetToChunks( - self.ptransform.dataset, chunks, split_vars + ptransform.dataset, chunks, split_vars ) - ptransform.label = _concat_labels(self.ptransform.label, label) + ptransform.label = _concat_labels(ptransform.label, label) + if pipeline is not None: + ptransform = _LazyPCollection(pipeline, ptransform) return type(self)(self.template, chunks, split_vars, ptransform) # Need to do a full rechunking. @@ -982,11 +1034,12 @@ def mean( def head(self, **indexers_kwargs: int) -> Dataset: """Return a Dataset with the first N elements of each dimension.""" - if not isinstance(self.ptransform, core.DatasetToChunks): + _, ptransform = _split_lazy_pcollection(self._ptransform) + if not isinstance(ptransform, core.DatasetToChunks): raise ValueError( 'head() is only supported on untransformed datasets, with ' 'ptransform=DatasetToChunks. This dataset has ' - f'ptransform={self.ptransform}' + f'ptransform={ptransform}' ) return self._head(**indexers_kwargs) @@ -994,11 +1047,12 @@ def head(self, **indexers_kwargs: int) -> Dataset: def tail(self, **indexers_kwargs: int) -> Dataset: """Return a Dataset with the last N elements of each dimension.""" - if not isinstance(self.ptransform, core.DatasetToChunks): + _, ptransform = _split_lazy_pcollection(self._ptransform) + if not isinstance(ptransform, core.DatasetToChunks): raise ValueError( 'tail() is only supported on untransformed datasets, with ' 'ptransform=DatasetToChunks. This dataset has ' - f'ptransform={self.ptransform}' + f'ptransform={ptransform}' ) return self._tail(**indexers_kwargs) diff --git a/xarray_beam/_src/dataset_test.py b/xarray_beam/_src/dataset_test.py index 0a1aefb..a2b813a 100644 --- a/xarray_beam/_src/dataset_test.py +++ b/xarray_beam/_src/dataset_test.py @@ -602,6 +602,24 @@ def test_from_zarr(self, split_vars): collected = beam_ds.collect_with_direct_runner() xarray.testing.assert_identical(ds, collected) + def test_from_zarr_with_pipeline(self): + temp_dir = self.create_tempdir().full_path + output_path = self.create_tempdir().full_path + ds = xarray.Dataset({'foo': ('x', np.arange(10))}) + ds.chunk({'x': 5}).to_zarr(temp_dir) + + with beam.Pipeline() as pipeline: + beam_ds = xbeam.Dataset.from_zarr(temp_dir, pipeline=pipeline) + self.assertIsInstance(beam_ds._ptransform, xbeam_dataset._LazyPCollection) + self.assertRegex( + repr(beam_ds), + r'PTransform: _LazyPCollection\(pipeline=.+, ptransform=.+\)', + ) + beam_ds.to_zarr(output_path) + + opened = xarray.open_zarr(output_path) + xarray.testing.assert_identical(ds, opened) + def test_from_zarr_with_chunks(self): temp_dir = self.create_tempdir().full_path ds = xarray.Dataset({'foo': (('x', 'y'), np.zeros((100, 100)))}) @@ -802,6 +820,16 @@ def test_to_zarr_default_chunks(self): xarray.testing.assert_identical(ds, opened) self.assertEqual(chunks, {'x': 2, 'y': 2}) + def test_from_xarray_to_zarr_with_pipeline(self): + path = self.create_tempdir().full_path + ds = xarray.Dataset({'foo': ('x', np.arange(10))}) + with beam.Pipeline() as pipeline: + beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5}, pipeline=pipeline) + self.assertIsInstance(beam_ds.ptransform, beam.PCollection) + beam_ds.to_zarr(path) + actual = xarray.open_zarr(path) + xarray.testing.assert_identical(ds, actual) + @parameterized.named_parameters( dict(testcase_name='getitem', call=lambda x: x[['foo']]), dict(testcase_name='transpose', call=lambda x: x.transpose()), @@ -848,6 +876,18 @@ def test_head(self): ): beam_ds.map_blocks(lambda x: x).head(x=2) + def test_head_with_pipeline(self): + ds = xarray.Dataset({'foo': ('x', np.arange(10))}) + path = self.create_tempdir().full_path + with beam.Pipeline() as pipeline: + beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5}, pipeline=pipeline) + head_ds = beam_ds.head(x=2) + self.assertIsInstance(head_ds._ptransform, xbeam_dataset._LazyPCollection) + head_ds.to_zarr(path) + actual = xarray.open_zarr(path) + expected = ds.head(x=2) + xarray.testing.assert_identical(expected, actual) + def test_tail(self): ds = xarray.Dataset({'foo': ('x', np.arange(10))}) beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5}) @@ -1078,6 +1118,22 @@ def test_rechunk_from_zarr_without_ptransform(self): actual = rechunked_ds.collect_with_direct_runner() xarray.testing.assert_identical(actual, source) + def test_rechunk_from_xarray_with_pipeline(self): + source = xarray.Dataset({'foo': (('x', 'y'), np.zeros((100, 100)))}) + path = self.create_tempdir().full_path + with beam.Pipeline() as pipeline: + beam_ds = xbeam.Dataset.from_xarray( + source, {'x': 10, 'y': 10}, pipeline=pipeline + ) + rechunked_ds = beam_ds.rechunk({'x': 20, 'y': 20}) + self.assertEqual(rechunked_ds.chunks, {'x': 20, 'y': 20}) + self.assertIsInstance( + rechunked_ds._ptransform, xbeam_dataset._LazyPCollection + ) + rechunked_ds.to_zarr(path) + actual = xarray.open_zarr(path) + xarray.testing.assert_identical(actual, source) + def test_rechunk_with_existing_split_vars(self): source = xarray.Dataset({ 'foo': (('x', 'y'), np.arange(20).reshape(10, 2)), @@ -1115,6 +1171,50 @@ def test_rechunk_and_split( xarray.testing.assert_identical(actual, source) +class MeanTest(test_util.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='x', dim='x', skipna=True), + dict(testcase_name='y', dim='y', skipna=True), + dict(testcase_name='two_dims', dim=['x', 'y'], skipna=True), + dict(testcase_name='all_dims', dim=None, skipna=True), + dict(testcase_name='skipna_false', dim='y', skipna=False), + ) + def test_mean(self, dim, skipna): + source_ds = xarray.Dataset( + {'foo': (('x', 'y'), np.array([[1, 2, np.nan], [4, np.nan, 6]]))} + ) + beam_ds = xbeam.Dataset.from_xarray(source_ds, chunks={'x': 1}) + actual = beam_ds.mean(dim=dim, skipna=skipna) + expected = source_ds.mean(dim=dim, skipna=skipna) + actual_collected = actual.collect_with_direct_runner() + xarray.testing.assert_allclose(expected, actual_collected) + + def test_mean_large_array_cases(self): + source_ds = xarray.Dataset( + {'foo': (('x', 'y'), np.arange(1000_000).reshape(1000, 1000))} + ) + beam_ds = xbeam.Dataset.from_xarray(source_ds, chunks={'x': 100, 'y': 1000}) + + with self.subTest('dim=y'): + actual = beam_ds.mean(dim='y') + expected = source_ds.mean(dim='y') + actual_collected = actual.collect_with_direct_runner() + xarray.testing.assert_allclose(expected, actual_collected) + + with self.subTest('dim=x'): + actual = beam_ds.mean(dim='x') + expected = source_ds.mean(dim='x') + actual_collected = actual.collect_with_direct_runner() + xarray.testing.assert_allclose(expected, actual_collected) + + with self.subTest('dim=[x,y]'): + actual = beam_ds.mean(dim=['x', 'y']) + expected = source_ds.mean(dim=['x', 'y']) + actual_collected = actual.collect_with_direct_runner() + xarray.testing.assert_allclose(expected, actual_collected) + + class EndToEndTest(test_util.TestCase): def test_bytes_per_chunk_and_chunk_count(self): @@ -1210,6 +1310,36 @@ def test_resample(self): xarray.testing.assert_identical(expected, actual) self.assertEqual(chunks, {'time': 10, 'latitude': 73, 'longitude': 144}) + def test_multiple_analysis_ready_outputs(self): + input_path = self.create_tempdir('source').full_path + temporal_path = self.create_tempdir('temporal').full_path + climatology_path = self.create_tempdir('climatology').full_path + + source_ds = test_util.dummy_era5_surface_dataset( + latitudes=73, longitudes=144, times=365, freq='24H' + ) + source_ds.chunk({'time': 30}).to_zarr(input_path) + + with beam.Pipeline() as p: + ds_spatial = xbeam.Dataset.from_zarr(input_path, pipeline=p) + + ds_temporal = ds_spatial.rechunk({'time': -1, ...: '10MB'}) + out_temporal = ds_temporal.to_zarr(temporal_path) + + out_climatology = ( + ds_temporal.map_blocks(lambda x: x.groupby('time.month').mean()) + .rechunk(-1) + .to_zarr(climatology_path) + ) + self.assertIs(out_temporal.pipeline, out_climatology.pipeline) + + actual_temporal = xarray.open_zarr(temporal_path) + xarray.testing.assert_identical(source_ds, actual_temporal) + + expected_climatology = source_ds.groupby('time.month').mean() + actual_climatology = xarray.open_zarr(climatology_path) + xarray.testing.assert_identical(expected_climatology, actual_climatology) + def test_from_ptransform_docs_example(self): source_ds = test_util.dummy_era5_surface_dataset( times=5, freq='1D', latitudes=3, longitudes=4 @@ -1232,49 +1362,5 @@ def load_chunk(time_val: np.datetime64) -> tuple[xbeam.Key, xarray.Dataset]: xarray.testing.assert_identical(source_ds, actual) -class MeanTest(test_util.TestCase): - - @parameterized.named_parameters( - dict(testcase_name='x', dim='x', skipna=True), - dict(testcase_name='y', dim='y', skipna=True), - dict(testcase_name='two_dims', dim=['x', 'y'], skipna=True), - dict(testcase_name='all_dims', dim=None, skipna=True), - dict(testcase_name='skipna_false', dim='y', skipna=False), - ) - def test_mean(self, dim, skipna): - source_ds = xarray.Dataset( - {'foo': (('x', 'y'), np.array([[1, 2, np.nan], [4, np.nan, 6]]))} - ) - beam_ds = xbeam.Dataset.from_xarray(source_ds, chunks={'x': 1}) - actual = beam_ds.mean(dim=dim, skipna=skipna) - expected = source_ds.mean(dim=dim, skipna=skipna) - actual_collected = actual.collect_with_direct_runner() - xarray.testing.assert_allclose(expected, actual_collected) - - def test_mean_large_array_cases(self): - source_ds = xarray.Dataset( - {'foo': (('x', 'y'), np.arange(1000_000).reshape(1000, 1000))} - ) - beam_ds = xbeam.Dataset.from_xarray(source_ds, chunks={'x': 100, 'y': 1000}) - - with self.subTest('dim=y'): - actual = beam_ds.mean(dim='y') - expected = source_ds.mean(dim='y') - actual_collected = actual.collect_with_direct_runner() - xarray.testing.assert_allclose(expected, actual_collected) - - with self.subTest('dim=x'): - actual = beam_ds.mean(dim='x') - expected = source_ds.mean(dim='x') - actual_collected = actual.collect_with_direct_runner() - xarray.testing.assert_allclose(expected, actual_collected) - - with self.subTest('dim=[x,y]'): - actual = beam_ds.mean(dim=['x', 'y']) - expected = source_ds.mean(dim=['x', 'y']) - actual_collected = actual.collect_with_direct_runner() - xarray.testing.assert_allclose(expected, actual_collected) - - if __name__ == '__main__': absltest.main()