Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions arrow-cast/src/cast/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::cast::can_cast_types;
use crate::cast_with_options;
use arrow_array::{Array, ArrayRef, UnionArray};
use arrow_schema::{ArrowError, DataType, FieldRef, UnionFields};
use arrow_select::union_extract::union_extract;
use arrow_select::union_extract::union_extract_by_id;

use super::CastOptions;

Expand Down Expand Up @@ -64,7 +64,7 @@ fn same_type_family(a: &DataType, b: &DataType) -> bool {
pub(crate) fn resolve_child_array<'a>(
fields: &'a UnionFields,
target_type: &DataType,
) -> Option<&'a FieldRef> {
) -> Option<(i8, &'a FieldRef)> {
fields
.iter()
.find(|(_, f)| f.data_type() == target_type)
Expand All @@ -84,7 +84,6 @@ pub(crate) fn resolve_child_array<'a>(
.iter()
.find(|(_, f)| can_cast_types(f.data_type(), target_type))
})
.map(|(_, f)| f)
}

/// Extracts the best-matching child array from a [`UnionArray`] for a given target type,
Expand Down Expand Up @@ -137,7 +136,7 @@ pub fn union_extract_by_type(
_ => unreachable!("union_extract_by_type called on non-union array"),
};

let Some(field) = resolve_child_array(fields, target_type) else {
let Some((type_id, _)) = resolve_child_array(fields, target_type) else {
return Err(ArrowError::CastError(format!(
"cannot cast Union with fields {} to {}",
fields
Expand All @@ -149,7 +148,7 @@ pub fn union_extract_by_type(
)));
};

let extracted = union_extract(union_array, field.name())?;
let extracted = union_extract_by_id(union_array, type_id)?;

if extracted.data_type() == target_type {
return Ok(extracted);
Expand Down Expand Up @@ -355,6 +354,63 @@ mod tests {
assert!(!arr.value(2));
}

// duplicate field names: ensure we resolve by type_id, not field name.
// Union has two children both named "val" — Int32 (type_id 0) and Utf8 (type_id 1).
// Casting to Utf8 should select the Utf8 child (type_id 1), not the Int32 child (type_id 0).
#[test]
fn test_duplicate_field_names() {
let fields = UnionFields::try_new(
[0, 1],
[
Field::new("val", DataType::Int32, true),
Field::new("val", DataType::Utf8, true),
],
)
.unwrap();

let target = DataType::Utf8;

let sparse = UnionArray::try_new(
fields.clone(),
vec![0_i8, 1, 0, 1].into(),
None,
vec![
Arc::new(Int32Array::from(vec![Some(42), None, Some(99), None])) as ArrayRef,
Arc::new(StringArray::from(vec![
None,
Some("hello"),
None,
Some("world"),
])),
],
)
.unwrap();

let result = cast::cast(&sparse, &target).unwrap();
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
assert!(arr.is_null(0));
assert_eq!(arr.value(1), "hello");
assert!(arr.is_null(2));
assert_eq!(arr.value(3), "world");

let dense = UnionArray::try_new(
fields,
vec![0_i8, 1, 1].into(),
Some(vec![0_i32, 0, 1].into()),
vec![
Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
Arc::new(StringArray::from(vec![Some("hello"), Some("world")])),
],
)
.unwrap();

let result = cast::cast(&dense, &target).unwrap();
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
assert!(arr.is_null(0));
assert_eq!(arr.value(1), "hello");
assert_eq!(arr.value(2), "world");
}

// no matching child array, all three passes fail.
// Union(Int32, Utf8) targeting Struct({x: Int32}). neither Int32 nor Utf8
// can be cast to a Struct, so both can_cast_types and cast return errors.
Expand Down
140 changes: 139 additions & 1 deletion arrow-select/src/union_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,40 @@ pub fn union_extract(union_array: &UnionArray, target: &str) -> Result<ArrayRef,
ArrowError::InvalidArgumentError(format!("field {target} not found on union"))
})?;

union_extract_impl(union_array, fields, target_type_id)
}

