Skip to content

Commit 9d00687

Browse files
committed
Add ree_distinct function that compares Run-End Encoded arrays directly without expanding to logical form.
1 parent 72bd81a commit 9d00687

5 files changed

Lines changed: 444 additions & 91 deletions

File tree

arrow-arith/src/aggregate.rs

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
//! Defines aggregations over Arrow arrays.
1919
20-
use arrow_array::cast::{*};
20+
use arrow_array::cast::*;
2121
use arrow_array::iterator::ArrayIter;
2222
use arrow_array::*;
2323
use arrow_buffer::{ArrowNativeType, NullBuffer};
@@ -653,7 +653,7 @@ where
653653
DataType::Int16 => AnyRunArray::new(&array, DataType::Int16),
654654
_ => return Ok(None),
655655
};
656-
656+
657657
if let Some(ree) = ree {
658658
let mut sum = T::default_value();
659659

@@ -666,15 +666,14 @@ where
666666
let end = ree.run_ends_value(i);
667667
let run_length = end - prev_end;
668668
let run_length_native = T::Native::from_usize(run_length).unwrap();
669-
sum = sum.add_checked(values_data[i].mul_checked(run_length_native)?)?;
669+
sum = sum.add_checked(values_data[i].mul_checked(run_length_native)?)?;
670670
prev_end = end;
671671
}
672672

673673
Ok(Some(sum))
674674
} else {
675675
Ok(None)
676676
}
677-
678677
}
679678
_ => sum_checked::<T>(as_primitive_array(&array)),
680679
}
@@ -712,9 +711,7 @@ where
712711
{
713712
match array.data_type() {
714713
DataType::Dictionary(_, _) => min_max_helper::<T::Native, _, _>(array, cmp),
715-
DataType::RunEndEncoded(_, _) => {
716-
min_max_helper::<T::Native, _, _>(array, cmp)
717-
}
714+
DataType::RunEndEncoded(_, _) => min_max_helper::<T::Native, _, _>(array, cmp),
718715
_ => m(as_primitive_array(&array)),
719716
}
720717
}
@@ -1773,8 +1770,8 @@ mod tests {
17731770
}
17741771
mod ree_aggregation {
17751772
use super::*;
1776-
use arrow_array::{RunArray, Int32Array, Int64Array, Float64Array};
1777-
use arrow_array::types::{Int32Type, Int64Type, Float64Type};
1773+
use arrow_array::types::{Float64Type, Int32Type, Int64Type};
1774+
use arrow_array::{Float64Array, Int32Array, Int64Array, RunArray};
17781775

17791776
#[test]
17801777
fn test_ree_sum_array_basic() {
@@ -1783,7 +1780,6 @@ mod tests {
17831780
let values = Int32Array::from(vec![10, 20, 30]);
17841781
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
17851782

1786-
17871783
let typed_array = run_array.downcast::<Int32Array>().unwrap();
17881784

17891785
let result = sum_array::<Int32Type, _>(typed_array);
@@ -1794,14 +1790,26 @@ mod tests {
17941790
fn test_ree_sum_array_with_nulls() {
17951791
// REE array with nulls: [10, NULL, 20, NULL, 30]
17961792
let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]);
1797-
let values = Int32Array::from(vec![10, 0, 20, 0, 30]); // 0 represents null
1793+
let values = Int32Array::from(vec![Some(10), None, Some(20), None, Some(30)]);
17981794
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
17991795

18001796
let typed_array = run_array.downcast::<Int32Array>().unwrap();
18011797
let result = sum_array::<Int32Type, _>(typed_array);
18021798
assert_eq!(result, Some(60)); // 10+20+30 = 60 (nulls ignored)
18031799
}
18041800

1801+
#[test]
1802+
fn test_ree_sum_array_with_only_nulls() {
1803+
// REE array: [None, None, None, None, None] (logical length 5)
1804+
let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]);
1805+
let values = Int32Array::from(vec![None, None, None, None, None]);
1806+
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
1807+
1808+
let typed_array = run_array.downcast::<Int32Array>().unwrap();
1809+
let result = sum_array::<Int32Type, _>(typed_array);
1810+
assert_eq!(result, Some(0)); // 0
1811+
}
1812+
18051813
#[test]
18061814
fn test_ree_sum_array_large_values() {
18071815
// REE array with large values: [1000, 1000, 2000, 3000, 3000]

arrow-array/src/array/mod.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ mod run_array;
6868

6969
pub use run_array::*;
7070

71-
// Re-export the unwrap_ree_array function for public use
72-
pub use run_array::unwrap_ree_array;
73-
7471
mod byte_view_array;
7572

7673
pub use byte_view_array::*;

arrow-array/src/array/run_array.rs

Lines changed: 155 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,24 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21+
use crate::{
22+
builder::StringRunBuilder,
23+
cast::AsArray,
24+
make_array,
25+
run_iterator::RunArrayIter,
26+
types::{
27+
Date32Type, Date64Type, Decimal128Type, Decimal256Type, DurationNanosecondType,
28+
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType,
29+
RunEndIndexType, Time32MillisecondType, Time64NanosecondType, TimestampMicrosecondType,
30+
UInt16Type, UInt32Type, UInt64Type, UInt8Type,
31+
},
32+
Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType, BooleanArray, Int16Array, Int32Array,
33+
Int64Array, PrimitiveArray, StringArray,
34+
};
2135
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, RunEndBuffer};
2236
use arrow_data::{ArrayData, ArrayDataBuilder};
2337
use arrow_schema::{ArrowError, DataType, Field};
2438

