diff --git a/arrow-array/src/ffi_stream.rs b/arrow-array/src/ffi_stream.rs index 27c020e5c08b..c46943682914 100644 --- a/arrow-array/src/ffi_stream.rs +++ b/arrow-array/src/ffi_stream.rs @@ -364,7 +364,9 @@ impl Iterator for ArrowArrayStreamReader { let result = unsafe { from_ffi_and_data_type(array, DataType::Struct(self.schema().fields().clone())) }; - Some(result.map(|data| RecordBatch::from(StructArray::from(data)))) + Some(result.and_then(|data| { + RecordBatch::try_new(self.schema.clone(), StructArray::from(data).into_parts().1) + })) } else { let last_error = self.get_stream_last_error(); let err = ArrowError::CDataInterface(last_error.unwrap()); @@ -382,6 +384,7 @@ impl RecordBatchReader for ArrowArrayStreamReader { #[cfg(test)] mod tests { use super::*; + use std::collections::HashMap; use arrow_schema::Field; @@ -417,11 +420,18 @@ mod tests { } fn _test_round_trip_export(arrays: Vec>) -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", arrays[0].data_type().clone(), true), - Field::new("b", arrays[1].data_type().clone(), true), - Field::new("c", arrays[2].data_type().clone(), true), - ])); + let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]); + let schema = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("a", arrays[0].data_type().clone(), true) + .with_metadata(metadata.clone()), + Field::new("b", arrays[1].data_type().clone(), true) + .with_metadata(metadata.clone()), + Field::new("c", arrays[2].data_type().clone(), true) + .with_metadata(metadata.clone()), + ], + metadata, + )); let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap(); let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _; @@ -452,7 +462,11 @@ mod tests { let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap(); - let record_batch = RecordBatch::from(StructArray::from(array)); + let record_batch = RecordBatch::try_new( + SchemaRef::from(exported_schema.clone()), + StructArray::from(array).into_parts().1, + ) + .unwrap(); produced_batches.push(record_batch); } @@ -462,11 +476,18 @@ mod tests { } fn _test_round_trip_import(arrays: Vec>) -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", arrays[0].data_type().clone(), true), - Field::new("b", arrays[1].data_type().clone(), true), - Field::new("c", arrays[2].data_type().clone(), true), - ])); + let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]); + let schema = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("a", arrays[0].data_type().clone(), true) + .with_metadata(metadata.clone()), + Field::new("b", arrays[1].data_type().clone(), true) + .with_metadata(metadata.clone()), + Field::new("c", arrays[2].data_type().clone(), true) + .with_metadata(metadata.clone()), + ], + metadata, + )); let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap(); let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _; diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 79220fb6a69f..b9b04ddee509 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -527,7 +527,7 @@ def test_empty_recordbatch_with_row_count(): """ # Create an empty schema with no fields - batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3, 4]}).select([]) + batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3, 4]}, metadata={b'key1': b'value1'}).select([]) num_rows = 4 assert batch.num_rows == num_rows assert batch.num_columns == 0 @@ -545,7 +545,7 @@ def test_record_batch_reader(): """ Python -> Rust -> Python """ - schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'}) batches = [ pa.record_batch([[[1], [2, 42]]], schema), pa.record_batch([[None, [], [5, 6]]], schema), @@ -571,7 +571,7 @@ def test_record_batch_reader_pycapsule(): """ Python -> Rust -> Python """ - schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'}) batches = [ pa.record_batch([[[1], [2, 42]]], schema), pa.record_batch([[None, [], [5, 6]]], schema), @@ -621,7 +621,7 @@ def test_record_batch_pycapsule(): """ Python -> Rust -> Python """ - schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'}) batch = pa.record_batch([[[1], [2, 42]]], schema) wrapped = StreamWrapper(batch) b = rust.round_trip_record_batch_reader(wrapped) @@ -640,7 +640,7 @@ def test_table_pycapsule(): """ Python -> Rust -> Python """ - schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'}) batches = [ pa.record_batch([[[1], [2, 42]]], schema), pa.record_batch([[None, [], [5, 6]]], schema), @@ -650,8 +650,9 @@ def test_table_pycapsule(): b = rust.round_trip_record_batch_reader(wrapped) new_table = b.read_all() - assert table.schema == new_table.schema assert table == new_table + assert table.schema == new_table.schema + assert table.schema.metadata == new_table.schema.metadata assert len(table.to_batches()) == len(new_table.to_batches()) @@ -659,12 +660,13 @@ def test_table_empty(): """ Python -> Rust -> Python """ - schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'}) table = pa.Table.from_batches([], schema=schema) new_table = rust.build_table([], schema=schema) - assert table.schema == new_table.schema assert table == new_table + assert table.schema == new_table.schema + assert table.schema.metadata == new_table.schema.metadata assert len(table.to_batches()) == len(new_table.to_batches()) @@ -672,7 +674,7 @@ def test_table_roundtrip(): """ Python -> Rust -> Python """ - schema = pa.schema([('ints', pa.list_(pa.int32()))]) + schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'}) batches = [ pa.record_batch([[[1], [2, 42]]], schema), pa.record_batch([[None, [], [5, 6]]], schema), @@ -680,8 +682,9 @@ def test_table_roundtrip(): table = pa.Table.from_batches(batches, schema=schema) new_table = rust.round_trip_table(table) - assert table.schema == new_table.schema assert table == new_table + assert table.schema == new_table.schema + assert table.schema.metadata == new_table.schema.metadata assert len(table.to_batches()) == len(new_table.to_batches()) @@ -689,7 +692,7 @@ def test_table_from_batches(): """ Python -> Rust -> Python """ - schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'}) batches = [ pa.record_batch([[[1], [2, 42]]], schema), pa.record_batch([[None, [], [5, 6]]], schema), @@ -697,8 +700,9 @@ def test_table_from_batches(): table = pa.Table.from_batches(batches) new_table = rust.build_table(batches, schema) - assert table.schema == new_table.schema assert table == new_table + assert table.schema == new_table.schema + assert table.schema.metadata == new_table.schema.metadata assert len(table.to_batches()) == len(new_table.to_batches())