Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions xarray_beam/_src/range_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Beam source for arbitrary data."""
from __future__ import annotations

import dataclasses
import math
from typing import Any, Callable, Generic, Iterator, TypeVar

import apache_beam as beam
from apache_beam.io import iobase
from apache_beam.io import range_trackers


_T = TypeVar('_T')


@dataclasses.dataclass
class RangeSource(iobase.BoundedSource, Generic[_T]):
"""A Beam BoundedSource for a range of elements.

This source is defined by a count, size of each element, and a function to
retrieve an element by index.

Attributes:
element_count: number of elements in this source.
element_size: size of each element in bytes.
get_element: callable that given an integer index in the range
``[0, element_count)`` returns the corresponding element of the source.
"""

element_count: int
element_size: int
get_element: Callable[[int], _T]
coder: beam.coders.Coder = beam.coders.PickleCoder()

def __post_init__(self):
if self.element_count < 0:
raise ValueError(
f'element_count must be non-negative: {self.element_count}'
)
if self.element_size <= 0:
raise ValueError(f'element_size must be positive: {self.element_size}')

def estimate_size(self) -> int:
"""Estimates the size of source in bytes."""
return self.element_count * self.element_size

def split(
self,
desired_bundle_size: int,
start_position: int | None = None,
stop_position: int | None = None,
) -> Iterator[iobase.SourceBundle]:
"""Splits the source into a set of bundles."""
start = start_position if start_position is not None else 0
stop = stop_position if stop_position is not None else self.element_count

bundle_size_in_elements = int(
math.ceil(desired_bundle_size / self.element_size)
)
for bundle_start in range(start, stop, bundle_size_in_elements):
bundle_stop = min(bundle_start + bundle_size_in_elements, stop)
weight = (bundle_stop - bundle_start) * self.element_size
yield iobase.SourceBundle(weight, self, bundle_start, bundle_stop)

def get_range_tracker(
self,
start_position: int | None,
stop_position: int | None,
) -> range_trackers.OffsetRangeTracker:
"""Returns a RangeTracker for a given position range."""
start = start_position if start_position is not None else 0
stop = stop_position if stop_position is not None else self.element_count
return range_trackers.OffsetRangeTracker(start, stop)

def read(
self, range_tracker: range_trackers.OffsetRangeTracker
) -> Iterator[_T]:
"""Returns an iterator that reads data from the source."""
i = range_tracker.start_position()
while range_tracker.try_claim(i):
yield self.get_element(i)
i += 1

def default_output_coder(self) -> beam.coders.Coder:
"""Coder that should be used for the records returned by the source."""
return self.coder
64 changes: 64 additions & 0 deletions xarray_beam/_src/range_source_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for xarray_beam._src.range_source."""
from __future__ import annotations

from absl.testing import absltest
import apache_beam as beam
from xarray_beam._src import range_source
from xarray_beam._src import test_util


class RangeSourceTest(test_util.TestCase):

def test_read(self):
source = range_source.RangeSource(
element_count=5,
element_size=1,
get_element=lambda i: f'elem_{i}',
)
result = test_util.EagerPipeline() | beam.io.Read(source)
self.assertEqual(result, ['elem_0', 'elem_1', 'elem_2', 'elem_3', 'elem_4'])

def test_estimate_size(self):
source = range_source.RangeSource(10, 8, lambda i: i)
self.assertEqual(source.estimate_size(), 80)

def test_split(self):
source = range_source.RangeSource(10, 1, lambda i: i)
splits = list(source.split(desired_bundle_size=3))
# 10 elements, size 1, bundle size 3 bytes -> 3 elements/bundle
# bundles: [0,3), [3,6), [6,9), [9,10)
self.assertEqual(len(splits), 4)
positions = [(s.start_position, s.stop_position) for s in splits]
self.assertEqual(positions, [(0, 3), (3, 6), (6, 9), (9, 10)])
weights = [s.weight for s in splits]
self.assertEqual(weights, [3, 3, 3, 1])

def test_read_empty_source(self):
source = range_source.RangeSource(0, 1, lambda i: i)
result = test_util.EagerPipeline() | beam.io.Read(source)
self.assertEqual(result, [])

def test_nonsplittable_range_is_read(self):
"""Test reading a range that is not splittable."""
source = range_source.RangeSource(
element_count=1, get_element=str, element_size=1
)
result = test_util.EagerPipeline() | 'Read' >> beam.io.Read(source)
self.assertEqual(result, ['0'])


if __name__ == '__main__':
absltest.main()