Skip to content

Commit e70bc45

Browse files
committed
feat: impl BatchCoalescer::push_batch_with_indices
MVP for #8957 awaits for #8951 very first version for behaviour review, optimizations TBD Signed-off-by: 蔡略 <cailue@apache.org>
1 parent c6cc7f8 commit e70bc45

File tree

1 file changed

+74
-3
lines changed

1 file changed

+74
-3
lines changed

arrow-select/src/coalesce.rs

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
//! [`filter`]: crate::filter::filter
2222
//! [`take`]: crate::take::take
2323
use crate::filter::filter_record_batch;
24+
use crate::take::take_record_batch;
2425
use arrow_array::types::{BinaryViewType, StringViewType};
2526
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, downcast_primitive};
2627
use arrow_schema::{ArrowError, DataType, SchemaRef};
@@ -243,6 +244,39 @@ impl BatchCoalescer {
243244
self.push_batch(filtered_batch)
244245
}
245246

247+
/// Push a batch into the Coalescer after applying a set of indices
248+
/// This is semantically equivalent of calling [`Self::push_batch`]
249+
/// with the results from [`take_record_batch`]
250+
///
251+
/// # Example
252+
/// ```
253+
/// # use arrow_array::{record_batch, UInt64Array};
254+
/// # use arrow_select::coalesce::BatchCoalescer;
255+
/// let batch1 = record_batch!(("a", Int32, [0, 0, 0])).unwrap();
256+
/// let batch2 = record_batch!(("a", Int32, [1, 1, 4, 5, 1, 4])).unwrap();
257+
/// // Sorted indices to create a sorted output, this can be obtained with
258+
/// // `arrow-ord`'s sort_to_indices operation
259+
/// let indices = UInt64Array::from(vec![0, 1, 4, 2, 5, 3]);
260+
/// // create a new Coalescer that targets creating 1000 row batches
261+
/// let mut coalescer = BatchCoalescer::new(batch1.schema(), 1000);
262+
/// coalescer.push_batch(batch1);
263+
/// coalescer.push_batch_with_indices(batch2, &indices);
264+
/// // finsh and retrieve the created batch
265+
/// coalescer.finish_buffered_batch().unwrap();
266+
/// let completed_batch = coalescer.next_completed_batch().unwrap();
267+
/// let expected_batch = record_batch!(("a", Int32, [0, 0, 0, 1, 1, 1, 4, 4, 5])).unwrap();
268+
/// assert_eq!(completed_batch, expected_batch);
269+
/// ```
270+
pub fn push_batch_with_indices(
271+
&mut self,
272+
batch: RecordBatch,
273+
indices: &dyn Array,
274+
) -> Result<(), ArrowError> {
275+
// todo: optimize this to avoid materializing (copying the results of take indices to a new batch)
276+
let taken_batch = take_record_batch(&batch, indices)?;
277+
self.push_batch(taken_batch)
278+
}
279+
246280
/// Push all the rows from `batch` into the Coalescer
247281
///
248282
/// When buffered data plus incoming rows reach `target_batch_size` ,
@@ -583,7 +617,7 @@ mod tests {
583617
use arrow_array::cast::AsArray;
584618
use arrow_array::{
585619
BinaryViewArray, Int32Array, Int64Array, RecordBatchOptions, StringArray, StringViewArray,
586-
TimestampNanosecondArray, UInt32Array,
620+
TimestampNanosecondArray, UInt32Array, UInt64Array,
587621
};
588622
use arrow_schema::{DataType, Field, Schema};
589623
use rand::{Rng, SeedableRng};
@@ -1327,21 +1361,29 @@ mod tests {
13271361

13281362
/// Return a RecordBatch with a UInt32Array with the specified range and
13291363
/// every third value is null.
1330-
fn uint32_batch(range: Range<u32>) -> RecordBatch {
1364+
fn uint32_batch<T: std::iter::Iterator<Item = u32>>(range: T) -> RecordBatch {
13311365
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, true)]));
13321366

13331367
let array = UInt32Array::from_iter(range.map(|i| if i % 3 == 0 { None } else { Some(i) }));
13341368
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
13351369
}
13361370

13371371
/// Return a RecordBatch with a UInt32Array with no nulls specified range
1338-
fn uint32_batch_non_null(range: Range<u32>) -> RecordBatch {
1372+
fn uint32_batch_non_null<T: std::iter::Iterator<Item = u32>>(range: T) -> RecordBatch {
13391373
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
13401374

13411375
let array = UInt32Array::from_iter_values(range);
13421376
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
13431377
}
13441378

1379+
/// Return a RecordBatch with a UInt64Array with no nulls specified range
1380+
fn uint64_batch_non_null<T: std::iter::Iterator<Item = u64>>(range: T) -> RecordBatch {
1381+
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt64, false)]));
1382+
1383+
let array = UInt64Array::from_iter_values(range);
1384+
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
1385+
}
1386+
13451387
/// Return a RecordBatch with a StringArrary with values `value0`, `value1`, ...
13461388
/// and every third value is `None`.
13471389
fn utf8_batch(range: Range<u32>) -> RecordBatch {
@@ -1932,4 +1974,33 @@ mod tests {
19321974
}
19331975
assert_eq!(coalescer.get_buffered_rows(), 0);
19341976
}
1977+
1978+
#[test]
1979+
fn test_coalasce_push_batch_with_indices() {
1980+
const MID_POINT: u32 = 2333;
1981+
const TOTAL_ROWS: u32 = 23333;
1982+
let batch1 = uint32_batch_non_null(0..MID_POINT);
1983+
let batch2 = uint32_batch_non_null((MID_POINT..TOTAL_ROWS).rev());
1984+
1985+
let mut coalescer = BatchCoalescer::new(
1986+
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])),
1987+
TOTAL_ROWS as usize,
1988+
);
1989+
coalescer.push_batch(batch1).unwrap();
1990+
1991+
let rev_indices = (0..((TOTAL_ROWS - MID_POINT) as u64)).rev();
1992+
let reversed_indices_batch = uint64_batch_non_null(rev_indices);
1993+
1994+
let reverse_indices = UInt64Array::from(reversed_indices_batch.column(0).to_data());
1995+
coalescer
1996+
.push_batch_with_indices(batch2, &reverse_indices)
1997+
.unwrap();
1998+
1999+
coalescer.finish_buffered_batch().unwrap();
2000+
let actual = coalescer.next_completed_batch().unwrap();
2001+
2002+
let expected = uint32_batch_non_null(0..TOTAL_ROWS);
2003+
2004+
assert_eq!(expected, actual);
2005+
}
19352006
}

0 commit comments

Comments
 (0)