Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 181 additions & 3 deletions arrow-ord/src/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::{ArrowNativeType, NullBuffer};
use arrow_schema::{ArrowError, SortOptions};
use arrow_schema::{ArrowError, DataType, SortOptions};
use std::cmp::Ordering;

/// Compare the values at two arbitrary indices in two arrays.
Expand Down Expand Up @@ -296,6 +296,72 @@ fn compare_struct(
Ok(f)
}

fn compare_union(
left: &dyn Array,
right: &dyn Array,
opts: SortOptions,
) -> Result<DynComparator, ArrowError> {
let left = left.as_union();
let right = right.as_union();

let (left_fields, left_mode) = match left.data_type() {
DataType::Union(fields, mode) => (fields, mode),
_ => unreachable!(),
};
let (right_fields, right_mode) = match right.data_type() {
DataType::Union(fields, mode) => (fields, mode),
_ => unreachable!(),
};

if left_fields != right_fields || left_mode != right_mode {
return Err(ArrowError::InvalidArgumentError(
"Cannot compare UnionArrays with different fields or modes".to_string(),
));
}

let c_opts = child_opts(opts);

let max_type_id = left_fields.iter().map(|(id, _)| id).max().unwrap_or(0);
let mut field_comparators: Vec<Option<DynComparator>> =
Vec::with_capacity((max_type_id + 1) as usize);
field_comparators.resize_with((max_type_id + 1) as usize, || None);

for (type_id, _field) in left_fields.iter() {
let left_child = left.child(type_id);
let right_child = right.child(type_id);
let cmp = make_comparator(left_child.as_ref(), right_child.as_ref(), c_opts)?;
field_comparators[type_id as usize] = Some(cmp);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indexing field_comparators with type_id as usize assumes non-negative, dense type IDs and can panic or cause huge allocations if a union uses negative or sparse type IDs (Arrow allows arbitrary i8 type IDs). Consider avoiding direct index-based lookup by type ID here (also applies to the later access in the comparator closure).

🤖 Was this useful? React with 👍 or 👎

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback:The Augment AI reviewer is correct! If the fields ids are something like "[1, 100]" then it will use 100 (the max) and create a Vector of 100 items, while it really needs just two. Using a HashMap would be better

}

let left_type_ids = left.type_ids().clone();
let right_type_ids = right.type_ids().clone();

let left_offsets = left.offsets().cloned();
let right_offsets = right.offsets().cloned();

let f = compare(left, right, opts, move |i, j| {
let left_type_id = left_type_ids[i];
let right_type_id = right_type_ids[j];

// first, compare by type_id
match left_type_id.cmp(&right_type_id) {
Ordering::Equal => {
// second, compare by values
let left_offset = left_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i);
let right_offset = right_offsets.as_ref().map(|o| o[j] as usize).unwrap_or(j);

let cmp = field_comparators[left_type_id as usize]
.as_ref()
.expect("type_id not found in field_comparators");

cmp(left_offset, right_offset)
}
other => other,
}
});
Ok(f)
}

