diff --git a/arrow-select/src/coalesce.rs b/arrow-select/src/coalesce.rs index 5ef9a86b9206..ddb1c41c8c79 100644 --- a/arrow-select/src/coalesce.rs +++ b/arrow-select/src/coalesce.rs @@ -21,6 +21,7 @@ //! [`filter`]: crate::filter::filter //! [`take`]: crate::take::take use crate::filter::filter_record_batch; +use crate::take::take_record_batch; use arrow_array::types::{BinaryViewType, StringViewType}; use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, downcast_primitive}; use arrow_schema::{ArrowError, DataType, SchemaRef}; @@ -243,6 +244,39 @@ impl BatchCoalescer { self.push_batch(filtered_batch) } + /// Push a batch into the Coalescer after applying a set of indices + /// This is semantically equivalent of calling [`Self::push_batch`] + /// with the results from [`take_record_batch`] + /// + /// # Example + /// ``` + /// # use arrow_array::{record_batch, UInt64Array}; + /// # use arrow_select::coalesce::BatchCoalescer; + /// let batch1 = record_batch!(("a", Int32, [0, 0, 0])).unwrap(); + /// let batch2 = record_batch!(("a", Int32, [1, 1, 4, 5, 1, 4])).unwrap(); + /// // Sorted indices to create a sorted output, this can be obtained with + /// // `arrow-ord`'s sort_to_indices operation + /// let indices = UInt64Array::from(vec![0, 1, 4, 2, 5, 3]); + /// // create a new Coalescer that targets creating 1000 row batches + /// let mut coalescer = BatchCoalescer::new(batch1.schema(), 1000); + /// coalescer.push_batch(batch1); + /// coalescer.push_batch_with_indices(batch2, &indices); + /// // finsh and retrieve the created batch + /// coalescer.finish_buffered_batch().unwrap(); + /// let completed_batch = coalescer.next_completed_batch().unwrap(); + /// let expected_batch = record_batch!(("a", Int32, [0, 0, 0, 1, 1, 1, 4, 4, 5])).unwrap(); + /// assert_eq!(completed_batch, expected_batch); + /// ``` + pub fn push_batch_with_indices( + &mut self, + batch: RecordBatch, + indices: &dyn Array, + ) -> Result<(), ArrowError> { + // todo: optimize this to avoid materializing (copying the results of take indices to a new batch) + let taken_batch = take_record_batch(&batch, indices)?; + self.push_batch(taken_batch) + } + /// Push all the rows from `batch` into the Coalescer /// /// When buffered data plus incoming rows reach `target_batch_size` , @@ -583,7 +617,7 @@ mod tests { use arrow_array::cast::AsArray; use arrow_array::{ BinaryViewArray, Int32Array, Int64Array, RecordBatchOptions, StringArray, StringViewArray, - TimestampNanosecondArray, UInt32Array, + TimestampNanosecondArray, UInt32Array, UInt64Array, }; use arrow_schema::{DataType, Field, Schema}; use rand::{Rng, SeedableRng}; @@ -1327,7 +1361,7 @@ mod tests { /// Return a RecordBatch with a UInt32Array with the specified range and /// every third value is null. - fn uint32_batch(range: Range) -> RecordBatch { + fn uint32_batch>(range: T) -> RecordBatch { let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, true)])); let array = UInt32Array::from_iter(range.map(|i| if i % 3 == 0 { None } else { Some(i) })); @@ -1335,13 +1369,21 @@ mod tests { } /// Return a RecordBatch with a UInt32Array with no nulls specified range - fn uint32_batch_non_null(range: Range) -> RecordBatch { + fn uint32_batch_non_null>(range: T) -> RecordBatch { let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])); let array = UInt32Array::from_iter_values(range); RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap() } + /// Return a RecordBatch with a UInt64Array with no nulls specified range + fn uint64_batch_non_null>(range: T) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt64, false)])); + + let array = UInt64Array::from_iter_values(range); + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap() + } + /// Return a RecordBatch with a StringArrary with values `value0`, `value1`, ... /// and every third value is `None`. fn utf8_batch(range: Range) -> RecordBatch { @@ -1932,4 +1974,33 @@ mod tests { } assert_eq!(coalescer.get_buffered_rows(), 0); } + + #[test] + fn test_coalasce_push_batch_with_indices() { + const MID_POINT: u32 = 2333; + const TOTAL_ROWS: u32 = 23333; + let batch1 = uint32_batch_non_null(0..MID_POINT); + let batch2 = uint32_batch_non_null((MID_POINT..TOTAL_ROWS).rev()); + + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])), + TOTAL_ROWS as usize, + ); + coalescer.push_batch(batch1).unwrap(); + + let rev_indices = (0..((TOTAL_ROWS - MID_POINT) as u64)).rev(); + let reversed_indices_batch = uint64_batch_non_null(rev_indices); + + let reverse_indices = UInt64Array::from(reversed_indices_batch.column(0).to_data()); + coalescer + .push_batch_with_indices(batch2, &reverse_indices) + .unwrap(); + + coalescer.finish_buffered_batch().unwrap(); + let actual = coalescer.next_completed_batch().unwrap(); + + let expected = uint32_batch_non_null(0..TOTAL_ROWS); + + assert_eq!(expected, actual); + } }