|
21 | 21 | //! [`filter`]: crate::filter::filter |
22 | 22 | //! [`take`]: crate::take::take |
23 | 23 | use crate::filter::filter_record_batch; |
| 24 | +use crate::take::take_record_batch; |
24 | 25 | use arrow_array::types::{BinaryViewType, StringViewType}; |
25 | 26 | use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, downcast_primitive}; |
26 | 27 | use arrow_schema::{ArrowError, DataType, SchemaRef}; |
@@ -243,6 +244,39 @@ impl BatchCoalescer { |
243 | 244 | self.push_batch(filtered_batch) |
244 | 245 | } |
245 | 246 |
|
| 247 | + /// Push a batch into the Coalescer after applying a set of indices |
| 248 | + /// This is semantically equivalent of calling [`Self::push_batch`] |
| 249 | + /// with the results from [`take_record_batch`] |
| 250 | + /// |
| 251 | + /// # Example |
| 252 | + /// ``` |
| 253 | + /// # use arrow_array::{record_batch, UInt64Array}; |
| 254 | + /// # use arrow_select::coalesce::BatchCoalescer; |
| 255 | + /// let batch1 = record_batch!(("a", Int32, [0, 0, 0])).unwrap(); |
| 256 | + /// let batch2 = record_batch!(("a", Int32, [1, 1, 4, 5, 1, 4])).unwrap(); |
| 257 | + /// // Sorted indices to create a sorted output, this can be obtained with |
| 258 | + /// // `arrow-ord`'s sort_to_indices operation |
| 259 | + /// let indices = UInt64Array::from(vec![0, 1, 4, 2, 5, 3]); |
| 260 | + /// // create a new Coalescer that targets creating 1000 row batches |
| 261 | + /// let mut coalescer = BatchCoalescer::new(batch1.schema(), 1000); |
| 262 | + /// coalescer.push_batch(batch1); |
| 263 | + /// coalescer.push_batch_with_indices(batch2, &indices); |
| 264 | + /// // finsh and retrieve the created batch |
| 265 | + /// coalescer.finish_buffered_batch().unwrap(); |
| 266 | + /// let completed_batch = coalescer.next_completed_batch().unwrap(); |
| 267 | + /// let expected_batch = record_batch!(("a", Int32, [0, 0, 0, 1, 1, 1, 4, 4, 5])).unwrap(); |
| 268 | + /// assert_eq!(completed_batch, expected_batch); |
| 269 | + /// ``` |
| 270 | + pub fn push_batch_with_indices( |
| 271 | + &mut self, |
| 272 | + batch: RecordBatch, |
| 273 | + indices: &dyn Array, |
| 274 | + ) -> Result<(), ArrowError> { |
| 275 | + // todo: optimize this to avoid materializing (copying the results of take indices to a new batch) |
| 276 | + let taken_batch = take_record_batch(&batch, indices)?; |
| 277 | + self.push_batch(taken_batch) |
| 278 | + } |
| 279 | + |
246 | 280 | /// Push all the rows from `batch` into the Coalescer |
247 | 281 | /// |
248 | 282 | /// When buffered data plus incoming rows reach `target_batch_size` , |
@@ -583,7 +617,7 @@ mod tests { |
583 | 617 | use arrow_array::cast::AsArray; |
584 | 618 | use arrow_array::{ |
585 | 619 | BinaryViewArray, Int32Array, Int64Array, RecordBatchOptions, StringArray, StringViewArray, |
586 | | - TimestampNanosecondArray, UInt32Array, |
| 620 | + TimestampNanosecondArray, UInt32Array, UInt64Array, |
587 | 621 | }; |
588 | 622 | use arrow_schema::{DataType, Field, Schema}; |
589 | 623 | use rand::{Rng, SeedableRng}; |
@@ -1327,21 +1361,29 @@ mod tests { |
1327 | 1361 |
|
1328 | 1362 | /// Return a RecordBatch with a UInt32Array with the specified range and |
1329 | 1363 | /// every third value is null. |
1330 | | - fn uint32_batch(range: Range<u32>) -> RecordBatch { |
| 1364 | + fn uint32_batch<T: std::iter::Iterator<Item = u32>>(range: T) -> RecordBatch { |
1331 | 1365 | let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, true)])); |
1332 | 1366 |
|
1333 | 1367 | let array = UInt32Array::from_iter(range.map(|i| if i % 3 == 0 { None } else { Some(i) })); |
1334 | 1368 | RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap() |
1335 | 1369 | } |
1336 | 1370 |
|
1337 | 1371 | /// Return a RecordBatch with a UInt32Array with no nulls specified range |
1338 | | - fn uint32_batch_non_null(range: Range<u32>) -> RecordBatch { |
| 1372 | + fn uint32_batch_non_null<T: std::iter::Iterator<Item = u32>>(range: T) -> RecordBatch { |
1339 | 1373 | let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])); |
1340 | 1374 |
|
1341 | 1375 | let array = UInt32Array::from_iter_values(range); |
1342 | 1376 | RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap() |
1343 | 1377 | } |
1344 | 1378 |
|
| 1379 | + /// Return a RecordBatch with a UInt64Array with no nulls specified range |
| 1380 | + fn uint64_batch_non_null<T: std::iter::Iterator<Item = u64>>(range: T) -> RecordBatch { |
| 1381 | + let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt64, false)])); |
| 1382 | + |
| 1383 | + let array = UInt64Array::from_iter_values(range); |
| 1384 | + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap() |
| 1385 | + } |
| 1386 | + |
1345 | 1387 | /// Return a RecordBatch with a StringArrary with values `value0`, `value1`, ... |
1346 | 1388 | /// and every third value is `None`. |
1347 | 1389 | fn utf8_batch(range: Range<u32>) -> RecordBatch { |
@@ -1932,4 +1974,33 @@ mod tests { |
1932 | 1974 | } |
1933 | 1975 | assert_eq!(coalescer.get_buffered_rows(), 0); |
1934 | 1976 | } |
| 1977 | + |
| 1978 | + #[test] |
| 1979 | + fn test_coalasce_push_batch_with_indices() { |
| 1980 | + const MID_POINT: u32 = 2333; |
| 1981 | + const TOTAL_ROWS: u32 = 23333; |
| 1982 | + let batch1 = uint32_batch_non_null(0..MID_POINT); |
| 1983 | + let batch2 = uint32_batch_non_null((MID_POINT..TOTAL_ROWS).rev()); |
| 1984 | + |
| 1985 | + let mut coalescer = BatchCoalescer::new( |
| 1986 | + Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])), |
| 1987 | + TOTAL_ROWS as usize, |
| 1988 | + ); |
| 1989 | + coalescer.push_batch(batch1).unwrap(); |
| 1990 | + |
| 1991 | + let rev_indices = (0..((TOTAL_ROWS - MID_POINT) as u64)).rev(); |
| 1992 | + let reversed_indices_batch = uint64_batch_non_null(rev_indices); |
| 1993 | + |
| 1994 | + let reverse_indices = UInt64Array::from(reversed_indices_batch.column(0).to_data()); |
| 1995 | + coalescer |
| 1996 | + .push_batch_with_indices(batch2, &reverse_indices) |
| 1997 | + .unwrap(); |
| 1998 | + |
| 1999 | + coalescer.finish_buffered_batch().unwrap(); |
| 2000 | + let actual = coalescer.next_completed_batch().unwrap(); |
| 2001 | + |
| 2002 | + let expected = uint32_batch_non_null(0..TOTAL_ROWS); |
| 2003 | + |
| 2004 | + assert_eq!(expected, actual); |
| 2005 | + } |
1935 | 2006 | } |
0 commit comments