25-
use crate::{
26-
builder::StringRunBuilder, cast::AsArray, make_array, run_iterator::RunArrayIter, types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}, Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType, PrimitiveArray, Int16Array, Int32Array, Int64Array
27-
};
28-
2939
/// An array of [run-end encoded values](https://arrow.apache.org/docs/format/Columnar.html#run-end-encoded-layout)
3040
///
3141
/// This encoding is variation on [run-length encoding (RLE)](https://en.wikipedia.org/wiki/Run-length_encoding)
@@ -247,14 +257,16 @@ impl<R: RunEndIndexType> RunArray<R> {
247257
values: self.values.clone(),
248258
}
249259
}
260+
250261
/// Expands the REE array to its logical form
251-
pub fn expand_to_logical<T: ArrowPrimitiveType>(&self) -> Result<Box<dyn Array>, ArrowError>
262+
pub fn expand_to_logical<T: ArrowPrimitiveType>(&self) -> Result<Box<dyn Array>, ArrowError>
252263
where
253264
T::Native: Default,
254265
{
255-
let typed_ree = self.downcast::<PrimitiveArray<T>>()
256-
.ok_or_else(|| ArrowError::InvalidArgumentError("Failed to downcast to typed REE".to_string()))?;
257-
266+
let typed_ree = self.downcast::<PrimitiveArray<T>>().ok_or_else(|| {
267+
ArrowError::InvalidArgumentError("Failed to downcast to typed REE".to_string())
268+
})?;
269+
258270
let mut builder = PrimitiveArray::<T>::builder(typed_ree.len());
259271
for i in 0..typed_ree.len() {
260272
if typed_ree.is_null(i) {
@@ -265,29 +277,6 @@ impl<R: RunEndIndexType> RunArray<R> {
265277
}
266278
Ok(Box::new(builder.finish()))
267279
}
268-
/// Unwraps a REE array into a logical array
269-
pub fn unwrap_ree_array(array: &dyn Array) -> Option<Box<dyn Array>> {
270-
match array.data_type() {
271-
arrow_schema::DataType::RunEndEncoded(run_ends_field, _) => {
272-
match run_ends_field.data_type() {
273-
arrow_schema::DataType::Int16 => {
274-
array.as_run_opt::<Int16Type>()
275-
.and_then(|ree| ree.expand_to_logical::<Int16Type>().ok())
276-
}
277-
arrow_schema::DataType::Int32 => {
278-
array.as_run_opt::<Int32Type>()
279-
.and_then(|ree| ree.expand_to_logical::<Int32Type>().ok())
280-
}
281-
arrow_schema::DataType::Int64 => {
282-
array.as_run_opt::<Int64Type>()
283-
.and_then(|ree| ree.expand_to_logical::<Int64Type>().ok())
284-
}
285-
_ => None,
286-
}
287-
}
288-
_ => None,
289-
}
290-
}
291280
}
292281

293282
impl<R: RunEndIndexType> From<ArrayData> for RunArray<R> {
@@ -566,28 +555,140 @@ pub struct TypedRunArray<'a, R: RunEndIndexType, V> {
566555
}
567556

568557
/// Unwraps a REE array into a logical array
569-
pub fn unwrap_ree_array(array: &dyn Array) -> Option<Box<dyn Array>> {
558+
pub fn ree_to_expanded_array(array: &dyn Array) -> Option<Box<dyn Array>> {
570559
match array.data_type() {
571560
arrow_schema::DataType::RunEndEncoded(run_ends_field, _) => {
572561
match run_ends_field.data_type() {
573-
arrow_schema::DataType::Int16 => {
574-
array.as_run_opt::<Int16Type>()
575-
.and_then(|ree| ree.expand_to_logical::<Int16Type>().ok())
576-
}
577-
arrow_schema::DataType::Int32 => {
578-
array.as_run_opt::<Int32Type>()
579-
.and_then(|ree| ree.expand_to_logical::<Int32Type>().ok())
580-
}
581-
arrow_schema::DataType::Int64 => {
582-
array.as_run_opt::<Int64Type>()
583-
.and_then(|ree| ree.expand_to_logical::<Int64Type>().ok())
584-
}
562+
arrow_schema::DataType::Int16 => array
563+
.as_run_opt::<Int16Type>()
564+
.and_then(|ree| ree.expand_to_logical::<Int16Type>().ok()),
565+
arrow_schema::DataType::Int32 => array
566+
.as_run_opt::<Int32Type>()
567+
.and_then(|ree| ree.expand_to_logical::<Int32Type>().ok()),
568+
arrow_schema::DataType::Int64 => array
569+
.as_run_opt::<Int64Type>()
570+
.and_then(|ree| ree.expand_to_logical::<Int64Type>().ok()),
585571
_ => None,
586572
}
587573
}
588574
_ => None,
589575
}
590576
}
577+
578+
/// Generate a boolean array that indicates if two run arrays are equal
579+
pub fn ree_distinct(
580+
lhs: &AnyRunArray,
581+
rhs: &AnyRunArray,
582+
size: usize,
583+
flag: bool,
584+
) -> Option<BooleanArray> {
585+
// Iterate through both run arrays and compare their logical indices
586+
// we know that the run arrays of the exact same size.
587+
let lhs_vals = lhs.values();
588+
let rhs_vals = rhs.values();
589+
if lhs_vals.data_type() != rhs_vals.data_type() {
590+
return None;
591+
}
592+
match lhs_vals.data_type() {
593+
// Integer types
594+
DataType::Int8 => ree_distinct_primitive::<Int8Type>(lhs, rhs, size, flag),
595+
DataType::Int16 => ree_distinct_primitive::<Int16Type>(lhs, rhs, size, flag),
596+
DataType::Int32 => ree_distinct_primitive::<Int32Type>(lhs, rhs, size, flag),
597+
DataType::Int64 => ree_distinct_primitive::<Int64Type>(lhs, rhs, size, flag),
598+
DataType::UInt8 => ree_distinct_primitive::<UInt8Type>(lhs, rhs, size, flag),
599+
DataType::UInt16 => ree_distinct_primitive::<UInt16Type>(lhs, rhs, size, flag),
600+
DataType::UInt32 => ree_distinct_primitive::<UInt32Type>(lhs, rhs, size, flag),
601+
DataType::UInt64 => ree_distinct_primitive::<UInt64Type>(lhs, rhs, size, flag),
602+
603+
// Floating point
604+
DataType::Float32 => ree_distinct_primitive::<Float32Type>(lhs, rhs, size, flag),
605+
DataType::Float64 => ree_distinct_primitive::<Float64Type>(lhs, rhs, size, flag),
606+
607+
// Temporal
608+
DataType::Date32 => ree_distinct_primitive::<Date32Type>(lhs, rhs, size, flag),
609+
DataType::Date64 => ree_distinct_primitive::<Date64Type>(lhs, rhs, size, flag),
610+
DataType::Timestamp(_, _) => {
611+
ree_distinct_primitive::<TimestampMicrosecondType>(lhs, rhs, size, flag)
612+
}
613+
DataType::Time32(_) => {
614+
ree_distinct_primitive::<Time32MillisecondType>(lhs, rhs, size, flag)
615+
}
616+
DataType::Time64(_) => ree_distinct_primitive::<Time64NanosecondType>(lhs, rhs, size, flag),
617+
DataType::Duration(_) => {
618+
ree_distinct_primitive::<DurationNanosecondType>(lhs, rhs, size, flag)
619+
}
620+
DataType::Interval(_) => {
621+
ree_distinct_primitive::<IntervalDayTimeType>(lhs, rhs, size, flag)
622+
}
623+
624+
// Decimals
625+
DataType::Decimal128(_, _) => {
626+
ree_distinct_primitive::<Decimal128Type>(lhs, rhs, size, flag)
627+
}
628+
DataType::Decimal256(_, _) => {
629+
ree_distinct_primitive::<Decimal256Type>(lhs, rhs, size, flag)
630+
}
631+
// Strings arent a primitive type, so we need to handle them separately
632+
DataType::Utf8 => ree_distinct_string(lhs, rhs, size, flag),
633+
634+
// Not yet supported or complex
635+
_ => None,
636+
}
637+
}
638+
639+
fn ree_distinct_primitive<T: ArrowPrimitiveType>(
640+
lhs: &AnyRunArray,
641+
rhs: &AnyRunArray,
642+
size: usize,
643+
flag: bool,
644+
) -> Option<BooleanArray> {
645+
let lhs_vals = lhs.values().as_any().downcast_ref::<PrimitiveArray<T>>()?;
646+
let rhs_vals = rhs.values().as_any().downcast_ref::<PrimitiveArray<T>>()?;
647+
let mut builder = BooleanBufferBuilder::new(size);
648+
for i in 0..size {
649+
let li = lhs.get_physical_index(i);
650+
let ri = rhs.get_physical_index(i);
651+
652+
let mut is_same = match (lhs_vals.is_null(li), rhs_vals.is_null(ri)) {
653+
(true, true) => true,
654+
(true, false) | (false, true) => false, // If one is null, result depends on flag
655+
(false, false) => lhs_vals.value(li) == rhs_vals.value(ri),
656+
};
657+
if flag {
658+
is_same = !is_same;
659+
}
660+
builder.append(is_same);
661+
}
662+
Some(BooleanArray::from(builder.finish()))
663+
}
664+
665+
fn ree_distinct_string(
666+
lhs: &AnyRunArray,
667+
rhs: &AnyRunArray,
668+
size: usize,
669+
flag: bool,
670+
) -> Option<BooleanArray> {
671+
let lhs_vals = lhs.values().as_any().downcast_ref::<StringArray>()?;
672+
let rhs_vals = rhs.values().as_any().downcast_ref::<StringArray>()?;
673+
674+
let mut builder = BooleanBufferBuilder::new(size);
675+
for i in 0..size {
676+
let li = lhs.get_physical_index(i);
677+
let ri = rhs.get_physical_index(i);
678+
679+
let mut is_same = match (lhs_vals.is_null(li), rhs_vals.is_null(ri)) {
680+
(true, true) => true,
681+
(true, false) | (false, true) => false,
682+
(false, false) => lhs_vals.value(li) == rhs_vals.value(ri),
683+
};
684+
if flag {
685+
is_same = !is_same;
686+
}
687+
builder.append(is_same);
688+
}
689+
Some(BooleanArray::from(builder.finish()))
690+
}
691+
591692
// Manually implement `Clone` to avoid `V: Clone` type constraint
592693
impl<R: RunEndIndexType, V> Clone for TypedRunArray<'_, R, V> {
593694
fn clone(&self) -> Self {
@@ -720,7 +821,6 @@ where
720821
}
721822
}
722823

723-
724824
/// An AnyRunArray is a wrapper around a RunArray that can be used to aggregate over a RunEndEncodedArray
725825
/// This is used to avoid the need to downcast the RunEndEncodedArray to a specific type
726826
pub enum AnyRunArray<'a> {
@@ -777,7 +877,7 @@ impl<'a> AnyRunArray<'a> {
777877
AnyRunArray::Int16(array) => array.run_ends().values()[i].as_usize(),
778878
}
779879
}
780-
880+
781881
/// Returns the length of run ends array
782882
pub fn run_ends_len(&self) -> usize {
783883
match self {
@@ -786,9 +886,16 @@ impl<'a> AnyRunArray<'a> {
786886
AnyRunArray::Int16(array) => array.values().len(),
787887
}
788888
}
789-
790-
}
791889

890+
/// Returns the physical index for the given logical index
891+
pub fn get_physical_index(&self, logical_index: usize) -> usize {
892+
match self {
893+
AnyRunArray::Int64(array) => array.get_physical_index(logical_index),
894+
AnyRunArray::Int32(array) => array.get_physical_index(logical_index),
895+
AnyRunArray::Int16(array) => array.get_physical_index(logical_index),
896+
}
897+
}
898+
}
792899

793900
#[cfg(test)]
794901
mod tests {

0 commit comments

Comments
 (0)