diff --git a/python/pyarrow/parquet/core.py b/python/pyarrow/parquet/core.py index 676bc445238..a2f52b4a0e4 100644 --- a/python/pyarrow/parquet/core.py +++ b/python/pyarrow/parquet/core.py @@ -715,32 +715,77 @@ def _sanitized_spark_field_name(name): return _SPARK_DISALLOWED_CHARS.sub('_', name) -def _sanitize_schema(schema, flavor): - if 'spark' in flavor: - sanitized_fields = [] +def _sanitize_field_recursive(field): + """ + Recursively sanitize field names in struct types for Spark compatibility. - schema_changed = False + Returns + ------- + tuple + (sanitized_field, changed) where changed is True if any sanitization occurred + """ + sanitized_name = _sanitized_spark_field_name(field.name) + sanitized_type = field.type + type_changed = False + + if pa.types.is_struct(field.type): + sanitized_fields = [_sanitize_field_recursive(f) for f in field.type] + if any(changed for _, changed in sanitized_fields): + sanitized_type = pa.struct([f for f, _ in sanitized_fields]) + type_changed = True + elif pa.types.is_list(field.type) or pa.types.is_large_list(field.type): + # Sanitize the value field of list types + value_field = field.type.value_field + sanitized_value_field, value_changed = _sanitize_field_recursive(value_field) + if value_changed: + if pa.types.is_list(field.type): + sanitized_type = pa.list_(sanitized_value_field) + else: # large_list + sanitized_type = pa.large_list(sanitized_value_field) + type_changed = True + elif pa.types.is_fixed_size_list(field.type): + # Sanitize the value field of fixed_size_list types + value_field = field.type.value_field + list_size = field.type.list_size + sanitized_value_field, value_changed = _sanitize_field_recursive(value_field) + if value_changed: + sanitized_type = pa.list_(sanitized_value_field, list_size) + type_changed = True + elif pa.types.is_map(field.type): + # Sanitize both key and item fields of map types + key_field = field.type.key_field + item_field = field.type.item_field + sanitized_key_field, key_changed = _sanitize_field_recursive(key_field) + sanitized_item_field, item_changed = _sanitize_field_recursive(item_field) + if key_changed or item_changed: + sanitized_type = pa.map_(sanitized_key_field, sanitized_item_field, + keys_sorted=field.type.keys_sorted) + type_changed = True + + name_changed = sanitized_name != field.name + if name_changed or type_changed: + return pa.field(sanitized_name, sanitized_type, field.nullable, + field.metadata), True + return field, False - for field in schema: - name = field.name - sanitized_name = _sanitized_spark_field_name(name) - if sanitized_name != name: - schema_changed = True - sanitized_field = pa.field(sanitized_name, field.type, - field.nullable, field.metadata) - sanitized_fields.append(sanitized_field) - else: - sanitized_fields.append(field) - - new_schema = pa.schema(sanitized_fields, metadata=schema.metadata) - return new_schema, schema_changed - else: +def _sanitize_schema(schema, flavor): + if 'spark' not in flavor: return schema, False + sanitized_fields = [] + schema_changed = False + + for field in schema: + sanitized_field, changed = _sanitize_field_recursive(field) + sanitized_fields.append(sanitized_field) + schema_changed = schema_changed or changed + + new_schema = pa.schema(sanitized_fields, metadata=schema.metadata) + return new_schema, schema_changed + def _sanitize_table(table, new_schema, flavor): - # TODO: This will not handle prohibited characters in nested field names if 'spark' in flavor: column_data = [table[i] for i in range(table.num_columns)] return pa.Table.from_arrays(column_data, schema=new_schema) diff --git a/python/pyarrow/tests/parquet/test_basic.py b/python/pyarrow/tests/parquet/test_basic.py index 94868741f39..74a7d88565f 100644 --- a/python/pyarrow/tests/parquet/test_basic.py +++ b/python/pyarrow/tests/parquet/test_basic.py @@ -613,14 +613,127 @@ def test_compression_level(): def test_sanitized_spark_field_names(): - a0 = pa.array([0, 1, 2, 3, 4]) - name = 'prohib; ,\t{}' - table = pa.Table.from_arrays([a0], [name]) + field_metadata = {b'key': b'value'} + schema_metadata = {b'schema_key': b'schema_value'} + + schema = pa.schema([ + pa.field('prohib; ,\t{}', pa.int32()), + pa.field('field=with\nspecial', pa.string(), metadata=field_metadata), + pa.field('nested_struct', pa.struct([ + pa.field('field,comma', pa.int32()), + pa.field('deeply{nested}', pa.struct([ + pa.field('field(parens)', pa.float64()), + pa.field('normal_field', pa.bool_()) + ])) + ])) + ], metadata=schema_metadata) + + data = [ + pa.array([1, 2]), + pa.array(['a', 'b']), + pa.array([ + {'field,comma': 10, 'deeply{nested}': { + 'field(parens)': 1.5, 'normal_field': True}}, + {'field,comma': 20, 'deeply{nested}': { + 'field(parens)': 2.5, 'normal_field': False}} + ], type=schema[2].type) + ] + + table = pa.Table.from_arrays(data, schema=schema) + result = _roundtrip_table(table, write_table_kwargs={'flavor': 'spark'}) + + assert result.schema[0].name == 'prohib______' + assert result.schema[1].name == 'field_with_special' + + nested_type = result.schema[2].type + assert nested_type[0].name == 'field_comma' + assert nested_type[1].name == 'deeply_nested_' + + deep_type = nested_type[1].type + assert deep_type[0].name == 'field_parens_' + assert deep_type[1].name == 'normal_field' + + assert result.schema[1].metadata == field_metadata + assert result.schema.metadata == schema_metadata + assert len(result) == 2 + + +def test_sanitized_spark_field_names_nested(): + # Test that field name sanitization works for structs nested inside + # lists, maps, and other complex types + schema = pa.schema([ + # List containing struct with special chars + pa.field('list;field', pa.list_(pa.field('item', pa.struct([ + pa.field('field,name', pa.int32()), + pa.field('other{field}', pa.string()) + ])))), + # Large list with nested struct + pa.field('large=list', pa.large_list(pa.field('element', pa.struct([ + pa.field('nested(field)', pa.float64()) + ])))), + # Fixed size list with nested struct + pa.field('fixed\tlist', pa.list_(pa.field('item', pa.struct([ + pa.field('special field', pa.int32()) + ])), 2)), + # Map with structs in both key and value + pa.field('map field', pa.map_( + pa.field('key', pa.struct( + [pa.field('key;field', pa.string())]), nullable=False), + pa.field('value', pa.struct([pa.field('value,field', pa.int32())])) + )) + ]) + + list_data = pa.array([ + [{'field,name': 1, 'other{field}': 'a'}], + [{'field,name': 2, 'other{field}': 'b'}] + ], type=schema[0].type) + + large_list_data = pa.array([ + [{'nested(field)': 1.5}], + [{'nested(field)': 2.5}] + ], type=schema[1].type) + + fixed_list_data = pa.array([ + [{'special field': 10}, {'special field': 20}], + [{'special field': 30}, {'special field': 40}] + ], type=schema[2].type) + + map_data = pa.array([ + [({'key;field': 'k1'}, {'value,field': 100})], + [({'key;field': 'k2'}, {'value,field': 200})] + ], type=schema[3].type) + + table = pa.Table.from_arrays( + [list_data, large_list_data, fixed_list_data, map_data], + schema=schema + ) result = _roundtrip_table(table, write_table_kwargs={'flavor': 'spark'}) - expected_name = 'prohib______' - assert result.schema[0].name == expected_name + # Check top-level field names are sanitized + assert result.schema[0].name == 'list_field' + assert result.schema[1].name == 'large_list' + assert result.schema[2].name == 'fixed_list' + assert result.schema[3].name == 'map_field' + + # Check list value field's struct has sanitized names + list_value_type = result.schema[0].type.value_type + assert list_value_type[0].name == 'field_name' + assert list_value_type[1].name == 'other_field_' + + # Check large list value field's struct has sanitized names + large_list_value_type = result.schema[1].type.value_type + assert large_list_value_type[0].name == 'nested_field_' + + # Check fixed size list value field's struct has sanitized names + fixed_list_value_type = result.schema[2].type.value_type + assert fixed_list_value_type[0].name == 'special_field' + + # Check map key and item structs have sanitized names + map_key_type = result.schema[3].type.key_type + map_item_type = result.schema[3].type.item_type + assert map_key_type[0].name == 'key_field' + assert map_item_type[0].name == 'value_field' @pytest.mark.pandas