Skip to content

Commit ec614b0

Browse files
avoid ambiguity when extracting union children
1 parent 9c6b42e commit ec614b0

File tree

2 files changed

+195
-6
lines changed

2 files changed

+195
-6
lines changed

arrow-cast/src/cast/union.rs

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::cast::can_cast_types;
2121
use crate::cast_with_options;
2222
use arrow_array::{Array, ArrayRef, UnionArray};
2323
use arrow_schema::{ArrowError, DataType, FieldRef, UnionFields};
24-
use arrow_select::union_extract::union_extract;
24+
use arrow_select::union_extract::union_extract_by_id;
2525

2626
use super::CastOptions;
2727

@@ -64,7 +64,7 @@ fn same_type_family(a: &DataType, b: &DataType) -> bool {
6464
pub(crate) fn resolve_child_array<'a>(
6565
fields: &'a UnionFields,
6666
target_type: &DataType,
67-
) -> Option<&'a FieldRef> {
67+
) -> Option<(i8, &'a FieldRef)> {
6868
fields
6969
.iter()
7070
.find(|(_, f)| f.data_type() == target_type)
@@ -84,7 +84,6 @@ pub(crate) fn resolve_child_array<'a>(
8484
.iter()
8585
.find(|(_, f)| can_cast_types(f.data_type(), target_type))
8686
})
87-
.map(|(_, f)| f)
8887
}
8988

9089
/// Extracts the best-matching child array from a [`UnionArray`] for a given target type,
@@ -137,7 +136,7 @@ pub fn union_extract_by_type(
137136
_ => unreachable!("union_extract_by_type called on non-union array"),
138137
};
139138

