diff --git a/arrow-select/src/interleave.rs b/arrow-select/src/interleave.rs index cb3ca655dc67..f870bc8c2fe3 100644 --- a/arrow-select/src/interleave.rs +++ b/arrow-select/src/interleave.rs @@ -26,7 +26,7 @@ use arrow_array::*; use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer, OffsetBuffer}; use arrow_data::ByteView; use arrow_data::transform::MutableArrayData; -use arrow_schema::{ArrowError, DataType, Fields}; +use arrow_schema::{ArrowError, DataType, FieldRef, Fields}; use std::sync::Arc; macro_rules! primitive_helper { @@ -106,6 +106,8 @@ pub fn interleave( _ => unreachable!("illegal dictionary key type {k}") }, DataType::Struct(fields) => interleave_struct(fields, values, indices), + DataType::List(field) => interleave_list::(values, indices, field), + DataType::LargeList(field) => interleave_list::(values, indices, field), _ => interleave_fallback(values, indices) } } @@ -319,6 +321,50 @@ fn interleave_struct( Ok(Arc::new(struct_array)) } +fn interleave_list( + values: &[&dyn Array], + indices: &[(usize, usize)], + field: &FieldRef, +) -> Result { + let interleaved = Interleave::<'_, GenericListArray>::new(values, indices); + + let mut capacity = 0usize; + let mut offsets = Vec::with_capacity(indices.len() + 1); + offsets.push(O::from_usize(0).unwrap()); + offsets.extend(indices.iter().map(|(array, row)| { + let o = interleaved.arrays[*array].value_offsets(); + let element_len = o[*row + 1].as_usize() - o[*row].as_usize(); + capacity += element_len; + O::from_usize(capacity).expect("offset overflow") + })); + + let mut child_indices = Vec::with_capacity(capacity); + for (array, row) in indices { + let list = interleaved.arrays[*array]; + let start = list.value_offsets()[*row].as_usize(); + let end = list.value_offsets()[*row + 1].as_usize(); + child_indices.extend((start..end).map(|i| (*array, i))); + } + + let child_arrays: Vec<&dyn Array> = interleaved + .arrays + .iter() + .map(|list| list.values().as_ref()) + .collect(); + + let interleaved_values = interleave(&child_arrays, &child_indices)?; + + let offsets = OffsetBuffer::new(offsets.into()); + let list_array = GenericListArray::::new( + field.clone(), + offsets, + interleaved_values, + interleaved.nulls, + ); + + Ok(Arc::new(list_array)) +} + /// Fallback implementation of interleave using [`MutableArrayData`] fn interleave_fallback( values: &[&dyn Array], @@ -488,7 +534,7 @@ pub fn interleave_record_batch( mod tests { use super::*; use arrow_array::Int32RunArray; - use arrow_array::builder::{Int32Builder, ListBuilder, PrimitiveRunBuilder}; + use arrow_array::builder::{GenericListBuilder, Int32Builder, PrimitiveRunBuilder}; use arrow_array::types::Int8Type; use arrow_schema::Field; @@ -622,10 +668,9 @@ mod tests { assert_eq!(string_result, vec!["v0", "v0", "v49"]); } - #[test] - fn test_lists() { + fn test_interleave_lists() { // [[1, 2], null, [3]] - let mut a = ListBuilder::new(Int32Builder::new()); + let mut a = GenericListBuilder::::new(Int32Builder::new()); a.values().append_value(1); a.values().append_value(2); a.append(true); @@ -635,7 +680,7 @@ mod tests { let a = a.finish(); // [[4], null, [5, 6, null]] - let mut b = ListBuilder::new(Int32Builder::new()); + let mut b = GenericListBuilder::::new(Int32Builder::new()); b.values().append_value(4); b.append(true); b.append(false); @@ -646,10 +691,13 @@ mod tests { let b = b.finish(); let values = interleave(&[&a, &b], &[(0, 2), (0, 1), (1, 0), (1, 2), (1, 1)]).unwrap(); - let v = values.as_any().downcast_ref::().unwrap(); + let v = values + .as_any() + .downcast_ref::>() + .unwrap(); // [[3], null, [4], [5, 6, null], null] - let mut expected = ListBuilder::new(Int32Builder::new()); + let mut expected = GenericListBuilder::::new(Int32Builder::new()); expected.values().append_value(3); expected.append(true); expected.append(false); @@ -665,6 +713,16 @@ mod tests { assert_eq!(v, &expected); } + #[test] + fn test_lists() { + test_interleave_lists::(); + } + + #[test] + fn test_large_lists() { + test_interleave_lists::(); + } + #[test] fn test_struct_without_nulls() { let fields = Fields::from(vec![