Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 131 additions & 109 deletions arrow-select/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ use arrow_buffer::{
ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer,
bit_util,
};
use arrow_data::ArrayDataBuilder;
use arrow_data::{ArrayDataBuilder, transform::MutableArrayData};
use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};

use num_traits::{One, Zero};
use num_traits::Zero;

/// Take elements by index from [Array], creating a new [Array] from those indexes.
///
Expand Down Expand Up @@ -611,9 +611,8 @@ fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(

/// `take` implementation for list arrays
///
/// Calculates the index and indexed offset for the inner array,
/// applying `take` on the inner array, then reconstructing a list array
/// with the indexed offsets
/// Copies the selected list entries' child slices into a new child array
/// via `MutableArrayData`, then reconstructs a list array with new offsets
fn take_list<IndexType, OffsetType>(
values: &GenericListArray<OffsetType::Native>,
indices: &PrimitiveArray<IndexType>,
Expand All @@ -624,23 +623,79 @@ where
OffsetType::Native: OffsetSizeTrait,
PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
{
// TODO: Some optimizations can be done here such as if it is
// taking the whole list or a contiguous sublist
let (list_indices, offsets, null_buf) =
take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;

let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
let value_offsets = Buffer::from_vec(offsets);
// create a new list with taken data and computed null information
let list_offsets = values.value_offsets();
let child_data = values.values().to_data();
let nulls = take_nulls(values.nulls(), indices);

let mut new_offsets = Vec::with_capacity(indices.len() + 1);
new_offsets.push(OffsetType::Native::zero());

let use_nulls = child_data.null_count() > 0;

let capacity = child_data
.len()
.checked_div(values.len())
.map(|v| v * indices.len())
.unwrap_or_default();

let mut array_data = MutableArrayData::new(vec![&child_data], use_nulls, capacity);

match nulls.as_ref().filter(|n| n.null_count() > 0) {
None => {
for index in indices.values() {
let ix = index.as_usize();
let start = list_offsets[ix].as_usize();
let end = list_offsets[ix + 1].as_usize();
array_data.extend(0, start, end);
new_offsets.push(OffsetType::Native::from_usize(array_data.len()).unwrap());
}
}
Some(output_nulls) => {
assert_eq!(output_nulls.len(), indices.len());

let mut last_filled = 0;
for i in output_nulls.valid_indices() {
let current = OffsetType::Native::from_usize(array_data.len()).unwrap();
// Filling offsets for the null values between the two valid indices
if last_filled < i {
Comment thread
alamb marked this conversation as resolved.
new_offsets.extend(std::iter::repeat_n(current, i - last_filled));
}

// SAFETY: `i` comes from validity bitmap over `indices`, so in-bounds.
Comment thread
alamb marked this conversation as resolved.
let ix = unsafe { indices.value_unchecked(i) }.as_usize();
let start = list_offsets[ix].as_usize();
let end = list_offsets[ix + 1].as_usize();
array_data.extend(0, start, end);
new_offsets.push(OffsetType::Native::from_usize(array_data.len()).unwrap());
last_filled = i + 1;
}

// Filling offsets for null values at the end
let final_offset = OffsetType::Native::from_usize(array_data.len()).unwrap();
new_offsets.extend(std::iter::repeat_n(
final_offset,
indices.len() - last_filled,
));
}
};

assert_eq!(
new_offsets.len(),
indices.len() + 1,
"New offsets was filled under/over the expected capacity"
);

let child_data = array_data.freeze();
let value_offsets = Buffer::from_vec(new_offsets);

let list_data = ArrayDataBuilder::new(values.data_type().clone())
.len(indices.len())
.null_bit_buffer(Some(null_buf.into()))
.nulls(nulls)
.offset(0)
.add_child_data(taken.into_data())
.add_child_data(child_data)
.add_buffer(value_offsets);

let list_data = unsafe { list_data.build_unchecked() };

Ok(GenericListArray::<OffsetType::Native>::from(list_data))
}

Expand Down Expand Up @@ -925,78 +980,6 @@ fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
Ok(array_data.into())
}

/// Takes/filters a list array's inner data using the offsets of the list array.
///
/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns
/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2
/// elements)
#[allow(clippy::type_complexity)]
fn take_value_indices_from_list<IndexType, OffsetType>(
list: &GenericListArray<OffsetType::Native>,
indices: &PrimitiveArray<IndexType>,
) -> Result<
(
PrimitiveArray<OffsetType>,
Vec<OffsetType::Native>,
MutableBuffer,
),
ArrowError,
>
where
IndexType: ArrowPrimitiveType,
OffsetType: ArrowPrimitiveType,
OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
{
// TODO: benchmark this function, there might be a faster unsafe alternative
let offsets: &[OffsetType::Native] = list.value_offsets();

let mut new_offsets = Vec::with_capacity(indices.len());
let mut values = Vec::new();
let mut current_offset = OffsetType::Native::zero();
// add first offset
new_offsets.push(OffsetType::Native::zero());

// Initialize null buffer
let num_bytes = bit_util::ceil(indices.len(), 8);
let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
let null_slice = null_buf.as_slice_mut();

// compute the value indices, and set offsets accordingly
for i in 0..indices.len() {
if indices.is_valid(i) {
let ix = indices
.value(i)
.to_usize()
.ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
let start = offsets[ix];
let end = offsets[ix + 1];
current_offset += end - start;
new_offsets.push(current_offset);

let mut curr = start;

// if start == end, this slot is empty
while curr < end {
values.push(curr);
curr += One::one();
}
if !list.is_valid(ix) {
bit_util::unset_bit(null_slice, i);
}
} else {
bit_util::unset_bit(null_slice, i);
new_offsets.push(current_offset);
}
}

Ok((
PrimitiveArray::<OffsetType>::from(values),
new_offsets,
null_buf,
))
}

/// Takes/filters a fixed size list array's inner data using the offsets of the list array.
fn take_value_indices_from_fixed_size_list<IndexType>(
list: &FixedSizeListArray,
Expand Down Expand Up @@ -2497,37 +2480,76 @@ mod tests {
)
}

#[test]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make it easier to see what you have changed if you didn't also move the tests around

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no idea why I did that

fn test_take_value_index_from_list() {
let list = build_generic_list::<i32, Int32Type>(vec![
fn test_take_sliced_list_generic<S: OffsetSizeTrait + 'static>() {
let list = build_generic_list::<S, Int32Type>(vec![
Some(vec![0, 1]),
Some(vec![2, 3, 4]),
Some(vec![5, 6, 7, 8, 9]),
None,
Some(vec![]),
Some(vec![5, 6]),
Some(vec![7]),
]);
let sliced = list.slice(1, 4);
let indices = UInt32Array::from(vec![Some(3), Some(0), None, Some(2), Some(1)]);

let taken = take(&sliced, &indices, None).unwrap();
let taken = taken.as_list::<S>();

let expected = build_generic_list::<S, Int32Type>(vec![
Some(vec![5, 6]),
Some(vec![2, 3, 4]),
None,
Some(vec![]),
None,
]);
let indices = UInt32Array::from(vec![2, 0]);

let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
assert_eq!(taken, &expected);
}

fn test_take_sliced_list_with_value_nulls_generic<S: OffsetSizeTrait + 'static>() {
let list = GenericListArray::<S>::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(10)]),
Some(vec![None, Some(1)]),
None,
Some(vec![Some(2), None]),
Some(vec![]),
Some(vec![Some(3)]),
]);
let sliced = list.slice(1, 4);
let indices = UInt32Array::from(vec![Some(2), Some(0), None, Some(3), Some(1)]);

