Skip to content

Commit 573c6f8

Browse files
committed
Merge branch 'main' into issue-9645-ree-optimization
2 parents ffd5849 + 65ad652 commit 573c6f8

File tree

17 files changed

+2295
-497
lines changed

17 files changed

+2295
-497
lines changed

arrow-arith/src/aggregate.rs

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use arrow_array::cast::*;
2121
use arrow_array::iterator::ArrayIter;
2222
use arrow_array::*;
23-
use arrow_buffer::{ArrowNativeType, NullBuffer};
23+
use arrow_buffer::NullBuffer;
2424
use arrow_data::bit_iterator::try_for_each_valid_idx;
2525
use arrow_schema::*;
2626
use std::borrow::BorrowMut;
@@ -541,11 +541,9 @@ pub fn min_string_view(array: &StringViewArray) -> Option<&str> {
541541
///
542542
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
543543
/// For an overflow-checking variant, use [`sum_array_checked`] instead.
544-
pub fn sum_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
545-
where
546-
T: ArrowNumericType,
547-
T::Native: ArrowNativeTypeOp,
548-
{
544+
pub fn sum_array<T: ArrowNumericType, A: ArrayAccessor<Item = T::Native>>(
545+
array: A,
546+
) -> Option<T::Native> {
549547
match array.data_type() {
550548
DataType::Dictionary(_, _) => {
551549
let null_count = array.null_count();
@@ -583,13 +581,9 @@ where
583581
/// use [`sum_array`] instead.
584582
/// Additionally returns an `Err` on run-end-encoded arrays with a provided
585583
/// values type parameter that is incorrect.
586-
pub fn sum_array_checked<T, A: ArrayAccessor<Item = T::Native>>(
584+
pub fn sum_array_checked<T: ArrowNumericType, A: ArrayAccessor<Item = T::Native>>(
587585
array: A,
588-
) -> Result<Option<T::Native>, ArrowError>
589-
where
590-
T: ArrowNumericType,
591-
T::Native: ArrowNativeTypeOp,
592-
{
586+
) -> Result<Option<T::Native>, ArrowError> {
593587
match array.data_type() {
594588
DataType::Dictionary(_, _) => {
595589
let null_count = array.null_count();
@@ -717,21 +711,17 @@ mod ree {
717711

718712
/// Returns the min of values in the array of `ArrowNumericType` type, or dictionary
719713
/// array with value of `ArrowNumericType` type.
720-
pub fn min_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
721-
where
722-
T: ArrowNumericType,
723-
T::Native: ArrowNativeType,
724-
{
714+
pub fn min_array<T: ArrowNumericType, A: ArrayAccessor<Item = T::Native>>(
715+
array: A,
716+
) -> Option<T::Native> {
725717
min_max_array_helper::<T, A, _, _>(array, |a, b| a.is_gt(*b), min)
726718
}
727719

728720
/// Returns the max of values in the array of `ArrowNumericType` type, or dictionary
729721
/// array with value of `ArrowNumericType` type.
730-
pub fn max_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
731-
where
732-
T: ArrowNumericType,
733-
T::Native: ArrowNativeTypeOp,
734-
{
722+
pub fn max_array<T: ArrowNumericType, A: ArrayAccessor<Item = T::Native>>(
723+
array: A,
724+
) -> Option<T::Native> {
735725
min_max_array_helper::<T, A, _, _>(array, |a, b| a.is_lt(*b), max)
736726
}
737727

@@ -874,11 +864,9 @@ pub fn bool_or(array: &BooleanArray) -> Option<bool> {
874864
///
875865
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
876866
/// use [`sum`] instead.
877-
pub fn sum_checked<T>(array: &PrimitiveArray<T>) -> Result<Option<T::Native>, ArrowError>
878-
where
879-
T: ArrowNumericType,
880-
T::Native: ArrowNativeTypeOp,
881-
{
867+
pub fn sum_checked<T: ArrowNumericType>(
868+
array: &PrimitiveArray<T>,
869+
) -> Result<Option<T::Native>, ArrowError> {
882870
let null_count = array.null_count();
883871

884872
if null_count == array.len() {
@@ -922,10 +910,7 @@ where
922910
///
923911
/// This doesn't detect overflow in release mode by default. Once overflowing, the result will
924912
/// wrap around. For an overflow-checking variant, use [`sum_checked`] instead.
925-
pub fn sum<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
926-
where
927-
T::Native: ArrowNativeTypeOp,
928-
{
913+
pub fn sum<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native> {
929914
aggregate::<T::Native, T, SumAccumulator<T::Native>>(array)
930915
}
931916

@@ -940,10 +925,7 @@ where
940925
/// let result = min(&array);
941926
/// assert_eq!(result, Some(2));
942927
/// ```
943-
pub fn min<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
944-
where
945-
T::Native: PartialOrd,
946-
{
928+
pub fn min<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native> {
947929
aggregate::<T::Native, T, MinAccumulator<T::Native>>(array)
948930
}
949931

@@ -958,10 +940,7 @@ where
958940
/// let result = max(&array);
959941
/// assert_eq!(result, Some(8));
960942
/// ```
961-
pub fn max<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
962-
where
963-
T::Native: PartialOrd,
964-
{
943+
pub fn max<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native> {
965944
aggregate::<T::Native, T, MaxAccumulator<T::Native>>(array)
966945
}
967946

arrow-arith/src/arithmetic.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,7 @@ pub fn multiply_fixed_point(
170170
}
171171

172172
/// Divide a decimal native value by given divisor and round the result.
173-
fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
174-
where
175-
I: DecimalType,
176-
I::Native: ArrowNativeTypeOp,
177-
{
173+
fn divide_and_round<I: DecimalType>(input: I::Native, div: I::Native) -> I::Native {
178174
let d = input.div_wrapping(div);
179175
let r = input.mod_wrapping(div);
180176

arrow-array/src/array/byte_view_array.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,37 @@ impl<T: ByteViewType + ?Sized> GenericByteViewArray<T> {
670670
}
671671
}
672672

673+
/// Returns the total number of bytes of all non-null values in this array.
674+
///
675+
/// Unlike [`Self::total_buffer_bytes_used`], this method includes inlined strings
676+
/// (those with length ≤ [`MAX_INLINE_VIEW_LEN`]), making it suitable as a
677+
/// capacity hint when pre-allocating output buffers.
678+
///
679+
/// Null values are excluded from the sum.
680+
///
681+
/// # Example
682+
///
683+
/// ```rust
684+
/// # use arrow_array::StringViewArray;
685+
/// let array = StringViewArray::from_iter(vec![
686+
/// Some("hello"), // 5 bytes, inlined
687+
/// None, // excluded
688+
/// Some("large payload over 12 bytes"), // 27 bytes, non-inlined
689+
/// ]);
690+
/// assert_eq!(array.total_bytes_len(), 5 + 27);
691+
/// ```
692+
pub fn total_bytes_len(&self) -> usize {
693+
match self.nulls() {
694+
None => self.views().iter().map(|v| (*v as u32) as usize).sum(),
695+
Some(nulls) => self
696+
.views()
697+
.iter()
698+
.zip(nulls.iter())
699+
.map(|(v, is_valid)| if is_valid { (*v as u32) as usize } else { 0 })
700+
.sum(),
701+
}
702+
}
703+
673704
/// Returns the total number of bytes used by all non inlined views in all
674705
/// buffers.
675706
///
@@ -1809,4 +1840,41 @@ mod tests {
18091840
assert!(from_utf8(array.value(2)).is_ok());
18101841
array
18111842
}
1843+
1844+
#[test]
1845+
fn test_total_bytes_len() {
1846+
// inlined: "hello"=5, "world"=5, "lulu"=4 → 14
1847+
// non-inlined: "large payload over 12 bytes"=27
1848+
// null: should not count
1849+
let mut builder = StringViewBuilder::new();
1850+
builder.append_value("hello");
1851+
builder.append_value("world");
1852+
builder.append_value("lulu");
1853+
builder.append_null();
1854+
builder.append_value("large payload over 12 bytes");
1855+
let array = builder.finish();
1856+
assert_eq!(array.total_bytes_len(), 5 + 5 + 4 + 27);
1857+
}
1858+
1859+
#[test]
1860+
fn test_total_bytes_len_empty() {
1861+
let array = StringViewArray::from_iter::<Vec<Option<&str>>>(vec![]);
1862+
assert_eq!(array.total_bytes_len(), 0);
1863+
}
1864+
1865+
#[test]
1866+
fn test_total_bytes_len_all_nulls() {
1867+
let array = StringViewArray::new_null(5);
1868+
assert_eq!(array.total_bytes_len(), 0);
1869+
}
1870+
1871+
#[test]
1872+
fn test_total_bytes_len_binary_view() {
1873+
let array = BinaryViewArray::from_iter(vec![
1874+
Some(b"hi".as_ref()),
1875+
None,
1876+
Some(b"large payload over 12 bytes".as_ref()),
1877+
]);
1878+
assert_eq!(array.total_bytes_len(), 2 + 27);
1879+
}
18121880
}

arrow-cast/src/cast/decimal.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ where
812812
T: ArrowPrimitiveType,
813813
<T as ArrowPrimitiveType>::Native: NumCast,
814814
D: DecimalType + ArrowPrimitiveType,
815-
<D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive,
815+
<D as ArrowPrimitiveType>::Native: ToPrimitive,
816816
{
817817
let array = array.as_primitive::<D>();
818818

arrow-cast/src/cast/mod.rs

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,15 @@ mod list;
4343
mod map;
4444
mod run_array;
4545
mod string;
46+
mod union;
4647

4748
use crate::cast::decimal::*;
4849
use crate::cast::dictionary::*;
4950
use crate::cast::list::*;
5051
use crate::cast::map::*;
5152
use crate::cast::run_array::*;
5253
use crate::cast::string::*;
54+
pub use crate::cast::union::*;
5355

5456
use arrow_buffer::IntervalMonthDayNano;
5557
use arrow_data::ByteView;
@@ -71,6 +73,7 @@ use arrow_select::take::take;
7173
use num_traits::{NumCast, ToPrimitive, cast::AsPrimitive};
7274

7375
pub use decimal::{DecimalCast, rescale_decimal};
76+
pub use string::cast_single_string_to_boolean_default;
7477

7578
/// CastOptions provides a way to override the default cast behaviors
7679
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -108,6 +111,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
108111
can_cast_types(from_value_type, to_value_type)
109112
}
110113
(Dictionary(_, value_type), _) => can_cast_types(value_type, to_type),
114+
(Union(fields, _), _) => union::resolve_child_array(fields, to_type).is_some(),
115+
(_, Union(_, _)) => false,
111116
(RunEndEncoded(_, value_type), _) => can_cast_types(value_type.data_type(), to_type),
112117
(_, RunEndEncoded(_, value_type)) => can_cast_types(from_type, value_type.data_type()),
113118
(_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type),
@@ -230,7 +235,6 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
230235
}
231236
(Struct(_), _) => false,
232237
(_, Struct(_)) => false,
233-
234238
(_, Boolean) => from_type.is_integer() || from_type.is_floating() || from_type.is_string(),
235239
(Boolean, _) => to_type.is_integer() || to_type.is_floating() || to_type.is_string(),
236240

@@ -781,6 +785,14 @@ pub fn cast_with_options(
781785
))),
782786
}
783787
}
788+
(Union(_, _), _) => union_extract_by_type(
789+
array.as_any().downcast_ref::<UnionArray>().unwrap(),
790+
to_type,
791+
cast_options,
792+
),
793+
(_, Union(_, _)) => Err(ArrowError::CastError(format!(
794+
"Casting from {from_type} to {to_type} not supported"
795+
))),
784796
(Dictionary(index_type, _), _) => match **index_type {
785797
Int8 => dictionary_cast::<Int8Type>(array, to_type, cast_options),
786798
Int16 => dictionary_cast::<Int16Type>(array, to_type, cast_options),
@@ -2287,7 +2299,7 @@ fn cast_from_decimal<D, F>(
22872299
) -> Result<ArrayRef, ArrowError>
22882300
where
22892301
D: DecimalType + ArrowPrimitiveType,
2290-
<D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive,
2302+
<D as ArrowPrimitiveType>::Native: ToPrimitive,
22912303
F: Fn(D::Native) -> f64,
22922304
{
22932305
use DataType::*;
@@ -2464,7 +2476,7 @@ where
24642476
R::Native: NumCast,
24652477
{
24662478
from.try_unary(|value| {
2467-
num_traits::cast::cast::<T::Native, R::Native>(value).ok_or_else(|| {
2479+
num_cast::<T::Native, R::Native>(value).ok_or_else(|| {
24682480
ArrowError::CastError(format!(
24692481
"Can't cast value {:?} to type {}",
24702482
value,
@@ -2474,6 +2486,17 @@ where
24742486
})
24752487
}
24762488

2489+
/// Natural cast between numeric types
2490+
/// Return None if the input `value` can't be casted to type `O`.
2491+
#[inline]
2492+
pub fn num_cast<I, O>(value: I) -> Option<O>
2493+
where
2494+
I: NumCast,
2495+
O: NumCast,
2496+
{
2497+
num_traits::cast::cast::<I, O>(value)
2498+
}
2499+
24772500
// Natural cast between numeric types
24782501
// If the value of T can't be casted to R, it will be converted to null
24792502
fn numeric_cast<T, R>(from: &PrimitiveArray<T>) -> PrimitiveArray<R>
@@ -2483,7 +2506,7 @@ where
24832506
T::Native: NumCast,
24842507
R::Native: NumCast,
24852508
{
2486-
from.unary_opt::<_, R>(num_traits::cast::cast::<T::Native, R::Native>)
2509+
from.unary_opt::<_, R>(num_cast::<T::Native, R::Native>)
24872510
}
24882511

24892512
fn cast_numeric_to_binary<FROM: ArrowPrimitiveType, O: OffsetSizeTrait>(
@@ -2540,16 +2563,23 @@ where
25402563
for i in 0..from.len() {
25412564
if from.is_null(i) {
25422565
b.append_null();
2543-
} else if from.value(i) != T::default_value() {
2544-
b.append_value(true);
25452566
} else {
2546-
b.append_value(false);
2567+
b.append_value(cast_num_to_bool::<T::Native>(from.value(i)));
25472568
}
25482569
}
25492570

25502571
Ok(b.finish())
25512572
}
25522573

2574+
/// Cast numeric types to boolean
2575+
#[inline]
2576+
pub fn cast_num_to_bool<I>(value: I) -> bool
2577+
where
2578+
I: Default + PartialEq,
2579+
{
2580+
value != I::default()
2581+
}
2582+
25532583
/// Cast Boolean types to numeric
25542584
///
25552585
/// `false` returns 0 while `true` returns 1
@@ -2575,11 +2605,8 @@ where
25752605
let iter = (0..from.len()).map(|i| {
25762606
if from.is_null(i) {
25772607
None
2578-
} else if from.value(i) {
2579-
// a workaround to cast a primitive to T::Native, infallible
2580-
num_traits::cast::cast(1)
25812608
} else {
2582-
Some(T::default_value())
2609+
single_bool_to_numeric::<T::Native>(from.value(i))
25832610
}
25842611
});
25852612
// Benefit:
@@ -2589,6 +2616,20 @@ where
25892616
unsafe { PrimitiveArray::<T>::from_trusted_len_iter(iter) }
25902617
}
25912618

2619+
/// Cast single bool value to numeric value.
2620+
#[inline]
2621+
pub fn single_bool_to_numeric<O>(value: bool) -> Option<O>
2622+
where
2623+
O: num_traits::NumCast + Default,
2624+
{
2625+
if value {
2626+
// a workaround to cast a primitive to type O, infallible
2627+
num_traits::cast::cast(1)
2628+
} else {
2629+
Some(O::default())
2630+
}
2631+
}
2632+
25922633
/// Helper function to cast from one `BinaryArray` or 'LargeBinaryArray' to 'FixedSizeBinaryArray'.
25932634
fn cast_binary_to_fixed_size_binary<O: OffsetSizeTrait>(
25942635
array: &dyn Array,

0 commit comments

Comments
 (0)