diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index e393113c002e..f45d8aca5a80 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -98,6 +98,7 @@ def foo((a, b)): from typing import get_args from typing import get_origin +from apache_beam.options.pipeline_options_context import get_pipeline_options from apache_beam.pvalue import TaggedOutput from apache_beam.typehints import native_type_compatibility from apache_beam.typehints import typehints @@ -546,12 +547,22 @@ def extract_tagged_outputs(self): A copy of this instance with TaggedOutput members moved from the main output type into the output kwargs dict. """ + opts = get_pipeline_options() + if opts and opts.is_compat_version_prior_to("2.72.0"): + return self if self.output_types is None or not self.has_simple_output_type(): return self - output_type = self.output_types[0][0] + # Tags already set via decorator/chain style — nothing to extract. + if self.output_types[1]: + return self + output_type = self.output_types[0][0] clean_type, extracted_tags = _extract_tagged_from_type(output_type) - if not extracted_tags: + + # If no tags were extracted, only return if the type is also unchanged. + # A bare `TaggedOutput` (e.g. in `int | TaggedOutput`) results in no + # extracted tags, but `clean_type` is modified, so we should proceed. + if not extracted_tags and clean_type == output_type: return self if clean_type is _NO_MAIN_TYPE: clean_type = typehints.Any diff --git a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py index c06f68fb88a4..3147c3f63450 100644 --- a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py +++ b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py @@ -39,8 +39,13 @@ def fn(element) -> int | TaggedOutput[Literal['errors'], str]: from typing import Literal from typing import Union +from parameterized import param +from parameterized import parameterized + import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.pvalue import TaggedOutput +from apache_beam.typehints import Any from apache_beam.typehints import with_output_types from apache_beam.typehints.decorators import IOTypeHints @@ -351,6 +356,36 @@ def process( self.assertEqual(results.main.element_type, int) self.assertEqual(results.errors.element_type, str) + @parameterized.expand([ + param(compat_version=None), + param(compat_version="2.71.0"), + ]) + def test_pardo_annotation_process_method_update_compatible( + self, compat_version): + """Test DoFn with process method annotation preserves update compatibility. + Pre 2.72.0 the main output is the union of all types, and tagged pcols + Any""" + class AnnotatedDoFn(beam.DoFn): + def process(self, element: int) -> Iterable[int | TaggedOutput]: + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline(options=PipelineOptions( + update_compatibility_version=compat_version)) as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.ParDo(AnnotatedDoFn()).with_outputs('errors', main='main')) + if compat_version: + self.assertEqual( + results.main.element_type.union_types, [TaggedOutput, int]) + self.assertEqual(results.errors.element_type, Any) + else: + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, Any) + if __name__ == '__main__': unittest.main()