diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index d0dfaec23afc..f429935c3a0e 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1462,9 +1462,10 @@ def normalize(x, none_as_type=False): # Convert bare builtin types to correct type hints directly elif x in _KNOWN_PRIMITIVE_TYPES: return _KNOWN_PRIMITIVE_TYPES[x] - elif getattr(x, '__module__', - None) in ('typing', 'collections', 'collections.abc') or getattr( - x, '__origin__', None) in _KNOWN_PRIMITIVE_TYPES: + elif isinstance(x, types.UnionType) or getattr( + x, '__module__', + None) in ('typing', 'collections', 'collections.abc') or getattr( + x, '__origin__', None) in _KNOWN_PRIMITIVE_TYPES: beam_type = native_type_compatibility.convert_to_beam_type(x) if beam_type != x: # We were able to do the conversion. diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index cec830380087..1377bea6d56d 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -1596,6 +1596,22 @@ def test_hint_helper(self): self.assertFalse(is_consistent_with(Union[str, int], str)) self.assertFalse(is_consistent_with(str, NonBuiltInGeneric[str])) + def test_hint_helper_pipe_union(self): + pipe_union = int | None # pylint: disable=unsupported-binary-operation + typing_union = Union[int, None] + self.assertTrue(is_consistent_with(int, pipe_union)) + self.assertTrue(is_consistent_with(type(None), pipe_union)) + self.assertFalse(is_consistent_with(str, pipe_union)) + self.assertTrue( + is_consistent_with(int, pipe_union) == is_consistent_with( + int, typing_union)) + self.assertTrue( + is_consistent_with(str, pipe_union) == is_consistent_with( + str, typing_union)) + pipe_union_2 = int | float # pylint: disable=unsupported-binary-operation + self.assertTrue(is_consistent_with(int, pipe_union_2)) + self.assertTrue(is_consistent_with(float, pipe_union_2)) + def test_positional_arg_hints(self): self.assertEqual(typehints.Any, _positional_arg_hints('x', {})) self.assertEqual(int, _positional_arg_hints('x', {'x': int})) @@ -1934,6 +1950,14 @@ def test_pipe_operator_as_union(self): native_type_compatibility.convert_to_beam_type(type_a), native_type_compatibility.convert_to_beam_type(type_b)) + def test_normalize_pipe_union(self): + pipe_union = int | None # pylint: disable=unsupported-binary-operation + normalized = typehints.normalize(pipe_union) + self.assertIsInstance(normalized, typehints.UnionConstraint) + pipe_union_2 = int | float # pylint: disable=unsupported-binary-operation + normalized_2 = typehints.normalize(pipe_union_2) + self.assertIsInstance(normalized_2, typehints.UnionConstraint) + class TestNonBuiltInGenerics(unittest.TestCase): def test_no_error_thrown(self):