From f04a0d6191f97ab072cd0e56745e427f36449e57 Mon Sep 17 00:00:00 2001 From: Jonas Dedden Date: Wed, 3 Dec 2025 14:56:33 +0100 Subject: [PATCH] Implement `pyarrow.Table` convenience wrapper Remove metadata tests --- arrow-pyarrow-integration-testing/src/lib.rs | 24 +++- .../tests/test_sql.py | 106 ++++++++++++++-- arrow-pyarrow/src/lib.rs | 113 ++++++++++++++++-- 3 files changed, 223 insertions(+), 20 deletions(-) diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 7d5d63c1d50d..a5690b307040 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -32,7 +32,7 @@ use arrow::compute::kernels; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; use arrow::ffi_stream::ArrowArrayStreamReader; -use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, ToPyArrow}; +use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, Table, ToPyArrow}; use arrow::record_batch::RecordBatch; fn to_py_err(err: ArrowError) -> PyErr { @@ -140,6 +140,26 @@ fn round_trip_record_batch_reader( Ok(obj) } +#[pyfunction] +fn round_trip_table(obj: PyArrowType) -> PyResult> { + Ok(obj) +} + +/// Builds a Table from a list of RecordBatches and a Schema. +#[pyfunction] +pub fn build_table( + record_batches: Vec>, + schema: PyArrowType, +) -> PyResult> { + Ok(PyArrowType( + Table::try_new( + record_batches.into_iter().map(|rb| rb.0).collect(), + Arc::new(schema.0), + ) + .map_err(to_py_err)?, + )) +} + #[pyfunction] fn reader_return_errors(obj: PyArrowType) -> PyResult<()> { // This makes sure we can correctly consume a RBR and return the error, @@ -178,6 +198,8 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &Bound) -> PyResu m.add_wrapped(wrap_pyfunction!(round_trip_array))?; m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?; m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?; + m.add_wrapped(wrap_pyfunction!(round_trip_table))?; + m.add_wrapped(wrap_pyfunction!(build_table))?; m.add_wrapped(wrap_pyfunction!(reader_return_errors))?; m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?; Ok(()) diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 3b46d5729a1f..73d243f01372 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -20,6 +20,7 @@ import datetime import decimal import string +from typing import Union, Tuple, Protocol import pytest import pyarrow as pa @@ -120,28 +121,50 @@ def assert_pyarrow_leak(): # This defines that Arrow consumers should allow any object that has specific "dunder" # methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able to handle # _any_ class, without pyarrow-specific handling. -class SchemaWrapper: - def __init__(self, schema): + + +class ArrowSchemaExportable(Protocol): + def __arrow_c_schema__(self) -> object: ... + + +class ArrowArrayExportable(Protocol): + def __arrow_c_array__( + self, + requested_schema: Union[object, None] = None + ) -> Tuple[object, object]: + ... + + +class ArrowStreamExportable(Protocol): + def __arrow_c_stream__( + self, + requested_schema: Union[object, None] = None + ) -> object: + ... + + +class SchemaWrapper(ArrowSchemaExportable): + def __init__(self, schema: ArrowSchemaExportable) -> None: self.schema = schema - def __arrow_c_schema__(self): + def __arrow_c_schema__(self) -> object: return self.schema.__arrow_c_schema__() -class ArrayWrapper: - def __init__(self, array): +class ArrayWrapper(ArrowArrayExportable): + def __init__(self, array: ArrowArrayExportable) -> None: self.array = array - def __arrow_c_array__(self): - return self.array.__arrow_c_array__() + def __arrow_c_array__(self, requested_schema: Union[object, None] = None) -> Tuple[object, object]: + return self.array.__arrow_c_array__(requested_schema=requested_schema) -class StreamWrapper: - def __init__(self, stream): +class StreamWrapper(ArrowStreamExportable): + def __init__(self, stream: ArrowStreamExportable) -> None: self.stream = stream - def __arrow_c_stream__(self): - return self.stream.__arrow_c_stream__() + def __arrow_c_stream__(self, requested_schema: Union[object, None] = None) -> object: + return self.stream.__arrow_c_stream__(requested_schema=requested_schema) @pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str) @@ -613,6 +636,67 @@ def test_table_pycapsule(): assert len(table.to_batches()) == len(new_table.to_batches()) +def test_table_empty(): + """ + Python -> Rust -> Python + """ + schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + table = pa.Table.from_batches([], schema=schema) + new_table = rust.build_table([], schema=schema) + + assert table.schema == new_table.schema + assert table == new_table + assert len(table.to_batches()) == len(new_table.to_batches()) + + +def test_table_roundtrip(): + """ + Python -> Rust -> Python + """ + schema = pa.schema([('ints', pa.list_(pa.int32()))]) + batches = [ + pa.record_batch([[[1], [2, 42]]], schema), + pa.record_batch([[None, [], [5, 6]]], schema), + ] + table = pa.Table.from_batches(batches, schema=schema) + new_table = rust.round_trip_table(table) + + assert table.schema == new_table.schema + assert table == new_table + assert len(table.to_batches()) == len(new_table.to_batches()) + + +def test_table_from_batches(): + """ + Python -> Rust -> Python + """ + schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + batches = [ + pa.record_batch([[[1], [2, 42]]], schema), + pa.record_batch([[None, [], [5, 6]]], schema), + ] + table = pa.Table.from_batches(batches) + new_table = rust.build_table(batches, schema) + + assert table.schema == new_table.schema + assert table == new_table + assert len(table.to_batches()) == len(new_table.to_batches()) + + +def test_table_error_inconsistent_schema(): + """ + Python -> Rust -> Python + """ + schema_1 = pa.schema([('ints', pa.list_(pa.int32()))]) + schema_2 = pa.schema([('floats', pa.list_(pa.float32()))]) + batches = [ + pa.record_batch([[[1], [2, 42]]], schema_1), + pa.record_batch([[None, [], [5.6, 6.4]]], schema_2), + ] + with pytest.raises(pa.ArrowException, match="Schema error: All record batches must have the same schema."): + rust.build_table(batches, schema_1) + + def test_reject_other_classes(): # Arbitrary type that is not a PyArrow type not_pyarrow = ["hello"] diff --git a/arrow-pyarrow/src/lib.rs b/arrow-pyarrow/src/lib.rs index d4bbb201f027..1f8941ef1cf5 100644 --- a/arrow-pyarrow/src/lib.rs +++ b/arrow-pyarrow/src/lib.rs @@ -44,17 +44,20 @@ //! | `pyarrow.Array` | [ArrayData] | //! | `pyarrow.RecordBatch` | [RecordBatch] | //! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box` (1) | +//! | `pyarrow.Table` | [Table] (2) | //! //! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either //! [ArrowArrayStreamReader] or `Box` can be exported //! as `pyarrow.RecordBatchReader`. (`Box` is typically //! easier to create.) //! -//! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't -//! have these same concepts. A chunked table is instead represented with -//! `Vec`. A `pyarrow.Table` can be imported to Rust by calling -//! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader) -//! and then importing the reader as a [ArrowArrayStreamReader]. +//! (2) Although arrow-rs offers [Table], a convenience wrapper for [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table) +//! that internally holds `Vec`, it is meant primarily for use cases where you already +//! have `Vec` on the Rust side and want to export that in bulk as a `pyarrow.Table`. +//! In general, it is recommended to use streaming approaches instead of dealing with data in bulk. +//! For example, a `pyarrow.Table` (or any other object that implements the ArrayStream PyCapsule +//! interface) can be imported to Rust through `PyArrowType` instead of +//! forcing eager reading into `Vec`. use std::convert::{From, TryFrom}; use std::ptr::{addr_of, addr_of_mut}; @@ -68,13 +71,13 @@ use arrow_array::{ make_array, }; use arrow_data::ArrayData; -use arrow_schema::{ArrowError, DataType, Field, Schema}; +use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::ffi::Py_uintptr_t; -use pyo3::import_exception; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; -use pyo3::types::{PyCapsule, PyList, PyTuple}; +use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple}; +use pyo3::{import_exception, intern}; import_exception!(pyarrow, ArrowException); /// Represents an exception raised by PyArrow. @@ -484,6 +487,100 @@ impl IntoPyArrow for ArrowArrayStreamReader { } } +/// This is a convenience wrapper around `Vec` that tries to simplify conversion from +/// and to `pyarrow.Table`. +/// +/// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly +/// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule +/// interface, one could also consume a `PyArrowType` instead) or, more +/// importantly, where one wants to export a `pyarrow.Table` from a `Vec` from the Rust +/// side. +/// +/// ```ignore +/// #[pyfunction] +/// fn return_table(...) -> PyResult> { +/// let batches: Vec; +/// let schema: SchemaRef; +/// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?) +/// } +/// ``` +#[derive(Clone)] +pub struct Table { + record_batches: Vec, + schema: SchemaRef, +} + +impl Table { + pub fn try_new( + record_batches: Vec, + schema: SchemaRef, + ) -> Result { + for record_batch in &record_batches { + if schema != record_batch.schema() { + return Err(ArrowError::SchemaError(format!( + "All record batches must have the same schema. \ + Expected schema: {:?}, got schema: {:?}", + schema, + record_batch.schema() + ))); + } + } + Ok(Self { + record_batches, + schema, + }) + } + + pub fn record_batches(&self) -> &[RecordBatch] { + &self.record_batches + } + + pub fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + pub fn into_inner(self) -> (Vec, SchemaRef) { + (self.record_batches, self.schema) + } +} + +impl TryFrom> for Table { + type Error = ArrowError; + + fn try_from(value: Box) -> Result { + let schema = value.schema(); + let batches = value.collect::, _>>()?; + Self::try_new(batches, schema) + } +} + +/// Convert a `pyarrow.Table` (or any other ArrowArrayStream compliant object) into [`Table`] +impl FromPyArrow for Table { + fn from_pyarrow_bound(ob: &Bound) -> PyResult { + let reader: Box = + Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?); + Self::try_from(reader).map_err(|err| PyErr::new::(err.to_string())) + } +} + +/// Convert a [`Table`] into `pyarrow.Table`. +impl IntoPyArrow for Table { + fn into_pyarrow(self, py: Python) -> PyResult> { + let module = py.import(intern!(py, "pyarrow"))?; + let class = module.getattr(intern!(py, "Table"))?; + + let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?; + let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema)); + + let kwargs = PyDict::new(py); + kwargs.set_item("schema", py_schema)?; + + let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?; + + Ok(reader) + } +} + /// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`]. /// /// When wrapped around a type `T: FromPyArrow`, it