diff --git a/arrow-cast/src/cast/union.rs b/arrow-cast/src/cast/union.rs index 7681e04356c8..89929e3c8888 100644 --- a/arrow-cast/src/cast/union.rs +++ b/arrow-cast/src/cast/union.rs @@ -21,7 +21,7 @@ use crate::cast::can_cast_types; use crate::cast_with_options; use arrow_array::{Array, ArrayRef, UnionArray}; use arrow_schema::{ArrowError, DataType, FieldRef, UnionFields}; -use arrow_select::union_extract::union_extract; +use arrow_select::union_extract::union_extract_by_id; use super::CastOptions; @@ -64,7 +64,7 @@ fn same_type_family(a: &DataType, b: &DataType) -> bool { pub(crate) fn resolve_child_array<'a>( fields: &'a UnionFields, target_type: &DataType, -) -> Option<&'a FieldRef> { +) -> Option<(i8, &'a FieldRef)> { fields .iter() .find(|(_, f)| f.data_type() == target_type) @@ -84,7 +84,6 @@ pub(crate) fn resolve_child_array<'a>( .iter() .find(|(_, f)| can_cast_types(f.data_type(), target_type)) }) - .map(|(_, f)| f) } /// Extracts the best-matching child array from a [`UnionArray`] for a given target type, @@ -137,7 +136,7 @@ pub fn union_extract_by_type( _ => unreachable!("union_extract_by_type called on non-union array"), }; - let Some(field) = resolve_child_array(fields, target_type) else { + let Some((type_id, _)) = resolve_child_array(fields, target_type) else { return Err(ArrowError::CastError(format!( "cannot cast Union with fields {} to {}", fields @@ -149,7 +148,7 @@ pub fn union_extract_by_type( ))); }; - let extracted = union_extract(union_array, field.name())?; + let extracted = union_extract_by_id(union_array, type_id)?; if extracted.data_type() == target_type { return Ok(extracted); @@ -355,6 +354,63 @@ mod tests { assert!(!arr.value(2)); } + // duplicate field names: ensure we resolve by type_id, not field name. + // Union has two children both named "val" — Int32 (type_id 0) and Utf8 (type_id 1). + // Casting to Utf8 should select the Utf8 child (type_id 1), not the Int32 child (type_id 0). + #[test] + fn test_duplicate_field_names() { + let fields = UnionFields::try_new( + [0, 1], + [ + Field::new("val", DataType::Int32, true), + Field::new("val", DataType::Utf8, true), + ], + ) + .unwrap(); + + let target = DataType::Utf8; + + let sparse = UnionArray::try_new( + fields.clone(), + vec![0_i8, 1, 0, 1].into(), + None, + vec![ + Arc::new(Int32Array::from(vec![Some(42), None, Some(99), None])) as ArrayRef, + Arc::new(StringArray::from(vec![ + None, + Some("hello"), + None, + Some("world"), + ])), + ], + ) + .unwrap(); + + let result = cast::cast(&sparse, &target).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert!(arr.is_null(0)); + assert_eq!(arr.value(1), "hello"); + assert!(arr.is_null(2)); + assert_eq!(arr.value(3), "world"); + + let dense = UnionArray::try_new( + fields, + vec![0_i8, 1, 1].into(), + Some(vec![0_i32, 0, 1].into()), + vec![ + Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef, + Arc::new(StringArray::from(vec![Some("hello"), Some("world")])), + ], + ) + .unwrap(); + + let result = cast::cast(&dense, &target).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert!(arr.is_null(0)); + assert_eq!(arr.value(1), "hello"); + assert_eq!(arr.value(2), "world"); + } + // no matching child array, all three passes fail. // Union(Int32, Utf8) targeting Struct({x: Int32}). neither Int32 nor Utf8 // can be cast to a Struct, so both can_cast_types and cast return errors. diff --git a/arrow-select/src/union_extract.rs b/arrow-select/src/union_extract.rs index 3accecc359fa..893b13554b1d 100644 --- a/arrow-select/src/union_extract.rs +++ b/arrow-select/src/union_extract.rs @@ -89,6 +89,40 @@ pub fn union_extract(union_array: &UnionArray, target: &str) -> Result Result { + let fields = match union_array.data_type() { + DataType::Union(fields, _) => fields, + _ => unreachable!(), + }; + + if fields.iter().all(|(id, _)| id != target_type_id) { + return Err(ArrowError::InvalidArgumentError(format!( + "type_id {target_type_id} not found on union" + ))); + } + + union_extract_impl(union_array, fields, target_type_id) +} + +fn union_extract_impl( + union_array: &UnionArray, + fields: &UnionFields, + target_type_id: i8, +) -> Result { match union_array.offsets() { Some(_) => extract_dense(union_array, fields, target_type_id), None => extract_sparse(union_array, fields, target_type_id), @@ -399,7 +433,9 @@ fn is_sequential_generic(offsets: &[i32]) -> bool { #[cfg(test)] mod tests { - use super::{BoolValue, eq_scalar_inner, is_sequential_generic, union_extract}; + use super::{ + BoolValue, eq_scalar_inner, is_sequential_generic, union_extract, union_extract_by_id, + }; use arrow_array::{Array, Int32Array, NullArray, StringArray, UnionArray, new_null_array}; use arrow_buffer::{BooleanBuffer, ScalarBuffer}; use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode}; @@ -1236,4 +1272,106 @@ mod tests { ArrowError::InvalidArgumentError("field a not found on union".into()).to_string() ); } + + #[test] + fn extract_by_id_sparse_duplicate_names() { + // Two fields with the same name "val" but different type_ids and types + let fields = UnionFields::try_new( + [0, 1], + [ + Field::new("val", DataType::Int32, true), + Field::new("val", DataType::Utf8, true), + ], + ) + .unwrap(); + + let union = UnionArray::try_new( + fields, + vec![0_i8, 1, 0, 1].into(), + None, + vec![ + Arc::new(Int32Array::from(vec![Some(42), None, Some(99), None])) as _, + Arc::new(StringArray::from(vec![ + None, + Some("hello"), + None, + Some("world"), + ])), + ], + ) + .unwrap(); + + // union_extract by name always returns type_id 0 (first match) + let by_name = union_extract(&union, "val").unwrap(); + assert_eq!( + *by_name, + Int32Array::from(vec![Some(42), None, Some(99), None]) + ); + + // union_extract_by_id can select type_id 1 (the Utf8 child) + let by_id = union_extract_by_id(&union, 1).unwrap(); + assert_eq!( + *by_id, + StringArray::from(vec![None, Some("hello"), None, Some("world")]) + ); + } + + #[test] + fn extract_by_id_dense_duplicate_names() { + let fields = UnionFields::try_new( + [0, 1], + [ + Field::new("val", DataType::Int32, true), + Field::new("val", DataType::Utf8, true), + ], + ) + .unwrap(); + + let union = UnionArray::try_new( + fields, + vec![0_i8, 1, 0].into(), + Some(vec![0_i32, 0, 1].into()), + vec![ + Arc::new(Int32Array::from(vec![Some(42), Some(99)])) as _, + Arc::new(StringArray::from(vec![Some("hello")])), + ], + ) + .unwrap(); + + // by type_id 0 → Int32 child + let by_id_0 = union_extract_by_id(&union, 0).unwrap(); + assert_eq!(*by_id_0, Int32Array::from(vec![Some(42), None, Some(99)])); + + // by type_id 1 → Utf8 child + let by_id_1 = union_extract_by_id(&union, 1).unwrap(); + assert_eq!(*by_id_1, StringArray::from(vec![None, Some("hello"), None])); + } + + #[test] + fn extract_by_id_not_found() { + let fields = UnionFields::try_new( + [0, 1], + [ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ], + ) + .unwrap(); + + let union = UnionArray::try_new( + fields, + vec![0_i8, 1].into(), + None, + vec![ + Arc::new(Int32Array::from(vec![Some(1), None])) as _, + Arc::new(StringArray::from(vec![None, Some("x")])), + ], + ) + .unwrap(); + + assert_eq!( + union_extract_by_id(&union, 5).unwrap_err().to_string(), + ArrowError::InvalidArgumentError("type_id 5 not found on union".into()).to_string() + ); + } }