diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs
index 7d5d63c1d50d..a5690b307040 100644
--- a/arrow-pyarrow-integration-testing/src/lib.rs
+++ b/arrow-pyarrow-integration-testing/src/lib.rs
@@ -32,7 +32,7 @@ use arrow::compute::kernels;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use arrow::ffi_stream::ArrowArrayStreamReader;
-use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, ToPyArrow};
+use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, Table, ToPyArrow};
use arrow::record_batch::RecordBatch;
fn to_py_err(err: ArrowError) -> PyErr {
@@ -140,6 +140,26 @@ fn round_trip_record_batch_reader(
Ok(obj)
}
+#[pyfunction]
+fn round_trip_table(obj: PyArrowType
) -> PyResult> {
+ Ok(obj)
+}
+
+/// Builds a Table from a list of RecordBatches and a Schema.
+#[pyfunction]
+pub fn build_table(
+ record_batches: Vec>,
+ schema: PyArrowType,
+) -> PyResult> {
+ Ok(PyArrowType(
+ Table::try_new(
+ record_batches.into_iter().map(|rb| rb.0).collect(),
+ Arc::new(schema.0),
+ )
+ .map_err(to_py_err)?,
+ ))
+}
+
#[pyfunction]
fn reader_return_errors(obj: PyArrowType) -> PyResult<()> {
// This makes sure we can correctly consume a RBR and return the error,
@@ -178,6 +198,8 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &Bound) -> PyResu
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
+ m.add_wrapped(wrap_pyfunction!(round_trip_table))?;
+ m.add_wrapped(wrap_pyfunction!(build_table))?;
m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?;
Ok(())
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py
index 3b46d5729a1f..73d243f01372 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -20,6 +20,7 @@
import datetime
import decimal
import string
+from typing import Union, Tuple, Protocol
import pytest
import pyarrow as pa
@@ -120,28 +121,50 @@ def assert_pyarrow_leak():
# This defines that Arrow consumers should allow any object that has specific "dunder"
# methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able to handle
# _any_ class, without pyarrow-specific handling.
-class SchemaWrapper:
- def __init__(self, schema):
+
+
+class ArrowSchemaExportable(Protocol):
+ def __arrow_c_schema__(self) -> object: ...
+
+
+class ArrowArrayExportable(Protocol):
+ def __arrow_c_array__(
+ self,
+ requested_schema: Union[object, None] = None
+ ) -> Tuple[object, object]:
+ ...
+
+
+class ArrowStreamExportable(Protocol):
+ def __arrow_c_stream__(
+ self,
+ requested_schema: Union[object, None] = None
+ ) -> object:
+ ...
+
+
+class SchemaWrapper(ArrowSchemaExportable):
+ def __init__(self, schema: ArrowSchemaExportable) -> None:
self.schema = schema
- def __arrow_c_schema__(self):
+ def __arrow_c_schema__(self) -> object:
return self.schema.__arrow_c_schema__()
-class ArrayWrapper:
- def __init__(self, array):
+class ArrayWrapper(ArrowArrayExportable):
+ def __init__(self, array: ArrowArrayExportable) -> None:
self.array = array
- def __arrow_c_array__(self):
- return self.array.__arrow_c_array__()
+ def __arrow_c_array__(self, requested_schema: Union[object, None] = None) -> Tuple[object, object]:
+ return self.array.__arrow_c_array__(requested_schema=requested_schema)
-class StreamWrapper:
- def __init__(self, stream):
+class StreamWrapper(ArrowStreamExportable):
+ def __init__(self, stream: ArrowStreamExportable) -> None:
self.stream = stream
- def __arrow_c_stream__(self):
- return self.stream.__arrow_c_stream__()
+ def __arrow_c_stream__(self, requested_schema: Union[object, None] = None) -> object:
+ return self.stream.__arrow_c_stream__(requested_schema=requested_schema)
@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
@@ -613,6 +636,67 @@ def test_table_pycapsule():
assert len(table.to_batches()) == len(new_table.to_batches())
+def test_table_empty():
+ """
+ Python -> Rust -> Python
+ """
+ schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
+ table = pa.Table.from_batches([], schema=schema)
+ new_table = rust.build_table([], schema=schema)
+
+ assert table.schema == new_table.schema
+ assert table == new_table
+ assert len(table.to_batches()) == len(new_table.to_batches())
+
+
+def test_table_roundtrip():
+ """
+ Python -> Rust -> Python
+ """
+ schema = pa.schema([('ints', pa.list_(pa.int32()))])
+ batches = [
+ pa.record_batch([[[1], [2, 42]]], schema),
+ pa.record_batch([[None, [], [5, 6]]], schema),
+ ]
+ table = pa.Table.from_batches(batches, schema=schema)
+ new_table = rust.round_trip_table(table)
+
+ assert table.schema == new_table.schema
+ assert table == new_table
+ assert len(table.to_batches()) == len(new_table.to_batches())
+
+
+def test_table_from_batches():
+ """
+ Python -> Rust -> Python
+ """
+ schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
+ batches = [
+ pa.record_batch([[[1], [2, 42]]], schema),
+ pa.record_batch([[None, [], [5, 6]]], schema),
+ ]
+ table = pa.Table.from_batches(batches)
+ new_table = rust.build_table(batches, schema)
+
+ assert table.schema == new_table.schema
+ assert table == new_table
+ assert len(table.to_batches()) == len(new_table.to_batches())
+
+
+def test_table_error_inconsistent_schema():
+ """
+ Python -> Rust -> Python
+ """
+ schema_1 = pa.schema([('ints', pa.list_(pa.int32()))])
+ schema_2 = pa.schema([('floats', pa.list_(pa.float32()))])
+ batches = [
+ pa.record_batch([[[1], [2, 42]]], schema_1),
+ pa.record_batch([[None, [], [5.6, 6.4]]], schema_2),
+ ]
+ with pytest.raises(pa.ArrowException, match="Schema error: All record batches must have the same schema."):
+ rust.build_table(batches, schema_1)
+
+
def test_reject_other_classes():
# Arbitrary type that is not a PyArrow type
not_pyarrow = ["hello"]
diff --git a/arrow-pyarrow/src/lib.rs b/arrow-pyarrow/src/lib.rs
index d4bbb201f027..1f8941ef1cf5 100644
--- a/arrow-pyarrow/src/lib.rs
+++ b/arrow-pyarrow/src/lib.rs
@@ -44,17 +44,20 @@
//! | `pyarrow.Array` | [ArrayData] |
//! | `pyarrow.RecordBatch` | [RecordBatch] |
//! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box` (1) |
+//! | `pyarrow.Table` | [Table] (2) |
//!
//! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either
//! [ArrowArrayStreamReader] or `Box` can be exported
//! as `pyarrow.RecordBatchReader`. (`Box` is typically
//! easier to create.)
//!
-//! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't
-//! have these same concepts. A chunked table is instead represented with
-//! `Vec`. A `pyarrow.Table` can be imported to Rust by calling
-//! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader)
-//! and then importing the reader as a [ArrowArrayStreamReader].
+//! (2) Although arrow-rs offers [Table], a convenience wrapper for [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table)
+//! that internally holds `Vec`, it is meant primarily for use cases where you already
+//! have `Vec` on the Rust side and want to export that in bulk as a `pyarrow.Table`.
+//! In general, it is recommended to use streaming approaches instead of dealing with data in bulk.
+//! For example, a `pyarrow.Table` (or any other object that implements the ArrayStream PyCapsule
+//! interface) can be imported to Rust through `PyArrowType` instead of
+//! forcing eager reading into `Vec`.
use std::convert::{From, TryFrom};
use std::ptr::{addr_of, addr_of_mut};
@@ -68,13 +71,13 @@ use arrow_array::{
make_array,
};
use arrow_data::ArrayData;
-use arrow_schema::{ArrowError, DataType, Field, Schema};
+use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
-use pyo3::import_exception;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
-use pyo3::types::{PyCapsule, PyList, PyTuple};
+use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
+use pyo3::{import_exception, intern};
import_exception!(pyarrow, ArrowException);
/// Represents an exception raised by PyArrow.
@@ -484,6 +487,100 @@ impl IntoPyArrow for ArrowArrayStreamReader {
}
}
+/// This is a convenience wrapper around `Vec` that tries to simplify conversion from
+/// and to `pyarrow.Table`.
+///
+/// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly
+/// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule
+/// interface, one could also consume a `PyArrowType` instead) or, more
+/// importantly, where one wants to export a `pyarrow.Table` from a `Vec` from the Rust
+/// side.
+///
+/// ```ignore
+/// #[pyfunction]
+/// fn return_table(...) -> PyResult> {
+/// let batches: Vec;
+/// let schema: SchemaRef;
+/// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?)
+/// }
+/// ```
+#[derive(Clone)]
+pub struct Table {
+ record_batches: Vec,
+ schema: SchemaRef,
+}
+
+impl Table {
+ pub fn try_new(
+ record_batches: Vec,
+ schema: SchemaRef,
+ ) -> Result {
+ for record_batch in &record_batches {
+ if schema != record_batch.schema() {
+ return Err(ArrowError::SchemaError(format!(
+ "All record batches must have the same schema. \
+ Expected schema: {:?}, got schema: {:?}",
+ schema,
+ record_batch.schema()
+ )));
+ }
+ }
+ Ok(Self {
+ record_batches,
+ schema,
+ })
+ }
+
+ pub fn record_batches(&self) -> &[RecordBatch] {
+ &self.record_batches
+ }
+
+ pub fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
+ pub fn into_inner(self) -> (Vec, SchemaRef) {
+ (self.record_batches, self.schema)
+ }
+}
+
+impl TryFrom> for Table {
+ type Error = ArrowError;
+
+ fn try_from(value: Box) -> Result {
+ let schema = value.schema();
+ let batches = value.collect::, _>>()?;
+ Self::try_new(batches, schema)
+ }
+}
+
+/// Convert a `pyarrow.Table` (or any other ArrowArrayStream compliant object) into [`Table`]
+impl FromPyArrow for Table {
+ fn from_pyarrow_bound(ob: &Bound) -> PyResult {
+ let reader: Box =
+ Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
+ Self::try_from(reader).map_err(|err| PyErr::new::(err.to_string()))
+ }
+}
+
+/// Convert a [`Table`] into `pyarrow.Table`.
+impl IntoPyArrow for Table {
+ fn into_pyarrow(self, py: Python) -> PyResult> {
+ let module = py.import(intern!(py, "pyarrow"))?;
+ let class = module.getattr(intern!(py, "Table"))?;
+
+ let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
+ let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));
+
+ let kwargs = PyDict::new(py);
+ kwargs.set_item("schema", py_schema)?;
+
+ let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?;
+
+ Ok(reader)
+ }
+}
+
/// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`].
///
/// When wrapped around a type `T: FromPyArrow`, it