/// Like [`union_extract`], but selects the child by `type_id` rather than by
/// field name.
///
/// This avoids ambiguity when the union contains duplicate field names.
///
/// # Errors
///
/// Returns error if `target_type_id` does not correspond to a field in the union.
pub fn union_extract_by_id(
union_array: &UnionArray,
target_type_id: i8,
) -> Result<ArrayRef, ArrowError> {
let fields = match union_array.data_type() {
DataType::Union(fields, _) => fields,
_ => unreachable!(),
};

if fields.iter().all(|(id, _)| id != target_type_id) {
return Err(ArrowError::InvalidArgumentError(format!(
"type_id {target_type_id} not found on union"
)));
}

union_extract_impl(union_array, fields, target_type_id)
}

fn union_extract_impl(
union_array: &UnionArray,
fields: &UnionFields,
target_type_id: i8,
) -> Result<ArrayRef, ArrowError> {
match union_array.offsets() {
Some(_) => extract_dense(union_array, fields, target_type_id),
None => extract_sparse(union_array, fields, target_type_id),
Expand Down Expand Up @@ -399,7 +433,9 @@ fn is_sequential_generic<const N: usize>(offsets: &[i32]) -> bool {

#[cfg(test)]
mod tests {
use super::{BoolValue, eq_scalar_inner, is_sequential_generic, union_extract};
use super::{
BoolValue, eq_scalar_inner, is_sequential_generic, union_extract, union_extract_by_id,
};
use arrow_array::{Array, Int32Array, NullArray, StringArray, UnionArray, new_null_array};
use arrow_buffer::{BooleanBuffer, ScalarBuffer};
use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
Expand Down Expand Up @@ -1236,4 +1272,106 @@ mod tests {
ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
);
}

#[test]
fn extract_by_id_sparse_duplicate_names() {
// Two fields with the same name "val" but different type_ids and types
let fields = UnionFields::try_new(
[0, 1],
[
Field::new("val", DataType::Int32, true),
Field::new("val", DataType::Utf8, true),
],
)
.unwrap();

let union = UnionArray::try_new(
fields,
vec![0_i8, 1, 0, 1].into(),
None,
vec![
Arc::new(Int32Array::from(vec![Some(42), None, Some(99), None])) as _,
Arc::new(StringArray::from(vec![
None,
Some("hello"),
None,
Some("world"),
])),
],
)
.unwrap();

// union_extract by name always returns type_id 0 (first match)
let by_name = union_extract(&union, "val").unwrap();
assert_eq!(
*by_name,
Int32Array::from(vec![Some(42), None, Some(99), None])
);

// union_extract_by_id can select type_id 1 (the Utf8 child)
let by_id = union_extract_by_id(&union, 1).unwrap();
assert_eq!(
*by_id,
StringArray::from(vec![None, Some("hello"), None, Some("world")])
);
}

#[test]
fn extract_by_id_dense_duplicate_names() {
let fields = UnionFields::try_new(
[0, 1],
[
Field::new("val", DataType::Int32, true),
Field::new("val", DataType::Utf8, true),
],
)
.unwrap();

let union = UnionArray::try_new(
fields,
vec![0_i8, 1, 0].into(),
Some(vec![0_i32, 0, 1].into()),
vec![
Arc::new(Int32Array::from(vec![Some(42), Some(99)])) as _,
Arc::new(StringArray::from(vec![Some("hello")])),
],
)
.unwrap();

// by type_id 0 → Int32 child
let by_id_0 = union_extract_by_id(&union, 0).unwrap();
assert_eq!(*by_id_0, Int32Array::from(vec![Some(42), None, Some(99)]));

// by type_id 1 → Utf8 child
let by_id_1 = union_extract_by_id(&union, 1).unwrap();
assert_eq!(*by_id_1, StringArray::from(vec![None, Some("hello"), None]));
}

#[test]
fn extract_by_id_not_found() {
let fields = UnionFields::try_new(
[0, 1],
[
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
],
)
.unwrap();

let union = UnionArray::try_new(
fields,
vec![0_i8, 1].into(),
None,
vec![
Arc::new(Int32Array::from(vec![Some(1), None])) as _,
Arc::new(StringArray::from(vec![None, Some("x")])),
],
)
.unwrap();

assert_eq!(
union_extract_by_id(&union, 5).unwrap_err().to_string(),
ArrowError::InvalidArgumentError("type_id 5 not found on union".into()).to_string()
);
}
}
Loading