6161
6262use std:: convert:: { From , TryFrom } ;
6363use std:: ffi:: CStr ;
64+ use std:: ptr:: NonNull ;
6465use std:: sync:: Arc ;
6566
6667use arrow_array:: ffi;
@@ -77,15 +78,12 @@ use pyo3::ffi::Py_uintptr_t;
7778use pyo3:: import_exception;
7879use pyo3:: prelude:: * ;
7980use pyo3:: sync:: PyOnceLock ;
80- use pyo3:: types:: { PyCapsule , PyDict , PyList , PyTuple , PyType } ;
81+ use pyo3:: types:: { PyCapsule , PyDict , PyList , PyType } ;
8182
8283import_exception ! ( pyarrow, ArrowException ) ;
8384/// Represents an exception raised by PyArrow.
8485pub type PyArrowException = ArrowException ;
8586
86- const ARROW_ARRAY_STREAM_CAPSULE_NAME : & CStr = c"arrow_array_stream" ;
87- const ARROW_SCHEMA_CAPSULE_NAME : & CStr = c"arrow_schema" ;
88- const ARROW_ARRAY_CAPSULE_NAME : & CStr = c"arrow_array" ;
8987
9088fn to_py_err ( err : ArrowError ) -> PyErr {
9189 PyArrowException :: new_err ( err. to_string ( ) )
@@ -131,54 +129,16 @@ fn validate_class(expected: &Bound<PyType>, value: &Bound<PyAny>) -> PyResult<()
131129 Ok ( ( ) )
132130}
133131
134- fn validate_pycapsule ( capsule : & Bound < PyCapsule > , name : & str ) -> PyResult < ( ) > {
135- let capsule_name = capsule. name ( ) ?;
136- if capsule_name. is_none ( ) {
137- return Err ( PyValueError :: new_err (
138- "Expected schema PyCapsule to have name set." ,
139- ) ) ;
140- }
141-
142- let capsule_name = unsafe { capsule_name. unwrap ( ) . as_cstr ( ) . to_str ( ) ? } ;
143- if capsule_name != name {
144- return Err ( PyValueError :: new_err ( format ! (
145- "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'" ,
146- ) ) ) ;
147- }
148-
149- Ok ( ( ) )
150- }
151-
152- fn extract_arrow_c_array_capsules < ' py > (
153- value : & Bound < ' py , PyAny > ,
154- ) -> PyResult < ( Bound < ' py , PyCapsule > , Bound < ' py , PyCapsule > ) > {
155- let tuple = value. call_method0 ( "__arrow_c_array__" ) ?;
156-
157- if !tuple. is_instance_of :: < PyTuple > ( ) {
158- return Err ( PyTypeError :: new_err (
159- "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules." ,
160- ) ) ;
161- }
162-
163- tuple. extract ( ) . map_err ( |_| {
164- PyTypeError :: new_err (
165- "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules." ,
166- )
167- } )
168- }
169-
170132impl FromPyArrow for DataType {
171133 fn from_pyarrow_bound ( value : & Bound < PyAny > ) -> PyResult < Self > {
172134 // Newer versions of PyArrow as well as other libraries with Arrow data implement this
173135 // method, so prefer it over _export_to_c.
174136 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
175137 if value. hasattr ( "__arrow_c_schema__" ) ? {
176- let capsule = value. call_method0 ( "__arrow_c_schema__" ) ?. extract ( ) ?;
177- validate_pycapsule ( & capsule, "arrow_schema" ) ?;
178-
179- let schema_ptr = capsule
180- . pointer_checked ( Some ( ARROW_SCHEMA_CAPSULE_NAME ) ) ?
181- . cast :: < FFI_ArrowSchema > ( ) ;
138+ let schema_ptr = extract_capsule_from_method :: < FFI_ArrowSchema > (
139+ value,
140+ "__arrow_c_schema__" ,
141+ ) ?;
182142 return unsafe { DataType :: try_from ( schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ;
183143 }
184144
@@ -203,12 +163,10 @@ impl FromPyArrow for Field {
203163 // method, so prefer it over _export_to_c.
204164 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
205165 if value. hasattr ( "__arrow_c_schema__" ) ? {
206- let capsule = value. call_method0 ( "__arrow_c_schema__" ) ?. extract ( ) ?;
207- validate_pycapsule ( & capsule, "arrow_schema" ) ?;
208-
209- let schema_ptr = capsule
210- . pointer_checked ( Some ( ARROW_SCHEMA_CAPSULE_NAME ) ) ?
211- . cast :: < FFI_ArrowSchema > ( ) ;
166+ let schema_ptr = extract_capsule_from_method :: < FFI_ArrowSchema > (
167+ value,
168+ "__arrow_c_schema__" ,
169+ ) ?;
212170 return unsafe { Field :: try_from ( schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ;
213171 }
214172
@@ -233,12 +191,10 @@ impl FromPyArrow for Schema {
233191 // method, so prefer it over _export_to_c.
234192 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
235193 if value. hasattr ( "__arrow_c_schema__" ) ? {
236- let capsule = value. call_method0 ( "__arrow_c_schema__" ) ?. extract ( ) ?;
237- validate_pycapsule ( & capsule, "arrow_schema" ) ?;
238-
239- let schema_ptr = capsule
240- . pointer_checked ( Some ( ARROW_SCHEMA_CAPSULE_NAME ) ) ?
241- . cast :: < FFI_ArrowSchema > ( ) ;
194+ let schema_ptr = extract_capsule_from_method :: < FFI_ArrowSchema > (
195+ value,
196+ "__arrow_c_schema__" ,
197+ ) ?;
242198 return unsafe { Schema :: try_from ( schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ;
243199 }
244200
@@ -263,22 +219,12 @@ impl FromPyArrow for ArrayData {
263219 // method, so prefer it over _export_to_c.
264220 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
265221 if value. hasattr ( "__arrow_c_array__" ) ? {
266- let ( schema_capsule, array_capsule) = extract_arrow_c_array_capsules ( value) ?;
267-
268- validate_pycapsule ( & schema_capsule, "arrow_schema" ) ?;
269- validate_pycapsule ( & array_capsule, "arrow_array" ) ?;
270-
271- let schema_ptr = schema_capsule
272- . pointer_checked ( Some ( ARROW_SCHEMA_CAPSULE_NAME ) ) ?
273- . cast :: < FFI_ArrowSchema > ( ) ;
274- let array = unsafe {
275- FFI_ArrowArray :: from_raw (
276- array_capsule
277- . pointer_checked ( Some ( ARROW_ARRAY_CAPSULE_NAME ) ) ?
278- . cast :: < FFI_ArrowArray > ( )
279- . as_ptr ( ) ,
280- )
281- } ;
222+ let ( schema_ptr, array_ptr) =
223+ extract_capsule_pair_from_method :: < FFI_ArrowSchema , FFI_ArrowArray > (
224+ value,
225+ "__arrow_c_array__" ,
226+ ) ?;
227+ let array = unsafe { FFI_ArrowArray :: from_raw ( array_ptr. as_ptr ( ) ) } ;
282228 return unsafe { ffi:: from_ffi ( array, schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ;
283229 }
284230
@@ -341,23 +287,17 @@ impl FromPyArrow for RecordBatch {
341287 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
342288
343289 if value. hasattr ( "__arrow_c_array__" ) ? {
344- let ( schema_capsule, array_capsule) = extract_arrow_c_array_capsules ( value) ?;
345-
346- validate_pycapsule ( & schema_capsule, "arrow_schema" ) ?;
347- validate_pycapsule ( & array_capsule, "arrow_array" ) ?;
348-
349- let schema_ptr = schema_capsule
350- . pointer_checked ( Some ( ARROW_SCHEMA_CAPSULE_NAME ) ) ?
351- . cast :: < FFI_ArrowSchema > ( ) ;
352- let array_ptr = array_capsule
353- . pointer_checked ( Some ( ARROW_ARRAY_CAPSULE_NAME ) ) ?
354- . cast :: < FFI_ArrowArray > ( ) ;
290+ let ( schema_ptr, array_ptr) =
291+ extract_capsule_pair_from_method :: < FFI_ArrowSchema , FFI_ArrowArray > (
292+ value,
293+ "__arrow_c_array__" ,
294+ ) ?;
355295 let ffi_array = unsafe { FFI_ArrowArray :: from_raw ( array_ptr. as_ptr ( ) ) } ;
356296 let mut array_data =
357297 unsafe { ffi:: from_ffi ( ffi_array, schema_ptr. as_ref ( ) ) } . map_err ( to_py_err) ?;
358298 if !matches ! ( array_data. data_type( ) , DataType :: Struct ( _) ) {
359299 return Err ( PyTypeError :: new_err (
360- "Expected Struct type from __arrow_c_array." ,
300+ format ! ( "Expected Struct type from __arrow_c_array__, found {}." , array_data . data_type ( ) ) ,
361301 ) ) ;
362302 }
363303 let options = RecordBatchOptions :: default ( ) . with_row_count ( Some ( array_data. len ( ) ) ) ;
@@ -421,18 +361,11 @@ impl FromPyArrow for ArrowArrayStreamReader {
421361 // method, so prefer it over _export_to_c.
422362 // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
423363 if value. hasattr ( "__arrow_c_stream__" ) ? {
424- let capsule = value. call_method0 ( "__arrow_c_stream__" ) ?. extract ( ) ?;
425-
426- validate_pycapsule ( & capsule, "arrow_array_stream" ) ?;
427-
428- let stream = unsafe {
429- FFI_ArrowArrayStream :: from_raw (
430- capsule
431- . pointer_checked ( Some ( ARROW_ARRAY_STREAM_CAPSULE_NAME ) ) ?
432- . cast :: < FFI_ArrowArrayStream > ( )
433- . as_ptr ( ) ,
434- )
435- } ;
364+ let stream_ptr = extract_capsule_from_method :: < FFI_ArrowArrayStream > (
365+ value,
366+ "__arrow_c_stream__" ,
367+ ) ?;
368+ let stream = unsafe { FFI_ArrowArrayStream :: from_raw ( stream_ptr. as_ptr ( ) ) } ;
436369
437370 let stream_reader = ArrowArrayStreamReader :: try_new ( stream)
438371 . map_err ( |err| PyValueError :: new_err ( err. to_string ( ) ) ) ?;
@@ -448,8 +381,7 @@ impl FromPyArrow for ArrowArrayStreamReader {
448381 // make the conversion through PyArrow's private API
449382 // this changes the pointer's memory and is thus unsafe.
450383 // In particular, `_export_to_c` can go out of bounds
451- let args = PyTuple :: new ( value. py ( ) , [ & raw mut stream as Py_uintptr_t ] ) ?;
452- value. call_method1 ( "_export_to_c" , args) ?;
384+ value. call_method1 ( "_export_to_c" , ( & raw mut stream as Py_uintptr_t , ) ) ?;
453385
454386 ArrowArrayStreamReader :: try_new ( stream)
455387 . map_err ( |err| PyValueError :: new_err ( err. to_string ( ) ) )
@@ -631,3 +563,74 @@ impl<T> From<T> for PyArrowType<T> {
631563 Self ( s)
632564 }
633565}
566+
567+ trait PyCapsuleType {
568+ const NAME : & CStr ;
569+ }
570+
571+ impl PyCapsuleType for FFI_ArrowSchema {
572+ const NAME : & CStr = c"arrow_schema" ;
573+ }
574+
575+ impl PyCapsuleType for FFI_ArrowArray {
576+ const NAME : & CStr = c"arrow_array" ;
577+ }
578+
579+ impl PyCapsuleType for FFI_ArrowArrayStream {
580+ const NAME : & CStr = c"arrow_array_stream" ;
581+ }
582+
583+ fn extract_capsule_from_method < T : PyCapsuleType > (
584+ object : & Bound < ' _ , PyAny > ,
585+ method_name : & ' static str ,
586+ ) -> PyResult < NonNull < T > > {
587+ ( || {
588+ Ok ( object
589+ . call_method0 ( method_name) ?
590+ . extract :: < Bound < ' _ , PyCapsule > > ( ) ?
591+ . pointer_checked ( Some ( T :: NAME ) ) ?
592+ . cast :: < T > ( ) )
593+ } ) ( )
594+ . map_err ( |e| {
595+ wrapping_type_error (
596+ object. py ( ) ,
597+ e,
598+ format ! (
599+ "Expected {method_name} to return a {} capsule." ,
600+ T :: NAME . to_str( ) . unwrap( ) ,
601+ ) ,
602+ )
603+ } )
604+ }
605+
606+ fn extract_capsule_pair_from_method < T1 : PyCapsuleType , T2 : PyCapsuleType > (
607+ object : & Bound < ' _ , PyAny > ,
608+ method_name : & ' static str ,
609+ ) -> PyResult < ( NonNull < T1 > , NonNull < T2 > ) > {
610+ ( || {
611+ let ( c1, c2) = object
612+ . call_method0 ( method_name) ?
613+ . extract :: < ( Bound < ' _ , PyCapsule > , Bound < ' _ , PyCapsule > ) > ( ) ?;
614+ Ok ( (
615+ c1. pointer_checked ( Some ( T1 :: NAME ) ) ?. cast :: < T1 > ( ) ,
616+ c2. pointer_checked ( Some ( T2 :: NAME ) ) ?. cast :: < T2 > ( ) ,
617+ ) )
618+ } ) ( )
619+ . map_err ( |e| {
620+ wrapping_type_error (
621+ object. py ( ) ,
622+ e,
623+ format ! (
624+ "Expected {method_name} to return a tuple of ({}, {}) capsules." ,
625+ T1 :: NAME . to_str( ) . unwrap( ) ,
626+ T2 :: NAME . to_str( ) . unwrap( )
627+ ) ,
628+ )
629+ } )
630+ }
631+
632+ fn wrapping_type_error ( py : Python < ' _ > , error : PyErr , message : String ) -> PyErr {
633+ let e = PyTypeError :: new_err ( message) ;
634+ e. set_cause ( py, Some ( error) ) ;
635+ e
636+ }
0 commit comments