From fc1d602f5bc395e21dda708944a0db0cb43b8a12 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 3 Nov 2025 10:27:26 -0800 Subject: [PATCH] Create `RangeSource` for Xarray-Beam Beam's splittable DoFn API is incredibly obtuse, but Gemini was able to power through a basic implementation with only a small amount of guidance. This will hopefully be useful creating custom IO source(s) for xarray-beam in the future. PiperOrigin-RevId: 827555460 --- xarray_beam/_src/range_source.py | 99 +++++++++++++++++++++++++++ xarray_beam/_src/range_source_test.py | 64 +++++++++++++++++ 2 files changed, 163 insertions(+) create mode 100644 xarray_beam/_src/range_source.py create mode 100644 xarray_beam/_src/range_source_test.py diff --git a/xarray_beam/_src/range_source.py b/xarray_beam/_src/range_source.py new file mode 100644 index 0000000..20ad3e9 --- /dev/null +++ b/xarray_beam/_src/range_source.py @@ -0,0 +1,99 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Beam source for arbitrary data.""" +from __future__ import annotations + +import dataclasses +import math +from typing import Any, Callable, Generic, Iterator, TypeVar + +import apache_beam as beam +from apache_beam.io import iobase +from apache_beam.io import range_trackers + + +_T = TypeVar('_T') + + +@dataclasses.dataclass +class RangeSource(iobase.BoundedSource, Generic[_T]): + """A Beam BoundedSource for a range of elements. + + This source is defined by a count, size of each element, and a function to + retrieve an element by index. + + Attributes: + element_count: number of elements in this source. + element_size: size of each element in bytes. + get_element: callable that given an integer index in the range + ``[0, element_count)`` returns the corresponding element of the source. + """ + + element_count: int + element_size: int + get_element: Callable[[int], _T] + coder: beam.coders.Coder = beam.coders.PickleCoder() + + def __post_init__(self): + if self.element_count < 0: + raise ValueError( + f'element_count must be non-negative: {self.element_count}' + ) + if self.element_size <= 0: + raise ValueError(f'element_size must be positive: {self.element_size}') + + def estimate_size(self) -> int: + """Estimates the size of source in bytes.""" + return self.element_count * self.element_size + + def split( + self, + desired_bundle_size: int, + start_position: int | None = None, + stop_position: int | None = None, + ) -> Iterator[iobase.SourceBundle]: + """Splits the source into a set of bundles.""" + start = start_position if start_position is not None else 0 + stop = stop_position if stop_position is not None else self.element_count + + bundle_size_in_elements = int( + math.ceil(desired_bundle_size / self.element_size) + ) + for bundle_start in range(start, stop, bundle_size_in_elements): + bundle_stop = min(bundle_start + bundle_size_in_elements, stop) + weight = (bundle_stop - bundle_start) * self.element_size + yield iobase.SourceBundle(weight, self, bundle_start, bundle_stop) + + def get_range_tracker( + self, + start_position: int | None, + stop_position: int | None, + ) -> range_trackers.OffsetRangeTracker: + """Returns a RangeTracker for a given position range.""" + start = start_position if start_position is not None else 0 + stop = stop_position if stop_position is not None else self.element_count + return range_trackers.OffsetRangeTracker(start, stop) + + def read( + self, range_tracker: range_trackers.OffsetRangeTracker + ) -> Iterator[_T]: + """Returns an iterator that reads data from the source.""" + i = range_tracker.start_position() + while range_tracker.try_claim(i): + yield self.get_element(i) + i += 1 + + def default_output_coder(self) -> beam.coders.Coder: + """Coder that should be used for the records returned by the source.""" + return self.coder diff --git a/xarray_beam/_src/range_source_test.py b/xarray_beam/_src/range_source_test.py new file mode 100644 index 0000000..3ee43d3 --- /dev/null +++ b/xarray_beam/_src/range_source_test.py @@ -0,0 +1,64 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for xarray_beam._src.range_source.""" +from __future__ import annotations + +from absl.testing import absltest +import apache_beam as beam +from xarray_beam._src import range_source +from xarray_beam._src import test_util + + +class RangeSourceTest(test_util.TestCase): + + def test_read(self): + source = range_source.RangeSource( + element_count=5, + element_size=1, + get_element=lambda i: f'elem_{i}', + ) + result = test_util.EagerPipeline() | beam.io.Read(source) + self.assertEqual(result, ['elem_0', 'elem_1', 'elem_2', 'elem_3', 'elem_4']) + + def test_estimate_size(self): + source = range_source.RangeSource(10, 8, lambda i: i) + self.assertEqual(source.estimate_size(), 80) + + def test_split(self): + source = range_source.RangeSource(10, 1, lambda i: i) + splits = list(source.split(desired_bundle_size=3)) + # 10 elements, size 1, bundle size 3 bytes -> 3 elements/bundle + # bundles: [0,3), [3,6), [6,9), [9,10) + self.assertEqual(len(splits), 4) + positions = [(s.start_position, s.stop_position) for s in splits] + self.assertEqual(positions, [(0, 3), (3, 6), (6, 9), (9, 10)]) + weights = [s.weight for s in splits] + self.assertEqual(weights, [3, 3, 3, 1]) + + def test_read_empty_source(self): + source = range_source.RangeSource(0, 1, lambda i: i) + result = test_util.EagerPipeline() | beam.io.Read(source) + self.assertEqual(result, []) + + def test_nonsplittable_range_is_read(self): + """Test reading a range that is not splittable.""" + source = range_source.RangeSource( + element_count=1, get_element=str, element_size=1 + ) + result = test_util.EagerPipeline() | 'Read' >> beam.io.Read(source) + self.assertEqual(result, ['0']) + + +if __name__ == '__main__': + absltest.main()