Skip to content

Commit 989ac03

Browse files
authored
[Variant] Align cast logic for variant_get to cast kernel for numeric/bool types (#9563)
# Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. --> - Closes #9564 . # What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Align tests with cast kernel # Are these changes tested? Covered by the existing tests <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> # Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. If there are any breaking changes to public APIs, please call them out. --> Yes, changed the logic for `Variant::asboolean/as_int8/as_int16/as_int32/as_int64/as_u8/as_u16/as_u32/as_u64/as_f16/as_f32/as_f64
1 parent 0936b38 commit 989ac03

File tree

7 files changed

+289
-228
lines changed

7 files changed

+289
-228
lines changed

arrow-cast/src/cast/mod.rs

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ use arrow_select::take::take;
7373
use num_traits::{NumCast, ToPrimitive, cast::AsPrimitive};
7474

7575
pub use decimal::{DecimalCast, rescale_decimal};
76+
pub use string::cast_single_string_to_boolean_default;
7677

7778
/// CastOptions provides a way to override the default cast behaviors
7879
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -2475,7 +2476,7 @@ where
24752476
R::Native: NumCast,
24762477
{
24772478
from.try_unary(|value| {
2478-
num_traits::cast::cast::<T::Native, R::Native>(value).ok_or_else(|| {
2479+
num_cast::<T::Native, R::Native>(value).ok_or_else(|| {
24792480
ArrowError::CastError(format!(
24802481
"Can't cast value {:?} to type {}",
24812482
value,
@@ -2485,6 +2486,17 @@ where
24852486
})
24862487
}
24872488

2489+
/// Natural cast between numeric types
2490+
/// Return None if the input `value` can't be casted to type `O`.
2491+
#[inline]
2492+
pub fn num_cast<I, O>(value: I) -> Option<O>
2493+
where
2494+
I: NumCast,
2495+
O: NumCast,
2496+
{
2497+
num_traits::cast::cast::<I, O>(value)
2498+
}
2499+
24882500
// Natural cast between numeric types
24892501
// If the value of T can't be casted to R, it will be converted to null
24902502
fn numeric_cast<T, R>(from: &PrimitiveArray<T>) -> PrimitiveArray<R>
@@ -2494,7 +2506,7 @@ where
24942506
T::Native: NumCast,
24952507
R::Native: NumCast,
24962508
{
2497-
from.unary_opt::<_, R>(num_traits::cast::cast::<T::Native, R::Native>)
2509+
from.unary_opt::<_, R>(num_cast::<T::Native, R::Native>)
24982510
}
24992511

25002512
fn cast_numeric_to_binary<FROM: ArrowPrimitiveType, O: OffsetSizeTrait>(
@@ -2551,16 +2563,23 @@ where
25512563
for i in 0..from.len() {
25522564
if from.is_null(i) {
25532565
b.append_null();
2554-
} else if from.value(i) != T::default_value() {
2555-
b.append_value(true);
25562566
} else {
2557-
b.append_value(false);
2567+
b.append_value(cast_num_to_bool::<T::Native>(from.value(i)));
25582568
}
25592569
}
25602570

25612571
Ok(b.finish())
25622572
}
25632573

2574+
/// Cast numeric types to boolean
2575+
#[inline]
2576+
pub fn cast_num_to_bool<I>(value: I) -> bool
2577+
where
2578+
I: Default + PartialEq,
2579+
{
2580+
value != I::default()
2581+
}
2582+
25642583
/// Cast Boolean types to numeric
25652584
///
25662585
/// `false` returns 0 while `true` returns 1
@@ -2586,11 +2605,8 @@ where
25862605
let iter = (0..from.len()).map(|i| {
25872606
if from.is_null(i) {
25882607
None
2589-
} else if from.value(i) {
2590-
// a workaround to cast a primitive to T::Native, infallible
2591-
num_traits::cast::cast(1)
25922608
} else {
2593-
Some(T::default_value())
2609+
single_bool_to_numeric::<T::Native>(from.value(i))
25942610
}
25952611
});
25962612
// Benefit:
@@ -2600,6 +2616,20 @@ where
26002616
unsafe { PrimitiveArray::<T>::from_trusted_len_iter(iter) }
26012617
}
26022618

2619+
/// Cast single bool value to numeric value.
2620+
#[inline]
2621+
pub fn single_bool_to_numeric<O>(value: bool) -> Option<O>
2622+
where
2623+
O: num_traits::NumCast + Default,
2624+
{
2625+
if value {
2626+
// a workaround to cast a primitive to type O, infallible
2627+
num_traits::cast::cast(1)
2628+
} else {
2629+
Some(O::default())
2630+
}
2631+
}
2632+
26032633
/// Helper function to cast from one `BinaryArray` or 'LargeBinaryArray' to 'FixedSizeBinaryArray'.
26042634
fn cast_binary_to_fixed_size_binary<O: OffsetSizeTrait>(
26052635
array: &dyn Array,

arrow-cast/src/cast/string.rs

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -401,25 +401,38 @@ where
401401
let output_array = array
402402
.iter()
403403
.map(|value| match value {
404-
Some(value) => match value.to_ascii_lowercase().trim() {
405-
"t" | "tr" | "tru" | "true" | "y" | "ye" | "yes" | "on" | "1" => Ok(Some(true)),
406-
"f" | "fa" | "fal" | "fals" | "false" | "n" | "no" | "of" | "off" | "0" => {
407-
Ok(Some(false))
408-
}
409-
invalid_value => match cast_options.safe {
410-
true => Ok(None),
411-
false => Err(ArrowError::CastError(format!(
412-
"Cannot cast value '{invalid_value}' to value of Boolean type",
413-
))),
414-
},
415-
},
404+
Some(value) => cast_single_string_to_boolean(value, cast_options),
416405
None => Ok(None),
417406
})
418407
.collect::<Result<BooleanArray, _>>()?;
419408

420409
Ok(Arc::new(output_array))
421410
}
422411

412+
#[inline]
413+
fn cast_single_string_to_boolean(
414+
value: &str,
415+
cast_options: &CastOptions,
416+
) -> Result<Option<bool>, ArrowError> {
417+
match value.to_ascii_lowercase().trim() {
418+
"t" | "tr" | "tru" | "true" | "y" | "ye" | "yes" | "on" | "1" => Ok(Some(true)),
419+
"f" | "fa" | "fal" | "fals" | "false" | "n" | "no" | "of" | "off" | "0" => Ok(Some(false)),
420+
invalid_value => match cast_options.safe {
421+
true => Ok(None),
422+
false => Err(ArrowError::CastError(format!(
423+
"Cannot cast value '{invalid_value}' to value of Boolean type",
424+
))),
425+
},
426+
}
427+
}
428+
429+
/// Cast a single string to boolean with default cast option(safe=true).
430+
pub fn cast_single_string_to_boolean_default(value: &str) -> Option<bool> {
431+
cast_single_string_to_boolean(value, &CastOptions::default())
432+
.ok()
433+
.flatten()
434+
}
435+
423436
pub(crate) fn cast_utf8_to_boolean<OffsetSize>(
424437
from: &dyn Array,
425438
cast_options: &CastOptions,

parquet-variant-compute/src/shred_variant.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1288,7 +1288,7 @@ mod tests {
12881288
.downcast_ref::<arrow::array::Int32Array>()
12891289
.unwrap();
12901290
assert_eq!(typed_value_int32.value(0), 42);
1291-
assert!(typed_value_int32.is_null(1)); // float doesn't convert to int32
1291+
assert_eq!(typed_value_int32.value(1), 3);
12921292
assert!(typed_value_int32.is_null(2)); // string doesn't convert to int32
12931293

12941294
// Test Float64 target

parquet-variant-compute/src/variant_get.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2508,7 +2508,7 @@ mod test {
25082508
#[test]
25092509
fn test_error_message_boolean_type_display() {
25102510
let mut builder = VariantArrayBuilder::new(1);
2511-
builder.append_variant(Variant::Int32(123));
2511+
builder.append_variant(Variant::from("abcd"));
25122512
let variant_array: ArrayRef = ArrayRef::from(builder.build());
25132513

25142514
// Request Boolean with strict casting to force an error
@@ -2529,10 +2529,10 @@ mod test {
25292529
#[test]
25302530
fn test_error_message_numeric_type_display() {
25312531
let mut builder = VariantArrayBuilder::new(1);
2532-
builder.append_variant(Variant::BooleanTrue);
2532+
builder.append_variant(Variant::from("abcd"));
25332533
let variant_array: ArrayRef = ArrayRef::from(builder.build());
25342534

2535-
// Request Boolean with strict casting to force an error
2535+
// Request Float32 with strict casting to force an error
25362536
let options = GetOptions {
25372537
path: VariantPath::default(),
25382538
as_type: Some(Arc::new(Field::new("result", DataType::Float32, true))),
@@ -2553,7 +2553,7 @@ mod test {
25532553
builder.append_variant(Variant::BooleanFalse);
25542554
let variant_array: ArrayRef = ArrayRef::from(builder.build());
25552555

2556-
// Request Boolean with strict casting to force an error
2556+
// Request Timestamp with strict casting to force an error
25572557
let options = GetOptions {
25582558
path: VariantPath::default(),
25592559
as_type: Some(Arc::new(Field::new(

parquet-variant/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ edition = { workspace = true }
2929
rust-version = { workspace = true }
3030

3131
[dependencies]
32+
arrow = { workspace = true , features = ["canonical_extension_types"] }
3233
arrow-schema = { workspace = true }
3334
chrono = { workspace = true }
3435
half = { version = "2.1", default-features = false }
3536
indexmap = "2.10.0"
37+
num-traits = { version = "0.2", default-features = false }
3638
uuid = { version = "1.18.0", features = ["v4"]}
3739

3840
simdutf8 = { workspace = true , optional = true }

parquet-variant/src/utils.rs

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,6 @@ pub(crate) const fn expect_size_of<T>(expected: usize) {
146146
}
147147
}
148148

149-
pub(crate) fn fits_precision<const N: u32>(n: impl Into<i64>) -> bool {
150-
n.into().unsigned_abs().leading_zeros() >= (i64::BITS - N)
151-
}
152-
153149
/// Parse a path string into a vector of [`VariantPathElement`].
154150
///
155151
/// # Syntax
@@ -289,16 +285,3 @@ fn parse_in_bracket(s: &str, i: usize) -> Result<(VariantPathElement<'_>, usize)
289285

290286
Ok((element, end + 1))
291287
}
292-
293-
#[cfg(test)]
294-
mod test {
295-
use super::*;
296-
297-
#[test]
298-
fn test_fits_precision() {
299-
assert!(fits_precision::<10>(1023));
300-
assert!(!fits_precision::<10>(1024));
301-
assert!(fits_precision::<10>(-1023));
302-
assert!(!fits_precision::<10>(-1024));
303-
}
304-
}

0 commit comments

Comments
 (0)