let taken = take(&sliced, &indices, None).unwrap();
let taken = taken.as_list::<S>();

let expected = GenericListArray::<S>::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(2), None]),
Some(vec![None, Some(1)]),
None,
Some(vec![]),
None,
]);

assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
assert_eq!(offsets, vec![0, 5, 7]);
assert_eq!(null_buf.as_slice(), &[0b11111111]);
assert_eq!(taken, &expected);
}

#[test]
fn test_take_value_index_from_large_list() {
let list = build_generic_list::<i64, Int32Type>(vec![
Some(vec![0, 1]),
Some(vec![2, 3, 4]),
Some(vec![5, 6, 7, 8, 9]),
]);
let indices = UInt32Array::from(vec![2, 0]);
fn test_take_sliced_list() {
test_take_sliced_list_generic::<i32>();
}

#[test]
fn test_take_sliced_large_list() {
test_take_sliced_list_generic::<i64>();
}

let (indexed, offsets, null_buf) =
take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
#[test]
fn test_take_sliced_list_with_value_nulls() {
test_take_sliced_list_with_value_nulls_generic::<i32>();
}

assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
assert_eq!(offsets, vec![0, 5, 7]);
assert_eq!(null_buf.as_slice(), &[0b11111111]);
#[test]
fn test_take_sliced_large_list_with_value_nulls() {
test_take_sliced_list_with_value_nulls_generic::<i64>();
}

#[test]
Expand Down
Loading