diff --git a/arrow-pyarrow-testing/tests/pyarrow.rs b/arrow-pyarrow-testing/tests/pyarrow.rs index 6f3606478c72..1c009b5f2fb6 100644 --- a/arrow-pyarrow-testing/tests/pyarrow.rs +++ b/arrow-pyarrow-testing/tests/pyarrow.rs @@ -84,7 +84,6 @@ fn test_to_pyarrow_byte_view() { ]) .unwrap(); - println!("input: {input:?}"); let res = Python::attach(|py| { let py_input = input.to_pyarrow(py)?; let records = RecordBatch::from_pyarrow_bound(&py_input)?; @@ -120,7 +119,7 @@ value = NotATuple() assert!(err.is_instance_of::(py)); assert_eq!( err.to_string(), - "TypeError: Expected __arrow_c_array__ to return a tuple of (schema, array) capsules." + "TypeError: Expected __arrow_c_array__ to return a pair of capsule." ); }); } diff --git a/arrow-pyarrow/src/lib.rs b/arrow-pyarrow/src/lib.rs index d8f584e396d3..2c2e82ea2d18 100644 --- a/arrow-pyarrow/src/lib.rs +++ b/arrow-pyarrow/src/lib.rs @@ -61,6 +61,7 @@ use std::convert::{From, TryFrom}; use std::ffi::CStr; +use std::ptr::NonNull; use std::sync::Arc; use arrow_array::ffi; @@ -74,19 +75,15 @@ use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::ffi::Py_uintptr_t; -use pyo3::import_exception; use pyo3::prelude::*; use pyo3::sync::PyOnceLock; -use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; +use pyo3::types::{PyCapsule, PyDict, PyList, PyType}; +use pyo3::{CastError, import_exception}; import_exception!(pyarrow, ArrowException); /// Represents an exception raised by PyArrow. pub type PyArrowException = ArrowException; -const ARROW_ARRAY_STREAM_CAPSULE_NAME: &CStr = c"arrow_array_stream"; -const ARROW_SCHEMA_CAPSULE_NAME: &CStr = c"arrow_schema"; -const ARROW_ARRAY_CAPSULE_NAME: &CStr = c"arrow_array"; - fn to_py_err(err: ArrowError) -> PyErr { PyArrowException::new_err(err.to_string()) } @@ -131,54 +128,14 @@ fn validate_class(expected: &Bound, value: &Bound) -> PyResult<() Ok(()) } -fn validate_pycapsule(capsule: &Bound, name: &str) -> PyResult<()> { - let capsule_name = capsule.name()?; - if capsule_name.is_none() { - return Err(PyValueError::new_err( - "Expected schema PyCapsule to have name set.", - )); - } - - let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? }; - if capsule_name != name { - return Err(PyValueError::new_err(format!( - "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'", - ))); - } - - Ok(()) -} - -fn extract_arrow_c_array_capsules<'py>( - value: &Bound<'py, PyAny>, -) -> PyResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> { - let tuple = value.call_method0("__arrow_c_array__")?; - - if !tuple.is_instance_of::() { - return Err(PyTypeError::new_err( - "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.", - )); - } - - tuple.extract().map_err(|_| { - PyTypeError::new_err( - "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.", - ) - }) -} - impl FromPyArrow for DataType { fn from_pyarrow_bound(value: &Bound) -> PyResult { // Newer versions of PyArrow as well as other libraries with Arrow data implement this // method, so prefer it over _export_to_c. // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_schema__")? { - let capsule = value.call_method0("__arrow_c_schema__")?.extract()?; - validate_pycapsule(&capsule, "arrow_schema")?; - - let schema_ptr = capsule - .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))? - .cast::(); + let capsule = call_capsule_method(value, "__arrow_c_schema__")?; + let schema_ptr = extract_capsule::(&capsule, "__arrow_c_schema__")?; return unsafe { DataType::try_from(schema_ptr.as_ref()) }.map_err(to_py_err); } @@ -203,12 +160,8 @@ impl FromPyArrow for Field { // method, so prefer it over _export_to_c. // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_schema__")? { - let capsule = value.call_method0("__arrow_c_schema__")?.extract()?; - validate_pycapsule(&capsule, "arrow_schema")?; - - let schema_ptr = capsule - .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))? - .cast::(); + let capsule = call_capsule_method(value, "__arrow_c_schema__")?; + let schema_ptr = extract_capsule::(&capsule, "__arrow_c_schema__")?; return unsafe { Field::try_from(schema_ptr.as_ref()) }.map_err(to_py_err); } @@ -233,12 +186,8 @@ impl FromPyArrow for Schema { // method, so prefer it over _export_to_c. // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_schema__")? { - let capsule = value.call_method0("__arrow_c_schema__")?.extract()?; - validate_pycapsule(&capsule, "arrow_schema")?; - - let schema_ptr = capsule - .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))? - .cast::(); + let capsule = call_capsule_method(value, "__arrow_c_schema__")?; + let schema_ptr = extract_capsule::(&capsule, "__arrow_c_schema__")?; return unsafe { Schema::try_from(schema_ptr.as_ref()) }.map_err(to_py_err); } @@ -263,22 +212,11 @@ impl FromPyArrow for ArrayData { // method, so prefer it over _export_to_c. // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_array__")? { - let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?; - - validate_pycapsule(&schema_capsule, "arrow_schema")?; - validate_pycapsule(&array_capsule, "arrow_array")?; - - let schema_ptr = schema_capsule - .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))? - .cast::(); - let array = unsafe { - FFI_ArrowArray::from_raw( - array_capsule - .pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))? - .cast::() - .as_ptr(), - ) - }; + let (schema_capsule, array_capsule) = + call_capsule_pair_method(value, "__arrow_c_array__")?; + let schema_ptr = extract_capsule(&schema_capsule, "__arrow_c_array__")?; + let array_ptr = extract_capsule(&array_capsule, "__arrow_c_array__")?; + let array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) }; return unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) }.map_err(to_py_err); } @@ -341,24 +279,18 @@ impl FromPyArrow for RecordBatch { // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_array__")? { - let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?; - - validate_pycapsule(&schema_capsule, "arrow_schema")?; - validate_pycapsule(&array_capsule, "arrow_array")?; - - let schema_ptr = schema_capsule - .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))? - .cast::(); - let array_ptr = array_capsule - .pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))? - .cast::(); + let (schema_capsule, array_capsule) = + call_capsule_pair_method(value, "__arrow_c_array__")?; + let schema_ptr = extract_capsule(&schema_capsule, "__arrow_c_array__")?; + let array_ptr = extract_capsule(&array_capsule, "__arrow_c_array__")?; let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) }; let mut array_data = unsafe { ffi::from_ffi(ffi_array, schema_ptr.as_ref()) }.map_err(to_py_err)?; if !matches!(array_data.data_type(), DataType::Struct(_)) { - return Err(PyTypeError::new_err( - "Expected Struct type from __arrow_c_array.", - )); + return Err(PyTypeError::new_err(format!( + "Expected Struct type from __arrow_c_array__, found {}.", + array_data.data_type() + ))); } let options = RecordBatchOptions::default().with_row_count(Some(array_data.len())); // Ensure data is aligned (by potentially copying the buffers). @@ -421,18 +353,9 @@ impl FromPyArrow for ArrowArrayStreamReader { // method, so prefer it over _export_to_c. // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html if value.hasattr("__arrow_c_stream__")? { - let capsule = value.call_method0("__arrow_c_stream__")?.extract()?; - - validate_pycapsule(&capsule, "arrow_array_stream")?; - - let stream = unsafe { - FFI_ArrowArrayStream::from_raw( - capsule - .pointer_checked(Some(ARROW_ARRAY_STREAM_CAPSULE_NAME))? - .cast::() - .as_ptr(), - ) - }; + let capsule = call_capsule_method(value, "__arrow_c_stream__")?; + let stream_ptr = extract_capsule(&capsule, "__arrow_c_stream__")?; + let stream = unsafe { FFI_ArrowArrayStream::from_raw(stream_ptr.as_ptr()) }; let stream_reader = ArrowArrayStreamReader::try_new(stream) .map_err(|err| PyValueError::new_err(err.to_string()))?; @@ -448,8 +371,7 @@ impl FromPyArrow for ArrowArrayStreamReader { // make the conversion through PyArrow's private API // this changes the pointer's memory and is thus unsafe. // In particular, `_export_to_c` can go out of bounds - let args = PyTuple::new(value.py(), [&raw mut stream as Py_uintptr_t])?; - value.call_method1("_export_to_c", args)?; + value.call_method1("_export_to_c", (&raw mut stream as Py_uintptr_t,))?; ArrowArrayStreamReader::try_new(stream) .map_err(|err| PyValueError::new_err(err.to_string())) @@ -631,3 +553,73 @@ impl From for PyArrowType { Self(s) } } + +fn call_capsule_method<'py>( + object: &Bound<'py, PyAny>, + method_name: &'static str, +) -> PyResult> { + object + .call_method0(method_name)? + .extract() + .map_err(|e: CastError| { + wrapping_type_error( + object.py(), + e.into(), + format!("Expected {method_name} to return a capsule."), + ) + }) +} + +fn call_capsule_pair_method<'py>( + object: &Bound<'py, PyAny>, + method_name: &'static str, +) -> PyResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> { + object.call_method0(method_name)?.extract().map_err(|e| { + wrapping_type_error( + object.py(), + e, + format!("Expected {method_name} to return a pair of capsule."), + ) + }) +} + +trait PyCapsuleType { + const NAME: &CStr; +} + +impl PyCapsuleType for FFI_ArrowSchema { + const NAME: &CStr = c"arrow_schema"; +} + +impl PyCapsuleType for FFI_ArrowArray { + const NAME: &CStr = c"arrow_array"; +} + +impl PyCapsuleType for FFI_ArrowArrayStream { + const NAME: &CStr = c"arrow_array_stream"; +} + +fn extract_capsule( + capsule: &Bound, + method_name: &'static str, +) -> PyResult> { + Ok(capsule + .pointer_checked(Some(T::NAME)) + .map_err(|e| { + wrapping_type_error( + capsule.py(), + e, + format!( + "Expected {method_name} to return a {} capsule.", + T::NAME.to_str().unwrap(), + ), + ) + })? + .cast::()) +} + +fn wrapping_type_error(py: Python<'_>, error: PyErr, message: String) -> PyErr { + let e = PyTypeError::new_err(message); + e.set_cause(py, Some(error)); + e +}