From 12eaa17cca219e76e7076ecfbf31d72c8bfb2689 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Fri, 24 Oct 2025 07:23:12 -0700 Subject: [PATCH 01/19] [Minor]: Document visibility for enums produced by Thrift macros (#8706) # Which issue does this PR close? N/A # Rationale for this change It is not obvious that the thrift macros produce public enums only (e.g. see https://github.com/apache/arrow-rs/pull/8680#discussion_r2460381795). This should be made clear in the documentation. # What changes are included in this PR? Add said clarification. # Are these changes tested? Documentation only, so no tests required. # Are there any user-facing changes? No, only changes to private documentation --- parquet/THRIFT.md | 7 ++++++- parquet/src/parquet_macros.rs | 23 +++++++++++++++-------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/parquet/THRIFT.md b/parquet/THRIFT.md index 56365665070a..599b33f2bce3 100644 --- a/parquet/THRIFT.md +++ b/parquet/THRIFT.md @@ -57,7 +57,7 @@ The `thrift_enum` macro can be used in this instance. ```rust thrift_enum!( - enum Type { +enum Type { BOOLEAN = 0; INT32 = 1; INT64 = 2; @@ -85,6 +85,8 @@ pub enum Type { } ``` +All Rust `enum`s produced with this macro will have `pub` visibility. + ### Unions Thrift unions are a special kind of struct in which only a single field is populated. In this @@ -175,6 +177,9 @@ pub enum ColumnCryptoMetaData { } ``` +All Rust `enum`s produced with either macro will have `pub` visibility. `thrift_union` also allows +for lifetime annotations, but this capability is not currently utilized. + ### Structs The `thrift_struct` macro is used for structs. This macro is a little more flexible than the others diff --git a/parquet/src/parquet_macros.rs b/parquet/src/parquet_macros.rs index eb8bc2b7f07a..d23bc9869945 100644 --- a/parquet/src/parquet_macros.rs +++ b/parquet/src/parquet_macros.rs @@ -36,7 +36,9 @@ #[allow(clippy::crate_in_macro_def)] /// Macro used to generate rust enums from a Thrift `enum` definition. /// -/// When utilizing this macro the Thrift serialization traits and structs need to be in scope. +/// Note: +/// - All enums generated with this macro will have `pub` visibility. +/// - When utilizing this macro the Thrift serialization traits and structs need to be in scope. macro_rules! thrift_enum { ($(#[$($def_attrs:tt)*])* enum $identifier:ident { $($(#[$($field_attrs:tt)*])* $field_name:ident = $field_value:literal;)* }) => { $(#[$($def_attrs)*])* @@ -91,7 +93,9 @@ macro_rules! thrift_enum { /// /// The resulting Rust enum will have all unit variants. /// -/// When utilizing this macro the Thrift serialization traits and structs need to be in scope. +/// Note: +/// - All enums generated with this macro will have `pub` visibility. +/// - When utilizing this macro the Thrift serialization traits and structs need to be in scope. #[doc(hidden)] #[macro_export] #[allow(clippy::crate_in_macro_def)] @@ -162,9 +166,10 @@ macro_rules! thrift_union_all_empty { /// non-empty type, the typename must be contained within parens (e.g. `1: MyType Var1;` becomes /// `1: (MyType) Var1;`). /// -/// This macro allows for specifying lifetime annotations for the resulting `enum` and its fields. -/// -/// When utilizing this macro the Thrift serialization traits and structs need to be in scope. +/// Note: +/// - All enums generated with this macro will have `pub` visibility. +/// - This macro allows for specifying lifetime annotations for the resulting `enum` and its fields. +/// - When utilizing this macro the Thrift serialization traits and structs need to be in scope. #[doc(hidden)] #[macro_export] #[allow(clippy::crate_in_macro_def)] @@ -228,9 +233,11 @@ macro_rules! thrift_union { /// Macro used to generate Rust structs from a Thrift `struct` definition. /// -/// This macro allows for specifying lifetime annotations for the resulting `struct` and its fields. -/// -/// When utilizing this macro the Thrift serialization traits and structs need to be in scope. +/// Note: +/// - This macro allows for specifying the visibility of the resulting `struct` and its fields. +/// + The `struct` and all fields will have the same visibility. +/// - This macro allows for specifying lifetime annotations for the resulting `struct` and its fields. +/// - When utilizing this macro the Thrift serialization traits and structs need to be in scope. #[doc(hidden)] #[macro_export] macro_rules! thrift_struct { From 9c1b03b300eb94fa9b21650ec8b971e9fb7ad6a4 Mon Sep 17 00:00:00 2001 From: Connor Sanders <170039284+jecsand838@users.noreply.github.com> Date: Fri, 24 Oct 2025 09:33:03 -0500 Subject: [PATCH 02/19] Update `arrow-avro` `README.md` version to 57 (#8695) # Which issue does this PR close? - Closes #8691 # Rationale for this change The `README.md` file in `arrow-avro` instructs users to install version 56. This is invalid and should be changed to version 57. # What changes are included in this PR? Updated the `README.md` file to reference version 57. # Are these changes tested? N/A since this a small `README.md` file change. # Are there any user-facing changes? The `README.md` file in `arrow-avro` now instructs users to install version 57. --------- Co-authored-by: Andrew Lamb --- arrow-avro/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arrow-avro/README.md b/arrow-avro/README.md index f89fc97d242f..85fd76094755 100644 --- a/arrow-avro/README.md +++ b/arrow-avro/README.md @@ -44,14 +44,14 @@ This crate provides: ```toml [dependencies] -arrow-avro = "56" +arrow-avro = "57.0.0" ```` Disable defaults and pick only what you need (see **Feature Flags**): ```toml [dependencies] -arrow-avro = { version = "56", default-features = false, features = ["deflate", "snappy"] } +arrow-avro = { version = "57.0.0", default-features = false, features = ["deflate", "snappy"] } ``` --- From 928daa469a9191992e32c8b7ed4f5c688624dc08 Mon Sep 17 00:00:00 2001 From: lichuang Date: Fri, 24 Oct 2025 23:25:20 +0800 Subject: [PATCH 03/19] chore: add test case of RowSelection::trim (#8660) # Which issue does this PR close? chore: add test case of `RowSelection::trim` # Rationale for this change Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. # What changes are included in this PR? There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. # Are these changes tested? We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? # Are there any user-facing changes? If there are user-facing changes then we may require documentation to be updated before approving the PR. If there are any breaking changes to public APIs, please call them out. --- parquet/src/arrow/arrow_reader/selection.rs | 29 +++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/parquet/src/arrow/arrow_reader/selection.rs b/parquet/src/arrow/arrow_reader/selection.rs index 9c3caec0b4a5..adbbff1ca2df 100644 --- a/parquet/src/arrow/arrow_reader/selection.rs +++ b/parquet/src/arrow/arrow_reader/selection.rs @@ -1432,4 +1432,33 @@ mod tests { assert_eq!(selection.row_count(), 0); assert_eq!(selection.skipped_row_count(), 0); } + + #[test] + fn test_trim() { + let selection = RowSelection::from(vec![ + RowSelector::skip(34), + RowSelector::select(12), + RowSelector::skip(3), + RowSelector::select(35), + ]); + + let expected = vec![ + RowSelector::skip(34), + RowSelector::select(12), + RowSelector::skip(3), + RowSelector::select(35), + ]; + + assert_eq!(selection.trim().selectors, expected); + + let selection = RowSelection::from(vec![ + RowSelector::skip(34), + RowSelector::select(12), + RowSelector::skip(3), + ]); + + let expected = vec![RowSelector::skip(34), RowSelector::select(12)]; + + assert_eq!(selection.trim().selectors, expected); + } } From 729b25817e5fa83ac4ecf4586e1649289323d8b9 Mon Sep 17 00:00:00 2001 From: huang qiwei Date: Fri, 24 Oct 2025 23:27:22 +0800 Subject: [PATCH 04/19] [parquet] Adding counting method in thrift_enum macro to support ENCODING_SLOTS (#8663) # Which issue does this PR close? - Closes [#8662] # Rationale for this change Related to https://github.com/apache/arrow-rs/pull/8607 We need to know how many encoding are support to create a decoder slot. # What changes are included in this PR? Update the `thrift_enum` to know the fields count of enum `Encoding`, and the value is passed to `EncodingMask` And the `ENCODING_SLOTS` # Are these changes tested? 1. Originally I think add a UT can prevent failure after the new encoding are introduced, then I realized the counts are already transferred, the UT is not required, the original tests can already cover the code. # Are there any user-facing changes? No --- parquet/src/basic.rs | 2 +- parquet/src/column/reader/decoder.rs | 2 +- parquet/src/parquet_macros.rs | 29 ++++++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs index 7f50eada46de..def69f251581 100644 --- a/parquet/src/basic.rs +++ b/parquet/src/basic.rs @@ -741,7 +741,7 @@ pub struct EncodingMask(i32); impl EncodingMask { /// Highest valued discriminant in the [`Encoding`] enum - const MAX_ENCODING: i32 = Encoding::BYTE_STREAM_SPLIT as i32; + const MAX_ENCODING: i32 = Encoding::MAX_DISCRIMINANT; /// A mask consisting of unused bit positions, used for validation. This includes the never /// used GROUP_VAR_INT encoding value of `1`. const ALLOWED_MASK: u32 = diff --git a/parquet/src/column/reader/decoder.rs b/parquet/src/column/reader/decoder.rs index 1d4e2f751181..e49906207577 100644 --- a/parquet/src/column/reader/decoder.rs +++ b/parquet/src/column/reader/decoder.rs @@ -138,7 +138,7 @@ pub trait ColumnValueDecoder { /// /// This replaces `HashMap` lookups with direct indexing to avoid hashing overhead in the /// hot decoding paths. -const ENCODING_SLOTS: usize = Encoding::BYTE_STREAM_SPLIT as usize + 1; +const ENCODING_SLOTS: usize = Encoding::MAX_DISCRIMINANT as usize + 1; /// An implementation of [`ColumnValueDecoder`] for `[T::T]` pub struct ColumnValueDecoderImpl { diff --git a/parquet/src/parquet_macros.rs b/parquet/src/parquet_macros.rs index d23bc9869945..714015e10e32 100644 --- a/parquet/src/parquet_macros.rs +++ b/parquet/src/parquet_macros.rs @@ -81,6 +81,35 @@ macro_rules! thrift_enum { Ok(field_id) } } + + impl $identifier { + #[allow(deprecated)] + #[doc = "Returns a slice containing every variant of this enum."] + #[allow(dead_code)] + pub const VARIANTS: &'static [Self] = &[ + $(Self::$field_name),* + ]; + + #[allow(deprecated)] + const fn max_discriminant_impl() -> i32 { + let values: &[i32] = &[$($field_value),*]; + let mut max = values[0]; + let mut idx = 1; + while idx < values.len() { + let candidate = values[idx]; + if candidate > max { + max = candidate; + } + idx += 1; + } + max + } + + #[allow(deprecated)] + #[doc = "Returns the largest discriminant value defined for this enum."] + #[allow(dead_code)] + pub const MAX_DISCRIMINANT: i32 = Self::max_discriminant_impl(); + } } } From 021090f29d6ff458bb208c576876a869cc224f59 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Fri, 24 Oct 2025 11:27:44 -0400 Subject: [PATCH 05/19] [Variant] Remove `create_test_variant_array` helper method (#8664) - A follow up from https://github.com/apache/arrow-rs/pull/8625 # Rationale for this change While working on a separate task, I noticed `create_test_variant_array` was redundant. Since `VariantArray` can already be constructed directly from an iterator of Variants, this PR removes the now-unnecessary test helper. --- parquet-variant-compute/src/shred_variant.rs | 30 ++++++-------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/parquet-variant-compute/src/shred_variant.rs b/parquet-variant-compute/src/shred_variant.rs index d5635291f712..f8158b2211a2 100644 --- a/parquet-variant-compute/src/shred_variant.rs +++ b/parquet-variant-compute/src/shred_variant.rs @@ -331,22 +331,11 @@ mod tests { use parquet_variant::{ObjectBuilder, ReadOnlyMetadataBuilder, Variant, VariantBuilder}; use std::sync::Arc; - fn create_test_variant_array(values: Vec>>) -> VariantArray { - let mut builder = VariantArrayBuilder::new(values.len()); - for value in values { - match value { - Some(v) => builder.append_variant(v), - None => builder.append_null(), - } - } - builder.build() - } - #[test] fn test_already_shredded_input_error() { // Create a VariantArray that already has typed_value_field // First create a valid VariantArray, then extract its parts to construct a shredded one - let temp_array = create_test_variant_array(vec![Some(Variant::from("test"))]); + let temp_array = VariantArray::from_iter(vec![Some(Variant::from("test"))]); let metadata = temp_array.metadata_field().clone(); let value = temp_array.value_field().unwrap().clone(); let typed_value = Arc::new(Int64Array::from(vec![42])) as ArrayRef; @@ -375,7 +364,7 @@ mod tests { #[test] fn test_unsupported_list_schema() { - let input = create_test_variant_array(vec![Some(Variant::from(42))]); + let input = VariantArray::from_iter([Variant::from(42)]); let list_schema = DataType::List(Arc::new(Field::new("item", DataType::Int64, true))); shred_variant(&input, &list_schema).expect_err("unsupported"); } @@ -383,7 +372,7 @@ mod tests { #[test] fn test_primitive_shredding_comprehensive() { // Test mixed scenarios in a single array - let input = create_test_variant_array(vec![ + let input = VariantArray::from_iter(vec![ Some(Variant::from(42i64)), // successful shred Some(Variant::from("hello")), // failed shred (string) Some(Variant::from(100i64)), // successful shred @@ -448,10 +437,10 @@ mod tests { #[test] fn test_primitive_different_target_types() { - let input = create_test_variant_array(vec![ - Some(Variant::from(42i32)), - Some(Variant::from(3.15f64)), - Some(Variant::from("not_a_number")), + let input = VariantArray::from_iter(vec![ + Variant::from(42i32), + Variant::from(3.15f64), + Variant::from("not_a_number"), ]); // Test Int32 target @@ -882,10 +871,7 @@ mod tests { #[test] fn test_spec_compliance() { - let input = create_test_variant_array(vec![ - Some(Variant::from(42i64)), - Some(Variant::from("hello")), - ]); + let input = VariantArray::from_iter(vec![Variant::from(42i64), Variant::from("hello")]); let result = shred_variant(&input, &DataType::Int64).unwrap(); From 79575aa343c2a2aa4bf226554ec4e4f7ff1cb37e Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Fri, 24 Oct 2025 18:28:03 +0300 Subject: [PATCH 06/19] perf: zero-copy path in `RowConverter::from_binary` (#8686) # Which issue does this PR close? - Closes #8685. # What changes are included in this PR? In the implementation of `RowConverter::from_binary`, the `BinaryArray` is broken into parts and an attempt is made to convert the data buffer into `Vec` at no copying cost with `Buffer::into_vec`. Only if this fails, the data is copied out for a newly allocated `Vec`. # Are these changes tested? Passes existing tests using `RowConverter::from_binary`, which all convert a non-shared buffer taking advantage of the optimization. Another test is added to cover the copying path. # Are there any user-facing changes? No --- arrow-row/src/lib.rs | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index db7758047c43..5f690e9a6734 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -913,9 +913,13 @@ impl RowConverter { 0, "can't construct Rows instance from array with nulls" ); + let (offsets, values, _) = array.into_parts(); + let offsets = offsets.iter().map(|&i| i.as_usize()).collect(); + // Try zero-copy, if it does not succeed, fall back to copying the values. + let buffer = values.into_vec().unwrap_or_else(|values| values.to_vec()); Rows { - buffer: array.values().to_vec(), - offsets: array.offsets().iter().map(|&i| i.as_usize()).collect(), + buffer, + offsets, config: RowConfig { fields: Arc::clone(&self.fields), validate_utf8: true, @@ -2474,6 +2478,19 @@ mod tests { assert!(rows.row(3) < rows.row(0)); } + #[test] + fn test_from_binary_shared_buffer() { + let converter = RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap(); + let array = Arc::new(BinaryArray::from_iter_values([&[0xFF]])) as _; + let rows = converter.convert_columns(&[array]).unwrap(); + let binary_rows = rows.try_into_binary().expect("known-small rows"); + let _binary_rows_shared_buffer = binary_rows.clone(); + + let parsed = converter.from_binary(binary_rows); + + converter.convert_rows(parsed.iter()).unwrap(); + } + #[test] #[should_panic(expected = "Encountered non UTF-8 data")] fn test_invalid_utf8() { From e9a7fe576449e86541103423ca40c8c47ff3ec39 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Fri, 24 Oct 2025 19:47:15 +0200 Subject: [PATCH 07/19] Add `FilterPredicate::filter_record_batch` (#8693) # Which issue does this PR close? - Closes #8692. # Rationale for this change Explained in issue. # What changes are included in this PR? - Adds `FilterPredicate::filter_record_batch` - Adapts the free function `filter_record_batch` to use the new function - Uses `new_unchecked` to create the filtered result. The rationale for this is identical to #8583 # Are these changes tested? Covered by existing tests for `filter_record_batch` # Are there any user-facing changes? No --------- Co-authored-by: Martin Grigorov --- arrow-select/src/filter.rs | 52 +++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index dace2bab728f..5c21a4adcab7 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -122,6 +122,12 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray { /// Returns a filtered `values` [`Array`] where the corresponding elements of /// `predicate` are `true`. /// +/// If multiple arrays (or record batches) need to be filtered using the same predicate array, +/// consider using [FilterBuilder] to create a single [FilterPredicate] and then +/// calling [FilterPredicate::filter_record_batch]. +/// In contrast to this function, it is then the responsibility of the caller +/// to use [FilterBuilder::optimize] if appropriate. +/// /// # See also /// * [`FilterBuilder`] for more control over the filtering process. /// * [`filter_record_batch`] to filter a [`RecordBatch`] @@ -168,25 +174,28 @@ fn multiple_arrays(data_type: &DataType) -> bool { /// `predicate` are true. /// /// This is the equivalent of calling [filter] on each column of the [RecordBatch]. +/// +/// If multiple record batches (or arrays) need to be filtered using the same predicate array, +/// consider using [FilterBuilder] to create a single [FilterPredicate] and then +/// calling [FilterPredicate::filter_record_batch]. +/// In contrast to this function, it is then the responsibility of the caller +/// to use [FilterBuilder::optimize] if appropriate. pub fn filter_record_batch( record_batch: &RecordBatch, predicate: &BooleanArray, ) -> Result { let mut filter_builder = FilterBuilder::new(predicate); - if record_batch.num_columns() > 1 { - // Only optimize if filtering more than one column + let num_cols = record_batch.num_columns(); + if num_cols > 1 + || (num_cols > 0 && multiple_arrays(record_batch.schema_ref().field(0).data_type())) + { + // Only optimize if filtering more than one column or if the column contains multiple internal arrays // Otherwise, the overhead of optimization can be more than the benefit filter_builder = filter_builder.optimize(); } let filter = filter_builder.build(); - let filtered_arrays = record_batch - .columns() - .iter() - .map(|a| filter_array(a, &filter)) - .collect::, _>>()?; - let options = RecordBatchOptions::default().with_row_count(Some(filter.count())); - RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options) + filter.filter_record_batch(record_batch) } /// A builder to construct [`FilterPredicate`] @@ -300,6 +309,31 @@ impl FilterPredicate { filter_array(values, self) } + /// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this + /// [`FilterPredicate`]. + /// + /// This is the equivalent of calling [filter] on each column of the [`RecordBatch`]. + pub fn filter_record_batch( + &self, + record_batch: &RecordBatch, + ) -> Result { + let filtered_arrays = record_batch + .columns() + .iter() + .map(|a| filter_array(a, self)) + .collect::, _>>()?; + + // SAFETY: we know that the set of filtered arrays will match the schema of the original + // record batch + unsafe { + Ok(RecordBatch::new_unchecked( + record_batch.schema(), + filtered_arrays, + self.count, + )) + } + } + /// Number of rows being selected based on this [`FilterPredicate`] pub fn count(&self) -> usize { self.count From 99811f822164ad408c70df9561c1224cfa7f1cfe Mon Sep 17 00:00:00 2001 From: Alex Stephen <1325798+rambleraptor@users.noreply.github.com> Date: Sun, 26 Oct 2025 03:45:03 -0700 Subject: [PATCH 08/19] check bit width to avoid panic in DeltaBitPackDecoder (#8688) # Which issue does this PR close? We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. - Part of #7806 # Rationale for this change The `DeltaBitPackDecoder` can panic if it encounters a bit width in the encoded data that is larger than the bit width of the data type being decoded. --------- Co-authored-by: Andrew Lamb --- parquet/src/encodings/decoding.rs | 62 +++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/parquet/src/encodings/decoding.rs b/parquet/src/encodings/decoding.rs index 91b31dbdfcd2..de8738cf09f9 100644 --- a/parquet/src/encodings/decoding.rs +++ b/parquet/src/encodings/decoding.rs @@ -631,6 +631,19 @@ where self.next_block() } } + + /// Verify the bit width is smaller then the integer type that it is trying to decode. + #[inline] + fn check_bit_width(&self, bit_width: usize) -> Result<()> { + if bit_width > std::mem::size_of::() * 8 { + return Err(general_err!( + "Invalid delta bit width {} which is larger than expected {} ", + bit_width, + std::mem::size_of::() * 8 + )); + } + Ok(()) + } } impl Decoder for DeltaBitPackDecoder @@ -726,6 +739,7 @@ where } let bit_width = self.mini_block_bit_widths[self.mini_block_idx] as usize; + self.check_bit_width(bit_width)?; let batch_to_read = self.mini_block_remaining.min(to_read - read); let batch_read = self @@ -796,6 +810,7 @@ where } let bit_width = self.mini_block_bit_widths[self.mini_block_idx] as usize; + self.check_bit_width(bit_width)?; let mini_block_to_skip = self.mini_block_remaining.min(to_skip - skip); let mini_block_should_skip = mini_block_to_skip; @@ -2091,4 +2106,51 @@ mod tests { v } } + + #[test] + // Allow initializing a vector and pushing to it for clarity in this test + #[allow(clippy::vec_init_then_push)] + fn test_delta_bit_packed_invalid_bit_width() { + // Manually craft a buffer with an invalid bit width + let mut buffer = vec![]; + // block_size = 128 + buffer.push(128); + buffer.push(1); + // mini_blocks_per_block = 4 + buffer.push(4); + // num_values = 32 + buffer.push(32); + // first_value = 0 + buffer.push(0); + // min_delta = 0 + buffer.push(0); + // bit_widths, one for each of the 4 mini blocks + buffer.push(33); // Invalid bit width + buffer.push(0); + buffer.push(0); + buffer.push(0); + + let corrupted_buffer = Bytes::from(buffer); + + let mut decoder = DeltaBitPackDecoder::::new(); + decoder.set_data(corrupted_buffer.clone(), 32).unwrap(); + let mut read_buffer = vec![0; 32]; + let err = decoder.get(&mut read_buffer).unwrap_err(); + assert!( + err.to_string() + .contains("Invalid delta bit width 33 which is larger than expected 32"), + "{}", + err + ); + + let mut decoder = DeltaBitPackDecoder::::new(); + decoder.set_data(corrupted_buffer, 32).unwrap(); + let err = decoder.skip(32).unwrap_err(); + assert!( + err.to_string() + .contains("Invalid delta bit width 33 which is larger than expected 32"), + "{}", + err + ); + } } From 78bd20446c1cc54b398e925021933c14cb7784be Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Sun, 26 Oct 2025 03:59:22 -0700 Subject: [PATCH 09/19] [thrift-remodel] Use `thrift_enum` macro for `ConvertedType` (#8680) # Which issue does this PR close? - Part of #5853. # Rationale for this change While converting to the new Thrift model, the `ConvertedType` enum was done manually due to the `NONE` variant, which used the discriminant of `0`. This PR changes that to `-1` which allows the `thrift_enum` macro to be used instead. This improves code maintainability. # What changes are included in this PR? See above. # Are these changes tested? Covered by existing tests # Are there any user-facing changes? No, this only changes the discriminant value for a unit variant enum. --- parquet/src/basic.rs | 204 ++++++++++++++++--------------------------- 1 file changed, 77 insertions(+), 127 deletions(-) diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs index def69f251581..7d9c1df37b3f 100644 --- a/parquet/src/basic.rs +++ b/parquet/src/basic.rs @@ -61,13 +61,10 @@ enum Type { // ---------------------------------------------------------------------- // Mirrors thrift enum `ConvertedType` -// -// Cannot use macros because of added field `None` // TODO(ets): Adding the `NONE` variant to this enum is a bit awkward. We should -// look into removing it and using `Option` instead. Then all of this -// handwritten code could go away. - +// look into removing it and using `Option` instead. +thrift_enum!( /// Common types (converted types) used by frameworks when using Parquet. /// /// This helps map between types in those frameworks to the base types in Parquet. @@ -75,142 +72,101 @@ enum Type { /// /// This struct was renamed from `LogicalType` in version 4.0.0. /// If targeting Parquet format 2.4.0 or above, please use [LogicalType] instead. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub enum ConvertedType { - /// No type conversion. - NONE, - /// A BYTE_ARRAY actually contains UTF8 encoded chars. - UTF8, - - /// A map is converted as an optional field containing a repeated key/value pair. - MAP, - - /// A key/value pair is converted into a group of two fields. - MAP_KEY_VALUE, - - /// A list is converted into an optional field containing a repeated field for its - /// values. - LIST, - - /// An enum is converted into a binary field - ENUM, - - /// A decimal value. - /// This may be used to annotate binary or fixed primitive types. The - /// underlying byte array stores the unscaled value encoded as two's - /// complement using big-endian byte order (the most significant byte is the - /// zeroth element). - /// - /// This must be accompanied by a (maximum) precision and a scale in the - /// SchemaElement. The precision specifies the number of digits in the decimal - /// and the scale stores the location of the decimal point. For example 1.23 - /// would have precision 3 (3 total digits) and scale 2 (the decimal point is - /// 2 digits over). - DECIMAL, +enum ConvertedType { + /// Not defined in the spec, used internally to indicate no type conversion + NONE = -1; - /// A date stored as days since Unix epoch, encoded as the INT32 physical type. - DATE, + /// A BYTE_ARRAY actually contains UTF8 encoded chars. + UTF8 = 0; - /// The total number of milliseconds since midnight. The value is stored as an INT32 - /// physical type. - TIME_MILLIS, + /// A map is converted as an optional field containing a repeated key/value pair. + MAP = 1; - /// The total number of microseconds since midnight. The value is stored as an INT64 - /// physical type. - TIME_MICROS, + /// A key/value pair is converted into a group of two fields. + MAP_KEY_VALUE = 2; - /// Date and time recorded as milliseconds since the Unix epoch. - /// Recorded as a physical type of INT64. - TIMESTAMP_MILLIS, + /// A list is converted into an optional field containing a repeated field for its + /// values. + LIST = 3; - /// Date and time recorded as microseconds since the Unix epoch. - /// The value is stored as an INT64 physical type. - TIMESTAMP_MICROS, + /// An enum is converted into a BYTE_ARRAY field + ENUM = 4; - /// An unsigned 8 bit integer value stored as INT32 physical type. - UINT_8, + /// A decimal value. + /// + /// This may be used to annotate BYTE_ARRAY or FIXED_LEN_BYTE_ARRAY primitive + /// types. The underlying byte array stores the unscaled value encoded as two's + /// complement using big-endian byte order (the most significant byte is the + /// zeroth element). The value of the decimal is the value * 10^{-scale}. + /// + /// This must be accompanied by a (maximum) precision and a scale in the + /// SchemaElement. The precision specifies the number of digits in the decimal + /// and the scale stores the location of the decimal point. For example 1.23 + /// would have precision 3 (3 total digits) and scale 2 (the decimal point is + /// 2 digits over). + DECIMAL = 5; - /// An unsigned 16 bit integer value stored as INT32 physical type. - UINT_16, + /// A date stored as days since Unix epoch, encoded as the INT32 physical type. + DATE = 6; - /// An unsigned 32 bit integer value stored as INT32 physical type. - UINT_32, + /// The total number of milliseconds since midnight. The value is stored as an INT32 + /// physical type. + TIME_MILLIS = 7; - /// An unsigned 64 bit integer value stored as INT64 physical type. - UINT_64, + /// The total number of microseconds since midnight. The value is stored as an INT64 + /// physical type. + TIME_MICROS = 8; - /// A signed 8 bit integer value stored as INT32 physical type. - INT_8, + /// Date and time recorded as milliseconds since the Unix epoch. + /// Recorded as a physical type of INT64. + TIMESTAMP_MILLIS = 9; - /// A signed 16 bit integer value stored as INT32 physical type. - INT_16, + /// Date and time recorded as microseconds since the Unix epoch. + /// The value is stored as an INT64 physical type. + TIMESTAMP_MICROS = 10; - /// A signed 32 bit integer value stored as INT32 physical type. - INT_32, + /// An unsigned 8 bit integer value stored as INT32 physical type. + UINT_8 = 11; - /// A signed 64 bit integer value stored as INT64 physical type. - INT_64, + /// An unsigned 16 bit integer value stored as INT32 physical type. + UINT_16 = 12; - /// A JSON document embedded within a single UTF8 column. - JSON, + /// An unsigned 32 bit integer value stored as INT32 physical type. + UINT_32 = 13; - /// A BSON document embedded within a single BINARY column. - BSON, + /// An unsigned 64 bit integer value stored as INT64 physical type. + UINT_64 = 14; - /// An interval of time. - /// - /// This type annotates data stored as a FIXED_LEN_BYTE_ARRAY of length 12. - /// This data is composed of three separate little endian unsigned integers. - /// Each stores a component of a duration of time. The first integer identifies - /// the number of months associated with the duration, the second identifies - /// the number of days associated with the duration and the third identifies - /// the number of milliseconds associated with the provided duration. - /// This duration of time is independent of any particular timezone or date. - INTERVAL, -} + /// A signed 8 bit integer value stored as INT32 physical type. + INT_8 = 15; -impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for ConvertedType { - fn read_thrift(prot: &mut R) -> Result { - let val = prot.read_i32()?; - Ok(match val { - 0 => Self::UTF8, - 1 => Self::MAP, - 2 => Self::MAP_KEY_VALUE, - 3 => Self::LIST, - 4 => Self::ENUM, - 5 => Self::DECIMAL, - 6 => Self::DATE, - 7 => Self::TIME_MILLIS, - 8 => Self::TIME_MICROS, - 9 => Self::TIMESTAMP_MILLIS, - 10 => Self::TIMESTAMP_MICROS, - 11 => Self::UINT_8, - 12 => Self::UINT_16, - 13 => Self::UINT_32, - 14 => Self::UINT_64, - 15 => Self::INT_8, - 16 => Self::INT_16, - 17 => Self::INT_32, - 18 => Self::INT_64, - 19 => Self::JSON, - 20 => Self::BSON, - 21 => Self::INTERVAL, - _ => return Err(general_err!("Unexpected ConvertedType {}", val)), - }) - } -} + /// A signed 16 bit integer value stored as INT32 physical type. + INT_16 = 16; -impl WriteThrift for ConvertedType { - const ELEMENT_TYPE: ElementType = ElementType::I32; + /// A signed 32 bit integer value stored as INT32 physical type. + INT_32 = 17; - fn write_thrift(&self, writer: &mut ThriftCompactOutputProtocol) -> Result<()> { - // because we've added NONE, the variant values are off by 1, so correct that here - writer.write_i32(*self as i32 - 1) - } -} + /// A signed 64 bit integer value stored as INT64 physical type. + INT_64 = 18; + + /// A JSON document embedded within a single UTF8 column. + JSON = 19; -write_thrift_field!(ConvertedType, FieldType::I32); + /// A BSON document embedded within a single BINARY column. + BSON = 20; + + /// An interval of time + /// + /// This type annotates data stored as a FIXED_LEN_BYTE_ARRAY of length 12. + /// This data is composed of three separate little endian unsigned integers. + /// Each stores a component of a duration of time. The first integer identifies + /// the number of months associated with the duration, the second identifies + /// the number of days associated with the duration and the third identifies + /// the number of milliseconds associated with the provided duration. + /// This duration of time is independent of any particular timezone or date. + INTERVAL = 21; +} +); // ---------------------------------------------------------------------- // Mirrors thrift union `TimeUnit` @@ -1327,12 +1283,6 @@ impl WriteThrift for ColumnOrder { // ---------------------------------------------------------------------- // Display handlers -impl fmt::Display for ConvertedType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self:?}") - } -} - impl fmt::Display for Compression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self:?}") From c14902727cf41fa73ed23acebca82c4eefed5d5f Mon Sep 17 00:00:00 2001 From: Vegard Stikbakke Date: Sun, 26 Oct 2025 12:44:58 +0100 Subject: [PATCH 10/19] Cast support for RunEndEncoded arrays (#8589) # Which issue does this PR close? - Contribues towards the RunEndEncoded (REE) epic #3520, but there is no specific issue for casting. - Replaces PRs https://github.com/apache/arrow-rs/pull/7713 and https://github.com/apache/arrow-rs/pull/8384. # Rationale for this change This PR implements casting support for RunEndEncoded arrays in Apache Arrow. # What changes are included in this PR? - `run_end_encoded_cast` in `arrow-cast/src/cast/run_array.rs` - `cast_to_run_end_encoded` in `arrow-cast/src/cast/run_array.rs` - Tests in `arrow-cast/src/cast/mod.rs` # Are these changes tested? Yes! # Are there any user-facing changes? No breaking changes, just new functionality --------- Co-authored-by: Richard Baah Co-authored-by: Andrew Lamb --- arrow-cast/Cargo.toml | 1 + arrow-cast/src/cast/mod.rs | 465 ++++++++++++++++++++++++++++++- arrow-cast/src/cast/run_array.rs | 164 +++++++++++ 3 files changed, 622 insertions(+), 8 deletions(-) create mode 100644 arrow-cast/src/cast/run_array.rs diff --git a/arrow-cast/Cargo.toml b/arrow-cast/Cargo.toml index 12da1af79fe0..f3309783fb38 100644 --- a/arrow-cast/Cargo.toml +++ b/arrow-cast/Cargo.toml @@ -43,6 +43,7 @@ force_validate = [] arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-data = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } arrow-select = { workspace = true } chrono = { workspace = true } diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index fe38298b017c..bb3247ca3c3c 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -41,11 +41,13 @@ mod decimal; mod dictionary; mod list; mod map; +mod run_array; mod string; use crate::cast::decimal::*; use crate::cast::dictionary::*; use crate::cast::list::*; use crate::cast::map::*; +use crate::cast::run_array::*; use crate::cast::string::*; use arrow_buffer::IntervalMonthDayNano; @@ -139,6 +141,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { can_cast_types(from_value_type, to_value_type) } (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), + (RunEndEncoded(_, value_type), _) => can_cast_types(value_type.data_type(), to_type), + (_, RunEndEncoded(_, value_type)) => can_cast_types(from_type, value_type.data_type()), (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), (List(list_from) | LargeList(list_from), List(list_to) | LargeList(list_to)) => { can_cast_types(list_from.data_type(), list_to.data_type()) @@ -791,6 +795,37 @@ pub fn cast_with_options( | Map(_, _) | Dictionary(_, _), ) => Ok(new_null_array(to_type, array.len())), + (RunEndEncoded(index_type, _), _) => match index_type.data_type() { + Int16 => run_end_encoded_cast::(array, to_type, cast_options), + Int32 => run_end_encoded_cast::(array, to_type, cast_options), + Int64 => run_end_encoded_cast::(array, to_type, cast_options), + _ => Err(ArrowError::CastError(format!( + "Casting from run end encoded type {from_type:?} to {to_type:?} not supported", + ))), + }, + (_, RunEndEncoded(index_type, value_type)) => { + let array_ref = make_array(array.to_data()); + match index_type.data_type() { + Int16 => cast_to_run_end_encoded::( + &array_ref, + value_type.data_type(), + cast_options, + ), + Int32 => cast_to_run_end_encoded::( + &array_ref, + value_type.data_type(), + cast_options, + ), + Int64 => cast_to_run_end_encoded::( + &array_ref, + value_type.data_type(), + cast_options, + ), + _ => Err(ArrowError::CastError(format!( + "Casting from type {from_type:?} to run end encoded type {to_type:?} not supported", + ))), + } + } (Dictionary(index_type, _), _) => match **index_type { Int8 => dictionary_cast::(array, to_type, cast_options), Int16 => dictionary_cast::(array, to_type, cast_options), @@ -2640,10 +2675,14 @@ where #[cfg(test)] mod tests { use super::*; + use DataType::*; + use arrow_array::{Int64Array, RunArray, StringArray}; use arrow_buffer::i256; use arrow_buffer::{Buffer, IntervalDayTime, NullBuffer}; + use arrow_schema::{DataType, Field}; use chrono::NaiveDate; use half::f16; + use std::sync::Arc; #[derive(Clone)] struct DecimalCastTestConfig { @@ -7794,8 +7833,6 @@ mod tests { #[test] fn test_cast_utf8_dict() { // FROM a dictionary with of Utf8 values - use DataType::*; - let mut builder = StringDictionaryBuilder::::new(); builder.append("one").unwrap(); builder.append_null(); @@ -7850,7 +7887,6 @@ mod tests { #[test] fn test_cast_dict_to_dict_bad_index_value_primitive() { - use DataType::*; // test converting from an array that has indexes of a type // that are out of bounds for a particular other kind of // index. @@ -7878,7 +7914,6 @@ mod tests { #[test] fn test_cast_dict_to_dict_bad_index_value_utf8() { - use DataType::*; // Same test as test_cast_dict_to_dict_bad_index_value but use // string values (and encode the expected behavior here); @@ -7907,8 +7942,6 @@ mod tests { #[test] fn test_cast_primitive_dict() { // FROM a dictionary with of INT32 values - use DataType::*; - let mut builder = PrimitiveDictionaryBuilder::::new(); builder.append(1).unwrap(); builder.append_null(); @@ -7929,8 +7962,6 @@ mod tests { #[test] fn test_cast_primitive_array_to_dict() { - use DataType::*; - let mut builder = PrimitiveBuilder::::new(); builder.append_value(1); builder.append_null(); @@ -11417,4 +11448,422 @@ mod tests { "Invalid argument error: -1.0 is too small to store in a Decimal32 of precision 1. Min is -0.9" ); } + + #[test] + fn test_run_end_encoded_to_primitive() { + // Create a RunEndEncoded array: [1, 1, 2, 2, 2, 3] + let run_ends = Int32Array::from(vec![2, 5, 6]); + let values = Int32Array::from(vec![1, 2, 3]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(run_array) as ArrayRef; + // Cast to Int64 + let cast_result = cast(&array_ref, &DataType::Int64).unwrap(); + // Verify the result is a RunArray with Int64 values + let result_run_array = cast_result.as_any().downcast_ref::().unwrap(); + assert_eq!( + result_run_array.values(), + &[1i64, 1i64, 2i64, 2i64, 2i64, 3i64] + ); + } + + #[test] + fn test_run_end_encoded_to_string() { + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Int32Array::from(vec![10, 20, 30]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(run_array) as ArrayRef; + + // Cast to String + let cast_result = cast(&array_ref, &DataType::Utf8).unwrap(); + + // Verify the result is a RunArray with String values + let result_array = cast_result.as_any().downcast_ref::().unwrap(); + // Check that values are correct + assert_eq!(result_array.value(0), "10"); + assert_eq!(result_array.value(1), "10"); + assert_eq!(result_array.value(2), "20"); + } + + #[test] + fn test_primitive_to_run_end_encoded() { + // Create an Int32 array with repeated values: [1, 1, 2, 2, 2, 3] + let source_array = Int32Array::from(vec![1, 1, 2, 2, 2, 3]); + let array_ref = Arc::new(source_array) as ArrayRef; + + // Cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + + // Verify the result is a RunArray + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + + // Check run structure: runs should end at positions [2, 5, 6] + assert_eq!(result_run_array.run_ends().values(), &[2, 5, 6]); + + // Check values: should be [1, 2, 3] + let values_array = result_run_array.values().as_primitive::(); + assert_eq!(values_array.values(), &[1, 2, 3]); + } + + #[test] + fn test_primitive_to_run_end_encoded_with_nulls() { + let source_array = Int32Array::from(vec![ + Some(1), + Some(1), + None, + None, + Some(2), + Some(2), + Some(3), + Some(3), + None, + None, + Some(4), + Some(4), + Some(5), + Some(5), + None, + None, + ]); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!( + result_run_array.run_ends().values(), + &[2, 4, 6, 8, 10, 12, 14, 16] + ); + assert_eq!( + result_run_array + .values() + .as_primitive::() + .values(), + &[1, 0, 2, 3, 0, 4, 5, 0] + ); + assert_eq!(result_run_array.values().null_count(), 3); + } + + #[test] + fn test_primitive_to_run_end_encoded_with_nulls_consecutive() { + let source_array = Int64Array::from(vec![ + Some(1), + Some(1), + None, + None, + None, + None, + None, + None, + None, + None, + Some(4), + Some(20), + Some(500), + Some(500), + None, + None, + ]); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Int64, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!( + result_run_array.run_ends().values(), + &[2, 10, 11, 12, 14, 16] + ); + assert_eq!( + result_run_array + .values() + .as_primitive::() + .values(), + &[1, 0, 4, 20, 500, 0] + ); + assert_eq!(result_run_array.values().null_count(), 2); + } + + #[test] + fn test_string_to_run_end_encoded() { + // Create a String array with repeated values: ["a", "a", "b", "c", "c"] + let source_array = StringArray::from(vec!["a", "a", "b", "c", "c"]); + let array_ref = Arc::new(source_array) as ArrayRef; + + // Cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + + // Verify the result is a RunArray + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + + // Check run structure: runs should end at positions [2, 3, 5] + assert_eq!(result_run_array.run_ends().values(), &[2, 3, 5]); + + // Check values: should be ["a", "b", "c"] + let values_array = result_run_array.values().as_string::(); + assert_eq!(values_array.value(0), "a"); + assert_eq!(values_array.value(1), "b"); + assert_eq!(values_array.value(2), "c"); + } + + #[test] + fn test_empty_array_to_run_end_encoded() { + // Create an empty Int32 array + let source_array = Int32Array::from(Vec::::new()); + let array_ref = Arc::new(source_array) as ArrayRef; + + // Cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + + // Verify the result is an empty RunArray + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + + // Check that both run_ends and values are empty + assert_eq!(result_run_array.run_ends().len(), 0); + assert_eq!(result_run_array.values().len(), 0); + } + + #[test] + fn test_run_end_encoded_with_nulls() { + // Create a RunEndEncoded array with nulls: [1, 1, null, 2, 2] + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Int32Array::from(vec![Some(1), None, Some(2)]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(run_array) as ArrayRef; + + // Cast to String + let cast_result = cast(&array_ref, &DataType::Utf8).unwrap(); + + // Verify the result preserves nulls + let result_run_array = cast_result.as_any().downcast_ref::().unwrap(); + assert_eq!(result_run_array.value(0), "1"); + assert!(result_run_array.is_null(2)); + assert_eq!(result_run_array.value(4), "2"); + } + + #[test] + fn test_different_index_types() { + // Test with Int16 index type + let source_array = Int32Array::from(vec![1, 1, 2, 3, 3]); + let array_ref = Arc::new(source_array) as ArrayRef; + + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + assert_eq!(cast_result.data_type(), &target_type); + + // Verify the cast worked correctly: values are [1, 2, 3] + // and run-ends are [2, 3, 5] + let run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(run_array.values().as_primitive::().value(0), 1); + assert_eq!(run_array.values().as_primitive::().value(1), 2); + assert_eq!(run_array.values().as_primitive::().value(2), 3); + assert_eq!(run_array.run_ends().values(), &[2i16, 3i16, 5i16]); + + // Test again with Int64 index type + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int64, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + assert_eq!(cast_result.data_type(), &target_type); + + // Verify the cast worked correctly: values are [1, 2, 3] + // and run-ends are [2, 3, 5] + let run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(run_array.values().as_primitive::().value(0), 1); + assert_eq!(run_array.values().as_primitive::().value(1), 2); + assert_eq!(run_array.values().as_primitive::().value(2), 3); + assert_eq!(run_array.run_ends().values(), &[2i64, 3i64, 5i64]); + } + + #[test] + fn test_unsupported_cast_to_run_end_encoded() { + // Create a Struct array - complex nested type that might not be supported + let field = Field::new("item", DataType::Int32, false); + let struct_array = StructArray::from(vec![( + Arc::new(field), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + )]); + let array_ref = Arc::new(struct_array) as ArrayRef; + + // This should fail because: + // 1. The target type is not RunEndEncoded + // 2. The target type is not supported for casting from StructArray + let cast_result = cast(&array_ref, &DataType::FixedSizeBinary(10)); + + // Expect this to fail + assert!(cast_result.is_err()); + } + + /// Test casting RunEndEncoded to RunEndEncoded should fail + #[test] + fn test_cast_run_end_encoded_int64_to_int16_should_fail() { + // Construct a valid REE array with Int64 run-ends + let run_ends = Int64Array::from(vec![100_000, 400_000, 700_000]); // values too large for Int16 + let values = StringArray::from(vec!["a", "b", "c"]); + + let ree_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(ree_array) as ArrayRef; + + // Attempt to cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_options = CastOptions { + safe: false, // This should make it fail instead of returning nulls + format_options: FormatOptions::default(), + }; + + // This should fail due to run-end overflow + let result: Result, ArrowError> = + cast_with_options(&array_ref, &target_type, &cast_options); + + let e = result.expect_err("Cast should have failed but succeeded"); + assert!( + e.to_string() + .contains("Cast error: Can't cast value 100000 to type Int16") + ); + } + + #[test] + fn test_cast_run_end_encoded_int64_to_int16_with_safe_should_fail_with_null_invalid_error() { + // Construct a valid REE array with Int64 run-ends + let run_ends = Int64Array::from(vec![100_000, 400_000, 700_000]); // values too large for Int16 + let values = StringArray::from(vec!["a", "b", "c"]); + + let ree_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(ree_array) as ArrayRef; + + // Attempt to cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_options = CastOptions { + safe: true, + format_options: FormatOptions::default(), + }; + + // This fails even though safe is true because the run_ends array has null values + let result: Result, ArrowError> = + cast_with_options(&array_ref, &target_type, &cast_options); + let e = result.expect_err("Cast should have failed but succeeded"); + assert!( + e.to_string() + .contains("Invalid argument error: Found null values in run_ends array. The run_ends array should not have null values.") + ); + } + + /// Test casting RunEndEncoded to RunEndEncoded should succeed + #[test] + fn test_cast_run_end_encoded_int16_to_int64_should_succeed() { + // Construct a valid REE array with Int16 run-ends + let run_ends = Int16Array::from(vec![2, 5, 8]); // values that fit in Int16 + let values = StringArray::from(vec!["a", "b", "c"]); + + let ree_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(ree_array) as ArrayRef; + + // Attempt to cast to RunEndEncoded (upcast should succeed) + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int64, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + + // This should succeed due to valid upcast + let result: Result, ArrowError> = + cast_with_options(&array_ref, &target_type, &cast_options); + + let array_ref = result.expect("Cast should have succeeded but failed"); + // Downcast to RunArray + let run_array = array_ref + .as_any() + .downcast_ref::>() + .unwrap(); + + // Verify the cast worked correctly + // Assert the values were cast correctly + assert_eq!(run_array.run_ends().values(), &[2i64, 5i64, 8i64]); + assert_eq!(run_array.values().as_string::().value(0), "a"); + assert_eq!(run_array.values().as_string::().value(1), "b"); + assert_eq!(run_array.values().as_string::().value(2), "c"); + } + + #[test] + fn test_cast_run_end_encoded_dictionary_to_run_end_encoded() { + // Construct a valid dictionary encoded array + let values = StringArray::from_iter([Some("a"), Some("b"), Some("c")]); + let keys = UInt64Array::from_iter(vec![1, 1, 1, 0, 0, 0, 2, 2, 2]); + let array_ref = Arc::new(DictionaryArray::new(keys, Arc::new(values))) as ArrayRef; + + // Attempt to cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int64, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + + // This should succeed + let result = cast_with_options(&array_ref, &target_type, &cast_options) + .expect("Cast should have succeeded but failed"); + + // Verify the cast worked correctly + // Assert the values were cast correctly + let run_array = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(run_array.values().as_string::().value(0), "b"); + assert_eq!(run_array.values().as_string::().value(1), "a"); + assert_eq!(run_array.values().as_string::().value(2), "c"); + + // Verify the run-ends were cast correctly (run ends at 3, 6, 9) + assert_eq!(run_array.run_ends().values(), &[3i64, 6i64, 9i64]); + } } diff --git a/arrow-cast/src/cast/run_array.rs b/arrow-cast/src/cast/run_array.rs new file mode 100644 index 000000000000..8d70afef3ab6 --- /dev/null +++ b/arrow-cast/src/cast/run_array.rs @@ -0,0 +1,164 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::*; +use arrow_ord::partition::partition; + +/// Attempts to cast a `RunArray` with index type K into +/// `to_type` for supported types. +pub(crate) fn run_end_encoded_cast( + array: &dyn Array, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result { + match array.data_type() { + DataType::RunEndEncoded(_, _) => { + let run_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| ArrowError::CastError("Expected RunArray".to_string()))?; + + let values = run_array.values(); + + match to_type { + // Stay as RunEndEncoded, cast only the values + DataType::RunEndEncoded(target_index_field, target_value_field) => { + let cast_values = + cast_with_options(values, target_value_field.data_type(), cast_options)?; + + let run_ends_array = PrimitiveArray::::from_iter_values( + run_array.run_ends().values().iter().copied(), + ); + let cast_run_ends = cast_with_options( + &run_ends_array, + target_index_field.data_type(), + cast_options, + )?; + let new_run_array: ArrayRef = match target_index_field.data_type() { + DataType::Int16 => { + let re = cast_run_ends.as_primitive::(); + Arc::new(RunArray::::try_new(re, cast_values.as_ref())?) + } + DataType::Int32 => { + let re = cast_run_ends.as_primitive::(); + Arc::new(RunArray::::try_new(re, cast_values.as_ref())?) + } + DataType::Int64 => { + let re = cast_run_ends.as_primitive::(); + Arc::new(RunArray::::try_new(re, cast_values.as_ref())?) + } + _ => { + return Err(ArrowError::CastError( + "Run-end type must be i16, i32, or i64".to_string(), + )); + } + }; + Ok(Arc::new(new_run_array)) + } + + // Expand to logical form + _ => { + let run_ends = run_array.run_ends().values().to_vec(); + let mut indices = Vec::with_capacity(run_array.run_ends().len()); + let mut physical_idx: usize = 0; + for logical_idx in 0..run_array.run_ends().len() { + // If the logical index is equal to the (next) run end, increment the physical index, + // since we are at the end of a run. + if logical_idx == run_ends[physical_idx].as_usize() { + physical_idx += 1; + } + indices.push(physical_idx as i32); + } + + let taken = take(&values, &Int32Array::from_iter_values(indices), None)?; + if taken.data_type() != to_type { + cast_with_options(taken.as_ref(), to_type, cast_options) + } else { + Ok(taken) + } + } + } + } + + _ => Err(ArrowError::CastError(format!( + "Cannot cast array of type {:?} to RunEndEncodedArray", + array.data_type() + ))), + } +} + +/// Attempts to encode an array into a `RunArray` with index type K +/// and value type `value_type` +pub(crate) fn cast_to_run_end_encoded( + array: &ArrayRef, + value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let mut run_ends_builder = PrimitiveBuilder::::new(); + + // Cast the input array to the target value type if necessary + let cast_array = if array.data_type() == value_type { + array + } else { + &cast_with_options(array, value_type, cast_options)? + }; + + // Return early if the array to cast is empty + if cast_array.is_empty() { + let empty_run_ends = run_ends_builder.finish(); + let empty_values = make_array(ArrayData::new_empty(value_type)); + return Ok(Arc::new(RunArray::::try_new( + &empty_run_ends, + empty_values.as_ref(), + )?)); + } + + // REE arrays are handled by run_end_encoded_cast + if let DataType::RunEndEncoded(_, _) = array.data_type() { + return Err(ArrowError::CastError( + "Source array is already a RunEndEncoded array, should have been handled by run_end_encoded_cast".to_string() + )); + } + + // Partition the array to identify runs of consecutive equal values + let partitions = partition(&[Arc::clone(cast_array)])?; + let mut run_ends = Vec::new(); + let mut values_indexes = Vec::new(); + let mut last_partition_end = 0; + for partition in partitions.ranges() { + values_indexes.push(last_partition_end); + run_ends.push(partition.end); + last_partition_end = partition.end; + } + + // Build the run_ends array + for run_end in run_ends { + run_ends_builder.append_value(K::Native::from_usize(run_end).ok_or_else(|| { + ArrowError::CastError(format!("Run end index out of range: {}", run_end)) + })?); + } + let run_ends_array = run_ends_builder.finish(); + // Build the values array by taking elements at the run start positions + let indices = PrimitiveArray::::from_iter_values( + values_indexes.iter().map(|&idx| idx as u32), + ); + let values_array = take(&cast_array, &indices, None)?; + + // Create and return the RunArray + let run_array = RunArray::::try_new(&run_ends_array, values_array.as_ref())?; + Ok(Arc::new(run_array)) +} From 7f3d3aef30613b039939efec8780ab4eaa8cb98d Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 27 Oct 2025 20:21:28 +0200 Subject: [PATCH 11/19] perf: override `count`, `nth`, `nth_back`, `last` and `max` for BitIterator (#8696) # Which issue does this PR close? N/A # Rationale for this change overriding this function improve performance over the fallback implementation # What changes are included in this PR? Override implementation of: - `count` which is not optimized away even when `ExactSizeIterator` is implemented - `nth` to avoid calling `next` `n + 1` times (which is also used when doing `.skip`) - `nth_back` - `last` - `max` # Are these changes tested? Yes, I've added a lot of tests # Are there any user-facing changes? Nope --- arrow-buffer/src/util/bit_iterator.rs | 516 ++++++++++++++++++++++++++ 1 file changed, 516 insertions(+) diff --git a/arrow-buffer/src/util/bit_iterator.rs b/arrow-buffer/src/util/bit_iterator.rs index c7f6f94fb869..0aa94a5d4dc1 100644 --- a/arrow-buffer/src/util/bit_iterator.rs +++ b/arrow-buffer/src/util/bit_iterator.rs @@ -23,6 +23,7 @@ use crate::bit_util::{ceil, get_bit_raw}; /// Iterator over the bits within a packed bitmask /// /// To efficiently iterate over just the set bits see [`BitIndexIterator`] and [`BitSliceIterator`] +#[derive(Clone)] pub struct BitIterator<'a> { buffer: &'a [u8], current_offset: usize, @@ -71,6 +72,71 @@ impl Iterator for BitIterator<'_> { let remaining_bits = self.end_offset - self.current_offset; (remaining_bits, Some(remaining_bits)) } + + fn count(self) -> usize + where + Self: Sized, + { + self.len() + } + + fn nth(&mut self, n: usize) -> Option { + // Check if we can advance to the desired offset. + // When n is 0 it means we want the next() value + // and when n is 1 we want the next().next() value + // so adding n to the current offset and not n - 1 + match self.current_offset.checked_add(n) { + // Yes, and still within bounds + Some(new_offset) if new_offset < self.end_offset => { + self.current_offset = new_offset; + } + + // Either overflow or would exceed end_offset + _ => { + self.current_offset = self.end_offset; + return None; + } + } + + self.next() + } + + fn last(mut self) -> Option { + // If already at the end, return None + if self.current_offset == self.end_offset { + return None; + } + + // Go to the one before the last bit + self.current_offset = self.end_offset - 1; + + // Return the last bit + self.next() + } + + fn max(self) -> Option + where + Self: Sized, + Self::Item: Ord, + { + if self.current_offset == self.end_offset { + return None; + } + + // true is greater than false so we only need to check if there's any true bit + let mut bit_index_iter = BitIndexIterator::new( + self.buffer, + self.current_offset, + self.end_offset - self.current_offset, + ); + + if bit_index_iter.next().is_some() { + return Some(true); + } + + // We know the iterator is not empty and there are no set bits so false is the max + Some(false) + } } impl ExactSizeIterator for BitIterator<'_> {} @@ -86,6 +152,27 @@ impl DoubleEndedIterator for BitIterator<'_> { let v = unsafe { get_bit_raw(self.buffer.as_ptr(), self.end_offset) }; Some(v) } + + fn nth_back(&mut self, n: usize) -> Option { + // Check if we can advance to the desired offset. + // When n is 0 it means we want the next_back() value + // and when n is 1 we want the next_back().next_back() value + // so subtracting n to the current offset and not n - 1 + match self.end_offset.checked_sub(n) { + // Yes, and still within bounds + Some(new_offset) if self.current_offset < new_offset => { + self.end_offset = new_offset; + } + + // Either underflow or would exceed current_offset + _ => { + self.current_offset = self.end_offset; + return None; + } + } + + self.next_back() + } } /// Iterator of contiguous ranges of set bits within a provided packed bitmask @@ -327,6 +414,12 @@ pub fn try_for_each_valid_idx Result<(), E>>( #[cfg(test)] mod tests { use super::*; + use crate::BooleanBuffer; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::fmt::Debug; + use std::iter::Copied; + use std::slice::Iter; #[test] fn test_bit_iterator_size_hint() { @@ -486,4 +579,427 @@ mod tests { .collect(); assert_eq!(result, expected); } + + trait SharedBetweenBitIteratorAndSliceIter: + ExactSizeIterator + DoubleEndedIterator + { + } + impl + DoubleEndedIterator> + SharedBetweenBitIteratorAndSliceIter for T + { + } + + fn get_bit_iterator_cases() -> impl Iterator)> { + let mut rng = StdRng::seed_from_u64(42); + + [0, 1, 6, 8, 100, 164] + .map(|len| { + let source = (0..len).map(|_| rng.random_bool(0.5)).collect::>(); + + (BooleanBuffer::from(source.as_slice()), source) + }) + .into_iter() + } + + fn setup_and_assert( + setup_iters: impl Fn(&mut dyn SharedBetweenBitIteratorAndSliceIter), + assert_fn: impl Fn(BitIterator, Copied>), + ) { + for (boolean_buffer, source) in get_bit_iterator_cases() { + // Not using `boolean_buffer.iter()` in case the implementation change to not call BitIterator internally + // in which case the test would not test what it intends to test + let mut actual = BitIterator::new(boolean_buffer.values(), 0, boolean_buffer.len()); + let mut expected = source.iter().copied(); + + setup_iters(&mut actual); + setup_iters(&mut expected); + + assert_fn(actual, expected); + } + } + + /// Trait representing an operation on a BitIterator + /// that can be compared against a slice iterator + trait BitIteratorOp { + /// What the operation returns (e.g. Option for last/max, usize for count, etc) + type Output: PartialEq + Debug; + + /// The name of the operation, used for error messages + const NAME: &'static str; + + /// Get the value of the operation for the provided iterator + /// This will be either a BitIterator or a slice iterator to make sure they produce the same result + fn get_value(iter: T) -> Self::Output; + } + + /// Helper function that will assert that the provided operation + /// produces the same result for both BitIterator and slice iterator + /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) + fn assert_bit_iterator_cases() { + setup_and_assert( + |_iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| {}, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + iter.next(); + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the start (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + iter.next_back(); + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the end (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + iter.next(); + iter.next_back(); + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming 1 element from start and end (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.len() > 1 { + iter.next(); + } + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the start but 1 (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.len() > 1 { + iter.next_back(); + } + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the end but 1 (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.next().is_some() {} + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the start (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.next_back().is_some() {} + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the end (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + } + + #[test] + fn assert_bit_iterator_count() { + struct CountOp; + + impl BitIteratorOp for CountOp { + type Output = usize; + const NAME: &'static str = "count"; + + fn get_value(iter: T) -> Self::Output { + iter.count() + } + } + + assert_bit_iterator_cases::() + } + + #[test] + fn assert_bit_iterator_last() { + struct LastOp; + + impl BitIteratorOp for LastOp { + type Output = Option; + const NAME: &'static str = "last"; + + fn get_value(iter: T) -> Self::Output { + iter.last() + } + } + + assert_bit_iterator_cases::() + } + + #[test] + fn assert_bit_iterator_max() { + struct MaxOp; + + impl BitIteratorOp for MaxOp { + type Output = Option; + const NAME: &'static str = "max"; + + fn get_value(iter: T) -> Self::Output { + iter.max() + } + } + + assert_bit_iterator_cases::() + } + + #[test] + fn assert_bit_iterator_nth_0() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { "nth_back(0)" } else { "nth(0)" }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { iter.nth_back(0) } else { iter.nth(0) } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_1() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { "nth_back(1)" } else { "nth(1)" }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { iter.nth_back(1) } else { iter.nth(1) } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_after_end() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { + "nth_back(iter.len() + 1)" + } else { + "nth(iter.len() + 1)" + }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { + iter.nth_back(iter.len() + 1) + } else { + iter.nth(iter.len() + 1) + } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_len() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { + "nth_back(iter.len())" + } else { + "nth(iter.len())" + }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { + iter.nth_back(iter.len()) + } else { + iter.nth(iter.len()) + } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_last() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { + "nth_back(iter.len().saturating_sub(1))" + } else { + "nth(iter.len().saturating_sub(1))" + }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { + iter.nth_back(iter.len().saturating_sub(1)) + } else { + iter.nth(iter.len().saturating_sub(1)) + } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_and_reuse() { + setup_and_assert( + |_| {}, + |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth(0); + #[allow(clippy::iter_nth_zero)] + let expected_val = expected.nth(0); + assert_eq!(actual_val, expected_val, "Failed on nth(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(1); + let expected_val = expected.nth(1); + assert_eq!(actual_val, expected_val, "Failed on nth(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(2); + let expected_val = expected.nth(2); + assert_eq!(actual_val, expected_val, "Failed on nth(2)"); + } + } + }, + ); + } + + #[test] + fn assert_bit_iterator_nth_back_and_reuse() { + setup_and_assert( + |_| {}, + |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth_back(0); + let expected_val = expected.nth_back(0); + assert_eq!(actual_val, expected_val, "Failed on nth_back(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(1); + let expected_val = expected.nth_back(1); + assert_eq!(actual_val, expected_val, "Failed on nth_back(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(2); + let expected_val = expected.nth_back(2); + assert_eq!(actual_val, expected_val, "Failed on nth_back(2)"); + } + } + }, + ); + } } From 29cd6850619216f862e517ad383cf5258133ba41 Mon Sep 17 00:00:00 2001 From: Alex Stephen <1325798+rambleraptor@users.noreply.github.com> Date: Mon, 27 Oct 2025 11:23:05 -0700 Subject: [PATCH 12/19] Change some panics to errors in parquet decoder (#8602) # Rationale for this change We've caused some unexpected panics from our internal testing. We've put in error checks for all of these so that they don't affect other users. # What changes are included in this PR? Various error checks to ensure panics don't occur. # Are these changes tested? Tests should continue to pass. If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? Existing tests should cover these changes. # Are there any user-facing changes? None. --------- Co-authored-by: Ed Seidl Co-authored-by: Ryan Johnson --- parquet/src/column/page.rs | 2 +- parquet/src/column/reader.rs | 34 ++++++++++-- parquet/src/encodings/decoding.rs | 17 ++++++ parquet/src/encodings/rle.rs | 5 +- parquet/src/file/reader.rs | 45 ++++++++++++++++ parquet/src/file/serialized_reader.rs | 75 +++++++++++++++++++++++++- parquet/src/schema/types.rs | 28 +++++----- parquet/tests/arrow_reader/bad_data.rs | 6 ++- 8 files changed, 190 insertions(+), 22 deletions(-) diff --git a/parquet/src/column/page.rs b/parquet/src/column/page.rs index 23517f05df11..f18b296c1c65 100644 --- a/parquet/src/column/page.rs +++ b/parquet/src/column/page.rs @@ -31,7 +31,7 @@ use crate::file::statistics::{Statistics, page_stats_to_thrift}; /// List of supported pages. /// These are 1-to-1 mapped from the equivalent Thrift definitions, except `buf` which /// used to store uncompressed bytes of the page. -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum Page { /// Data page Parquet format v1. DataPage { diff --git a/parquet/src/column/reader.rs b/parquet/src/column/reader.rs index b8ff38efa3c4..ebde79e6a7f2 100644 --- a/parquet/src/column/reader.rs +++ b/parquet/src/column/reader.rs @@ -569,11 +569,16 @@ fn parse_v1_level( match encoding { Encoding::RLE => { let i32_size = std::mem::size_of::(); - let data_size = read_num_bytes::(i32_size, buf.as_ref()) as usize; - Ok(( - i32_size + data_size, - buf.slice(i32_size..i32_size + data_size), - )) + if i32_size <= buf.len() { + let data_size = read_num_bytes::(i32_size, buf.as_ref()) as usize; + let end = i32_size + .checked_add(data_size) + .ok_or(general_err!("invalid level length"))?; + if end <= buf.len() { + return Ok((end, buf.slice(i32_size..end))); + } + } + Err(general_err!("not enough data to read levels")) } #[allow(deprecated)] Encoding::BIT_PACKED => { @@ -597,6 +602,25 @@ mod tests { use crate::util::test_common::page_util::InMemoryPageReader; use crate::util::test_common::rand_gen::make_pages; + #[test] + fn test_parse_v1_level_invalid_length() { + // Say length is 10, but buffer is only 4 + let buf = Bytes::from(vec![10, 0, 0, 0]); + let err = parse_v1_level(1, 100, Encoding::RLE, buf).unwrap_err(); + assert_eq!( + err.to_string(), + "Parquet error: not enough data to read levels" + ); + + // Say length is 4, but buffer is only 3 + let buf = Bytes::from(vec![4, 0, 0]); + let err = parse_v1_level(1, 100, Encoding::RLE, buf).unwrap_err(); + assert_eq!( + err.to_string(), + "Parquet error: not enough data to read levels" + ); + } + const NUM_LEVELS: usize = 128; const NUM_PAGES: usize = 2; const MAX_DEF_LEVEL: i16 = 5; diff --git a/parquet/src/encodings/decoding.rs b/parquet/src/encodings/decoding.rs index de8738cf09f9..f5336ca7c09a 100644 --- a/parquet/src/encodings/decoding.rs +++ b/parquet/src/encodings/decoding.rs @@ -381,7 +381,17 @@ impl DictDecoder { impl Decoder for DictDecoder { fn set_data(&mut self, data: Bytes, num_values: usize) -> Result<()> { // First byte in `data` is bit width + if data.is_empty() { + return Err(eof_err!("Not enough bytes to decode bit_width")); + } + let bit_width = data.as_ref()[0]; + if bit_width > 32 { + return Err(general_err!( + "Invalid or corrupted RLE bit width {}. Max allowed is 32", + bit_width + )); + } let mut rle_decoder = RleDecoder::new(bit_width); rle_decoder.set_data(data.slice(1..)); self.num_values = num_values; @@ -1395,6 +1405,13 @@ mod tests { test_plain_skip::(Bytes::from(data_bytes), 3, 6, 4, &[]); } + #[test] + fn test_dict_decoder_empty_data() { + let mut decoder = DictDecoder::::new(); + let err = decoder.set_data(Bytes::new(), 10).unwrap_err(); + assert_eq!(err.to_string(), "EOF: Not enough bytes to decode bit_width"); + } + fn test_plain_decode( data: Bytes, num_values: usize, diff --git a/parquet/src/encodings/rle.rs b/parquet/src/encodings/rle.rs index db8227fcac3a..41c050132064 100644 --- a/parquet/src/encodings/rle.rs +++ b/parquet/src/encodings/rle.rs @@ -513,7 +513,10 @@ impl RleDecoder { self.rle_left = (indicator_value >> 1) as u32; let value_width = bit_util::ceil(self.bit_width as usize, 8); self.current_value = bit_reader.get_aligned::(value_width); - assert!(self.current_value.is_some()); + assert!( + self.current_value.is_some(), + "parquet_data_error: not enough data for RLE decoding" + ); } true } else { diff --git a/parquet/src/file/reader.rs b/parquet/src/file/reader.rs index 61af21a68ec1..3adf10fac220 100644 --- a/parquet/src/file/reader.rs +++ b/parquet/src/file/reader.rs @@ -124,11 +124,25 @@ impl ChunkReader for Bytes { fn get_read(&self, start: u64) -> Result { let start = start as usize; + if start > self.len() { + return Err(eof_err!( + "Expected to read at offset {start}, while file has length {}", + self.len() + )); + } Ok(self.slice(start..).reader()) } fn get_bytes(&self, start: u64, length: usize) -> Result { let start = start as usize; + if start > self.len() || start + length > self.len() { + return Err(eof_err!( + "Expected to read {} bytes at offset {}, while file has length {}", + length, + start, + self.len() + )); + } Ok(self.slice(start..start + length)) } } @@ -274,3 +288,34 @@ impl Iterator for FilePageIterator { } impl PageIterator for FilePageIterator {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bytes_chunk_reader_get_read_out_of_bounds() { + let data = Bytes::from(vec![0, 1, 2, 3]); + let err = data.get_read(5).unwrap_err(); + assert_eq!( + err.to_string(), + "EOF: Expected to read at offset 5, while file has length 4" + ); + } + + #[test] + fn test_bytes_chunk_reader_get_bytes_out_of_bounds() { + let data = Bytes::from(vec![0, 1, 2, 3]); + let err = data.get_bytes(5, 1).unwrap_err(); + assert_eq!( + err.to_string(), + "EOF: Expected to read 1 bytes at offset 5, while file has length 4" + ); + + let err = data.get_bytes(2, 3).unwrap_err(); + assert_eq!( + err.to_string(), + "EOF: Expected to read 3 bytes at offset 2, while file has length 4" + ); + } +} diff --git a/parquet/src/file/serialized_reader.rs b/parquet/src/file/serialized_reader.rs index 6da5c39d745b..3f95ea9d4982 100644 --- a/parquet/src/file/serialized_reader.rs +++ b/parquet/src/file/serialized_reader.rs @@ -392,6 +392,9 @@ pub(crate) fn decode_page( let buffer = match decompressor { Some(decompressor) if can_decompress => { let uncompressed_page_size = usize::try_from(page_header.uncompressed_page_size)?; + if offset > buffer.len() || offset > uncompressed_page_size { + return Err(general_err!("Invalid page header")); + } let decompressed_size = uncompressed_page_size - offset; let mut decompressed = Vec::with_capacity(uncompressed_page_size); decompressed.extend_from_slice(&buffer.as_ref()[..offset]); @@ -458,7 +461,10 @@ pub(crate) fn decode_page( } _ => { // For unknown page type (e.g., INDEX_PAGE), skip and read next. - unimplemented!("Page type {:?} is not supported", page_header.r#type) + return Err(general_err!( + "Page type {:?} is not supported", + page_header.r#type + )); } }; @@ -1130,6 +1136,7 @@ mod tests { use crate::column::reader::ColumnReader; use crate::data_type::private::ParquetValueType; use crate::data_type::{AsBytes, FixedLenByteArrayType, Int32Type}; + use crate::file::metadata::thrift::DataPageHeaderV2; #[allow(deprecated)] use crate::file::page_index::index_reader::{read_columns_indexes, read_offset_indexes}; use crate::file::writer::SerializedFileWriter; @@ -1139,6 +1146,72 @@ mod tests { use super::*; + #[test] + fn test_decode_page_invalid_offset() { + let page_header = PageHeader { + r#type: PageType::DATA_PAGE_V2, + uncompressed_page_size: 10, + compressed_page_size: 10, + data_page_header: None, + index_page_header: None, + dictionary_page_header: None, + crc: None, + data_page_header_v2: Some(DataPageHeaderV2 { + num_nulls: 0, + num_rows: 0, + num_values: 0, + encoding: Encoding::PLAIN, + definition_levels_byte_length: 11, + repetition_levels_byte_length: 0, + is_compressed: None, + statistics: None, + }), + }; + + let buffer = Bytes::new(); + let err = decode_page(page_header, buffer, Type::INT32, None).unwrap_err(); + assert!( + err.to_string() + .contains("DataPage v2 header contains implausible values") + ); + } + + #[test] + fn test_decode_unsupported_page() { + let mut page_header = PageHeader { + r#type: PageType::INDEX_PAGE, + uncompressed_page_size: 10, + compressed_page_size: 10, + data_page_header: None, + index_page_header: None, + dictionary_page_header: None, + crc: None, + data_page_header_v2: None, + }; + let buffer = Bytes::new(); + let err = decode_page(page_header.clone(), buffer.clone(), Type::INT32, None).unwrap_err(); + assert_eq!( + err.to_string(), + "Parquet error: Page type INDEX_PAGE is not supported" + ); + + page_header.data_page_header_v2 = Some(DataPageHeaderV2 { + num_nulls: 0, + num_rows: 0, + num_values: 0, + encoding: Encoding::PLAIN, + definition_levels_byte_length: 11, + repetition_levels_byte_length: 0, + is_compressed: None, + statistics: None, + }); + let err = decode_page(page_header, buffer, Type::INT32, None).unwrap_err(); + assert!( + err.to_string() + .contains("DataPage v2 header contains implausible values") + ); + } + #[test] fn test_cursor_and_file_has_the_same_behaviour() { let mut buf: Vec = Vec::new(); diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index 1ae37d0a462f..de6f855685a6 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -1348,19 +1348,23 @@ fn schema_from_array_helper<'a>( .with_logical_type(logical_type) .with_fields(fields) .with_id(field_id); - if let Some(rep) = repetition { - // Sometimes parquet-cpp and parquet-mr set repetition level REQUIRED or - // REPEATED for root node. - // - // We only set repetition for group types that are not top-level message - // type. According to parquet-format: - // Root of the schema does not have a repetition_type. - // All other types must have one. - if !is_root_node { - builder = builder.with_repetition(rep); - } + + // Sometimes parquet-cpp and parquet-mr set repetition level REQUIRED or + // REPEATED for root node. + // + // We only set repetition for group types that are not top-level message + // type. According to parquet-format: + // Root of the schema does not have a repetition_type. + // All other types must have one. + if !is_root_node { + let Some(rep) = repetition else { + return Err(general_err!( + "Repetition level must be defined for non-root types" + )); + }; + builder = builder.with_repetition(rep); } - Ok((next_index, Arc::new(builder.build().unwrap()))) + Ok((next_index, Arc::new(builder.build()?))) } } } diff --git a/parquet/tests/arrow_reader/bad_data.rs b/parquet/tests/arrow_reader/bad_data.rs index 235f81812468..54c92976e41c 100644 --- a/parquet/tests/arrow_reader/bad_data.rs +++ b/parquet/tests/arrow_reader/bad_data.rs @@ -84,10 +84,12 @@ fn test_parquet_1481() { } #[test] -#[should_panic(expected = "assertion failed: self.current_value.is_some()")] fn test_arrow_gh_41321() { let err = read_file("ARROW-GH-41321.parquet").unwrap_err(); - assert_eq!(err.to_string(), "TBD (currently panics)"); + assert_eq!( + err.to_string(), + "External: Parquet argument error: Parquet error: Invalid or corrupted RLE bit width 254. Max allowed is 32" + ); } #[test] From 5e32cc60e5e2f3119dc1d56ea925c4daf17f13df Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Mon, 27 Oct 2025 14:23:39 -0400 Subject: [PATCH 13/19] Support more operations on ListView (#8645) # Which issue does this PR close? Part of #5375 Vortex was encountering some issues after we switched our preferred List type to `ListView`, the first thing we noticed was that `arrow_select::filter_array` would fail on ListView (and LargeListView, though we don't use that). This PR addresses some missing select kernel implementations for ListView and LargeListView. This also fixes an existing bug in the ArrayData validation for ListView arrays that would trigger an out of bounds index panic. # Are these changes tested? - [x] filter_array - [x] concat - [x] take # Are there any user-facing changes? ListView/LargeListView can now be used with the `take`, `concat` and `filter_array` kernels You can now use the `PartialEq` to compare ListView arrays. --------- Signed-off-by: Andrew Duffy --- arrow-data/src/data.rs | 10 +- arrow-data/src/equal/list_view.rs | 129 ++++++++++++++++++ arrow-data/src/equal/mod.rs | 10 +- arrow-data/src/transform/list_view.rs | 56 ++++++++ arrow-data/src/transform/mod.rs | 33 +++-- arrow-select/src/concat.rs | 186 ++++++++++++++++++++++++-- arrow-select/src/filter.rs | 97 ++++++++++++++ arrow-select/src/take.rs | 149 +++++++++++++++++++++ arrow/tests/array_equal.rs | 129 +++++++++++++++++- 9 files changed, 767 insertions(+), 32 deletions(-) create mode 100644 arrow-data/src/equal/list_view.rs create mode 100644 arrow-data/src/transform/list_view.rs diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index c31ac0c6e693..91957e14f332 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -980,7 +980,15 @@ impl ArrayData { ) -> Result<(), ArrowError> { let offsets: &[T] = self.typed_buffer(0, self.len)?; let sizes: &[T] = self.typed_buffer(1, self.len)?; - for i in 0..values_length { + if offsets.len() != sizes.len() { + return Err(ArrowError::ComputeError(format!( + "ListView offsets len {} does not match sizes len {}", + offsets.len(), + sizes.len() + ))); + } + + for i in 0..sizes.len() { let size = sizes[i].to_usize().ok_or_else(|| { ArrowError::InvalidArgumentError(format!( "Error converting size[{}] ({}) to usize for {}", diff --git a/arrow-data/src/equal/list_view.rs b/arrow-data/src/equal/list_view.rs new file mode 100644 index 000000000000..c7cb31db9099 --- /dev/null +++ b/arrow-data/src/equal/list_view.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::ArrayData; +use crate::data::count_nulls; +use crate::equal::equal_values; +use arrow_buffer::ArrowNativeType; +use num_integer::Integer; + +pub(super) fn list_view_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_offsets = lhs.buffer::(0); + let lhs_sizes = lhs.buffer::(1); + + let rhs_offsets = rhs.buffer::(0); + let rhs_sizes = rhs.buffer::(1); + + let lhs_data = &lhs.child_data()[0]; + let rhs_data = &rhs.child_data()[0]; + + let lhs_null_count = count_nulls(lhs.nulls(), lhs_start, len); + let rhs_null_count = count_nulls(rhs.nulls(), rhs_start, len); + + if lhs_null_count != rhs_null_count { + return false; + } + + if lhs_null_count == 0 { + // non-null pathway: all sizes must be equal, and all values must be equal + let lhs_range_sizes = &lhs_sizes[lhs_start..lhs_start + len]; + let rhs_range_sizes = &rhs_sizes[rhs_start..rhs_start + len]; + + if lhs_range_sizes.len() != rhs_range_sizes.len() { + return false; + } + + if lhs_range_sizes != rhs_range_sizes { + return false; + } + + // Check values for equality + let lhs_range_offsets = &lhs_offsets[lhs_start..lhs_start + len]; + let rhs_range_offsets = &rhs_offsets[rhs_start..rhs_start + len]; + + if lhs_range_offsets.len() != rhs_range_offsets.len() { + return false; + } + + for ((&lhs_offset, &rhs_offset), &size) in lhs_range_offsets + .iter() + .zip(rhs_range_offsets) + .zip(lhs_range_sizes) + { + let lhs_offset = lhs_offset.to_usize().unwrap(); + let rhs_offset = rhs_offset.to_usize().unwrap(); + let size = size.to_usize().unwrap(); + + // Check if offsets are valid for the given range + if !equal_values(lhs_data, rhs_data, lhs_offset, rhs_offset, size) { + return false; + } + } + } else { + // Need to integrate validity check in the inner loop. + // non-null pathway: all sizes must be equal, and all values must be equal + let lhs_range_sizes = &lhs_sizes[lhs_start..lhs_start + len]; + let rhs_range_sizes = &rhs_sizes[rhs_start..rhs_start + len]; + + let lhs_nulls = lhs.nulls().unwrap().slice(lhs_start, len); + let rhs_nulls = rhs.nulls().unwrap().slice(rhs_start, len); + + // Sizes can differ if values are null + if lhs_range_sizes.len() != rhs_range_sizes.len() { + return false; + } + + // Check values for equality, with null checking + let lhs_range_offsets = &lhs_offsets[lhs_start..lhs_start + len]; + let rhs_range_offsets = &rhs_offsets[rhs_start..rhs_start + len]; + + if lhs_range_offsets.len() != rhs_range_offsets.len() { + return false; + } + + for (index, ((&lhs_offset, &rhs_offset), &size)) in lhs_range_offsets + .iter() + .zip(rhs_range_offsets) + .zip(lhs_range_sizes) + .enumerate() + { + let lhs_is_null = lhs_nulls.is_null(index); + let rhs_is_null = rhs_nulls.is_null(index); + + if lhs_is_null != rhs_is_null { + return false; + } + + let lhs_offset = lhs_offset.to_usize().unwrap(); + let rhs_offset = rhs_offset.to_usize().unwrap(); + let size = size.to_usize().unwrap(); + + // Check if values match in the range + if !lhs_is_null && !equal_values(lhs_data, rhs_data, lhs_offset, rhs_offset, size) { + return false; + } + } + } + + true +} diff --git a/arrow-data/src/equal/mod.rs b/arrow-data/src/equal/mod.rs index 1c16ee2f8a14..7a310b1240df 100644 --- a/arrow-data/src/equal/mod.rs +++ b/arrow-data/src/equal/mod.rs @@ -30,6 +30,7 @@ mod dictionary; mod fixed_binary; mod fixed_list; mod list; +mod list_view; mod null; mod primitive; mod run; @@ -41,6 +42,8 @@ mod variable_size; // these methods assume the same type, len and null count. // For this reason, they are not exposed and are instead used // to build the generic functions below (`equal_range` and `equal`). +use self::run::run_equal; +use crate::equal::list_view::list_view_equal; use boolean::boolean_equal; use byte_view::byte_view_equal; use dictionary::dictionary_equal; @@ -53,8 +56,6 @@ use structure::struct_equal; use union::union_equal; use variable_size::variable_sized_equal; -use self::run::run_equal; - /// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively /// for `len` slots. #[inline] @@ -104,10 +105,9 @@ fn equal_values( byte_view_equal(lhs, rhs, lhs_start, rhs_start, len) } DataType::List(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), - DataType::ListView(_) | DataType::LargeListView(_) => { - unimplemented!("ListView/LargeListView not yet implemented") - } DataType::LargeList(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::ListView(_) => list_view_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::LargeListView(_) => list_view_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::FixedSizeList(_, _) => fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len), diff --git a/arrow-data/src/transform/list_view.rs b/arrow-data/src/transform/list_view.rs new file mode 100644 index 000000000000..9b66a6a6abb1 --- /dev/null +++ b/arrow-data/src/transform/list_view.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::ArrayData; +use crate::transform::_MutableArrayData; +use arrow_buffer::ArrowNativeType; +use num_integer::Integer; +use num_traits::CheckedAdd; + +pub(super) fn build_extend( + array: &ArrayData, +) -> crate::transform::Extend<'_> { + let offsets = array.buffer::(0); + let sizes = array.buffer::(1); + Box::new( + move |mutable: &mut _MutableArrayData, _index: usize, start: usize, len: usize| { + let offset_buffer = &mut mutable.buffer1; + let sizes_buffer = &mut mutable.buffer2; + + for &offset in &offsets[start..start + len] { + offset_buffer.push(offset); + } + + // sizes + for &size in &sizes[start..start + len] { + sizes_buffer.push(size); + } + + // the beauty of views is that we don't need to copy child_data, we just splat + // the offsets and sizes. + }, + ) +} + +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { + let offset_buffer = &mut mutable.buffer1; + let sizes_buffer = &mut mutable.buffer2; + + // We push 0 as a placeholder for NULL values in both the offsets and sizes + (0..len).for_each(|_| offset_buffer.push(T::default())); + (0..len).for_each(|_| sizes_buffer.push(T::default())); +} diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs index 5b994046e6ca..c6052817bfb6 100644 --- a/arrow-data/src/transform/mod.rs +++ b/arrow-data/src/transform/mod.rs @@ -33,6 +33,7 @@ mod boolean; mod fixed_binary; mod fixed_size_list; mod list; +mod list_view; mod null; mod primitive; mod run; @@ -265,10 +266,9 @@ fn build_extend(array: &ArrayData) -> Extend<'_> { DataType::LargeUtf8 | DataType::LargeBinary => variable_size::build_extend::(array), DataType::BinaryView | DataType::Utf8View => unreachable!("should use build_extend_view"), DataType::Map(_, _) | DataType::List(_) => list::build_extend::(array), - DataType::ListView(_) | DataType::LargeListView(_) => { - unimplemented!("ListView/LargeListView not implemented") - } DataType::LargeList(_) => list::build_extend::(array), + DataType::ListView(_) => list_view::build_extend::(array), + DataType::LargeListView(_) => list_view::build_extend::(array), DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"), DataType::Struct(_) => structure::build_extend(array), DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), @@ -313,10 +313,9 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { DataType::LargeUtf8 | DataType::LargeBinary => variable_size::extend_nulls::, DataType::BinaryView | DataType::Utf8View => primitive::extend_nulls::, DataType::Map(_, _) | DataType::List(_) => list::extend_nulls::, - DataType::ListView(_) | DataType::LargeListView(_) => { - unimplemented!("ListView/LargeListView not implemented") - } DataType::LargeList(_) => list::extend_nulls::, + DataType::ListView(_) => list_view::extend_nulls::, + DataType::LargeListView(_) => list_view::extend_nulls::, DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { DataType::UInt8 => primitive::extend_nulls::, DataType::UInt16 => primitive::extend_nulls::, @@ -450,7 +449,11 @@ impl<'a> MutableArrayData<'a> { new_buffers(data_type, *capacity) } ( - DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _), + DataType::List(_) + | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) + | DataType::FixedSizeList(_, _), Capacities::List(capacity, _), ) => { array_capacity = *capacity; @@ -491,10 +494,11 @@ impl<'a> MutableArrayData<'a> { | DataType::Utf8View | DataType::Interval(_) | DataType::FixedSizeBinary(_) => vec![], - DataType::ListView(_) | DataType::LargeListView(_) => { - unimplemented!("ListView/LargeListView not implemented") - } - DataType::Map(_, _) | DataType::List(_) | DataType::LargeList(_) => { + DataType::Map(_, _) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) => { let children = arrays .iter() .map(|array| &array.child_data()[0]) @@ -785,7 +789,12 @@ impl<'a> MutableArrayData<'a> { b.insert(0, data.buffer1.into()); b } - DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary => { + DataType::Utf8 + | DataType::Binary + | DataType::LargeUtf8 + | DataType::LargeBinary + | DataType::ListView(_) + | DataType::LargeListView(_) => { vec![data.buffer1.into(), data.buffer2.into()] } DataType::Union(_, mode) => { diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index 83bc5c2763d2..3bfdd31ccf2d 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -37,7 +37,9 @@ use arrow_array::builder::{ use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; -use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, OffsetBuffer}; +use arrow_buffer::{ + ArrowNativeType, BooleanBufferBuilder, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer, +}; use arrow_data::ArrayDataBuilder; use arrow_data::transform::{Capacities, MutableArrayData}; use arrow_schema::{ArrowError, DataType, FieldRef, Fields, SchemaRef}; @@ -206,6 +208,63 @@ fn concat_lists( Ok(Arc::new(array)) } +fn concat_list_view( + arrays: &[&dyn Array], + field: &FieldRef, +) -> Result { + let mut output_len = 0; + let mut list_has_nulls = false; + + let lists = arrays + .iter() + .map(|x| x.as_list_view::()) + .inspect(|l| { + output_len += l.len(); + list_has_nulls |= l.null_count() != 0; + }) + .collect::>(); + + let lists_nulls = list_has_nulls.then(|| { + let mut nulls = BooleanBufferBuilder::new(output_len); + for l in &lists { + match l.nulls() { + Some(n) => nulls.append_buffer(n.inner()), + None => nulls.append_n(l.len(), true), + } + } + NullBuffer::new(nulls.finish()) + }); + + let values: Vec<&dyn Array> = lists.iter().map(|l| l.values().as_ref()).collect(); + + let concatenated_values = concat(values.as_slice())?; + + let sizes: ScalarBuffer = lists.iter().flat_map(|x| x.sizes()).copied().collect(); + + let mut offsets = MutableBuffer::with_capacity(lists.iter().map(|l| l.offsets().len()).sum()); + let mut global_offset = OffsetSize::zero(); + for l in lists.iter() { + for &offset in l.offsets() { + offsets.push(offset + global_offset); + } + + // advance the offsets + global_offset += OffsetSize::from_usize(l.values().len()).unwrap(); + } + + let offsets = ScalarBuffer::from(offsets); + + let array = GenericListViewArray::try_new( + field.clone(), + offsets, + sizes, + concatenated_values, + lists_nulls, + )?; + + Ok(Arc::new(array)) +} + fn concat_primitives(arrays: &[&dyn Array]) -> Result { let mut builder = PrimitiveBuilder::::with_capacity(arrays.iter().map(|a| a.len()).sum()) .with_data_type(arrays[0].data_type().clone()); @@ -422,6 +481,8 @@ pub fn concat(arrays: &[&dyn Array]) -> Result { } DataType::List(field) => concat_lists::(arrays, field), DataType::LargeList(field) => concat_lists::(arrays, field), + DataType::ListView(field) => concat_list_view::(arrays, field), + DataType::LargeListView(field) => concat_list_view::(arrays, field), DataType::Struct(fields) => concat_structs(arrays, fields), DataType::Utf8 => concat_bytes::(arrays), DataType::LargeUtf8 => concat_bytes::(arrays), @@ -500,7 +561,9 @@ pub fn concat_batches<'a>( #[cfg(test)] mod tests { use super::*; - use arrow_array::builder::{GenericListBuilder, StringDictionaryBuilder}; + use arrow_array::builder::{ + GenericListBuilder, Int64Builder, ListViewBuilder, StringDictionaryBuilder, + }; use arrow_schema::{Field, Schema}; use std::fmt::Debug; @@ -768,7 +831,7 @@ mod tests { #[test] fn test_concat_primitive_list_arrays() { - let list1 = vec![ + let list1 = [ Some(vec![Some(-1), Some(-1), Some(2), None, None]), Some(vec![]), None, @@ -776,14 +839,14 @@ mod tests { ]; let list1_array = ListArray::from_iter_primitive::(list1.clone()); - let list2 = vec![ + let list2 = [ None, Some(vec![Some(100), None, Some(101)]), Some(vec![Some(102)]), ]; let list2_array = ListArray::from_iter_primitive::(list2.clone()); - let list3 = vec![Some(vec![Some(1000), Some(1001)])]; + let list3 = [Some(vec![Some(1000), Some(1001)])]; let list3_array = ListArray::from_iter_primitive::(list3.clone()); let array_result = concat(&[&list1_array, &list2_array, &list3_array]).unwrap(); @@ -796,7 +859,7 @@ mod tests { #[test] fn test_concat_primitive_list_arrays_slices() { - let list1 = vec![ + let list1 = [ Some(vec![Some(-1), Some(-1), Some(2), None, None]), Some(vec![]), // In slice None, // In slice @@ -806,7 +869,7 @@ mod tests { let list1_array = list1_array.slice(1, 2); let list1_values = list1.into_iter().skip(1).take(2); - let list2 = vec![ + let list2 = [ None, Some(vec![Some(100), None, Some(101)]), Some(vec![Some(102)]), @@ -825,7 +888,7 @@ mod tests { #[test] fn test_concat_primitive_list_arrays_sliced_lengths() { - let list1 = vec![ + let list1 = [ Some(vec![Some(-1), Some(-1), Some(2), None, None]), // In slice Some(vec![]), // In slice None, // In slice @@ -835,7 +898,7 @@ mod tests { let list1_array = list1_array.slice(0, 3); // no offset, but not all values let list1_values = list1.into_iter().take(3); - let list2 = vec![ + let list2 = [ None, Some(vec![Some(100), None, Some(101)]), Some(vec![Some(102)]), @@ -856,7 +919,7 @@ mod tests { #[test] fn test_concat_primitive_fixed_size_list_arrays() { - let list1 = vec![ + let list1 = [ Some(vec![Some(-1), None]), None, Some(vec![Some(10), Some(20)]), @@ -864,7 +927,7 @@ mod tests { let list1_array = FixedSizeListArray::from_iter_primitive::(list1.clone(), 2); - let list2 = vec![ + let list2 = [ None, Some(vec![Some(100), None]), Some(vec![Some(102), Some(103)]), @@ -872,7 +935,7 @@ mod tests { let list2_array = FixedSizeListArray::from_iter_primitive::(list2.clone(), 2); - let list3 = vec![Some(vec![Some(1000), Some(1001)])]; + let list3 = [Some(vec![Some(1000), Some(1001)])]; let list3_array = FixedSizeListArray::from_iter_primitive::(list3.clone(), 2); @@ -885,6 +948,105 @@ mod tests { assert_eq!(array_result.as_ref(), &array_expected as &dyn Array); } + #[test] + fn test_concat_list_view_arrays() { + let list1 = [ + Some(vec![Some(-1), None]), + None, + Some(vec![Some(10), Some(20)]), + ]; + let mut list1_array = ListViewBuilder::new(Int64Builder::new()); + for v in list1.iter() { + list1_array.append_option(v.clone()); + } + let list1_array = list1_array.finish(); + + let list2 = [ + None, + Some(vec![Some(100), None]), + Some(vec![Some(102), Some(103)]), + ]; + let mut list2_array = ListViewBuilder::new(Int64Builder::new()); + for v in list2.iter() { + list2_array.append_option(v.clone()); + } + let list2_array = list2_array.finish(); + + let list3 = [Some(vec![Some(1000), Some(1001)])]; + let mut list3_array = ListViewBuilder::new(Int64Builder::new()); + for v in list3.iter() { + list3_array.append_option(v.clone()); + } + let list3_array = list3_array.finish(); + + let array_result = concat(&[&list1_array, &list2_array, &list3_array]).unwrap(); + + let expected: Vec<_> = list1.into_iter().chain(list2).chain(list3).collect(); + let mut array_expected = ListViewBuilder::new(Int64Builder::new()); + for v in expected.iter() { + array_expected.append_option(v.clone()); + } + let array_expected = array_expected.finish(); + + assert_eq!(array_result.as_ref(), &array_expected as &dyn Array); + } + + #[test] + fn test_concat_sliced_list_view_arrays() { + let list1 = [ + Some(vec![Some(-1), None]), + None, + Some(vec![Some(10), Some(20)]), + ]; + let mut list1_array = ListViewBuilder::new(Int64Builder::new()); + for v in list1.iter() { + list1_array.append_option(v.clone()); + } + let list1_array = list1_array.finish(); + + let list2 = [ + None, + Some(vec![Some(100), None]), + Some(vec![Some(102), Some(103)]), + ]; + let mut list2_array = ListViewBuilder::new(Int64Builder::new()); + for v in list2.iter() { + list2_array.append_option(v.clone()); + } + let list2_array = list2_array.finish(); + + let list3 = [Some(vec![Some(1000), Some(1001)])]; + let mut list3_array = ListViewBuilder::new(Int64Builder::new()); + for v in list3.iter() { + list3_array.append_option(v.clone()); + } + let list3_array = list3_array.finish(); + + // Concat sliced arrays. + // ListView slicing will slice the offset/sizes but preserve the original values child. + let array_result = concat(&[ + &list1_array.slice(1, 2), + &list2_array.slice(1, 2), + &list3_array.slice(0, 1), + ]) + .unwrap(); + + let expected: Vec<_> = vec![ + None, + Some(vec![Some(10), Some(20)]), + Some(vec![Some(100), None]), + Some(vec![Some(102), Some(103)]), + Some(vec![Some(1000), Some(1001)]), + ]; + let mut array_expected = ListViewBuilder::new(Int64Builder::new()); + for v in expected.iter() { + array_expected.append_option(v.clone()); + } + let array_expected = array_expected.finish(); + + assert_eq!(array_result.as_ref(), &array_expected as &dyn Array); + } + #[test] fn test_concat_struct_arrays() { let field = Arc::new(Field::new("field", DataType::Int64, true)); diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index 5c21a4adcab7..cd132de39d1c 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -380,6 +380,12 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result { Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate))) } + DataType::ListView(_) => { + Ok(Arc::new(filter_list_view::(values.as_list_view(), predicate))) + } + DataType::LargeListView(_) => { + Ok(Arc::new(filter_list_view::(values.as_list_view(), predicate))) + } DataType::RunEndEncoded(_, _) => { downcast_run_array!{ values => Ok(Arc::new(filter_run_end_array(values, predicate)?)), @@ -894,6 +900,34 @@ fn filter_sparse_union( }) } +/// `filter` implementation for list views +fn filter_list_view( + array: &GenericListViewArray, + predicate: &FilterPredicate, +) -> GenericListViewArray { + let filtered_offsets = filter_native::(array.offsets(), predicate); + let filtered_sizes = filter_native::(array.sizes(), predicate); + + // Filter the nulls + let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { + let buffer = BooleanBuffer::new(nulls, 0, predicate.count); + + Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) }) + } else { + None + }; + + let list_data = ArrayDataBuilder::new(array.data_type().clone()) + .nulls(nulls) + .buffers(vec![filtered_offsets, filtered_sizes]) + .child_data(vec![array.values().to_data()]) + .len(predicate.count); + + let list_data = unsafe { list_data.build_unchecked() }; + + GenericListViewArray::from(list_data) +} + #[cfg(test)] mod tests { use super::*; @@ -1404,6 +1438,69 @@ mod tests { assert_eq!(&make_array(expected), &result); } + fn test_case_filter_list_view() { + // [[1, 2], null, [], [3,4]] + let mut list_array = GenericListViewBuilder::::new(Int32Builder::new()); + list_array.append_value([Some(1), Some(2)]); + list_array.append_null(); + list_array.append_value([]); + list_array.append_value([Some(3), Some(4)]); + + let list_array = list_array.finish(); + let predicate = BooleanArray::from_iter([true, false, true, false]); + + // Filter result: [[1, 2], []] + let filtered = filter(&list_array, &predicate) + .unwrap() + .as_list_view::() + .clone(); + + let mut expected = + GenericListViewBuilder::::with_capacity(Int32Builder::with_capacity(5), 3); + expected.append_value([Some(1), Some(2)]); + expected.append_value([]); + let expected = expected.finish(); + + assert_eq!(&filtered, &expected); + } + + fn test_case_filter_sliced_list_view() { + // [[1, 2], null, [], [3,4]] + let mut list_array = + GenericListViewBuilder::::with_capacity(Int32Builder::with_capacity(6), 4); + list_array.append_value([Some(1), Some(2)]); + list_array.append_null(); + list_array.append_value([]); + list_array.append_value([Some(3), Some(4)]); + + let list_array = list_array.finish(); + + // Sliced: [null, [], [3, 4]] + let sliced = list_array.slice(1, 3); + let predicate = BooleanArray::from_iter([false, false, true]); + + // Filter result: [[1, 2], []] + let filtered = filter(&sliced, &predicate) + .unwrap() + .as_list_view::() + .clone(); + + let mut expected = GenericListViewBuilder::::new(Int32Builder::new()); + expected.append_value([Some(3), Some(4)]); + let expected = expected.finish(); + + assert_eq!(&filtered, &expected); + } + + #[test] + fn test_filter_list_view_array() { + test_case_filter_list_view::(); + test_case_filter_list_view::(); + + test_case_filter_sliced_list_view::(); + test_case_filter_sliced_list_view::(); + } + #[test] fn test_slice_iterator_bits() { let filter_values = (0..64).map(|i| i == 1).collect::>(); diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index dfe6903dc4e3..eec4ffa14e72 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -218,6 +218,12 @@ fn take_impl( DataType::LargeList(_) => { Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?)) } + DataType::ListView(_) => { + Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)?)) + } + DataType::LargeListView(_) => { + Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)?)) + } DataType::FixedSizeList(_, length) => { let values = values .as_any() @@ -621,6 +627,33 @@ where Ok(GenericListArray::::from(list_data)) } +fn take_list_view( + values: &GenericListViewArray, + indices: &PrimitiveArray, +) -> Result, ArrowError> +where + IndexType: ArrowPrimitiveType, + OffsetType: ArrowPrimitiveType, + OffsetType::Native: OffsetSizeTrait, +{ + let taken_offsets = take_native(values.offsets(), indices); + let taken_sizes = take_native(values.sizes(), indices); + let nulls = take_nulls(values.nulls(), indices); + + let list_view_data = ArrayDataBuilder::new(values.data_type().clone()) + .len(indices.len()) + .nulls(nulls) + .buffers(vec![taken_offsets.into(), taken_sizes.into()]) + .child_data(vec![values.values().to_data()]); + + // SAFETY: all buffers and child nodes for ListView added in constructor + let list_view_data = unsafe { list_view_data.build_unchecked() }; + + Ok(GenericListViewArray::::from( + list_view_data, + )) +} + /// `take` implementation for `FixedSizeListArray` /// /// Calculates the index and indexed offset for the inner array, @@ -980,6 +1013,7 @@ mod tests { use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use arrow_data::ArrayData; use arrow_schema::{Field, Fields, TimeUnit, UnionFields}; + use num_traits::ToPrimitive; fn test_take_decimal_arrays( data: Vec>, @@ -1821,6 +1855,55 @@ mod tests { }}; } + fn test_take_list_view_generic( + values: Vec>>>, + take_indices: Vec>, + expected: Vec>>>, + mapper: F, + ) where + F: Fn(GenericListViewArray) -> GenericListViewArray, + { + let mut list_view_array = + GenericListViewBuilder::::new(PrimitiveBuilder::::new()); + + for value in values { + list_view_array.append_option(value); + } + let list_view_array = list_view_array.finish(); + let list_view_array = mapper(list_view_array); + + let mut indices = UInt64Builder::new(); + for idx in take_indices { + indices.append_option(idx.map(|i| i.to_u64().unwrap())); + } + let indices = indices.finish(); + + let taken = take(&list_view_array, &indices, None) + .unwrap() + .as_list_view() + .clone(); + + let mut expected_array = + GenericListViewBuilder::::new(PrimitiveBuilder::::new()); + for value in expected { + expected_array.append_option(value); + } + let expected_array = expected_array.finish(); + + assert_eq!(taken, expected_array); + } + + macro_rules! list_view_test_case { + (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{ + test_take_list_view_generic::($values, $indices, $expected, |x| x); + test_take_list_view_generic::($values, $indices, $expected, |x| x); + }}; + (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{ + test_take_list_view_generic::($values, $indices, $expected, $fn); + test_take_list_view_generic::($values, $indices, $expected, $fn); + }}; + } + fn do_take_fixed_size_list_test( length: ::Native, input_data: Vec>>>, @@ -1871,6 +1954,72 @@ mod tests { test_take_list_with_nulls!(i64, LargeList, LargeListArray); } + #[test] + fn test_test_take_list_view_reversed() { + // Take reversed indices + list_view_test_case! { + values: vec![ + Some(vec![Some(1), None, Some(3)]), + None, + Some(vec![Some(7), Some(8), None]), + ], + indices: vec![Some(2), Some(1), Some(0)], + expected: vec![ + Some(vec![Some(7), Some(8), None]), + None, + Some(vec![Some(1), None, Some(3)]), + ] + } + } + + #[test] + fn test_take_list_view_null_indices() { + // Take with null indices + list_view_test_case! { + values: vec![ + Some(vec![Some(1), None, Some(3)]), + None, + Some(vec![Some(7), Some(8), None]), + ], + indices: vec![None, Some(0), None], + expected: vec![None, Some(vec![Some(1), None, Some(3)]), None] + } + } + + #[test] + fn test_take_list_view_null_values() { + // Take at null values + list_view_test_case! { + values: vec![ + Some(vec![Some(1), None, Some(3)]), + None, + Some(vec![Some(7), Some(8), None]), + ], + indices: vec![Some(1), Some(1), Some(1), None, None], + expected: vec![None; 5] + } + } + + #[test] + fn test_take_list_view_sliced() { + // Take null indices/values, with slicing. + list_view_test_case! { + values: vec![ + Some(vec![Some(1)]), + None, + None, + Some(vec![Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + None, + ], + transform: |l| l.slice(2, 4), + indices: vec![Some(0), Some(3), None, Some(1), Some(2)], + expected: vec![ + None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)]) + ] + } + } + #[test] fn test_take_fixed_size_list() { do_take_fixed_size_list_test::( diff --git a/arrow/tests/array_equal.rs b/arrow/tests/array_equal.rs index 7fc8b0be7a3d..381054a25df5 100644 --- a/arrow/tests/array_equal.rs +++ b/arrow/tests/array_equal.rs @@ -22,11 +22,17 @@ use arrow::array::{ StringDictionaryBuilder, StructArray, UnionBuilder, make_array, }; use arrow::datatypes::{Int16Type, Int32Type}; -use arrow_array::builder::{StringBuilder, StringViewBuilder, StructBuilder}; -use arrow_array::{DictionaryArray, FixedSizeListArray, StringViewArray}; +use arrow_array::builder::{ + GenericListViewBuilder, StringBuilder, StringViewBuilder, StructBuilder, +}; +use arrow_array::cast::AsArray; +use arrow_array::{ + DictionaryArray, FixedSizeListArray, GenericListViewArray, PrimitiveArray, StringViewArray, +}; use arrow_buffer::{Buffer, ToByteSlice}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{DataType, Field, Fields}; +use arrow_select::take::take; use std::sync::Arc; #[test] @@ -756,6 +762,125 @@ fn test_fixed_list_offsets() { test_equal(&a_slice, &b_slice, true); } +fn create_list_view_array< + O: OffsetSizeTrait, + U: IntoIterator>, + T: IntoIterator>, +>( + data: T, +) -> GenericListViewArray { + let mut builder = GenericListViewBuilder::::new(Int32Builder::new()); + for d in data { + if let Some(v) = d { + builder.append_value(v); + } else { + builder.append_null(); + } + } + + builder.finish() +} + +fn test_test_list_view_array() { + let a = create_list_view_array::([ + None, + Some(vec![Some(1), None, Some(2)]), + Some(vec![Some(3), Some(4), Some(5), None]), + ]); + let b = create_list_view_array::([ + None, + Some(vec![Some(1), None, Some(2)]), + Some(vec![Some(3), Some(4), Some(5), None]), + ]); + + test_equal(&a, &b, true); + + // Simple non-matching arrays by reordering + let b = create_list_view_array::([ + Some(vec![Some(3), Some(4), Some(5), None]), + Some(vec![Some(1), None, Some(2)]), + ]); + test_equal(&a, &b, false); + + // reorder using take yields equal values + let indices: PrimitiveArray = vec![None, Some(1), Some(0)].into(); + let b = take(&b, &indices, None) + .unwrap() + .as_list_view::() + .clone(); + + test_equal(&a, &b, true); + + // Slicing one side yields unequal again + let a = a.slice(1, 2); + + test_equal(&a, &b, false); + + // Slicing the other to match makes them equal again + let b = b.slice(1, 2); + + test_equal(&a, &b, true); +} + +// Special test for List>. +// This tests the equal_ranges kernel +fn test_sliced_list_of_list_view() { + // First list view is created using the builder, with elements not deduplicated. + let mut a = ListBuilder::new(GenericListViewBuilder::::new(Int32Builder::new())); + + a.append_value([Some(vec![Some(1), Some(2), Some(3)]), Some(vec![])]); + a.append_null(); + a.append_value([ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(6)]), + ]); + + let a = a.finish(); + // a = [[[1,2,3], []], null, [[4, null], [5], null, [6]]] + + // First list view is created using the builder, with elements not deduplicated. + let mut b = ListBuilder::new(GenericListViewBuilder::::new(Int32Builder::new())); + + // Add an extra row that we will slice off, adjust the List offsets + b.append_value([Some(vec![Some(0), Some(0), Some(0)])]); + b.append_value([Some(vec![Some(1), Some(2), Some(3)]), Some(vec![])]); + b.append_null(); + b.append_value([ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(6)]), + ]); + + let b = b.finish(); + // b = [[[0, 0, 0]], [[1,2,3], []], null, [[4, null], [5], null, [6]]] + let b = b.slice(1, 3); + // b = [[[1,2,3], []], null, [[4, null], [5], null, [6]]] but the outer ListArray + // has an offset + + test_equal(&a, &b, true); +} + +#[test] +fn test_list_view_array() { + test_test_list_view_array::(); +} + +#[test] +fn test_large_list_view_array() { + test_test_list_view_array::(); +} + +#[test] +fn test_nested_list_view_array() { + test_sliced_list_of_list_view::(); +} + +#[test] +fn test_nested_large_list_view_array() { + test_sliced_list_of_list_view::(); +} + #[test] fn test_struct_equal() { let strings: ArrayRef = Arc::new(StringArray::from(vec![ From 06c49db3e736aa6990e56f7099a5fba9dc5c3c8d Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 28 Oct 2025 07:32:02 +1300 Subject: [PATCH 14/19] [Parquet] Account for FileDecryptor in ParquetMetaData heap size calculation (#8671) # Which issue does this PR close? - Closes #8472. # Rationale for this change Makes the metadata heap size calculation more accurate when reading encrypted Parquet files, which helps to better manage caches of Parquet metadata. # What changes are included in this PR? * Accounts for heap allocations related to the `FileDecryptor` in `ParquetMetaData` * Fixes the `HeapSize` implementation for `Arc` so the size of `T` is included, as well as the reference counts that are stored on the heap * Fixes the heap size of type pointers within `ColumnDescriptor` being included twice ## Not included * Accounting for any heap allocations in a user-provided `KeyRetriever` # Are these changes tested? Yes, there's a new unit test added that computes the heap size with a decryptor. I also did a manual test that created a test Parquet file with 100 columns using per-column encryption keys, and loaded 10,000 copies of the `ParquetMetaData` into a vector. `heaptrack` reported 1.136 GB memory heap allocated in this test program. Prior to this change, the sum of the metadata was reported as 879.2 MB, and afterwards it was 1.136 GB. # Are there any user-facing changes? No This was co-authored by @etseidl --------- Co-authored-by: Ed Seidl --- parquet/src/encryption/ciphers.rs | 10 +++- parquet/src/encryption/decrypt.rs | 40 +++++++++++++ parquet/src/file/metadata/memory.rs | 54 ++++++++++++++++- parquet/src/file/metadata/mod.rs | 91 +++++++++++++++++++++++++++-- parquet/src/schema/types.rs | 4 +- 5 files changed, 190 insertions(+), 9 deletions(-) diff --git a/parquet/src/encryption/ciphers.rs b/parquet/src/encryption/ciphers.rs index faff28f8acff..a94c72dcd5ec 100644 --- a/parquet/src/encryption/ciphers.rs +++ b/parquet/src/encryption/ciphers.rs @@ -18,6 +18,7 @@ use crate::errors::ParquetError; use crate::errors::ParquetError::General; use crate::errors::Result; +use crate::file::metadata::HeapSize; use ring::aead::{AES_128_GCM, Aad, LessSafeKey, NonceSequence, UnboundKey}; use ring::rand::{SecureRandom, SystemRandom}; use std::fmt::Debug; @@ -27,7 +28,7 @@ pub(crate) const NONCE_LEN: usize = 12; pub(crate) const TAG_LEN: usize = 16; pub(crate) const SIZE_LEN: usize = 4; -pub(crate) trait BlockDecryptor: Debug + Send + Sync { +pub(crate) trait BlockDecryptor: Debug + Send + Sync + HeapSize { fn decrypt(&self, length_and_ciphertext: &[u8], aad: &[u8]) -> Result>; fn compute_plaintext_tag(&self, aad: &[u8], plaintext: &[u8]) -> Result>; @@ -50,6 +51,13 @@ impl RingGcmBlockDecryptor { } } +impl HeapSize for RingGcmBlockDecryptor { + fn heap_size(&self) -> usize { + // Ring's LessSafeKey doesn't allocate on the heap + 0 + } +} + impl BlockDecryptor for RingGcmBlockDecryptor { fn decrypt(&self, length_and_ciphertext: &[u8], aad: &[u8]) -> Result> { let mut result = Vec::with_capacity(length_and_ciphertext.len() - SIZE_LEN - NONCE_LEN); diff --git a/parquet/src/encryption/decrypt.rs b/parquet/src/encryption/decrypt.rs index b5374066dfc3..0066523419de 100644 --- a/parquet/src/encryption/decrypt.rs +++ b/parquet/src/encryption/decrypt.rs @@ -21,6 +21,7 @@ use crate::encryption::ciphers::{BlockDecryptor, RingGcmBlockDecryptor, TAG_LEN} use crate::encryption::modules::{ModuleType, create_footer_aad, create_module_aad}; use crate::errors::{ParquetError, Result}; use crate::file::column_crypto_metadata::ColumnCryptoMetaData; +use crate::file::metadata::HeapSize; use std::borrow::Cow; use std::collections::HashMap; use std::fmt::Formatter; @@ -271,6 +272,12 @@ struct ExplicitDecryptionKeys { column_keys: HashMap>, } +impl HeapSize for ExplicitDecryptionKeys { + fn heap_size(&self) -> usize { + self.footer_key.heap_size() + self.column_keys.heap_size() + } +} + #[derive(Clone)] enum DecryptionKeys { Explicit(ExplicitDecryptionKeys), @@ -290,6 +297,19 @@ impl PartialEq for DecryptionKeys { } } +impl HeapSize for DecryptionKeys { + fn heap_size(&self) -> usize { + match self { + Self::Explicit(keys) => keys.heap_size(), + Self::ViaRetriever(_) => { + // The retriever is a user-defined type we don't control, + // so we can't determine the heap size. + 0 + } + } + } +} + /// `FileDecryptionProperties` hold keys and AAD data required to decrypt a Parquet file. /// /// When reading Arrow data, the `FileDecryptionProperties` should be included in the @@ -334,6 +354,11 @@ pub struct FileDecryptionProperties { footer_signature_verification: bool, } +impl HeapSize for FileDecryptionProperties { + fn heap_size(&self) -> usize { + self.keys.heap_size() + self.aad_prefix.heap_size() + } +} impl FileDecryptionProperties { /// Returns a new [`FileDecryptionProperties`] builder that will use the provided key to /// decrypt footer metadata. @@ -547,6 +572,21 @@ impl PartialEq for FileDecryptor { } } +/// Estimate the size in bytes required for the file decryptor. +/// This is important to track the memory usage of cached Parquet meta data, +/// and is used via [`crate::file::metadata::ParquetMetaData::memory_size`]. +/// Note that when a [`KeyRetriever`] is used, its heap size won't be included +/// and the result will be an underestimate. +/// If the [`FileDecryptionProperties`] are shared between multiple files then the +/// heap size may also be an overestimate. +impl HeapSize for FileDecryptor { + fn heap_size(&self) -> usize { + self.decryption_properties.heap_size() + + (Arc::clone(&self.footer_decryptor) as Arc).heap_size() + + self.file_aad.heap_size() + } +} + impl FileDecryptor { pub(crate) fn new( decryption_properties: &Arc, diff --git a/parquet/src/file/metadata/memory.rs b/parquet/src/file/metadata/memory.rs index 98ce5736ae1d..11536bbbd41e 100644 --- a/parquet/src/file/metadata/memory.rs +++ b/parquet/src/file/metadata/memory.rs @@ -28,6 +28,7 @@ use crate::file::page_index::column_index::{ }; use crate::file::page_index::offset_index::{OffsetIndexMetaData, PageLocation}; use crate::file::statistics::{Statistics, ValueStatistics}; +use std::collections::HashMap; use std::sync::Arc; /// Trait for calculating the size of various containers @@ -50,9 +51,60 @@ impl HeapSize for Vec { } } +impl HeapSize for HashMap { + fn heap_size(&self) -> usize { + let capacity = self.capacity(); + if capacity == 0 { + return 0; + } + + // HashMap doesn't provide a way to get its heap size, so this is an approximation based on + // the behavior of hashbrown::HashMap as at version 0.16.0, and may become inaccurate + // if the implementation changes. + let key_val_size = std::mem::size_of::<(K, V)>(); + // Overhead for the control tags group, which may be smaller depending on architecture + let group_size = 16; + // 1 byte of metadata stored per bucket. + let metadata_size = 1; + + // Compute the number of buckets for the capacity. Based on hashbrown's capacity_to_buckets + let buckets = if capacity < 15 { + let min_cap = match key_val_size { + 0..=1 => 14, + 2..=3 => 7, + _ => 3, + }; + let cap = min_cap.max(capacity); + if cap < 4 { + 4 + } else if cap < 8 { + 8 + } else { + 16 + } + } else { + (capacity.saturating_mul(8) / 7).next_power_of_two() + }; + + group_size + + (buckets * (key_val_size + metadata_size)) + + self.keys().map(|k| k.heap_size()).sum::() + + self.values().map(|v| v.heap_size()).sum::() + } +} + impl HeapSize for Arc { fn heap_size(&self) -> usize { - self.as_ref().heap_size() + // Arc stores weak and strong counts on the heap alongside an instance of T + 2 * std::mem::size_of::() + std::mem::size_of::() + self.as_ref().heap_size() + } +} + +impl HeapSize for Arc { + fn heap_size(&self) -> usize { + 2 * std::mem::size_of::() + + std::mem::size_of_val(self.as_ref()) + + self.as_ref().heap_size() } } diff --git a/parquet/src/file/metadata/mod.rs b/parquet/src/file/metadata/mod.rs index 763025fe142b..7022bd61c44d 100644 --- a/parquet/src/file/metadata/mod.rs +++ b/parquet/src/file/metadata/mod.rs @@ -287,11 +287,17 @@ impl ParquetMetaData { /// /// 4. Does not include any allocator overheads pub fn memory_size(&self) -> usize { + #[cfg(feature = "encryption")] + let encryption_size = self.file_decryptor.heap_size(); + #[cfg(not(feature = "encryption"))] + let encryption_size = 0usize; + std::mem::size_of::() + self.file_metadata.heap_size() + self.row_groups.heap_size() + self.column_index.heap_size() + self.offset_index.heap_size() + + encryption_size } /// Override the column index @@ -1875,10 +1881,9 @@ mod tests { .build(); #[cfg(not(feature = "encryption"))] - let base_expected_size = 2248; + let base_expected_size = 2766; #[cfg(feature = "encryption")] - // Not as accurate as it should be: https://github.com/apache/arrow-rs/issues/8472 - let base_expected_size = 2416; + let base_expected_size = 2934; assert_eq!(parquet_meta.memory_size(), base_expected_size); @@ -1907,16 +1912,90 @@ mod tests { .build(); #[cfg(not(feature = "encryption"))] - let bigger_expected_size = 2674; + let bigger_expected_size = 3192; #[cfg(feature = "encryption")] - // Not as accurate as it should be: https://github.com/apache/arrow-rs/issues/8472 - let bigger_expected_size = 2842; + let bigger_expected_size = 3360; // more set fields means more memory usage assert!(bigger_expected_size > base_expected_size); assert_eq!(parquet_meta.memory_size(), bigger_expected_size); } + #[test] + #[cfg(feature = "encryption")] + fn test_memory_size_with_decryptor() { + use crate::encryption::decrypt::FileDecryptionProperties; + use crate::file::metadata::thrift::encryption::AesGcmV1; + + let schema_descr = get_test_schema_descr(); + + let columns = schema_descr + .columns() + .iter() + .map(|column_descr| ColumnChunkMetaData::builder(column_descr.clone()).build()) + .collect::>>() + .unwrap(); + let row_group_meta = RowGroupMetaData::builder(schema_descr.clone()) + .set_num_rows(1000) + .set_column_metadata(columns) + .build() + .unwrap(); + let row_group_meta = vec![row_group_meta]; + + let version = 2; + let num_rows = 1000; + let aad_file_unique = vec![1u8; 8]; + let aad_prefix = vec![2u8; 8]; + let encryption_algorithm = EncryptionAlgorithm::AES_GCM_V1(AesGcmV1 { + aad_prefix: Some(aad_prefix.clone()), + aad_file_unique: Some(aad_file_unique.clone()), + supply_aad_prefix: Some(true), + }); + let footer_key_metadata = Some(vec![3u8; 8]); + let file_metadata = + FileMetaData::new(version, num_rows, None, None, schema_descr.clone(), None) + .with_encryption_algorithm(Some(encryption_algorithm)) + .with_footer_signing_key_metadata(footer_key_metadata.clone()); + + let parquet_meta_data = ParquetMetaDataBuilder::new(file_metadata.clone()) + .set_row_groups(row_group_meta.clone()) + .build(); + + let base_expected_size = 2058; + assert_eq!(parquet_meta_data.memory_size(), base_expected_size); + + let footer_key = "0123456789012345".as_bytes(); + let column_key = "1234567890123450".as_bytes(); + let mut decryption_properties_builder = + FileDecryptionProperties::builder(footer_key.to_vec()) + .with_aad_prefix(aad_prefix.clone()); + for column in schema_descr.columns() { + decryption_properties_builder = decryption_properties_builder + .with_column_key(&column.path().string(), column_key.to_vec()); + } + let decryption_properties = decryption_properties_builder.build().unwrap(); + let decryptor = FileDecryptor::new( + &decryption_properties, + footer_key_metadata.as_deref(), + aad_file_unique, + aad_prefix, + ) + .unwrap(); + + let parquet_meta_data = ParquetMetaDataBuilder::new(file_metadata.clone()) + .set_row_groups(row_group_meta.clone()) + .set_file_decryptor(Some(decryptor)) + .build(); + + let expected_size_with_decryptor = 3072; + assert!(expected_size_with_decryptor > base_expected_size); + + assert_eq!( + parquet_meta_data.memory_size(), + expected_size_with_decryptor + ); + } + /// Returns sample schema descriptor so we can create column metadata. fn get_test_schema_descr() -> SchemaDescPtr { let schema = SchemaType::group_type_builder("schema") diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index de6f855685a6..50ae4955380b 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -845,7 +845,9 @@ pub struct ColumnDescriptor { impl HeapSize for ColumnDescriptor { fn heap_size(&self) -> usize { - self.primitive_type.heap_size() + self.path.heap_size() + // Don't include the heap size of primitive_type, this is already + // accounted for via SchemaDescriptor::schema + self.path.heap_size() } } From 7a92be56d1a77696fa6c185ce25bfaf7c113dbc8 Mon Sep 17 00:00:00 2001 From: Liam Bao Date: Mon, 27 Oct 2025 14:32:25 -0400 Subject: [PATCH 15/19] Refactor arrow-cast decimal casting to unify the rescale logic used in Parquet variant casts (#8689) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Which issue does this PR close? - Closes #8670. # Rationale for this change We currently have two separate code paths that both handle decimal casting between different (precision, scale) pairs. Without unifying the logic, a fix in one place often needs to be duplicated in the other (e.g., https://github.com/apache/arrow-rs/issues/8579 fixed the `arrow-cast` and #8552 fixed the `parquet-variant-compute`), which can easily lead to divergence when contributors lack full context. This PR consolidates the decimal rescale logic for both `arrow-cast` and `parquet-variant-compute`. # What changes are included in this PR? 1. Extract the shared array-unary logic from `convert_to_smaller_scale_decimal` and `convert_to_bigger_or_equal_scale_decimal` into `apply_decimal_cast` 2. Move the rescale-closure creation into `make_upscaler` and `make_downscaler` so that they can be used in `parquet-compute-variant` 3. rework `rescale_decimal` in `parquet-compute-variant` to use the new `make_upscaler` and `make_downscaler` utilities. One challenge is incorporating the large-scale reduction path (aka the `delta_scale` cannot fit into `I::MAX_PRECISION`) into `make_downscaler` without hurting performance. Returning 0 directly is usually cheaper than applying a unary operation to return zero. Therefore, `make_downscaler` may return None, and it is the caller’s responsibility to handle this case appropriately based on the documented behavior. # Are these changes tested? Covered by existing tests # Are there any user-facing changes? No --- arrow-cast/src/cast/decimal.rs | 361 ++++++++++++++---- arrow-cast/src/cast/mod.rs | 2 +- .../src/type_conversion.rs | 87 +---- 3 files changed, 289 insertions(+), 161 deletions(-) diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index 907e61b09f7b..71338a6921e9 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -145,54 +145,82 @@ impl DecimalCast for i256 { } } -pub(crate) fn cast_decimal_to_decimal_error( +/// Construct closures to upscale decimals from `(input_precision, input_scale)` to +/// `(output_precision, output_scale)`. +/// +/// Returns `(f_fallible, f_infallible)` where: +/// * `f_fallible` yields `None` when the requested cast would overflow +/// * `f_infallible` is present only when every input is guaranteed to succeed; otherwise it is `None` +/// and callers must fall back to `f_fallible` +/// +/// Returns `None` if the required scale increase `delta_scale = output_scale - input_scale` +/// exceeds the supported precomputed precision table `O::MAX_FOR_EACH_PRECISION`. +/// In that case, the caller should treat this as an overflow for the output scale +/// and handle it accordingly (e.g., return a cast error). +#[allow(clippy::type_complexity)] +fn make_upscaler( + input_precision: u8, + input_scale: i8, output_precision: u8, output_scale: i8, -) -> impl Fn(::Native) -> ArrowError +) -> Option<( + impl Fn(I::Native) -> Option, + Option O::Native>, +)> where - I: DecimalType, - O: DecimalType, I::Native: DecimalCast + ArrowNativeTypeOp, O::Native: DecimalCast + ArrowNativeTypeOp, { - move |x: I::Native| { - ArrowError::CastError(format!( - "Cannot cast to {}({}, {}). Overflowing on {:?}", - O::PREFIX, - output_precision, - output_scale, - x - )) - } + let delta_scale = output_scale - input_scale; + + // O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999, ...). + // Adding 1 yields exactly 10^k without computing a power at runtime. + // Using the precomputed table avoids pow(10, k) and its checked/overflow + // handling, which is faster and simpler for scaling by 10^delta_scale. + let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; + let mul = max.add_wrapping(O::Native::ONE); + let f_fallible = move |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); + + // if the gain in precision (digits) is greater than the multiplication due to scaling + // every number will fit into the output type + // Example: If we are starting with any number of precision 5 [xxxxx], + // then an increase of scale by 3 will have the following effect on the representation: + // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type + // needs to provide at least 8 digits precision + let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8); + let f_infallible = is_infallible_cast + .then_some(move |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul)); + Some((f_fallible, f_infallible)) } -pub(crate) fn convert_to_smaller_scale_decimal( - array: &PrimitiveArray, +/// Construct closures to downscale decimals from `(input_precision, input_scale)` to +/// `(output_precision, output_scale)`. +/// +/// Returns `(f_fallible, f_infallible)` where: +/// * `f_fallible` yields `None` when the requested cast would overflow +/// * `f_infallible` is present only when every input is guaranteed to succeed; otherwise it is `None` +/// and callers must fall back to `f_fallible` +/// +/// Returns `None` if the required scale reduction `delta_scale = input_scale - output_scale` +/// exceeds the supported precomputed precision table `I::MAX_FOR_EACH_PRECISION`. +/// In this scenario, any value would round to zero (e.g., dividing by 10^k where k exceeds the +/// available precision). Callers should therefore produce zero values (preserving nulls) rather +/// than returning an error. +#[allow(clippy::type_complexity)] +fn make_downscaler( input_precision: u8, input_scale: i8, output_precision: u8, output_scale: i8, - cast_options: &CastOptions, -) -> Result, ArrowError> +) -> Option<( + impl Fn(I::Native) -> Option, + Option O::Native>, +)> where - I: DecimalType, - O: DecimalType, I::Native: DecimalCast + ArrowNativeTypeOp, O::Native: DecimalCast + ArrowNativeTypeOp, { - let error = cast_decimal_to_decimal_error::(output_precision, output_scale); let delta_scale = input_scale - output_scale; - // if the reduction of the input number through scaling (dividing) is greater - // than a possible precision loss (plus potential increase via rounding) - // every input number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then and decrease the scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). - // The rounding may add an additional digit, so the cast to be infallible, - // the output type needs to have at least 3 digits of precision. - // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: - // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible - let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8); // delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the // scale change divides out more digits than the input has precision and the result of the cast @@ -200,16 +228,13 @@ where // possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values // (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even // smaller results, which also round to zero. In that case, just return an array of zeros. - let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize) else { - let zeros = vec![O::Native::ZERO; array.len()]; - return Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned())); - }; + let max = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; let div = max.add_wrapping(I::Native::ONE); let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE)); let half_neg = half.neg_wrapping(); - let f = |x: I::Native| { + let f_fallible = move |x: I::Native| { // div is >= 10 and so this cannot overflow let d = x.div_wrapping(div); let r = x.mod_wrapping(div); @@ -223,24 +248,136 @@ where O::Native::from_decimal(adjusted) }; - Ok(if is_infallible_cast { - // make sure we don't perform calculations that don't make sense w/o validation - validate_decimal_precision_and_scale::(output_precision, output_scale)?; - let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed - // to fit into the target type - array.unary(g) + // if the reduction of the input number through scaling (dividing) is greater + // than a possible precision loss (plus potential increase via rounding) + // every input number will fit into the output type + // Example: If we are starting with any number of precision 5 [xxxxx], + // then and decrease the scale by 3 will have the following effect on the representation: + // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). + // The rounding may add a digit, so the cast to be infallible, + // the output type needs to have at least 3 digits of precision. + // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: + // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible + let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8); + let f_infallible = is_infallible_cast.then_some(move |x| f_fallible(x).unwrap()); + Some((f_fallible, f_infallible)) +} + +/// Apply the rescaler function to the value. +/// If the rescaler is infallible, use the infallible function. +/// Otherwise, use the fallible function and validate the precision. +fn apply_rescaler( + value: I::Native, + output_precision: u8, + f: impl Fn(I::Native) -> Option, + f_infallible: Option O::Native>, +) -> Option +where + I::Native: DecimalCast, + O::Native: DecimalCast, +{ + if let Some(f_infallible) = f_infallible { + Some(f_infallible(value)) + } else { + f(value).filter(|v| O::is_valid_decimal_precision(*v, output_precision)) + } +} + +/// Rescales a decimal value from `(input_precision, input_scale)` to +/// `(output_precision, output_scale)` and returns the converted number when it fits +/// within the output precision. +/// +/// The function first validates that the requested precision and scale are supported for +/// both the source and destination decimal types. It then either upscales (multiplying +/// by an appropriate power of ten) or downscales (dividing with rounding) the input value. +/// When the scaling factor exceeds the precision table of the destination type, the value +/// is treated as an overflow for upscaling, or rounded to zero for downscaling (as any +/// possible result would be zero at the requested scale). +/// +/// This mirrors the column-oriented helpers of decimal casting but operates on a single value +/// (row-level) instead of an entire array. +/// +/// Returns `None` if the value cannot be represented with the requested precision. +pub fn rescale_decimal( + value: I::Native, + input_precision: u8, + input_scale: i8, + output_precision: u8, + output_scale: i8, +) -> Option +where + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + validate_decimal_precision_and_scale::(input_precision, input_scale).ok()?; + validate_decimal_precision_and_scale::(output_precision, output_scale).ok()?; + + if input_scale <= output_scale { + let (f, f_infallible) = + make_upscaler::(input_precision, input_scale, output_precision, output_scale)?; + apply_rescaler::(value, output_precision, f, f_infallible) + } else { + let Some((f, f_infallible)) = + make_downscaler::(input_precision, input_scale, output_precision, output_scale) + else { + // Scale reduction exceeds supported precision; result mathematically rounds to zero + return Some(O::Native::ZERO); + }; + apply_rescaler::(value, output_precision, f, f_infallible) + } +} + +fn cast_decimal_to_decimal_error( + output_precision: u8, + output_scale: i8, +) -> impl Fn(::Native) -> ArrowError +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + move |x: I::Native| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + O::PREFIX, + output_precision, + output_scale, + x + )) + } +} + +fn apply_decimal_cast( + array: &PrimitiveArray, + output_precision: u8, + output_scale: i8, + f_fallible: impl Fn(I::Native) -> Option, + f_infallible: Option O::Native>, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let array = if let Some(f_infallible) = f_infallible { + array.unary(f_infallible) } else if cast_options.safe { - array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) + array.unary_opt(|x| { + f_fallible(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)) + }) } else { + let error = cast_decimal_to_decimal_error::(output_precision, output_scale); array.try_unary(|x| { - f(x).ok_or_else(|| error(x)).and_then(|v| { + f_fallible(x).ok_or_else(|| error(x)).and_then(|v| { O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v) }) })? - }) + }; + Ok(array) } -pub(crate) fn convert_to_bigger_or_equal_scale_decimal( +fn convert_to_smaller_scale_decimal( array: &PrimitiveArray, input_precision: u8, input_scale: i8, @@ -254,36 +391,58 @@ where I::Native: DecimalCast + ArrowNativeTypeOp, O::Native: DecimalCast + ArrowNativeTypeOp, { - let error = cast_decimal_to_decimal_error::(output_precision, output_scale); - let delta_scale = output_scale - input_scale; - let mul = O::Native::from_decimal(10_i128) - .unwrap() - .pow_checked(delta_scale as u32)?; + if let Some((f_fallible, f_infallible)) = + make_downscaler::(input_precision, input_scale, output_precision, output_scale) + { + apply_decimal_cast( + array, + output_precision, + output_scale, + f_fallible, + f_infallible, + cast_options, + ) + } else { + // Scale reduction exceeds supported precision; result mathematically rounds to zero + let zeros = vec![O::Native::ZERO; array.len()]; + Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned())) + } +} - // if the gain in precision (digits) is greater than the multiplication due to scaling - // every number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then an increase of scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type - // needs to provide at least 8 digits precision - let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8); - let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); - - Ok(if is_infallible_cast { - // make sure we don't perform calculations that don't make sense w/o validation - validate_decimal_precision_and_scale::(output_precision, output_scale)?; - // unwrapping is safe since the result is guaranteed to fit into the target type - let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul); - array.unary(f) - } else if cast_options.safe { - array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) +fn convert_to_bigger_or_equal_scale_decimal( + array: &PrimitiveArray, + input_precision: u8, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + if let Some((f, f_infallible)) = + make_upscaler::(input_precision, input_scale, output_precision, output_scale) + { + apply_decimal_cast( + array, + output_precision, + output_scale, + f, + f_infallible, + cast_options, + ) } else { - array.try_unary(|x| { - f(x).ok_or_else(|| error(x)).and_then(|v| { - O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v) - }) - })? - }) + // Scale increase exceeds supported precision; return overflow error + Err(ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Value overflows for output scale", + O::PREFIX, + output_precision, + output_scale + ))) + } } // Only support one type of decimal cast operations @@ -763,4 +922,58 @@ mod tests { ); Ok(()) } + + #[test] + fn test_rescale_decimal_upscale_within_precision() { + let result = rescale_decimal::( + 12_345_i128, // 123.45 with scale 2 + 5, + 2, + 8, + 5, + ); + assert_eq!(result, Some(12_345_000_i128)); + } + + #[test] + fn test_rescale_decimal_downscale_rounds_half_away_from_zero() { + let positive = rescale_decimal::( + 1_050_i128, // 1.050 with scale 3 + 5, 3, 5, 1, + ); + assert_eq!(positive, Some(11_i128)); // 1.1 with scale 1 + + let negative = rescale_decimal::( + -1_050_i128, // -1.050 with scale 3 + 5, + 3, + 5, + 1, + ); + assert_eq!(negative, Some(-11_i128)); // -1.1 with scale 1 + } + + #[test] + fn test_rescale_decimal_downscale_large_delta_returns_zero() { + let result = rescale_decimal::(12_345_i32, 9, 9, 9, 4); + assert_eq!(result, Some(0_i32)); + } + + #[test] + fn test_rescale_decimal_upscale_overflow_returns_none() { + let result = rescale_decimal::(9_999_i32, 4, 0, 5, 2); + assert_eq!(result, None); + } + + #[test] + fn test_rescale_decimal_invalid_input_precision_scale_returns_none() { + let result = rescale_decimal::(123_i128, 39, 39, 38, 38); + assert_eq!(result, None); + } + + #[test] + fn test_rescale_decimal_invalid_output_precision_scale_returns_none() { + let result = rescale_decimal::(123_i128, 38, 38, 39, 39); + assert_eq!(result, None); + } } diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index bb3247ca3c3c..47fdb01a09f4 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -69,7 +69,7 @@ use arrow_schema::*; use arrow_select::take::take; use num_traits::{NumCast, ToPrimitive, cast::AsPrimitive}; -pub use decimal::DecimalCast; +pub use decimal::{DecimalCast, rescale_decimal}; /// CastOptions provides a way to override the default cast behaviors #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/parquet-variant-compute/src/type_conversion.rs b/parquet-variant-compute/src/type_conversion.rs index 28087d7541e4..d15664f5af9e 100644 --- a/parquet-variant-compute/src/type_conversion.rs +++ b/parquet-variant-compute/src/type_conversion.rs @@ -17,8 +17,7 @@ //! Module for transforming a typed arrow `Array` to `VariantArray`. -use arrow::array::ArrowNativeTypeOp; -use arrow::compute::DecimalCast; +use arrow::compute::{DecimalCast, rescale_decimal}; use arrow::datatypes::{ self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type, DecimalType, @@ -190,90 +189,6 @@ where } } -/// Rescale a decimal from (input_precision, input_scale) to (output_precision, output_scale) -/// and return the scaled value if it fits the output precision. Similar to the implementation in -/// decimal.rs in arrow-cast. -pub(crate) fn rescale_decimal( - value: I::Native, - input_precision: u8, - input_scale: i8, - output_precision: u8, - output_scale: i8, -) -> Option -where - I::Native: DecimalCast, - O::Native: DecimalCast, -{ - let delta_scale = output_scale - input_scale; - - let (scaled, is_infallible_cast) = if delta_scale >= 0 { - // O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999, ...). - // Adding 1 yields exactly 10^k without computing a power at runtime. - // Using the precomputed table avoids pow(10, k) and its checked/overflow - // handling, which is faster and simpler for scaling by 10^delta_scale. - let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; - let mul = max.add_wrapping(O::Native::ONE); - - // if the gain in precision (digits) is greater than the multiplication due to scaling - // every number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then an increase of scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type - // needs to provide at least 8 digits precision - let is_infallible_cast = input_precision as i8 + delta_scale <= output_precision as i8; - let value = O::Native::from_decimal(value); - let scaled = if is_infallible_cast { - Some(value.unwrap().mul_wrapping(mul)) - } else { - value.and_then(|x| x.mul_checked(mul).ok()) - }; - (scaled, is_infallible_cast) - } else { - // the abs of delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. - // If so, the scale change divides out more digits than the input has precision and the result - // of the cast is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, - // the largest possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. - // Smaller values (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) - // produce even smaller results, which also round to zero. In that case, just return zero. - let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale.unsigned_abs() as usize) else { - return Some(O::Native::ZERO); - }; - let div = max.add_wrapping(I::Native::ONE); - let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE)); - let half_neg = half.neg_wrapping(); - - // div is >= 10 and so this cannot overflow - let d = value.div_wrapping(div); - let r = value.mod_wrapping(div); - - // Round result - let adjusted = match value >= I::Native::ZERO { - true if r >= half => d.add_wrapping(I::Native::ONE), - false if r <= half_neg => d.sub_wrapping(I::Native::ONE), - _ => d, - }; - - // if the reduction of the input number through scaling (dividing) is greater - // than a possible precision loss (plus potential increase via rounding) - // every input number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then and decrease the scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). - // The rounding may add a digit, so for the cast to be infallible, - // the output type needs to have at least 3 digits of precision. - // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: - // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible - let is_infallible_cast = input_precision as i8 + delta_scale < output_precision as i8; - (O::Native::from_decimal(adjusted), is_infallible_cast) - }; - - if is_infallible_cast { - scaled - } else { - scaled.filter(|v| O::is_valid_decimal_precision(*v, output_precision)) - } -} - /// Convert the value at a specific index in the given array into a `Variant`. macro_rules! non_generic_conversion_single_value { ($array:expr, $cast_fn:expr, $index:expr) => {{ From 62df32e29d03b69c3eeeb12268fdd539ebd00098 Mon Sep 17 00:00:00 2001 From: Vegard Stikbakke Date: Mon, 27 Oct 2025 20:24:02 +0100 Subject: [PATCH 16/19] Add benchmark for casting to RunEndEncoded (REE) (#8710) Closes #8709. Adds bench `cast_ree` which can be run with `cargo bench --bench cast_ree`. --------- Co-authored-by: Andrew Lamb --- arrow-cast/Cargo.toml | 1 + arrow/benches/cast_kernels.rs | 40 +++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/arrow-cast/Cargo.toml b/arrow-cast/Cargo.toml index f3309783fb38..fb5ad1af3d3a 100644 --- a/arrow-cast/Cargo.toml +++ b/arrow-cast/Cargo.toml @@ -75,3 +75,4 @@ harness = false [[bench]] name = "parse_decimal" harness = false + diff --git a/arrow/benches/cast_kernels.rs b/arrow/benches/cast_kernels.rs index a54529c8d108..040c118a1e83 100644 --- a/arrow/benches/cast_kernels.rs +++ b/arrow/benches/cast_kernels.rs @@ -359,6 +359,46 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("cast binary view to string view", |b| { b.iter(|| cast_array(&binary_view_array, DataType::Utf8View)) }); + + c.bench_function("cast string single run to ree", |b| { + let source_array = StringArray::from(vec!["a"; 8192]); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + b.iter(|| cast(&array_ref, &target_type).unwrap()); + }); + + c.bench_function("cast runs of 10 string to ree", |b| { + let source_array: Int32Array = (0..8192).map(|i| i / 10).collect(); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + b.iter(|| cast(&array_ref, &target_type).unwrap()); + }); + + c.bench_function("cast runs of 1000 int32s to ree", |b| { + let source_array: Int32Array = (0..8192).map(|i| i / 1000).collect(); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + b.iter(|| cast(&array_ref, &target_type).unwrap()); + }); + + c.bench_function("cast no runs of int32s to ree", |b| { + let source_array: Int32Array = (0..8192).collect(); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + b.iter(|| cast(&array_ref, &target_type).unwrap()); + }); } criterion_group!(benches, add_benchmark); From 6c3e5881185bf74b56ce96233135bb540826554b Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 27 Oct 2025 22:34:29 +0200 Subject: [PATCH 17/19] fix: `zip` now treats nulls as false in provided mask regardless of the underlying bit value (#8711) # Which issue does this PR close? - closes https://github.com/apache/arrow-rs/issues/8721 # Rationale for this change mask `nulls` should be treated as `false` (even if the underlying values are not 0) as described in the docs for zip # What changes are included in this PR? used `prep_null_mask_filter` before iterating over the mask, added tests for both scalar and non scalar (to prepare for #8653) # Are these changes tested? Yes # Are there any user-facing changes? Kinda --- arrow-select/src/filter.rs | 8 ++- arrow-select/src/zip.rs | 125 ++++++++++++++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 3 deletions(-) diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index cd132de39d1c..6a5ba13c950a 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -58,7 +58,13 @@ pub struct SlicesIterator<'a>(BitSliceIterator<'a>); impl<'a> SlicesIterator<'a> { /// Creates a new iterator from a [BooleanArray] pub fn new(filter: &'a BooleanArray) -> Self { - Self(filter.values().set_slices()) + filter.values().into() + } +} + +impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> { + fn from(filter: &'a BooleanBuffer) -> Self { + Self(filter.set_slices()) } } diff --git a/arrow-select/src/zip.rs b/arrow-select/src/zip.rs index 2efd2e749921..c202be6b6299 100644 --- a/arrow-select/src/zip.rs +++ b/arrow-select/src/zip.rs @@ -17,8 +17,9 @@ //! [`zip`]: Combine values from two arrays based on boolean mask -use crate::filter::SlicesIterator; +use crate::filter::{SlicesIterator, prep_null_mask_filter}; use arrow_array::*; +use arrow_buffer::BooleanBuffer; use arrow_data::transform::MutableArrayData; use arrow_schema::ArrowError; @@ -127,7 +128,8 @@ pub fn zip( // keep track of how much is filled let mut filled = 0; - SlicesIterator::new(mask).for_each(|(start, end)| { + let mask = maybe_prep_null_mask_filter(mask); + SlicesIterator::from(&mask).for_each(|(start, end)| { // the gap needs to be filled with falsy values if start > filled { if falsy_is_scalar { @@ -166,9 +168,22 @@ pub fn zip( Ok(make_array(data)) } +fn maybe_prep_null_mask_filter(predicate: &BooleanArray) -> BooleanBuffer { + // Nulls are treated as false + if predicate.null_count() == 0 { + predicate.values().clone() + } else { + let cleaned = prep_null_mask_filter(predicate); + let (boolean_buffer, _) = cleaned.into_parts(); + boolean_buffer + } +} + #[cfg(test)] mod test { use super::*; + use arrow_array::cast::AsArray; + use arrow_buffer::{BooleanBuffer, NullBuffer}; #[test] fn test_zip_kernel_one() { @@ -279,4 +294,110 @@ mod test { let expected = Int32Array::from(vec![None, None, Some(42), Some(42), None]); assert_eq!(actual, &expected); } + + #[test] + fn test_zip_primitive_array_with_nulls_is_mask_should_be_treated_as_false() { + let truthy = Int32Array::from_iter_values(vec![1, 2, 3, 4, 5, 6]); + let falsy = Int32Array::from_iter_values(vec![7, 8, 9, 10, 11, 12]); + + let mask = { + let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]); + let nulls = NullBuffer::from(vec![ + true, true, true, + false, // null treated as false even though in the original mask it was true + true, true, + ]); + BooleanArray::new(booleans, Some(nulls)) + }; + let out = zip(&mask, &truthy, &falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = Int32Array::from(vec![ + Some(1), + Some(2), + Some(9), + Some(10), // true in mask but null + Some(11), + Some(12), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_primitive_scalar_with_boolean_array_mask_with_nulls_should_be_treated_as_false() + { + let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1)); + let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1)); + + let mask = { + let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]); + let nulls = NullBuffer::from(vec![ + true, true, true, + false, // null treated as false even though in the original mask it was true + true, true, + ]); + BooleanArray::new(booleans, Some(nulls)) + }; + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = Int32Array::from(vec![ + Some(42), + Some(42), + Some(123), + Some(123), // true in mask but null + Some(123), + Some(123), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_string_array_with_nulls_is_mask_should_be_treated_as_false() { + let truthy = StringArray::from_iter_values(vec!["1", "2", "3", "4", "5", "6"]); + let falsy = StringArray::from_iter_values(vec!["7", "8", "9", "10", "11", "12"]); + + let mask = { + let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]); + let nulls = NullBuffer::from(vec![ + true, true, true, + false, // null treated as false even though in the original mask it was true + true, true, + ]); + BooleanArray::new(booleans, Some(nulls)) + }; + let out = zip(&mask, &truthy, &falsy).unwrap(); + let actual = out.as_string::(); + let expected = StringArray::from_iter_values(vec![ + "1", "2", "9", "10", // true in mask but null + "11", "12", + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_large_string_scalar_with_boolean_array_mask_with_nulls_should_be_treated_as_false() + { + let scalar_truthy = Scalar::new(LargeStringArray::from_iter_values(["test"])); + let scalar_falsy = Scalar::new(LargeStringArray::from_iter_values(["something else"])); + + let mask = { + let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]); + let nulls = NullBuffer::from(vec![ + true, true, true, + false, // null treated as false even though in the original mask it was true + true, true, + ]); + BooleanArray::new(booleans, Some(nulls)) + }; + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = LargeStringArray::from_iter(vec![ + Some("test"), + Some("test"), + Some("something else"), + Some("something else"), // true in mask but null + Some("something else"), + Some("something else"), + ]); + assert_eq!(actual, &expected); + } } From dd98964bf67043a324205c7da51f71ff277fbc6c Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 27 Oct 2025 23:56:57 +0200 Subject: [PATCH 18/19] fix: `ArrayIter` does not report size hint correctly after advancing from the iterator back this also adds a LOT of tests extracted from (which is how I found that bug): - #8697 --- arrow-array/src/iterator.rs | 930 +++++++++++++++++++++++++++++++++++- 1 file changed, 925 insertions(+), 5 deletions(-) diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index 6708da3d5dd6..c1026c3ad561 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -44,7 +44,7 @@ use arrow_buffer::NullBuffer; /// [`PrimitiveArray`]: crate::PrimitiveArray /// [`compute::unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.unary.html /// [`compute::try_unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.try_unary.html -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ArrayIter { array: T, logical_nulls: Option, @@ -98,8 +98,8 @@ impl Iterator for ArrayIter { fn size_hint(&self) -> (usize, Option) { ( - self.array.len() - self.current, - Some(self.array.len() - self.current), + self.current_end - self.current, + Some(self.current_end - self.current), ) } } @@ -147,9 +147,14 @@ pub type MapArrayIter<'a> = ArrayIter<&'a MapArray>; pub type GenericListViewArrayIter<'a, O> = ArrayIter<&'a GenericListViewArray>; #[cfg(test)] mod tests { - use std::sync::Arc; - use crate::array::{ArrayRef, BinaryArray, BooleanArray, Int32Array, StringArray}; + use crate::iterator::ArrayIter; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::fmt::Debug; + use std::iter::Copied; + use std::slice::Iter; + use std::sync::Arc; #[test] fn test_primitive_array_iter_round_trip() { @@ -264,4 +269,919 @@ mod tests { // check if ExactSizeIterator is implemented let _ = array.iter().rposition(|opt_b| opt_b == Some(true)); } + + trait SharedBetweenArrayIterAndSliceIter: + ExactSizeIterator> + DoubleEndedIterator> + Clone + { + } + impl> + DoubleEndedIterator>> + SharedBetweenArrayIterAndSliceIter for T + { + } + + fn get_int32_iterator_cases() -> impl Iterator>)> { + let mut rng = StdRng::seed_from_u64(42); + + let no_nulls_and_no_duplicates = (0..10).map(Some).collect::>>(); + let no_nulls_random_values = (0..10) + .map(|_| rng.random::()) + .map(Some) + .collect::>>(); + + let all_nulls = (0..10).map(|_| None).collect::>>(); + let only_start_nulls = (0..10) + .map(|item| if item < 4 { None } else { Some(item) }) + .collect::>>(); + let only_end_nulls = (0..10) + .map(|item| if item > 8 { None } else { Some(item) }) + .collect::>>(); + let only_middle_nulls = (0..10) + .map(|item| { + if (4..=8).contains(&item) && rng.random_bool(0.9) { + None + } else { + Some(item) + } + }) + .collect::>>(); + let random_values_with_random_nulls = (0..10) + .map(|_| { + if rng.random_bool(0.3) { + None + } else { + Some(rng.random::()) + } + }) + .collect::>>(); + + let no_nulls_and_some_duplicates = (0..10) + .map(|item| item % 3) + .map(Some) + .collect::>>(); + let no_nulls_and_all_same_value = + (0..10).map(|_| 9).map(Some).collect::>>(); + let no_nulls_and_continues_duplicates = [0, 0, 0, 1, 1, 2, 2, 2, 2, 3] + .map(Some) + .into_iter() + .collect::>>(); + + let single_null_and_no_duplicates = (0..10) + .map(|item| if item == 4 { None } else { Some(item) }) + .collect::>>(); + let multiple_nulls_and_no_duplicates = (0..10) + .map(|item| if item % 3 == 2 { None } else { Some(item) }) + .collect::>>(); + let continues_nulls_and_no_duplicates = [ + Some(0), + Some(1), + None, + None, + Some(2), + Some(3), + None, + Some(4), + Some(5), + None, + ] + .into_iter() + .collect::>>(); + + [ + no_nulls_and_no_duplicates, + no_nulls_random_values, + no_nulls_and_some_duplicates, + no_nulls_and_all_same_value, + no_nulls_and_continues_duplicates, + all_nulls, + only_start_nulls, + only_end_nulls, + only_middle_nulls, + random_values_with_random_nulls, + single_null_and_no_duplicates, + multiple_nulls_and_no_duplicates, + continues_nulls_and_no_duplicates, + ] + .map(|case| (Int32Array::from(case.clone()), case)) + .into_iter() + } + + trait SetupIter { + fn setup(&self, iter: &mut I); + } + + struct NoSetup; + impl SetupIter for NoSetup { + fn setup(&self, _iter: &mut I) { + // none + } + } + + fn setup_and_assert_cases( + setup_iterator: impl SetupIter, + assert_fn: impl Fn(ArrayIter<&Int32Array>, Copied>>), + ) { + for (array, source) in get_int32_iterator_cases() { + let mut actual = ArrayIter::new(&array); + let mut expected = source.iter().copied(); + + setup_iterator.setup(&mut actual); + setup_iterator.setup(&mut expected); + + assert_fn(actual, expected); + } + } + + /// Trait representing an operation on a BitIterator + /// that can be compared against a slice iterator + trait ArrayIteratorOp { + /// What the operation returns (e.g. Option for last/max, usize for count, etc) + type Output: PartialEq + Debug; + + /// The name of the operation, used for error messages + fn name(&self) -> String; + + /// Get the value of the operation for the provided iterator + /// This will be either a BitIterator or a slice iterator to make sure they produce the same result + fn get_value(&self, iter: T) -> Self::Output; + } + + /// Trait representing an operation on a BitIterator + /// that can be compared against a slice iterator + trait ArrayIteratorMutateOp { + /// What the operation returns (e.g. Option for last/max, usize for count, etc) + type Output: PartialEq + Debug; + + /// The name of the operation, used for error messages + fn name(&self) -> String; + + /// Get the value of the operation for the provided iterator + /// This will be either a BitIterator or a slice iterator to make sure they produce the same result + fn get_value(&self, iter: &mut T) -> Self::Output; + } + + /// Helper function that will assert that the provided operation + /// produces the same result for both BitIterator and slice iterator + /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) + fn assert_array_iterator_cases(o: O) { + setup_and_assert_cases(NoSetup, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct Next; + impl SetupIter for Next { + fn setup(&self, iter: &mut I) { + iter.next(); + } + } + setup_and_assert_cases(Next, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the start (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextBack; + impl SetupIter for NextBack { + fn setup(&self, iter: &mut I) { + iter.next_back(); + } + } + + setup_and_assert_cases(NextBack, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the end (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextAndBack; + impl SetupIter for NextAndBack { + fn setup(&self, iter: &mut I) { + iter.next(); + iter.next_back(); + } + } + + setup_and_assert_cases(NextAndBack, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming 1 element from start and end (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextUntilLast; + impl SetupIter for NextUntilLast { + fn setup(&self, iter: &mut I) { + let len = iter.len(); + if len > 1 { + iter.nth(len - 2); + } + } + } + setup_and_assert_cases(NextUntilLast, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming all from the start but 1 (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextBackUntilFirst; + impl SetupIter for NextBackUntilFirst { + fn setup(&self, iter: &mut I) { + let len = iter.len(); + if len > 1 { + iter.nth_back(len - 2); + } + } + } + setup_and_assert_cases(NextBackUntilFirst, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming all from the end but 1 (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextFinish; + impl SetupIter for NextFinish { + fn setup(&self, iter: &mut I) { + iter.nth(iter.len()); + } + } + setup_and_assert_cases(NextFinish, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming all from the start (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextBackFinish; + impl SetupIter for NextBackFinish { + fn setup(&self, iter: &mut I) { + iter.nth_back(iter.len()); + } + } + setup_and_assert_cases(NextBackFinish, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for new iter after consuming all from the end (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextUntilLastNone; + impl SetupIter for NextUntilLastNone { + fn setup(&self, iter: &mut I) { + let last_null_position = iter.clone().rposition(|item| item.is_none()); + + // move the iterator to the location where there are no nulls anymore + if let Some(last_null_position) = last_null_position { + iter.nth(last_null_position); + } + } + } + setup_and_assert_cases(NextUntilLastNone, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for iter that have no nulls left (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + + struct NextUntilLastSome; + impl SetupIter for NextUntilLastSome { + fn setup(&self, iter: &mut I) { + let last_some_position = iter.clone().rposition(|item| item.is_some()); + + // move the iterator to the location where there are only nulls + if let Some(last_some_position) = last_some_position { + iter.nth(last_some_position); + } + } + } + setup_and_assert_cases(NextUntilLastSome, |actual, expected| { + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for iter that only have nulls left (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + }); + } + + /// Helper function that will assert that the provided operation + /// produces the same result for both BitIterator and slice iterator + /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) + fn assert_array_iterator_cases_mutate(o: O) { + for (array, source) in get_int32_iterator_cases() { + for i in 0..source.len() { + let mut actual = ArrayIter::new(&array); + let mut expected = source.iter().copied(); + + // calling nth(0) is the same as calling next() + // but we want to get to the ith position so we call nth(i - 1) + if i > 0 { + actual.nth(i - 1); + expected.nth(i - 1); + } + + let current_iterator_values: Vec> = expected.clone().collect(); + + let actual_value = o.get_value(&mut actual); + let expected_value = o.get_value(&mut expected); + + assert_eq!( + actual_value, + expected_value, + "Failed on op {} for iter that advanced to i {i} (left actual, right expected) ({current_iterator_values:?})", + o.name() + ); + + let left_over_actual: Vec<_> = actual.clone().collect(); + let left_over_expected: Vec<_> = expected.clone().collect(); + + assert_eq!( + left_over_actual, left_over_expected, + "state after mutable should be the same" + ); + } + } + } + + #[derive(Debug, PartialEq)] + struct CallTrackingAndResult { + result: Result, + calls: Vec, + } + type CallTrackingWithInputType = CallTrackingAndResult>; + type CallTrackingOnly = CallTrackingWithInputType<()>; + + #[test] + fn assert_position() { + struct PositionOp { + reverse: bool, + number_of_false: usize, + } + + impl ArrayIteratorMutateOp for PositionOp { + type Output = CallTrackingWithInputType>; + fn name(&self) -> String { + if self.reverse { + format!("rposition with {} false returned", self.number_of_false) + } else { + format!("position with {} false returned", self.number_of_false) + } + } + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = vec![]; + + let mut count = 0; + + let position_result = if self.reverse { + iter.rposition(|item| { + items.push(item); + + if count < self.number_of_false { + count += 1; + false + } else { + true + } + }) + } else { + iter.position(|item| { + items.push(item); + + if count < self.number_of_false { + count += 1; + false + } else { + true + } + }) + }; + + CallTrackingAndResult { + result: position_result, + calls: items, + } + } + } + + for reverse in [false, true] { + for number_of_false in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(PositionOp { + reverse, + number_of_false, + }); + } + } + } + + #[test] + fn assert_nth() { + setup_and_assert_cases(NoSetup, |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth(0); + #[allow(clippy::iter_nth_zero)] + let expected_val = expected.nth(0); + assert_eq!(actual_val, expected_val, "Failed on nth(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(1); + let expected_val = expected.nth(1); + assert_eq!(actual_val, expected_val, "Failed on nth(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(2); + let expected_val = expected.nth(2); + assert_eq!(actual_val, expected_val, "Failed on nth(2)"); + } + } + }); + } + + #[test] + fn assert_nth_back() { + setup_and_assert_cases(NoSetup, |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth_back(0); + #[allow(clippy::iter_nth_zero)] + let expected_val = expected.nth_back(0); + assert_eq!(actual_val, expected_val, "Failed on nth_back(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(1); + let expected_val = expected.nth_back(1); + assert_eq!(actual_val, expected_val, "Failed on nth_back(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(2); + let expected_val = expected.nth_back(2); + assert_eq!(actual_val, expected_val, "Failed on nth_back(2)"); + } + } + }); + } + + #[test] + fn assert_last() { + for (array, source) in get_int32_iterator_cases() { + let mut actual_forward = ArrayIter::new(&array); + let mut expected_forward = source.iter().copied(); + + for _ in 0..source.len() + 1 { + { + let actual_forward_clone = actual_forward.clone(); + let expected_forward_clone = expected_forward.clone(); + + assert_eq!(actual_forward_clone.last(), expected_forward_clone.last()); + } + + actual_forward.next(); + expected_forward.next(); + } + + let mut actual_backward = ArrayIter::new(&array); + let mut expected_backward = source.iter().copied(); + for _ in 0..source.len() + 1 { + { + assert_eq!( + actual_backward.clone().last(), + expected_backward.clone().last() + ); + } + + actual_backward.next_back(); + expected_backward.next_back(); + } + } + } + + #[test] + fn assert_for_each() { + struct ForEachOp; + + impl ArrayIteratorOp for ForEachOp { + type Output = CallTrackingOnly; + + fn name(&self) -> String { + "for_each".to_string() + } + + fn get_value(&self, iter: T) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + iter.for_each(|item| { + items.push(item); + }); + + CallTrackingAndResult { + calls: items, + result: (), + } + } + } + + assert_array_iterator_cases(ForEachOp) + } + + #[test] + fn assert_fold() { + struct FoldOp { + reverse: bool, + } + + #[derive(Debug, PartialEq)] + struct CallArgs { + acc: Option, + item: Option, + } + + impl ArrayIteratorOp for FoldOp { + type Output = CallTrackingAndResult, CallArgs>; + + fn name(&self) -> String { + if self.reverse { + "rfold".to_string() + } else { + "fold".to_string() + } + } + + fn get_value(&self, iter: T) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + let result = if self.reverse { + iter.rfold(Some(1), |acc, item| { + items.push(CallArgs { item, acc }); + + item.map(|val| val + 100) + }) + } else { + #[allow(clippy::manual_try_fold)] + iter.fold(Some(1), |acc, item| { + items.push(CallArgs { item, acc }); + + item.map(|val| val + 100) + }) + }; + + CallTrackingAndResult { + calls: items, + result, + } + } + } + + assert_array_iterator_cases(FoldOp { reverse: false }); + assert_array_iterator_cases(FoldOp { reverse: true }); + } + + #[test] + fn assert_count() { + struct CountOp; + + impl ArrayIteratorOp for CountOp { + type Output = usize; + + fn name(&self) -> String { + "count".to_string() + } + + fn get_value(&self, iter: T) -> Self::Output { + iter.count() + } + } + + assert_array_iterator_cases(CountOp) + } + + #[test] + fn assert_any() { + struct AnyOp { + false_count: usize, + } + + impl ArrayIteratorMutateOp for AnyOp { + type Output = CallTrackingWithInputType; + + fn name(&self) -> String { + format!("any with {} false returned", self.false_count) + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + let mut count = 0; + let res = iter.any(|item| { + items.push(item); + + if count < self.false_count { + count += 1; + false + } else { + true + } + }); + + CallTrackingWithInputType { + calls: items, + result: res, + } + } + } + + for false_count in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(AnyOp { false_count }); + } + } + + #[test] + fn assert_all() { + struct AllOp { + true_count: usize, + } + + impl ArrayIteratorMutateOp for AllOp { + type Output = CallTrackingWithInputType; + + fn name(&self) -> String { + format!("all with {} false returned", self.true_count) + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + let mut count = 0; + let res = iter.all(|item| { + items.push(item); + + if count < self.true_count { + count += 1; + true + } else { + false + } + }); + + CallTrackingWithInputType { + calls: items, + result: res, + } + } + } + + for true_count in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(AllOp { true_count }); + } + } + + #[test] + fn assert_find() { + struct FindOp { + reverse: bool, + false_count: usize, + } + + impl ArrayIteratorMutateOp for FindOp { + type Output = CallTrackingWithInputType>>; + + fn name(&self) -> String { + if self.reverse { + format!("rfind with {} false returned", self.false_count) + } else { + format!("find with {} false returned", self.false_count) + } + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = vec![]; + + let mut count = 0; + + let position_result = if self.reverse { + iter.rfind(|item| { + items.push(*item); + + if count < self.false_count { + count += 1; + false + } else { + true + } + }) + } else { + iter.find(|item| { + items.push(*item); + + if count < self.false_count { + count += 1; + false + } else { + true + } + }) + }; + + CallTrackingWithInputType { + calls: items, + result: position_result, + } + } + } + + for reverse in [false, true] { + for false_count in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(FindOp { + reverse, + false_count, + }); + } + } + } + + #[test] + fn assert_find_map() { + struct FindMapOp { + number_of_nones: usize, + } + + impl ArrayIteratorMutateOp for FindMapOp { + type Output = CallTrackingWithInputType>; + + fn name(&self) -> String { + format!("find_map with {} None returned", self.number_of_nones) + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = vec![]; + + let mut count = 0; + + let result = iter.find_map(|item| { + items.push(item); + + if count < self.number_of_nones { + count += 1; + None + } else { + Some("found it") + } + }); + + CallTrackingAndResult { + result, + calls: items, + } + } + } + + for number_of_nones in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(FindMapOp { number_of_nones }); + } + } + + #[test] + fn assert_partition() { + struct PartitionOp) -> bool> { + description: &'static str, + predicate: F, + } + + #[derive(Debug, PartialEq)] + struct PartitionResult { + left: Vec>, + right: Vec>, + } + + impl) -> bool> ArrayIteratorOp for PartitionOp { + type Output = CallTrackingWithInputType; + + fn name(&self) -> String { + format!("partition by {}", self.description) + } + + fn get_value(&self, iter: T) -> Self::Output { + let mut items = vec![]; + + let mut index = 0; + + let (left, right) = iter.partition(|item| { + items.push(*item); + + let res = (self.predicate)(index, item); + + index += 1; + res + }); + + CallTrackingAndResult { + result: PartitionResult { left, right }, + calls: items, + } + } + } + + assert_array_iterator_cases(PartitionOp { + description: "None on one side and Some(*) on the other", + predicate: |_, item| item.is_none(), + }); + + assert_array_iterator_cases(PartitionOp { + description: "all true", + predicate: |_, _| true, + }); + + assert_array_iterator_cases(PartitionOp { + description: "all false", + predicate: |_, _| false, + }); + + let random_values = (0..100).map(|_| rand::random_bool(0.5)).collect::>(); + assert_array_iterator_cases(PartitionOp { + description: "random", + predicate: |index, _| random_values[index % random_values.len()], + }); + } } From ae5e9618619a91d51a03ed51571fab6e912a49ca Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 28 Oct 2025 00:19:05 +0200 Subject: [PATCH 19/19] cleanup tests --- arrow-array/src/iterator.rs | 101 +++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 48 deletions(-) diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index c1026c3ad561..f96d0158768e 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -391,38 +391,42 @@ mod tests { } } - /// Trait representing an operation on a BitIterator + /// Trait representing an operation on a [`ArrayIter`] /// that can be compared against a slice iterator - trait ArrayIteratorOp { - /// What the operation returns (e.g. Option for last/max, usize for count, etc) + /// + /// this is for consuming operations (e.g. `count`, `last`, etc) + trait ConsumingArrayIteratorOp { + /// What the operation returns (e.g. Option for last, usize for count, etc) type Output: PartialEq + Debug; /// The name of the operation, used for error messages fn name(&self) -> String; /// Get the value of the operation for the provided iterator - /// This will be either a BitIterator or a slice iterator to make sure they produce the same result + /// This will be either a [`ArrayIter`] or a slice iterator to make sure they produce the same result fn get_value(&self, iter: T) -> Self::Output; } - /// Trait representing an operation on a BitIterator - /// that can be compared against a slice iterator - trait ArrayIteratorMutateOp { - /// What the operation returns (e.g. Option for last/max, usize for count, etc) + /// Trait representing an operation on a [`ArrayIter`] + /// that can be compared against a slice iterator. + /// + /// This is for mutating operations (e.g. `position`, `any`, `find`, etc) + trait MutatingArrayIteratorOp { + /// What the operation returns (e.g. Option for last, usize for count, etc) type Output: PartialEq + Debug; /// The name of the operation, used for error messages fn name(&self) -> String; /// Get the value of the operation for the provided iterator - /// This will be either a BitIterator or a slice iterator to make sure they produce the same result + /// This will be either a [`ArrayIter`] or a slice iterator to make sure they produce the same result fn get_value(&self, iter: &mut T) -> Self::Output; } /// Helper function that will assert that the provided operation - /// produces the same result for both BitIterator and slice iterator + /// produces the same result for both [`ArrayIter`] and slice iterator /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) - fn assert_array_iterator_cases(o: O) { + fn assert_array_iterator_cases(o: O) { setup_and_assert_cases(NoSetup, |actual, expected| { let current_iterator_values: Vec> = expected.clone().collect(); assert_eq!( @@ -607,42 +611,43 @@ mod tests { } /// Helper function that will assert that the provided operation - /// produces the same result for both BitIterator and slice iterator + /// produces the same result for both [`ArrayIter`] and slice iterator /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) - fn assert_array_iterator_cases_mutate(o: O) { - for (array, source) in get_int32_iterator_cases() { - for i in 0..source.len() { - let mut actual = ArrayIter::new(&array); - let mut expected = source.iter().copied(); - - // calling nth(0) is the same as calling next() - // but we want to get to the ith position so we call nth(i - 1) - if i > 0 { - actual.nth(i - 1); - expected.nth(i - 1); - } + /// + /// this is different from [`assert_array_iterator_cases`] as this also check that the state after the call is correct + /// to make sure we don't leave the iterator in incorrect state + fn assert_array_iterator_cases_mutate(o: O) { + struct Adapter { + o: O, + } - let current_iterator_values: Vec> = expected.clone().collect(); + #[derive(Debug, PartialEq)] + struct AdapterOutput { + value: Value, + /// collect on the iterator after running the operation + leftover: Vec>, + } - let actual_value = o.get_value(&mut actual); - let expected_value = o.get_value(&mut expected); + impl ConsumingArrayIteratorOp for Adapter { + type Output = AdapterOutput; - assert_eq!( - actual_value, - expected_value, - "Failed on op {} for iter that advanced to i {i} (left actual, right expected) ({current_iterator_values:?})", - o.name() - ); + fn name(&self) -> String { + self.o.name() + } + + fn get_value( + &self, + mut iter: T, + ) -> Self::Output { + let value = self.o.get_value(&mut iter); - let left_over_actual: Vec<_> = actual.clone().collect(); - let left_over_expected: Vec<_> = expected.clone().collect(); + let leftover: Vec<_> = iter.collect(); - assert_eq!( - left_over_actual, left_over_expected, - "state after mutable should be the same" - ); + AdapterOutput { value, leftover } } } + + assert_array_iterator_cases(Adapter { o }) } #[derive(Debug, PartialEq)] @@ -660,7 +665,7 @@ mod tests { number_of_false: usize, } - impl ArrayIteratorMutateOp for PositionOp { + impl MutatingArrayIteratorOp for PositionOp { type Output = CallTrackingWithInputType>; fn name(&self) -> String { if self.reverse { @@ -830,7 +835,7 @@ mod tests { fn assert_for_each() { struct ForEachOp; - impl ArrayIteratorOp for ForEachOp { + impl ConsumingArrayIteratorOp for ForEachOp { type Output = CallTrackingOnly; fn name(&self) -> String { @@ -866,7 +871,7 @@ mod tests { item: Option, } - impl ArrayIteratorOp for FoldOp { + impl ConsumingArrayIteratorOp for FoldOp { type Output = CallTrackingAndResult, CallArgs>; fn name(&self) -> String { @@ -910,7 +915,7 @@ mod tests { fn assert_count() { struct CountOp; - impl ArrayIteratorOp for CountOp { + impl ConsumingArrayIteratorOp for CountOp { type Output = usize; fn name(&self) -> String { @@ -931,7 +936,7 @@ mod tests { false_count: usize, } - impl ArrayIteratorMutateOp for AnyOp { + impl MutatingArrayIteratorOp for AnyOp { type Output = CallTrackingWithInputType; fn name(&self) -> String { @@ -974,7 +979,7 @@ mod tests { true_count: usize, } - impl ArrayIteratorMutateOp for AllOp { + impl MutatingArrayIteratorOp for AllOp { type Output = CallTrackingWithInputType; fn name(&self) -> String { @@ -1018,7 +1023,7 @@ mod tests { false_count: usize, } - impl ArrayIteratorMutateOp for FindOp { + impl MutatingArrayIteratorOp for FindOp { type Output = CallTrackingWithInputType>>; fn name(&self) -> String { @@ -1084,7 +1089,7 @@ mod tests { number_of_nones: usize, } - impl ArrayIteratorMutateOp for FindMapOp { + impl MutatingArrayIteratorOp for FindMapOp { type Output = CallTrackingWithInputType>; fn name(&self) -> String { @@ -1135,7 +1140,7 @@ mod tests { right: Vec>, } - impl) -> bool> ArrayIteratorOp for PartitionOp { + impl) -> bool> ConsumingArrayIteratorOp for PartitionOp { type Output = CallTrackingWithInputType; fn name(&self) -> String {