diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs index 6e3025576c69..95566e95fea7 100644 --- a/arrow-ord/src/ord.rs +++ b/arrow-ord/src/ord.rs @@ -21,7 +21,7 @@ use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::{ArrowNativeType, NullBuffer}; -use arrow_schema::{ArrowError, SortOptions}; +use arrow_schema::{ArrowError, DataType, SortOptions}; use std::cmp::Ordering; /// Compare the values at two arbitrary indices in two arrays. @@ -296,6 +296,72 @@ fn compare_struct( Ok(f) } +fn compare_union( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> Result { + let left = left.as_union(); + let right = right.as_union(); + + let (left_fields, left_mode) = match left.data_type() { + DataType::Union(fields, mode) => (fields, mode), + _ => unreachable!(), + }; + let (right_fields, right_mode) = match right.data_type() { + DataType::Union(fields, mode) => (fields, mode), + _ => unreachable!(), + }; + + if left_fields != right_fields || left_mode != right_mode { + return Err(ArrowError::InvalidArgumentError( + "Cannot compare UnionArrays with different fields or modes".to_string(), + )); + } + + let c_opts = child_opts(opts); + + let max_type_id = left_fields.iter().map(|(id, _)| id).max().unwrap_or(0); + let mut field_comparators: Vec> = + Vec::with_capacity((max_type_id + 1) as usize); + field_comparators.resize_with((max_type_id + 1) as usize, || None); + + for (type_id, _field) in left_fields.iter() { + let left_child = left.child(type_id); + let right_child = right.child(type_id); + let cmp = make_comparator(left_child.as_ref(), right_child.as_ref(), c_opts)?; + field_comparators[type_id as usize] = Some(cmp); + } + + let left_type_ids = left.type_ids().clone(); + let right_type_ids = right.type_ids().clone(); + + let left_offsets = left.offsets().cloned(); + let right_offsets = right.offsets().cloned(); + + let f = compare(left, right, opts, move |i, j| { + let left_type_id = left_type_ids[i]; + let right_type_id = right_type_ids[j]; + + // first, compare by type_id + match left_type_id.cmp(&right_type_id) { + Ordering::Equal => { + // second, compare by values + let left_offset = left_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i); + let right_offset = right_offsets.as_ref().map(|o| o[j] as usize).unwrap_or(j); + + let cmp = field_comparators[left_type_id as usize] + .as_ref() + .expect("type_id not found in field_comparators"); + + cmp(left_offset, right_offset) + } + other => other, + } + }); + Ok(f) +} + /// Returns a comparison function that compares two values at two different positions /// between the two arrays. /// @@ -412,6 +478,7 @@ pub fn make_comparator( } }, (Map(_, _), Map(_, _)) => compare_map(left, right, opts), + (Union(_, _), Union(_, _)) => compare_union(left, right, opts), (lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs { true => format!("The data type type {lhs:?} has no natural order"), false => "Can't compare arrays of different types".to_string(), @@ -423,8 +490,8 @@ pub fn make_comparator( mod tests { use super::*; use arrow_array::builder::{Int32Builder, ListBuilder, MapBuilder, StringBuilder}; - use arrow_buffer::{IntervalDayTime, OffsetBuffer, i256}; - use arrow_schema::{DataType, Field, Fields}; + use arrow_buffer::{IntervalDayTime, OffsetBuffer, ScalarBuffer, i256}; + use arrow_schema::{DataType, Field, Fields, UnionFields}; use half::f16; use std::sync::Arc; @@ -1189,4 +1256,115 @@ mod tests { } } } + + #[test] + fn test_dense_union() { + // create a dense union array with Int32 (type_id = 0) and Utf8 (type_id=1) + // the values are: [1, "b", 2, "a", 3] + // type_ids are: [0, 1, 0, 1, 0] + // offsets are: [0, 0, 1, 1, 2] from [1, 2, 3] and ["b", "a"] + let int_array = Int32Array::from(vec![1, 2, 3]); + let str_array = StringArray::from(vec!["b", "a"]); + + let type_ids = [0, 1, 0, 1, 0].into_iter().collect::>(); + let offsets = [0, 0, 1, 1, 2].into_iter().collect::>(); + + let union_fields = [ + (0, Arc::new(Field::new("A", DataType::Int32, false))), + (1, Arc::new(Field::new("B", DataType::Utf8, false))), + ] + .into_iter() + .collect::(); + + let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; + + let array1 = + UnionArray::try_new(union_fields.clone(), type_ids, Some(offsets), children).unwrap(); + + // create a second array: [2, "a", 1, "c"] + // type ids are: [0, 1, 0, 1] + // offsets are: [0, 0, 1, 1] from [2, 1] and ["a", "c"] + let int_array2 = Int32Array::from(vec![2, 1]); + let str_array2 = StringArray::from(vec!["a", "c"]); + let type_ids2 = [0, 1, 0, 1].into_iter().collect::>(); + let offsets2 = [0, 0, 1, 1].into_iter().collect::>(); + + let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)]; + + let array2 = + UnionArray::try_new(union_fields, type_ids2, Some(offsets2), children2).unwrap(); + + let opts = SortOptions { + descending: false, + nulls_first: true, + }; + + // comparing + // [1, "b", 2, "a", 3] + // [2, "a", 1, "c"] + let cmp = make_comparator(&array1, &array2, opts).unwrap(); + + // array1[0] = (type_id=0, value=1) + // array2[0] = (type_id=0, value=2) + assert_eq!(cmp(0, 0), Ordering::Less); // 1 < 2 + + // array1[0] = (type_id=0, value=1) + // array2[1] = (type_id=1, value="a") + assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1 + + // array1[1] = (type_id=1, value="b") + // array2[1] = (type_id=1, value="a") + assert_eq!(cmp(1, 1), Ordering::Greater); // "b" > "a" + + // array1[2] = (type_id=0, value=2) + // array2[0] = (type_id=0, value=2) + assert_eq!(cmp(2, 0), Ordering::Equal); // 2 == 2 + + // array1[3] = (type_id=1, value="a") + // array2[1] = (type_id=1, value="a") + assert_eq!(cmp(3, 1), Ordering::Equal); // "a" == "a" + + // array1[1] = (type_id=1, value="b") + // array2[3] = (type_id=1, value="c") + assert_eq!(cmp(1, 3), Ordering::Less); // "b" < "c" + + let opts_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let cmp_desc = make_comparator(&array1, &array2, opts_desc).unwrap(); + + assert_eq!(cmp_desc(0, 0), Ordering::Greater); // 1 > 2 (reversed) + assert_eq!(cmp_desc(0, 1), Ordering::Greater); // type_id 0 < 1, reversed to Greater + assert_eq!(cmp_desc(1, 1), Ordering::Less); // "b" < "a" (reversed) + } + + #[test] + fn test_sparse_union() { + // create a sparse union array with Int32 (type_id=0) and Utf8 (type_id=1) + // values: [1, "b", 3] + // note, in sparse unions, child arrays have the same length as the union + let int_array = Int32Array::from(vec![Some(1), None, Some(3)]); + let str_array = StringArray::from(vec![None, Some("b"), None]); + let type_ids = [0, 1, 0].into_iter().collect::>(); + + let union_fields = [ + (0, Arc::new(Field::new("a", DataType::Int32, false))), + (1, Arc::new(Field::new("b", DataType::Utf8, false))), + ] + .into_iter() + .collect::(); + + let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; + + let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); + + let opts = SortOptions::default(); + let cmp = make_comparator(&array, &array, opts).unwrap(); + + // array[0] = (type_id=0, value=1), array[2] = (type_id=0, value=3) + assert_eq!(cmp(0, 2), Ordering::Less); // 1 < 3 + // array[0] = (type_id=0, value=1), array[1] = (type_id=1, value="b") + assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1 + } }