Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion arrow-pyarrow-integration-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use arrow::compute::kernels;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, ToPyArrow};
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, Table, ToPyArrow};
use arrow::record_batch::RecordBatch;

fn to_py_err(err: ArrowError) -> PyErr {
Expand Down Expand Up @@ -140,6 +140,25 @@ fn round_trip_record_batch_reader(
Ok(obj)
}

#[pyfunction]
fn round_trip_table(obj: PyArrowType<Table>) -> PyResult<PyArrowType<Table>> {
Ok(obj)
}

#[pyfunction]
pub fn build_table(
record_batches: Vec<PyArrowType<RecordBatch>>,
schema: PyArrowType<Schema>,
) -> PyResult<PyArrowType<Table>> {
Ok(PyArrowType(
Table::try_new(
record_batches.into_iter().map(|rb| rb.0).collect(),
Arc::new(schema.0),
)
.map_err(to_py_err)?,
))
}

#[pyfunction]
fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) -> PyResult<()> {
// This makes sure we can correctly consume a RBR and return the error,
Expand Down Expand Up @@ -178,6 +197,8 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &Bound<PyModule>) -> PyResu
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
m.add_wrapped(wrap_pyfunction!(round_trip_table))?;
m.add_wrapped(wrap_pyfunction!(build_table))?;
m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?;
Ok(())
Expand Down
65 changes: 65 additions & 0 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,71 @@ def test_table_pycapsule():
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_empty():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
table = pa.Table.from_batches([], schema=schema)
new_table = rust.build_table([], schema=schema)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_roundtrip():
"""
Python -> Rust -> Python
"""
metadata = {b'key1': b'value1'}
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata=metadata)
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches, schema=schema)
# TODO: Remove these `assert`s as soon as the metadata issue is solved in Rust
assert table.schema.metadata == metadata
assert all(batch.schema.metadata == metadata for batch in table.to_batches())
new_table = rust.round_trip_table(table)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_from_batches():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches)
new_table = rust.build_table(batches, schema)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_error_inconsistent_schema():
"""
Python -> Rust -> Python
"""
schema_1 = pa.schema([('ints', pa.list_(pa.int32()))])
schema_2 = pa.schema([('floats', pa.list_(pa.float32()))])
batches = [
pa.record_batch([[[1], [2, 42]]], schema_1),
pa.record_batch([[None, [], [5.6, 6.4]]], schema_2),
]
with pytest.raises(pa.ArrowException, match="Schema error: All record batches must have the same schema."):
rust.build_table(batches, schema_1)


def test_reject_other_classes():
# Arbitrary type that is not a PyArrow type
not_pyarrow = ["hello"]
Expand Down
133 changes: 125 additions & 8 deletions arrow-pyarrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,20 @@
//! | `pyarrow.Array` | [ArrayData] |
//! | `pyarrow.RecordBatch` | [RecordBatch] |
//! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) |
//! | `pyarrow.Table` | [Table] (2) |
//!
//! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either
//! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported
//! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically
//! easier to create.)
//!
//! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't
//! have these same concepts. A chunked table is instead represented with
//! `Vec<RecordBatch>`. A `pyarrow.Table` can be imported to Rust by calling
//! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader)
//! and then importing the reader as a [ArrowArrayStreamReader].
//! (2) Although arrow-rs offers [Table], a convenience wrapper for [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table)
//! that internally holds `Vec<RecordBatch>`, it is meant primarily for use cases where you already
//! have `Vec<RecordBatch>` on the Rust side and want to export that in bulk as a `pyarrow.Table`.
//! In general, it is recommended to use streaming approaches instead of dealing with data in bulk.
//! For example, a `pyarrow.Table` (or any other object that implements the ArrayStream PyCapsule
//! interface) can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>>` instead of
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor doc typo: the generic in PyArrowType<ArrowArrayStreamReader>> has an extra closing >; should be PyArrowType<ArrowArrayStreamReader>.

🤖 Was this useful? React with 👍 or 👎

//! forcing eager reading into `Vec<RecordBatch>`.

use std::convert::{From, TryFrom};
use std::ptr::{addr_of, addr_of_mut};
Expand All @@ -68,13 +71,13 @@ use arrow_array::{
make_array,
};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field, Schema};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyList, PyTuple};
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
use pyo3::{import_exception, intern};

import_exception!(pyarrow, ArrowException);
/// Represents an exception raised by PyArrow.
Expand Down Expand Up @@ -484,6 +487,120 @@ impl IntoPyArrow for ArrowArrayStreamReader {
}
}

/// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from
/// and to `pyarrow.Table`.
///
/// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly
/// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule
/// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more
/// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust
/// side.
///
/// ```ignore
/// #[pyfunction]
/// fn return_table(...) -> PyResult<PyArrowType<Table>> {
/// let batches: Vec<RecordBatch>;
/// let schema: SchemaRef;
/// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?)
/// }
/// ```
#[derive(Clone)]
pub struct Table {
record_batches: Vec<RecordBatch>,
schema: SchemaRef,
}

