diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index 6708da3d5dd6..f96d0158768e 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -44,7 +44,7 @@ use arrow_buffer::NullBuffer; /// [`PrimitiveArray`]: crate::PrimitiveArray /// [`compute::unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.unary.html /// [`compute::try_unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.try_unary.html -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ArrayIter { array: T, logical_nulls: Option, @@ -98,8 +98,8 @@ impl Iterator for ArrayIter { fn size_hint(&self) -> (usize, Option) { ( - self.array.len() - self.current, - Some(self.array.len() - self.current), + self.current_end - self.current, + Some(self.current_end - self.current), ) } } @@ -147,9 +147,14 @@ pub type MapArrayIter<'a> = ArrayIter<&'a MapArray>; pub type GenericListViewArrayIter<'a, O> = ArrayIter<&'a GenericListViewArray>; #[cfg(test)] mod tests { - use std::sync::Arc; - use crate::array::{ArrayRef, BinaryArray, BooleanArray, Int32Array, StringArray}; + use crate::iterator::ArrayIter; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::fmt::Debug; + use std::iter::Copied; + use std::slice::Iter; + use std::sync::Arc; #[test] fn test_primitive_array_iter_round_trip() { @@ -264,4 +269,924 @@ mod tests { // check if ExactSizeIterator is implemented let _ = array.iter().rposition(|opt_b| opt_b == Some(true)); } + + trait SharedBetweenArrayIterAndSliceIter: + ExactSizeIterator> + DoubleEndedIterator> + Clone + { + } + impl> + DoubleEndedIterator>> + SharedBetweenArrayIterAndSliceIter for T + { + } + + fn get_int32_iterator_cases() -> impl Iterator>)> { + let mut rng = StdRng::seed_from_u64(42); + + let no_nulls_and_no_duplicates = (0..10).map(Some).collect::>>(); + let no_nulls_random_values = (0..10) + .map(|_| rng.random::()) + .map(Some) + .collect::>>(); + + let all_nulls = (0..10).map(|_| None).collect::>>(); + let only_start_nulls = (0..10) + .map(|item| if item < 4 { None } else { Some(item) }) + .collect::>>(); + let only_end_nulls = (0..10) + .map(|item| if item > 8 { None } else { Some(item) }) + .collect::>>(); + let only_middle_nulls = (0..10) + .map(|item| { + if (4..=8).contains(&item) && rng.random_bool(0.9) { + None + } else { + Some(item) + } + }) + .collect::>>(); + let random_values_with_random_nulls = (0..10) + .map(|_| { + if rng.random_bool(0.3) { + None + } else { + Some(rng.random::()) + } + }) + .collect::>>(); + + let no_nulls_and_some_duplicates = (0..10) + .map(|item| item % 3) + .map(Some) + .collect::>>(); + let no_nulls_and_all_same_value = + (0..10).map(|_| 9).map(Some).collect::>>(); + let no_nulls_and_continues_duplicates = [0, 0, 0, 1, 1, 2, 2, 2, 2, 3] + .map(Some) + .into_iter() + .collect::>>(); + + let single_null_and_no_duplicates = (0..10) + .map(|item| if item == 4 { None } else { Some(item) }) + .collect::>>(); + let multiple_nulls_and_no_duplicates = (0..10) + .map(|item| if item % 3 == 2 { None } else { Some(item) }) + .collect::>>(); + let continues_nulls_and_no_duplicates = [ + Some(0), + Some(1), + None, + None, + Some(2), + Some(3), + None, + Some(4), + Some(5), + None, + ] + .into_iter() + .collect::>>(); + + [ + no_nulls_and_no_duplicates, + no_nulls_random_values, + no_nulls_and_some_duplicates, + no_nulls_and_all_same_value, + no_nulls_and_continues_duplicates, + all_nulls, + only_start_nulls, + only_end_nulls, + only_middle_nulls, + random_values_with_random_nulls, + single_null_and_no_duplicates, + multiple_nulls_and_no_duplicates, + continues_nulls_and_no_duplicates, + ] + .map(|case| (Int32Array::from(case.clone()), case)) + .into_iter() + } + + trait SetupIter { + fn setup(&self, iter: &mut I); + } + + struct NoSetup; + impl SetupIter for NoSetup { + fn setup(&self, _iter: &mut I) { + // none + } + } + + fn setup_and_assert_cases( + setup_iterator: impl SetupIter, + assert_fn: impl Fn(ArrayIter<&Int32Array>, Copied>>), + ) { + for (array, source) in get_int32_iterator_cases() { + let mut actual = ArrayIter::new(&array); + let mut expected = source.iter().copied(); + + setup_iterator.setup(&mut actual); + setup_iterator.setup(&mut expected); + + assert_fn(actual, expected); + } + } + + /// Trait representing an operation on a [`ArrayIter`] + /// that can be compared against a slice iterator + /// + /// this is for consuming operations (e.g. `count`, `last`, etc) + trait ConsumingArrayIteratorOp { + /// What the operation returns (e.g. Option for last, usize for count, etc) + type Output: PartialEq + Debug; + + /// The name of the operation, used for error messages + fn name(&self) -> String; + + /// Get the value of the operation for the provided iterator + /// This will be either a [`ArrayIter`] or a slice iterator to make sure they produce the same result + fn get_value(&self, iter: T) -> Self::Output; + } + + /// Trait representing an operation on a [`ArrayIter`] + /// that can be compared against a slice iterator. + /// + /// This is for mutating operations (e.g. `position`, `any`, `find`, etc) + trait MutatingArrayIteratorOp { + /// What the operation returns (e.g. Option for last, usize for count, etc) + type Output: PartialEq + Debug; + + /// The name of the operation, used for error messages + fn name(&self) -> String; + + /// Get the value of the operation for the provided iterator + /// This will be either a [`ArrayIter`] or a slice iterator to make sure they produce the same result + fn get_value(&self, iter: &mut T) -> Self::Output; + } + + /// Helper function that will assert that the provided operation + /// produces the same result for both [`ArrayIter`] and slice iterator + /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) + fn assert_array_iterator_cases(o: O) { + setup_and_assert_cases(NoSetup, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct Next; + impl SetupIter for Next { + fn setup(&self, iter: &mut I) { + iter.next(); + } + } + setup_and_assert_cases(Next, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the start (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextBack; + impl SetupIter for NextBack { + fn setup(&self, iter: &mut I) { + iter.next_back(); + } + } + + setup_and_assert_cases(NextBack, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the end (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextAndBack; + impl SetupIter for NextAndBack { + fn setup(&self, iter: &mut I) { + iter.next(); + iter.next_back(); + } + } + + setup_and_assert_cases(NextAndBack, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming 1 element from start and end (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextUntilLast; + impl SetupIter for NextUntilLast { + fn setup(&self, iter: &mut I) { + let len = iter.len(); + if len > 1 { + iter.nth(len - 2); + } + } + } + setup_and_assert_cases(NextUntilLast, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming all from the start but 1 (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextBackUntilFirst; + impl SetupIter for NextBackUntilFirst { + fn setup(&self, iter: &mut I) { + let len = iter.len(); + if len > 1 { + iter.nth_back(len - 2); + } + } + } + setup_and_assert_cases(NextBackUntilFirst, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming all from the end but 1 (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextFinish; + impl SetupIter for NextFinish { + fn setup(&self, iter: &mut I) { + iter.nth(iter.len()); + } + } + setup_and_assert_cases(NextFinish, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming all from the start (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextBackFinish; + impl SetupIter for NextBackFinish { + fn setup(&self, iter: &mut I) { + iter.nth_back(iter.len()); + } + } + setup_and_assert_cases(NextBackFinish, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming all from the end (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextUntilLastNone; + impl SetupIter for NextUntilLastNone { + fn setup(&self, iter: &mut I) { + let last_null_position = iter.clone().rposition(|item| item.is_none()); + + // move the iterator to the location where there are no nulls anymore + if let Some(last_null_position) = last_null_position { + iter.nth(last_null_position); + } + } + } + setup_and_assert_cases(NextUntilLastNone, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for iter that have no nulls left (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextUntilLastSome; + impl SetupIter for NextUntilLastSome { + fn setup(&self, iter: &mut I) { + let last_some_position = iter.clone().rposition(|item| item.is_some()); + + // move the iterator to the location where there are only nulls + if let Some(last_some_position) = last_some_position { + iter.nth(last_some_position); + } + } + } + setup_and_assert_cases(NextUntilLastSome, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for iter that only have nulls left (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + } + + /// Helper function that will assert that the provided operation + /// produces the same result for both [`ArrayIter`] and slice iterator + /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) + /// + /// this is different from [`assert_array_iterator_cases`] as this also check that the state after the call is correct + /// to make sure we don't leave the iterator in incorrect state + fn assert_array_iterator_cases_mutate(o: O) { + struct Adapter { + o: O, + } + + #[derive(Debug, PartialEq)] + struct AdapterOutput { + value: Value, + /// collect on the iterator after running the operation + leftover: Vec>, + } + + impl ConsumingArrayIteratorOp for Adapter { + type Output = AdapterOutput; + + fn name(&self) -> String { + self.o.name() + } + + fn get_value( + &self, + mut iter: T, + ) -> Self::Output { + let value = self.o.get_value(&mut iter); + + let leftover: Vec<_> = iter.collect(); + + AdapterOutput { value, leftover } + } + } + + assert_array_iterator_cases(Adapter { o }) + } + + #[derive(Debug, PartialEq)] + struct CallTrackingAndResult { + result: Result, + calls: Vec, + } + type CallTrackingWithInputType = CallTrackingAndResult>; + type CallTrackingOnly = CallTrackingWithInputType<()>; + + #[test] + fn assert_position() { + struct PositionOp { + reverse: bool, + number_of_false: usize, + } + + impl MutatingArrayIteratorOp for PositionOp { + type Output = CallTrackingWithInputType>; + fn name(&self) -> String { + if self.reverse { + format!("rposition with {} false returned", self.number_of_false) + } else { + format!("position with {} false returned", self.number_of_false) + } + } + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = vec![]; + + let mut count = 0; + + let position_result = if self.reverse { + iter.rposition(|item| { + items.push(item); + + if count < self.number_of_false { + count += 1; + false + } else { + true + } + }) + } else { + iter.position(|item| { + items.push(item); + + if count < self.number_of_false { + count += 1; + false + } else { + true + } + }) + }; + + CallTrackingAndResult { + result: position_result, + calls: items, + } + } + } + + for reverse in [false, true] { + for number_of_false in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(PositionOp { + reverse, + number_of_false, + }); + } + } + } + + #[test] + fn assert_nth() { + setup_and_assert_cases(NoSetup, |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth(0); + #[allow(clippy::iter_nth_zero)] + let expected_val = expected.nth(0); + assert_eq!(actual_val, expected_val, "Failed on nth(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(1); + let expected_val = expected.nth(1); + assert_eq!(actual_val, expected_val, "Failed on nth(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(2); + let expected_val = expected.nth(2); + assert_eq!(actual_val, expected_val, "Failed on nth(2)"); + } + } + }); + } + + #[test] + fn assert_nth_back() { + setup_and_assert_cases(NoSetup, |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth_back(0); + #[allow(clippy::iter_nth_zero)] + let expected_val = expected.nth_back(0); + assert_eq!(actual_val, expected_val, "Failed on nth_back(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(1); + let expected_val = expected.nth_back(1); + assert_eq!(actual_val, expected_val, "Failed on nth_back(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(2); + let expected_val = expected.nth_back(2); + assert_eq!(actual_val, expected_val, "Failed on nth_back(2)"); + } + } + }); + } + + #[test] + fn assert_last() { + for (array, source) in get_int32_iterator_cases() { + let mut actual_forward = ArrayIter::new(&array); + let mut expected_forward = source.iter().copied(); + + for _ in 0..source.len() + 1 { + { + let actual_forward_clone = actual_forward.clone(); + let expected_forward_clone = expected_forward.clone(); + + assert_eq!(actual_forward_clone.last(), expected_forward_clone.last()); + } + + actual_forward.next(); + expected_forward.next(); + } + + let mut actual_backward = ArrayIter::new(&array); + let mut expected_backward = source.iter().copied(); + for _ in 0..source.len() + 1 { + { + assert_eq!( + actual_backward.clone().last(), + expected_backward.clone().last() + ); + } + + actual_backward.next_back(); + expected_backward.next_back(); + } + } + } + + #[test] + fn assert_for_each() { + struct ForEachOp; + + impl ConsumingArrayIteratorOp for ForEachOp { + type Output = CallTrackingOnly; + + fn name(&self) -> String { + "for_each".to_string() + } + + fn get_value(&self, iter: T) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + iter.for_each(|item| { + items.push(item); + }); + + CallTrackingAndResult { + calls: items, + result: (), + } + } + } + + assert_array_iterator_cases(ForEachOp) + } + + #[test] + fn assert_fold() { + struct FoldOp { + reverse: bool, + } + + #[derive(Debug, PartialEq)] + struct CallArgs { + acc: Option, + item: Option, + } + + impl ConsumingArrayIteratorOp for FoldOp { + type Output = CallTrackingAndResult, CallArgs>; + + fn name(&self) -> String { + if self.reverse { + "rfold".to_string() + } else { + "fold".to_string() + } + } + + fn get_value(&self, iter: T) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + let result = if self.reverse { + iter.rfold(Some(1), |acc, item| { + items.push(CallArgs { item, acc }); + + item.map(|val| val + 100) + }) + } else { + #[allow(clippy::manual_try_fold)] + iter.fold(Some(1), |acc, item| { + items.push(CallArgs { item, acc }); + + item.map(|val| val + 100) + }) + }; + + CallTrackingAndResult { + calls: items, + result, + } + } + } + + assert_array_iterator_cases(FoldOp { reverse: false }); + assert_array_iterator_cases(FoldOp { reverse: true }); + } + + #[test] + fn assert_count() { + struct CountOp; + + impl ConsumingArrayIteratorOp for CountOp { + type Output = usize; + + fn name(&self) -> String { + "count".to_string() + } + + fn get_value(&self, iter: T) -> Self::Output { + iter.count() + } + } + + assert_array_iterator_cases(CountOp) + } + + #[test] + fn assert_any() { + struct AnyOp { + false_count: usize, + } + + impl MutatingArrayIteratorOp for AnyOp { + type Output = CallTrackingWithInputType; + + fn name(&self) -> String { + format!("any with {} false returned", self.false_count) + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + let mut count = 0; + let res = iter.any(|item| { + items.push(item); + + if count < self.false_count { + count += 1; + false + } else { + true + } + }); + + CallTrackingWithInputType { + calls: items, + result: res, + } + } + } + + for false_count in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(AnyOp { false_count }); + } + } + + #[test] + fn assert_all() { + struct AllOp { + true_count: usize, + } + + impl MutatingArrayIteratorOp for AllOp { + type Output = CallTrackingWithInputType; + + fn name(&self) -> String { + format!("all with {} false returned", self.true_count) + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + let mut count = 0; + let res = iter.all(|item| { + items.push(item); + + if count < self.true_count { + count += 1; + true + } else { + false + } + }); + + CallTrackingWithInputType { + calls: items, + result: res, + } + } + } + + for true_count in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(AllOp { true_count }); + } + } + + #[test] + fn assert_find() { + struct FindOp { + reverse: bool, + false_count: usize, + } + + impl MutatingArrayIteratorOp for FindOp { + type Output = CallTrackingWithInputType>>; + + fn name(&self) -> String { + if self.reverse { + format!("rfind with {} false returned", self.false_count) + } else { + format!("find with {} false returned", self.false_count) + } + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = vec![]; + + let mut count = 0; + + let position_result = if self.reverse { + iter.rfind(|item| { + items.push(*item); + + if count < self.false_count { + count += 1; + false + } else { + true + } + }) + } else { + iter.find(|item| { + items.push(*item); + + if count < self.false_count { + count += 1; + false + } else { + true + } + }) + }; + + CallTrackingWithInputType { + calls: items, + result: position_result, + } + } + } + + for reverse in [false, true] { + for false_count in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(FindOp { + reverse, + false_count, + }); + } + } + } + + #[test] + fn assert_find_map() { + struct FindMapOp { + number_of_nones: usize, + } + + impl MutatingArrayIteratorOp for FindMapOp { + type Output = CallTrackingWithInputType>; + + fn name(&self) -> String { + format!("find_map with {} None returned", self.number_of_nones) + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = vec![]; + + let mut count = 0; + + let result = iter.find_map(|item| { + items.push(item); + + if count < self.number_of_nones { + count += 1; + None + } else { + Some("found it") + } + }); + + CallTrackingAndResult { + result, + calls: items, + } + } + } + + for number_of_nones in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(FindMapOp { number_of_nones }); + } + } + + #[test] + fn assert_partition() { + struct PartitionOp) -> bool> { + description: &'static str, + predicate: F, + } + + #[derive(Debug, PartialEq)] + struct PartitionResult { + left: Vec>, + right: Vec>, + } + + impl) -> bool> ConsumingArrayIteratorOp for PartitionOp { + type Output = CallTrackingWithInputType; + + fn name(&self) -> String { + format!("partition by {}", self.description) + } + + fn get_value(&self, iter: T) -> Self::Output { + let mut items = vec![]; + + let mut index = 0; + + let (left, right) = iter.partition(|item| { + items.push(*item); + + let res = (self.predicate)(index, item); + + index += 1; + res + }); + + CallTrackingAndResult { + result: PartitionResult { left, right }, + calls: items, + } + } + } + + assert_array_iterator_cases(PartitionOp { + description: "None on one side and Some(*) on the other", + predicate: |_, item| item.is_none(), + }); + + assert_array_iterator_cases(PartitionOp { + description: "all true", + predicate: |_, _| true, + }); + + assert_array_iterator_cases(PartitionOp { + description: "all false", + predicate: |_, _| false, + }); + + let random_values = (0..100).map(|_| rand::random_bool(0.5)).collect::>(); + assert_array_iterator_cases(PartitionOp { + description: "random", + predicate: |index, _| random_values[index % random_values.len()], + }); + } } diff --git a/arrow-avro/README.md b/arrow-avro/README.md index f89fc97d242f..85fd76094755 100644 --- a/arrow-avro/README.md +++ b/arrow-avro/README.md @@ -44,14 +44,14 @@ This crate provides: ```toml [dependencies] -arrow-avro = "56" +arrow-avro = "57.0.0" ```` Disable defaults and pick only what you need (see **Feature Flags**): ```toml [dependencies] -arrow-avro = { version = "56", default-features = false, features = ["deflate", "snappy"] } +arrow-avro = { version = "57.0.0", default-features = false, features = ["deflate", "snappy"] } ``` --- diff --git a/arrow-buffer/src/util/bit_iterator.rs b/arrow-buffer/src/util/bit_iterator.rs index c7f6f94fb869..0aa94a5d4dc1 100644 --- a/arrow-buffer/src/util/bit_iterator.rs +++ b/arrow-buffer/src/util/bit_iterator.rs @@ -23,6 +23,7 @@ use crate::bit_util::{ceil, get_bit_raw}; /// Iterator over the bits within a packed bitmask /// /// To efficiently iterate over just the set bits see [`BitIndexIterator`] and [`BitSliceIterator`] +#[derive(Clone)] pub struct BitIterator<'a> { buffer: &'a [u8], current_offset: usize, @@ -71,6 +72,71 @@ impl Iterator for BitIterator<'_> { let remaining_bits = self.end_offset - self.current_offset; (remaining_bits, Some(remaining_bits)) } + + fn count(self) -> usize + where + Self: Sized, + { + self.len() + } + + fn nth(&mut self, n: usize) -> Option { + // Check if we can advance to the desired offset. + // When n is 0 it means we want the next() value + // and when n is 1 we want the next().next() value + // so adding n to the current offset and not n - 1 + match self.current_offset.checked_add(n) { + // Yes, and still within bounds + Some(new_offset) if new_offset < self.end_offset => { + self.current_offset = new_offset; + } + + // Either overflow or would exceed end_offset + _ => { + self.current_offset = self.end_offset; + return None; + } + } + + self.next() + } + + fn last(mut self) -> Option { + // If already at the end, return None + if self.current_offset == self.end_offset { + return None; + } + + // Go to the one before the last bit + self.current_offset = self.end_offset - 1; + + // Return the last bit + self.next() + } + + fn max(self) -> Option + where + Self: Sized, + Self::Item: Ord, + { + if self.current_offset == self.end_offset { + return None; + } + + // true is greater than false so we only need to check if there's any true bit + let mut bit_index_iter = BitIndexIterator::new( + self.buffer, + self.current_offset, + self.end_offset - self.current_offset, + ); + + if bit_index_iter.next().is_some() { + return Some(true); + } + + // We know the iterator is not empty and there are no set bits so false is the max + Some(false) + } } impl ExactSizeIterator for BitIterator<'_> {} @@ -86,6 +152,27 @@ impl DoubleEndedIterator for BitIterator<'_> { let v = unsafe { get_bit_raw(self.buffer.as_ptr(), self.end_offset) }; Some(v) } + + fn nth_back(&mut self, n: usize) -> Option { + // Check if we can advance to the desired offset. + // When n is 0 it means we want the next_back() value + // and when n is 1 we want the next_back().next_back() value + // so subtracting n to the current offset and not n - 1 + match self.end_offset.checked_sub(n) { + // Yes, and still within bounds + Some(new_offset) if self.current_offset < new_offset => { + self.end_offset = new_offset; + } + + // Either underflow or would exceed current_offset + _ => { + self.current_offset = self.end_offset; + return None; + } + } + + self.next_back() + } } /// Iterator of contiguous ranges of set bits within a provided packed bitmask @@ -327,6 +414,12 @@ pub fn try_for_each_valid_idx Result<(), E>>( #[cfg(test)] mod tests { use super::*; + use crate::BooleanBuffer; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::fmt::Debug; + use std::iter::Copied; + use std::slice::Iter; #[test] fn test_bit_iterator_size_hint() { @@ -486,4 +579,427 @@ mod tests { .collect(); assert_eq!(result, expected); } + + trait SharedBetweenBitIteratorAndSliceIter: + ExactSizeIterator + DoubleEndedIterator + { + } + impl + DoubleEndedIterator> + SharedBetweenBitIteratorAndSliceIter for T + { + } + + fn get_bit_iterator_cases() -> impl Iterator)> { + let mut rng = StdRng::seed_from_u64(42); + + [0, 1, 6, 8, 100, 164] + .map(|len| { + let source = (0..len).map(|_| rng.random_bool(0.5)).collect::>(); + + (BooleanBuffer::from(source.as_slice()), source) + }) + .into_iter() + } + + fn setup_and_assert( + setup_iters: impl Fn(&mut dyn SharedBetweenBitIteratorAndSliceIter), + assert_fn: impl Fn(BitIterator, Copied>), + ) { + for (boolean_buffer, source) in get_bit_iterator_cases() { + // Not using `boolean_buffer.iter()` in case the implementation change to not call BitIterator internally + // in which case the test would not test what it intends to test + let mut actual = BitIterator::new(boolean_buffer.values(), 0, boolean_buffer.len()); + let mut expected = source.iter().copied(); + + setup_iters(&mut actual); + setup_iters(&mut expected); + + assert_fn(actual, expected); + } + } + + /// Trait representing an operation on a BitIterator + /// that can be compared against a slice iterator + trait BitIteratorOp { + /// What the operation returns (e.g. Option for last/max, usize for count, etc) + type Output: PartialEq + Debug; + + /// The name of the operation, used for error messages + const NAME: &'static str; + + /// Get the value of the operation for the provided iterator + /// This will be either a BitIterator or a slice iterator to make sure they produce the same result + fn get_value(iter: T) -> Self::Output; + } + + /// Helper function that will assert that the provided operation + /// produces the same result for both BitIterator and slice iterator + /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) + fn assert_bit_iterator_cases() { + setup_and_assert( + |_iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| {}, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + iter.next(); + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the start (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + iter.next_back(); + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the end (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + iter.next(); + iter.next_back(); + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming 1 element from start and end (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.len() > 1 { + iter.next(); + } + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the start but 1 (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.len() > 1 { + iter.next_back(); + } + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the end but 1 (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.next().is_some() {} + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the start (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.next_back().is_some() {} + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the end (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + } + + #[test] + fn assert_bit_iterator_count() { + struct CountOp; + + impl BitIteratorOp for CountOp { + type Output = usize; + const NAME: &'static str = "count"; + + fn get_value(iter: T) -> Self::Output { + iter.count() + } + } + + assert_bit_iterator_cases::() + } + + #[test] + fn assert_bit_iterator_last() { + struct LastOp; + + impl BitIteratorOp for LastOp { + type Output = Option; + const NAME: &'static str = "last"; + + fn get_value(iter: T) -> Self::Output { + iter.last() + } + } + + assert_bit_iterator_cases::() + } + + #[test] + fn assert_bit_iterator_max() { + struct MaxOp; + + impl BitIteratorOp for MaxOp { + type Output = Option; + const NAME: &'static str = "max"; + + fn get_value(iter: T) -> Self::Output { + iter.max() + } + } + + assert_bit_iterator_cases::() + } + + #[test] + fn assert_bit_iterator_nth_0() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { "nth_back(0)" } else { "nth(0)" }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { iter.nth_back(0) } else { iter.nth(0) } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_1() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { "nth_back(1)" } else { "nth(1)" }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { iter.nth_back(1) } else { iter.nth(1) } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_after_end() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { + "nth_back(iter.len() + 1)" + } else { + "nth(iter.len() + 1)" + }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { + iter.nth_back(iter.len() + 1) + } else { + iter.nth(iter.len() + 1) + } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_len() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { + "nth_back(iter.len())" + } else { + "nth(iter.len())" + }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { + iter.nth_back(iter.len()) + } else { + iter.nth(iter.len()) + } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_last() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { + "nth_back(iter.len().saturating_sub(1))" + } else { + "nth(iter.len().saturating_sub(1))" + }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { + iter.nth_back(iter.len().saturating_sub(1)) + } else { + iter.nth(iter.len().saturating_sub(1)) + } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_and_reuse() { + setup_and_assert( + |_| {}, + |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth(0); + #[allow(clippy::iter_nth_zero)] + let expected_val = expected.nth(0); + assert_eq!(actual_val, expected_val, "Failed on nth(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(1); + let expected_val = expected.nth(1); + assert_eq!(actual_val, expected_val, "Failed on nth(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(2); + let expected_val = expected.nth(2); + assert_eq!(actual_val, expected_val, "Failed on nth(2)"); + } + } + }, + ); + } + + #[test] + fn assert_bit_iterator_nth_back_and_reuse() { + setup_and_assert( + |_| {}, + |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth_back(0); + let expected_val = expected.nth_back(0); + assert_eq!(actual_val, expected_val, "Failed on nth_back(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(1); + let expected_val = expected.nth_back(1); + assert_eq!(actual_val, expected_val, "Failed on nth_back(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(2); + let expected_val = expected.nth_back(2); + assert_eq!(actual_val, expected_val, "Failed on nth_back(2)"); + } + } + }, + ); + } } diff --git a/arrow-cast/Cargo.toml b/arrow-cast/Cargo.toml index 12da1af79fe0..fb5ad1af3d3a 100644 --- a/arrow-cast/Cargo.toml +++ b/arrow-cast/Cargo.toml @@ -43,6 +43,7 @@ force_validate = [] arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-data = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } arrow-select = { workspace = true } chrono = { workspace = true } @@ -74,3 +75,4 @@ harness = false [[bench]] name = "parse_decimal" harness = false + diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index 907e61b09f7b..71338a6921e9 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -145,54 +145,82 @@ impl DecimalCast for i256 { } } -pub(crate) fn cast_decimal_to_decimal_error( +/// Construct closures to upscale decimals from `(input_precision, input_scale)` to +/// `(output_precision, output_scale)`. +/// +/// Returns `(f_fallible, f_infallible)` where: +/// * `f_fallible` yields `None` when the requested cast would overflow +/// * `f_infallible` is present only when every input is guaranteed to succeed; otherwise it is `None` +/// and callers must fall back to `f_fallible` +/// +/// Returns `None` if the required scale increase `delta_scale = output_scale - input_scale` +/// exceeds the supported precomputed precision table `O::MAX_FOR_EACH_PRECISION`. +/// In that case, the caller should treat this as an overflow for the output scale +/// and handle it accordingly (e.g., return a cast error). +#[allow(clippy::type_complexity)] +fn make_upscaler( + input_precision: u8, + input_scale: i8, output_precision: u8, output_scale: i8, -) -> impl Fn(::Native) -> ArrowError +) -> Option<( + impl Fn(I::Native) -> Option, + Option O::Native>, +)> where - I: DecimalType, - O: DecimalType, I::Native: DecimalCast + ArrowNativeTypeOp, O::Native: DecimalCast + ArrowNativeTypeOp, { - move |x: I::Native| { - ArrowError::CastError(format!( - "Cannot cast to {}({}, {}). Overflowing on {:?}", - O::PREFIX, - output_precision, - output_scale, - x - )) - } + let delta_scale = output_scale - input_scale; + + // O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999, ...). + // Adding 1 yields exactly 10^k without computing a power at runtime. + // Using the precomputed table avoids pow(10, k) and its checked/overflow + // handling, which is faster and simpler for scaling by 10^delta_scale. + let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; + let mul = max.add_wrapping(O::Native::ONE); + let f_fallible = move |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); + + // if the gain in precision (digits) is greater than the multiplication due to scaling + // every number will fit into the output type + // Example: If we are starting with any number of precision 5 [xxxxx], + // then an increase of scale by 3 will have the following effect on the representation: + // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type + // needs to provide at least 8 digits precision + let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8); + let f_infallible = is_infallible_cast + .then_some(move |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul)); + Some((f_fallible, f_infallible)) } -pub(crate) fn convert_to_smaller_scale_decimal( - array: &PrimitiveArray, +/// Construct closures to downscale decimals from `(input_precision, input_scale)` to +/// `(output_precision, output_scale)`. +/// +/// Returns `(f_fallible, f_infallible)` where: +/// * `f_fallible` yields `None` when the requested cast would overflow +/// * `f_infallible` is present only when every input is guaranteed to succeed; otherwise it is `None` +/// and callers must fall back to `f_fallible` +/// +/// Returns `None` if the required scale reduction `delta_scale = input_scale - output_scale` +/// exceeds the supported precomputed precision table `I::MAX_FOR_EACH_PRECISION`. +/// In this scenario, any value would round to zero (e.g., dividing by 10^k where k exceeds the +/// available precision). Callers should therefore produce zero values (preserving nulls) rather +/// than returning an error. +#[allow(clippy::type_complexity)] +fn make_downscaler( input_precision: u8, input_scale: i8, output_precision: u8, output_scale: i8, - cast_options: &CastOptions, -) -> Result, ArrowError> +) -> Option<( + impl Fn(I::Native) -> Option, + Option O::Native>, +)> where - I: DecimalType, - O: DecimalType, I::Native: DecimalCast + ArrowNativeTypeOp, O::Native: DecimalCast + ArrowNativeTypeOp, { - let error = cast_decimal_to_decimal_error::(output_precision, output_scale); let delta_scale = input_scale - output_scale; - // if the reduction of the input number through scaling (dividing) is greater - // than a possible precision loss (plus potential increase via rounding) - // every input number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then and decrease the scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). - // The rounding may add an additional digit, so the cast to be infallible, - // the output type needs to have at least 3 digits of precision. - // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: - // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible - let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8); // delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the // scale change divides out more digits than the input has precision and the result of the cast @@ -200,16 +228,13 @@ where // possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values // (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even // smaller results, which also round to zero. In that case, just return an array of zeros. - let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize) else { - let zeros = vec![O::Native::ZERO; array.len()]; - return Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned())); - }; + let max = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; let div = max.add_wrapping(I::Native::ONE); let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE)); let half_neg = half.neg_wrapping(); - let f = |x: I::Native| { + let f_fallible = move |x: I::Native| { // div is >= 10 and so this cannot overflow let d = x.div_wrapping(div); let r = x.mod_wrapping(div); @@ -223,24 +248,136 @@ where O::Native::from_decimal(adjusted) }; - Ok(if is_infallible_cast { - // make sure we don't perform calculations that don't make sense w/o validation - validate_decimal_precision_and_scale::(output_precision, output_scale)?; - let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed - // to fit into the target type - array.unary(g) + // if the reduction of the input number through scaling (dividing) is greater + // than a possible precision loss (plus potential increase via rounding) + // every input number will fit into the output type + // Example: If we are starting with any number of precision 5 [xxxxx], + // then and decrease the scale by 3 will have the following effect on the representation: + // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). + // The rounding may add a digit, so the cast to be infallible, + // the output type needs to have at least 3 digits of precision. + // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: + // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible + let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8); + let f_infallible = is_infallible_cast.then_some(move |x| f_fallible(x).unwrap()); + Some((f_fallible, f_infallible)) +} + +/// Apply the rescaler function to the value. +/// If the rescaler is infallible, use the infallible function. +/// Otherwise, use the fallible function and validate the precision. +fn apply_rescaler( + value: I::Native, + output_precision: u8, + f: impl Fn(I::Native) -> Option, + f_infallible: Option O::Native>, +) -> Option +where + I::Native: DecimalCast, + O::Native: DecimalCast, +{ + if let Some(f_infallible) = f_infallible { + Some(f_infallible(value)) + } else { + f(value).filter(|v| O::is_valid_decimal_precision(*v, output_precision)) + } +} + +/// Rescales a decimal value from `(input_precision, input_scale)` to +/// `(output_precision, output_scale)` and returns the converted number when it fits +/// within the output precision. +/// +/// The function first validates that the requested precision and scale are supported for +/// both the source and destination decimal types. It then either upscales (multiplying +/// by an appropriate power of ten) or downscales (dividing with rounding) the input value. +/// When the scaling factor exceeds the precision table of the destination type, the value +/// is treated as an overflow for upscaling, or rounded to zero for downscaling (as any +/// possible result would be zero at the requested scale). +/// +/// This mirrors the column-oriented helpers of decimal casting but operates on a single value +/// (row-level) instead of an entire array. +/// +/// Returns `None` if the value cannot be represented with the requested precision. +pub fn rescale_decimal( + value: I::Native, + input_precision: u8, + input_scale: i8, + output_precision: u8, + output_scale: i8, +) -> Option +where + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + validate_decimal_precision_and_scale::(input_precision, input_scale).ok()?; + validate_decimal_precision_and_scale::(output_precision, output_scale).ok()?; + + if input_scale <= output_scale { + let (f, f_infallible) = + make_upscaler::(input_precision, input_scale, output_precision, output_scale)?; + apply_rescaler::(value, output_precision, f, f_infallible) + } else { + let Some((f, f_infallible)) = + make_downscaler::(input_precision, input_scale, output_precision, output_scale) + else { + // Scale reduction exceeds supported precision; result mathematically rounds to zero + return Some(O::Native::ZERO); + }; + apply_rescaler::(value, output_precision, f, f_infallible) + } +} + +fn cast_decimal_to_decimal_error( + output_precision: u8, + output_scale: i8, +) -> impl Fn(::Native) -> ArrowError +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + move |x: I::Native| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + O::PREFIX, + output_precision, + output_scale, + x + )) + } +} + +fn apply_decimal_cast( + array: &PrimitiveArray, + output_precision: u8, + output_scale: i8, + f_fallible: impl Fn(I::Native) -> Option, + f_infallible: Option O::Native>, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let array = if let Some(f_infallible) = f_infallible { + array.unary(f_infallible) } else if cast_options.safe { - array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) + array.unary_opt(|x| { + f_fallible(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)) + }) } else { + let error = cast_decimal_to_decimal_error::(output_precision, output_scale); array.try_unary(|x| { - f(x).ok_or_else(|| error(x)).and_then(|v| { + f_fallible(x).ok_or_else(|| error(x)).and_then(|v| { O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v) }) })? - }) + }; + Ok(array) } -pub(crate) fn convert_to_bigger_or_equal_scale_decimal( +fn convert_to_smaller_scale_decimal( array: &PrimitiveArray, input_precision: u8, input_scale: i8, @@ -254,36 +391,58 @@ where I::Native: DecimalCast + ArrowNativeTypeOp, O::Native: DecimalCast + ArrowNativeTypeOp, { - let error = cast_decimal_to_decimal_error::(output_precision, output_scale); - let delta_scale = output_scale - input_scale; - let mul = O::Native::from_decimal(10_i128) - .unwrap() - .pow_checked(delta_scale as u32)?; + if let Some((f_fallible, f_infallible)) = + make_downscaler::(input_precision, input_scale, output_precision, output_scale) + { + apply_decimal_cast( + array, + output_precision, + output_scale, + f_fallible, + f_infallible, + cast_options, + ) + } else { + // Scale reduction exceeds supported precision; result mathematically rounds to zero + let zeros = vec![O::Native::ZERO; array.len()]; + Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned())) + } +} - // if the gain in precision (digits) is greater than the multiplication due to scaling - // every number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then an increase of scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type - // needs to provide at least 8 digits precision - let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8); - let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); - - Ok(if is_infallible_cast { - // make sure we don't perform calculations that don't make sense w/o validation - validate_decimal_precision_and_scale::(output_precision, output_scale)?; - // unwrapping is safe since the result is guaranteed to fit into the target type - let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul); - array.unary(f) - } else if cast_options.safe { - array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) +fn convert_to_bigger_or_equal_scale_decimal( + array: &PrimitiveArray, + input_precision: u8, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + if let Some((f, f_infallible)) = + make_upscaler::(input_precision, input_scale, output_precision, output_scale) + { + apply_decimal_cast( + array, + output_precision, + output_scale, + f, + f_infallible, + cast_options, + ) } else { - array.try_unary(|x| { - f(x).ok_or_else(|| error(x)).and_then(|v| { - O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v) - }) - })? - }) + // Scale increase exceeds supported precision; return overflow error + Err(ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Value overflows for output scale", + O::PREFIX, + output_precision, + output_scale + ))) + } } // Only support one type of decimal cast operations @@ -763,4 +922,58 @@ mod tests { ); Ok(()) } + + #[test] + fn test_rescale_decimal_upscale_within_precision() { + let result = rescale_decimal::( + 12_345_i128, // 123.45 with scale 2 + 5, + 2, + 8, + 5, + ); + assert_eq!(result, Some(12_345_000_i128)); + } + + #[test] + fn test_rescale_decimal_downscale_rounds_half_away_from_zero() { + let positive = rescale_decimal::( + 1_050_i128, // 1.050 with scale 3 + 5, 3, 5, 1, + ); + assert_eq!(positive, Some(11_i128)); // 1.1 with scale 1 + + let negative = rescale_decimal::( + -1_050_i128, // -1.050 with scale 3 + 5, + 3, + 5, + 1, + ); + assert_eq!(negative, Some(-11_i128)); // -1.1 with scale 1 + } + + #[test] + fn test_rescale_decimal_downscale_large_delta_returns_zero() { + let result = rescale_decimal::(12_345_i32, 9, 9, 9, 4); + assert_eq!(result, Some(0_i32)); + } + + #[test] + fn test_rescale_decimal_upscale_overflow_returns_none() { + let result = rescale_decimal::(9_999_i32, 4, 0, 5, 2); + assert_eq!(result, None); + } + + #[test] + fn test_rescale_decimal_invalid_input_precision_scale_returns_none() { + let result = rescale_decimal::(123_i128, 39, 39, 38, 38); + assert_eq!(result, None); + } + + #[test] + fn test_rescale_decimal_invalid_output_precision_scale_returns_none() { + let result = rescale_decimal::(123_i128, 38, 38, 39, 39); + assert_eq!(result, None); + } } diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index fe38298b017c..47fdb01a09f4 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -41,11 +41,13 @@ mod decimal; mod dictionary; mod list; mod map; +mod run_array; mod string; use crate::cast::decimal::*; use crate::cast::dictionary::*; use crate::cast::list::*; use crate::cast::map::*; +use crate::cast::run_array::*; use crate::cast::string::*; use arrow_buffer::IntervalMonthDayNano; @@ -67,7 +69,7 @@ use arrow_schema::*; use arrow_select::take::take; use num_traits::{NumCast, ToPrimitive, cast::AsPrimitive}; -pub use decimal::DecimalCast; +pub use decimal::{DecimalCast, rescale_decimal}; /// CastOptions provides a way to override the default cast behaviors #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -139,6 +141,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { can_cast_types(from_value_type, to_value_type) } (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), + (RunEndEncoded(_, value_type), _) => can_cast_types(value_type.data_type(), to_type), + (_, RunEndEncoded(_, value_type)) => can_cast_types(from_type, value_type.data_type()), (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), (List(list_from) | LargeList(list_from), List(list_to) | LargeList(list_to)) => { can_cast_types(list_from.data_type(), list_to.data_type()) @@ -791,6 +795,37 @@ pub fn cast_with_options( | Map(_, _) | Dictionary(_, _), ) => Ok(new_null_array(to_type, array.len())), + (RunEndEncoded(index_type, _), _) => match index_type.data_type() { + Int16 => run_end_encoded_cast::(array, to_type, cast_options), + Int32 => run_end_encoded_cast::(array, to_type, cast_options), + Int64 => run_end_encoded_cast::(array, to_type, cast_options), + _ => Err(ArrowError::CastError(format!( + "Casting from run end encoded type {from_type:?} to {to_type:?} not supported", + ))), + }, + (_, RunEndEncoded(index_type, value_type)) => { + let array_ref = make_array(array.to_data()); + match index_type.data_type() { + Int16 => cast_to_run_end_encoded::( + &array_ref, + value_type.data_type(), + cast_options, + ), + Int32 => cast_to_run_end_encoded::( + &array_ref, + value_type.data_type(), + cast_options, + ), + Int64 => cast_to_run_end_encoded::( + &array_ref, + value_type.data_type(), + cast_options, + ), + _ => Err(ArrowError::CastError(format!( + "Casting from type {from_type:?} to run end encoded type {to_type:?} not supported", + ))), + } + } (Dictionary(index_type, _), _) => match **index_type { Int8 => dictionary_cast::(array, to_type, cast_options), Int16 => dictionary_cast::(array, to_type, cast_options), @@ -2640,10 +2675,14 @@ where #[cfg(test)] mod tests { use super::*; + use DataType::*; + use arrow_array::{Int64Array, RunArray, StringArray}; use arrow_buffer::i256; use arrow_buffer::{Buffer, IntervalDayTime, NullBuffer}; + use arrow_schema::{DataType, Field}; use chrono::NaiveDate; use half::f16; + use std::sync::Arc; #[derive(Clone)] struct DecimalCastTestConfig { @@ -7794,8 +7833,6 @@ mod tests { #[test] fn test_cast_utf8_dict() { // FROM a dictionary with of Utf8 values - use DataType::*; - let mut builder = StringDictionaryBuilder::::new(); builder.append("one").unwrap(); builder.append_null(); @@ -7850,7 +7887,6 @@ mod tests { #[test] fn test_cast_dict_to_dict_bad_index_value_primitive() { - use DataType::*; // test converting from an array that has indexes of a type // that are out of bounds for a particular other kind of // index. @@ -7878,7 +7914,6 @@ mod tests { #[test] fn test_cast_dict_to_dict_bad_index_value_utf8() { - use DataType::*; // Same test as test_cast_dict_to_dict_bad_index_value but use // string values (and encode the expected behavior here); @@ -7907,8 +7942,6 @@ mod tests { #[test] fn test_cast_primitive_dict() { // FROM a dictionary with of INT32 values - use DataType::*; - let mut builder = PrimitiveDictionaryBuilder::::new(); builder.append(1).unwrap(); builder.append_null(); @@ -7929,8 +7962,6 @@ mod tests { #[test] fn test_cast_primitive_array_to_dict() { - use DataType::*; - let mut builder = PrimitiveBuilder::::new(); builder.append_value(1); builder.append_null(); @@ -11417,4 +11448,422 @@ mod tests { "Invalid argument error: -1.0 is too small to store in a Decimal32 of precision 1. Min is -0.9" ); } + + #[test] + fn test_run_end_encoded_to_primitive() { + // Create a RunEndEncoded array: [1, 1, 2, 2, 2, 3] + let run_ends = Int32Array::from(vec![2, 5, 6]); + let values = Int32Array::from(vec![1, 2, 3]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(run_array) as ArrayRef; + // Cast to Int64 + let cast_result = cast(&array_ref, &DataType::Int64).unwrap(); + // Verify the result is a RunArray with Int64 values + let result_run_array = cast_result.as_any().downcast_ref::().unwrap(); + assert_eq!( + result_run_array.values(), + &[1i64, 1i64, 2i64, 2i64, 2i64, 3i64] + ); + } + + #[test] + fn test_run_end_encoded_to_string() { + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Int32Array::from(vec![10, 20, 30]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(run_array) as ArrayRef; + + // Cast to String + let cast_result = cast(&array_ref, &DataType::Utf8).unwrap(); + + // Verify the result is a RunArray with String values + let result_array = cast_result.as_any().downcast_ref::().unwrap(); + // Check that values are correct + assert_eq!(result_array.value(0), "10"); + assert_eq!(result_array.value(1), "10"); + assert_eq!(result_array.value(2), "20"); + } + + #[test] + fn test_primitive_to_run_end_encoded() { + // Create an Int32 array with repeated values: [1, 1, 2, 2, 2, 3] + let source_array = Int32Array::from(vec![1, 1, 2, 2, 2, 3]); + let array_ref = Arc::new(source_array) as ArrayRef; + + // Cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + + // Verify the result is a RunArray + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + + // Check run structure: runs should end at positions [2, 5, 6] + assert_eq!(result_run_array.run_ends().values(), &[2, 5, 6]); + + // Check values: should be [1, 2, 3] + let values_array = result_run_array.values().as_primitive::(); + assert_eq!(values_array.values(), &[1, 2, 3]); + } + + #[test] + fn test_primitive_to_run_end_encoded_with_nulls() { + let source_array = Int32Array::from(vec![ + Some(1), + Some(1), + None, + None, + Some(2), + Some(2), + Some(3), + Some(3), + None, + None, + Some(4), + Some(4), + Some(5), + Some(5), + None, + None, + ]); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!( + result_run_array.run_ends().values(), + &[2, 4, 6, 8, 10, 12, 14, 16] + ); + assert_eq!( + result_run_array + .values() + .as_primitive::() + .values(), + &[1, 0, 2, 3, 0, 4, 5, 0] + ); + assert_eq!(result_run_array.values().null_count(), 3); + } + + #[test] + fn test_primitive_to_run_end_encoded_with_nulls_consecutive() { + let source_array = Int64Array::from(vec![ + Some(1), + Some(1), + None, + None, + None, + None, + None, + None, + None, + None, + Some(4), + Some(20), + Some(500), + Some(500), + None, + None, + ]); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Int64, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!( + result_run_array.run_ends().values(), + &[2, 10, 11, 12, 14, 16] + ); + assert_eq!( + result_run_array + .values() + .as_primitive::() + .values(), + &[1, 0, 4, 20, 500, 0] + ); + assert_eq!(result_run_array.values().null_count(), 2); + } + + #[test] + fn test_string_to_run_end_encoded() { + // Create a String array with repeated values: ["a", "a", "b", "c", "c"] + let source_array = StringArray::from(vec!["a", "a", "b", "c", "c"]); + let array_ref = Arc::new(source_array) as ArrayRef; + + // Cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + + // Verify the result is a RunArray + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + + // Check run structure: runs should end at positions [2, 3, 5] + assert_eq!(result_run_array.run_ends().values(), &[2, 3, 5]); + + // Check values: should be ["a", "b", "c"] + let values_array = result_run_array.values().as_string::(); + assert_eq!(values_array.value(0), "a"); + assert_eq!(values_array.value(1), "b"); + assert_eq!(values_array.value(2), "c"); + } + + #[test] + fn test_empty_array_to_run_end_encoded() { + // Create an empty Int32 array + let source_array = Int32Array::from(Vec::::new()); + let array_ref = Arc::new(source_array) as ArrayRef; + + // Cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + + // Verify the result is an empty RunArray + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + + // Check that both run_ends and values are empty + assert_eq!(result_run_array.run_ends().len(), 0); + assert_eq!(result_run_array.values().len(), 0); + } + + #[test] + fn test_run_end_encoded_with_nulls() { + // Create a RunEndEncoded array with nulls: [1, 1, null, 2, 2] + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Int32Array::from(vec![Some(1), None, Some(2)]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(run_array) as ArrayRef; + + // Cast to String + let cast_result = cast(&array_ref, &DataType::Utf8).unwrap(); + + // Verify the result preserves nulls + let result_run_array = cast_result.as_any().downcast_ref::().unwrap(); + assert_eq!(result_run_array.value(0), "1"); + assert!(result_run_array.is_null(2)); + assert_eq!(result_run_array.value(4), "2"); + } + + #[test] + fn test_different_index_types() { + // Test with Int16 index type + let source_array = Int32Array::from(vec![1, 1, 2, 3, 3]); + let array_ref = Arc::new(source_array) as ArrayRef; + + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + assert_eq!(cast_result.data_type(), &target_type); + + // Verify the cast worked correctly: values are [1, 2, 3] + // and run-ends are [2, 3, 5] + let run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(run_array.values().as_primitive::().value(0), 1); + assert_eq!(run_array.values().as_primitive::().value(1), 2); + assert_eq!(run_array.values().as_primitive::().value(2), 3); + assert_eq!(run_array.run_ends().values(), &[2i16, 3i16, 5i16]); + + // Test again with Int64 index type + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int64, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + assert_eq!(cast_result.data_type(), &target_type); + + // Verify the cast worked correctly: values are [1, 2, 3] + // and run-ends are [2, 3, 5] + let run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(run_array.values().as_primitive::().value(0), 1); + assert_eq!(run_array.values().as_primitive::().value(1), 2); + assert_eq!(run_array.values().as_primitive::().value(2), 3); + assert_eq!(run_array.run_ends().values(), &[2i64, 3i64, 5i64]); + } + + #[test] + fn test_unsupported_cast_to_run_end_encoded() { + // Create a Struct array - complex nested type that might not be supported + let field = Field::new("item", DataType::Int32, false); + let struct_array = StructArray::from(vec![( + Arc::new(field), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + )]); + let array_ref = Arc::new(struct_array) as ArrayRef; + + // This should fail because: + // 1. The target type is not RunEndEncoded + // 2. The target type is not supported for casting from StructArray + let cast_result = cast(&array_ref, &DataType::FixedSizeBinary(10)); + + // Expect this to fail + assert!(cast_result.is_err()); + } + + /// Test casting RunEndEncoded to RunEndEncoded should fail + #[test] + fn test_cast_run_end_encoded_int64_to_int16_should_fail() { + // Construct a valid REE array with Int64 run-ends + let run_ends = Int64Array::from(vec![100_000, 400_000, 700_000]); // values too large for Int16 + let values = StringArray::from(vec!["a", "b", "c"]); + + let ree_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(ree_array) as ArrayRef; + + // Attempt to cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_options = CastOptions { + safe: false, // This should make it fail instead of returning nulls + format_options: FormatOptions::default(), + }; + + // This should fail due to run-end overflow + let result: Result, ArrowError> = + cast_with_options(&array_ref, &target_type, &cast_options); + + let e = result.expect_err("Cast should have failed but succeeded"); + assert!( + e.to_string() + .contains("Cast error: Can't cast value 100000 to type Int16") + ); + } + + #[test] + fn test_cast_run_end_encoded_int64_to_int16_with_safe_should_fail_with_null_invalid_error() { + // Construct a valid REE array with Int64 run-ends + let run_ends = Int64Array::from(vec![100_000, 400_000, 700_000]); // values too large for Int16 + let values = StringArray::from(vec!["a", "b", "c"]); + + let ree_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(ree_array) as ArrayRef; + + // Attempt to cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_options = CastOptions { + safe: true, + format_options: FormatOptions::default(), + }; + + // This fails even though safe is true because the run_ends array has null values + let result: Result, ArrowError> = + cast_with_options(&array_ref, &target_type, &cast_options); + let e = result.expect_err("Cast should have failed but succeeded"); + assert!( + e.to_string() + .contains("Invalid argument error: Found null values in run_ends array. The run_ends array should not have null values.") + ); + } + + /// Test casting RunEndEncoded to RunEndEncoded should succeed + #[test] + fn test_cast_run_end_encoded_int16_to_int64_should_succeed() { + // Construct a valid REE array with Int16 run-ends + let run_ends = Int16Array::from(vec![2, 5, 8]); // values that fit in Int16 + let values = StringArray::from(vec!["a", "b", "c"]); + + let ree_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(ree_array) as ArrayRef; + + // Attempt to cast to RunEndEncoded (upcast should succeed) + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int64, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + + // This should succeed due to valid upcast + let result: Result, ArrowError> = + cast_with_options(&array_ref, &target_type, &cast_options); + + let array_ref = result.expect("Cast should have succeeded but failed"); + // Downcast to RunArray + let run_array = array_ref + .as_any() + .downcast_ref::>() + .unwrap(); + + // Verify the cast worked correctly + // Assert the values were cast correctly + assert_eq!(run_array.run_ends().values(), &[2i64, 5i64, 8i64]); + assert_eq!(run_array.values().as_string::().value(0), "a"); + assert_eq!(run_array.values().as_string::().value(1), "b"); + assert_eq!(run_array.values().as_string::().value(2), "c"); + } + + #[test] + fn test_cast_run_end_encoded_dictionary_to_run_end_encoded() { + // Construct a valid dictionary encoded array + let values = StringArray::from_iter([Some("a"), Some("b"), Some("c")]); + let keys = UInt64Array::from_iter(vec![1, 1, 1, 0, 0, 0, 2, 2, 2]); + let array_ref = Arc::new(DictionaryArray::new(keys, Arc::new(values))) as ArrayRef; + + // Attempt to cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int64, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + + // This should succeed + let result = cast_with_options(&array_ref, &target_type, &cast_options) + .expect("Cast should have succeeded but failed"); + + // Verify the cast worked correctly + // Assert the values were cast correctly + let run_array = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(run_array.values().as_string::().value(0), "b"); + assert_eq!(run_array.values().as_string::().value(1), "a"); + assert_eq!(run_array.values().as_string::().value(2), "c"); + + // Verify the run-ends were cast correctly (run ends at 3, 6, 9) + assert_eq!(run_array.run_ends().values(), &[3i64, 6i64, 9i64]); + } } diff --git a/arrow-cast/src/cast/run_array.rs b/arrow-cast/src/cast/run_array.rs new file mode 100644 index 000000000000..8d70afef3ab6 --- /dev/null +++ b/arrow-cast/src/cast/run_array.rs @@ -0,0 +1,164 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::*; +use arrow_ord::partition::partition; + +/// Attempts to cast a `RunArray` with index type K into +/// `to_type` for supported types. +pub(crate) fn run_end_encoded_cast( + array: &dyn Array, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result { + match array.data_type() { + DataType::RunEndEncoded(_, _) => { + let run_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| ArrowError::CastError("Expected RunArray".to_string()))?; + + let values = run_array.values(); + + match to_type { + // Stay as RunEndEncoded, cast only the values + DataType::RunEndEncoded(target_index_field, target_value_field) => { + let cast_values = + cast_with_options(values, target_value_field.data_type(), cast_options)?; + + let run_ends_array = PrimitiveArray::::from_iter_values( + run_array.run_ends().values().iter().copied(), + ); + let cast_run_ends = cast_with_options( + &run_ends_array, + target_index_field.data_type(), + cast_options, + )?; + let new_run_array: ArrayRef = match target_index_field.data_type() { + DataType::Int16 => { + let re = cast_run_ends.as_primitive::(); + Arc::new(RunArray::::try_new(re, cast_values.as_ref())?) + } + DataType::Int32 => { + let re = cast_run_ends.as_primitive::(); + Arc::new(RunArray::::try_new(re, cast_values.as_ref())?) + } + DataType::Int64 => { + let re = cast_run_ends.as_primitive::(); + Arc::new(RunArray::::try_new(re, cast_values.as_ref())?) + } + _ => { + return Err(ArrowError::CastError( + "Run-end type must be i16, i32, or i64".to_string(), + )); + } + }; + Ok(Arc::new(new_run_array)) + } + + // Expand to logical form + _ => { + let run_ends = run_array.run_ends().values().to_vec(); + let mut indices = Vec::with_capacity(run_array.run_ends().len()); + let mut physical_idx: usize = 0; + for logical_idx in 0..run_array.run_ends().len() { + // If the logical index is equal to the (next) run end, increment the physical index, + // since we are at the end of a run. + if logical_idx == run_ends[physical_idx].as_usize() { + physical_idx += 1; + } + indices.push(physical_idx as i32); + } + + let taken = take(&values, &Int32Array::from_iter_values(indices), None)?; + if taken.data_type() != to_type { + cast_with_options(taken.as_ref(), to_type, cast_options) + } else { + Ok(taken) + } + } + } + } + + _ => Err(ArrowError::CastError(format!( + "Cannot cast array of type {:?} to RunEndEncodedArray", + array.data_type() + ))), + } +} + +/// Attempts to encode an array into a `RunArray` with index type K +/// and value type `value_type` +pub(crate) fn cast_to_run_end_encoded( + array: &ArrayRef, + value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let mut run_ends_builder = PrimitiveBuilder::::new(); + + // Cast the input array to the target value type if necessary + let cast_array = if array.data_type() == value_type { + array + } else { + &cast_with_options(array, value_type, cast_options)? + }; + + // Return early if the array to cast is empty + if cast_array.is_empty() { + let empty_run_ends = run_ends_builder.finish(); + let empty_values = make_array(ArrayData::new_empty(value_type)); + return Ok(Arc::new(RunArray::::try_new( + &empty_run_ends, + empty_values.as_ref(), + )?)); + } + + // REE arrays are handled by run_end_encoded_cast + if let DataType::RunEndEncoded(_, _) = array.data_type() { + return Err(ArrowError::CastError( + "Source array is already a RunEndEncoded array, should have been handled by run_end_encoded_cast".to_string() + )); + } + + // Partition the array to identify runs of consecutive equal values + let partitions = partition(&[Arc::clone(cast_array)])?; + let mut run_ends = Vec::new(); + let mut values_indexes = Vec::new(); + let mut last_partition_end = 0; + for partition in partitions.ranges() { + values_indexes.push(last_partition_end); + run_ends.push(partition.end); + last_partition_end = partition.end; + } + + // Build the run_ends array + for run_end in run_ends { + run_ends_builder.append_value(K::Native::from_usize(run_end).ok_or_else(|| { + ArrowError::CastError(format!("Run end index out of range: {}", run_end)) + })?); + } + let run_ends_array = run_ends_builder.finish(); + // Build the values array by taking elements at the run start positions + let indices = PrimitiveArray::::from_iter_values( + values_indexes.iter().map(|&idx| idx as u32), + ); + let values_array = take(&cast_array, &indices, None)?; + + // Create and return the RunArray + let run_array = RunArray::::try_new(&run_ends_array, values_array.as_ref())?; + Ok(Arc::new(run_array)) +} diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index c31ac0c6e693..91957e14f332 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -980,7 +980,15 @@ impl ArrayData { ) -> Result<(), ArrowError> { let offsets: &[T] = self.typed_buffer(0, self.len)?; let sizes: &[T] = self.typed_buffer(1, self.len)?; - for i in 0..values_length { + if offsets.len() != sizes.len() { + return Err(ArrowError::ComputeError(format!( + "ListView offsets len {} does not match sizes len {}", + offsets.len(), + sizes.len() + ))); + } + + for i in 0..sizes.len() { let size = sizes[i].to_usize().ok_or_else(|| { ArrowError::InvalidArgumentError(format!( "Error converting size[{}] ({}) to usize for {}", diff --git a/arrow-data/src/equal/list_view.rs b/arrow-data/src/equal/list_view.rs new file mode 100644 index 000000000000..c7cb31db9099 --- /dev/null +++ b/arrow-data/src/equal/list_view.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::ArrayData; +use crate::data::count_nulls; +use crate::equal::equal_values; +use arrow_buffer::ArrowNativeType; +use num_integer::Integer; + +pub(super) fn list_view_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_offsets = lhs.buffer::(0); + let lhs_sizes = lhs.buffer::(1); + + let rhs_offsets = rhs.buffer::(0); + let rhs_sizes = rhs.buffer::(1); + + let lhs_data = &lhs.child_data()[0]; + let rhs_data = &rhs.child_data()[0]; + + let lhs_null_count = count_nulls(lhs.nulls(), lhs_start, len); + let rhs_null_count = count_nulls(rhs.nulls(), rhs_start, len); + + if lhs_null_count != rhs_null_count { + return false; + } + + if lhs_null_count == 0 { + // non-null pathway: all sizes must be equal, and all values must be equal + let lhs_range_sizes = &lhs_sizes[lhs_start..lhs_start + len]; + let rhs_range_sizes = &rhs_sizes[rhs_start..rhs_start + len]; + + if lhs_range_sizes.len() != rhs_range_sizes.len() { + return false; + } + + if lhs_range_sizes != rhs_range_sizes { + return false; + } + + // Check values for equality + let lhs_range_offsets = &lhs_offsets[lhs_start..lhs_start + len]; + let rhs_range_offsets = &rhs_offsets[rhs_start..rhs_start + len]; + + if lhs_range_offsets.len() != rhs_range_offsets.len() { + return false; + } + + for ((&lhs_offset, &rhs_offset), &size) in lhs_range_offsets + .iter() + .zip(rhs_range_offsets) + .zip(lhs_range_sizes) + { + let lhs_offset = lhs_offset.to_usize().unwrap(); + let rhs_offset = rhs_offset.to_usize().unwrap(); + let size = size.to_usize().unwrap(); + + // Check if offsets are valid for the given range + if !equal_values(lhs_data, rhs_data, lhs_offset, rhs_offset, size) { + return false; + } + } + } else { + // Need to integrate validity check in the inner loop. + // non-null pathway: all sizes must be equal, and all values must be equal + let lhs_range_sizes = &lhs_sizes[lhs_start..lhs_start + len]; + let rhs_range_sizes = &rhs_sizes[rhs_start..rhs_start + len]; + + let lhs_nulls = lhs.nulls().unwrap().slice(lhs_start, len); + let rhs_nulls = rhs.nulls().unwrap().slice(rhs_start, len); + + // Sizes can differ if values are null + if lhs_range_sizes.len() != rhs_range_sizes.len() { + return false; + } + + // Check values for equality, with null checking + let lhs_range_offsets = &lhs_offsets[lhs_start..lhs_start + len]; + let rhs_range_offsets = &rhs_offsets[rhs_start..rhs_start + len]; + + if lhs_range_offsets.len() != rhs_range_offsets.len() { + return false; + } + + for (index, ((&lhs_offset, &rhs_offset), &size)) in lhs_range_offsets + .iter() + .zip(rhs_range_offsets) + .zip(lhs_range_sizes) + .enumerate() + { + let lhs_is_null = lhs_nulls.is_null(index); + let rhs_is_null = rhs_nulls.is_null(index); + + if lhs_is_null != rhs_is_null { + return false; + } + + let lhs_offset = lhs_offset.to_usize().unwrap(); + let rhs_offset = rhs_offset.to_usize().unwrap(); + let size = size.to_usize().unwrap(); + + // Check if values match in the range + if !lhs_is_null && !equal_values(lhs_data, rhs_data, lhs_offset, rhs_offset, size) { + return false; + } + } + } + + true +} diff --git a/arrow-data/src/equal/mod.rs b/arrow-data/src/equal/mod.rs index 1c16ee2f8a14..7a310b1240df 100644 --- a/arrow-data/src/equal/mod.rs +++ b/arrow-data/src/equal/mod.rs @@ -30,6 +30,7 @@ mod dictionary; mod fixed_binary; mod fixed_list; mod list; +mod list_view; mod null; mod primitive; mod run; @@ -41,6 +42,8 @@ mod variable_size; // these methods assume the same type, len and null count. // For this reason, they are not exposed and are instead used // to build the generic functions below (`equal_range` and `equal`). +use self::run::run_equal; +use crate::equal::list_view::list_view_equal; use boolean::boolean_equal; use byte_view::byte_view_equal; use dictionary::dictionary_equal; @@ -53,8 +56,6 @@ use structure::struct_equal; use union::union_equal; use variable_size::variable_sized_equal; -use self::run::run_equal; - /// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively /// for `len` slots. #[inline] @@ -104,10 +105,9 @@ fn equal_values( byte_view_equal(lhs, rhs, lhs_start, rhs_start, len) } DataType::List(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), - DataType::ListView(_) | DataType::LargeListView(_) => { - unimplemented!("ListView/LargeListView not yet implemented") - } DataType::LargeList(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::ListView(_) => list_view_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::LargeListView(_) => list_view_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::FixedSizeList(_, _) => fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len), diff --git a/arrow-data/src/transform/list_view.rs b/arrow-data/src/transform/list_view.rs new file mode 100644 index 000000000000..9b66a6a6abb1 --- /dev/null +++ b/arrow-data/src/transform/list_view.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::ArrayData; +use crate::transform::_MutableArrayData; +use arrow_buffer::ArrowNativeType; +use num_integer::Integer; +use num_traits::CheckedAdd; + +pub(super) fn build_extend( + array: &ArrayData, +) -> crate::transform::Extend<'_> { + let offsets = array.buffer::(0); + let sizes = array.buffer::(1); + Box::new( + move |mutable: &mut _MutableArrayData, _index: usize, start: usize, len: usize| { + let offset_buffer = &mut mutable.buffer1; + let sizes_buffer = &mut mutable.buffer2; + + for &offset in &offsets[start..start + len] { + offset_buffer.push(offset); + } + + // sizes + for &size in &sizes[start..start + len] { + sizes_buffer.push(size); + } + + // the beauty of views is that we don't need to copy child_data, we just splat + // the offsets and sizes. + }, + ) +} + +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { + let offset_buffer = &mut mutable.buffer1; + let sizes_buffer = &mut mutable.buffer2; + + // We push 0 as a placeholder for NULL values in both the offsets and sizes + (0..len).for_each(|_| offset_buffer.push(T::default())); + (0..len).for_each(|_| sizes_buffer.push(T::default())); +} diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs index 5b994046e6ca..c6052817bfb6 100644 --- a/arrow-data/src/transform/mod.rs +++ b/arrow-data/src/transform/mod.rs @@ -33,6 +33,7 @@ mod boolean; mod fixed_binary; mod fixed_size_list; mod list; +mod list_view; mod null; mod primitive; mod run; @@ -265,10 +266,9 @@ fn build_extend(array: &ArrayData) -> Extend<'_> { DataType::LargeUtf8 | DataType::LargeBinary => variable_size::build_extend::(array), DataType::BinaryView | DataType::Utf8View => unreachable!("should use build_extend_view"), DataType::Map(_, _) | DataType::List(_) => list::build_extend::(array), - DataType::ListView(_) | DataType::LargeListView(_) => { - unimplemented!("ListView/LargeListView not implemented") - } DataType::LargeList(_) => list::build_extend::(array), + DataType::ListView(_) => list_view::build_extend::(array), + DataType::LargeListView(_) => list_view::build_extend::(array), DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"), DataType::Struct(_) => structure::build_extend(array), DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), @@ -313,10 +313,9 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { DataType::LargeUtf8 | DataType::LargeBinary => variable_size::extend_nulls::, DataType::BinaryView | DataType::Utf8View => primitive::extend_nulls::, DataType::Map(_, _) | DataType::List(_) => list::extend_nulls::, - DataType::ListView(_) | DataType::LargeListView(_) => { - unimplemented!("ListView/LargeListView not implemented") - } DataType::LargeList(_) => list::extend_nulls::, + DataType::ListView(_) => list_view::extend_nulls::, + DataType::LargeListView(_) => list_view::extend_nulls::, DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { DataType::UInt8 => primitive::extend_nulls::, DataType::UInt16 => primitive::extend_nulls::, @@ -450,7 +449,11 @@ impl<'a> MutableArrayData<'a> { new_buffers(data_type, *capacity) } ( - DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _), + DataType::List(_) + | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) + | DataType::FixedSizeList(_, _), Capacities::List(capacity, _), ) => { array_capacity = *capacity; @@ -491,10 +494,11 @@ impl<'a> MutableArrayData<'a> { | DataType::Utf8View | DataType::Interval(_) | DataType::FixedSizeBinary(_) => vec![], - DataType::ListView(_) | DataType::LargeListView(_) => { - unimplemented!("ListView/LargeListView not implemented") - } - DataType::Map(_, _) | DataType::List(_) | DataType::LargeList(_) => { + DataType::Map(_, _) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) => { let children = arrays .iter() .map(|array| &array.child_data()[0]) @@ -785,7 +789,12 @@ impl<'a> MutableArrayData<'a> { b.insert(0, data.buffer1.into()); b } - DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary => { + DataType::Utf8 + | DataType::Binary + | DataType::LargeUtf8 + | DataType::LargeBinary + | DataType::ListView(_) + | DataType::LargeListView(_) => { vec![data.buffer1.into(), data.buffer2.into()] } DataType::Union(_, mode) => { diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index db7758047c43..5f690e9a6734 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -913,9 +913,13 @@ impl RowConverter { 0, "can't construct Rows instance from array with nulls" ); + let (offsets, values, _) = array.into_parts(); + let offsets = offsets.iter().map(|&i| i.as_usize()).collect(); + // Try zero-copy, if it does not succeed, fall back to copying the values. + let buffer = values.into_vec().unwrap_or_else(|values| values.to_vec()); Rows { - buffer: array.values().to_vec(), - offsets: array.offsets().iter().map(|&i| i.as_usize()).collect(), + buffer, + offsets, config: RowConfig { fields: Arc::clone(&self.fields), validate_utf8: true, @@ -2474,6 +2478,19 @@ mod tests { assert!(rows.row(3) < rows.row(0)); } + #[test] + fn test_from_binary_shared_buffer() { + let converter = RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap(); + let array = Arc::new(BinaryArray::from_iter_values([&[0xFF]])) as _; + let rows = converter.convert_columns(&[array]).unwrap(); + let binary_rows = rows.try_into_binary().expect("known-small rows"); + let _binary_rows_shared_buffer = binary_rows.clone(); + + let parsed = converter.from_binary(binary_rows); + + converter.convert_rows(parsed.iter()).unwrap(); + } + #[test] #[should_panic(expected = "Encountered non UTF-8 data")] fn test_invalid_utf8() { diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index 83bc5c2763d2..3bfdd31ccf2d 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -37,7 +37,9 @@ use arrow_array::builder::{ use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; -use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, OffsetBuffer}; +use arrow_buffer::{ + ArrowNativeType, BooleanBufferBuilder, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer, +}; use arrow_data::ArrayDataBuilder; use arrow_data::transform::{Capacities, MutableArrayData}; use arrow_schema::{ArrowError, DataType, FieldRef, Fields, SchemaRef}; @@ -206,6 +208,63 @@ fn concat_lists( Ok(Arc::new(array)) } +fn concat_list_view( + arrays: &[&dyn Array], + field: &FieldRef, +) -> Result { + let mut output_len = 0; + let mut list_has_nulls = false; + + let lists = arrays + .iter() + .map(|x| x.as_list_view::()) + .inspect(|l| { + output_len += l.len(); + list_has_nulls |= l.null_count() != 0; + }) + .collect::>(); + + let lists_nulls = list_has_nulls.then(|| { + let mut nulls = BooleanBufferBuilder::new(output_len); + for l in &lists { + match l.nulls() { + Some(n) => nulls.append_buffer(n.inner()), + None => nulls.append_n(l.len(), true), + } + } + NullBuffer::new(nulls.finish()) + }); + + let values: Vec<&dyn Array> = lists.iter().map(|l| l.values().as_ref()).collect(); + + let concatenated_values = concat(values.as_slice())?; + + let sizes: ScalarBuffer = lists.iter().flat_map(|x| x.sizes()).copied().collect(); + + let mut offsets = MutableBuffer::with_capacity(lists.iter().map(|l| l.offsets().len()).sum()); + let mut global_offset = OffsetSize::zero(); + for l in lists.iter() { + for &offset in l.offsets() { + offsets.push(offset + global_offset); + } + + // advance the offsets + global_offset += OffsetSize::from_usize(l.values().len()).unwrap(); + } + + let offsets = ScalarBuffer::from(offsets); + + let array = GenericListViewArray::try_new( + field.clone(), + offsets, + sizes, + concatenated_values, + lists_nulls, + )?; + + Ok(Arc::new(array)) +} + fn concat_primitives(arrays: &[&dyn Array]) -> Result { let mut builder = PrimitiveBuilder::::with_capacity(arrays.iter().map(|a| a.len()).sum()) .with_data_type(arrays[0].data_type().clone()); @@ -422,6 +481,8 @@ pub fn concat(arrays: &[&dyn Array]) -> Result { } DataType::List(field) => concat_lists::(arrays, field), DataType::LargeList(field) => concat_lists::(arrays, field), + DataType::ListView(field) => concat_list_view::(arrays, field), + DataType::LargeListView(field) => concat_list_view::(arrays, field), DataType::Struct(fields) => concat_structs(arrays, fields), DataType::Utf8 => concat_bytes::(arrays), DataType::LargeUtf8 => concat_bytes::(arrays), @@ -500,7 +561,9 @@ pub fn concat_batches<'a>( #[cfg(test)] mod tests { use super::*; - use arrow_array::builder::{GenericListBuilder, StringDictionaryBuilder}; + use arrow_array::builder::{ + GenericListBuilder, Int64Builder, ListViewBuilder, StringDictionaryBuilder, + }; use arrow_schema::{Field, Schema}; use std::fmt::Debug; @@ -768,7 +831,7 @@ mod tests { #[test] fn test_concat_primitive_list_arrays() { - let list1 = vec![ + let list1 = [ Some(vec![Some(-1), Some(-1), Some(2), None, None]), Some(vec![]), None, @@ -776,14 +839,14 @@ mod tests { ]; let list1_array = ListArray::from_iter_primitive::(list1.clone()); - let list2 = vec![ + let list2 = [ None, Some(vec![Some(100), None, Some(101)]), Some(vec![Some(102)]), ]; let list2_array = ListArray::from_iter_primitive::(list2.clone()); - let list3 = vec![Some(vec![Some(1000), Some(1001)])]; + let list3 = [Some(vec![Some(1000), Some(1001)])]; let list3_array = ListArray::from_iter_primitive::(list3.clone()); let array_result = concat(&[&list1_array, &list2_array, &list3_array]).unwrap(); @@ -796,7 +859,7 @@ mod tests { #[test] fn test_concat_primitive_list_arrays_slices() { - let list1 = vec![ + let list1 = [ Some(vec![Some(-1), Some(-1), Some(2), None, None]), Some(vec![]), // In slice None, // In slice @@ -806,7 +869,7 @@ mod tests { let list1_array = list1_array.slice(1, 2); let list1_values = list1.into_iter().skip(1).take(2); - let list2 = vec![ + let list2 = [ None, Some(vec![Some(100), None, Some(101)]), Some(vec![Some(102)]), @@ -825,7 +888,7 @@ mod tests { #[test] fn test_concat_primitive_list_arrays_sliced_lengths() { - let list1 = vec![ + let list1 = [ Some(vec![Some(-1), Some(-1), Some(2), None, None]), // In slice Some(vec![]), // In slice None, // In slice @@ -835,7 +898,7 @@ mod tests { let list1_array = list1_array.slice(0, 3); // no offset, but not all values let list1_values = list1.into_iter().take(3); - let list2 = vec![ + let list2 = [ None, Some(vec![Some(100), None, Some(101)]), Some(vec![Some(102)]), @@ -856,7 +919,7 @@ mod tests { #[test] fn test_concat_primitive_fixed_size_list_arrays() { - let list1 = vec![ + let list1 = [ Some(vec![Some(-1), None]), None, Some(vec![Some(10), Some(20)]), @@ -864,7 +927,7 @@ mod tests { let list1_array = FixedSizeListArray::from_iter_primitive::(list1.clone(), 2); - let list2 = vec![ + let list2 = [ None, Some(vec![Some(100), None]), Some(vec![Some(102), Some(103)]), @@ -872,7 +935,7 @@ mod tests { let list2_array = FixedSizeListArray::from_iter_primitive::(list2.clone(), 2); - let list3 = vec![Some(vec![Some(1000), Some(1001)])]; + let list3 = [Some(vec![Some(1000), Some(1001)])]; let list3_array = FixedSizeListArray::from_iter_primitive::(list3.clone(), 2); @@ -885,6 +948,105 @@ mod tests { assert_eq!(array_result.as_ref(), &array_expected as &dyn Array); } + #[test] + fn test_concat_list_view_arrays() { + let list1 = [ + Some(vec![Some(-1), None]), + None, + Some(vec![Some(10), Some(20)]), + ]; + let mut list1_array = ListViewBuilder::new(Int64Builder::new()); + for v in list1.iter() { + list1_array.append_option(v.clone()); + } + let list1_array = list1_array.finish(); + + let list2 = [ + None, + Some(vec![Some(100), None]), + Some(vec![Some(102), Some(103)]), + ]; + let mut list2_array = ListViewBuilder::new(Int64Builder::new()); + for v in list2.iter() { + list2_array.append_option(v.clone()); + } + let list2_array = list2_array.finish(); + + let list3 = [Some(vec![Some(1000), Some(1001)])]; + let mut list3_array = ListViewBuilder::new(Int64Builder::new()); + for v in list3.iter() { + list3_array.append_option(v.clone()); + } + let list3_array = list3_array.finish(); + + let array_result = concat(&[&list1_array, &list2_array, &list3_array]).unwrap(); + + let expected: Vec<_> = list1.into_iter().chain(list2).chain(list3).collect(); + let mut array_expected = ListViewBuilder::new(Int64Builder::new()); + for v in expected.iter() { + array_expected.append_option(v.clone()); + } + let array_expected = array_expected.finish(); + + assert_eq!(array_result.as_ref(), &array_expected as &dyn Array); + } + + #[test] + fn test_concat_sliced_list_view_arrays() { + let list1 = [ + Some(vec![Some(-1), None]), + None, + Some(vec![Some(10), Some(20)]), + ]; + let mut list1_array = ListViewBuilder::new(Int64Builder::new()); + for v in list1.iter() { + list1_array.append_option(v.clone()); + } + let list1_array = list1_array.finish(); + + let list2 = [ + None, + Some(vec![Some(100), None]), + Some(vec![Some(102), Some(103)]), + ]; + let mut list2_array = ListViewBuilder::new(Int64Builder::new()); + for v in list2.iter() { + list2_array.append_option(v.clone()); + } + let list2_array = list2_array.finish(); + + let list3 = [Some(vec![Some(1000), Some(1001)])]; + let mut list3_array = ListViewBuilder::new(Int64Builder::new()); + for v in list3.iter() { + list3_array.append_option(v.clone()); + } + let list3_array = list3_array.finish(); + + // Concat sliced arrays. + // ListView slicing will slice the offset/sizes but preserve the original values child. + let array_result = concat(&[ + &list1_array.slice(1, 2), + &list2_array.slice(1, 2), + &list3_array.slice(0, 1), + ]) + .unwrap(); + + let expected: Vec<_> = vec![ + None, + Some(vec![Some(10), Some(20)]), + Some(vec![Some(100), None]), + Some(vec![Some(102), Some(103)]), + Some(vec![Some(1000), Some(1001)]), + ]; + let mut array_expected = ListViewBuilder::new(Int64Builder::new()); + for v in expected.iter() { + array_expected.append_option(v.clone()); + } + let array_expected = array_expected.finish(); + + assert_eq!(array_result.as_ref(), &array_expected as &dyn Array); + } + #[test] fn test_concat_struct_arrays() { let field = Arc::new(Field::new("field", DataType::Int64, true)); diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index dace2bab728f..6a5ba13c950a 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -58,7 +58,13 @@ pub struct SlicesIterator<'a>(BitSliceIterator<'a>); impl<'a> SlicesIterator<'a> { /// Creates a new iterator from a [BooleanArray] pub fn new(filter: &'a BooleanArray) -> Self { - Self(filter.values().set_slices()) + filter.values().into() + } +} + +impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> { + fn from(filter: &'a BooleanBuffer) -> Self { + Self(filter.set_slices()) } } @@ -122,6 +128,12 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray { /// Returns a filtered `values` [`Array`] where the corresponding elements of /// `predicate` are `true`. /// +/// If multiple arrays (or record batches) need to be filtered using the same predicate array, +/// consider using [FilterBuilder] to create a single [FilterPredicate] and then +/// calling [FilterPredicate::filter_record_batch]. +/// In contrast to this function, it is then the responsibility of the caller +/// to use [FilterBuilder::optimize] if appropriate. +/// /// # See also /// * [`FilterBuilder`] for more control over the filtering process. /// * [`filter_record_batch`] to filter a [`RecordBatch`] @@ -168,25 +180,28 @@ fn multiple_arrays(data_type: &DataType) -> bool { /// `predicate` are true. /// /// This is the equivalent of calling [filter] on each column of the [RecordBatch]. +/// +/// If multiple record batches (or arrays) need to be filtered using the same predicate array, +/// consider using [FilterBuilder] to create a single [FilterPredicate] and then +/// calling [FilterPredicate::filter_record_batch]. +/// In contrast to this function, it is then the responsibility of the caller +/// to use [FilterBuilder::optimize] if appropriate. pub fn filter_record_batch( record_batch: &RecordBatch, predicate: &BooleanArray, ) -> Result { let mut filter_builder = FilterBuilder::new(predicate); - if record_batch.num_columns() > 1 { - // Only optimize if filtering more than one column + let num_cols = record_batch.num_columns(); + if num_cols > 1 + || (num_cols > 0 && multiple_arrays(record_batch.schema_ref().field(0).data_type())) + { + // Only optimize if filtering more than one column or if the column contains multiple internal arrays // Otherwise, the overhead of optimization can be more than the benefit filter_builder = filter_builder.optimize(); } let filter = filter_builder.build(); - let filtered_arrays = record_batch - .columns() - .iter() - .map(|a| filter_array(a, &filter)) - .collect::, _>>()?; - let options = RecordBatchOptions::default().with_row_count(Some(filter.count())); - RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options) + filter.filter_record_batch(record_batch) } /// A builder to construct [`FilterPredicate`] @@ -300,6 +315,31 @@ impl FilterPredicate { filter_array(values, self) } + /// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this + /// [`FilterPredicate`]. + /// + /// This is the equivalent of calling [filter] on each column of the [`RecordBatch`]. + pub fn filter_record_batch( + &self, + record_batch: &RecordBatch, + ) -> Result { + let filtered_arrays = record_batch + .columns() + .iter() + .map(|a| filter_array(a, self)) + .collect::, _>>()?; + + // SAFETY: we know that the set of filtered arrays will match the schema of the original + // record batch + unsafe { + Ok(RecordBatch::new_unchecked( + record_batch.schema(), + filtered_arrays, + self.count, + )) + } + } + /// Number of rows being selected based on this [`FilterPredicate`] pub fn count(&self) -> usize { self.count @@ -346,6 +386,12 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result { Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate))) } + DataType::ListView(_) => { + Ok(Arc::new(filter_list_view::(values.as_list_view(), predicate))) + } + DataType::LargeListView(_) => { + Ok(Arc::new(filter_list_view::(values.as_list_view(), predicate))) + } DataType::RunEndEncoded(_, _) => { downcast_run_array!{ values => Ok(Arc::new(filter_run_end_array(values, predicate)?)), @@ -860,6 +906,34 @@ fn filter_sparse_union( }) } +/// `filter` implementation for list views +fn filter_list_view( + array: &GenericListViewArray, + predicate: &FilterPredicate, +) -> GenericListViewArray { + let filtered_offsets = filter_native::(array.offsets(), predicate); + let filtered_sizes = filter_native::(array.sizes(), predicate); + + // Filter the nulls + let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { + let buffer = BooleanBuffer::new(nulls, 0, predicate.count); + + Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) }) + } else { + None + }; + + let list_data = ArrayDataBuilder::new(array.data_type().clone()) + .nulls(nulls) + .buffers(vec![filtered_offsets, filtered_sizes]) + .child_data(vec![array.values().to_data()]) + .len(predicate.count); + + let list_data = unsafe { list_data.build_unchecked() }; + + GenericListViewArray::from(list_data) +} + #[cfg(test)] mod tests { use super::*; @@ -1370,6 +1444,69 @@ mod tests { assert_eq!(&make_array(expected), &result); } + fn test_case_filter_list_view() { + // [[1, 2], null, [], [3,4]] + let mut list_array = GenericListViewBuilder::::new(Int32Builder::new()); + list_array.append_value([Some(1), Some(2)]); + list_array.append_null(); + list_array.append_value([]); + list_array.append_value([Some(3), Some(4)]); + + let list_array = list_array.finish(); + let predicate = BooleanArray::from_iter([true, false, true, false]); + + // Filter result: [[1, 2], []] + let filtered = filter(&list_array, &predicate) + .unwrap() + .as_list_view::() + .clone(); + + let mut expected = + GenericListViewBuilder::::with_capacity(Int32Builder::with_capacity(5), 3); + expected.append_value([Some(1), Some(2)]); + expected.append_value([]); + let expected = expected.finish(); + + assert_eq!(&filtered, &expected); + } + + fn test_case_filter_sliced_list_view() { + // [[1, 2], null, [], [3,4]] + let mut list_array = + GenericListViewBuilder::::with_capacity(Int32Builder::with_capacity(6), 4); + list_array.append_value([Some(1), Some(2)]); + list_array.append_null(); + list_array.append_value([]); + list_array.append_value([Some(3), Some(4)]); + + let list_array = list_array.finish(); + + // Sliced: [null, [], [3, 4]] + let sliced = list_array.slice(1, 3); + let predicate = BooleanArray::from_iter([false, false, true]); + + // Filter result: [[1, 2], []] + let filtered = filter(&sliced, &predicate) + .unwrap() + .as_list_view::() + .clone(); + + let mut expected = GenericListViewBuilder::::new(Int32Builder::new()); + expected.append_value([Some(3), Some(4)]); + let expected = expected.finish(); + + assert_eq!(&filtered, &expected); + } + + #[test] + fn test_filter_list_view_array() { + test_case_filter_list_view::(); + test_case_filter_list_view::(); + + test_case_filter_sliced_list_view::(); + test_case_filter_sliced_list_view::(); + } + #[test] fn test_slice_iterator_bits() { let filter_values = (0..64).map(|i| i == 1).collect::>(); diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index dfe6903dc4e3..eec4ffa14e72 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -218,6 +218,12 @@ fn take_impl( DataType::LargeList(_) => { Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?)) } + DataType::ListView(_) => { + Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)?)) + } + DataType::LargeListView(_) => { + Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)?)) + } DataType::FixedSizeList(_, length) => { let values = values .as_any() @@ -621,6 +627,33 @@ where Ok(GenericListArray::::from(list_data)) } +fn take_list_view( + values: &GenericListViewArray, + indices: &PrimitiveArray, +) -> Result, ArrowError> +where + IndexType: ArrowPrimitiveType, + OffsetType: ArrowPrimitiveType, + OffsetType::Native: OffsetSizeTrait, +{ + let taken_offsets = take_native(values.offsets(), indices); + let taken_sizes = take_native(values.sizes(), indices); + let nulls = take_nulls(values.nulls(), indices); + + let list_view_data = ArrayDataBuilder::new(values.data_type().clone()) + .len(indices.len()) + .nulls(nulls) + .buffers(vec![taken_offsets.into(), taken_sizes.into()]) + .child_data(vec![values.values().to_data()]); + + // SAFETY: all buffers and child nodes for ListView added in constructor + let list_view_data = unsafe { list_view_data.build_unchecked() }; + + Ok(GenericListViewArray::::from( + list_view_data, + )) +} + /// `take` implementation for `FixedSizeListArray` /// /// Calculates the index and indexed offset for the inner array, @@ -980,6 +1013,7 @@ mod tests { use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use arrow_data::ArrayData; use arrow_schema::{Field, Fields, TimeUnit, UnionFields}; + use num_traits::ToPrimitive; fn test_take_decimal_arrays( data: Vec>, @@ -1821,6 +1855,55 @@ mod tests { }}; } + fn test_take_list_view_generic( + values: Vec>>>, + take_indices: Vec>, + expected: Vec>>>, + mapper: F, + ) where + F: Fn(GenericListViewArray) -> GenericListViewArray, + { + let mut list_view_array = + GenericListViewBuilder::::new(PrimitiveBuilder::::new()); + + for value in values { + list_view_array.append_option(value); + } + let list_view_array = list_view_array.finish(); + let list_view_array = mapper(list_view_array); + + let mut indices = UInt64Builder::new(); + for idx in take_indices { + indices.append_option(idx.map(|i| i.to_u64().unwrap())); + } + let indices = indices.finish(); + + let taken = take(&list_view_array, &indices, None) + .unwrap() + .as_list_view() + .clone(); + + let mut expected_array = + GenericListViewBuilder::::new(PrimitiveBuilder::::new()); + for value in expected { + expected_array.append_option(value); + } + let expected_array = expected_array.finish(); + + assert_eq!(taken, expected_array); + } + + macro_rules! list_view_test_case { + (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{ + test_take_list_view_generic::($values, $indices, $expected, |x| x); + test_take_list_view_generic::($values, $indices, $expected, |x| x); + }}; + (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{ + test_take_list_view_generic::($values, $indices, $expected, $fn); + test_take_list_view_generic::($values, $indices, $expected, $fn); + }}; + } + fn do_take_fixed_size_list_test( length: ::Native, input_data: Vec>>>, @@ -1871,6 +1954,72 @@ mod tests { test_take_list_with_nulls!(i64, LargeList, LargeListArray); } + #[test] + fn test_test_take_list_view_reversed() { + // Take reversed indices + list_view_test_case! { + values: vec![ + Some(vec![Some(1), None, Some(3)]), + None, + Some(vec![Some(7), Some(8), None]), + ], + indices: vec![Some(2), Some(1), Some(0)], + expected: vec![ + Some(vec![Some(7), Some(8), None]), + None, + Some(vec![Some(1), None, Some(3)]), + ] + } + } + + #[test] + fn test_take_list_view_null_indices() { + // Take with null indices + list_view_test_case! { + values: vec![ + Some(vec![Some(1), None, Some(3)]), + None, + Some(vec![Some(7), Some(8), None]), + ], + indices: vec![None, Some(0), None], + expected: vec![None, Some(vec![Some(1), None, Some(3)]), None] + } + } + + #[test] + fn test_take_list_view_null_values() { + // Take at null values + list_view_test_case! { + values: vec![ + Some(vec![Some(1), None, Some(3)]), + None, + Some(vec![Some(7), Some(8), None]), + ], + indices: vec![Some(1), Some(1), Some(1), None, None], + expected: vec![None; 5] + } + } + + #[test] + fn test_take_list_view_sliced() { + // Take null indices/values, with slicing. + list_view_test_case! { + values: vec![ + Some(vec![Some(1)]), + None, + None, + Some(vec![Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + None, + ], + transform: |l| l.slice(2, 4), + indices: vec![Some(0), Some(3), None, Some(1), Some(2)], + expected: vec![ + None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)]) + ] + } + } + #[test] fn test_take_fixed_size_list() { do_take_fixed_size_list_test::( diff --git a/arrow-select/src/zip.rs b/arrow-select/src/zip.rs index 2efd2e749921..c202be6b6299 100644 --- a/arrow-select/src/zip.rs +++ b/arrow-select/src/zip.rs @@ -17,8 +17,9 @@ //! [`zip`]: Combine values from two arrays based on boolean mask -use crate::filter::SlicesIterator; +use crate::filter::{SlicesIterator, prep_null_mask_filter}; use arrow_array::*; +use arrow_buffer::BooleanBuffer; use arrow_data::transform::MutableArrayData; use arrow_schema::ArrowError; @@ -127,7 +128,8 @@ pub fn zip( // keep track of how much is filled let mut filled = 0; - SlicesIterator::new(mask).for_each(|(start, end)| { + let mask = maybe_prep_null_mask_filter(mask); + SlicesIterator::from(&mask).for_each(|(start, end)| { // the gap needs to be filled with falsy values if start > filled { if falsy_is_scalar { @@ -166,9 +168,22 @@ pub fn zip( Ok(make_array(data)) } +fn maybe_prep_null_mask_filter(predicate: &BooleanArray) -> BooleanBuffer { + // Nulls are treated as false + if predicate.null_count() == 0 { + predicate.values().clone() + } else { + let cleaned = prep_null_mask_filter(predicate); + let (boolean_buffer, _) = cleaned.into_parts(); + boolean_buffer + } +} + #[cfg(test)] mod test { use super::*; + use arrow_array::cast::AsArray; + use arrow_buffer::{BooleanBuffer, NullBuffer}; #[test] fn test_zip_kernel_one() { @@ -279,4 +294,110 @@ mod test { let expected = Int32Array::from(vec![None, None, Some(42), Some(42), None]); assert_eq!(actual, &expected); } + + #[test] + fn test_zip_primitive_array_with_nulls_is_mask_should_be_treated_as_false() { + let truthy = Int32Array::from_iter_values(vec![1, 2, 3, 4, 5, 6]); + let falsy = Int32Array::from_iter_values(vec![7, 8, 9, 10, 11, 12]); + + let mask = { + let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]); + let nulls = NullBuffer::from(vec![ + true, true, true, + false, // null treated as false even though in the original mask it was true + true, true, + ]); + BooleanArray::new(booleans, Some(nulls)) + }; + let out = zip(&mask, &truthy, &falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = Int32Array::from(vec![ + Some(1), + Some(2), + Some(9), + Some(10), // true in mask but null + Some(11), + Some(12), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_primitive_scalar_with_boolean_array_mask_with_nulls_should_be_treated_as_false() + { + let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1)); + let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1)); + + let mask = { + let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]); + let nulls = NullBuffer::from(vec![ + true, true, true, + false, // null treated as false even though in the original mask it was true + true, true, + ]); + BooleanArray::new(booleans, Some(nulls)) + }; + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = Int32Array::from(vec![ + Some(42), + Some(42), + Some(123), + Some(123), // true in mask but null + Some(123), + Some(123), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_string_array_with_nulls_is_mask_should_be_treated_as_false() { + let truthy = StringArray::from_iter_values(vec!["1", "2", "3", "4", "5", "6"]); + let falsy = StringArray::from_iter_values(vec!["7", "8", "9", "10", "11", "12"]); + + let mask = { + let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]); + let nulls = NullBuffer::from(vec![ + true, true, true, + false, // null treated as false even though in the original mask it was true + true, true, + ]); + BooleanArray::new(booleans, Some(nulls)) + }; + let out = zip(&mask, &truthy, &falsy).unwrap(); + let actual = out.as_string::(); + let expected = StringArray::from_iter_values(vec![ + "1", "2", "9", "10", // true in mask but null + "11", "12", + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_large_string_scalar_with_boolean_array_mask_with_nulls_should_be_treated_as_false() + { + let scalar_truthy = Scalar::new(LargeStringArray::from_iter_values(["test"])); + let scalar_falsy = Scalar::new(LargeStringArray::from_iter_values(["something else"])); + + let mask = { + let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]); + let nulls = NullBuffer::from(vec![ + true, true, true, + false, // null treated as false even though in the original mask it was true + true, true, + ]); + BooleanArray::new(booleans, Some(nulls)) + }; + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = LargeStringArray::from_iter(vec![ + Some("test"), + Some("test"), + Some("something else"), + Some("something else"), // true in mask but null + Some("something else"), + Some("something else"), + ]); + assert_eq!(actual, &expected); + } } diff --git a/arrow/benches/cast_kernels.rs b/arrow/benches/cast_kernels.rs index a54529c8d108..040c118a1e83 100644 --- a/arrow/benches/cast_kernels.rs +++ b/arrow/benches/cast_kernels.rs @@ -359,6 +359,46 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("cast binary view to string view", |b| { b.iter(|| cast_array(&binary_view_array, DataType::Utf8View)) }); + + c.bench_function("cast string single run to ree", |b| { + let source_array = StringArray::from(vec!["a"; 8192]); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + b.iter(|| cast(&array_ref, &target_type).unwrap()); + }); + + c.bench_function("cast runs of 10 string to ree", |b| { + let source_array: Int32Array = (0..8192).map(|i| i / 10).collect(); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + b.iter(|| cast(&array_ref, &target_type).unwrap()); + }); + + c.bench_function("cast runs of 1000 int32s to ree", |b| { + let source_array: Int32Array = (0..8192).map(|i| i / 1000).collect(); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + b.iter(|| cast(&array_ref, &target_type).unwrap()); + }); + + c.bench_function("cast no runs of int32s to ree", |b| { + let source_array: Int32Array = (0..8192).collect(); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + b.iter(|| cast(&array_ref, &target_type).unwrap()); + }); } criterion_group!(benches, add_benchmark); diff --git a/arrow/tests/array_equal.rs b/arrow/tests/array_equal.rs index 7fc8b0be7a3d..381054a25df5 100644 --- a/arrow/tests/array_equal.rs +++ b/arrow/tests/array_equal.rs @@ -22,11 +22,17 @@ use arrow::array::{ StringDictionaryBuilder, StructArray, UnionBuilder, make_array, }; use arrow::datatypes::{Int16Type, Int32Type}; -use arrow_array::builder::{StringBuilder, StringViewBuilder, StructBuilder}; -use arrow_array::{DictionaryArray, FixedSizeListArray, StringViewArray}; +use arrow_array::builder::{ + GenericListViewBuilder, StringBuilder, StringViewBuilder, StructBuilder, +}; +use arrow_array::cast::AsArray; +use arrow_array::{ + DictionaryArray, FixedSizeListArray, GenericListViewArray, PrimitiveArray, StringViewArray, +}; use arrow_buffer::{Buffer, ToByteSlice}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{DataType, Field, Fields}; +use arrow_select::take::take; use std::sync::Arc; #[test] @@ -756,6 +762,125 @@ fn test_fixed_list_offsets() { test_equal(&a_slice, &b_slice, true); } +fn create_list_view_array< + O: OffsetSizeTrait, + U: IntoIterator>, + T: IntoIterator>, +>( + data: T, +) -> GenericListViewArray { + let mut builder = GenericListViewBuilder::::new(Int32Builder::new()); + for d in data { + if let Some(v) = d { + builder.append_value(v); + } else { + builder.append_null(); + } + } + + builder.finish() +} + +fn test_test_list_view_array() { + let a = create_list_view_array::([ + None, + Some(vec![Some(1), None, Some(2)]), + Some(vec![Some(3), Some(4), Some(5), None]), + ]); + let b = create_list_view_array::([ + None, + Some(vec![Some(1), None, Some(2)]), + Some(vec![Some(3), Some(4), Some(5), None]), + ]); + + test_equal(&a, &b, true); + + // Simple non-matching arrays by reordering + let b = create_list_view_array::([ + Some(vec![Some(3), Some(4), Some(5), None]), + Some(vec![Some(1), None, Some(2)]), + ]); + test_equal(&a, &b, false); + + // reorder using take yields equal values + let indices: PrimitiveArray = vec![None, Some(1), Some(0)].into(); + let b = take(&b, &indices, None) + .unwrap() + .as_list_view::() + .clone(); + + test_equal(&a, &b, true); + + // Slicing one side yields unequal again + let a = a.slice(1, 2); + + test_equal(&a, &b, false); + + // Slicing the other to match makes them equal again + let b = b.slice(1, 2); + + test_equal(&a, &b, true); +} + +// Special test for List>. +// This tests the equal_ranges kernel +fn test_sliced_list_of_list_view() { + // First list view is created using the builder, with elements not deduplicated. + let mut a = ListBuilder::new(GenericListViewBuilder::::new(Int32Builder::new())); + + a.append_value([Some(vec![Some(1), Some(2), Some(3)]), Some(vec![])]); + a.append_null(); + a.append_value([ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(6)]), + ]); + + let a = a.finish(); + // a = [[[1,2,3], []], null, [[4, null], [5], null, [6]]] + + // First list view is created using the builder, with elements not deduplicated. + let mut b = ListBuilder::new(GenericListViewBuilder::::new(Int32Builder::new())); + + // Add an extra row that we will slice off, adjust the List offsets + b.append_value([Some(vec![Some(0), Some(0), Some(0)])]); + b.append_value([Some(vec![Some(1), Some(2), Some(3)]), Some(vec![])]); + b.append_null(); + b.append_value([ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(6)]), + ]); + + let b = b.finish(); + // b = [[[0, 0, 0]], [[1,2,3], []], null, [[4, null], [5], null, [6]]] + let b = b.slice(1, 3); + // b = [[[1,2,3], []], null, [[4, null], [5], null, [6]]] but the outer ListArray + // has an offset + + test_equal(&a, &b, true); +} + +#[test] +fn test_list_view_array() { + test_test_list_view_array::(); +} + +#[test] +fn test_large_list_view_array() { + test_test_list_view_array::(); +} + +#[test] +fn test_nested_list_view_array() { + test_sliced_list_of_list_view::(); +} + +#[test] +fn test_nested_large_list_view_array() { + test_sliced_list_of_list_view::(); +} + #[test] fn test_struct_equal() { let strings: ArrayRef = Arc::new(StringArray::from(vec![ diff --git a/parquet-variant-compute/src/shred_variant.rs b/parquet-variant-compute/src/shred_variant.rs index d5635291f712..f8158b2211a2 100644 --- a/parquet-variant-compute/src/shred_variant.rs +++ b/parquet-variant-compute/src/shred_variant.rs @@ -331,22 +331,11 @@ mod tests { use parquet_variant::{ObjectBuilder, ReadOnlyMetadataBuilder, Variant, VariantBuilder}; use std::sync::Arc; - fn create_test_variant_array(values: Vec>>) -> VariantArray { - let mut builder = VariantArrayBuilder::new(values.len()); - for value in values { - match value { - Some(v) => builder.append_variant(v), - None => builder.append_null(), - } - } - builder.build() - } - #[test] fn test_already_shredded_input_error() { // Create a VariantArray that already has typed_value_field // First create a valid VariantArray, then extract its parts to construct a shredded one - let temp_array = create_test_variant_array(vec![Some(Variant::from("test"))]); + let temp_array = VariantArray::from_iter(vec![Some(Variant::from("test"))]); let metadata = temp_array.metadata_field().clone(); let value = temp_array.value_field().unwrap().clone(); let typed_value = Arc::new(Int64Array::from(vec![42])) as ArrayRef; @@ -375,7 +364,7 @@ mod tests { #[test] fn test_unsupported_list_schema() { - let input = create_test_variant_array(vec![Some(Variant::from(42))]); + let input = VariantArray::from_iter([Variant::from(42)]); let list_schema = DataType::List(Arc::new(Field::new("item", DataType::Int64, true))); shred_variant(&input, &list_schema).expect_err("unsupported"); } @@ -383,7 +372,7 @@ mod tests { #[test] fn test_primitive_shredding_comprehensive() { // Test mixed scenarios in a single array - let input = create_test_variant_array(vec![ + let input = VariantArray::from_iter(vec![ Some(Variant::from(42i64)), // successful shred Some(Variant::from("hello")), // failed shred (string) Some(Variant::from(100i64)), // successful shred @@ -448,10 +437,10 @@ mod tests { #[test] fn test_primitive_different_target_types() { - let input = create_test_variant_array(vec![ - Some(Variant::from(42i32)), - Some(Variant::from(3.15f64)), - Some(Variant::from("not_a_number")), + let input = VariantArray::from_iter(vec![ + Variant::from(42i32), + Variant::from(3.15f64), + Variant::from("not_a_number"), ]); // Test Int32 target @@ -882,10 +871,7 @@ mod tests { #[test] fn test_spec_compliance() { - let input = create_test_variant_array(vec![ - Some(Variant::from(42i64)), - Some(Variant::from("hello")), - ]); + let input = VariantArray::from_iter(vec![Variant::from(42i64), Variant::from("hello")]); let result = shred_variant(&input, &DataType::Int64).unwrap(); diff --git a/parquet-variant-compute/src/type_conversion.rs b/parquet-variant-compute/src/type_conversion.rs index 28087d7541e4..d15664f5af9e 100644 --- a/parquet-variant-compute/src/type_conversion.rs +++ b/parquet-variant-compute/src/type_conversion.rs @@ -17,8 +17,7 @@ //! Module for transforming a typed arrow `Array` to `VariantArray`. -use arrow::array::ArrowNativeTypeOp; -use arrow::compute::DecimalCast; +use arrow::compute::{DecimalCast, rescale_decimal}; use arrow::datatypes::{ self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type, DecimalType, @@ -190,90 +189,6 @@ where } } -/// Rescale a decimal from (input_precision, input_scale) to (output_precision, output_scale) -/// and return the scaled value if it fits the output precision. Similar to the implementation in -/// decimal.rs in arrow-cast. -pub(crate) fn rescale_decimal( - value: I::Native, - input_precision: u8, - input_scale: i8, - output_precision: u8, - output_scale: i8, -) -> Option -where - I::Native: DecimalCast, - O::Native: DecimalCast, -{ - let delta_scale = output_scale - input_scale; - - let (scaled, is_infallible_cast) = if delta_scale >= 0 { - // O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999, ...). - // Adding 1 yields exactly 10^k without computing a power at runtime. - // Using the precomputed table avoids pow(10, k) and its checked/overflow - // handling, which is faster and simpler for scaling by 10^delta_scale. - let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; - let mul = max.add_wrapping(O::Native::ONE); - - // if the gain in precision (digits) is greater than the multiplication due to scaling - // every number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then an increase of scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type - // needs to provide at least 8 digits precision - let is_infallible_cast = input_precision as i8 + delta_scale <= output_precision as i8; - let value = O::Native::from_decimal(value); - let scaled = if is_infallible_cast { - Some(value.unwrap().mul_wrapping(mul)) - } else { - value.and_then(|x| x.mul_checked(mul).ok()) - }; - (scaled, is_infallible_cast) - } else { - // the abs of delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. - // If so, the scale change divides out more digits than the input has precision and the result - // of the cast is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, - // the largest possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. - // Smaller values (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) - // produce even smaller results, which also round to zero. In that case, just return zero. - let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale.unsigned_abs() as usize) else { - return Some(O::Native::ZERO); - }; - let div = max.add_wrapping(I::Native::ONE); - let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE)); - let half_neg = half.neg_wrapping(); - - // div is >= 10 and so this cannot overflow - let d = value.div_wrapping(div); - let r = value.mod_wrapping(div); - - // Round result - let adjusted = match value >= I::Native::ZERO { - true if r >= half => d.add_wrapping(I::Native::ONE), - false if r <= half_neg => d.sub_wrapping(I::Native::ONE), - _ => d, - }; - - // if the reduction of the input number through scaling (dividing) is greater - // than a possible precision loss (plus potential increase via rounding) - // every input number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then and decrease the scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). - // The rounding may add a digit, so for the cast to be infallible, - // the output type needs to have at least 3 digits of precision. - // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: - // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible - let is_infallible_cast = input_precision as i8 + delta_scale < output_precision as i8; - (O::Native::from_decimal(adjusted), is_infallible_cast) - }; - - if is_infallible_cast { - scaled - } else { - scaled.filter(|v| O::is_valid_decimal_precision(*v, output_precision)) - } -} - /// Convert the value at a specific index in the given array into a `Variant`. macro_rules! non_generic_conversion_single_value { ($array:expr, $cast_fn:expr, $index:expr) => {{ diff --git a/parquet/THRIFT.md b/parquet/THRIFT.md index 56365665070a..599b33f2bce3 100644 --- a/parquet/THRIFT.md +++ b/parquet/THRIFT.md @@ -57,7 +57,7 @@ The `thrift_enum` macro can be used in this instance. ```rust thrift_enum!( - enum Type { +enum Type { BOOLEAN = 0; INT32 = 1; INT64 = 2; @@ -85,6 +85,8 @@ pub enum Type { } ``` +All Rust `enum`s produced with this macro will have `pub` visibility. + ### Unions Thrift unions are a special kind of struct in which only a single field is populated. In this @@ -175,6 +177,9 @@ pub enum ColumnCryptoMetaData { } ``` +All Rust `enum`s produced with either macro will have `pub` visibility. `thrift_union` also allows +for lifetime annotations, but this capability is not currently utilized. + ### Structs The `thrift_struct` macro is used for structs. This macro is a little more flexible than the others diff --git a/parquet/src/arrow/arrow_reader/selection.rs b/parquet/src/arrow/arrow_reader/selection.rs index 9c3caec0b4a5..adbbff1ca2df 100644 --- a/parquet/src/arrow/arrow_reader/selection.rs +++ b/parquet/src/arrow/arrow_reader/selection.rs @@ -1432,4 +1432,33 @@ mod tests { assert_eq!(selection.row_count(), 0); assert_eq!(selection.skipped_row_count(), 0); } + + #[test] + fn test_trim() { + let selection = RowSelection::from(vec![ + RowSelector::skip(34), + RowSelector::select(12), + RowSelector::skip(3), + RowSelector::select(35), + ]); + + let expected = vec![ + RowSelector::skip(34), + RowSelector::select(12), + RowSelector::skip(3), + RowSelector::select(35), + ]; + + assert_eq!(selection.trim().selectors, expected); + + let selection = RowSelection::from(vec![ + RowSelector::skip(34), + RowSelector::select(12), + RowSelector::skip(3), + ]); + + let expected = vec![RowSelector::skip(34), RowSelector::select(12)]; + + assert_eq!(selection.trim().selectors, expected); + } } diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs index 7f50eada46de..7d9c1df37b3f 100644 --- a/parquet/src/basic.rs +++ b/parquet/src/basic.rs @@ -61,13 +61,10 @@ enum Type { // ---------------------------------------------------------------------- // Mirrors thrift enum `ConvertedType` -// -// Cannot use macros because of added field `None` // TODO(ets): Adding the `NONE` variant to this enum is a bit awkward. We should -// look into removing it and using `Option` instead. Then all of this -// handwritten code could go away. - +// look into removing it and using `Option` instead. +thrift_enum!( /// Common types (converted types) used by frameworks when using Parquet. /// /// This helps map between types in those frameworks to the base types in Parquet. @@ -75,142 +72,101 @@ enum Type { /// /// This struct was renamed from `LogicalType` in version 4.0.0. /// If targeting Parquet format 2.4.0 or above, please use [LogicalType] instead. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub enum ConvertedType { - /// No type conversion. - NONE, - /// A BYTE_ARRAY actually contains UTF8 encoded chars. - UTF8, - - /// A map is converted as an optional field containing a repeated key/value pair. - MAP, - - /// A key/value pair is converted into a group of two fields. - MAP_KEY_VALUE, - - /// A list is converted into an optional field containing a repeated field for its - /// values. - LIST, - - /// An enum is converted into a binary field - ENUM, - - /// A decimal value. - /// This may be used to annotate binary or fixed primitive types. The - /// underlying byte array stores the unscaled value encoded as two's - /// complement using big-endian byte order (the most significant byte is the - /// zeroth element). - /// - /// This must be accompanied by a (maximum) precision and a scale in the - /// SchemaElement. The precision specifies the number of digits in the decimal - /// and the scale stores the location of the decimal point. For example 1.23 - /// would have precision 3 (3 total digits) and scale 2 (the decimal point is - /// 2 digits over). - DECIMAL, +enum ConvertedType { + /// Not defined in the spec, used internally to indicate no type conversion + NONE = -1; - /// A date stored as days since Unix epoch, encoded as the INT32 physical type. - DATE, + /// A BYTE_ARRAY actually contains UTF8 encoded chars. + UTF8 = 0; - /// The total number of milliseconds since midnight. The value is stored as an INT32 - /// physical type. - TIME_MILLIS, + /// A map is converted as an optional field containing a repeated key/value pair. + MAP = 1; - /// The total number of microseconds since midnight. The value is stored as an INT64 - /// physical type. - TIME_MICROS, + /// A key/value pair is converted into a group of two fields. + MAP_KEY_VALUE = 2; - /// Date and time recorded as milliseconds since the Unix epoch. - /// Recorded as a physical type of INT64. - TIMESTAMP_MILLIS, + /// A list is converted into an optional field containing a repeated field for its + /// values. + LIST = 3; - /// Date and time recorded as microseconds since the Unix epoch. - /// The value is stored as an INT64 physical type. - TIMESTAMP_MICROS, + /// An enum is converted into a BYTE_ARRAY field + ENUM = 4; - /// An unsigned 8 bit integer value stored as INT32 physical type. - UINT_8, + /// A decimal value. + /// + /// This may be used to annotate BYTE_ARRAY or FIXED_LEN_BYTE_ARRAY primitive + /// types. The underlying byte array stores the unscaled value encoded as two's + /// complement using big-endian byte order (the most significant byte is the + /// zeroth element). The value of the decimal is the value * 10^{-scale}. + /// + /// This must be accompanied by a (maximum) precision and a scale in the + /// SchemaElement. The precision specifies the number of digits in the decimal + /// and the scale stores the location of the decimal point. For example 1.23 + /// would have precision 3 (3 total digits) and scale 2 (the decimal point is + /// 2 digits over). + DECIMAL = 5; - /// An unsigned 16 bit integer value stored as INT32 physical type. - UINT_16, + /// A date stored as days since Unix epoch, encoded as the INT32 physical type. + DATE = 6; - /// An unsigned 32 bit integer value stored as INT32 physical type. - UINT_32, + /// The total number of milliseconds since midnight. The value is stored as an INT32 + /// physical type. + TIME_MILLIS = 7; - /// An unsigned 64 bit integer value stored as INT64 physical type. - UINT_64, + /// The total number of microseconds since midnight. The value is stored as an INT64 + /// physical type. + TIME_MICROS = 8; - /// A signed 8 bit integer value stored as INT32 physical type. - INT_8, + /// Date and time recorded as milliseconds since the Unix epoch. + /// Recorded as a physical type of INT64. + TIMESTAMP_MILLIS = 9; - /// A signed 16 bit integer value stored as INT32 physical type. - INT_16, + /// Date and time recorded as microseconds since the Unix epoch. + /// The value is stored as an INT64 physical type. + TIMESTAMP_MICROS = 10; - /// A signed 32 bit integer value stored as INT32 physical type. - INT_32, + /// An unsigned 8 bit integer value stored as INT32 physical type. + UINT_8 = 11; - /// A signed 64 bit integer value stored as INT64 physical type. - INT_64, + /// An unsigned 16 bit integer value stored as INT32 physical type. + UINT_16 = 12; - /// A JSON document embedded within a single UTF8 column. - JSON, + /// An unsigned 32 bit integer value stored as INT32 physical type. + UINT_32 = 13; - /// A BSON document embedded within a single BINARY column. - BSON, + /// An unsigned 64 bit integer value stored as INT64 physical type. + UINT_64 = 14; - /// An interval of time. - /// - /// This type annotates data stored as a FIXED_LEN_BYTE_ARRAY of length 12. - /// This data is composed of three separate little endian unsigned integers. - /// Each stores a component of a duration of time. The first integer identifies - /// the number of months associated with the duration, the second identifies - /// the number of days associated with the duration and the third identifies - /// the number of milliseconds associated with the provided duration. - /// This duration of time is independent of any particular timezone or date. - INTERVAL, -} + /// A signed 8 bit integer value stored as INT32 physical type. + INT_8 = 15; -impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for ConvertedType { - fn read_thrift(prot: &mut R) -> Result { - let val = prot.read_i32()?; - Ok(match val { - 0 => Self::UTF8, - 1 => Self::MAP, - 2 => Self::MAP_KEY_VALUE, - 3 => Self::LIST, - 4 => Self::ENUM, - 5 => Self::DECIMAL, - 6 => Self::DATE, - 7 => Self::TIME_MILLIS, - 8 => Self::TIME_MICROS, - 9 => Self::TIMESTAMP_MILLIS, - 10 => Self::TIMESTAMP_MICROS, - 11 => Self::UINT_8, - 12 => Self::UINT_16, - 13 => Self::UINT_32, - 14 => Self::UINT_64, - 15 => Self::INT_8, - 16 => Self::INT_16, - 17 => Self::INT_32, - 18 => Self::INT_64, - 19 => Self::JSON, - 20 => Self::BSON, - 21 => Self::INTERVAL, - _ => return Err(general_err!("Unexpected ConvertedType {}", val)), - }) - } -} + /// A signed 16 bit integer value stored as INT32 physical type. + INT_16 = 16; -impl WriteThrift for ConvertedType { - const ELEMENT_TYPE: ElementType = ElementType::I32; + /// A signed 32 bit integer value stored as INT32 physical type. + INT_32 = 17; - fn write_thrift(&self, writer: &mut ThriftCompactOutputProtocol) -> Result<()> { - // because we've added NONE, the variant values are off by 1, so correct that here - writer.write_i32(*self as i32 - 1) - } -} + /// A signed 64 bit integer value stored as INT64 physical type. + INT_64 = 18; + + /// A JSON document embedded within a single UTF8 column. + JSON = 19; -write_thrift_field!(ConvertedType, FieldType::I32); + /// A BSON document embedded within a single BINARY column. + BSON = 20; + + /// An interval of time + /// + /// This type annotates data stored as a FIXED_LEN_BYTE_ARRAY of length 12. + /// This data is composed of three separate little endian unsigned integers. + /// Each stores a component of a duration of time. The first integer identifies + /// the number of months associated with the duration, the second identifies + /// the number of days associated with the duration and the third identifies + /// the number of milliseconds associated with the provided duration. + /// This duration of time is independent of any particular timezone or date. + INTERVAL = 21; +} +); // ---------------------------------------------------------------------- // Mirrors thrift union `TimeUnit` @@ -741,7 +697,7 @@ pub struct EncodingMask(i32); impl EncodingMask { /// Highest valued discriminant in the [`Encoding`] enum - const MAX_ENCODING: i32 = Encoding::BYTE_STREAM_SPLIT as i32; + const MAX_ENCODING: i32 = Encoding::MAX_DISCRIMINANT; /// A mask consisting of unused bit positions, used for validation. This includes the never /// used GROUP_VAR_INT encoding value of `1`. const ALLOWED_MASK: u32 = @@ -1327,12 +1283,6 @@ impl WriteThrift for ColumnOrder { // ---------------------------------------------------------------------- // Display handlers -impl fmt::Display for ConvertedType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self:?}") - } -} - impl fmt::Display for Compression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self:?}") diff --git a/parquet/src/column/page.rs b/parquet/src/column/page.rs index 23517f05df11..f18b296c1c65 100644 --- a/parquet/src/column/page.rs +++ b/parquet/src/column/page.rs @@ -31,7 +31,7 @@ use crate::file::statistics::{Statistics, page_stats_to_thrift}; /// List of supported pages. /// These are 1-to-1 mapped from the equivalent Thrift definitions, except `buf` which /// used to store uncompressed bytes of the page. -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum Page { /// Data page Parquet format v1. DataPage { diff --git a/parquet/src/column/reader.rs b/parquet/src/column/reader.rs index b8ff38efa3c4..ebde79e6a7f2 100644 --- a/parquet/src/column/reader.rs +++ b/parquet/src/column/reader.rs @@ -569,11 +569,16 @@ fn parse_v1_level( match encoding { Encoding::RLE => { let i32_size = std::mem::size_of::(); - let data_size = read_num_bytes::(i32_size, buf.as_ref()) as usize; - Ok(( - i32_size + data_size, - buf.slice(i32_size..i32_size + data_size), - )) + if i32_size <= buf.len() { + let data_size = read_num_bytes::(i32_size, buf.as_ref()) as usize; + let end = i32_size + .checked_add(data_size) + .ok_or(general_err!("invalid level length"))?; + if end <= buf.len() { + return Ok((end, buf.slice(i32_size..end))); + } + } + Err(general_err!("not enough data to read levels")) } #[allow(deprecated)] Encoding::BIT_PACKED => { @@ -597,6 +602,25 @@ mod tests { use crate::util::test_common::page_util::InMemoryPageReader; use crate::util::test_common::rand_gen::make_pages; + #[test] + fn test_parse_v1_level_invalid_length() { + // Say length is 10, but buffer is only 4 + let buf = Bytes::from(vec![10, 0, 0, 0]); + let err = parse_v1_level(1, 100, Encoding::RLE, buf).unwrap_err(); + assert_eq!( + err.to_string(), + "Parquet error: not enough data to read levels" + ); + + // Say length is 4, but buffer is only 3 + let buf = Bytes::from(vec![4, 0, 0]); + let err = parse_v1_level(1, 100, Encoding::RLE, buf).unwrap_err(); + assert_eq!( + err.to_string(), + "Parquet error: not enough data to read levels" + ); + } + const NUM_LEVELS: usize = 128; const NUM_PAGES: usize = 2; const MAX_DEF_LEVEL: i16 = 5; diff --git a/parquet/src/column/reader/decoder.rs b/parquet/src/column/reader/decoder.rs index 1d4e2f751181..e49906207577 100644 --- a/parquet/src/column/reader/decoder.rs +++ b/parquet/src/column/reader/decoder.rs @@ -138,7 +138,7 @@ pub trait ColumnValueDecoder { /// /// This replaces `HashMap` lookups with direct indexing to avoid hashing overhead in the /// hot decoding paths. -const ENCODING_SLOTS: usize = Encoding::BYTE_STREAM_SPLIT as usize + 1; +const ENCODING_SLOTS: usize = Encoding::MAX_DISCRIMINANT as usize + 1; /// An implementation of [`ColumnValueDecoder`] for `[T::T]` pub struct ColumnValueDecoderImpl { diff --git a/parquet/src/encodings/decoding.rs b/parquet/src/encodings/decoding.rs index 91b31dbdfcd2..f5336ca7c09a 100644 --- a/parquet/src/encodings/decoding.rs +++ b/parquet/src/encodings/decoding.rs @@ -381,7 +381,17 @@ impl DictDecoder { impl Decoder for DictDecoder { fn set_data(&mut self, data: Bytes, num_values: usize) -> Result<()> { // First byte in `data` is bit width + if data.is_empty() { + return Err(eof_err!("Not enough bytes to decode bit_width")); + } + let bit_width = data.as_ref()[0]; + if bit_width > 32 { + return Err(general_err!( + "Invalid or corrupted RLE bit width {}. Max allowed is 32", + bit_width + )); + } let mut rle_decoder = RleDecoder::new(bit_width); rle_decoder.set_data(data.slice(1..)); self.num_values = num_values; @@ -631,6 +641,19 @@ where self.next_block() } } + + /// Verify the bit width is smaller then the integer type that it is trying to decode. + #[inline] + fn check_bit_width(&self, bit_width: usize) -> Result<()> { + if bit_width > std::mem::size_of::() * 8 { + return Err(general_err!( + "Invalid delta bit width {} which is larger than expected {} ", + bit_width, + std::mem::size_of::() * 8 + )); + } + Ok(()) + } } impl Decoder for DeltaBitPackDecoder @@ -726,6 +749,7 @@ where } let bit_width = self.mini_block_bit_widths[self.mini_block_idx] as usize; + self.check_bit_width(bit_width)?; let batch_to_read = self.mini_block_remaining.min(to_read - read); let batch_read = self @@ -796,6 +820,7 @@ where } let bit_width = self.mini_block_bit_widths[self.mini_block_idx] as usize; + self.check_bit_width(bit_width)?; let mini_block_to_skip = self.mini_block_remaining.min(to_skip - skip); let mini_block_should_skip = mini_block_to_skip; @@ -1380,6 +1405,13 @@ mod tests { test_plain_skip::(Bytes::from(data_bytes), 3, 6, 4, &[]); } + #[test] + fn test_dict_decoder_empty_data() { + let mut decoder = DictDecoder::::new(); + let err = decoder.set_data(Bytes::new(), 10).unwrap_err(); + assert_eq!(err.to_string(), "EOF: Not enough bytes to decode bit_width"); + } + fn test_plain_decode( data: Bytes, num_values: usize, @@ -2091,4 +2123,51 @@ mod tests { v } } + + #[test] + // Allow initializing a vector and pushing to it for clarity in this test + #[allow(clippy::vec_init_then_push)] + fn test_delta_bit_packed_invalid_bit_width() { + // Manually craft a buffer with an invalid bit width + let mut buffer = vec![]; + // block_size = 128 + buffer.push(128); + buffer.push(1); + // mini_blocks_per_block = 4 + buffer.push(4); + // num_values = 32 + buffer.push(32); + // first_value = 0 + buffer.push(0); + // min_delta = 0 + buffer.push(0); + // bit_widths, one for each of the 4 mini blocks + buffer.push(33); // Invalid bit width + buffer.push(0); + buffer.push(0); + buffer.push(0); + + let corrupted_buffer = Bytes::from(buffer); + + let mut decoder = DeltaBitPackDecoder::::new(); + decoder.set_data(corrupted_buffer.clone(), 32).unwrap(); + let mut read_buffer = vec![0; 32]; + let err = decoder.get(&mut read_buffer).unwrap_err(); + assert!( + err.to_string() + .contains("Invalid delta bit width 33 which is larger than expected 32"), + "{}", + err + ); + + let mut decoder = DeltaBitPackDecoder::::new(); + decoder.set_data(corrupted_buffer, 32).unwrap(); + let err = decoder.skip(32).unwrap_err(); + assert!( + err.to_string() + .contains("Invalid delta bit width 33 which is larger than expected 32"), + "{}", + err + ); + } } diff --git a/parquet/src/encodings/rle.rs b/parquet/src/encodings/rle.rs index db8227fcac3a..41c050132064 100644 --- a/parquet/src/encodings/rle.rs +++ b/parquet/src/encodings/rle.rs @@ -513,7 +513,10 @@ impl RleDecoder { self.rle_left = (indicator_value >> 1) as u32; let value_width = bit_util::ceil(self.bit_width as usize, 8); self.current_value = bit_reader.get_aligned::(value_width); - assert!(self.current_value.is_some()); + assert!( + self.current_value.is_some(), + "parquet_data_error: not enough data for RLE decoding" + ); } true } else { diff --git a/parquet/src/encryption/ciphers.rs b/parquet/src/encryption/ciphers.rs index faff28f8acff..a94c72dcd5ec 100644 --- a/parquet/src/encryption/ciphers.rs +++ b/parquet/src/encryption/ciphers.rs @@ -18,6 +18,7 @@ use crate::errors::ParquetError; use crate::errors::ParquetError::General; use crate::errors::Result; +use crate::file::metadata::HeapSize; use ring::aead::{AES_128_GCM, Aad, LessSafeKey, NonceSequence, UnboundKey}; use ring::rand::{SecureRandom, SystemRandom}; use std::fmt::Debug; @@ -27,7 +28,7 @@ pub(crate) const NONCE_LEN: usize = 12; pub(crate) const TAG_LEN: usize = 16; pub(crate) const SIZE_LEN: usize = 4; -pub(crate) trait BlockDecryptor: Debug + Send + Sync { +pub(crate) trait BlockDecryptor: Debug + Send + Sync + HeapSize { fn decrypt(&self, length_and_ciphertext: &[u8], aad: &[u8]) -> Result>; fn compute_plaintext_tag(&self, aad: &[u8], plaintext: &[u8]) -> Result>; @@ -50,6 +51,13 @@ impl RingGcmBlockDecryptor { } } +impl HeapSize for RingGcmBlockDecryptor { + fn heap_size(&self) -> usize { + // Ring's LessSafeKey doesn't allocate on the heap + 0 + } +} + impl BlockDecryptor for RingGcmBlockDecryptor { fn decrypt(&self, length_and_ciphertext: &[u8], aad: &[u8]) -> Result> { let mut result = Vec::with_capacity(length_and_ciphertext.len() - SIZE_LEN - NONCE_LEN); diff --git a/parquet/src/encryption/decrypt.rs b/parquet/src/encryption/decrypt.rs index b5374066dfc3..0066523419de 100644 --- a/parquet/src/encryption/decrypt.rs +++ b/parquet/src/encryption/decrypt.rs @@ -21,6 +21,7 @@ use crate::encryption::ciphers::{BlockDecryptor, RingGcmBlockDecryptor, TAG_LEN} use crate::encryption::modules::{ModuleType, create_footer_aad, create_module_aad}; use crate::errors::{ParquetError, Result}; use crate::file::column_crypto_metadata::ColumnCryptoMetaData; +use crate::file::metadata::HeapSize; use std::borrow::Cow; use std::collections::HashMap; use std::fmt::Formatter; @@ -271,6 +272,12 @@ struct ExplicitDecryptionKeys { column_keys: HashMap>, } +impl HeapSize for ExplicitDecryptionKeys { + fn heap_size(&self) -> usize { + self.footer_key.heap_size() + self.column_keys.heap_size() + } +} + #[derive(Clone)] enum DecryptionKeys { Explicit(ExplicitDecryptionKeys), @@ -290,6 +297,19 @@ impl PartialEq for DecryptionKeys { } } +impl HeapSize for DecryptionKeys { + fn heap_size(&self) -> usize { + match self { + Self::Explicit(keys) => keys.heap_size(), + Self::ViaRetriever(_) => { + // The retriever is a user-defined type we don't control, + // so we can't determine the heap size. + 0 + } + } + } +} + /// `FileDecryptionProperties` hold keys and AAD data required to decrypt a Parquet file. /// /// When reading Arrow data, the `FileDecryptionProperties` should be included in the @@ -334,6 +354,11 @@ pub struct FileDecryptionProperties { footer_signature_verification: bool, } +impl HeapSize for FileDecryptionProperties { + fn heap_size(&self) -> usize { + self.keys.heap_size() + self.aad_prefix.heap_size() + } +} impl FileDecryptionProperties { /// Returns a new [`FileDecryptionProperties`] builder that will use the provided key to /// decrypt footer metadata. @@ -547,6 +572,21 @@ impl PartialEq for FileDecryptor { } } +/// Estimate the size in bytes required for the file decryptor. +/// This is important to track the memory usage of cached Parquet meta data, +/// and is used via [`crate::file::metadata::ParquetMetaData::memory_size`]. +/// Note that when a [`KeyRetriever`] is used, its heap size won't be included +/// and the result will be an underestimate. +/// If the [`FileDecryptionProperties`] are shared between multiple files then the +/// heap size may also be an overestimate. +impl HeapSize for FileDecryptor { + fn heap_size(&self) -> usize { + self.decryption_properties.heap_size() + + (Arc::clone(&self.footer_decryptor) as Arc).heap_size() + + self.file_aad.heap_size() + } +} + impl FileDecryptor { pub(crate) fn new( decryption_properties: &Arc, diff --git a/parquet/src/file/metadata/memory.rs b/parquet/src/file/metadata/memory.rs index 98ce5736ae1d..11536bbbd41e 100644 --- a/parquet/src/file/metadata/memory.rs +++ b/parquet/src/file/metadata/memory.rs @@ -28,6 +28,7 @@ use crate::file::page_index::column_index::{ }; use crate::file::page_index::offset_index::{OffsetIndexMetaData, PageLocation}; use crate::file::statistics::{Statistics, ValueStatistics}; +use std::collections::HashMap; use std::sync::Arc; /// Trait for calculating the size of various containers @@ -50,9 +51,60 @@ impl HeapSize for Vec { } } +impl HeapSize for HashMap { + fn heap_size(&self) -> usize { + let capacity = self.capacity(); + if capacity == 0 { + return 0; + } + + // HashMap doesn't provide a way to get its heap size, so this is an approximation based on + // the behavior of hashbrown::HashMap as at version 0.16.0, and may become inaccurate + // if the implementation changes. + let key_val_size = std::mem::size_of::<(K, V)>(); + // Overhead for the control tags group, which may be smaller depending on architecture + let group_size = 16; + // 1 byte of metadata stored per bucket. + let metadata_size = 1; + + // Compute the number of buckets for the capacity. Based on hashbrown's capacity_to_buckets + let buckets = if capacity < 15 { + let min_cap = match key_val_size { + 0..=1 => 14, + 2..=3 => 7, + _ => 3, + }; + let cap = min_cap.max(capacity); + if cap < 4 { + 4 + } else if cap < 8 { + 8 + } else { + 16 + } + } else { + (capacity.saturating_mul(8) / 7).next_power_of_two() + }; + + group_size + + (buckets * (key_val_size + metadata_size)) + + self.keys().map(|k| k.heap_size()).sum::() + + self.values().map(|v| v.heap_size()).sum::() + } +} + impl HeapSize for Arc { fn heap_size(&self) -> usize { - self.as_ref().heap_size() + // Arc stores weak and strong counts on the heap alongside an instance of T + 2 * std::mem::size_of::() + std::mem::size_of::() + self.as_ref().heap_size() + } +} + +impl HeapSize for Arc { + fn heap_size(&self) -> usize { + 2 * std::mem::size_of::() + + std::mem::size_of_val(self.as_ref()) + + self.as_ref().heap_size() } } diff --git a/parquet/src/file/metadata/mod.rs b/parquet/src/file/metadata/mod.rs index 763025fe142b..7022bd61c44d 100644 --- a/parquet/src/file/metadata/mod.rs +++ b/parquet/src/file/metadata/mod.rs @@ -287,11 +287,17 @@ impl ParquetMetaData { /// /// 4. Does not include any allocator overheads pub fn memory_size(&self) -> usize { + #[cfg(feature = "encryption")] + let encryption_size = self.file_decryptor.heap_size(); + #[cfg(not(feature = "encryption"))] + let encryption_size = 0usize; + std::mem::size_of::() + self.file_metadata.heap_size() + self.row_groups.heap_size() + self.column_index.heap_size() + self.offset_index.heap_size() + + encryption_size } /// Override the column index @@ -1875,10 +1881,9 @@ mod tests { .build(); #[cfg(not(feature = "encryption"))] - let base_expected_size = 2248; + let base_expected_size = 2766; #[cfg(feature = "encryption")] - // Not as accurate as it should be: https://github.com/apache/arrow-rs/issues/8472 - let base_expected_size = 2416; + let base_expected_size = 2934; assert_eq!(parquet_meta.memory_size(), base_expected_size); @@ -1907,16 +1912,90 @@ mod tests { .build(); #[cfg(not(feature = "encryption"))] - let bigger_expected_size = 2674; + let bigger_expected_size = 3192; #[cfg(feature = "encryption")] - // Not as accurate as it should be: https://github.com/apache/arrow-rs/issues/8472 - let bigger_expected_size = 2842; + let bigger_expected_size = 3360; // more set fields means more memory usage assert!(bigger_expected_size > base_expected_size); assert_eq!(parquet_meta.memory_size(), bigger_expected_size); } + #[test] + #[cfg(feature = "encryption")] + fn test_memory_size_with_decryptor() { + use crate::encryption::decrypt::FileDecryptionProperties; + use crate::file::metadata::thrift::encryption::AesGcmV1; + + let schema_descr = get_test_schema_descr(); + + let columns = schema_descr + .columns() + .iter() + .map(|column_descr| ColumnChunkMetaData::builder(column_descr.clone()).build()) + .collect::>>() + .unwrap(); + let row_group_meta = RowGroupMetaData::builder(schema_descr.clone()) + .set_num_rows(1000) + .set_column_metadata(columns) + .build() + .unwrap(); + let row_group_meta = vec![row_group_meta]; + + let version = 2; + let num_rows = 1000; + let aad_file_unique = vec![1u8; 8]; + let aad_prefix = vec![2u8; 8]; + let encryption_algorithm = EncryptionAlgorithm::AES_GCM_V1(AesGcmV1 { + aad_prefix: Some(aad_prefix.clone()), + aad_file_unique: Some(aad_file_unique.clone()), + supply_aad_prefix: Some(true), + }); + let footer_key_metadata = Some(vec![3u8; 8]); + let file_metadata = + FileMetaData::new(version, num_rows, None, None, schema_descr.clone(), None) + .with_encryption_algorithm(Some(encryption_algorithm)) + .with_footer_signing_key_metadata(footer_key_metadata.clone()); + + let parquet_meta_data = ParquetMetaDataBuilder::new(file_metadata.clone()) + .set_row_groups(row_group_meta.clone()) + .build(); + + let base_expected_size = 2058; + assert_eq!(parquet_meta_data.memory_size(), base_expected_size); + + let footer_key = "0123456789012345".as_bytes(); + let column_key = "1234567890123450".as_bytes(); + let mut decryption_properties_builder = + FileDecryptionProperties::builder(footer_key.to_vec()) + .with_aad_prefix(aad_prefix.clone()); + for column in schema_descr.columns() { + decryption_properties_builder = decryption_properties_builder + .with_column_key(&column.path().string(), column_key.to_vec()); + } + let decryption_properties = decryption_properties_builder.build().unwrap(); + let decryptor = FileDecryptor::new( + &decryption_properties, + footer_key_metadata.as_deref(), + aad_file_unique, + aad_prefix, + ) + .unwrap(); + + let parquet_meta_data = ParquetMetaDataBuilder::new(file_metadata.clone()) + .set_row_groups(row_group_meta.clone()) + .set_file_decryptor(Some(decryptor)) + .build(); + + let expected_size_with_decryptor = 3072; + assert!(expected_size_with_decryptor > base_expected_size); + + assert_eq!( + parquet_meta_data.memory_size(), + expected_size_with_decryptor + ); + } + /// Returns sample schema descriptor so we can create column metadata. fn get_test_schema_descr() -> SchemaDescPtr { let schema = SchemaType::group_type_builder("schema") diff --git a/parquet/src/file/reader.rs b/parquet/src/file/reader.rs index 61af21a68ec1..3adf10fac220 100644 --- a/parquet/src/file/reader.rs +++ b/parquet/src/file/reader.rs @@ -124,11 +124,25 @@ impl ChunkReader for Bytes { fn get_read(&self, start: u64) -> Result { let start = start as usize; + if start > self.len() { + return Err(eof_err!( + "Expected to read at offset {start}, while file has length {}", + self.len() + )); + } Ok(self.slice(start..).reader()) } fn get_bytes(&self, start: u64, length: usize) -> Result { let start = start as usize; + if start > self.len() || start + length > self.len() { + return Err(eof_err!( + "Expected to read {} bytes at offset {}, while file has length {}", + length, + start, + self.len() + )); + } Ok(self.slice(start..start + length)) } } @@ -274,3 +288,34 @@ impl Iterator for FilePageIterator { } impl PageIterator for FilePageIterator {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bytes_chunk_reader_get_read_out_of_bounds() { + let data = Bytes::from(vec![0, 1, 2, 3]); + let err = data.get_read(5).unwrap_err(); + assert_eq!( + err.to_string(), + "EOF: Expected to read at offset 5, while file has length 4" + ); + } + + #[test] + fn test_bytes_chunk_reader_get_bytes_out_of_bounds() { + let data = Bytes::from(vec![0, 1, 2, 3]); + let err = data.get_bytes(5, 1).unwrap_err(); + assert_eq!( + err.to_string(), + "EOF: Expected to read 1 bytes at offset 5, while file has length 4" + ); + + let err = data.get_bytes(2, 3).unwrap_err(); + assert_eq!( + err.to_string(), + "EOF: Expected to read 3 bytes at offset 2, while file has length 4" + ); + } +} diff --git a/parquet/src/file/serialized_reader.rs b/parquet/src/file/serialized_reader.rs index 6da5c39d745b..3f95ea9d4982 100644 --- a/parquet/src/file/serialized_reader.rs +++ b/parquet/src/file/serialized_reader.rs @@ -392,6 +392,9 @@ pub(crate) fn decode_page( let buffer = match decompressor { Some(decompressor) if can_decompress => { let uncompressed_page_size = usize::try_from(page_header.uncompressed_page_size)?; + if offset > buffer.len() || offset > uncompressed_page_size { + return Err(general_err!("Invalid page header")); + } let decompressed_size = uncompressed_page_size - offset; let mut decompressed = Vec::with_capacity(uncompressed_page_size); decompressed.extend_from_slice(&buffer.as_ref()[..offset]); @@ -458,7 +461,10 @@ pub(crate) fn decode_page( } _ => { // For unknown page type (e.g., INDEX_PAGE), skip and read next. - unimplemented!("Page type {:?} is not supported", page_header.r#type) + return Err(general_err!( + "Page type {:?} is not supported", + page_header.r#type + )); } }; @@ -1130,6 +1136,7 @@ mod tests { use crate::column::reader::ColumnReader; use crate::data_type::private::ParquetValueType; use crate::data_type::{AsBytes, FixedLenByteArrayType, Int32Type}; + use crate::file::metadata::thrift::DataPageHeaderV2; #[allow(deprecated)] use crate::file::page_index::index_reader::{read_columns_indexes, read_offset_indexes}; use crate::file::writer::SerializedFileWriter; @@ -1139,6 +1146,72 @@ mod tests { use super::*; + #[test] + fn test_decode_page_invalid_offset() { + let page_header = PageHeader { + r#type: PageType::DATA_PAGE_V2, + uncompressed_page_size: 10, + compressed_page_size: 10, + data_page_header: None, + index_page_header: None, + dictionary_page_header: None, + crc: None, + data_page_header_v2: Some(DataPageHeaderV2 { + num_nulls: 0, + num_rows: 0, + num_values: 0, + encoding: Encoding::PLAIN, + definition_levels_byte_length: 11, + repetition_levels_byte_length: 0, + is_compressed: None, + statistics: None, + }), + }; + + let buffer = Bytes::new(); + let err = decode_page(page_header, buffer, Type::INT32, None).unwrap_err(); + assert!( + err.to_string() + .contains("DataPage v2 header contains implausible values") + ); + } + + #[test] + fn test_decode_unsupported_page() { + let mut page_header = PageHeader { + r#type: PageType::INDEX_PAGE, + uncompressed_page_size: 10, + compressed_page_size: 10, + data_page_header: None, + index_page_header: None, + dictionary_page_header: None, + crc: None, + data_page_header_v2: None, + }; + let buffer = Bytes::new(); + let err = decode_page(page_header.clone(), buffer.clone(), Type::INT32, None).unwrap_err(); + assert_eq!( + err.to_string(), + "Parquet error: Page type INDEX_PAGE is not supported" + ); + + page_header.data_page_header_v2 = Some(DataPageHeaderV2 { + num_nulls: 0, + num_rows: 0, + num_values: 0, + encoding: Encoding::PLAIN, + definition_levels_byte_length: 11, + repetition_levels_byte_length: 0, + is_compressed: None, + statistics: None, + }); + let err = decode_page(page_header, buffer, Type::INT32, None).unwrap_err(); + assert!( + err.to_string() + .contains("DataPage v2 header contains implausible values") + ); + } + #[test] fn test_cursor_and_file_has_the_same_behaviour() { let mut buf: Vec = Vec::new(); diff --git a/parquet/src/parquet_macros.rs b/parquet/src/parquet_macros.rs index eb8bc2b7f07a..714015e10e32 100644 --- a/parquet/src/parquet_macros.rs +++ b/parquet/src/parquet_macros.rs @@ -36,7 +36,9 @@ #[allow(clippy::crate_in_macro_def)] /// Macro used to generate rust enums from a Thrift `enum` definition. /// -/// When utilizing this macro the Thrift serialization traits and structs need to be in scope. +/// Note: +/// - All enums generated with this macro will have `pub` visibility. +/// - When utilizing this macro the Thrift serialization traits and structs need to be in scope. macro_rules! thrift_enum { ($(#[$($def_attrs:tt)*])* enum $identifier:ident { $($(#[$($field_attrs:tt)*])* $field_name:ident = $field_value:literal;)* }) => { $(#[$($def_attrs)*])* @@ -79,6 +81,35 @@ macro_rules! thrift_enum { Ok(field_id) } } + + impl $identifier { + #[allow(deprecated)] + #[doc = "Returns a slice containing every variant of this enum."] + #[allow(dead_code)] + pub const VARIANTS: &'static [Self] = &[ + $(Self::$field_name),* + ]; + + #[allow(deprecated)] + const fn max_discriminant_impl() -> i32 { + let values: &[i32] = &[$($field_value),*]; + let mut max = values[0]; + let mut idx = 1; + while idx < values.len() { + let candidate = values[idx]; + if candidate > max { + max = candidate; + } + idx += 1; + } + max + } + + #[allow(deprecated)] + #[doc = "Returns the largest discriminant value defined for this enum."] + #[allow(dead_code)] + pub const MAX_DISCRIMINANT: i32 = Self::max_discriminant_impl(); + } } } @@ -91,7 +122,9 @@ macro_rules! thrift_enum { /// /// The resulting Rust enum will have all unit variants. /// -/// When utilizing this macro the Thrift serialization traits and structs need to be in scope. +/// Note: +/// - All enums generated with this macro will have `pub` visibility. +/// - When utilizing this macro the Thrift serialization traits and structs need to be in scope. #[doc(hidden)] #[macro_export] #[allow(clippy::crate_in_macro_def)] @@ -162,9 +195,10 @@ macro_rules! thrift_union_all_empty { /// non-empty type, the typename must be contained within parens (e.g. `1: MyType Var1;` becomes /// `1: (MyType) Var1;`). /// -/// This macro allows for specifying lifetime annotations for the resulting `enum` and its fields. -/// -/// When utilizing this macro the Thrift serialization traits and structs need to be in scope. +/// Note: +/// - All enums generated with this macro will have `pub` visibility. +/// - This macro allows for specifying lifetime annotations for the resulting `enum` and its fields. +/// - When utilizing this macro the Thrift serialization traits and structs need to be in scope. #[doc(hidden)] #[macro_export] #[allow(clippy::crate_in_macro_def)] @@ -228,9 +262,11 @@ macro_rules! thrift_union { /// Macro used to generate Rust structs from a Thrift `struct` definition. /// -/// This macro allows for specifying lifetime annotations for the resulting `struct` and its fields. -/// -/// When utilizing this macro the Thrift serialization traits and structs need to be in scope. +/// Note: +/// - This macro allows for specifying the visibility of the resulting `struct` and its fields. +/// + The `struct` and all fields will have the same visibility. +/// - This macro allows for specifying lifetime annotations for the resulting `struct` and its fields. +/// - When utilizing this macro the Thrift serialization traits and structs need to be in scope. #[doc(hidden)] #[macro_export] macro_rules! thrift_struct { diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index 1ae37d0a462f..50ae4955380b 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -845,7 +845,9 @@ pub struct ColumnDescriptor { impl HeapSize for ColumnDescriptor { fn heap_size(&self) -> usize { - self.primitive_type.heap_size() + self.path.heap_size() + // Don't include the heap size of primitive_type, this is already + // accounted for via SchemaDescriptor::schema + self.path.heap_size() } } @@ -1348,19 +1350,23 @@ fn schema_from_array_helper<'a>( .with_logical_type(logical_type) .with_fields(fields) .with_id(field_id); - if let Some(rep) = repetition { - // Sometimes parquet-cpp and parquet-mr set repetition level REQUIRED or - // REPEATED for root node. - // - // We only set repetition for group types that are not top-level message - // type. According to parquet-format: - // Root of the schema does not have a repetition_type. - // All other types must have one. - if !is_root_node { - builder = builder.with_repetition(rep); - } + + // Sometimes parquet-cpp and parquet-mr set repetition level REQUIRED or + // REPEATED for root node. + // + // We only set repetition for group types that are not top-level message + // type. According to parquet-format: + // Root of the schema does not have a repetition_type. + // All other types must have one. + if !is_root_node { + let Some(rep) = repetition else { + return Err(general_err!( + "Repetition level must be defined for non-root types" + )); + }; + builder = builder.with_repetition(rep); } - Ok((next_index, Arc::new(builder.build().unwrap()))) + Ok((next_index, Arc::new(builder.build()?))) } } } diff --git a/parquet/tests/arrow_reader/bad_data.rs b/parquet/tests/arrow_reader/bad_data.rs index 235f81812468..54c92976e41c 100644 --- a/parquet/tests/arrow_reader/bad_data.rs +++ b/parquet/tests/arrow_reader/bad_data.rs @@ -84,10 +84,12 @@ fn test_parquet_1481() { } #[test] -#[should_panic(expected = "assertion failed: self.current_value.is_some()")] fn test_arrow_gh_41321() { let err = read_file("ARROW-GH-41321.parquet").unwrap_err(); - assert_eq!(err.to_string(), "TBD (currently panics)"); + assert_eq!( + err.to_string(), + "External: Parquet argument error: Parquet error: Invalid or corrupted RLE bit width 254. Max allowed is 32" + ); } #[test]