Skip to content

Commit a11d018

Browse files
committed
pyarrow: Wrap capsule extraction errors into an explicit error
Remove validate_pycapsule: pointer_checked already raises an exception if the capsule type is wrong, and we wrap it into a nicer error
1 parent 2b851d9 commit a11d018

File tree

1 file changed

+99
-98
lines changed

1 file changed

+99
-98
lines changed

arrow-pyarrow/src/lib.rs

Lines changed: 99 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
6262
use std::convert::{From, TryFrom};
6363
use std::ffi::CStr;
64+
use std::ptr::NonNull;
6465
use std::sync::Arc;
6566

6667
use arrow_array::ffi;
@@ -77,7 +78,7 @@ use pyo3::ffi::Py_uintptr_t;
7778
use pyo3::import_exception;
7879
use pyo3::prelude::*;
7980
use pyo3::sync::PyOnceLock;
80-
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
81+
use pyo3::types::{PyCapsule, PyDict, PyList, PyType};
8182

8283
import_exception!(pyarrow, ArrowException);
8384
/// Represents an exception raised by PyArrow.
@@ -131,54 +132,17 @@ fn validate_class(expected: &Bound<PyType>, value: &Bound<PyAny>) -> PyResult<()
131132
Ok(())
132133
}
133134

134-
fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
135-
let capsule_name = capsule.name()?;
136-
if capsule_name.is_none() {
137-
return Err(PyValueError::new_err(
138-
"Expected schema PyCapsule to have name set.",
139-
));
140-
}
141-
142-
let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? };
143-
if capsule_name != name {
144-
return Err(PyValueError::new_err(format!(
145-
"Expected name '{name}' in PyCapsule, instead got '{capsule_name}'",
146-
)));
147-
}
148-
149-
Ok(())
150-
}
151-
152-
fn extract_arrow_c_array_capsules<'py>(
153-
value: &Bound<'py, PyAny>,
154-
) -> PyResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
155-
let tuple = value.call_method0("__arrow_c_array__")?;
156-
157-
if !tuple.is_instance_of::<PyTuple>() {
158-
return Err(PyTypeError::new_err(
159-
"Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.",
160-
));
161-
}
162-
163-
tuple.extract().map_err(|_| {
164-
PyTypeError::new_err(
165-
"Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.",
166-
)
167-
})
168-
}
169-
170135
impl FromPyArrow for DataType {
171136
fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
172137
// Newer versions of PyArrow as well as other libraries with Arrow data implement this
173138
// method, so prefer it over _export_to_c.
174139
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
175140
if value.hasattr("__arrow_c_schema__")? {
176-
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
177-
validate_pycapsule(&capsule, "arrow_schema")?;
178-
179-
let schema_ptr = capsule
180-
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
181-
.cast::<FFI_ArrowSchema>();
141+
let schema_ptr = extract_capsule_from_method::<FFI_ArrowSchema>(
142+
value,
143+
"__arrow_c_schema__",
144+
ARROW_SCHEMA_CAPSULE_NAME,
145+
)?;
182146
return unsafe { DataType::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
183147
}
184148

@@ -203,12 +167,11 @@ impl FromPyArrow for Field {
203167
// method, so prefer it over _export_to_c.
204168
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
205169
if value.hasattr("__arrow_c_schema__")? {
206-
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
207-
validate_pycapsule(&capsule, "arrow_schema")?;
208-
209-
let schema_ptr = capsule
210-
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
211-
.cast::<FFI_ArrowSchema>();
170+
let schema_ptr = extract_capsule_from_method::<FFI_ArrowSchema>(
171+
value,
172+
"__arrow_c_schema__",
173+
ARROW_SCHEMA_CAPSULE_NAME,
174+
)?;
212175
return unsafe { Field::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
213176
}
214177

@@ -233,12 +196,11 @@ impl FromPyArrow for Schema {
233196
// method, so prefer it over _export_to_c.
234197
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
235198
if value.hasattr("__arrow_c_schema__")? {
236-
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
237-
validate_pycapsule(&capsule, "arrow_schema")?;
238-
239-
let schema_ptr = capsule
240-
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
241-
.cast::<FFI_ArrowSchema>();
199+
let schema_ptr = extract_capsule_from_method::<FFI_ArrowSchema>(
200+
value,
201+
"__arrow_c_schema__",
202+
ARROW_SCHEMA_CAPSULE_NAME,
203+
)?;
242204
return unsafe { Schema::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
243205
}
244206

@@ -263,22 +225,14 @@ impl FromPyArrow for ArrayData {
263225
// method, so prefer it over _export_to_c.
264226
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
265227
if value.hasattr("__arrow_c_array__")? {
266-
let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?;
267-
268-
validate_pycapsule(&schema_capsule, "arrow_schema")?;
269-
validate_pycapsule(&array_capsule, "arrow_array")?;
270-
271-
let schema_ptr = schema_capsule
272-
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
273-
.cast::<FFI_ArrowSchema>();
274-
let array = unsafe {
275-
FFI_ArrowArray::from_raw(
276-
array_capsule
277-
.pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
278-
.cast::<FFI_ArrowArray>()
279-
.as_ptr(),
280-
)
281-
};
228+
let (schema_ptr, array_ptr) =
229+
extract_capsule_pair_from_method::<FFI_ArrowSchema, FFI_ArrowArray>(
230+
value,
231+
"__arrow_c_array__",
232+
ARROW_SCHEMA_CAPSULE_NAME,
233+
ARROW_ARRAY_CAPSULE_NAME,
234+
)?;
235+
let array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
282236
return unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) }.map_err(to_py_err);
283237
}
284238

@@ -341,20 +295,16 @@ impl FromPyArrow for RecordBatch {
341295
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
342296

343297
if value.hasattr("__arrow_c_array__")? {
344-
let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?;
345-
346-
validate_pycapsule(&schema_capsule, "arrow_schema")?;
347-
validate_pycapsule(&array_capsule, "arrow_array")?;
348-
349-
let schema_ptr = schema_capsule
350-
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
351-
.cast::<FFI_ArrowSchema>();
352-
let array_ptr = array_capsule
353-
.pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
354-
.cast::<FFI_ArrowArray>();
355-
let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
298+
let (schema_ptr, array_ptr) =
299+
extract_capsule_pair_from_method::<FFI_ArrowSchema, FFI_ArrowArray>(
300+
value,
301+
"__arrow_c_array__",
302+
ARROW_SCHEMA_CAPSULE_NAME,
303+
ARROW_ARRAY_CAPSULE_NAME,
304+
)?;
305+
let array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
356306
let mut array_data =
357-
unsafe { ffi::from_ffi(ffi_array, schema_ptr.as_ref()) }.map_err(to_py_err)?;
307+
unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) }.map_err(to_py_err)?;
358308
if !matches!(array_data.data_type(), DataType::Struct(_)) {
359309
return Err(PyTypeError::new_err(
360310
"Expected Struct type from __arrow_c_array.",
@@ -421,18 +371,12 @@ impl FromPyArrow for ArrowArrayStreamReader {
421371
// method, so prefer it over _export_to_c.
422372
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
423373
if value.hasattr("__arrow_c_stream__")? {
424-
let capsule = value.call_method0("__arrow_c_stream__")?.extract()?;
425-
426-
validate_pycapsule(&capsule, "arrow_array_stream")?;
427-
428-
let stream = unsafe {
429-
FFI_ArrowArrayStream::from_raw(
430-
capsule
431-
.pointer_checked(Some(ARROW_ARRAY_STREAM_CAPSULE_NAME))?
432-
.cast::<FFI_ArrowArrayStream>()
433-
.as_ptr(),
434-
)
435-
};
374+
let stream_ptr = extract_capsule_from_method::<FFI_ArrowArrayStream>(
375+
value,
376+
"__arrow_c_stream__",
377+
ARROW_ARRAY_STREAM_CAPSULE_NAME,
378+
)?;
379+
let stream = unsafe { FFI_ArrowArrayStream::from_raw(stream_ptr.as_ptr()) };
436380

437381
let stream_reader = ArrowArrayStreamReader::try_new(stream)
438382
.map_err(|err| PyValueError::new_err(err.to_string()))?;
@@ -448,8 +392,7 @@ impl FromPyArrow for ArrowArrayStreamReader {
448392
// make the conversion through PyArrow's private API
449393
// this changes the pointer's memory and is thus unsafe.
450394
// In particular, `_export_to_c` can go out of bounds
451-
let args = PyTuple::new(value.py(), [&raw mut stream as Py_uintptr_t])?;
452-
value.call_method1("_export_to_c", args)?;
395+
value.call_method1("_export_to_c", (&raw mut stream as Py_uintptr_t,))?;
453396

454397
ArrowArrayStreamReader::try_new(stream)
455398
.map_err(|err| PyValueError::new_err(err.to_string()))
@@ -631,3 +574,61 @@ impl<T> From<T> for PyArrowType<T> {
631574
Self(s)
632575
}
633576
}
577+
578+
fn extract_capsule_from_method<T>(
579+
object: &Bound<'_, PyAny>,
580+
method_name: &'static str,
581+
capsule_name: &'static CStr,
582+
) -> PyResult<NonNull<T>> {
583+
(|| {
584+
Ok(object
585+
.call_method0(method_name)?
586+
.extract::<Bound<'_, PyCapsule>>()?
587+
.pointer_checked(Some(capsule_name))?
588+
.cast::<T>())
589+
})()
590+
.map_err(|e| {
591+
wrapping_type_error(
592+
object.py(),
593+
e,
594+
format!(
595+
"Expected {method_name} to return a {} capsule.",
596+
capsule_name.to_str().unwrap(),
597+
),
598+
)
599+
})
600+
}
601+
602+
fn extract_capsule_pair_from_method<T1, T2>(
603+
object: &Bound<'_, PyAny>,
604+
method_name: &'static str,
605+
capsule1_name: &'static CStr,
606+
capsule2_name: &'static CStr,
607+
) -> PyResult<(NonNull<T1>, NonNull<T2>)> {
608+
(|| {
609+
let (c1, c2) = object
610+
.call_method0(method_name)?
611+
.extract::<(Bound<'_, PyCapsule>, Bound<'_, PyCapsule>)>()?;
612+
Ok((
613+
c1.pointer_checked(Some(capsule1_name))?.cast::<T1>(),
614+
c2.pointer_checked(Some(capsule2_name))?.cast::<T2>(),
615+
))
616+
})()
617+
.map_err(|e| {
618+
wrapping_type_error(
619+
object.py(),
620+
e,
621+
format!(
622+
"Expected {method_name} to return a tuple of ({}, {}) capsules.",
623+
capsule1_name.to_str().unwrap(),
624+
capsule2_name.to_str().unwrap()
625+
),
626+
)
627+
})
628+
}
629+
630+
fn wrapping_type_error(py: Python<'_>, error: PyErr, message: String) -> PyErr {
631+
let e = PyTypeError::new_err(message);
632+
e.set_cause(py, Some(error));
633+
e
634+
}

0 commit comments

Comments
 (0)