Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ fn concat_dictionaries<K: ArrowDictionaryKeyType>(
.inspect(|d| output_len += d.len())
.collect();

if !should_merge_dictionary_values::<K>(&dictionaries, output_len) {
if !should_merge_dictionary_values::<K>(&dictionaries, output_len).0 {
return concat_fallback(arrays, Capacities::Array(output_len));
}

Expand Down
17 changes: 14 additions & 3 deletions arrow-select/src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,14 @@ type PtrEq = fn(&dyn Array, &dyn Array) -> bool;
/// some return over the naive approach used by MutableArrayData
///
/// `len` is the total length of the merged output
///
/// Returns `(should_merge, has_overflow)` where:
/// - `should_merge`: whether dictionary values should be merged
/// - `has_overflow`: whether the combined dictionary values would overflow the key type
pub(crate) fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
dictionaries: &[&DictionaryArray<K>],
len: usize,
) -> bool {
) -> (bool, bool) {
use DataType::*;
let first_values = dictionaries[0].values().as_ref();
let ptr_eq: PtrEq = match first_values.data_type() {
Expand All @@ -187,7 +191,11 @@ pub(crate) fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
LargeBinary => bytes_ptr_eq::<LargeBinaryType>,
dt => {
if !dt.is_primitive() {
return false;
return (
false,
K::Native::from_usize(dictionaries.iter().map(|d| d.values().len()).sum())
.is_none(),
);
}
|a, b| a.to_data().ptr_eq(&b.to_data())
}
Expand All @@ -206,7 +214,10 @@ pub(crate) fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
let overflow = K::Native::from_usize(total_values).is_none();
let values_exceed_length = total_values >= len;

!single_dictionary && (overflow || values_exceed_length)
(
!single_dictionary && (overflow || values_exceed_length),
overflow,
)
}

/// Given an array of dictionaries and an optional key mask compute a values array
Expand Down
155 changes: 153 additions & 2 deletions arrow-select/src/interleave.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Interleave elements from multiple arrays

use crate::concat::concat;
use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values};
use arrow_array::builder::{BooleanBufferBuilder, PrimitiveBuilder};
use arrow_array::cast::AsArray;
Expand Down Expand Up @@ -195,8 +196,14 @@ fn interleave_dictionaries<K: ArrowDictionaryKeyType>(
indices: &[(usize, usize)],
) -> Result<ArrayRef, ArrowError> {
let dictionaries: Vec<_> = arrays.iter().map(|x| x.as_dictionary::<K>()).collect();
if !should_merge_dictionary_values::<K>(&dictionaries, indices.len()) {
return interleave_fallback(arrays, indices);
let (should_merge, has_overflow) =
should_merge_dictionary_values::<K>(&dictionaries, indices.len());
if !should_merge {
return if has_overflow {
interleave_fallback(arrays, indices)
} else {
interleave_fallback_dictionary::<K>(&dictionaries, indices)
};
}

let masks: Vec<_> = dictionaries
Expand Down Expand Up @@ -346,6 +353,76 @@ fn interleave_fallback(
Ok(make_array(array_data.freeze()))
}

/// Fallback implementation for interleaving dictionaries when it was determined
/// that the dictionary values should not be merged. This implementation concatenates
/// the value slices and recomputes the resulting dictionary keys.
///
/// # Panics
///
/// This function assumes that the combined dictionary values will not overflow the
/// key type. Callers must verify this condition [`should_merge_dictionary_values`]
/// before calling this function.
fn interleave_fallback_dictionary<K: ArrowDictionaryKeyType>(
dictionaries: &[&DictionaryArray<K>],
indices: &[(usize, usize)],
) -> Result<ArrayRef, ArrowError> {
let relative_offsets: Vec<usize> = dictionaries
.iter()
.scan(0usize, |offset, dict| {
let current = *offset;
*offset += dict.values().len();
Some(current)
})
.collect();
let all_values: Vec<&dyn Array> = dictionaries.iter().map(|d| d.values().as_ref()).collect();
let concatenated_values = concat(&all_values)?;

let any_nulls = dictionaries.iter().any(|d| d.keys().nulls().is_some());
let (new_keys, nulls) = if any_nulls {
let mut has_nulls = false;
let new_keys: Vec<K::Native> = indices
.iter()
.map(|(array, row)| {
let old_keys = dictionaries[*array].keys();
if old_keys.is_valid(*row) {
let old_key = old_keys.values()[*row].as_usize();
K::Native::from_usize(relative_offsets[*array] + old_key)
.expect("key overflow should be checked by caller")
} else {
has_nulls = true;
K::Native::ZERO
}
})
.collect();

let nulls = if has_nulls {
let null_buffer = BooleanBuffer::collect_bool(indices.len(), |i| {
let (array, row) = indices[i];
dictionaries[array].keys().is_valid(row)
});
Some(NullBuffer::new(null_buffer))
} else {
None
};
(new_keys, nulls)
} else {
let new_keys: Vec<K::Native> = indices
.iter()
.map(|(array, row)| {
let old_key = dictionaries[*array].keys().values()[*row].as_usize();
K::Native::from_usize(relative_offsets[*array] + old_key)
.expect("key overflow should be checked by caller")
})
.collect();
(new_keys, None)
};

let keys_array = PrimitiveArray::<K>::new(new_keys.into(), nulls);
// SAFETY: keys_array is constructed from a valid set of keys.
let array = unsafe { DictionaryArray::new_unchecked(keys_array, concatenated_values) };
Ok(Arc::new(array))
}

/// Interleave rows by index from multiple [`RecordBatch`] instances and return a new [`RecordBatch`].
///
/// This function will call [`interleave`] on each array of the [`RecordBatch`] instances and assemble a new [`RecordBatch`].
Expand Down Expand Up @@ -412,6 +489,7 @@ mod tests {
use super::*;
use arrow_array::Int32RunArray;
use arrow_array::builder::{Int32Builder, ListBuilder, PrimitiveRunBuilder};
use arrow_array::types::Int8Type;
use arrow_schema::Field;

#[test]
Expand Down Expand Up @@ -509,6 +587,41 @@ mod tests {
assert_eq!(actual, expected);
}

#[test]
fn test_interleave_dictionary_overflow_same_values() {
let values: ArrayRef = Arc::new(StringArray::from_iter_values(
(0..50).map(|i| format!("v{i}")),
));

// With 3 dictionaries of 50 values each, relative_offsets = [0, 50, 100]
// Accessing key 49 from dict3 gives 100 + 49 = 149 which overflows Int8
// (max 127).
// This test case falls back to interleave_fallback because the
// dictionaries share the same underlying values slice.
let dict1 = DictionaryArray::<Int8Type>::new(
Int8Array::from_iter_values([0, 1, 2]),
values.clone(),
);
let dict2 = DictionaryArray::<Int8Type>::new(
Int8Array::from_iter_values([0, 1, 2]),
values.clone(),
);
let dict3 =
DictionaryArray::<Int8Type>::new(Int8Array::from_iter_values([49]), values.clone());

let indices = &[(0, 0), (1, 0), (2, 0)];
let result = interleave(&[&dict1, &dict2, &dict3], indices).unwrap();

let dict_result = result.as_dictionary::<Int8Type>();
let string_result: Vec<_> = dict_result
.downcast_dict::<StringArray>()
.unwrap()
.into_iter()
.map(|x| x.unwrap())
.collect();
assert_eq!(string_result, vec!["v0", "v0", "v49"]);
}

#[test]
fn test_lists() {
// [[1, 2], null, [3]]
Expand Down Expand Up @@ -1182,4 +1295,42 @@ mod tests {
assert_eq!(v.len(), 1);
assert_eq!(v.data_type(), &DataType::Struct(fields));
}

#[test]
fn test_interleave_fallback_dictionary_with_nulls() {
let input_1_keys = Int32Array::from_iter([Some(0), None, Some(1)]);
let input_1_values = StringArray::from_iter_values(["foo", "bar"]);
let dict_a = DictionaryArray::new(input_1_keys, Arc::new(input_1_values));

let input_2_keys = Int32Array::from_iter([Some(0), Some(1), None]);
let input_2_values = StringArray::from_iter_values(["baz", "qux"]);
let dict_b = DictionaryArray::new(input_2_keys, Arc::new(input_2_values));

let indices = vec![
(0, 0), // "foo"
(0, 1), // null
(1, 0), // "baz"
(1, 2), // null
(0, 2), // "bar"
(1, 1), // "qux"
];

let result =
interleave_fallback_dictionary::<Int32Type>(&[&dict_a, &dict_b], &indices).unwrap();
let dict_result = result.as_dictionary::<Int32Type>();

let string_result = dict_result.downcast_dict::<StringArray>().unwrap();
let collected: Vec<_> = string_result.into_iter().collect();
assert_eq!(
collected,
vec![
Some("foo"),
None,
Some("baz"),
None,
Some("bar"),
Some("qux")
]
);
}
}
Loading