impl Table {
pub fn try_new(
record_batches: Vec<RecordBatch>,
schema: SchemaRef,
) -> Result<Self, ArrowError> {
/// This function was copied from `pyo3_arrow/utils.rs` for now. I don't understand yet why
/// this is required instead of a "normal" `schema == record_batch.schema()` check.
///
/// TODO: Either remove this check, replace it with something already existing in `arrow-rs`
/// or move it to a central `utils` location.
fn schema_equals(left: &SchemaRef, right: &SchemaRef) -> bool {
left.fields
.iter()
.zip(right.fields.iter())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema_equals helper zips field iterators without checking field counts, so schemas with differing numbers of fields could incorrectly be treated as equal; consider verifying left.fields.len() == right.fields.len() before zipping. This would prevent accepting mismatched schemas in Table::try_new.

🤖 Was this useful? React with 👍 or 👎

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback:The Augment AI reviewer is correct that the implementation of the equals function is not correct! The finding prevents a wrong result when of the schemas has more fields than the other.

.all(|(left_field, right_field)| {
left_field.name() == right_field.name()
&& left_field
.data_type()
.equals_datatype(right_field.data_type())
})
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Incomplete Field Validation

The schema_equals function incorrectly returns true when comparing schemas with different field counts if all fields in the shorter schema match. The zip iterator stops at the shorter of the two iterators, so comparing a 3-field schema against a 2-field schema only checks the first 2 fields. Schemas with different field counts are not equal and should fail validation.

Fix in Cursor Fix in Web

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback:The Bugbot AI reviewer is correct that the implementation of the equals function is not correct! The finding prevents a wrong result when of the schemas has more fields than the other.


for record_batch in &record_batches {
if !schema_equals(&schema, &record_batch.schema()) {
return Err(ArrowError::SchemaError(
//"All record batches must have the same schema.".to_owned(),
format!(
"All record batches must have the same schema. \
Expected schema: {:?}, got schema: {:?}",
schema,
record_batch.schema()
Comment on lines +523 to +543
Copy link
Copy Markdown

@coderabbitai coderabbitai bot Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add field-count equality check in schema_equals

schema_equals zips the field iterators but never verifies they have the same length, so a batch with fewer columns than the provided schema still returns true and Table::try_new accepts it. That yields a Table whose stored schema disagrees with the batches and will fail later when exporting to PyArrow. Please reject mismatched field counts before the zip comparison.

         fn schema_equals(left: &SchemaRef, right: &SchemaRef) -> bool {
+            if left.fields.len() != right.fields.len() {
+                return false;
+            }
             left.fields
                 .iter()
                 .zip(right.fields.iter())
🤖 Prompt for AI Agents
In arrow-pyarrow/src/lib.rs around lines 523 to 543, schema_equals currently
zips the two field iterators but doesn't verify they have the same length,
allowing schemas with mismatched field counts to be considered equal; modify
schema_equals to first compare left.fields.len() and right.fields.len() and
return false if they differ, then proceed with the existing zip/all comparison
so only schemas with identical field counts and matching names/types are
considered equal.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback:The CodeRabbit AI reviewer is correct that the implementation of the equals function is not correct! The finding prevents a wrong result when of the schemas has more fields than the other.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martin-augment Thank you for confirming the issue! I'm glad the review caught this bug before it could cause problems downstream when exporting to PyArrow. The field-count check will ensure that only truly equivalent schemas are accepted by Table::try_new. 🎯

    (\(\
    ( -.-)
    o_(")(")
    
    Happy to help! 🎉

),
));
}
}
Ok(Self {
record_batches,
schema,
})
}

pub fn record_batches(&self) -> &[RecordBatch] {
&self.record_batches
}

pub fn schema(&self) -> SchemaRef {
self.schema.clone()
}

pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
(self.record_batches, self.schema)
}
}

impl TryFrom<Box<dyn RecordBatchReader>> for Table {
type Error = ArrowError;

fn try_from(value: Box<dyn RecordBatchReader>) -> Result<Self, ArrowError> {
let schema = value.schema();
let batches = value.collect::<Result<Vec<_>, _>>()?;
Self::try_new(batches, schema)
}
}

/// Convert a `pyarrow.Table` (or any other ArrowArrayStream compliant object) into [`Table`]
impl FromPyArrow for Table {
fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
let reader: Box<dyn RecordBatchReader> =
Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
Self::try_from(reader).map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
}
}

/// Convert a [`Table`] into `pyarrow.Table`.
impl IntoPyArrow for Table {
fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
let module = py.import(intern!(py, "pyarrow"))?;
let class = module.getattr(intern!(py, "Table"))?;

let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));

let kwargs = PyDict::new(py);
kwargs.set_item("schema", py_schema)?;

let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?;

Ok(reader)
}
}

/// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`].
///
/// When wrapped around a type `T: FromPyArrow`, it
Expand Down