140-
let Some(field) = resolve_child_array(fields, target_type) else {
139+
let Some((type_id, _)) = resolve_child_array(fields, target_type) else {
141140
return Err(ArrowError::CastError(format!(
142141
"cannot cast Union with fields {} to {}",
143142
fields
@@ -149,7 +148,7 @@ pub fn union_extract_by_type(
149148
)));
150149
};
151150

152-
let extracted = union_extract(union_array, field.name())?;
151+
let extracted = union_extract_by_id(union_array, type_id)?;
153152

154153
if extracted.data_type() == target_type {
155154
return Ok(extracted);
@@ -355,6 +354,63 @@ mod tests {
355354
assert!(!arr.value(2));
356355
}
357356

357+
// duplicate field names: ensure we resolve by type_id, not field name.
358+
// Union has two children both named "val" — Int32 (type_id 0) and Utf8 (type_id 1).
359+
// Casting to Utf8 should select the Utf8 child (type_id 1), not the Int32 child (type_id 0).
360+
#[test]
361+
fn test_duplicate_field_names() {
362+
let fields = UnionFields::try_new(
363+
[0, 1],
364+
[
365+
Field::new("val", DataType::Int32, true),
366+
Field::new("val", DataType::Utf8, true),
367+
],
368+
)
369+
.unwrap();
370+
371+
let target = DataType::Utf8;
372+
373+
let sparse = UnionArray::try_new(
374+
fields.clone(),
375+
vec![0_i8, 1, 0, 1].into(),
376+
None,
377+
vec![
378+
Arc::new(Int32Array::from(vec![Some(42), None, Some(99), None])) as ArrayRef,
379+
Arc::new(StringArray::from(vec![
380+
None,
381+
Some("hello"),
382+
None,
383+
Some("world"),
384+
])),
385+
],
386+
)
387+
.unwrap();
388+
389+
let result = cast::cast(&sparse, &target).unwrap();
390+
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
391+
assert!(arr.is_null(0));
392+
assert_eq!(arr.value(1), "hello");
393+
assert!(arr.is_null(2));
394+
assert_eq!(arr.value(3), "world");
395+
396+
let dense = UnionArray::try_new(
397+
fields,
398+
vec![0_i8, 1, 1].into(),
399+
Some(vec![0_i32, 0, 1].into()),
400+
vec![
401+
Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
402+
Arc::new(StringArray::from(vec![Some("hello"), Some("world")])),
403+
],
404+
)
405+
.unwrap();
406+
407+
let result = cast::cast(&dense, &target).unwrap();
408+
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
409+
assert!(arr.is_null(0));
410+
assert_eq!(arr.value(1), "hello");
411+
assert_eq!(arr.value(2), "world");
412+
}
413+
358414
// no matching child array, all three passes fail.
359415
// Union(Int32, Utf8) targeting Struct({x: Int32}). neither Int32 nor Utf8
360416
// can be cast to a Struct, so both can_cast_types and cast return errors.

arrow-select/src/union_extract.rs

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,40 @@ pub fn union_extract(union_array: &UnionArray, target: &str) -> Result<ArrayRef,
8989
ArrowError::InvalidArgumentError(format!("field {target} not found on union"))
9090
})?;
9191

92+
union_extract_impl(union_array, fields, target_type_id)
93+
}
94+
95+
/// Like [`union_extract`], but selects the child by `type_id` rather than by
96+
/// field name.
97+
///
98+
/// This avoids ambiguity when the union contains duplicate field names.
99+
///
100+
/// # Errors
101+
///
102+
/// Returns error if `target_type_id` does not correspond to a field in the union.
103+
pub fn union_extract_by_id(
104+
union_array: &UnionArray,
105+
target_type_id: i8,
106+
) -> Result<ArrayRef, ArrowError> {
107+
let fields = match union_array.data_type() {
108+
DataType::Union(fields, _) => fields,
109+
_ => unreachable!(),
110+
};
111+
112+
if fields.iter().all(|(id, _)| id != target_type_id) {
113+
return Err(ArrowError::InvalidArgumentError(format!(
114+
"type_id {target_type_id} not found on union"
115+
)));
116+
}
117+
118+
union_extract_impl(union_array, fields, target_type_id)
119+
}
120+
121+
fn union_extract_impl(
122+
union_array: &UnionArray,
123+
fields: &UnionFields,
124+
target_type_id: i8,
125+
) -> Result<ArrayRef, ArrowError> {
92126
match union_array.offsets() {
93127
Some(_) => extract_dense(union_array, fields, target_type_id),
94128
None => extract_sparse(union_array, fields, target_type_id),
@@ -399,7 +433,9 @@ fn is_sequential_generic<const N: usize>(offsets: &[i32]) -> bool {
399433

400434
#[cfg(test)]
401435
mod tests {
402-
use super::{BoolValue, eq_scalar_inner, is_sequential_generic, union_extract};
436+
use super::{
437+
BoolValue, eq_scalar_inner, is_sequential_generic, union_extract, union_extract_by_id,
438+
};
403439
use arrow_array::{Array, Int32Array, NullArray, StringArray, UnionArray, new_null_array};
404440
use arrow_buffer::{BooleanBuffer, ScalarBuffer};
405441
use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
@@ -1236,4 +1272,101 @@ mod tests {
12361272
ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
12371273
);
12381274
}
1275+
1276+
#[test]
1277+
fn extract_by_id_sparse_duplicate_names() {
1278+
// Two fields with the same name "val" but different type_ids and types
1279+
let fields = UnionFields::try_new(
1280+
[0, 1],
1281+
[
1282+
Field::new("val", DataType::Int32, true),
1283+
Field::new("val", DataType::Utf8, true),
1284+
],
1285+
)
1286+
.unwrap();
1287+
1288+
let union = UnionArray::try_new(
1289+
fields,
1290+
vec![0_i8, 1, 0, 1].into(),
1291+
None,
1292+
vec![
1293+
Arc::new(Int32Array::from(vec![Some(42), None, Some(99), None])) as _,
1294+
Arc::new(StringArray::from(vec![None, Some("hello"), None, Some("world")])),
1295+
],
1296+
)
1297+
.unwrap();
1298+
1299+
// union_extract by name always returns type_id 0 (first match)
1300+
let by_name = union_extract(&union, "val").unwrap();
1301+
assert_eq!(*by_name, Int32Array::from(vec![Some(42), None, Some(99), None]));
1302+
1303+
// union_extract_by_id can select type_id 1 (the Utf8 child)
1304+
let by_id = union_extract_by_id(&union, 1).unwrap();
1305+
assert_eq!(
1306+
*by_id,
1307+
StringArray::from(vec![None, Some("hello"), None, Some("world")])
1308+
);
1309+
}
1310+
1311+
#[test]
1312+
fn extract_by_id_dense_duplicate_names() {
1313+
let fields = UnionFields::try_new(
1314+
[0, 1],
1315+
[
1316+
Field::new("val", DataType::Int32, true),
1317+
Field::new("val", DataType::Utf8, true),
1318+
],
1319+
)
1320+
.unwrap();
1321+
1322+
let union = UnionArray::try_new(
1323+
fields,
1324+
vec![0_i8, 1, 0].into(),
1325+
Some(vec![0_i32, 0, 1].into()),
1326+
vec![
1327+
Arc::new(Int32Array::from(vec![Some(42), Some(99)])) as _,
1328+
Arc::new(StringArray::from(vec![Some("hello")])),
1329+
],
1330+
)
1331+
.unwrap();
1332+
1333+
// by type_id 0 → Int32 child
1334+
let by_id_0 = union_extract_by_id(&union, 0).unwrap();
1335+
assert_eq!(*by_id_0, Int32Array::from(vec![Some(42), None, Some(99)]));
1336+
1337+
// by type_id 1 → Utf8 child
1338+
let by_id_1 = union_extract_by_id(&union, 1).unwrap();
1339+
assert_eq!(
1340+
*by_id_1,
1341+
StringArray::from(vec![None, Some("hello"), None])
1342+
);
1343+
}
1344+
1345+
#[test]
1346+
fn extract_by_id_not_found() {
1347+
let fields = UnionFields::try_new(
1348+
[0, 1],
1349+
[
1350+
Field::new("a", DataType::Int32, true),
1351+
Field::new("b", DataType::Utf8, true),
1352+
],
1353+
)
1354+
.unwrap();
1355+
1356+
let union = UnionArray::try_new(
1357+
fields,
1358+
vec![0_i8, 1].into(),
1359+
None,
1360+
vec![
1361+
Arc::new(Int32Array::from(vec![Some(1), None])) as _,
1362+
Arc::new(StringArray::from(vec![None, Some("x")])),
1363+
],
1364+
)
1365+
.unwrap();
1366+
1367+
assert_eq!(
1368+
union_extract_by_id(&union, 5).unwrap_err().to_string(),
1369+
ArrowError::InvalidArgumentError("type_id 5 not found on union".into()).to_string()
1370+
);
1371+
}
12391372
}

0 commit comments

Comments
 (0)