diff --git a/sdks/python/apache_beam/transforms/sideinputs.py b/sdks/python/apache_beam/transforms/sideinputs.py index 7d72a02f8874..a38e05d66cbe 100644 --- a/sdks/python/apache_beam/transforms/sideinputs.py +++ b/sdks/python/apache_beam/transforms/sideinputs.py @@ -60,7 +60,8 @@ def default_window_mapping_fn( def map_via_end(source_window: window.BoundedWindow) -> window.BoundedWindow: return list( target_window_fn.assign( - window.WindowFn.AssignContext(source_window.max_timestamp())))[-1] + window.WindowFn.AssignContext( + source_window.max_timestamp(), window=source_window)))[-1] return map_via_end diff --git a/sdks/python/apache_beam/transforms/sideinputs_test.py b/sdks/python/apache_beam/transforms/sideinputs_test.py index 5f3cf761e1eb..2794e2ac31bb 100644 --- a/sdks/python/apache_beam/transforms/sideinputs_test.py +++ b/sdks/python/apache_beam/transforms/sideinputs_test.py @@ -30,6 +30,7 @@ from typing import Union import pytest +from unittest import mock import apache_beam as beam from apache_beam.testing.synthetic_pipeline import SyntheticSDFAsSource @@ -41,6 +42,7 @@ from apache_beam.transforms import Map from apache_beam.transforms import trigger from apache_beam.transforms import window +from apache_beam.transforms import sideinputs from apache_beam.utils.timestamp import Timestamp @@ -489,6 +491,39 @@ def process( assert_that(results, equal_to([(num_records, expected_fingerprint)])) pipeline.run() + def test_default_window_mapping_fn_source_window(self): + """Test that the default window mapping function will propagate the + source window when attempting to assign context. + """ + class StringIDWindow(window.BoundedWindow): + """A window defined by an arbitrary string ID.""" + def __init__(self, window_id: str): + super().__init__(self._getTimestampFromProto()) + self.id = window_id + + @staticmethod + def _getTimestampFromProto() -> Timestamp: + return Timestamp(micros=0) + + class StringIDWindows(window.NonMergingWindowFn): + """ A windowing function that assigns each element a window with ID.""" + def assign( + self, assign_context: window.WindowFn.AssignContext + ) -> Iterable[BoundedWindow | None]: + if assign_context.element is None: + return [assign_context.window] + return [StringIDWindow(str(assign_context.element))] + + def get_window_coder(self): + return None + + mapping_fn = sideinputs.default_window_mapping_fn(StringIDWindows()) + source_window = StringIDWindows().assign( + window.WindowFn.AssignContext(Timestamp(10), element='element'))[0] + bounded_window = mapping_fn(source_window) + assert bounded_window is not None + assert bounded_window.id == 'element' + if __name__ == '__main__': logging.getLogger().setLevel(logging.DEBUG)