Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 74 additions & 3 deletions arrow-select/src/coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
//! [`filter`]: crate::filter::filter
//! [`take`]: crate::take::take
use crate::filter::filter_record_batch;
use crate::take::take_record_batch;
use arrow_array::types::{BinaryViewType, StringViewType};
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, downcast_primitive};
use arrow_schema::{ArrowError, DataType, SchemaRef};
Expand Down Expand Up @@ -243,6 +244,39 @@ impl BatchCoalescer {
self.push_batch(filtered_batch)
}

/// Push a batch into the Coalescer after applying a set of indices
/// This is semantically equivalent of calling [`Self::push_batch`]
/// with the results from [`take_record_batch`]
///
/// # Example
/// ```
/// # use arrow_array::{record_batch, UInt64Array};
/// # use arrow_select::coalesce::BatchCoalescer;
/// let batch1 = record_batch!(("a", Int32, [0, 0, 0])).unwrap();
/// let batch2 = record_batch!(("a", Int32, [1, 1, 4, 5, 1, 4])).unwrap();
/// // Sorted indices to create a sorted output, this can be obtained with
/// // `arrow-ord`'s sort_to_indices operation
/// let indices = UInt64Array::from(vec![0, 1, 4, 2, 5, 3]);
/// // create a new Coalescer that targets creating 1000 row batches
/// let mut coalescer = BatchCoalescer::new(batch1.schema(), 1000);
/// coalescer.push_batch(batch1);
/// coalescer.push_batch_with_indices(batch2, &indices);
/// // finsh and retrieve the created batch
/// coalescer.finish_buffered_batch().unwrap();
/// let completed_batch = coalescer.next_completed_batch().unwrap();
/// let expected_batch = record_batch!(("a", Int32, [0, 0, 0, 1, 1, 1, 4, 4, 5])).unwrap();
/// assert_eq!(completed_batch, expected_batch);
/// ```
pub fn push_batch_with_indices(
&mut self,
batch: RecordBatch,
indices: &dyn Array,
) -> Result<(), ArrowError> {
// todo: optimize this to avoid materializing (copying the results of take indices to a new batch)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let taken_batch = take_record_batch(&batch, indices)?;
self.push_batch(taken_batch)
}

/// Push all the rows from `batch` into the Coalescer
///
/// When buffered data plus incoming rows reach `target_batch_size` ,
Expand Down Expand Up @@ -583,7 +617,7 @@ mod tests {
use arrow_array::cast::AsArray;
use arrow_array::{
BinaryViewArray, Int32Array, Int64Array, RecordBatchOptions, StringArray, StringViewArray,
TimestampNanosecondArray, UInt32Array,
TimestampNanosecondArray, UInt32Array, UInt64Array,
};
use arrow_schema::{DataType, Field, Schema};
use rand::{Rng, SeedableRng};
Expand Down Expand Up @@ -1327,21 +1361,29 @@ mod tests {

/// Return a RecordBatch with a UInt32Array with the specified range and
/// every third value is null.
fn uint32_batch(range: Range<u32>) -> RecordBatch {
fn uint32_batch<T: std::iter::Iterator<Item = u32>>(range: T) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, true)]));

let array = UInt32Array::from_iter(range.map(|i| if i % 3 == 0 { None } else { Some(i) }));
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}

/// Return a RecordBatch with a UInt32Array with no nulls specified range
fn uint32_batch_non_null(range: Range<u32>) -> RecordBatch {
fn uint32_batch_non_null<T: std::iter::Iterator<Item = u32>>(range: T) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));

let array = UInt32Array::from_iter_values(range);
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}

/// Return a RecordBatch with a UInt64Array with no nulls specified range
fn uint64_batch_non_null<T: std::iter::Iterator<Item = u64>>(range: T) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt64, false)]));

let array = UInt64Array::from_iter_values(range);
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}

/// Return a RecordBatch with a StringArrary with values `value0`, `value1`, ...
/// and every third value is `None`.
fn utf8_batch(range: Range<u32>) -> RecordBatch {
Expand Down Expand Up @@ -1932,4 +1974,33 @@ mod tests {
}
assert_eq!(coalescer.get_buffered_rows(), 0);
}

#[test]
fn test_coalasce_push_batch_with_indices() {
const MID_POINT: u32 = 2333;
const TOTAL_ROWS: u32 = 23333;
let batch1 = uint32_batch_non_null(0..MID_POINT);
let batch2 = uint32_batch_non_null((MID_POINT..TOTAL_ROWS).rev());

let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])),
TOTAL_ROWS as usize,
);
coalescer.push_batch(batch1).unwrap();

let rev_indices = (0..((TOTAL_ROWS - MID_POINT) as u64)).rev();
let reversed_indices_batch = uint64_batch_non_null(rev_indices);

let reverse_indices = UInt64Array::from(reversed_indices_batch.column(0).to_data());
coalescer
.push_batch_with_indices(batch2, &reverse_indices)
.unwrap();

coalescer.finish_buffered_batch().unwrap();
let actual = coalescer.next_completed_batch().unwrap();

let expected = uint32_batch_non_null(0..TOTAL_ROWS);

assert_eq!(expected, actual);
}
}
Loading