diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 30ee463ad4e9..9e1d1e1b80dd 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 2 + "modification": 4 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json index 0c1dae5766f3..afdc7f7012a8 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 9 + "modification": 11 } diff --git a/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py b/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py index 094161afbf93..26fa2f400d83 100644 --- a/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py @@ -36,8 +36,6 @@ from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to -from apache_beam.typehints.schemas import LogicalType -from apache_beam.typehints.schemas import MillisInstant from apache_beam.utils.timestamp import Timestamp # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports @@ -242,10 +240,6 @@ def test_xlang_jdbc_write_read(self, database): config = self.jdbc_configs[database] - # Register MillisInstant logical type to override the mapping from Timestamp - # originally handled by MicrosInstant. - LogicalType.register_logical_type(MillisInstant) - with TestPipeline() as p: p.not_use_test_runner_api = True _ = ( @@ -356,10 +350,6 @@ def custom_row_equals(expected, actual): classpath=config['classpath'], )) - # Register MillisInstant logical type to override the mapping from Timestamp - # originally handled by MicrosInstant. - LogicalType.register_logical_type(MillisInstant) - # Run read pipeline with custom schema with TestPipeline() as p: p.not_use_test_runner_api = True diff --git a/sdks/python/apache_beam/io/jdbc.py b/sdks/python/apache_beam/io/jdbc.py index 604b95f6eebe..79e6b3ce315e 100644 --- a/sdks/python/apache_beam/io/jdbc.py +++ b/sdks/python/apache_beam/io/jdbc.py @@ -86,6 +86,7 @@ # pytype: skip-file +import contextlib import datetime import typing @@ -257,6 +258,17 @@ def __init__( ) +@contextlib.contextmanager +def enforce_millis_instant_for_timestamp(): + old_registry = LogicalType._known_logical_types + LogicalType._known_logical_types = old_registry.copy() + try: + LogicalType.register_logical_type(MillisInstant) + yield + finally: + LogicalType._known_logical_types = old_registry + + class ReadFromJdbc(ExternalTransform): """A PTransform which reads Rows from the specified database via JDBC. @@ -352,8 +364,9 @@ def __init__( dataSchema = None if schema is not None: - # Convert Python schema to Beam Schema proto - schema_proto = typing_to_runner_api(schema).row_type.schema + with enforce_millis_instant_for_timestamp(): + # Convert Python schema to Beam Schema proto + schema_proto = typing_to_runner_api(schema).row_type.schema # Serialize the proto to bytes for transmission dataSchema = schema_proto.SerializeToString() diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py index 038eb50d0606..08838c84a050 100644 --- a/sdks/python/apache_beam/typehints/row_type.py +++ b/sdks/python/apache_beam/typehints/row_type.py @@ -228,6 +228,7 @@ def __init__( schema_id=schema_id, schema_options=schema_options, field_options=field_options, + field_descriptions=field_descriptions, **kwargs) user_type = named_tuple_from_schema(schema, **kwargs) diff --git a/sdks/python/apache_beam/typehints/schema_registry.py b/sdks/python/apache_beam/typehints/schema_registry.py index a73e97f43f70..7d8cdcf57d3f 100644 --- a/sdks/python/apache_beam/typehints/schema_registry.py +++ b/sdks/python/apache_beam/typehints/schema_registry.py @@ -40,7 +40,7 @@ def generate_new_id(self): "schemas.") def add(self, typing, schema): - if not schema.id: + if schema.id: self.by_id[schema.id] = (typing, schema) def get_typing_by_id(self, unique_id): diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index 90a692e21125..a3d9e4d8bf73 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -142,9 +142,11 @@ def named_fields_to_schema( schema_options: Optional[Sequence[Tuple[str, Any]]] = None, field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, schema_registry: SchemaTypeRegistry = SCHEMA_REGISTRY, + field_descriptions: Optional[Dict[str, str]] = None, ): schema_options = schema_options or [] field_options = field_options or {} + field_descriptions = field_descriptions or {} if isinstance(names_and_types, dict): names_and_types = names_and_types.items() @@ -158,7 +160,8 @@ def named_fields_to_schema( option_to_runner_api(option_tuple) for option_tuple in field_options.get(name, []) ], - ) for (name, type) in names_and_types + description=field_descriptions.get(name, None)) + for (name, type) in names_and_types ], options=[ option_to_runner_api(option_tuple) for option_tuple in schema_options @@ -616,6 +619,13 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema: if isinstance(element_type, row_type.RowTypeConstraint): return named_fields_to_schema(element_type._fields) elif match_is_named_tuple(element_type): + if hasattr(element_type, row_type._BEAM_SCHEMA_ID): + # if the named tuple's schema is in registry, we just use it instead of + # regenerating one. + schema_id = getattr(element_type, row_type._BEAM_SCHEMA_ID) + schema = SCHEMA_REGISTRY.get_schema_by_id(schema_id) + if schema is not None: + return schema return named_tuple_to_schema(element_type) else: raise TypeError( @@ -1017,15 +1027,15 @@ def representation_type(cls): def language_type(cls): return decimal.Decimal - def to_representation_type(self, value): - # type: (decimal.Decimal) -> bytes - - return DecimalLogicalType().to_representation_type(value) - - def to_language_type(self, value): - # type: (bytes) -> decimal.Decimal + # from language type (decimal.Decimal) to representation type + # (the type corresponding to the coder used in DecimalLogicalType) + def to_representation_type(self, value: decimal.Decimal) -> decimal.Decimal: + return value - return DecimalLogicalType().to_language_type(value) + # from representation type (the type corresponding to the coder used in + # DecimalLogicalType) to language type + def to_language_type(self, value: decimal.Decimal) -> decimal.Decimal: + return value @classmethod def argument_type(cls):