/// Returns a comparison function that compares two values at two different positions
/// between the two arrays.
///
Expand Down Expand Up @@ -412,6 +478,7 @@ pub fn make_comparator(
}
},
(Map(_, _), Map(_, _)) => compare_map(left, right, opts),
(Union(_, _), Union(_, _)) => compare_union(left, right, opts),
(lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
true => format!("The data type type {lhs:?} has no natural order"),
false => "Can't compare arrays of different types".to_string(),
Expand All @@ -423,8 +490,8 @@ pub fn make_comparator(
mod tests {
use super::*;
use arrow_array::builder::{Int32Builder, ListBuilder, MapBuilder, StringBuilder};
use arrow_buffer::{IntervalDayTime, OffsetBuffer, i256};
use arrow_schema::{DataType, Field, Fields};
use arrow_buffer::{IntervalDayTime, OffsetBuffer, ScalarBuffer, i256};
use arrow_schema::{DataType, Field, Fields, UnionFields};
use half::f16;
use std::sync::Arc;

Expand Down Expand Up @@ -1189,4 +1256,115 @@ mod tests {
}
}
}

#[test]
fn test_dense_union() {
// create a dense union array with Int32 (type_id = 0) and Utf8 (type_id=1)
// the values are: [1, "b", 2, "a", 3]
// type_ids are: [0, 1, 0, 1, 0]
// offsets are: [0, 0, 1, 1, 2] from [1, 2, 3] and ["b", "a"]
let int_array = Int32Array::from(vec![1, 2, 3]);
let str_array = StringArray::from(vec!["b", "a"]);

let type_ids = [0, 1, 0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
let offsets = [0, 0, 1, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();

let union_fields = [
(0, Arc::new(Field::new("A", DataType::Int32, false))),
(1, Arc::new(Field::new("B", DataType::Utf8, false))),
]
.into_iter()
.collect::<UnionFields>();

let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];

let array1 =
UnionArray::try_new(union_fields.clone(), type_ids, Some(offsets), children).unwrap();

// create a second array: [2, "a", 1, "c"]
// type ids are: [0, 1, 0, 1]
// offsets are: [0, 0, 1, 1] from [2, 1] and ["a", "c"]
let int_array2 = Int32Array::from(vec![2, 1]);
let str_array2 = StringArray::from(vec!["a", "c"]);
let type_ids2 = [0, 1, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
let offsets2 = [0, 0, 1, 1].into_iter().collect::<ScalarBuffer<i32>>();

let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)];

let array2 =
UnionArray::try_new(union_fields, type_ids2, Some(offsets2), children2).unwrap();

let opts = SortOptions {
descending: false,
nulls_first: true,
};

// comparing
// [1, "b", 2, "a", 3]
// [2, "a", 1, "c"]
let cmp = make_comparator(&array1, &array2, opts).unwrap();

// array1[0] = (type_id=0, value=1)
// array2[0] = (type_id=0, value=2)
assert_eq!(cmp(0, 0), Ordering::Less); // 1 < 2

// array1[0] = (type_id=0, value=1)
// array2[1] = (type_id=1, value="a")
assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1

// array1[1] = (type_id=1, value="b")
// array2[1] = (type_id=1, value="a")
assert_eq!(cmp(1, 1), Ordering::Greater); // "b" > "a"

// array1[2] = (type_id=0, value=2)
// array2[0] = (type_id=0, value=2)
assert_eq!(cmp(2, 0), Ordering::Equal); // 2 == 2

// array1[3] = (type_id=1, value="a")
// array2[1] = (type_id=1, value="a")
assert_eq!(cmp(3, 1), Ordering::Equal); // "a" == "a"

// array1[1] = (type_id=1, value="b")
// array2[3] = (type_id=1, value="c")
assert_eq!(cmp(1, 3), Ordering::Less); // "b" < "c"

let opts_desc = SortOptions {
descending: true,
nulls_first: true,
};
let cmp_desc = make_comparator(&array1, &array2, opts_desc).unwrap();

assert_eq!(cmp_desc(0, 0), Ordering::Greater); // 1 > 2 (reversed)
assert_eq!(cmp_desc(0, 1), Ordering::Greater); // type_id 0 < 1, reversed to Greater
assert_eq!(cmp_desc(1, 1), Ordering::Less); // "b" < "a" (reversed)
}

#[test]
fn test_sparse_union() {
// create a sparse union array with Int32 (type_id=0) and Utf8 (type_id=1)
// values: [1, "b", 3]
// note, in sparse unions, child arrays have the same length as the union
let int_array = Int32Array::from(vec![Some(1), None, Some(3)]);
let str_array = StringArray::from(vec![None, Some("b"), None]);
let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();

let union_fields = [
(0, Arc::new(Field::new("a", DataType::Int32, false))),
(1, Arc::new(Field::new("b", DataType::Utf8, false))),
]
.into_iter()
.collect::<UnionFields>();

let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];

let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();

let opts = SortOptions::default();
let cmp = make_comparator(&array, &array, opts).unwrap();

// array[0] = (type_id=0, value=1), array[2] = (type_id=0, value=3)
assert_eq!(cmp(0, 2), Ordering::Less); // 1 < 3
// array[0] = (type_id=0, value=1), array[1] = (type_id=1, value="b")
assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1
}
}