From 1ad680841a7ecffe3b50d65e8cbf55351a5ba21c Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 27 Oct 2025 13:07:38 -0700 Subject: [PATCH] Register custom Beam coders for `xarray_beam.Key` and `xarray.Dataset`. This change adds `KeyCoder` and `DatasetCoder` to explicitly handle serialization of these types within Beam pipelines, preventing warnings about using fallback deterministic coders. PiperOrigin-RevId: 824645533 --- xarray_beam/_src/core.py | 47 +++++++++++++++++++++++++++++++++++ xarray_beam/_src/core_test.py | 13 ++++++++++ 2 files changed, 60 insertions(+) diff --git a/xarray_beam/_src/core.py b/xarray_beam/_src/core.py index 196af77..1ec3895 100644 --- a/xarray_beam/_src/core.py +++ b/xarray_beam/_src/core.py @@ -19,6 +19,7 @@ from functools import cached_property import itertools import math +import pickle import time from typing import Any, Generic, TypeVar @@ -236,6 +237,52 @@ def __setstate__(self, state): self.__init__(*state) +class _PickleCoder(beam.coders.Coder): + """Base class for Xarray-Beam coders that use pickle.""" + + def encode(self, value: Any) -> bytes: + return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL) + + def decode(self, encoded: bytes) -> Any: + return pickle.loads(encoded) + + +class KeyCoder(_PickleCoder): + """Custom coder for Key.""" + + def is_deterministic(self) -> bool: + return True + + def estimate_size(self, value: Key) -> int: + return len(self.encode(value)) + + def to_type_hint(self) -> type[Key]: + return Key + + +# Register a coder for Key, to silence warnings about using the fallback +# deterministic coder. +beam.coders.registry.register_coder(Key, KeyCoder) + + +class DatasetCoder(_PickleCoder): + """Custom coder for xarray.Dataset.""" + + def is_deterministic(self) -> bool: + return False + + def estimate_size(self, value: xarray.Dataset) -> int: + return value.nbytes + + def to_type_hint(self) -> type[xarray.Dataset]: + return xarray.Dataset + + +# I'm not 100% sure if this is used anywhere (I don't see warnings about +# xarray.Dataset using a fall-back coder), but it can't hurt to add it. +beam.coders.registry.register_coder(xarray.Dataset, DatasetCoder) + + K = TypeVar("K") diff --git a/xarray_beam/_src/core_test.py b/xarray_beam/_src/core_test.py index 2b79179..7028761 100644 --- a/xarray_beam/_src/core_test.py +++ b/xarray_beam/_src/core_test.py @@ -299,6 +299,19 @@ def test_offsets_to_slices_base(self): self.assertEqual(slices, expected) +class CodersTest(test_util.TestCase): + + def test_no_fallback_deterministic_coder_warnings(self): + dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) + inputs = [ + (xbeam.Key({'x': 0}), dataset.head(x=3)), + (xbeam.Key({'x': 3}), dataset.tail(x=3)), + ] + with self.assertNoLogs(level='WARNING'): + with beam.Pipeline(runner='DirectRunner') as p: + p | beam.Create(inputs) | beam.GroupByKey() + + class DatasetToChunksTest(test_util.TestCase): def test_iter_chunk_keys(self):