Skip to content

Commit 61391e3

Browse files
committed
pyarrow: Wrap capsule extraction errors into an explicit error
Remove validate_pycapsule: pointer_checked already raises an exception if the capsule type is wrong, and we wrap it into a nicer error
1 parent 2b851d9 commit 61391e3

File tree

2 files changed

+104
-102
lines changed

2 files changed

+104
-102
lines changed

arrow-pyarrow-testing/tests/pyarrow.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ fn test_to_pyarrow_byte_view() {
8484
])
8585
.unwrap();
8686

87-
println!("input: {input:?}");
8887
let res = Python::attach(|py| {
8988
let py_input = input.to_pyarrow(py)?;
9089
let records = RecordBatch::from_pyarrow_bound(&py_input)?;
@@ -120,7 +119,7 @@ value = NotATuple()
120119
assert!(err.is_instance_of::<PyTypeError>(py));
121120
assert_eq!(
122121
err.to_string(),
123-
"TypeError: Expected __arrow_c_array__ to return a tuple of (schema, array) capsules."
122+
"TypeError: Expected __arrow_c_array__ to return a tuple of (arrow_schema, arrow_array) capsules."
124123
);
125124
});
126125
}

arrow-pyarrow/src/lib.rs

Lines changed: 103 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
6262
use std::convert::{From, TryFrom};
6363
use std::ffi::CStr;
64+
use std::ptr::NonNull;
6465
use std::sync::Arc;
6566

6667
use arrow_array::ffi;
@@ -77,15 +78,12 @@ use pyo3::ffi::Py_uintptr_t;
7778
use pyo3::import_exception;
7879
use pyo3::prelude::*;
7980
use pyo3::sync::PyOnceLock;
80-
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
81+
use pyo3::types::{PyCapsule, PyDict, PyList, PyType};
8182

8283
import_exception!(pyarrow, ArrowException);
8384
/// Represents an exception raised by PyArrow.
8485
pub type PyArrowException = ArrowException;
8586

86-
const ARROW_ARRAY_STREAM_CAPSULE_NAME: &CStr = c"arrow_array_stream";
87-
const ARROW_SCHEMA_CAPSULE_NAME: &CStr = c"arrow_schema";
88-
const ARROW_ARRAY_CAPSULE_NAME: &CStr = c"arrow_array";
8987

9088
fn to_py_err(err: ArrowError) -> PyErr {
9189
PyArrowException::new_err(err.to_string())
@@ -131,54 +129,16 @@ fn validate_class(expected: &Bound<PyType>, value: &Bound<PyAny>) -> PyResult<()
131129
Ok(())
132130
}
133131

