@@ -89,6 +89,40 @@ pub fn union_extract(union_array: &UnionArray, target: &str) -> Result<ArrayRef,
8989 ArrowError :: InvalidArgumentError ( format ! ( "field {target} not found on union" ) )
9090 } ) ?;
9191
92+ union_extract_impl ( union_array, fields, target_type_id)
93+ }
94+
95+ /// Like [`union_extract`], but selects the child by `type_id` rather than by
96+ /// field name.
97+ ///
98+ /// This avoids ambiguity when the union contains duplicate field names.
99+ ///
100+ /// # Errors
101+ ///
102+ /// Returns error if `target_type_id` does not correspond to a field in the union.
103+ pub fn union_extract_by_id (
104+ union_array : & UnionArray ,
105+ target_type_id : i8 ,
106+ ) -> Result < ArrayRef , ArrowError > {
107+ let fields = match union_array. data_type ( ) {
108+ DataType :: Union ( fields, _) => fields,
109+ _ => unreachable ! ( ) ,
110+ } ;
111+
112+ if fields. iter ( ) . all ( |( id, _) | id != target_type_id) {
113+ return Err ( ArrowError :: InvalidArgumentError ( format ! (
114+ "type_id {target_type_id} not found on union"
115+ ) ) ) ;
116+ }
117+
118+ union_extract_impl ( union_array, fields, target_type_id)
119+ }
120+
121+ fn union_extract_impl (
122+ union_array : & UnionArray ,
123+ fields : & UnionFields ,
124+ target_type_id : i8 ,
125+ ) -> Result < ArrayRef , ArrowError > {
92126 match union_array. offsets ( ) {
93127 Some ( _) => extract_dense ( union_array, fields, target_type_id) ,
94128 None => extract_sparse ( union_array, fields, target_type_id) ,
@@ -399,7 +433,9 @@ fn is_sequential_generic<const N: usize>(offsets: &[i32]) -> bool {
399433
400434#[ cfg( test) ]
401435mod tests {
402- use super :: { BoolValue , eq_scalar_inner, is_sequential_generic, union_extract} ;
436+ use super :: {
437+ BoolValue , eq_scalar_inner, is_sequential_generic, union_extract, union_extract_by_id,
438+ } ;
403439 use arrow_array:: { Array , Int32Array , NullArray , StringArray , UnionArray , new_null_array} ;
404440 use arrow_buffer:: { BooleanBuffer , ScalarBuffer } ;
405441 use arrow_schema:: { ArrowError , DataType , Field , UnionFields , UnionMode } ;
@@ -1236,4 +1272,101 @@ mod tests {
12361272 ArrowError :: InvalidArgumentError ( "field a not found on union" . into( ) ) . to_string( )
12371273 ) ;
12381274 }
1275+
1276+ #[ test]
1277+ fn extract_by_id_sparse_duplicate_names ( ) {
1278+ // Two fields with the same name "val" but different type_ids and types
1279+ let fields = UnionFields :: try_new (
1280+ [ 0 , 1 ] ,
1281+ [
1282+ Field :: new ( "val" , DataType :: Int32 , true ) ,
1283+ Field :: new ( "val" , DataType :: Utf8 , true ) ,
1284+ ] ,
1285+ )
1286+ . unwrap ( ) ;
1287+
1288+ let union = UnionArray :: try_new (
1289+ fields,
1290+ vec ! [ 0_i8 , 1 , 0 , 1 ] . into ( ) ,
1291+ None ,
1292+ vec ! [
1293+ Arc :: new( Int32Array :: from( vec![ Some ( 42 ) , None , Some ( 99 ) , None ] ) ) as _,
1294+ Arc :: new( StringArray :: from( vec![ None , Some ( "hello" ) , None , Some ( "world" ) ] ) ) ,
1295+ ] ,
1296+ )
1297+ . unwrap ( ) ;
1298+
1299+ // union_extract by name always returns type_id 0 (first match)
1300+ let by_name = union_extract ( & union, "val" ) . unwrap ( ) ;
1301+ assert_eq ! ( * by_name, Int32Array :: from( vec![ Some ( 42 ) , None , Some ( 99 ) , None ] ) ) ;
1302+
1303+ // union_extract_by_id can select type_id 1 (the Utf8 child)
1304+ let by_id = union_extract_by_id ( & union, 1 ) . unwrap ( ) ;
1305+ assert_eq ! (
1306+ * by_id,
1307+ StringArray :: from( vec![ None , Some ( "hello" ) , None , Some ( "world" ) ] )
1308+ ) ;
1309+ }
1310+
1311+ #[ test]
1312+ fn extract_by_id_dense_duplicate_names ( ) {
1313+ let fields = UnionFields :: try_new (
1314+ [ 0 , 1 ] ,
1315+ [
1316+ Field :: new ( "val" , DataType :: Int32 , true ) ,
1317+ Field :: new ( "val" , DataType :: Utf8 , true ) ,
1318+ ] ,
1319+ )
1320+ . unwrap ( ) ;
1321+
1322+ let union = UnionArray :: try_new (
1323+ fields,
1324+ vec ! [ 0_i8 , 1 , 0 ] . into ( ) ,
1325+ Some ( vec ! [ 0_i32 , 0 , 1 ] . into ( ) ) ,
1326+ vec ! [
1327+ Arc :: new( Int32Array :: from( vec![ Some ( 42 ) , Some ( 99 ) ] ) ) as _,
1328+ Arc :: new( StringArray :: from( vec![ Some ( "hello" ) ] ) ) ,
1329+ ] ,
1330+ )
1331+ . unwrap ( ) ;
1332+
1333+ // by type_id 0 → Int32 child
1334+ let by_id_0 = union_extract_by_id ( & union, 0 ) . unwrap ( ) ;
1335+ assert_eq ! ( * by_id_0, Int32Array :: from( vec![ Some ( 42 ) , None , Some ( 99 ) ] ) ) ;
1336+
1337+ // by type_id 1 → Utf8 child
1338+ let by_id_1 = union_extract_by_id ( & union, 1 ) . unwrap ( ) ;
1339+ assert_eq ! (
1340+ * by_id_1,
1341+ StringArray :: from( vec![ None , Some ( "hello" ) , None ] )
1342+ ) ;
1343+ }
1344+
1345+ #[ test]
1346+ fn extract_by_id_not_found ( ) {
1347+ let fields = UnionFields :: try_new (
1348+ [ 0 , 1 ] ,
1349+ [
1350+ Field :: new ( "a" , DataType :: Int32 , true ) ,
1351+ Field :: new ( "b" , DataType :: Utf8 , true ) ,
1352+ ] ,
1353+ )
1354+ . unwrap ( ) ;
1355+
1356+ let union = UnionArray :: try_new (
1357+ fields,
1358+ vec ! [ 0_i8 , 1 ] . into ( ) ,
1359+ None ,
1360+ vec ! [
1361+ Arc :: new( Int32Array :: from( vec![ Some ( 1 ) , None ] ) ) as _,
1362+ Arc :: new( StringArray :: from( vec![ None , Some ( "x" ) ] ) ) ,
1363+ ] ,
1364+ )
1365+ . unwrap ( ) ;
1366+
1367+ assert_eq ! (
1368+ union_extract_by_id( & union , 5 ) . unwrap_err( ) . to_string( ) ,
1369+ ArrowError :: InvalidArgumentError ( "type_id 5 not found on union" . into( ) ) . to_string( )
1370+ ) ;
1371+ }
12391372}
0 commit comments