From 1cf208502623d6ac647154a8677a0375ce961bea Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:42:19 -0400 Subject: [PATCH 1/5] rework union casting to scalar --- arrow-cast/src/cast/mod.rs | 13 +- arrow-cast/src/cast/union.rs | 470 +++++++++++++++++++++++++++++++++++ 2 files changed, 482 insertions(+), 1 deletion(-) create mode 100644 arrow-cast/src/cast/union.rs diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 9f1eba1057fd..b07b06ff2331 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -43,6 +43,7 @@ mod list; mod map; mod run_array; mod string; +mod union; use crate::cast::decimal::*; use crate::cast::dictionary::*; @@ -50,6 +51,7 @@ use crate::cast::list::*; use crate::cast::map::*; use crate::cast::run_array::*; use crate::cast::string::*; +pub use crate::cast::union::*; use arrow_buffer::IntervalMonthDayNano; use arrow_data::ByteView; @@ -230,7 +232,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } (Struct(_), _) => false, (_, Struct(_)) => false, - + (Union(fields, _), _) => union::resolve_variant(fields, to_type).is_some(), + (_, Union(_, _)) => false, (_, Boolean) => from_type.is_integer() || from_type.is_floating() || from_type.is_string(), (Boolean, _) => to_type.is_integer() || to_type.is_floating() || to_type.is_string(), @@ -1180,6 +1183,14 @@ pub fn cast_with_options( (_, Struct(_)) => Err(ArrowError::CastError(format!( "Casting from {from_type} to {to_type} not supported" ))), + (Union(_, _), _) => union_extract_by_type( + array.as_any().downcast_ref::().unwrap(), + to_type, + cast_options, + ), + (_, Union(_, _)) => Err(ArrowError::CastError(format!( + "Casting from {from_type} to {to_type} not supported" + ))), (_, Boolean) => match from_type { UInt8 => cast_numeric_to_bool::(array), UInt16 => cast_numeric_to_bool::(array), diff --git a/arrow-cast/src/cast/union.rs b/arrow-cast/src/cast/union.rs new file mode 100644 index 000000000000..bd5a84e916dd --- /dev/null +++ b/arrow-cast/src/cast/union.rs @@ -0,0 +1,470 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Cast support for union arrays. + +use crate::cast::can_cast_types; +use crate::cast_with_options; +use arrow_array::{Array, ArrayRef, UnionArray, new_null_array}; +use arrow_schema::{ArrowError, DataType, FieldRef, UnionFields}; +use arrow_select::union_extract::union_extract; + +use super::CastOptions; + +// this is used during variant selection to prefer a "close" type over a distant cast +// for example: when targeting Utf8View, a Utf8 variant is preferred over Int32 despite both being castable +fn same_type_family(a: &DataType, b: &DataType) -> bool { + use DataType::*; + matches!( + (a, b), + (Utf8 | LargeUtf8 | Utf8View, Utf8 | LargeUtf8 | Utf8View) + | ( + Binary | LargeBinary | BinaryView, + Binary | LargeBinary | BinaryView + ) + | (Int8 | Int16 | Int32 | Int64, Int8 | Int16 | Int32 | Int64) + | ( + UInt8 | UInt16 | UInt32 | UInt64, + UInt8 | UInt16 | UInt32 | UInt64 + ) + | (Float16 | Float32 | Float64, Float16 | Float32 | Float64) + ) +} + +// variant selection heuristic — 3 passes with decreasing specificity: +// +// first pass: field type == target type +// second pass: field and target are in the same equivalence class +// (e.g., Utf8 and Utf8View are both strings) +// third pass: field can be cast to target +// note: this is the most permissive and may lose information +// also, the matching logic is greedy so it will pick the first 'castable' variant +// +// each pass picks the first matching variant by type_id order. +pub(crate) fn resolve_variant<'a>( + fields: &'a UnionFields, + target_type: &DataType, +) -> Option<&'a FieldRef> { + fields + .iter() + .find(|(_, f)| f.data_type() == target_type) + .or_else(|| { + fields + .iter() + .find(|(_, f)| same_type_family(f.data_type(), target_type)) + }) + .or_else(|| { + fields + .iter() + .find(|(_, f)| can_cast_types(f.data_type(), target_type)) + }) + .map(|(_, f)| f) +} + +/// Extracts the best-matching variant from a union array for a given target type, +/// and casts it to that type. +/// +/// Rows where a different variant is active become NULL. +/// If no variant matches, returns a null array. +/// +/// # Example +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_schema::{DataType, Field, UnionFields}; +/// # use arrow_array::{UnionArray, StringArray, Int32Array, Array}; +/// # use arrow_cast::cast::union_extract_by_type; +/// # use arrow_cast::CastOptions; +/// let fields = UnionFields::try_new( +/// [0, 1], +/// [ +/// Field::new("int", DataType::Int32, true), +/// Field::new("str", DataType::Utf8, true), +/// ], +/// ).unwrap(); +/// +/// let union = UnionArray::try_new( +/// fields, +/// vec![0, 1, 0].into(), +/// None, +/// vec![ +/// Arc::new(Int32Array::from(vec![Some(42), None, Some(99)])), +/// Arc::new(StringArray::from(vec![None, Some("hello"), None])), +/// ], +/// ) +/// .unwrap(); +/// +/// // extract the Utf8 variant and cast to Utf8View +/// let result = union_extract_by_type(&union, &DataType::Utf8View, &CastOptions::default()).unwrap(); +/// assert_eq!(result.data_type(), &DataType::Utf8View); +/// assert!(result.is_null(0)); // Int32 row -> NULL +/// assert!(!result.is_null(1)); // Utf8 row -> "hello" +/// assert!(result.is_null(2)); // Int32 row -> NULL +/// ``` +pub fn union_extract_by_type( + union_array: &UnionArray, + target_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let fields = match union_array.data_type() { + DataType::Union(fields, _) => fields, + _ => unreachable!("union_extract_by_type called on non-union array"), + }; + + let Some(field) = resolve_variant(fields, target_type) else { + return Ok(new_null_array(target_type, union_array.len())); + }; + + let extracted = union_extract(union_array, field.name())?; + + if extracted.data_type() == target_type { + return Ok(extracted); + } + + cast_with_options(&extracted, target_type, cast_options) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cast; + use arrow_array::*; + use arrow_schema::{Field, UnionFields, UnionMode}; + use std::sync::Arc; + + fn int_str_fields() -> UnionFields { + UnionFields::try_new( + [0, 1], + [ + Field::new("int", DataType::Int32, true), + Field::new("str", DataType::Utf8, true), + ], + ) + .unwrap() + } + + fn int_str_union_type(mode: UnionMode) -> DataType { + DataType::Union(int_str_fields(), mode) + } + + // pass 1: exact type match. + // Union(Int32, Utf8) targeting Utf8 — the Utf8 variant matches exactly. + // Int32 rows become NULL. tested for both sparse and dense. + #[test] + fn test_exact_type_match() { + let target = DataType::Utf8; + + // sparse + assert!(can_cast_types( + &int_str_union_type(UnionMode::Sparse), + &target + )); + + let sparse = UnionArray::try_new( + int_str_fields(), + vec![1_i8, 0, 1].into(), + None, + vec![ + Arc::new(Int32Array::from(vec![None, Some(42), None])) as ArrayRef, + Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])), + ], + ) + .unwrap(); + + let result = cast::cast(&sparse, &target).unwrap(); + assert_eq!(result.data_type(), &target); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "hello"); + assert!(arr.is_null(1)); + assert_eq!(arr.value(2), "world"); + + // dense + assert!(can_cast_types( + &int_str_union_type(UnionMode::Dense), + &target + )); + + let dense = UnionArray::try_new( + int_str_fields(), + vec![1_i8, 0, 1].into(), + Some(vec![0_i32, 0, 1].into()), + vec![ + Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef, + Arc::new(StringArray::from(vec![Some("hello"), Some("world")])), + ], + ) + .unwrap(); + + let result = cast::cast(&dense, &target).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "hello"); + assert!(arr.is_null(1)); + assert_eq!(arr.value(2), "world"); + } + + // pass 2: same type family match. + // Union(Int32, Utf8) targeting Utf8View — no exact match, but Utf8 and Utf8View + // are in the same family. picks the Utf8 variant and casts to Utf8View. + // this is the bug that motivated this work: without pass 2, pass 3 would + // greedily pick Int32 (since can_cast_types(Int32, Utf8View) is true). + #[test] + fn test_same_family_utf8_to_utf8view() { + let target = DataType::Utf8View; + + // sparse + assert!(can_cast_types( + &int_str_union_type(UnionMode::Sparse), + &target + )); + + let sparse = UnionArray::try_new( + int_str_fields(), + vec![1_i8, 0, 1, 1].into(), + None, + vec![ + Arc::new(Int32Array::from(vec![None, Some(42), None, None])) as ArrayRef, + Arc::new(StringArray::from(vec![ + Some("agent_alpha"), + None, + Some("agent_beta"), + None, + ])), + ], + ) + .unwrap(); + + let result = cast::cast(&sparse, &target).unwrap(); + assert_eq!(result.data_type(), &target); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "agent_alpha"); + assert!(arr.is_null(1)); + assert_eq!(arr.value(2), "agent_beta"); + assert!(arr.is_null(3)); + + // dense + assert!(can_cast_types( + &int_str_union_type(UnionMode::Dense), + &target + )); + + let dense = UnionArray::try_new( + int_str_fields(), + vec![1_i8, 0, 1].into(), + Some(vec![0_i32, 0, 1].into()), + vec![ + Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef, + Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])), + ], + ) + .unwrap(); + + let result = cast::cast(&dense, &target).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "alpha"); + assert!(arr.is_null(1)); + assert_eq!(arr.value(2), "beta"); + } + + // pass 3: one-directional cast across type families. + // Union(Int32, Utf8) targeting Boolean — no exact match, no family match. + // pass 3 picks Int32 (first variant where can_cast_types is true) and + // casts to Boolean (0 → false, nonzero → true). Utf8 rows become NULL. + #[test] + fn test_one_directional_cast() { + let target = DataType::Boolean; + + // sparse + assert!(can_cast_types( + &int_str_union_type(UnionMode::Sparse), + &target + )); + + let sparse = UnionArray::try_new( + int_str_fields(), + vec![0_i8, 1, 0].into(), + None, + vec![ + Arc::new(Int32Array::from(vec![Some(42), None, Some(0)])) as ArrayRef, + Arc::new(StringArray::from(vec![None, Some("hello"), None])), + ], + ) + .unwrap(); + + let result = cast::cast(&sparse, &target).unwrap(); + assert_eq!(result.data_type(), &target); + let arr = result.as_any().downcast_ref::().unwrap(); + assert!(arr.value(0)); + assert!(arr.is_null(1)); + assert!(!arr.value(2)); + + // dense + assert!(can_cast_types( + &int_str_union_type(UnionMode::Dense), + &target + )); + + let dense = UnionArray::try_new( + int_str_fields(), + vec![0_i8, 1, 0].into(), + Some(vec![0_i32, 0, 1].into()), + vec![ + Arc::new(Int32Array::from(vec![Some(42), Some(0)])) as ArrayRef, + Arc::new(StringArray::from(vec![Some("hello")])), + ], + ) + .unwrap(); + + let result = cast::cast(&dense, &target).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert!(arr.value(0)); + assert!(arr.is_null(1)); + assert!(!arr.value(2)); + } + + // no matching variant — all three passes fail. + // Union(Int32, Utf8) targeting Struct({x: Int32}). neither Int32 nor Utf8 + // can be cast to a Struct, so can_cast_types returns false and + // union_extract_by_type returns an all-null array. + #[test] + fn test_no_match_returns_nulls() { + let target = DataType::Struct(vec![Field::new("x", DataType::Int32, true)].into()); + + assert!(!can_cast_types( + &int_str_union_type(UnionMode::Sparse), + &target + )); + + let union = UnionArray::try_new( + int_str_fields(), + vec![0_i8, 1].into(), + None, + vec![ + Arc::new(Int32Array::from(vec![Some(42), None])) as ArrayRef, + Arc::new(StringArray::from(vec![None, Some("hello")])), + ], + ) + .unwrap(); + + let result = union_extract_by_type(&union, &target, &CastOptions::default()).unwrap(); + assert_eq!(result.data_type(), &target); + assert_eq!(result.null_count(), 2); + } + + // priority: exact match (pass 1) wins over family match (pass 2). + // Union(Utf8, Utf8View) targeting Utf8View — both variants are in the string + // family, but Utf8View is an exact match. pass 1 should pick it, not Utf8. + #[test] + fn test_exact_match_preferred_over_family() { + let fields = UnionFields::try_new( + [0, 1], + [ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Utf8View, true), + ], + ) + .unwrap(); + let target = DataType::Utf8View; + + assert!(can_cast_types( + &DataType::Union(fields.clone(), UnionMode::Sparse), + &target, + )); + + // [Utf8("from_a"), Utf8View("from_b"), Utf8("also_a")] + let union = UnionArray::try_new( + fields, + vec![0_i8, 1, 0].into(), + None, + vec![ + Arc::new(StringArray::from(vec![ + Some("from_a"), + None, + Some("also_a"), + ])) as ArrayRef, + Arc::new(StringViewArray::from(vec![None, Some("from_b"), None])), + ], + ) + .unwrap(); + + let result = cast::cast(&union, &target).unwrap(); + assert_eq!(result.data_type(), &target); + let arr = result.as_any().downcast_ref::().unwrap(); + + // pass 1 picks variant "b" (Utf8View), so variant "a" rows become NULL + assert!(arr.is_null(0)); + assert_eq!(arr.value(1), "from_b"); + assert!(arr.is_null(2)); + } + + // null values within the selected variant stay null. + // this is distinct from "wrong variant → NULL": here the correct variant + // is active but its value is null. + #[test] + fn test_null_in_selected_variant() { + let target = DataType::Utf8; + + assert!(can_cast_types( + &int_str_union_type(UnionMode::Sparse), + &target + )); + + // ["hello", NULL(str), "world"] + // all rows are the Utf8 variant, but row 1 has a null value + let union = UnionArray::try_new( + int_str_fields(), + vec![1_i8, 1, 1].into(), + None, + vec![ + Arc::new(Int32Array::from(vec![None, None, None])) as ArrayRef, + Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])), + ], + ) + .unwrap(); + + let result = cast::cast(&union, &target).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "hello"); + assert!(arr.is_null(1)); + assert_eq!(arr.value(2), "world"); + } + + // empty union array returns a zero-length result of the target type. + #[test] + fn test_empty_union() { + let target = DataType::Utf8View; + + assert!(can_cast_types( + &int_str_union_type(UnionMode::Sparse), + &target + )); + + let union = UnionArray::try_new( + int_str_fields(), + Vec::::new().into(), + None, + vec![ + Arc::new(Int32Array::from(Vec::>::new())) as ArrayRef, + Arc::new(StringArray::from(Vec::>::new())), + ], + ) + .unwrap(); + + let result = cast::cast(&union, &target).unwrap(); + assert_eq!(result.data_type(), &target); + assert_eq!(result.len(), 0); + } +} From b47a20adf8280cbe5beb02aceab8c8e836b69357 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Mon, 23 Mar 2026 12:24:01 -0400 Subject: [PATCH 2/5] fix up a test behavior --- arrow-cast/src/cast/mod.rs | 20 +++++++++--------- arrow-cast/src/cast/union.rs | 41 +++++++++++++++++++++++++++++------- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index b07b06ff2331..a77b958ce2d3 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -110,6 +110,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { can_cast_types(from_value_type, to_value_type) } (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), + (Union(fields, _), _) => union::resolve_variant(fields, to_type).is_some(), + (_, Union(_, _)) => false, (RunEndEncoded(_, value_type), _) => can_cast_types(value_type.data_type(), to_type), (_, RunEndEncoded(_, value_type)) => can_cast_types(from_type, value_type.data_type()), (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), @@ -232,8 +234,6 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } (Struct(_), _) => false, (_, Struct(_)) => false, - (Union(fields, _), _) => union::resolve_variant(fields, to_type).is_some(), - (_, Union(_, _)) => false, (_, Boolean) => from_type.is_integer() || from_type.is_floating() || from_type.is_string(), (Boolean, _) => to_type.is_integer() || to_type.is_floating() || to_type.is_string(), @@ -784,6 +784,14 @@ pub fn cast_with_options( ))), } } + (Union(_, _), _) => union_extract_by_type( + array.as_any().downcast_ref::().unwrap(), + to_type, + cast_options, + ), + (_, Union(_, _)) => Err(ArrowError::CastError(format!( + "Casting from {from_type} to {to_type} not supported" + ))), (Dictionary(index_type, _), _) => match **index_type { Int8 => dictionary_cast::(array, to_type, cast_options), Int16 => dictionary_cast::(array, to_type, cast_options), @@ -1183,14 +1191,6 @@ pub fn cast_with_options( (_, Struct(_)) => Err(ArrowError::CastError(format!( "Casting from {from_type} to {to_type} not supported" ))), - (Union(_, _), _) => union_extract_by_type( - array.as_any().downcast_ref::().unwrap(), - to_type, - cast_options, - ), - (_, Union(_, _)) => Err(ArrowError::CastError(format!( - "Casting from {from_type} to {to_type} not supported" - ))), (_, Boolean) => match from_type { UInt8 => cast_numeric_to_bool::(array), UInt16 => cast_numeric_to_bool::(array), diff --git a/arrow-cast/src/cast/union.rs b/arrow-cast/src/cast/union.rs index bd5a84e916dd..7946338ae439 100644 --- a/arrow-cast/src/cast/union.rs +++ b/arrow-cast/src/cast/union.rs @@ -19,7 +19,7 @@ use crate::cast::can_cast_types; use crate::cast_with_options; -use arrow_array::{Array, ArrayRef, UnionArray, new_null_array}; +use arrow_array::{Array, ArrayRef, UnionArray}; use arrow_schema::{ArrowError, DataType, FieldRef, UnionFields}; use arrow_select::union_extract::union_extract; @@ -45,6 +45,20 @@ fn same_type_family(a: &DataType, b: &DataType) -> bool { ) } +fn is_complex_container(dt: &DataType) -> bool { + use DataType::*; + matches!( + dt, + List(_) + | LargeList(_) + | ListView(_) + | LargeListView(_) + | FixedSizeList(_, _) + | Struct(_) + | Map(_, _) + ) +} + // variant selection heuristic — 3 passes with decreasing specificity: // // first pass: field type == target type @@ -68,6 +82,12 @@ pub(crate) fn resolve_variant<'a>( .find(|(_, f)| same_type_family(f.data_type(), target_type)) }) .or_else(|| { + // skip complex container types in pass 3 — union extraction introduces nulls, + // and casting nullable arrays to containers like List/Struct/Map can fail when + // inner fields are non-nullable. + if is_complex_container(target_type) { + return None; + } fields .iter() .find(|(_, f)| can_cast_types(f.data_type(), target_type)) @@ -126,7 +146,15 @@ pub fn union_extract_by_type( }; let Some(field) = resolve_variant(fields, target_type) else { - return Ok(new_null_array(target_type, union_array.len())); + return Err(ArrowError::CastError(format!( + "cannot cast Union with fields {} to {}", + fields + .iter() + .map(|(_, f)| f.data_type().to_string()) + .collect::>() + .join(", "), + target_type + ))); }; let extracted = union_extract(union_array, field.name())?; @@ -337,10 +365,9 @@ mod tests { // no matching variant — all three passes fail. // Union(Int32, Utf8) targeting Struct({x: Int32}). neither Int32 nor Utf8 - // can be cast to a Struct, so can_cast_types returns false and - // union_extract_by_type returns an all-null array. + // can be cast to a Struct, so both can_cast_types and cast return errors. #[test] - fn test_no_match_returns_nulls() { + fn test_no_match_errors() { let target = DataType::Struct(vec![Field::new("x", DataType::Int32, true)].into()); assert!(!can_cast_types( @@ -359,9 +386,7 @@ mod tests { ) .unwrap(); - let result = union_extract_by_type(&union, &target, &CastOptions::default()).unwrap(); - assert_eq!(result.data_type(), &target); - assert_eq!(result.null_count(), 2); + assert!(cast::cast(&union, &target).is_err()); } // priority: exact match (pass 1) wins over family match (pass 2). From 1f2aad70fa98ebcf529c5ea9e359e3882c1829a5 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Mon, 6 Apr 2026 09:19:14 -0400 Subject: [PATCH 3/5] use is_nested --- arrow-cast/src/cast/union.rs | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/arrow-cast/src/cast/union.rs b/arrow-cast/src/cast/union.rs index 7946338ae439..fb8ab8c8abca 100644 --- a/arrow-cast/src/cast/union.rs +++ b/arrow-cast/src/cast/union.rs @@ -45,20 +45,6 @@ fn same_type_family(a: &DataType, b: &DataType) -> bool { ) } -fn is_complex_container(dt: &DataType) -> bool { - use DataType::*; - matches!( - dt, - List(_) - | LargeList(_) - | ListView(_) - | LargeListView(_) - | FixedSizeList(_, _) - | Struct(_) - | Map(_, _) - ) -} - // variant selection heuristic — 3 passes with decreasing specificity: // // first pass: field type == target type @@ -82,10 +68,10 @@ pub(crate) fn resolve_variant<'a>( .find(|(_, f)| same_type_family(f.data_type(), target_type)) }) .or_else(|| { - // skip complex container types in pass 3 — union extraction introduces nulls, - // and casting nullable arrays to containers like List/Struct/Map can fail when - // inner fields are non-nullable. - if is_complex_container(target_type) { + // skip nested types in pass 3 — union extraction introduces nulls, + // and casting nullable arrays to nested types like List/Struct/Map can fail + // when inner fields are non-nullable. + if target_type.is_nested() { return None; } fields From cb9ef0751fd51021dc187f0523cfcb826066375d Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Mon, 6 Apr 2026 09:23:13 -0400 Subject: [PATCH 4/5] explain heuristic --- arrow-cast/src/cast/union.rs | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/arrow-cast/src/cast/union.rs b/arrow-cast/src/cast/union.rs index fb8ab8c8abca..03e5611c67f4 100644 --- a/arrow-cast/src/cast/union.rs +++ b/arrow-cast/src/cast/union.rs @@ -45,16 +45,22 @@ fn same_type_family(a: &DataType, b: &DataType) -> bool { ) } -// variant selection heuristic — 3 passes with decreasing specificity: -// -// first pass: field type == target type -// second pass: field and target are in the same equivalence class -// (e.g., Utf8 and Utf8View are both strings) -// third pass: field can be cast to target -// note: this is the most permissive and may lose information -// also, the matching logic is greedy so it will pick the first 'castable' variant -// -// each pass picks the first matching variant by type_id order. +/// Selects the best-matching child array from a [`UnionArray`] for a given target type +/// +/// The goal is to find the source field whose type is closest to the target, +/// so that the subsequent cast is as lossless as possible. The heuristic uses +/// three passes with decreasing specificity: +/// +/// 1. **Exact match**: field type equals the target type. +/// 2. **Same type family**: field and target belong to the same logical family +/// (e.g. `Utf8` and `Utf8View` are both strings). This avoids a greedy +/// cross-family cast in pass 3 (e.g. picking `Int32` over `Utf8` when the +/// target is `Utf8View`, since `can_cast_types(Int32, Utf8View)` is true) +/// 3. **Castable**:`can_cast_types` reports the field can be cast to the target +/// Nested target types are skipped here because union extraction introduces +/// nulls, which can conflict with non-nullable inner fields +/// +/// Each pass greedily picks the first matching field by type_id order pub(crate) fn resolve_variant<'a>( fields: &'a UnionFields, target_type: &DataType, @@ -81,7 +87,7 @@ pub(crate) fn resolve_variant<'a>( .map(|(_, f)| f) } -/// Extracts the best-matching variant from a union array for a given target type, +/// Extracts the best-matching child array from a [`UnionArray`] for a given target type, /// and casts it to that type. /// /// Rows where a different variant is active become NULL. From ecdefa8d5d568a6bb774935ab3cf07c8905cc51a Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Mon, 6 Apr 2026 09:32:21 -0400 Subject: [PATCH 5/5] rename variant -> child array --- arrow-cast/src/cast/mod.rs | 2 +- arrow-cast/src/cast/union.rs | 38 ++++++++++++++++++------------------ 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index a77b958ce2d3..a584c39fa635 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -110,7 +110,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { can_cast_types(from_value_type, to_value_type) } (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), - (Union(fields, _), _) => union::resolve_variant(fields, to_type).is_some(), + (Union(fields, _), _) => union::resolve_child_array(fields, to_type).is_some(), (_, Union(_, _)) => false, (RunEndEncoded(_, value_type), _) => can_cast_types(value_type.data_type(), to_type), (_, RunEndEncoded(_, value_type)) => can_cast_types(from_type, value_type.data_type()), diff --git a/arrow-cast/src/cast/union.rs b/arrow-cast/src/cast/union.rs index 03e5611c67f4..7681e04356c8 100644 --- a/arrow-cast/src/cast/union.rs +++ b/arrow-cast/src/cast/union.rs @@ -25,8 +25,8 @@ use arrow_select::union_extract::union_extract; use super::CastOptions; -// this is used during variant selection to prefer a "close" type over a distant cast -// for example: when targeting Utf8View, a Utf8 variant is preferred over Int32 despite both being castable +// this is used during child array selection to prefer a "close" type over a distant cast +// for example: when targeting Utf8View, a Utf8 child is preferred over Int32 despite both being castable fn same_type_family(a: &DataType, b: &DataType) -> bool { use DataType::*; matches!( @@ -61,7 +61,7 @@ fn same_type_family(a: &DataType, b: &DataType) -> bool { /// nulls, which can conflict with non-nullable inner fields /// /// Each pass greedily picks the first matching field by type_id order -pub(crate) fn resolve_variant<'a>( +pub(crate) fn resolve_child_array<'a>( fields: &'a UnionFields, target_type: &DataType, ) -> Option<&'a FieldRef> { @@ -74,7 +74,7 @@ pub(crate) fn resolve_variant<'a>( .find(|(_, f)| same_type_family(f.data_type(), target_type)) }) .or_else(|| { - // skip nested types in pass 3 — union extraction introduces nulls, + // skip nested types in pass 3 because union extraction introduces nulls, // and casting nullable arrays to nested types like List/Struct/Map can fail // when inner fields are non-nullable. if target_type.is_nested() { @@ -90,8 +90,8 @@ pub(crate) fn resolve_variant<'a>( /// Extracts the best-matching child array from a [`UnionArray`] for a given target type, /// and casts it to that type. /// -/// Rows where a different variant is active become NULL. -/// If no variant matches, returns a null array. +/// Rows where a different child array is active become NULL. +/// If no child array matches, returns an error. /// /// # Example /// @@ -120,7 +120,7 @@ pub(crate) fn resolve_variant<'a>( /// ) /// .unwrap(); /// -/// // extract the Utf8 variant and cast to Utf8View +/// // extract the Utf8 child array and cast to Utf8View /// let result = union_extract_by_type(&union, &DataType::Utf8View, &CastOptions::default()).unwrap(); /// assert_eq!(result.data_type(), &DataType::Utf8View); /// assert!(result.is_null(0)); // Int32 row -> NULL @@ -137,7 +137,7 @@ pub fn union_extract_by_type( _ => unreachable!("union_extract_by_type called on non-union array"), }; - let Some(field) = resolve_variant(fields, target_type) else { + let Some(field) = resolve_child_array(fields, target_type) else { return Err(ArrowError::CastError(format!( "cannot cast Union with fields {} to {}", fields @@ -182,7 +182,7 @@ mod tests { } // pass 1: exact type match. - // Union(Int32, Utf8) targeting Utf8 — the Utf8 variant matches exactly. + // Union(Int32, Utf8) targeting Utf8. The Utf8 child matches exactly. // Int32 rows become NULL. tested for both sparse and dense. #[test] fn test_exact_type_match() { @@ -237,8 +237,8 @@ mod tests { } // pass 2: same type family match. - // Union(Int32, Utf8) targeting Utf8View — no exact match, but Utf8 and Utf8View - // are in the same family. picks the Utf8 variant and casts to Utf8View. + // Union(Int32, Utf8) targeting Utf8View. No exact match, but Utf8 and Utf8View + // are in the same family. picks the Utf8 child array and casts to Utf8View. // this is the bug that motivated this work: without pass 2, pass 3 would // greedily pick Int32 (since can_cast_types(Int32, Utf8View) is true). #[test] @@ -301,7 +301,7 @@ mod tests { // pass 3: one-directional cast across type families. // Union(Int32, Utf8) targeting Boolean — no exact match, no family match. - // pass 3 picks Int32 (first variant where can_cast_types is true) and + // pass 3 picks Int32 (first child array where can_cast_types is true) and // casts to Boolean (0 → false, nonzero → true). Utf8 rows become NULL. #[test] fn test_one_directional_cast() { @@ -355,7 +355,7 @@ mod tests { assert!(!arr.value(2)); } - // no matching variant — all three passes fail. + // no matching child array, all three passes fail. // Union(Int32, Utf8) targeting Struct({x: Int32}). neither Int32 nor Utf8 // can be cast to a Struct, so both can_cast_types and cast return errors. #[test] @@ -382,7 +382,7 @@ mod tests { } // priority: exact match (pass 1) wins over family match (pass 2). - // Union(Utf8, Utf8View) targeting Utf8View — both variants are in the string + // Union(Utf8, Utf8View) targeting Utf8View. Both child arrays are in the string // family, but Utf8View is an exact match. pass 1 should pick it, not Utf8. #[test] fn test_exact_match_preferred_over_family() { @@ -421,17 +421,17 @@ mod tests { assert_eq!(result.data_type(), &target); let arr = result.as_any().downcast_ref::().unwrap(); - // pass 1 picks variant "b" (Utf8View), so variant "a" rows become NULL + // pass 1 picks child "b" (Utf8View), so child "a" rows become NULL assert!(arr.is_null(0)); assert_eq!(arr.value(1), "from_b"); assert!(arr.is_null(2)); } - // null values within the selected variant stay null. - // this is distinct from "wrong variant → NULL": here the correct variant + // null values within the selected child array stay null. + // this is distinct from "wrong child array -> NULL": here the correct child array // is active but its value is null. #[test] - fn test_null_in_selected_variant() { + fn test_null_in_selected_child_array() { let target = DataType::Utf8; assert!(can_cast_types( @@ -440,7 +440,7 @@ mod tests { )); // ["hello", NULL(str), "world"] - // all rows are the Utf8 variant, but row 1 has a null value + // all rows are the Utf8 child array, but row 1 has a null value let union = UnionArray::try_new( int_str_fields(), vec![1_i8, 1, 1].into(),