Skip to content

Commit 72bd81a

Browse files
committed
Implementing optmizations, to remove unneeded RunEnd unpacking. Summation is the only function left
1 parent acb5d08 commit 72bd81a

File tree

3 files changed

+176
-154
lines changed

3 files changed

+176
-154
lines changed

arrow-arith/src/aggregate.rs

Lines changed: 102 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717

1818
//! Defines aggregations over Arrow arrays.
1919
20-
use arrow_array::cast::*;
20+
use arrow_array::cast::{*};
2121
use arrow_array::iterator::ArrayIter;
2222
use arrow_array::*;
2323
use arrow_buffer::{ArrowNativeType, NullBuffer};
2424
use arrow_data::bit_iterator::try_for_each_valid_idx;
2525
use arrow_schema::*;
26-
use num::cast;
2726
use std::borrow::BorrowMut;
2827
use std::cmp::{self, Ordering};
2928
use std::ops::{BitAnd, BitOr, BitXor};
@@ -574,22 +573,33 @@ where
574573

575574
Some(sum)
576575
}
577-
DataType::RunEndEncoded(_, _) => {
576+
DataType::RunEndEncoded(run_field, _) => {
578577
let null_count = array.null_count();
579578

580579
if null_count == array.len() {
581580
return None;
582581
}
583-
584-
// Expand REE array to its logical form and recursively call sum_array
585-
if let Some(expanded_array) = arrow_array::unwrap_ree_array(&array) {
586-
// Cast the expanded array to the appropriate type and call sum_array recursively
587-
if let Some(primitive_array) = expanded_array.as_any().downcast_ref::<PrimitiveArray<T>>() {
588-
sum::<T>(primitive_array)
589-
} else {
590-
// If we can't downcast, return None
591-
None
582+
let ree = match run_field.data_type() {
583+
DataType::Int64 => AnyRunArray::new(&array, DataType::Int64),
584+
DataType::Int32 => AnyRunArray::new(&array, DataType::Int32),
585+
DataType::Int16 => AnyRunArray::new(&array, DataType::Int16),
586+
_ => return None,
587+
};
588+
if let Some(ree) = ree {
589+
let mut sum = T::default_value();
590+
591+
let values = ree.values();
592+
let values_array = values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
593+
let values_data = values_array.values();
594+
let mut prev_end = 0;
595+
for i in 0..ree.run_ends_len() {
596+
let end = ree.run_ends_value(i);
597+
let run_length = end - prev_end;
598+
let run_length_native = T::Native::from_usize(run_length).unwrap();
599+
sum = sum.add_wrapping(values_data[i].mul_wrapping(run_length_native));
600+
prev_end = end;
592601
}
602+
Some(sum)
593603
} else {
594604
None
595605
}
@@ -630,25 +640,41 @@ where
630640

631641
Ok(Some(sum))
632642
}
633-
DataType::RunEndEncoded(_, _) => {
643+
DataType::RunEndEncoded(run_field, _) => {
634644
let null_count = array.null_count();
635645

636646
if null_count == array.len() {
637647
return Ok(None);
638648
}
639649

640-
// Expand REE array to its logical form and recursively call sum_array_checked
641-
if let Some(expanded_array) = arrow_array::unwrap_ree_array(&array) {
642-
// Cast the expanded array to the appropriate type and call sum_checked recursively
643-
if let Some(primitive_array) = expanded_array.as_any().downcast_ref::<PrimitiveArray<T>>() {
644-
sum_checked::<T>(primitive_array)
645-
} else {
646-
// If we can't downcast, return None
647-
Ok(None)
650+
let ree = match run_field.data_type() {
651+
DataType::Int64 => AnyRunArray::new(&array, DataType::Int64),
652+
DataType::Int32 => AnyRunArray::new(&array, DataType::Int32),
653+
DataType::Int16 => AnyRunArray::new(&array, DataType::Int16),
654+
_ => return Ok(None),
655+
};
656+
657+
if let Some(ree) = ree {
658+
let mut sum = T::default_value();
659+
660+
let values = ree.values();
661+
let values_array = values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
662+
let values_data = values_array.values();
663+
664+
let mut prev_end = 0;
665+
for i in 0..ree.run_ends_len() {
666+
let end = ree.run_ends_value(i);
667+
let run_length = end - prev_end;
668+
let run_length_native = T::Native::from_usize(run_length).unwrap();
669+
sum = sum.add_checked(values_data[i].mul_checked(run_length_native)?)?;
670+
prev_end = end;
648671
}
672+
673+
Ok(Some(sum))
649674
} else {
650675
Ok(None)
651676
}
677+
652678
}
653679
_ => sum_checked::<T>(as_primitive_array(&array)),
654680
}
@@ -687,24 +713,7 @@ where
687713
match array.data_type() {
688714
DataType::Dictionary(_, _) => min_max_helper::<T::Native, _, _>(array, cmp),
689715
DataType::RunEndEncoded(_, _) => {
690-
let null_count = array.null_count();
691-
692-
if null_count == array.len() {
693-
return None;
694-
}
695-
696-
// Expand REE array to its logical form and recursively call min_max_array_helper
697-
if let Some(expanded_array) = arrow_array::unwrap_ree_array(&array) {
698-
// Cast the expanded array to the appropriate type and call min_max_helper recursively
699-
if let Some(primitive_array) = expanded_array.as_any().downcast_ref::<PrimitiveArray<T>>() {
700-
min_max_helper::<T::Native, _, _>(primitive_array, cmp)
701-
} else {
702-
// If we can't downcast, return None
703-
None
704-
}
705-
} else {
706-
None
707-
}
716+
min_max_helper::<T::Native, _, _>(array, cmp)
708717
}
709718
_ => m(as_primitive_array(&array)),
710719
}
@@ -1762,151 +1771,119 @@ mod tests {
17621771
sum_checked(&a).expect_err("overflow should be detected");
17631772
sum_array_checked::<Int32Type, _>(&a).expect_err("overflow should be detected");
17641773
}
1765-
// ... existing code ...
1766-
// REE (RunEndEncodedArray) Tests
17671774
mod ree_aggregation {
17681775
use super::*;
17691776
use arrow_array::{RunArray, Int32Array, Int64Array, Float64Array};
17701777
use arrow_array::types::{Int32Type, Int64Type, Float64Type};
17711778

17721779
#[test]
17731780
fn test_ree_sum_array_basic() {
1774-
// REE array: [10, 10, 20, 30, 30] (logical length 5)
1775-
let run_ends = Int32Array::from(vec![2, 3, 5]);
1781+
// REE array: [10, 10, 20, 30, 30,30] (logical length 6)
1782+
let run_ends = Int32Array::from(vec![2, 3, 6]);
17761783
let values = Int32Array::from(vec![10, 20, 30]);
17771784
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
17781785

1779-
// Expand to logical form and test
1780-
let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap();
1781-
let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap();
1782-
let result = sum_array::<Int32Type, _>(primitive_array);
1783-
assert_eq!(result, Some(100)); // 10+10+20+30+30 = 100
1786+
1787+
let typed_array = run_array.downcast::<Int32Array>().unwrap();
1788+
1789+
let result = sum_array::<Int32Type, _>(typed_array);
1790+
assert_eq!(result, Some(130)); // 10+10+20+30+30+30 = 130
17841791
}
17851792

17861793
#[test]
17871794
fn test_ree_sum_array_with_nulls() {
17881795
// REE array with nulls: [10, NULL, 20, NULL, 30]
17891796
let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]);
1790-
let values = Int32Array::from(vec![10, -1, 20, -1, 30]); // -1 represents null
1797+
let values = Int32Array::from(vec![10, 0, 20, 0, 30]); // 0 represents null
17911798
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
17921799

1793-
// Expand to logical form and test
1794-
let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap();
1795-
let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap();
1796-
let result = sum_array::<Int32Type, _>(primitive_array);
1800+
let typed_array = run_array.downcast::<Int32Array>().unwrap();
1801+
let result = sum_array::<Int32Type, _>(typed_array);
17971802
assert_eq!(result, Some(60)); // 10+20+30 = 60 (nulls ignored)
17981803
}
17991804

18001805
#[test]
1801-
fn test_ree_sum_array_checked_basic() {
1806+
fn test_ree_sum_array_large_values() {
1807+
// REE array with large values: [1000, 1000, 2000, 3000, 3000]
1808+
let run_ends = Int64Array::from(vec![2, 3, 5]);
1809+
let values = Int64Array::from(vec![1000, 2000, 3000]);
1810+
let run_array = RunArray::<Int64Type>::try_new(&run_ends, &values).unwrap();
1811+
1812+
let typed_array = run_array.downcast::<Int64Array>().unwrap();
1813+
let result = sum_array::<Int64Type, _>(typed_array);
1814+
assert_eq!(result, Some(10000)); // 1000+1000+2000+3000+3000 = 10000
1815+
}
1816+
1817+
#[test]
1818+
fn test_ree_sum_checked_array_basic() {
18021819
// REE array: [5, 5, 10, 15, 15] (logical length 5)
18031820
let run_ends = Int32Array::from(vec![2, 3, 5]);
18041821
let values = Int32Array::from(vec![5, 10, 15]);
18051822
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
18061823

1807-
// Expand to logical form and test
1808-
let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap();
1809-
let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap();
1810-
let result = sum_array_checked::<Int32Type, _>(primitive_array).unwrap();
1811-
assert_eq!(result, Some(50)); // 5+5+10+15+15 = 50
1824+
let typed_array = run_array.downcast::<Int32Array>().unwrap();
1825+
let result = sum_array_checked::<Int32Type, _>(typed_array);
1826+
assert_eq!(result.unwrap(), Some(50)); // 5+5+10+15+15 = 50
18121827
}
18131828

18141829
#[test]
1815-
fn test_ree_sum_array_checked_overflow() {
1816-
// REE array that will overflow: [i32::MAX, i32::MAX, 1]
1830+
fn test_ree_sum_checked_array_overflow() {
1831+
// REE array that will cause overflow: [i32::MAX, i32::MAX, 1]
18171832
let run_ends = Int32Array::from(vec![2, 3]);
18181833
let values = Int32Array::from(vec![i32::MAX, 1]);
18191834
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
18201835

1821-
// Expand to logical form and test
1822-
let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap();
1823-
let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap();
1824-
let result = sum_array_checked::<Int32Type, _>(primitive_array);
1825-
assert!(result.is_err()); // Should overflow
1836+
let typed_array = run_array.downcast::<Int32Array>().unwrap();
1837+
let result = sum_array_checked::<Int32Type, _>(typed_array);
1838+
assert!(result.is_err()); // Should detect overflow
18261839
}
18271840

18281841
#[test]
18291842
fn test_ree_min_array_basic() {
1830-
// REE array: [50, 50, 10, 30, 30] (logical length 5)
1843+
// REE array: [30, 30, 10, 20, 20] (logical length 5)
18311844
let run_ends = Int32Array::from(vec![2, 3, 5]);
1832-
let values = Int32Array::from(vec![50, 10, 30]);
1845+
let values = Int32Array::from(vec![30, 10, 20]);
18331846
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
18341847

1835-
// Expand to logical form and test
1836-
let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap();
1837-
let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap();
1838-
let result = min_array::<Int32Type, _>(primitive_array);
1839-
assert_eq!(result, Some(10)); // Minimum value is 10
1848+
let typed_array = run_array.downcast::<Int32Array>().unwrap();
1849+
let result = min_array::<Int32Type, _>(typed_array);
1850+
assert_eq!(result, Some(10)); // min(30, 30, 10, 20, 20) = 10
18401851
}
18411852

18421853
#[test]
1843-
fn test_ree_min_array_with_nulls() {
1844-
// REE array with nulls: [100, NULL, 5, NULL, 200]
1845-
let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]);
1846-
let values = Int32Array::from(vec![100, -1, 5, -1, 200]); // -1 represents null
1854+
fn test_ree_min_array_float() {
1855+
// REE array with floats: [5.5, 5.5, 2.1, 8.9, 8.9]
1856+
let run_ends = Int32Array::from(vec![2, 3, 5]);
1857+
let values = Float64Array::from(vec![5.5, 2.1, 8.9]);
18471858
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
18481859

1849-
// Expand to logical form and test
1850-
let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap();
1851-
let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap();
1852-
let result = min_array::<Int32Type, _>(primitive_array);
1853-
assert_eq!(result, Some(5)); // Minimum non-null value is 5
1860+
let typed_array = run_array.downcast::<Float64Array>().unwrap();
1861+
let result = min_array::<Float64Type, _>(typed_array);
1862+
assert_eq!(result, Some(2.1)); // min(5.5, 5.5, 2.1, 8.9, 8.9) = 2.1
18541863
}
18551864

18561865
#[test]
18571866
fn test_ree_max_array_basic() {
1858-
// REE array: [10, 10, 50, 20, 20] (logical length 5)
1867+
// REE array: [10, 10, 30, 20, 20] (logical length 5)
18591868
let run_ends = Int32Array::from(vec![2, 3, 5]);
1860-
let values = Int32Array::from(vec![10, 50, 20]);
1869+
let values = Int32Array::from(vec![10, 30, 20]);
18611870
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
18621871

1863-
// Expand to logical form and test
1864-
let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap();
1865-
let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap();
1866-
let result = max_array::<Int32Type, _>(primitive_array);
1867-
assert_eq!(result, Some(50)); // Maximum value is 50
1868-
}
1869-
1870-
#[test]
1871-
fn test_ree_max_array_with_nulls() {
1872-
// REE array with nulls: [5, NULL, 500, NULL, 10]
1873-
let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]);
1874-
let values = Int32Array::from(vec![5, -1, 500, -1, 10]); // -1 represents null
1875-
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
1876-
1877-
// Expand to logical form and test
1878-
let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap();
1879-
let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap();
1880-
let result = max_array::<Int32Type, _>(primitive_array);
1881-
assert_eq!(result, Some(500)); // Maximum non-null value is 500
1882-
}
1883-
1884-
#[test]
1885-
fn test_ree_sum_array_large_values() {
1886-
// REE array with large values: [1000000, 1000000, 2000000, 3000000, 3000000]
1887-
let run_ends = Int64Array::from(vec![2, 3, 5]);
1888-
let values = Int64Array::from(vec![1000000, 2000000, 3000000]);
1889-
let run_array = RunArray::<Int64Type>::try_new(&run_ends, &values).unwrap();
1890-
1891-
// Expand to logical form and test
1892-
let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap();
1893-
let primitive_array = expanded.as_any().downcast_ref::<Int64Array>().unwrap();
1894-
let result = sum_array::<Int64Type, _>(primitive_array);
1895-
assert_eq!(result, Some(10000000)); // 1M+1M+2M+3M+3M = 10M
1872+
let typed_array = run_array.downcast::<Int32Array>().unwrap();
1873+
let result = max_array::<Int32Type, _>(typed_array);
1874+
assert_eq!(result, Some(30)); // max(10, 10, 30, 20, 20) = 30
18961875
}
18971876

18981877
#[test]
1899-
fn test_ree_max_array_float_values() {
1900-
// REE array with float values: [1.5, 1.5, 3.7, 2.1, 2.1]
1878+
fn test_ree_max_array_float() {
1879+
// REE array with floats: [2.1, 2.1, 8.9, 5.5, 5.5]
19011880
let run_ends = Int32Array::from(vec![2, 3, 5]);
1902-
let values = Float64Array::from(vec![1.5, 3.7, 2.1]);
1881+
let values = Float64Array::from(vec![2.1, 8.9, 5.5]);
19031882
let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap();
19041883

1905-
// Expand to logical form and test
1906-
let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap();
1907-
let primitive_array = expanded.as_any().downcast_ref::<Float64Array>().unwrap();
1908-
let result = max_array::<Float64Type, _>(primitive_array);
1909-
assert_eq!(result, Some(3.7)); // Maximum value is 3.7
1884+
let typed_array = run_array.downcast::<Float64Array>().unwrap();
1885+
let result = max_array::<Float64Type, _>(typed_array);
1886+
assert_eq!(result, Some(8.9)); // max(2.1, 2.1, 8.9, 5.5, 5.5) = 8.9
19101887
}
19111888
}
19121889
}

0 commit comments

Comments
 (0)