diff --git a/CHANGES.md b/CHANGES.md index f4a04320d66c..072b0efabbf3 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -71,6 +71,7 @@ * (Python) Added exception chaining to preserve error context in CloudSQLEnrichmentHandler, processes utilities, and core transforms ([#37422](https://github.com/apache/beam/issues/37422)). * (Python) Added `take(n)` convenience for PCollection: `beam.take(n)` and `pcoll.take(n)` to get the first N elements deterministically without Top.Of + FlatMap ([#X](https://github.com/apache/beam/issues/37429)). +* (Python) Added `type_overrides` parameter to `WriteToBigQuery` allowing users to specify custom BigQuery to Python type mappings when using Storage Write API. This enables support for types like DATE, DATETIME, and JSON (Python) ([#25946](https://github.com/apache/beam/issues/25946)). * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## Breaking Changes diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 181c891c1b65..949e307baf01 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -2009,7 +2009,8 @@ def __init__( use_cdc_writes: bool = False, primary_key: List[str] = None, expansion_service=None, - big_lake_configuration=None): + big_lake_configuration=None, + type_overrides=None): """Initialize a WriteToBigQuery transform. Args: @@ -2186,6 +2187,11 @@ def __init__( CREATE_IF_NEEDED mode for the underlying tables a list of column names is required to be configured as the primary key. Used for STORAGE_WRITE_API, working on 'at least once' mode. + type_overrides (dict): Optional mapping of BigQuery type names (uppercase) + to Python types. These override the default type mappings when + converting BigQuery schemas to Python types for STORAGE_WRITE_API. + For example: ``{'DATE': datetime.date, 'JSON': dict}``. + Default mappings include STRING->str, INT64->np.int64, etc. """ self._table = table self._dataset = dataset @@ -2231,6 +2237,7 @@ def __init__( self._use_cdc_writes = use_cdc_writes self._primary_key = primary_key self._big_lake_configuration = big_lake_configuration + self._type_overrides = type_overrides # Dict/schema methods were moved to bigquery_tools, but keep references # here for backward compatibility. @@ -2395,7 +2402,8 @@ def find_in_nested_dict(schema): use_cdc_writes=self._use_cdc_writes, primary_key=self._primary_key, big_lake_configuration=self._big_lake_configuration, - expansion_service=self.expansion_service) + expansion_service=self.expansion_service, + type_overrides=self._type_overrides) else: raise ValueError(f"Unsupported method {method_to_use}") @@ -2644,7 +2652,8 @@ def __init__( use_cdc_writes: bool = False, primary_key: List[str] = None, big_lake_configuration=None, - expansion_service=None): + expansion_service=None, + type_overrides=None): self._table = table self._table_side_inputs = table_side_inputs self._schema = schema @@ -2658,6 +2667,7 @@ def __init__( self._use_cdc_writes = use_cdc_writes self._primary_key = primary_key self._big_lake_configuration = big_lake_configuration + self._type_overrides = type_overrides self._expansion_service = expansion_service or BeamJarExpansionService( 'sdks:java:io:google-cloud-platform:expansion-service:build') @@ -2691,7 +2701,7 @@ def expand(self, input): input_beam_rows = ( input | "Convert dict to Beam Row" >> self.ConvertToBeamRows( - schema, False).with_output_types()) + schema, False, self._type_overrides).with_output_types()) # For dynamic destinations, we first figure out where each row is going. # Then we send (destination, record) rows over to Java SchemaTransform. @@ -2723,7 +2733,7 @@ def expand(self, input): input_beam_rows = ( input_rows | "Convert dict to Beam Row" >> self.ConvertToBeamRows( - schema, True).with_output_types()) + schema, True, self._type_overrides).with_output_types()) # communicate to Java that this write should use dynamic destinations table = StorageWriteToBigQuery.DYNAMIC_DESTINATIONS @@ -2791,9 +2801,10 @@ def __exit__(self, *args): pass class ConvertToBeamRows(PTransform): - def __init__(self, schema, dynamic_destinations): + def __init__(self, schema, dynamic_destinations, type_overrides=None): self.schema = schema self.dynamic_destinations = dynamic_destinations + self.type_overrides = type_overrides def expand(self, input_dicts): if self.dynamic_destinations: @@ -2819,7 +2830,7 @@ def expand(self, input_dicts): def with_output_types(self): row_type_hints = bigquery_tools.get_beam_typehints_from_tableschema( - self.schema) + self.schema, self.type_overrides) if self.dynamic_destinations: type_hint = RowTypeConstraint.from_fields([ (StorageWriteToBigQuery.DESTINATION, str), diff --git a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py index 54c7ca90f011..792d6ffbd90c 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py @@ -55,15 +55,21 @@ def generate_user_type_from_bq_schema( - the_table_schema, selected_fields: 'bigquery.TableSchema' = None) -> type: + the_table_schema, + selected_fields: 'bigquery.TableSchema' = None, + type_overrides=None) -> type: """Convert a schema of type TableSchema into a pcollection element. Args: the_table_schema: A BQ schema of type TableSchema selected_fields: if not None, the subset of fields to consider + type_overrides: Optional mapping of BigQuery type names (uppercase) + to Python types. These override the default mappings in + BIG_QUERY_TO_PYTHON_TYPES. For example: + ``{'DATE': datetime.date, 'JSON': dict}`` Returns: type: type that can be used to work with pCollections. """ - + effective_types = {**BIG_QUERY_TO_PYTHON_TYPES, **(type_overrides or {})} the_schema = beam.io.gcp.bigquery_tools.get_dict_table_schema( the_table_schema) if the_schema == {}: @@ -72,8 +78,8 @@ def generate_user_type_from_bq_schema( for field in the_schema['fields']: if selected_fields is not None and field['name'] not in selected_fields: continue - if field['type'] in BIG_QUERY_TO_PYTHON_TYPES: - typ = bq_field_to_type(field['type'], field['mode']) + if field['type'] in effective_types: + typ = bq_field_to_type(field['type'], field['mode'], type_overrides) else: raise ValueError( f"Encountered " @@ -85,19 +91,44 @@ def generate_user_type_from_bq_schema( return usertype -def bq_field_to_type(field, mode): +def bq_field_to_type(field, mode, type_overrides=None): + """Convert a BigQuery field type and mode to a Python type hint. + + Args: + field: The BigQuery type name (e.g., 'STRING', 'DATE'). + mode: The field mode ('NULLABLE', 'REPEATED', 'REQUIRED'). + type_overrides: Optional mapping of BigQuery type names (uppercase) + to Python types. These override the default mappings. + + Returns: + The corresponding Python type hint. + """ + effective_types = {**BIG_QUERY_TO_PYTHON_TYPES, **(type_overrides or {})} if mode == 'NULLABLE' or mode is None or mode == '': - return Optional[BIG_QUERY_TO_PYTHON_TYPES[field]] + return Optional[effective_types[field]] elif mode == 'REPEATED': - return Sequence[BIG_QUERY_TO_PYTHON_TYPES[field]] + return Sequence[effective_types[field]] elif mode == 'REQUIRED': - return BIG_QUERY_TO_PYTHON_TYPES[field] + return effective_types[field] else: raise ValueError(f"Encountered an unsupported mode: {mode!r}") -def convert_to_usertype(table_schema, selected_fields=None): - usertype = generate_user_type_from_bq_schema(table_schema, selected_fields) +def convert_to_usertype( + table_schema, selected_fields=None, type_overrides=None): + """Convert a BigQuery table schema to a user type. + + Args: + table_schema: A BQ schema of type TableSchema + selected_fields: if not None, the subset of fields to consider + type_overrides: Optional mapping of BigQuery type names (uppercase) + to Python types. + + Returns: + A ParDo transform that converts dictionaries to the user type. + """ + usertype = generate_user_type_from_bq_schema( + table_schema, selected_fields, type_overrides) return beam.ParDo(BeamSchemaConversionDoFn(usertype)) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools_test.py b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools_test.py index 0eb3351ee84c..3cf641a2fb04 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools_test.py @@ -337,6 +337,115 @@ def test_geography_with_complex_wkt(self): self.assertEqual(usertype.__annotations__, expected_annotations) +@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') +class TestTypeOverridesSchemaTools(unittest.TestCase): + """Tests for type_overrides parameter in bigquery_schema_tools.""" + def test_bq_field_to_type_with_overrides(self): + """Test bq_field_to_type function with type_overrides.""" + import datetime + + from apache_beam.io.gcp.bigquery_schema_tools import bq_field_to_type + + # Without overrides, DATE is not supported + with self.assertRaises(KeyError): + bq_field_to_type("DATE", "REQUIRED") + + # With overrides, DATE works + overrides = {"DATE": datetime.date} + self.assertEqual( + bq_field_to_type("DATE", "REQUIRED", overrides), datetime.date) + self.assertEqual( + bq_field_to_type("DATE", "NULLABLE", overrides), + typing.Optional[datetime.date]) + self.assertEqual( + bq_field_to_type("DATE", "REPEATED", overrides), + typing.Sequence[datetime.date]) + + def test_bq_field_to_type_overrides_can_use_str(self): + """Test that type_overrides can map DATE/DATETIME/JSON to str.""" + from apache_beam.io.gcp.bigquery_schema_tools import bq_field_to_type + + overrides = {"DATE": str, "DATETIME": str, "JSON": str} + self.assertEqual(bq_field_to_type("DATE", "REQUIRED", overrides), str) + self.assertEqual(bq_field_to_type("DATETIME", "REQUIRED", overrides), str) + self.assertEqual(bq_field_to_type("JSON", "REQUIRED", overrides), str) + + def test_generate_user_type_with_overrides(self): + """Test generate_user_type_from_bq_schema with type_overrides.""" + import datetime + + schema = bigquery.TableSchema( + fields=[ + bigquery.TableFieldSchema( + name='id', type='INTEGER', mode="REQUIRED"), + bigquery.TableFieldSchema( + name='event_date', type='DATE', mode="NULLABLE") + ]) + + # Without overrides, DATE is not supported + with self.assertRaises(ValueError): + bigquery_schema_tools.generate_user_type_from_bq_schema(schema) + + # With overrides, DATE works + overrides = {"DATE": datetime.date} + usertype = bigquery_schema_tools.generate_user_type_from_bq_schema( + schema, type_overrides=overrides) + self.assertEqual( + usertype.__annotations__, { + 'id': np.int64, 'event_date': typing.Optional[datetime.date] + }) + + def test_generate_user_type_overrides_with_str(self): + """Test that type_overrides can map DATE to str.""" + schema = bigquery.TableSchema( + fields=[ + bigquery.TableFieldSchema( + name='id', type='INTEGER', mode="REQUIRED"), + bigquery.TableFieldSchema( + name='event_date', type='DATE', mode="NULLABLE") + ]) + + overrides = {"DATE": str} + usertype = bigquery_schema_tools.generate_user_type_from_bq_schema( + schema, type_overrides=overrides) + self.assertEqual( + usertype.__annotations__, { + 'id': np.int64, 'event_date': typing.Optional[str] + }) + + def test_convert_to_usertype_with_overrides(self): + """Test convert_to_usertype function with type_overrides.""" + import datetime + + schema = bigquery.TableSchema( + fields=[ + bigquery.TableFieldSchema( + name='id', type='INTEGER', mode="REQUIRED"), + bigquery.TableFieldSchema( + name='event_date', type='DATE', mode="NULLABLE") + ]) + + overrides = {"DATE": datetime.date} + transform = bigquery_schema_tools.convert_to_usertype( + schema, type_overrides=overrides) + + # The transform should be created successfully + self.assertIsNotNone(transform) + self.assertIsInstance(transform, beam.ParDo) + + def test_type_overrides_can_override_default_types(self): + """Test that type_overrides can override default type mappings.""" + from apache_beam.io.gcp.bigquery_schema_tools import bq_field_to_type + + # GEOGRAPHY is in the default mapping as str + self.assertEqual(bq_field_to_type("GEOGRAPHY", "REQUIRED"), str) + + # We can override it + overrides = {"GEOGRAPHY": bytes} + self.assertEqual( + bq_field_to_type("GEOGRAPHY", "REQUIRED", overrides), bytes) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index ddab941f9278..b254ee2fa1f2 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -1774,18 +1774,23 @@ def get_avro_schema_from_table_schema(schema): "root", dict_table_schema) -def get_beam_typehints_from_tableschema(schema): +def get_beam_typehints_from_tableschema(schema, type_overrides=None): """Extracts Beam Python type hints from the schema. Args: schema (~apache_beam.io.gcp.internal.clients.bigquery.\ bigquery_v2_messages.TableSchema): The TableSchema to extract type hints from. + type_overrides (dict): Optional mapping of BigQuery type names (uppercase) + to Python types. These override the default mappings in + BIGQUERY_TYPE_TO_PYTHON_TYPE. For example: + ``{'DATE': datetime.date, 'JSON': dict}`` Returns: List[Tuple[str, Any]]: A list of type hints that describe the input schema. Nested and repeated fields are supported. """ + effective_types = {**BIGQUERY_TYPE_TO_PYTHON_TYPE, **(type_overrides or {})} if not isinstance(schema, (bigquery.TableSchema, bigquery.TableFieldSchema)): schema = get_bq_tableschema(schema) typehints = [] @@ -1795,9 +1800,9 @@ def get_beam_typehints_from_tableschema(schema): if field_type in ["STRUCT", "RECORD"]: # Structs can be represented as Beam Rows. typehint = RowTypeConstraint.from_fields( - get_beam_typehints_from_tableschema(field)) - elif field_type in BIGQUERY_TYPE_TO_PYTHON_TYPE: - typehint = BIGQUERY_TYPE_TO_PYTHON_TYPE[field_type] + get_beam_typehints_from_tableschema(field, type_overrides)) + elif field_type in effective_types: + typehint = effective_types[field_type] else: raise ValueError( f"Converting BigQuery type [{field_type}] to " diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py index 2594e6728e0e..078c42160941 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py @@ -1248,6 +1248,167 @@ def test_geography_with_special_characters(self): self.assertIsInstance(result, str) +@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') +class TestTypeOverrides(unittest.TestCase): + """Tests for type_overrides parameter in BigQuery type mappings.""" + def test_type_overrides_enables_unsupported_types(self): + """Test that type_overrides enables support for DATE/DATETIME/JSON.""" + import datetime + schema = { + "fields": [{ + "name": "date_field", "type": "DATE", "mode": "REQUIRED" + }, + { + "name": "datetime_field", + "type": "DATETIME", + "mode": "REQUIRED" + }, { + "name": "json_field", "type": "JSON", "mode": "REQUIRED" + }] + } + + # Without overrides, these types are not supported + with self.assertRaises(ValueError): + get_beam_typehints_from_tableschema(schema) + + # With overrides, they work + type_overrides = {"DATE": str, "DATETIME": str, "JSON": str} + typehints = get_beam_typehints_from_tableschema(schema, type_overrides) + self.assertEqual( + typehints, [("date_field", str), ("datetime_field", str), + ("json_field", str)]) + + def test_type_overrides_with_custom_types(self): + """Test type_overrides with custom Python types.""" + import datetime + schema = { + "fields": [{ + "name": "date_field", "type": "DATE", "mode": "REQUIRED" + }, + { + "name": "datetime_field", + "type": "DATETIME", + "mode": "REQUIRED" + }] + } + + type_overrides = {"DATE": datetime.date, "DATETIME": datetime.datetime} + typehints = get_beam_typehints_from_tableschema(schema, type_overrides) + self.assertEqual( + typehints, [("date_field", datetime.date), + ("datetime_field", datetime.datetime)]) + + def test_type_overrides_with_modes(self): + """Test that type_overrides works with NULLABLE and REPEATED modes.""" + import datetime + schema = { + "fields": [{ + "name": "required_date", "type": "DATE", "mode": "REQUIRED" + }, { + "name": "optional_date", "type": "DATE", "mode": "NULLABLE" + }, { + "name": "repeated_dates", "type": "DATE", "mode": "REPEATED" + }] + } + + type_overrides = {"DATE": datetime.date} + typehints = get_beam_typehints_from_tableschema(schema, type_overrides) + + expected = [("required_date", datetime.date), + ("optional_date", Optional[datetime.date]), + ("repeated_dates", Sequence[datetime.date])] + self.assertEqual(typehints, expected) + + def test_type_overrides_mixed_with_default_types(self): + """Test type_overrides alongside default type mappings.""" + import datetime + schema = { + "fields": [{ + "name": "date_field", "type": "DATE", "mode": "REQUIRED" + }, { + "name": "string_field", "type": "STRING", "mode": "REQUIRED" + }, { + "name": "int_field", "type": "INTEGER", "mode": "REQUIRED" + }] + } + + type_overrides = {"DATE": datetime.date} + typehints = get_beam_typehints_from_tableschema(schema, type_overrides) + + expected = [("date_field", datetime.date), ("string_field", str), + ("int_field", np.int64)] + self.assertEqual(typehints, expected) + + def test_type_overrides_with_nested_struct(self): + """Test that type_overrides is propagated to nested STRUCT fields.""" + import datetime + schema = bigquery.TableSchema() + + # Root field + date_field = bigquery.TableFieldSchema() + date_field.name = "date_field" + date_field.type = "DATE" + date_field.mode = "REQUIRED" + + # Nested struct with DATE field + struct_field = bigquery.TableFieldSchema() + struct_field.name = "nested" + struct_field.type = "RECORD" + struct_field.mode = "REQUIRED" + + nested_date = bigquery.TableFieldSchema() + nested_date.name = "nested_date" + nested_date.type = "DATE" + nested_date.mode = "REQUIRED" + struct_field.fields.append(nested_date) + + schema.fields.append(date_field) + schema.fields.append(struct_field) + + type_overrides = {"DATE": datetime.date} + typehints = get_beam_typehints_from_tableschema(schema, type_overrides) + + self.assertEqual(len(typehints), 2) + self.assertEqual(typehints[0], ("date_field", datetime.date)) + # The nested field's DATE should also be overridden + nested_constraint = typehints[1][1] + nested_fields = nested_constraint._fields + self.assertEqual(nested_fields[0], ("nested_date", datetime.date)) + + def test_type_overrides_can_override_default_types(self): + """Test that type_overrides can override default type mappings.""" + schema = { + "fields": [{ + "name": "geo_field", "type": "GEOGRAPHY", "mode": "REQUIRED" + }] + } + + # Without overrides, GEOGRAPHY maps to str (default) + typehints = get_beam_typehints_from_tableschema(schema, None) + self.assertEqual(typehints, [("geo_field", str)]) + + # With overrides, we can change it + typehints_override = get_beam_typehints_from_tableschema( + schema, {"GEOGRAPHY": bytes}) + self.assertEqual(typehints_override, [("geo_field", bytes)]) + + def test_type_overrides_json_to_dict(self): + """Test using type_overrides to map JSON to dict.""" + schema = {"fields": [{"name": "data", "type": "JSON", "mode": "NULLABLE"}]} + + # Without overrides, JSON is not supported + with self.assertRaises(ValueError): + get_beam_typehints_from_tableschema(schema) + + # With overrides, can map to str + typehints_str = get_beam_typehints_from_tableschema(schema, {"JSON": str}) + self.assertEqual(typehints_str, [("data", Optional[str])]) + + # Or map to dict + typehints_dict = get_beam_typehints_from_tableschema(schema, {"JSON": dict}) + self.assertEqual(typehints_dict, [("data", Optional[dict])]) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()