|
17 | 17 |
|
18 | 18 | //! Defines aggregations over Arrow arrays. |
19 | 19 |
|
20 | | -use arrow_array::cast::*; |
| 20 | +use arrow_array::cast::{*}; |
21 | 21 | use arrow_array::iterator::ArrayIter; |
22 | 22 | use arrow_array::*; |
23 | 23 | use arrow_buffer::{ArrowNativeType, NullBuffer}; |
24 | 24 | use arrow_data::bit_iterator::try_for_each_valid_idx; |
25 | 25 | use arrow_schema::*; |
26 | | -use num::cast; |
27 | 26 | use std::borrow::BorrowMut; |
28 | 27 | use std::cmp::{self, Ordering}; |
29 | 28 | use std::ops::{BitAnd, BitOr, BitXor}; |
@@ -574,22 +573,33 @@ where |
574 | 573 |
|
575 | 574 | Some(sum) |
576 | 575 | } |
577 | | - DataType::RunEndEncoded(_, _) => { |
| 576 | + DataType::RunEndEncoded(run_field, _) => { |
578 | 577 | let null_count = array.null_count(); |
579 | 578 |
|
580 | 579 | if null_count == array.len() { |
581 | 580 | return None; |
582 | 581 | } |
583 | | - |
584 | | - // Expand REE array to its logical form and recursively call sum_array |
585 | | - if let Some(expanded_array) = arrow_array::unwrap_ree_array(&array) { |
586 | | - // Cast the expanded array to the appropriate type and call sum_array recursively |
587 | | - if let Some(primitive_array) = expanded_array.as_any().downcast_ref::<PrimitiveArray<T>>() { |
588 | | - sum::<T>(primitive_array) |
589 | | - } else { |
590 | | - // If we can't downcast, return None |
591 | | - None |
| 582 | + let ree = match run_field.data_type() { |
| 583 | + DataType::Int64 => AnyRunArray::new(&array, DataType::Int64), |
| 584 | + DataType::Int32 => AnyRunArray::new(&array, DataType::Int32), |
| 585 | + DataType::Int16 => AnyRunArray::new(&array, DataType::Int16), |
| 586 | + _ => return None, |
| 587 | + }; |
| 588 | + if let Some(ree) = ree { |
| 589 | + let mut sum = T::default_value(); |
| 590 | + |
| 591 | + let values = ree.values(); |
| 592 | + let values_array = values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(); |
| 593 | + let values_data = values_array.values(); |
| 594 | + let mut prev_end = 0; |
| 595 | + for i in 0..ree.run_ends_len() { |
| 596 | + let end = ree.run_ends_value(i); |
| 597 | + let run_length = end - prev_end; |
| 598 | + let run_length_native = T::Native::from_usize(run_length).unwrap(); |
| 599 | + sum = sum.add_wrapping(values_data[i].mul_wrapping(run_length_native)); |
| 600 | + prev_end = end; |
592 | 601 | } |
| 602 | + Some(sum) |
593 | 603 | } else { |
594 | 604 | None |
595 | 605 | } |
@@ -630,25 +640,41 @@ where |
630 | 640 |
|
631 | 641 | Ok(Some(sum)) |
632 | 642 | } |
633 | | - DataType::RunEndEncoded(_, _) => { |
| 643 | + DataType::RunEndEncoded(run_field, _) => { |
634 | 644 | let null_count = array.null_count(); |
635 | 645 |
|
636 | 646 | if null_count == array.len() { |
637 | 647 | return Ok(None); |
638 | 648 | } |
639 | 649 |
|
640 | | - // Expand REE array to its logical form and recursively call sum_array_checked |
641 | | - if let Some(expanded_array) = arrow_array::unwrap_ree_array(&array) { |
642 | | - // Cast the expanded array to the appropriate type and call sum_checked recursively |
643 | | - if let Some(primitive_array) = expanded_array.as_any().downcast_ref::<PrimitiveArray<T>>() { |
644 | | - sum_checked::<T>(primitive_array) |
645 | | - } else { |
646 | | - // If we can't downcast, return None |
647 | | - Ok(None) |
| 650 | + let ree = match run_field.data_type() { |
| 651 | + DataType::Int64 => AnyRunArray::new(&array, DataType::Int64), |
| 652 | + DataType::Int32 => AnyRunArray::new(&array, DataType::Int32), |
| 653 | + DataType::Int16 => AnyRunArray::new(&array, DataType::Int16), |
| 654 | + _ => return Ok(None), |
| 655 | + }; |
| 656 | + |
| 657 | + if let Some(ree) = ree { |
| 658 | + let mut sum = T::default_value(); |
| 659 | + |
| 660 | + let values = ree.values(); |
| 661 | + let values_array = values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(); |
| 662 | + let values_data = values_array.values(); |
| 663 | + |
| 664 | + let mut prev_end = 0; |
| 665 | + for i in 0..ree.run_ends_len() { |
| 666 | + let end = ree.run_ends_value(i); |
| 667 | + let run_length = end - prev_end; |
| 668 | + let run_length_native = T::Native::from_usize(run_length).unwrap(); |
| 669 | + sum = sum.add_checked(values_data[i].mul_checked(run_length_native)?)?; |
| 670 | + prev_end = end; |
648 | 671 | } |
| 672 | + |
| 673 | + Ok(Some(sum)) |
649 | 674 | } else { |
650 | 675 | Ok(None) |
651 | 676 | } |
| 677 | + |
652 | 678 | } |
653 | 679 | _ => sum_checked::<T>(as_primitive_array(&array)), |
654 | 680 | } |
@@ -687,24 +713,7 @@ where |
687 | 713 | match array.data_type() { |
688 | 714 | DataType::Dictionary(_, _) => min_max_helper::<T::Native, _, _>(array, cmp), |
689 | 715 | DataType::RunEndEncoded(_, _) => { |
690 | | - let null_count = array.null_count(); |
691 | | - |
692 | | - if null_count == array.len() { |
693 | | - return None; |
694 | | - } |
695 | | - |
696 | | - // Expand REE array to its logical form and recursively call min_max_array_helper |
697 | | - if let Some(expanded_array) = arrow_array::unwrap_ree_array(&array) { |
698 | | - // Cast the expanded array to the appropriate type and call min_max_helper recursively |
699 | | - if let Some(primitive_array) = expanded_array.as_any().downcast_ref::<PrimitiveArray<T>>() { |
700 | | - min_max_helper::<T::Native, _, _>(primitive_array, cmp) |
701 | | - } else { |
702 | | - // If we can't downcast, return None |
703 | | - None |
704 | | - } |
705 | | - } else { |
706 | | - None |
707 | | - } |
| 716 | + min_max_helper::<T::Native, _, _>(array, cmp) |
708 | 717 | } |
709 | 718 | _ => m(as_primitive_array(&array)), |
710 | 719 | } |
@@ -1762,151 +1771,119 @@ mod tests { |
1762 | 1771 | sum_checked(&a).expect_err("overflow should be detected"); |
1763 | 1772 | sum_array_checked::<Int32Type, _>(&a).expect_err("overflow should be detected"); |
1764 | 1773 | } |
1765 | | - // ... existing code ... |
1766 | | - // REE (RunEndEncodedArray) Tests |
1767 | 1774 | mod ree_aggregation { |
1768 | 1775 | use super::*; |
1769 | 1776 | use arrow_array::{RunArray, Int32Array, Int64Array, Float64Array}; |
1770 | 1777 | use arrow_array::types::{Int32Type, Int64Type, Float64Type}; |
1771 | 1778 |
|
1772 | 1779 | #[test] |
1773 | 1780 | fn test_ree_sum_array_basic() { |
1774 | | - // REE array: [10, 10, 20, 30, 30] (logical length 5) |
1775 | | - let run_ends = Int32Array::from(vec![2, 3, 5]); |
| 1781 | + // REE array: [10, 10, 20, 30, 30,30] (logical length 6) |
| 1782 | + let run_ends = Int32Array::from(vec![2, 3, 6]); |
1776 | 1783 | let values = Int32Array::from(vec![10, 20, 30]); |
1777 | 1784 | let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap(); |
1778 | 1785 |
|
1779 | | - // Expand to logical form and test |
1780 | | - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); |
1781 | | - let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap(); |
1782 | | - let result = sum_array::<Int32Type, _>(primitive_array); |
1783 | | - assert_eq!(result, Some(100)); // 10+10+20+30+30 = 100 |
| 1786 | + |
| 1787 | + let typed_array = run_array.downcast::<Int32Array>().unwrap(); |
| 1788 | + |
| 1789 | + let result = sum_array::<Int32Type, _>(typed_array); |
| 1790 | + assert_eq!(result, Some(130)); // 10+10+20+30+30+30 = 130 |
1784 | 1791 | } |
1785 | 1792 |
|
1786 | 1793 | #[test] |
1787 | 1794 | fn test_ree_sum_array_with_nulls() { |
1788 | 1795 | // REE array with nulls: [10, NULL, 20, NULL, 30] |
1789 | 1796 | let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); |
1790 | | - let values = Int32Array::from(vec![10, -1, 20, -1, 30]); // -1 represents null |
| 1797 | + let values = Int32Array::from(vec![10, 0, 20, 0, 30]); // 0 represents null |
1791 | 1798 | let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap(); |
1792 | 1799 |
|
1793 | | - // Expand to logical form and test |
1794 | | - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); |
1795 | | - let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap(); |
1796 | | - let result = sum_array::<Int32Type, _>(primitive_array); |
| 1800 | + let typed_array = run_array.downcast::<Int32Array>().unwrap(); |
| 1801 | + let result = sum_array::<Int32Type, _>(typed_array); |
1797 | 1802 | assert_eq!(result, Some(60)); // 10+20+30 = 60 (nulls ignored) |
1798 | 1803 | } |
1799 | 1804 |
|
1800 | 1805 | #[test] |
1801 | | - fn test_ree_sum_array_checked_basic() { |
| 1806 | + fn test_ree_sum_array_large_values() { |
| 1807 | + // REE array with large values: [1000, 1000, 2000, 3000, 3000] |
| 1808 | + let run_ends = Int64Array::from(vec![2, 3, 5]); |
| 1809 | + let values = Int64Array::from(vec![1000, 2000, 3000]); |
| 1810 | + let run_array = RunArray::<Int64Type>::try_new(&run_ends, &values).unwrap(); |
| 1811 | + |
| 1812 | + let typed_array = run_array.downcast::<Int64Array>().unwrap(); |
| 1813 | + let result = sum_array::<Int64Type, _>(typed_array); |
| 1814 | + assert_eq!(result, Some(10000)); // 1000+1000+2000+3000+3000 = 10000 |
| 1815 | + } |
| 1816 | + |
| 1817 | + #[test] |
| 1818 | + fn test_ree_sum_checked_array_basic() { |
1802 | 1819 | // REE array: [5, 5, 10, 15, 15] (logical length 5) |
1803 | 1820 | let run_ends = Int32Array::from(vec![2, 3, 5]); |
1804 | 1821 | let values = Int32Array::from(vec![5, 10, 15]); |
1805 | 1822 | let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap(); |
1806 | 1823 |
|
1807 | | - // Expand to logical form and test |
1808 | | - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); |
1809 | | - let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap(); |
1810 | | - let result = sum_array_checked::<Int32Type, _>(primitive_array).unwrap(); |
1811 | | - assert_eq!(result, Some(50)); // 5+5+10+15+15 = 50 |
| 1824 | + let typed_array = run_array.downcast::<Int32Array>().unwrap(); |
| 1825 | + let result = sum_array_checked::<Int32Type, _>(typed_array); |
| 1826 | + assert_eq!(result.unwrap(), Some(50)); // 5+5+10+15+15 = 50 |
1812 | 1827 | } |
1813 | 1828 |
|
1814 | 1829 | #[test] |
1815 | | - fn test_ree_sum_array_checked_overflow() { |
1816 | | - // REE array that will overflow: [i32::MAX, i32::MAX, 1] |
| 1830 | + fn test_ree_sum_checked_array_overflow() { |
| 1831 | + // REE array that will cause overflow: [i32::MAX, i32::MAX, 1] |
1817 | 1832 | let run_ends = Int32Array::from(vec![2, 3]); |
1818 | 1833 | let values = Int32Array::from(vec![i32::MAX, 1]); |
1819 | 1834 | let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap(); |
1820 | 1835 |
|
1821 | | - // Expand to logical form and test |
1822 | | - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); |
1823 | | - let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap(); |
1824 | | - let result = sum_array_checked::<Int32Type, _>(primitive_array); |
1825 | | - assert!(result.is_err()); // Should overflow |
| 1836 | + let typed_array = run_array.downcast::<Int32Array>().unwrap(); |
| 1837 | + let result = sum_array_checked::<Int32Type, _>(typed_array); |
| 1838 | + assert!(result.is_err()); // Should detect overflow |
1826 | 1839 | } |
1827 | 1840 |
|
1828 | 1841 | #[test] |
1829 | 1842 | fn test_ree_min_array_basic() { |
1830 | | - // REE array: [50, 50, 10, 30, 30] (logical length 5) |
| 1843 | + // REE array: [30, 30, 10, 20, 20] (logical length 5) |
1831 | 1844 | let run_ends = Int32Array::from(vec![2, 3, 5]); |
1832 | | - let values = Int32Array::from(vec![50, 10, 30]); |
| 1845 | + let values = Int32Array::from(vec![30, 10, 20]); |
1833 | 1846 | let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap(); |
1834 | 1847 |
|
1835 | | - // Expand to logical form and test |
1836 | | - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); |
1837 | | - let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap(); |
1838 | | - let result = min_array::<Int32Type, _>(primitive_array); |
1839 | | - assert_eq!(result, Some(10)); // Minimum value is 10 |
| 1848 | + let typed_array = run_array.downcast::<Int32Array>().unwrap(); |
| 1849 | + let result = min_array::<Int32Type, _>(typed_array); |
| 1850 | + assert_eq!(result, Some(10)); // min(30, 30, 10, 20, 20) = 10 |
1840 | 1851 | } |
1841 | 1852 |
|
1842 | 1853 | #[test] |
1843 | | - fn test_ree_min_array_with_nulls() { |
1844 | | - // REE array with nulls: [100, NULL, 5, NULL, 200] |
1845 | | - let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); |
1846 | | - let values = Int32Array::from(vec![100, -1, 5, -1, 200]); // -1 represents null |
| 1854 | + fn test_ree_min_array_float() { |
| 1855 | + // REE array with floats: [5.5, 5.5, 2.1, 8.9, 8.9] |
| 1856 | + let run_ends = Int32Array::from(vec![2, 3, 5]); |
| 1857 | + let values = Float64Array::from(vec![5.5, 2.1, 8.9]); |
1847 | 1858 | let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap(); |
1848 | 1859 |
|
1849 | | - // Expand to logical form and test |
1850 | | - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); |
1851 | | - let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap(); |
1852 | | - let result = min_array::<Int32Type, _>(primitive_array); |
1853 | | - assert_eq!(result, Some(5)); // Minimum non-null value is 5 |
| 1860 | + let typed_array = run_array.downcast::<Float64Array>().unwrap(); |
| 1861 | + let result = min_array::<Float64Type, _>(typed_array); |
| 1862 | + assert_eq!(result, Some(2.1)); // min(5.5, 5.5, 2.1, 8.9, 8.9) = 2.1 |
1854 | 1863 | } |
1855 | 1864 |
|
1856 | 1865 | #[test] |
1857 | 1866 | fn test_ree_max_array_basic() { |
1858 | | - // REE array: [10, 10, 50, 20, 20] (logical length 5) |
| 1867 | + // REE array: [10, 10, 30, 20, 20] (logical length 5) |
1859 | 1868 | let run_ends = Int32Array::from(vec![2, 3, 5]); |
1860 | | - let values = Int32Array::from(vec![10, 50, 20]); |
| 1869 | + let values = Int32Array::from(vec![10, 30, 20]); |
1861 | 1870 | let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap(); |
1862 | 1871 |
|
1863 | | - // Expand to logical form and test |
1864 | | - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); |
1865 | | - let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap(); |
1866 | | - let result = max_array::<Int32Type, _>(primitive_array); |
1867 | | - assert_eq!(result, Some(50)); // Maximum value is 50 |
1868 | | - } |
1869 | | - |
1870 | | - #[test] |
1871 | | - fn test_ree_max_array_with_nulls() { |
1872 | | - // REE array with nulls: [5, NULL, 500, NULL, 10] |
1873 | | - let run_ends = Int32Array::from(vec![1, 2, 3, 4, 5]); |
1874 | | - let values = Int32Array::from(vec![5, -1, 500, -1, 10]); // -1 represents null |
1875 | | - let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap(); |
1876 | | - |
1877 | | - // Expand to logical form and test |
1878 | | - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); |
1879 | | - let primitive_array = expanded.as_any().downcast_ref::<Int32Array>().unwrap(); |
1880 | | - let result = max_array::<Int32Type, _>(primitive_array); |
1881 | | - assert_eq!(result, Some(500)); // Maximum non-null value is 500 |
1882 | | - } |
1883 | | - |
1884 | | - #[test] |
1885 | | - fn test_ree_sum_array_large_values() { |
1886 | | - // REE array with large values: [1000000, 1000000, 2000000, 3000000, 3000000] |
1887 | | - let run_ends = Int64Array::from(vec![2, 3, 5]); |
1888 | | - let values = Int64Array::from(vec![1000000, 2000000, 3000000]); |
1889 | | - let run_array = RunArray::<Int64Type>::try_new(&run_ends, &values).unwrap(); |
1890 | | - |
1891 | | - // Expand to logical form and test |
1892 | | - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); |
1893 | | - let primitive_array = expanded.as_any().downcast_ref::<Int64Array>().unwrap(); |
1894 | | - let result = sum_array::<Int64Type, _>(primitive_array); |
1895 | | - assert_eq!(result, Some(10000000)); // 1M+1M+2M+3M+3M = 10M |
| 1872 | + let typed_array = run_array.downcast::<Int32Array>().unwrap(); |
| 1873 | + let result = max_array::<Int32Type, _>(typed_array); |
| 1874 | + assert_eq!(result, Some(30)); // max(10, 10, 30, 20, 20) = 30 |
1896 | 1875 | } |
1897 | 1876 |
|
1898 | 1877 | #[test] |
1899 | | - fn test_ree_max_array_float_values() { |
1900 | | - // REE array with float values: [1.5, 1.5, 3.7, 2.1, 2.1] |
| 1878 | + fn test_ree_max_array_float() { |
| 1879 | + // REE array with floats: [2.1, 2.1, 8.9, 5.5, 5.5] |
1901 | 1880 | let run_ends = Int32Array::from(vec![2, 3, 5]); |
1902 | | - let values = Float64Array::from(vec![1.5, 3.7, 2.1]); |
| 1881 | + let values = Float64Array::from(vec![2.1, 8.9, 5.5]); |
1903 | 1882 | let run_array = RunArray::<Int32Type>::try_new(&run_ends, &values).unwrap(); |
1904 | 1883 |
|
1905 | | - // Expand to logical form and test |
1906 | | - let expanded = arrow_array::unwrap_ree_array(&run_array).unwrap(); |
1907 | | - let primitive_array = expanded.as_any().downcast_ref::<Float64Array>().unwrap(); |
1908 | | - let result = max_array::<Float64Type, _>(primitive_array); |
1909 | | - assert_eq!(result, Some(3.7)); // Maximum value is 3.7 |
| 1884 | + let typed_array = run_array.downcast::<Float64Array>().unwrap(); |
| 1885 | + let result = max_array::<Float64Type, _>(typed_array); |
| 1886 | + assert_eq!(result, Some(8.9)); // max(2.1, 2.1, 8.9, 5.5, 5.5) = 8.9 |
1910 | 1887 | } |
1911 | 1888 | } |
1912 | 1889 | } |
0 commit comments