Skip to content

Commit 7ecba54

Browse files
committed
Fix use after free error
1 parent fb9fa2a commit 7ecba54

File tree

1 file changed

+65
-76
lines changed

1 file changed

+65
-76
lines changed

arrow-pyarrow/src/lib.rs

Lines changed: 65 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,15 @@ use arrow_data::ArrayData;
7575
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
7676
use pyo3::exceptions::{PyTypeError, PyValueError};
7777
use pyo3::ffi::Py_uintptr_t;
78-
use pyo3::import_exception;
7978
use pyo3::prelude::*;
8079
use pyo3::sync::PyOnceLock;
8180
use pyo3::types::{PyCapsule, PyDict, PyList, PyType};
81+
use pyo3::{CastError, import_exception};
8282

8383
import_exception!(pyarrow, ArrowException);
8484
/// Represents an exception raised by PyArrow.
8585
pub type PyArrowException = ArrowException;
8686

87-
8887
fn to_py_err(err: ArrowError) -> PyErr {
8988
PyArrowException::new_err(err.to_string())
9089
}
@@ -135,10 +134,8 @@ impl FromPyArrow for DataType {
135134
// method, so prefer it over _export_to_c.
136135
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
137136
if value.hasattr("__arrow_c_schema__")? {
138-
let schema_ptr = extract_capsule_from_method::<FFI_ArrowSchema>(
139-
value,
140-
"__arrow_c_schema__",
141-
)?;
137+
let capsule = call_capsule_method(value, "__arrow_c_schema__")?;
138+
let schema_ptr = extract_capsule::<FFI_ArrowSchema>(&capsule, "__arrow_c_schema__")?;
142139
return unsafe { DataType::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
143140
}
144141

@@ -163,10 +160,8 @@ impl FromPyArrow for Field {
163160
// method, so prefer it over _export_to_c.
164161
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
165162
if value.hasattr("__arrow_c_schema__")? {
166-
let schema_ptr = extract_capsule_from_method::<FFI_ArrowSchema>(
167-
value,
168-
"__arrow_c_schema__",
169-
)?;
163+
let capsule = call_capsule_method(value, "__arrow_c_schema__")?;
164+
let schema_ptr = extract_capsule::<FFI_ArrowSchema>(&capsule, "__arrow_c_schema__")?;
170165
return unsafe { Field::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
171166
}
172167

@@ -191,10 +186,8 @@ impl FromPyArrow for Schema {
191186
// method, so prefer it over _export_to_c.
192187
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
193188
if value.hasattr("__arrow_c_schema__")? {
194-
let schema_ptr = extract_capsule_from_method::<FFI_ArrowSchema>(
195-
value,
196-
"__arrow_c_schema__",
197-
)?;
189+
let capsule = call_capsule_method(value, "__arrow_c_schema__")?;
190+
let schema_ptr = extract_capsule::<FFI_ArrowSchema>(&capsule, "__arrow_c_schema__")?;
198191
return unsafe { Schema::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
199192
}
200193

@@ -219,11 +212,10 @@ impl FromPyArrow for ArrayData {
219212
// method, so prefer it over _export_to_c.
220213
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
221214
if value.hasattr("__arrow_c_array__")? {
222-
let (schema_ptr, array_ptr) =
223-
extract_capsule_pair_from_method::<FFI_ArrowSchema, FFI_ArrowArray>(
224-
value,
225-
"__arrow_c_array__",
226-
)?;
215+
let (schema_capsule, array_capsule) =
216+
call_capsule_pair_method(value, "__arrow_c_array__")?;
217+
let schema_ptr = extract_capsule(&schema_capsule, "__arrow_c_array__")?;
218+
let array_ptr = extract_capsule(&array_capsule, "__arrow_c_array__")?;
227219
let array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
228220
return unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) }.map_err(to_py_err);
229221
}
@@ -287,18 +279,18 @@ impl FromPyArrow for RecordBatch {
287279
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
288280

289281
if value.hasattr("__arrow_c_array__")? {
290-
let (schema_ptr, array_ptr) =
291-
extract_capsule_pair_from_method::<FFI_ArrowSchema, FFI_ArrowArray>(
292-
value,
293-
"__arrow_c_array__",
294-
)?;
282+
let (schema_capsule, array_capsule) =
283+
call_capsule_pair_method(value, "__arrow_c_array__")?;
284+
let schema_ptr = extract_capsule(&schema_capsule, "__arrow_c_array__")?;
285+
let array_ptr = extract_capsule(&array_capsule, "__arrow_c_array__")?;
295286
let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
296287
let mut array_data =
297288
unsafe { ffi::from_ffi(ffi_array, schema_ptr.as_ref()) }.map_err(to_py_err)?;
298289
if !matches!(array_data.data_type(), DataType::Struct(_)) {
299-
return Err(PyTypeError::new_err(
300-
format!("Expected Struct type from __arrow_c_array__, found {}.", array_data.data_type()),
301-
));
290+
return Err(PyTypeError::new_err(format!(
291+
"Expected Struct type from __arrow_c_array__, found {}.",
292+
array_data.data_type()
293+
)));
302294
}
303295
let options = RecordBatchOptions::default().with_row_count(Some(array_data.len()));
304296
// Ensure data is aligned (by potentially copying the buffers).
@@ -361,10 +353,8 @@ impl FromPyArrow for ArrowArrayStreamReader {
361353
// method, so prefer it over _export_to_c.
362354
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
363355
if value.hasattr("__arrow_c_stream__")? {
364-
let stream_ptr = extract_capsule_from_method::<FFI_ArrowArrayStream>(
365-
value,
366-
"__arrow_c_stream__",
367-
)?;
356+
let capsule = call_capsule_method(value, "__arrow_c_stream__")?;
357+
let stream_ptr = extract_capsule(&capsule, "__arrow_c_stream__")?;
368358
let stream = unsafe { FFI_ArrowArrayStream::from_raw(stream_ptr.as_ptr()) };
369359

370360
let stream_reader = ArrowArrayStreamReader::try_new(stream)
@@ -564,6 +554,35 @@ impl<T> From<T> for PyArrowType<T> {
564554
}
565555
}
566556

557+
fn call_capsule_method<'py>(
558+
object: &Bound<'py, PyAny>,
559+
method_name: &'static str,
560+
) -> PyResult<Bound<'py, PyCapsule>> {
561+
object
562+
.call_method0(method_name)?
563+
.extract()
564+
.map_err(|e: CastError| {
565+
wrapping_type_error(
566+
object.py(),
567+
e.into(),
568+
format!("Expected {method_name} to return a capsule."),
569+
)
570+
})
571+
}
572+
573+
fn call_capsule_pair_method<'py>(
574+
object: &Bound<'py, PyAny>,
575+
method_name: &'static str,
576+
) -> PyResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
577+
object.call_method0(method_name)?.extract().map_err(|e| {
578+
wrapping_type_error(
579+
object.py(),
580+
e,
581+
format!("Expected {method_name} to return a pair of capsule."),
582+
)
583+
})
584+
}
585+
567586
trait PyCapsuleType {
568587
const NAME: &CStr;
569588
}
@@ -580,53 +599,23 @@ impl PyCapsuleType for FFI_ArrowArrayStream {
580599
const NAME: &CStr = c"arrow_array_stream";
581600
}
582601

583-
fn extract_capsule_from_method<T: PyCapsuleType>(
584-
object: &Bound<'_, PyAny>,
602+
fn extract_capsule<T: PyCapsuleType>(
603+
capsule: &Bound<PyCapsule>,
585604
method_name: &'static str,
586605
) -> PyResult<NonNull<T>> {
587-
(|| {
588-
Ok(object
589-
.call_method0(method_name)?
590-
.extract::<Bound<'_, PyCapsule>>()?
591-
.pointer_checked(Some(T::NAME))?
592-
.cast::<T>())
593-
})()
594-
.map_err(|e| {
595-
wrapping_type_error(
596-
object.py(),
597-
e,
598-
format!(
599-
"Expected {method_name} to return a {} capsule.",
600-
T::NAME.to_str().unwrap(),
601-
),
602-
)
603-
})
604-
}
605-
606-
fn extract_capsule_pair_from_method<T1: PyCapsuleType, T2: PyCapsuleType>(
607-
object: &Bound<'_, PyAny>,
608-
method_name: &'static str,
609-
) -> PyResult<(NonNull<T1>, NonNull<T2>)> {
610-
(|| {
611-
let (c1, c2) = object
612-
.call_method0(method_name)?
613-
.extract::<(Bound<'_, PyCapsule>, Bound<'_, PyCapsule>)>()?;
614-
Ok((
615-
c1.pointer_checked(Some(T1::NAME))?.cast::<T1>(),
616-
c2.pointer_checked(Some(T2::NAME))?.cast::<T2>(),
617-
))
618-
})()
619-
.map_err(|e| {
620-
wrapping_type_error(
621-
object.py(),
622-
e,
623-
format!(
624-
"Expected {method_name} to return a tuple of ({}, {}) capsules.",
625-
T1::NAME.to_str().unwrap(),
626-
T2::NAME.to_str().unwrap()
627-
),
628-
)
629-
})
606+
Ok(capsule
607+
.pointer_checked(Some(T::NAME))
608+
.map_err(|e| {
609+
wrapping_type_error(
610+
capsule.py(),
611+
e,
612+
format!(
613+
"Expected {method_name} to return a {} capsule.",
614+
T::NAME.to_str().unwrap(),
615+
),
616+
)
617+
})?
618+
.cast::<T>())
630619
}
631620

632621
fn wrapping_type_error(py: Python<'_>, error: PyErr, message: String) -> PyErr {

0 commit comments

Comments
 (0)