@@ -75,16 +75,15 @@ use arrow_data::ArrayData;
7575use arrow_schema:: { ArrowError , DataType , Field , Schema , SchemaRef } ;
7676use pyo3:: exceptions:: { PyTypeError , PyValueError } ;
7777use pyo3:: ffi:: Py_uintptr_t ;
78- use pyo3:: import_exception;
7978use pyo3:: prelude:: * ;
8079use pyo3:: sync:: PyOnceLock ;
8180use pyo3:: types:: { PyCapsule , PyDict , PyList , PyType } ;
81+ use pyo3:: { CastError , import_exception} ;
8282
8383import_exception ! ( pyarrow, ArrowException ) ;
8484/// Represents an exception raised by PyArrow.
8585pub type PyArrowException = ArrowException ;
8686
87-
8887fn to_py_err ( err : ArrowError ) -> PyErr {
8988 PyArrowException :: new_err ( err. to_string ( ) )
9089}
@@ -135,10 +134,8 @@ impl FromPyArrow for DataType {
135134 // method, so prefer it over _export_to_c.
136135 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
137136 if value. hasattr ( "__arrow_c_schema__" ) ? {
138- let schema_ptr = extract_capsule_from_method :: < FFI_ArrowSchema > (
139- value,
140- "__arrow_c_schema__" ,
141- ) ?;
137+ let capsule = call_capsule_method ( value, "__arrow_c_schema__" ) ?;
138+ let schema_ptr = extract_capsule :: < FFI_ArrowSchema > ( & capsule, "__arrow_c_schema__" ) ?;
142139 return unsafe { DataType :: try_from ( schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ;
143140 }
144141
@@ -163,10 +160,8 @@ impl FromPyArrow for Field {
163160 // method, so prefer it over _export_to_c.
164161 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
165162 if value. hasattr ( "__arrow_c_schema__" ) ? {
166- let schema_ptr = extract_capsule_from_method :: < FFI_ArrowSchema > (
167- value,
168- "__arrow_c_schema__" ,
169- ) ?;
163+ let capsule = call_capsule_method ( value, "__arrow_c_schema__" ) ?;
164+ let schema_ptr = extract_capsule :: < FFI_ArrowSchema > ( & capsule, "__arrow_c_schema__" ) ?;
170165 return unsafe { Field :: try_from ( schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ;
171166 }
172167
@@ -191,10 +186,8 @@ impl FromPyArrow for Schema {
191186 // method, so prefer it over _export_to_c.
192187 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
193188 if value. hasattr ( "__arrow_c_schema__" ) ? {
194- let schema_ptr = extract_capsule_from_method :: < FFI_ArrowSchema > (
195- value,
196- "__arrow_c_schema__" ,
197- ) ?;
189+ let capsule = call_capsule_method ( value, "__arrow_c_schema__" ) ?;
190+ let schema_ptr = extract_capsule :: < FFI_ArrowSchema > ( & capsule, "__arrow_c_schema__" ) ?;
198191 return unsafe { Schema :: try_from ( schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ;
199192 }
200193
@@ -219,11 +212,10 @@ impl FromPyArrow for ArrayData {
219212 // method, so prefer it over _export_to_c.
220213 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
221214 if value. hasattr ( "__arrow_c_array__" ) ? {
222- let ( schema_ptr, array_ptr) =
223- extract_capsule_pair_from_method :: < FFI_ArrowSchema , FFI_ArrowArray > (
224- value,
225- "__arrow_c_array__" ,
226- ) ?;
215+ let ( schema_capsule, array_capsule) =
216+ call_capsule_pair_method ( value, "__arrow_c_array__" ) ?;
217+ let schema_ptr = extract_capsule ( & schema_capsule, "__arrow_c_array__" ) ?;
218+ let array_ptr = extract_capsule ( & array_capsule, "__arrow_c_array__" ) ?;
227219 let array = unsafe { FFI_ArrowArray :: from_raw ( array_ptr. as_ptr ( ) ) } ;
228220 return unsafe { ffi:: from_ffi ( array, schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ;
229221 }
@@ -287,18 +279,18 @@ impl FromPyArrow for RecordBatch {
287279 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
288280
289281 if value. hasattr ( "__arrow_c_array__" ) ? {
290- let ( schema_ptr, array_ptr) =
291- extract_capsule_pair_from_method :: < FFI_ArrowSchema , FFI_ArrowArray > (
292- value,
293- "__arrow_c_array__" ,
294- ) ?;
282+ let ( schema_capsule, array_capsule) =
283+ call_capsule_pair_method ( value, "__arrow_c_array__" ) ?;
284+ let schema_ptr = extract_capsule ( & schema_capsule, "__arrow_c_array__" ) ?;
285+ let array_ptr = extract_capsule ( & array_capsule, "__arrow_c_array__" ) ?;
295286 let ffi_array = unsafe { FFI_ArrowArray :: from_raw ( array_ptr. as_ptr ( ) ) } ;
296287 let mut array_data =
297288 unsafe { ffi:: from_ffi ( ffi_array, schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ?;
298289 if !matches ! ( array_data. data_type( ) , DataType :: Struct ( _) ) {
299- return Err ( PyTypeError :: new_err (
300- format ! ( "Expected Struct type from __arrow_c_array__, found {}." , array_data. data_type( ) ) ,
301- ) ) ;
290+ return Err ( PyTypeError :: new_err ( format ! (
291+ "Expected Struct type from __arrow_c_array__, found {}." ,
292+ array_data. data_type( )
293+ ) ) ) ;
302294 }
303295 let options = RecordBatchOptions :: default ( ) . with_row_count ( Some ( array_data. len ( ) ) ) ;
304296 // Ensure data is aligned (by potentially copying the buffers).
@@ -361,10 +353,8 @@ impl FromPyArrow for ArrowArrayStreamReader {
361353 // method, so prefer it over _export_to_c.
362354 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
363355 if value. hasattr ( "__arrow_c_stream__" ) ? {
364- let stream_ptr = extract_capsule_from_method :: < FFI_ArrowArrayStream > (
365- value,
366- "__arrow_c_stream__" ,
367- ) ?;
356+ let capsule = call_capsule_method ( value, "__arrow_c_stream__" ) ?;
357+ let stream_ptr = extract_capsule ( & capsule, "__arrow_c_stream__" ) ?;
368358 let stream = unsafe { FFI_ArrowArrayStream :: from_raw ( stream_ptr. as_ptr ( ) ) } ;
369359
370360 let stream_reader = ArrowArrayStreamReader :: try_new ( stream)
@@ -564,6 +554,35 @@ impl<T> From<T> for PyArrowType<T> {
564554 }
565555}
566556
557+ fn call_capsule_method < ' py > (
558+ object : & Bound < ' py , PyAny > ,
559+ method_name : & ' static str ,
560+ ) -> PyResult < Bound < ' py , PyCapsule > > {
561+ object
562+ . call_method0 ( method_name) ?
563+ . extract ( )
564+ . map_err ( |e : CastError | {
565+ wrapping_type_error (
566+ object. py ( ) ,
567+ e. into ( ) ,
568+ format ! ( "Expected {method_name} to return a capsule." ) ,
569+ )
570+ } )
571+ }
572+
573+ fn call_capsule_pair_method < ' py > (
574+ object : & Bound < ' py , PyAny > ,
575+ method_name : & ' static str ,
576+ ) -> PyResult < ( Bound < ' py , PyCapsule > , Bound < ' py , PyCapsule > ) > {
577+ object. call_method0 ( method_name) ?. extract ( ) . map_err ( |e| {
578+ wrapping_type_error (
579+ object. py ( ) ,
580+ e,
581+ format ! ( "Expected {method_name} to return a pair of capsule." ) ,
582+ )
583+ } )
584+ }
585+
567586trait PyCapsuleType {
568587 const NAME : & CStr ;
569588}
@@ -580,53 +599,23 @@ impl PyCapsuleType for FFI_ArrowArrayStream {
580599 const NAME : & CStr = c"arrow_array_stream" ;
581600}
582601
583- fn extract_capsule_from_method < T : PyCapsuleType > (
584- object : & Bound < ' _ , PyAny > ,
602+ fn extract_capsule < T : PyCapsuleType > (
603+ capsule : & Bound < PyCapsule > ,
585604 method_name : & ' static str ,
586605) -> PyResult < NonNull < T > > {
587- ( || {
588- Ok ( object
589- . call_method0 ( method_name) ?
590- . extract :: < Bound < ' _ , PyCapsule > > ( ) ?
591- . pointer_checked ( Some ( T :: NAME ) ) ?
592- . cast :: < T > ( ) )
593- } ) ( )
594- . map_err ( |e| {
595- wrapping_type_error (
596- object. py ( ) ,
597- e,
598- format ! (
599- "Expected {method_name} to return a {} capsule." ,
600- T :: NAME . to_str( ) . unwrap( ) ,
601- ) ,
602- )
603- } )
604- }
605-
606- fn extract_capsule_pair_from_method < T1 : PyCapsuleType , T2 : PyCapsuleType > (
607- object : & Bound < ' _ , PyAny > ,
608- method_name : & ' static str ,
609- ) -> PyResult < ( NonNull < T1 > , NonNull < T2 > ) > {
610- ( || {
611- let ( c1, c2) = object
612- . call_method0 ( method_name) ?
613- . extract :: < ( Bound < ' _ , PyCapsule > , Bound < ' _ , PyCapsule > ) > ( ) ?;
614- Ok ( (
615- c1. pointer_checked ( Some ( T1 :: NAME ) ) ?. cast :: < T1 > ( ) ,
616- c2. pointer_checked ( Some ( T2 :: NAME ) ) ?. cast :: < T2 > ( ) ,
617- ) )
618- } ) ( )
619- . map_err ( |e| {
620- wrapping_type_error (
621- object. py ( ) ,
622- e,
623- format ! (
624- "Expected {method_name} to return a tuple of ({}, {}) capsules." ,
625- T1 :: NAME . to_str( ) . unwrap( ) ,
626- T2 :: NAME . to_str( ) . unwrap( )
627- ) ,
628- )
629- } )
606+ Ok ( capsule
607+ . pointer_checked ( Some ( T :: NAME ) )
608+ . map_err ( |e| {
609+ wrapping_type_error (
610+ capsule. py ( ) ,
611+ e,
612+ format ! (
613+ "Expected {method_name} to return a {} capsule." ,
614+ T :: NAME . to_str( ) . unwrap( ) ,
615+ ) ,
616+ )
617+ } ) ?
618+ . cast :: < T > ( ) )
630619}
631620
632621fn wrapping_type_error ( py : Python < ' _ > , error : PyErr , message : String ) -> PyErr {
0 commit comments