6161
6262use std:: convert:: { From , TryFrom } ;
6363use std:: ffi:: CStr ;
64+ use std:: ptr:: NonNull ;
6465use std:: sync:: Arc ;
6566
6667use arrow_array:: ffi;
@@ -77,7 +78,7 @@ use pyo3::ffi::Py_uintptr_t;
7778use pyo3:: import_exception;
7879use pyo3:: prelude:: * ;
7980use pyo3:: sync:: PyOnceLock ;
80- use pyo3:: types:: { PyCapsule , PyDict , PyList , PyTuple , PyType } ;
81+ use pyo3:: types:: { PyCapsule , PyDict , PyList , PyType } ;
8182
8283import_exception ! ( pyarrow, ArrowException ) ;
8384/// Represents an exception raised by PyArrow.
@@ -131,54 +132,17 @@ fn validate_class(expected: &Bound<PyType>, value: &Bound<PyAny>) -> PyResult<()
131132 Ok ( ( ) )
132133}
133134
134- fn validate_pycapsule ( capsule : & Bound < PyCapsule > , name : & str ) -> PyResult < ( ) > {
135- let capsule_name = capsule. name ( ) ?;
136- if capsule_name. is_none ( ) {
137- return Err ( PyValueError :: new_err (
138- "Expected schema PyCapsule to have name set." ,
139- ) ) ;
140- }
141-
142- let capsule_name = unsafe { capsule_name. unwrap ( ) . as_cstr ( ) . to_str ( ) ? } ;
143- if capsule_name != name {
144- return Err ( PyValueError :: new_err ( format ! (
145- "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'" ,
146- ) ) ) ;
147- }
148-
149- Ok ( ( ) )
150- }
151-
152- fn extract_arrow_c_array_capsules < ' py > (
153- value : & Bound < ' py , PyAny > ,
154- ) -> PyResult < ( Bound < ' py , PyCapsule > , Bound < ' py , PyCapsule > ) > {
155- let tuple = value. call_method0 ( "__arrow_c_array__" ) ?;
156-
157- if !tuple. is_instance_of :: < PyTuple > ( ) {
158- return Err ( PyTypeError :: new_err (
159- "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules." ,
160- ) ) ;
161- }
162-
163- tuple. extract ( ) . map_err ( |_| {
164- PyTypeError :: new_err (
165- "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules." ,
166- )
167- } )
168- }
169-
170135impl FromPyArrow for DataType {
171136 fn from_pyarrow_bound ( value : & Bound < PyAny > ) -> PyResult < Self > {
172137 // Newer versions of PyArrow as well as other libraries with Arrow data implement this
173138 // method, so prefer it over _export_to_c.
174139 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
175140 if value. hasattr ( "__arrow_c_schema__" ) ? {
176- let capsule = value. call_method0 ( "__arrow_c_schema__" ) ?. extract ( ) ?;
177- validate_pycapsule ( & capsule, "arrow_schema" ) ?;
178-
179- let schema_ptr = capsule
180- . pointer_checked ( Some ( ARROW_SCHEMA_CAPSULE_NAME ) ) ?
181- . cast :: < FFI_ArrowSchema > ( ) ;
141+ let schema_ptr = extract_capsule_from_method :: < FFI_ArrowSchema > (
142+ value,
143+ "__arrow_c_schema__" ,
144+ ARROW_SCHEMA_CAPSULE_NAME ,
145+ ) ?;
182146 return unsafe { DataType :: try_from ( schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ;
183147 }
184148
@@ -203,12 +167,11 @@ impl FromPyArrow for Field {
203167 // method, so prefer it over _export_to_c.
204168 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
205169 if value. hasattr ( "__arrow_c_schema__" ) ? {
206- let capsule = value. call_method0 ( "__arrow_c_schema__" ) ?. extract ( ) ?;
207- validate_pycapsule ( & capsule, "arrow_schema" ) ?;
208-
209- let schema_ptr = capsule
210- . pointer_checked ( Some ( ARROW_SCHEMA_CAPSULE_NAME ) ) ?
211- . cast :: < FFI_ArrowSchema > ( ) ;
170+ let schema_ptr = extract_capsule_from_method :: < FFI_ArrowSchema > (
171+ value,
172+ "__arrow_c_schema__" ,
173+ ARROW_SCHEMA_CAPSULE_NAME ,
174+ ) ?;
212175 return unsafe { Field :: try_from ( schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ;
213176 }
214177
@@ -233,12 +196,11 @@ impl FromPyArrow for Schema {
233196 // method, so prefer it over _export_to_c.
234197 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
235198 if value. hasattr ( "__arrow_c_schema__" ) ? {
236- let capsule = value. call_method0 ( "__arrow_c_schema__" ) ?. extract ( ) ?;
237- validate_pycapsule ( & capsule, "arrow_schema" ) ?;
238-
239- let schema_ptr = capsule
240- . pointer_checked ( Some ( ARROW_SCHEMA_CAPSULE_NAME ) ) ?
241- . cast :: < FFI_ArrowSchema > ( ) ;
199+ let schema_ptr = extract_capsule_from_method :: < FFI_ArrowSchema > (
200+ value,
201+ "__arrow_c_schema__" ,
202+ ARROW_SCHEMA_CAPSULE_NAME ,
203+ ) ?;
242204 return unsafe { Schema :: try_from ( schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ;
243205 }
244206
@@ -263,22 +225,14 @@ impl FromPyArrow for ArrayData {
263225 // method, so prefer it over _export_to_c.
264226 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
265227 if value. hasattr ( "__arrow_c_array__" ) ? {
266- let ( schema_capsule, array_capsule) = extract_arrow_c_array_capsules ( value) ?;
267-
268- validate_pycapsule ( & schema_capsule, "arrow_schema" ) ?;
269- validate_pycapsule ( & array_capsule, "arrow_array" ) ?;
270-
271- let schema_ptr = schema_capsule
272- . pointer_checked ( Some ( ARROW_SCHEMA_CAPSULE_NAME ) ) ?
273- . cast :: < FFI_ArrowSchema > ( ) ;
274- let array = unsafe {
275- FFI_ArrowArray :: from_raw (
276- array_capsule
277- . pointer_checked ( Some ( ARROW_ARRAY_CAPSULE_NAME ) ) ?
278- . cast :: < FFI_ArrowArray > ( )
279- . as_ptr ( ) ,
280- )
281- } ;
228+ let ( schema_ptr, array_ptr) =
229+ extract_capsule_pair_from_method :: < FFI_ArrowSchema , FFI_ArrowArray > (
230+ value,
231+ "__arrow_c_array__" ,
232+ ARROW_SCHEMA_CAPSULE_NAME ,
233+ ARROW_ARRAY_CAPSULE_NAME ,
234+ ) ?;
235+ let array = unsafe { FFI_ArrowArray :: from_raw ( array_ptr. as_ptr ( ) ) } ;
282236 return unsafe { ffi:: from_ffi ( array, schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ;
283237 }
284238
@@ -341,20 +295,16 @@ impl FromPyArrow for RecordBatch {
341295 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
342296
343297 if value. hasattr ( "__arrow_c_array__" ) ? {
344- let ( schema_capsule, array_capsule) = extract_arrow_c_array_capsules ( value) ?;
345-
346- validate_pycapsule ( & schema_capsule, "arrow_schema" ) ?;
347- validate_pycapsule ( & array_capsule, "arrow_array" ) ?;
348-
349- let schema_ptr = schema_capsule
350- . pointer_checked ( Some ( ARROW_SCHEMA_CAPSULE_NAME ) ) ?
351- . cast :: < FFI_ArrowSchema > ( ) ;
352- let array_ptr = array_capsule
353- . pointer_checked ( Some ( ARROW_ARRAY_CAPSULE_NAME ) ) ?
354- . cast :: < FFI_ArrowArray > ( ) ;
355- let ffi_array = unsafe { FFI_ArrowArray :: from_raw ( array_ptr. as_ptr ( ) ) } ;
298+ let ( schema_ptr, array_ptr) =
299+ extract_capsule_pair_from_method :: < FFI_ArrowSchema , FFI_ArrowArray > (
300+ value,
301+ "__arrow_c_array__" ,
302+ ARROW_SCHEMA_CAPSULE_NAME ,
303+ ARROW_ARRAY_CAPSULE_NAME ,
304+ ) ?;
305+ let array = unsafe { FFI_ArrowArray :: from_raw ( array_ptr. as_ptr ( ) ) } ;
356306 let mut array_data =
357- unsafe { ffi:: from_ffi ( ffi_array , schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ?;
307+ unsafe { ffi:: from_ffi ( array , schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ?;
358308 if !matches ! ( array_data. data_type( ) , DataType :: Struct ( _) ) {
359309 return Err ( PyTypeError :: new_err (
360310 "Expected Struct type from __arrow_c_array." ,
@@ -421,18 +371,12 @@ impl FromPyArrow for ArrowArrayStreamReader {
421371 // method, so prefer it over _export_to_c.
422372 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
423373 if value. hasattr ( "__arrow_c_stream__" ) ? {
424- let capsule = value. call_method0 ( "__arrow_c_stream__" ) ?. extract ( ) ?;
425-
426- validate_pycapsule ( & capsule, "arrow_array_stream" ) ?;
427-
428- let stream = unsafe {
429- FFI_ArrowArrayStream :: from_raw (
430- capsule
431- . pointer_checked ( Some ( ARROW_ARRAY_STREAM_CAPSULE_NAME ) ) ?
432- . cast :: < FFI_ArrowArrayStream > ( )
433- . as_ptr ( ) ,
434- )
435- } ;
374+ let stream_ptr = extract_capsule_from_method :: < FFI_ArrowArrayStream > (
375+ value,
376+ "__arrow_c_stream__" ,
377+ ARROW_ARRAY_STREAM_CAPSULE_NAME ,
378+ ) ?;
379+ let stream = unsafe { FFI_ArrowArrayStream :: from_raw ( stream_ptr. as_ptr ( ) ) } ;
436380
437381 let stream_reader = ArrowArrayStreamReader :: try_new ( stream)
438382 . map_err ( |err| PyValueError :: new_err ( err. to_string ( ) ) ) ?;
@@ -448,8 +392,7 @@ impl FromPyArrow for ArrowArrayStreamReader {
448392 // make the conversion through PyArrow's private API
449393 // this changes the pointer's memory and is thus unsafe.
450394 // In particular, `_export_to_c` can go out of bounds
451- let args = PyTuple :: new ( value. py ( ) , [ & raw mut stream as Py_uintptr_t ] ) ?;
452- value. call_method1 ( "_export_to_c" , args) ?;
395+ value. call_method1 ( "_export_to_c" , ( & raw mut stream as Py_uintptr_t , ) ) ?;
453396
454397 ArrowArrayStreamReader :: try_new ( stream)
455398 . map_err ( |err| PyValueError :: new_err ( err. to_string ( ) ) )
@@ -631,3 +574,61 @@ impl<T> From<T> for PyArrowType<T> {
631574 Self ( s)
632575 }
633576}
577+
578+ fn extract_capsule_from_method < T > (
579+ object : & Bound < ' _ , PyAny > ,
580+ method_name : & ' static str ,
581+ capsule_name : & ' static CStr ,
582+ ) -> PyResult < NonNull < T > > {
583+ ( || {
584+ Ok ( object
585+ . call_method0 ( method_name) ?
586+ . extract :: < Bound < ' _ , PyCapsule > > ( ) ?
587+ . pointer_checked ( Some ( capsule_name) ) ?
588+ . cast :: < T > ( ) )
589+ } ) ( )
590+ . map_err ( |e| {
591+ wrapping_type_error (
592+ object. py ( ) ,
593+ e,
594+ format ! (
595+ "Expected {method_name} to return a {} capsule." ,
596+ capsule_name. to_str( ) . unwrap( ) ,
597+ ) ,
598+ )
599+ } )
600+ }
601+
602+ fn extract_capsule_pair_from_method < T1 , T2 > (
603+ object : & Bound < ' _ , PyAny > ,
604+ method_name : & ' static str ,
605+ capsule1_name : & ' static CStr ,
606+ capsule2_name : & ' static CStr ,
607+ ) -> PyResult < ( NonNull < T1 > , NonNull < T2 > ) > {
608+ ( || {
609+ let ( c1, c2) = object
610+ . call_method0 ( method_name) ?
611+ . extract :: < ( Bound < ' _ , PyCapsule > , Bound < ' _ , PyCapsule > ) > ( ) ?;
612+ Ok ( (
613+ c1. pointer_checked ( Some ( capsule1_name) ) ?. cast :: < T1 > ( ) ,
614+ c2. pointer_checked ( Some ( capsule2_name) ) ?. cast :: < T2 > ( ) ,
615+ ) )
616+ } ) ( )
617+ . map_err ( |e| {
618+ wrapping_type_error (
619+ object. py ( ) ,
620+ e,
621+ format ! (
622+ "Expected {method_name} to return a tuple of ({}, {}) capsules." ,
623+ capsule1_name. to_str( ) . unwrap( ) ,
624+ capsule2_name. to_str( ) . unwrap( )
625+ ) ,
626+ )
627+ } )
628+ }
629+
630+ fn wrapping_type_error ( py : Python < ' _ > , error : PyErr , message : String ) -> PyErr {
631+ let e = PyTypeError :: new_err ( message) ;
632+ e. set_cause ( py, Some ( error) ) ;
633+ e
634+ }
0 commit comments