-
Notifications
You must be signed in to change notification settings - Fork 0
8790: Implement a Vec<RecordBatch> wrapper for pyarrow.Table convenience
#14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,17 +44,20 @@ | |
| //! | `pyarrow.Array` | [ArrayData] | | ||
| //! | `pyarrow.RecordBatch` | [RecordBatch] | | ||
| //! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) | | ||
| //! | `pyarrow.Table` | [Table] (2) | | ||
| //! | ||
| //! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either | ||
| //! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported | ||
| //! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically | ||
| //! easier to create.) | ||
| //! | ||
| //! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't | ||
| //! have these same concepts. A chunked table is instead represented with | ||
| //! `Vec<RecordBatch>`. A `pyarrow.Table` can be imported to Rust by calling | ||
| //! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader) | ||
| //! and then importing the reader as a [ArrowArrayStreamReader]. | ||
| //! (2) Although arrow-rs offers [Table], a convenience wrapper for [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table) | ||
| //! that internally holds `Vec<RecordBatch>`, it is meant primarily for use cases where you already | ||
| //! have `Vec<RecordBatch>` on the Rust side and want to export that in bulk as a `pyarrow.Table`. | ||
| //! In general, it is recommended to use streaming approaches instead of dealing with data in bulk. | ||
| //! For example, a `pyarrow.Table` (or any other object that implements the ArrayStream PyCapsule | ||
| //! interface) can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>>` instead of | ||
| //! forcing eager reading into `Vec<RecordBatch>`. | ||
|
|
||
| use std::convert::{From, TryFrom}; | ||
| use std::ptr::{addr_of, addr_of_mut}; | ||
|
|
@@ -68,13 +71,13 @@ use arrow_array::{ | |
| make_array, | ||
| }; | ||
| use arrow_data::ArrayData; | ||
| use arrow_schema::{ArrowError, DataType, Field, Schema}; | ||
| use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef}; | ||
| use pyo3::exceptions::{PyTypeError, PyValueError}; | ||
| use pyo3::ffi::Py_uintptr_t; | ||
| use pyo3::import_exception; | ||
| use pyo3::prelude::*; | ||
| use pyo3::pybacked::PyBackedStr; | ||
| use pyo3::types::{PyCapsule, PyList, PyTuple}; | ||
| use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple}; | ||
| use pyo3::{import_exception, intern}; | ||
|
|
||
| import_exception!(pyarrow, ArrowException); | ||
| /// Represents an exception raised by PyArrow. | ||
|
|
@@ -484,6 +487,120 @@ impl IntoPyArrow for ArrowArrayStreamReader { | |
| } | ||
| } | ||
|
|
||
| /// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from | ||
| /// and to `pyarrow.Table`. | ||
| /// | ||
| /// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly | ||
| /// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule | ||
| /// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more | ||
| /// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust | ||
| /// side. | ||
| /// | ||
| /// ```ignore | ||
| /// #[pyfunction] | ||
| /// fn return_table(...) -> PyResult<PyArrowType<Table>> { | ||
| /// let batches: Vec<RecordBatch>; | ||
| /// let schema: SchemaRef; | ||
| /// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?) | ||
| /// } | ||
| /// ``` | ||
| #[derive(Clone)] | ||
| pub struct Table { | ||
| record_batches: Vec<RecordBatch>, | ||
| schema: SchemaRef, | ||
| } | ||
|
|
||
| impl Table { | ||
| pub fn try_new( | ||
| record_batches: Vec<RecordBatch>, | ||
| schema: SchemaRef, | ||
| ) -> Result<Self, ArrowError> { | ||
| /// This function was copied from `pyo3_arrow/utils.rs` for now. I don't understand yet why | ||
| /// this is required instead of a "normal" `schema == record_batch.schema()` check. | ||
| /// | ||
| /// TODO: Either remove this check, replace it with something already existing in `arrow-rs` | ||
| /// or move it to a central `utils` location. | ||
| fn schema_equals(left: &SchemaRef, right: &SchemaRef) -> bool { | ||
| left.fields | ||
| .iter() | ||
| .zip(right.fields.iter()) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The 🤖 Was this useful? React with 👍 or 👎
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value:useful; category:bug; feedback:The Augment AI reviewer is correct that the implementation of the equals function is not correct! The finding prevents a wrong result when of the schemas has more fields than the other. |
||
| .all(|(left_field, right_field)| { | ||
| left_field.name() == right_field.name() | ||
| && left_field | ||
| .data_type() | ||
| .equals_datatype(right_field.data_type()) | ||
| }) | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Incomplete Field ValidationThe
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value:useful; category:bug; feedback:The Bugbot AI reviewer is correct that the implementation of the equals function is not correct! The finding prevents a wrong result when of the schemas has more fields than the other. |
||
|
|
||
| for record_batch in &record_batches { | ||
| if !schema_equals(&schema, &record_batch.schema()) { | ||
| return Err(ArrowError::SchemaError( | ||
| //"All record batches must have the same schema.".to_owned(), | ||
| format!( | ||
| "All record batches must have the same schema. \ | ||
| Expected schema: {:?}, got schema: {:?}", | ||
| schema, | ||
| record_batch.schema() | ||
|
Comment on lines
+523
to
+543
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add field-count equality check in
fn schema_equals(left: &SchemaRef, right: &SchemaRef) -> bool {
+ if left.fields.len() != right.fields.len() {
+ return false;
+ }
left.fields
.iter()
.zip(right.fields.iter())🤖 Prompt for AI Agents
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value:useful; category:bug; feedback:The CodeRabbit AI reviewer is correct that the implementation of the equals function is not correct! The finding prevents a wrong result when of the schemas has more fields than the other. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @martin-augment Thank you for confirming the issue! I'm glad the review caught this bug before it could cause problems downstream when exporting to PyArrow. The field-count check will ensure that only truly equivalent schemas are accepted by |
||
| ), | ||
| )); | ||
| } | ||
| } | ||
| Ok(Self { | ||
| record_batches, | ||
| schema, | ||
| }) | ||
| } | ||
|
|
||
| pub fn record_batches(&self) -> &[RecordBatch] { | ||
| &self.record_batches | ||
| } | ||
|
|
||
| pub fn schema(&self) -> SchemaRef { | ||
| self.schema.clone() | ||
| } | ||
|
|
||
| pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) { | ||
| (self.record_batches, self.schema) | ||
| } | ||
| } | ||
|
|
||
| impl TryFrom<Box<dyn RecordBatchReader>> for Table { | ||
| type Error = ArrowError; | ||
|
|
||
| fn try_from(value: Box<dyn RecordBatchReader>) -> Result<Self, ArrowError> { | ||
| let schema = value.schema(); | ||
| let batches = value.collect::<Result<Vec<_>, _>>()?; | ||
| Self::try_new(batches, schema) | ||
| } | ||
| } | ||
|
|
||
| /// Convert a `pyarrow.Table` (or any other ArrowArrayStream compliant object) into [`Table`] | ||
| impl FromPyArrow for Table { | ||
| fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> { | ||
| let reader: Box<dyn RecordBatchReader> = | ||
| Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?); | ||
| Self::try_from(reader).map_err(|err| PyErr::new::<PyValueError, _>(err.to_string())) | ||
| } | ||
| } | ||
|
|
||
| /// Convert a [`Table`] into `pyarrow.Table`. | ||
| impl IntoPyArrow for Table { | ||
| fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> { | ||
| let module = py.import(intern!(py, "pyarrow"))?; | ||
| let class = module.getattr(intern!(py, "Table"))?; | ||
|
|
||
| let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?; | ||
| let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema)); | ||
|
|
||
| let kwargs = PyDict::new(py); | ||
| kwargs.set_item("schema", py_schema)?; | ||
|
|
||
| let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?; | ||
|
|
||
| Ok(reader) | ||
| } | ||
| } | ||
|
|
||
| /// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`]. | ||
| /// | ||
| /// When wrapped around a type `T: FromPyArrow`, it | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor doc typo: the generic in
PyArrowType<ArrowArrayStreamReader>>has an extra closing>; should bePyArrowType<ArrowArrayStreamReader>.🤖 Was this useful? React with 👍 or 👎