diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 28c65c5994bf..f25b51f3655d 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -161,6 +161,8 @@ #![warn(missing_docs)] use std::cmp::Ordering; use std::hash::{Hash, Hasher}; +use std::iter::Map; +use std::slice::Windows; use std::sync::Arc; use arrow_array::cast::*; @@ -1118,6 +1120,9 @@ pub struct Rows { config: RowConfig, } +/// The iterator type for [`Rows::lengths`] +pub type RowLengthIter<'a> = Map, fn(&'a [usize]) -> usize>; + impl Rows { /// Append a [`Row`] to this [`Rows`] pub fn push(&mut self, row: Row<'_>) { @@ -1156,6 +1161,19 @@ impl Rows { } } + /// Returns the number of bytes the row at index `row` is occupying, + /// that is, what is the length of the returned [`Row::data`] will be. + pub fn row_len(&self, row: usize) -> usize { + assert!(row + 1 < self.offsets.len()); + + self.offsets[row + 1] - self.offsets[row] + } + + /// Get an iterator over the lengths of each row in this [`Rows`] + pub fn lengths(&self) -> RowLengthIter<'_> { + self.offsets.windows(2).map(|w| w[1] - w[0]) + } + /// Sets the length of this [`Rows`] to 0 pub fn clear(&mut self) { self.offsets.truncate(1); @@ -1579,7 +1597,7 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> LengthTracker { array => { tracker.push_variable( array.keys().iter().map(|v| match v { - Some(k) => values.row(k.as_usize()).data.len(), + Some(k) => values.row_len(k.as_usize()), None => null.data.len(), }) ) @@ -1590,7 +1608,7 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> LengthTracker { Encoder::Struct(rows, null) => { let array = as_struct_array(array); tracker.push_variable((0..array.len()).map(|idx| match array.is_valid(idx) { - true => 1 + rows.row(idx).as_ref().len(), + true => 1 + rows.row_len(idx), false => 1 + null.data.len(), })); } @@ -1642,10 +1660,10 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> LengthTracker { let lengths = (0..union_array.len()).map(|i| { let type_id = type_ids[i]; let child_row_i = offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i); - let child_row = child_rows[type_id as usize].row(child_row_i); + let child_row_len = child_rows[type_id as usize].row_len(child_row_i); // length: 1 byte type_id + child row bytes - 1 + child_row.as_ref().len() + 1 + child_row_len }); tracker.push_variable(lengths); @@ -3691,6 +3709,38 @@ mod tests { } } + // Validate rows length iterator + { + let mut rows_iter = rows.iter(); + let mut rows_lengths_iter = rows.lengths(); + for (index, row) in rows_iter.by_ref().enumerate() { + let len = rows_lengths_iter + .next() + .expect("Reached end of length iterator while still have rows"); + assert_eq!( + row.data.len(), + len, + "Row length mismatch: {} vs {}", + row.data.len(), + len + ); + assert_eq!( + len, + rows.row_len(index), + "Row length mismatch at index {}: {} vs {}", + index, + len, + rows.row_len(index) + ); + } + + assert_eq!( + rows_lengths_iter.next(), + None, + "Length iterator did not reach end" + ); + } + // Convert rows produced from convert_columns(). // Note: validate_utf8 is set to false since Row is initialized through empty_rows() let back = converter.convert_rows(&rows).unwrap(); @@ -4343,4 +4393,13 @@ mod tests { "Size should increase when reserving more space than previously reserved" ); } + + #[test] + fn empty_rows_should_return_empty_lengths_iterator() { + let rows = RowConverter::new(vec![SortField::new(DataType::UInt8)]) + .unwrap() + .empty_rows(0, 0); + let mut lengths_iter = rows.lengths(); + assert_eq!(lengths_iter.next(), None); + } } diff --git a/arrow-row/src/list.rs b/arrow-row/src/list.rs index e04aa70c528f..6e552b0a93b9 100644 --- a/arrow-row/src/list.rs +++ b/arrow-row/src/list.rs @@ -27,32 +27,32 @@ pub fn compute_lengths( rows: &Rows, array: &GenericListArray, ) { - let shift = array.value_offsets()[0].as_usize(); - let offsets = array.value_offsets().windows(2); + let mut rows_length_iter = rows.lengths(); + lengths .iter_mut() .zip(offsets) .enumerate() .for_each(|(idx, (length, offsets))| { - let start = offsets[0].as_usize() - shift; - let end = offsets[1].as_usize() - shift; - let range = array.is_valid(idx).then_some(start..end); - *length += encoded_len(rows, range); + let len = offsets[1].as_usize() - offsets[0].as_usize(); + if array.is_valid(idx) { + *length += 1 + rows_length_iter + .by_ref() + .take(len) + .map(Some) + .map(super::variable::padded_length) + .sum::() + } else { + // Advance rows iterator by len + if len > 0 { + rows_length_iter.nth(len - 1); + } + *length += 1; + } }); } -fn encoded_len(rows: &Rows, range: Option>) -> usize { - match range { - None => 1, - Some(range) => { - 1 + range - .map(|i| super::variable::padded_length(Some(rows.row(i).as_ref().len()))) - .sum::() - } - } -} - /// Encodes the provided `GenericListArray` to `out` with the provided `SortOptions` /// /// `rows` should contain the encoded child elements diff --git a/arrow-row/src/run.rs b/arrow-row/src/run.rs index 24eaaa18e018..e12fa87dce4b 100644 --- a/arrow-row/src/run.rs +++ b/arrow-row/src/run.rs @@ -33,8 +33,8 @@ pub fn compute_lengths( // Iterate over each run and apply the same length to all logical positions in the run for (physical_idx, &run_end) in run_ends.iter().enumerate() { let logical_end = run_end.as_usize(); - let row = rows.row(physical_idx); - let encoded_len = variable::encoded_len(Some(row.data)); + let row_len = rows.row_len(physical_idx); + let encoded_len = variable::padded_length(Some(row_len)); // Add the same length for all logical positions in this run for length in &mut lengths[logical_start..logical_end] {