134-
fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
135-
let capsule_name = capsule.name()?;
136-
if capsule_name.is_none() {
137-
return Err(PyValueError::new_err(
138-
"Expected schema PyCapsule to have name set.",
139-
));
140-
}
141-
142-
let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? };
143-
if capsule_name != name {
144-
return Err(PyValueError::new_err(format!(
145-
"Expected name '{name}' in PyCapsule, instead got '{capsule_name}'",
146-
)));
147-
}
148-
149-
Ok(())
150-
}
151-
152-
fn extract_arrow_c_array_capsules<'py>(
153-
value: &Bound<'py, PyAny>,
154-
) -> PyResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
155-
let tuple = value.call_method0("__arrow_c_array__")?;
156-
157-
if !tuple.is_instance_of::<PyTuple>() {
158-
return Err(PyTypeError::new_err(
159-
"Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.",
160-
));
161-
}
162-
163-
tuple.extract().map_err(|_| {
164-
PyTypeError::new_err(
165-
"Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.",
166-
)
167-
})
168-
}
169-
170132
impl FromPyArrow for DataType {
171133
fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
172134
// Newer versions of PyArrow as well as other libraries with Arrow data implement this
173135
// method, so prefer it over _export_to_c.
174136
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
175137
if value.hasattr("__arrow_c_schema__")? {
176-
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
177-
validate_pycapsule(&capsule, "arrow_schema")?;
178-
179-
let schema_ptr = capsule
180-
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
181-
.cast::<FFI_ArrowSchema>();
138+
let schema_ptr = extract_capsule_from_method::<FFI_ArrowSchema>(
139+
value,
140+
"__arrow_c_schema__",
141+
)?;
182142
return unsafe { DataType::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
183143
}
184144

@@ -203,12 +163,10 @@ impl FromPyArrow for Field {
203163
// method, so prefer it over _export_to_c.
204164
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
205165
if value.hasattr("__arrow_c_schema__")? {
206-
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
207-
validate_pycapsule(&capsule, "arrow_schema")?;
208-
209-
let schema_ptr = capsule
210-
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
211-
.cast::<FFI_ArrowSchema>();
166+
let schema_ptr = extract_capsule_from_method::<FFI_ArrowSchema>(
167+
value,
168+
"__arrow_c_schema__",
169+
)?;
212170
return unsafe { Field::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
213171
}
214172

@@ -233,12 +191,10 @@ impl FromPyArrow for Schema {
233191
// method, so prefer it over _export_to_c.
234192
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
235193
if value.hasattr("__arrow_c_schema__")? {
236-
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
237-
validate_pycapsule(&capsule, "arrow_schema")?;
238-
239-
let schema_ptr = capsule
240-
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
241-
.cast::<FFI_ArrowSchema>();
194+
let schema_ptr = extract_capsule_from_method::<FFI_ArrowSchema>(
195+
value,
196+
"__arrow_c_schema__",
197+
)?;
242198
return unsafe { Schema::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
243199
}
244200

@@ -263,22 +219,12 @@ impl FromPyArrow for ArrayData {
263219
// method, so prefer it over _export_to_c.
264220
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
265221
if value.hasattr("__arrow_c_array__")? {
266-
let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?;
267-
268-
validate_pycapsule(&schema_capsule, "arrow_schema")?;
269-
validate_pycapsule(&array_capsule, "arrow_array")?;
270-
271-
let schema_ptr = schema_capsule
272-
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
273-
.cast::<FFI_ArrowSchema>();
274-
let array = unsafe {
275-
FFI_ArrowArray::from_raw(
276-
array_capsule
277-
.pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
278-
.cast::<FFI_ArrowArray>()
279-
.as_ptr(),
280-
)
281-
};
222+
let (schema_ptr, array_ptr) =
223+
extract_capsule_pair_from_method::<FFI_ArrowSchema, FFI_ArrowArray>(
224+
value,
225+
"__arrow_c_array__",
226+
)?;
227+
let array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
282228
return unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) }.map_err(to_py_err);
283229
}
284230

@@ -341,23 +287,17 @@ impl FromPyArrow for RecordBatch {
341287
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
342288

343289
if value.hasattr("__arrow_c_array__")? {
344-
let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?;
345-
346-
validate_pycapsule(&schema_capsule, "arrow_schema")?;
347-
validate_pycapsule(&array_capsule, "arrow_array")?;
348-
349-
let schema_ptr = schema_capsule
350-
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
351-
.cast::<FFI_ArrowSchema>();
352-
let array_ptr = array_capsule
353-
.pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
354-
.cast::<FFI_ArrowArray>();
290+
let (schema_ptr, array_ptr) =
291+
extract_capsule_pair_from_method::<FFI_ArrowSchema, FFI_ArrowArray>(
292+
value,
293+
"__arrow_c_array__",
294+
)?;
355295
let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
356296
let mut array_data =
357297
unsafe { ffi::from_ffi(ffi_array, schema_ptr.as_ref()) }.map_err(to_py_err)?;
358298
if !matches!(array_data.data_type(), DataType::Struct(_)) {
359299
return Err(PyTypeError::new_err(
360-
"Expected Struct type from __arrow_c_array.",
300+
format!("Expected Struct type from __arrow_c_array__, found {}.", array_data.data_type()),
361301
));
362302
}
363303
let options = RecordBatchOptions::default().with_row_count(Some(array_data.len()));
@@ -421,18 +361,11 @@ impl FromPyArrow for ArrowArrayStreamReader {
421361
// method, so prefer it over _export_to_c.
422362
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
423363
if value.hasattr("__arrow_c_stream__")? {
424-
let capsule = value.call_method0("__arrow_c_stream__")?.extract()?;
425-
426-
validate_pycapsule(&capsule, "arrow_array_stream")?;
427-
428-
let stream = unsafe {
429-
FFI_ArrowArrayStream::from_raw(
430-
capsule
431-
.pointer_checked(Some(ARROW_ARRAY_STREAM_CAPSULE_NAME))?
432-
.cast::<FFI_ArrowArrayStream>()
433-
.as_ptr(),
434-
)
435-
};
364+
let stream_ptr = extract_capsule_from_method::<FFI_ArrowArrayStream>(
365+
value,
366+
"__arrow_c_stream__",
367+
)?;
368+
let stream = unsafe { FFI_ArrowArrayStream::from_raw(stream_ptr.as_ptr()) };
436369

437370
let stream_reader = ArrowArrayStreamReader::try_new(stream)
438371
.map_err(|err| PyValueError::new_err(err.to_string()))?;
@@ -448,8 +381,7 @@ impl FromPyArrow for ArrowArrayStreamReader {
448381
// make the conversion through PyArrow's private API
449382
// this changes the pointer's memory and is thus unsafe.
450383
// In particular, `_export_to_c` can go out of bounds
451-
let args = PyTuple::new(value.py(), [&raw mut stream as Py_uintptr_t])?;
452-
value.call_method1("_export_to_c", args)?;
384+
value.call_method1("_export_to_c", (&raw mut stream as Py_uintptr_t,))?;
453385

454386
ArrowArrayStreamReader::try_new(stream)
455387
.map_err(|err| PyValueError::new_err(err.to_string()))
@@ -631,3 +563,74 @@ impl<T> From<T> for PyArrowType<T> {
631563
Self(s)
632564
}
633565
}
566+
567+
trait PyCapsuleType {
568+
const NAME: &CStr;
569+
}
570+
571+
impl PyCapsuleType for FFI_ArrowSchema {
572+
const NAME: &CStr = c"arrow_schema";
573+
}
574+
575+
impl PyCapsuleType for FFI_ArrowArray {
576+
const NAME: &CStr = c"arrow_array";
577+
}
578+
579+
impl PyCapsuleType for FFI_ArrowArrayStream {
580+
const NAME: &CStr = c"arrow_array_stream";
581+
}
582+
583+
fn extract_capsule_from_method<T: PyCapsuleType>(
584+
object: &Bound<'_, PyAny>,
585+
method_name: &'static str,
586+
) -> PyResult<NonNull<T>> {
587+
(|| {
588+
Ok(object
589+
.call_method0(method_name)?
590+
.extract::<Bound<'_, PyCapsule>>()?
591+
.pointer_checked(Some(T::NAME))?
592+
.cast::<T>())
593+
})()
594+
.map_err(|e| {
595+
wrapping_type_error(
596+
object.py(),
597+
e,
598+
format!(
599+
"Expected {method_name} to return a {} capsule.",
600+
T::NAME.to_str().unwrap(),
601+
),
602+
)
603+
})
604+
}
605+
606+
fn extract_capsule_pair_from_method<T1: PyCapsuleType, T2: PyCapsuleType>(
607+
object: &Bound<'_, PyAny>,
608+
method_name: &'static str,
609+
) -> PyResult<(NonNull<T1>, NonNull<T2>)> {
610+
(|| {
611+
let (c1, c2) = object
612+
.call_method0(method_name)?
613+
.extract::<(Bound<'_, PyCapsule>, Bound<'_, PyCapsule>)>()?;
614+
Ok((
615+
c1.pointer_checked(Some(T1::NAME))?.cast::<T1>(),
616+
c2.pointer_checked(Some(T2::NAME))?.cast::<T2>(),
617+
))
618+
})()
619+
.map_err(|e| {
620+
wrapping_type_error(
621+
object.py(),
622+
e,
623+
format!(
624+
"Expected {method_name} to return a tuple of ({}, {}) capsules.",
625+
T1::NAME.to_str().unwrap(),
626+
T2::NAME.to_str().unwrap()
627+
),
628+
)
629+
})
630+
}
631+
632+
fn wrapping_type_error(py: Python<'_>, error: PyErr, message: String) -> PyErr {
633+
let e = PyTypeError::new_err(message);
634+
e.set_cause(py, Some(error));
635+
e
636+
}

0 commit comments

Comments
 (0)