diff --git a/Cargo.lock b/Cargo.lock index 4803d13bca9c5..895b3059f50c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2603,6 +2603,7 @@ dependencies = [ "datafusion-functions-aggregate", "datafusion-functions-nested", "log", + "num-traits", "percent-encoding", "rand 0.9.2", "serde_json", diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index 7bc45693a8b0d..7aa405338bffe 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -414,14 +414,12 @@ fn test_backtrace_output(#[case] query: &str) { let output = cmd.output().expect("Failed to execute command"); let stdout = String::from_utf8_lossy(&output.stdout); let stderr = String::from_utf8_lossy(&output.stderr); - let combined_output = format!("{}{}", stdout, stderr); + let combined_output = format!("{stdout}{stderr}"); // Assert that the output includes literal 'backtrace' assert!( combined_output.to_lowercase().contains("backtrace"), - "Expected output to contain 'backtrace', but got stdout: '{}' stderr: '{}'", - stdout, - stderr + "Expected output to contain 'backtrace', but got stdout: '{stdout}' stderr: '{stderr}'" ); } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index fbc1421727a7f..c6c50371c26c1 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -1274,7 +1274,6 @@ mod test { // To pass the test the environment variable RUST_BACKTRACE should be set to 1 to enforce backtrace #[cfg(feature = "backtrace")] #[test] - #[expect(clippy::unnecessary_literal_unwrap)] fn test_enabled_backtrace() { match std::env::var("RUST_BACKTRACE") { Ok(val) if val == "1" => {} diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 255525b92e0c0..fcc2e919b6cc2 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -19,12 +19,15 @@ use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::*; +#[cfg(not(feature = "force_hash_collisions"))] use arrow::compute::take; use arrow::datatypes::*; #[cfg(not(feature = "force_hash_collisions"))] use arrow::{downcast_dictionary_array, downcast_primitive_array}; use foldhash::fast::FixedState; +#[cfg(not(feature = "force_hash_collisions"))] use itertools::Itertools; +#[cfg(not(feature = "force_hash_collisions"))] use std::collections::HashMap; use std::hash::{BuildHasher, Hash, Hasher}; @@ -198,6 +201,7 @@ hash_float_value!((half::f16, u16), (f32, u32), (f64, u64)); /// Create a `SeedableRandomState` whose per-hasher seed incorporates `seed`. /// This folds the previous hash into the hasher's initial state so only the /// new value needs to pass through the hash function — same cost as `hash_one`. +#[cfg(not(feature = "force_hash_collisions"))] #[inline] fn seeded_state(seed: u64) -> foldhash::fast::SeedableRandomState { foldhash::fast::SeedableRandomState::with_seed( @@ -303,6 +307,7 @@ fn hash_array( /// HAS_NULLS: do we have to check null in the inner loop /// HAS_BUFFERS: if true, array has external buffers; if false, all strings are inlined/ less then 12 bytes /// REHASH: if true, combining with existing hash, otherwise initializing +#[cfg(not(feature = "force_hash_collisions"))] #[inline(never)] fn hash_string_view_array_inner< T: ByteViewType, @@ -429,6 +434,7 @@ fn hash_generic_byte_view_array( /// - `HAS_NULL_KEYS`: Whether to check for null dictionary keys /// - `HAS_NULL_VALUES`: Whether to check for null dictionary values /// - `MULTI_COL`: Whether to combine with existing hash (true) or initialize (false) +#[cfg(not(feature = "force_hash_collisions"))] #[inline(never)] fn hash_dictionary_inner< K: ArrowDictionaryKeyType, diff --git a/datafusion/common/src/pruning.rs b/datafusion/common/src/pruning.rs index 5a7598ea1f299..27148de59a544 100644 --- a/datafusion/common/src/pruning.rs +++ b/datafusion/common/src/pruning.rs @@ -121,6 +121,7 @@ pub trait PruningStatistics { /// container, return `None` (the default). /// /// Note: the returned array must contain [`Self::num_containers`] rows + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key fn contained( &self, column: &Column, @@ -526,6 +527,7 @@ impl PruningStatistics for CompositePruningStatistics { #[cfg(test)] #[expect(deprecated)] +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key mod tests { use crate::{ ColumnStatistics, diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index ad76442562985..4511d8db90075 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -4601,6 +4601,7 @@ impl ScalarValue { /// Estimates [size](Self::size) of [`HashSet`] in bytes. /// /// Includes the size of the [`HashSet`] container itself. + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key pub fn size_of_hashset(set: &HashSet) -> usize { size_of_val(set) + (size_of::() * set.capacity()) @@ -7263,6 +7264,8 @@ mod tests { size_of::>() + (9 * size_of::()) + sv_size, ); + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] + // ScalarValue has interior mutability but is intentionally used as hash key let mut s = HashSet::with_capacity(0); // do NOT clone `sv` here because this may shrink the vector capacity s.insert(v.pop().unwrap()); diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 4cf5cc366158b..a212122401f98 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -669,58 +669,41 @@ impl Statistics { where I: IntoIterator, { - let items: Vec<&Statistics> = items.into_iter().collect(); - - if items.is_empty() { + let mut items = items.into_iter(); + let Some(first) = items.next() else { return Ok(Statistics::new_unknown(schema)); - } - if items.len() == 1 { - return Ok(items[0].clone()); + }; + let Some(second) = items.next() else { + return Ok(first.clone()); + }; + + let num_cols = first.column_statistics.len(); + let mut num_rows = first.num_rows; + let mut total_byte_size = first.total_byte_size; + let mut column_statistics = first.column_statistics.clone(); + for col_stats in &mut column_statistics { + cast_sum_value_to_sum_type_in_place(&mut col_stats.sum_value); } - let num_cols = items[0].column_statistics.len(); - // Validate all items have the same number of columns - for (i, stat) in items.iter().enumerate().skip(1) { + // Merge the remaining items in a single pass. + for (i, stat) in std::iter::once(second).chain(items).enumerate() { if stat.column_statistics.len() != num_cols { return _plan_err!( "Cannot merge statistics with different number of columns: {} vs {} (item {})", num_cols, stat.column_statistics.len(), - i + i + 1 ); } - } - - // Aggregate usize fields (cheap arithmetic) - let mut num_rows = Precision::Exact(0usize); - let mut total_byte_size = Precision::Exact(0usize); - for stat in &items { num_rows = num_rows.add(&stat.num_rows); total_byte_size = total_byte_size.add(&stat.total_byte_size); - } - - let first = items[0]; - let mut column_statistics: Vec = first - .column_statistics - .iter() - .map(|cs| ColumnStatistics { - null_count: cs.null_count, - max_value: cs.max_value.clone(), - min_value: cs.min_value.clone(), - sum_value: cs.sum_value.cast_to_sum_type(), - distinct_count: cs.distinct_count, - byte_size: cs.byte_size, - }) - .collect(); - - // Accumulate all statistics in a single pass. - // Uses precision_add for sum (reuses the lhs accumulator for - // direct numeric addition), while preserving the NDV update - // ordering required by estimate_ndv_with_overlap. - for stat in items.iter().skip(1) { - for (col_idx, col_stats) in column_statistics.iter_mut().enumerate() { - let item_cs = &stat.column_statistics[col_idx]; + // Uses precision_add for sum (reuses the lhs accumulator for + // direct numeric addition), while preserving the NDV update + // ordering required by estimate_ndv_with_overlap. + for (col_stats, item_cs) in + column_statistics.iter_mut().zip(&stat.column_statistics) + { col_stats.null_count = col_stats.null_count.add(&item_cs.null_count); // NDV must be computed before min/max update (needs pre-merge ranges) @@ -734,10 +717,12 @@ impl Statistics { ), _ => Precision::Absent, }; - col_stats.min_value = col_stats.min_value.min(&item_cs.min_value); - col_stats.max_value = col_stats.max_value.max(&item_cs.max_value); - let item_sum_value = item_cs.sum_value.cast_to_sum_type(); - precision_add(&mut col_stats.sum_value, &item_sum_value); + precision_min(&mut col_stats.min_value, &item_cs.min_value); + precision_max(&mut col_stats.max_value, &item_cs.max_value); + precision_add_for_sum_in_place( + &mut col_stats.sum_value, + &item_cs.sum_value, + ); col_stats.byte_size = col_stats.byte_size.add(&item_cs.byte_size); } } @@ -840,6 +825,115 @@ pub fn estimate_ndv_with_overlap( Some((intersection + only_left + only_right).round() as usize) } +/// Returns the minimum precision while not allocating a new value, +/// mirrors the semantics of `PartialOrd`. +#[inline] +fn precision_min(lhs: &mut Precision, rhs: &Precision) +where + T: Debug + Clone + PartialEq + Eq + PartialOrd, +{ + *lhs = match (std::mem::take(lhs), rhs) { + (Precision::Exact(left), Precision::Exact(right)) => { + if left <= *right { + Precision::Exact(left) + } else { + Precision::Exact(right.clone()) + } + } + (Precision::Exact(left), Precision::Inexact(right)) + | (Precision::Inexact(left), Precision::Exact(right)) + | (Precision::Inexact(left), Precision::Inexact(right)) => { + if left <= *right { + Precision::Inexact(left) + } else { + Precision::Inexact(right.clone()) + } + } + (_, _) => Precision::Absent, + }; +} + +/// Returns the maximum precision while not allocating a new value, +/// mirrors the semantics of `PartialOrd`. +#[inline] +fn precision_max(lhs: &mut Precision, rhs: &Precision) +where + T: Debug + Clone + PartialEq + Eq + PartialOrd, +{ + *lhs = match (std::mem::take(lhs), rhs) { + (Precision::Exact(left), Precision::Exact(right)) => { + if left >= *right { + Precision::Exact(left) + } else { + Precision::Exact(right.clone()) + } + } + (Precision::Exact(left), Precision::Inexact(right)) + | (Precision::Inexact(left), Precision::Exact(right)) + | (Precision::Inexact(left), Precision::Inexact(right)) => { + if left >= *right { + Precision::Inexact(left) + } else { + Precision::Inexact(right.clone()) + } + } + (_, _) => Precision::Absent, + }; +} + +#[inline] +fn cast_sum_value_to_sum_type_in_place(value: &mut Precision) { + let (is_exact, inner) = match std::mem::take(value) { + Precision::Exact(v) => (true, v), + Precision::Inexact(v) => (false, v), + Precision::Absent => return, + }; + let source_type = inner.data_type(); + let target_type = Precision::::sum_data_type(&source_type); + + let wrap_precision_fn: fn(ScalarValue) -> Precision = if is_exact { + Precision::Exact + } else { + Precision::Inexact + }; + + *value = if source_type == target_type { + wrap_precision_fn(inner) + } else { + inner + .cast_to(&target_type) + .map(wrap_precision_fn) + .unwrap_or(Precision::Absent) + }; +} + +#[inline] +fn precision_add_for_sum_in_place( + lhs: &mut Precision, + rhs: &Precision, +) { + let (value, wrap_fn): (&ScalarValue, fn(ScalarValue) -> Precision) = + match rhs { + Precision::Exact(v) => (v, Precision::Exact), + Precision::Inexact(v) => (v, Precision::Inexact), + Precision::Absent => { + *lhs = Precision::Absent; + return; + } + }; + let source_type = value.data_type(); + let target_type = Precision::::sum_data_type(&source_type); + if source_type == target_type { + precision_add(lhs, rhs); + } else { + let rhs = value + .cast_to(&target_type) + .map(wrap_fn) + .unwrap_or(Precision::Absent); + precision_add(lhs, &rhs); + } +} + /// Creates an estimate of the number of rows in the output using the given /// optional value and exactness flag. fn check_num_rows(value: Option, is_exact: bool) -> Precision { @@ -2624,4 +2718,146 @@ mod tests { Precision::Inexact(ScalarValue::Int64(Some(1500))) ); } + + #[test] + fn test_precision_min_in_place() { + // Exact vs Exact: keeps the smaller + let mut lhs = Precision::Exact(10); + precision_min(&mut lhs, &Precision::Exact(20)); + assert_eq!(lhs, Precision::Exact(10)); + + let mut lhs = Precision::Exact(20); + precision_min(&mut lhs, &Precision::Exact(10)); + assert_eq!(lhs, Precision::Exact(10)); + + // Equal exact values + let mut lhs = Precision::Exact(5); + precision_min(&mut lhs, &Precision::Exact(5)); + assert_eq!(lhs, Precision::Exact(5)); + + // Mixed exact/inexact: result is Inexact with smaller value + let mut lhs = Precision::Exact(10); + precision_min(&mut lhs, &Precision::Inexact(20)); + assert_eq!(lhs, Precision::Inexact(10)); + + let mut lhs = Precision::Inexact(10); + precision_min(&mut lhs, &Precision::Exact(5)); + assert_eq!(lhs, Precision::Inexact(5)); + + // Inexact vs Inexact + let mut lhs = Precision::Inexact(30); + precision_min(&mut lhs, &Precision::Inexact(20)); + assert_eq!(lhs, Precision::Inexact(20)); + + // Absent makes result Absent + let mut lhs = Precision::Exact(10); + precision_min(&mut lhs, &Precision::Absent); + assert_eq!(lhs, Precision::Absent); + + let mut lhs = Precision::::Absent; + precision_min(&mut lhs, &Precision::Exact(10)); + assert_eq!(lhs, Precision::Absent); + } + + #[test] + fn test_precision_max_in_place() { + // Exact vs Exact: keeps the larger + let mut lhs = Precision::Exact(10); + precision_max(&mut lhs, &Precision::Exact(20)); + assert_eq!(lhs, Precision::Exact(20)); + + let mut lhs = Precision::Exact(20); + precision_max(&mut lhs, &Precision::Exact(10)); + assert_eq!(lhs, Precision::Exact(20)); + + // Equal exact values + let mut lhs = Precision::Exact(5); + precision_max(&mut lhs, &Precision::Exact(5)); + assert_eq!(lhs, Precision::Exact(5)); + + // Mixed exact/inexact: result is Inexact with larger value + let mut lhs = Precision::Exact(10); + precision_max(&mut lhs, &Precision::Inexact(20)); + assert_eq!(lhs, Precision::Inexact(20)); + + let mut lhs = Precision::Inexact(10); + precision_max(&mut lhs, &Precision::Exact(5)); + assert_eq!(lhs, Precision::Inexact(10)); + + // Inexact vs Inexact + let mut lhs = Precision::Inexact(20); + precision_max(&mut lhs, &Precision::Inexact(30)); + assert_eq!(lhs, Precision::Inexact(30)); + + // Absent makes result Absent + let mut lhs = Precision::Exact(10); + precision_max(&mut lhs, &Precision::Absent); + assert_eq!(lhs, Precision::Absent); + + let mut lhs = Precision::::Absent; + precision_max(&mut lhs, &Precision::Exact(10)); + assert_eq!(lhs, Precision::Absent); + } + + #[test] + fn test_cast_sum_value_to_sum_type_in_place_widens_int32() { + let mut value = Precision::Exact(ScalarValue::Int32(Some(42))); + cast_sum_value_to_sum_type_in_place(&mut value); + assert_eq!(value, Precision::Exact(ScalarValue::Int64(Some(42)))); + } + + #[test] + fn test_cast_sum_value_to_sum_type_in_place_preserves_int64() { + // Int64 is already the sum type for Int64, no widening needed + let mut value = Precision::Exact(ScalarValue::Int64(Some(100))); + cast_sum_value_to_sum_type_in_place(&mut value); + assert_eq!(value, Precision::Exact(ScalarValue::Int64(Some(100)))); + } + + #[test] + fn test_cast_sum_value_to_sum_type_in_place_inexact() { + let mut value = Precision::Inexact(ScalarValue::Int32(Some(42))); + cast_sum_value_to_sum_type_in_place(&mut value); + assert_eq!(value, Precision::Inexact(ScalarValue::Int64(Some(42)))); + } + + #[test] + fn test_cast_sum_value_to_sum_type_in_place_absent() { + let mut value = Precision::::Absent; + cast_sum_value_to_sum_type_in_place(&mut value); + assert_eq!(value, Precision::Absent); + } + + #[test] + fn test_precision_add_for_sum_in_place_same_type() { + // Int64 + Int64: no widening needed, straight add + let mut lhs = Precision::Exact(ScalarValue::Int64(Some(10))); + let rhs = Precision::Exact(ScalarValue::Int64(Some(20))); + precision_add_for_sum_in_place(&mut lhs, &rhs); + assert_eq!(lhs, Precision::Exact(ScalarValue::Int64(Some(30)))); + } + + #[test] + fn test_precision_add_for_sum_in_place_widens_rhs() { + // lhs is already Int64 (widened), rhs is Int32 -> gets cast to Int64 + let mut lhs = Precision::Exact(ScalarValue::Int64(Some(10))); + let rhs = Precision::Exact(ScalarValue::Int32(Some(5))); + precision_add_for_sum_in_place(&mut lhs, &rhs); + assert_eq!(lhs, Precision::Exact(ScalarValue::Int64(Some(15)))); + } + + #[test] + fn test_precision_add_for_sum_in_place_inexact() { + let mut lhs = Precision::Inexact(ScalarValue::Int64(Some(10))); + let rhs = Precision::Inexact(ScalarValue::Int32(Some(5))); + precision_add_for_sum_in_place(&mut lhs, &rhs); + assert_eq!(lhs, Precision::Inexact(ScalarValue::Int64(Some(15)))); + } + + #[test] + fn test_precision_add_for_sum_in_place_absent_rhs() { + let mut lhs = Precision::Exact(ScalarValue::Int64(Some(10))); + precision_add_for_sum_in_place(&mut lhs, &Precision::Absent); + assert_eq!(lhs, Precision::Absent); + } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 82f3c4d80c9ec..bf84fcc53e957 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2093,6 +2093,7 @@ fn get_physical_expr_pair( /// A vector of unqualified filter expressions that can be passed to the TableProvider for execution. /// Returns an empty vector if no applicable filters are found. /// +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn extract_dml_filters( input: &Arc, target: &TableReference, @@ -3573,21 +3574,17 @@ mod tests { } #[tokio::test] - async fn in_list_types() -> Result<()> { - // expression: "a in ('a', 1)" + async fn in_list_types_mixed_string_int_error() -> Result<()> { + // expression: "c1 in ('a', 1)" where c1 is Utf8 let list = vec![lit("a"), lit(1i64)]; let logical_plan = test_csv_scan() .await? - // filter clause needs the type coercion rule applied .filter(col("c12").lt(lit(0.05)))? .project(vec![col("c1").in_list(list, false)])? .build()?; - let execution_plan = plan(&logical_plan).await?; - // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - - let expected = r#"expr: BinaryExpr { left: BinaryExpr { left: Column { name: "c1", index: 0 }, op: Eq, right: Literal { value: Utf8("a"), field: Field { name: "lit", data_type: Utf8 } }, fail_on_overflow: false }"#; + let e = plan(&logical_plan).await.unwrap_err().to_string(); - assert_contains!(format!("{execution_plan:?}"), expected); + assert_contains!(&e, "Cannot cast string 'a' to value of Int64 type"); Ok(()) } diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 8726c5aec9057..19ff3933193de 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -342,20 +342,26 @@ fn test_create_physical_expr_nvl2() { #[tokio::test] async fn test_create_physical_expr_coercion() { - // create_physical_expr does apply type coercion and unwrapping in cast + // create_physical_expr applies type coercion (and can unwrap/fold + // literal casts). Comparison coercion prefers numeric types, so + // string/int comparisons cast the string side to the numeric type. // - // expect the cast on the literals - // compare string function to int `id = 1` - create_expr_test(col("id").eq(lit(1i32)), "id@0 = CAST(1 AS Utf8)"); - create_expr_test(lit(1i32).eq(col("id")), "CAST(1 AS Utf8) = id@0"); - // compare int col to string literal `i = '202410'` - // Note this casts the column (not the field) - create_expr_test(col("i").eq(lit("202410")), "CAST(i@1 AS Utf8) = 202410"); - create_expr_test(lit("202410").eq(col("i")), "202410 = CAST(i@1 AS Utf8)"); - // however, when simplified the casts on i should removed - // https://github.com/apache/datafusion/issues/14944 - create_simplified_expr_test(col("i").eq(lit("202410")), "CAST(i@1 AS Utf8) = 202410"); - create_simplified_expr_test(lit("202410").eq(col("i")), "CAST(i@1 AS Utf8) = 202410"); + // string column vs int literal: id (Utf8) is cast to Int32 + create_expr_test(col("id").eq(lit(1i32)), "CAST(id@0 AS Int32) = 1"); + create_expr_test(lit(1i32).eq(col("id")), "1 = CAST(id@0 AS Int32)"); + // int column vs string literal: the string literal is cast to Int64 + create_expr_test(col("i").eq(lit("202410")), "i@1 = CAST(202410 AS Int64)"); + create_expr_test(lit("202410").eq(col("i")), "CAST(202410 AS Int64) = i@1"); + // The simplifier operates on the logical expression before type + // coercion adds the CAST, so the output is unchanged. + create_simplified_expr_test( + col("i").eq(lit("202410")), + "i@1 = CAST(202410 AS Int64)", + ); + create_simplified_expr_test( + lit("202410").eq(col("i")), + "i@1 = CAST(202410 AS Int64)", + ); } /// Evaluates the specified expr as an aggregate and compares the result to the diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index ff8c512cbd22e..7075fbc2443d2 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -24,6 +24,7 @@ use std::sync::{Arc, LazyLock}; #[cfg(feature = "extended_tests")] mod memory_limit_validation; mod repartition_mem_limit; +mod union_nullable_spill; use arrow::array::{ArrayRef, DictionaryArray, Int32Array, RecordBatch, StringViewArray}; use arrow::compute::SortOptions; use arrow::datatypes::{Int32Type, SchemaRef}; diff --git a/datafusion/core/tests/memory_limit/union_nullable_spill.rs b/datafusion/core/tests/memory_limit/union_nullable_spill.rs new file mode 100644 index 0000000000000..c5ef2387d3cdc --- /dev/null +++ b/datafusion/core/tests/memory_limit/union_nullable_spill.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Array, Int64Array, RecordBatch}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::datasource::memory::MemorySourceConfig; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::memory_pool::FairSpillPool; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::sort_batch; +use datafusion_physical_plan::union::UnionExec; +use datafusion_physical_plan::{ExecutionPlan, Partitioning}; +use futures::StreamExt; + +const NUM_BATCHES: usize = 200; +const ROWS_PER_BATCH: usize = 10; + +fn non_nullable_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("val", DataType::Int64, false), + ])) +} + +fn nullable_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("val", DataType::Int64, true), + ])) +} + +fn non_nullable_batches() -> Vec { + (0..NUM_BATCHES) + .map(|i| { + let start = (i * ROWS_PER_BATCH) as i64; + let keys: Vec = (start..start + ROWS_PER_BATCH as i64).collect(); + RecordBatch::try_new( + non_nullable_schema(), + vec![ + Arc::new(Int64Array::from(keys)), + Arc::new(Int64Array::from(vec![0i64; ROWS_PER_BATCH])), + ], + ) + .unwrap() + }) + .collect() +} + +fn nullable_batches() -> Vec { + (0..NUM_BATCHES) + .map(|i| { + let start = (i * ROWS_PER_BATCH) as i64; + let keys: Vec = (start..start + ROWS_PER_BATCH as i64).collect(); + let vals: Vec> = (0..ROWS_PER_BATCH) + .map(|j| if j % 3 == 1 { None } else { Some(j as i64) }) + .collect(); + RecordBatch::try_new( + nullable_schema(), + vec![ + Arc::new(Int64Array::from(keys)), + Arc::new(Int64Array::from(vals)), + ], + ) + .unwrap() + }) + .collect() +} + +fn build_task_ctx(pool_size: usize) -> Arc { + let session_config = SessionConfig::new().with_batch_size(2); + let runtime = RuntimeEnvBuilder::new() + .with_memory_pool(Arc::new(FairSpillPool::new(pool_size))) + .build_arc() + .unwrap(); + Arc::new( + datafusion_execution::TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ) +} + +/// Exercises spilling through UnionExec -> RepartitionExec where union children +/// have mismatched nullability (one child's `val` is non-nullable, the other's +/// is nullable with NULLs). A tiny FairSpillPool forces all batches to spill. +/// +/// UnionExec returns child streams without schema coercion, so batches from +/// different children carry different per-field nullability into the shared +/// SpillPool. The IPC writer must use the SpillManager's canonical (nullable) +/// schema — not the first batch's schema — so readback batches are valid. +/// +/// Otherwise, sort_batch will panic with +/// `Column 'val' is declared as non-nullable but contains null values` +#[tokio::test] +async fn test_sort_union_repartition_spill_mixed_nullability() { + let non_nullable_exec = MemorySourceConfig::try_new_exec( + &[non_nullable_batches()], + non_nullable_schema(), + None, + ) + .unwrap(); + + let nullable_exec = + MemorySourceConfig::try_new_exec(&[nullable_batches()], nullable_schema(), None) + .unwrap(); + + let union_exec = UnionExec::try_new(vec![non_nullable_exec, nullable_exec]).unwrap(); + assert!(union_exec.schema().field(1).is_nullable()); + + let repartition = Arc::new( + RepartitionExec::try_new(union_exec, Partitioning::RoundRobinBatch(1)).unwrap(), + ); + + let task_ctx = build_task_ctx(200); + let mut stream = repartition.execute(0, task_ctx).unwrap(); + + let sort_expr = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("key", &nullable_schema()).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + + let mut total_rows = 0usize; + let mut total_nulls = 0usize; + while let Some(result) = stream.next().await { + let batch = result.unwrap(); + + let batch = sort_batch(&batch, &sort_expr, None).unwrap(); + + total_rows += batch.num_rows(); + total_nulls += batch.column(1).null_count(); + } + + assert_eq!( + total_rows, + NUM_BATCHES * ROWS_PER_BATCH * 2, + "All rows from both UNION branches should be present" + ); + assert!( + total_nulls > 0, + "Expected some null values in output (i.e. nullable batches were processed)" + ); +} diff --git a/datafusion/core/tests/sql/unparser.rs b/datafusion/core/tests/sql/unparser.rs index e9bad71843ff2..d6ca872e198c3 100644 --- a/datafusion/core/tests/sql/unparser.rs +++ b/datafusion/core/tests/sql/unparser.rs @@ -142,16 +142,26 @@ fn tpch_queries() -> Vec { } /// Create a new SessionContext for testing that has all Clickbench tables registered. +/// +/// Registers the raw Parquet as `hits_raw`, then creates a `hits` view that +/// casts `EventDate` from UInt16 (day-offset) to DATE. This mirrors the +/// approach used by the benchmark runner in `benchmarks/src/clickbench.rs`. async fn clickbench_test_context() -> Result { let ctx = SessionContext::new(); ctx.register_parquet( - "hits", + "hits_raw", "tests/data/clickbench_hits_10.parquet", ParquetReadOptions::default(), ) .await?; - // Sanity check we found the table by querying it's schema, it should not be empty - // Otherwise if the path is wrong the tests will all fail in confusing ways + ctx.sql( + r#"CREATE VIEW hits AS + SELECT * EXCEPT ("EventDate"), + CAST(CAST("EventDate" AS INTEGER) AS DATE) AS "EventDate" + FROM hits_raw"#, + ) + .await?; + // Sanity check we found the table by querying its schema let df = ctx.sql("SELECT * FROM hits LIMIT 1").await?; assert!( !df.schema().fields().is_empty(), diff --git a/datafusion/datasource/src/file_groups.rs b/datafusion/datasource/src/file_groups.rs index 28a403ab92ad8..84594be54b504 100644 --- a/datafusion/datasource/src/file_groups.rs +++ b/datafusion/datasource/src/file_groups.rs @@ -488,6 +488,7 @@ impl FileGroup { /// /// Note: May return fewer groups than `max_target_partitions` when the /// number of unique partition values is less than the target. + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key pub fn group_by_partition_values( self, max_target_partitions: usize, diff --git a/datafusion/datasource/src/file_stream/builder.rs b/datafusion/datasource/src/file_stream/builder.rs index 7f21ace92c46b..6d99f4b56a8ee 100644 --- a/datafusion/datasource/src/file_stream/builder.rs +++ b/datafusion/datasource/src/file_stream/builder.rs @@ -21,7 +21,8 @@ use crate::file_scan_config::FileScanConfig; use datafusion_common::{Result, internal_err}; use datafusion_physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; -use super::{FileOpener, FileStream, FileStreamMetrics, FileStreamState, OnError}; +use super::metrics::FileStreamMetrics; +use super::{FileOpener, FileStream, FileStreamState, OnError}; /// Builder for constructing a [`FileStream`]. pub struct FileStreamBuilder<'a> { diff --git a/datafusion/datasource/src/file_stream/metrics.rs b/datafusion/datasource/src/file_stream/metrics.rs new file mode 100644 index 0000000000000..f4dddeaee8d0e --- /dev/null +++ b/datafusion/datasource/src/file_stream/metrics.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::instant::Instant; +use datafusion_physical_plan::metrics::{ + Count, ExecutionPlanMetricsSet, MetricBuilder, MetricCategory, Time, +}; + +/// A timer that can be started and stopped. +pub struct StartableTime { + pub metrics: Time, + // use for record each part cost time, will eventually add into 'metrics'. + pub start: Option, +} + +impl StartableTime { + pub fn start(&mut self) { + assert!(self.start.is_none()); + self.start = Some(Instant::now()); + } + + pub fn stop(&mut self) { + if let Some(start) = self.start.take() { + self.metrics.add_elapsed(start); + } + } +} + +/// Metrics for [`FileStream`] +/// +/// Note that all of these metrics are in terms of wall clock time +/// (not cpu time) so they include time spent waiting on I/O as well +/// as other operators. +/// +/// [`FileStream`]: +pub struct FileStreamMetrics { + /// Wall clock time elapsed for file opening. + /// + /// Time between when [`FileOpener::open`] is called and when the + /// [`FileStream`] receives a stream for reading. + /// + /// [`FileStream`]: crate::file_stream::FileStream + /// [`FileOpener::open`]: crate::file_stream::FileOpener::open + pub time_opening: StartableTime, + /// Wall clock time elapsed for file scanning + first record batch of decompression + decoding + /// + /// Time between when the [`FileStream`] requests data from the + /// stream and when the first [`RecordBatch`] is produced. + /// + /// [`FileStream`]: crate::file_stream::FileStream + /// [`RecordBatch`]: arrow::record_batch::RecordBatch + pub time_scanning_until_data: StartableTime, + /// Total elapsed wall clock time for scanning + record batch decompression / decoding + /// + /// Sum of time between when the [`FileStream`] requests data from + /// the stream and when a [`RecordBatch`] is produced for all + /// record batches in the stream. Note that this metric also + /// includes the time of the parent operator's execution. + /// + /// [`FileStream`]: crate::file_stream::FileStream + /// [`RecordBatch`]: arrow::record_batch::RecordBatch + pub time_scanning_total: StartableTime, + /// Wall clock time elapsed for data decompression + decoding + /// + /// Time spent waiting for the FileStream's input. + pub time_processing: StartableTime, + /// Count of errors opening file. + /// + /// If using `OnError::Skip` this will provide a count of the number of files + /// which were skipped and will not be included in the scan results. + pub file_open_errors: Count, + /// Count of errors scanning file + /// + /// If using `OnError::Skip` this will provide a count of the number of files + /// which were skipped and will not be included in the scan results. + pub file_scan_errors: Count, + /// Count of files successfully opened or evaluated for processing. + /// At t=end (completion of a query) this is equal to `files_opened`, and both values are equal + /// to the total number of files in the query; unless the query itself fails. + /// This value will always be greater than or equal to `files_open`. + /// Note that this value does *not* mean the file was actually scanned. + /// We increment this value for any processing of a file, even if that processing is + /// discarding it because we hit a `LIMIT` (in this case `files_opened` and `files_processed` are both incremented at the same time). + pub files_opened: Count, + /// Count of files completely processed / closed (opened, pruned, or skipped due to limit). + /// At t=0 (the beginning of a query) this is 0. + /// At t=end (completion of a query) this is equal to `files_opened`, and both values are equal + /// to the total number of files in the query; unless the query itself fails. + /// This value will always be less than or equal to `files_open`. + /// We increment this value for any processing of a file, even if that processing is + /// discarding it because we hit a `LIMIT` (in this case `files_opened` and `files_processed` are both incremented at the same time). + pub files_processed: Count, +} + +impl FileStreamMetrics { + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + let time_opening = StartableTime { + metrics: MetricBuilder::new(metrics) + .subset_time("time_elapsed_opening", partition), + start: None, + }; + + let time_scanning_until_data = StartableTime { + metrics: MetricBuilder::new(metrics) + .subset_time("time_elapsed_scanning_until_data", partition), + start: None, + }; + + let time_scanning_total = StartableTime { + metrics: MetricBuilder::new(metrics) + .subset_time("time_elapsed_scanning_total", partition), + start: None, + }; + + let time_processing = StartableTime { + metrics: MetricBuilder::new(metrics) + .subset_time("time_elapsed_processing", partition), + start: None, + }; + + let file_open_errors = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("file_open_errors", partition); + + let file_scan_errors = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("file_scan_errors", partition); + + let files_opened = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("files_opened", partition); + + let files_processed = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("files_processed", partition); + + Self { + time_opening, + time_scanning_until_data, + time_scanning_total, + time_processing, + file_open_errors, + file_scan_errors, + files_opened, + files_processed, + } + } +} diff --git a/datafusion/datasource/src/file_stream/mod.rs b/datafusion/datasource/src/file_stream/mod.rs index a423552917408..33e5065cb5a3f 100644 --- a/datafusion/datasource/src/file_stream/mod.rs +++ b/datafusion/datasource/src/file_stream/mod.rs @@ -22,6 +22,7 @@ //! compliant with the `SendableRecordBatchStream` trait. mod builder; +mod metrics; use std::collections::VecDeque; use std::pin::Pin; @@ -33,18 +34,16 @@ use crate::file_scan_config::FileScanConfig; use arrow::datatypes::SchemaRef; use datafusion_common::Result; use datafusion_execution::RecordBatchStream; -use datafusion_physical_plan::metrics::{ - BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricCategory, Time, -}; +use datafusion_physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; use arrow::record_batch::RecordBatch; -use datafusion_common::instant::Instant; use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{FutureExt as _, Stream, StreamExt as _, ready}; pub use builder::FileStreamBuilder; +pub use metrics::{FileStreamMetrics, StartableTime}; /// A stream that iterates record batch by record batch, file over file. pub struct FileStream { @@ -261,139 +260,6 @@ pub enum FileStreamState { } /// A timer that can be started and stopped. -pub struct StartableTime { - pub metrics: Time, - // use for record each part cost time, will eventually add into 'metrics'. - pub start: Option, -} - -impl StartableTime { - pub fn start(&mut self) { - assert!(self.start.is_none()); - self.start = Some(Instant::now()); - } - - pub fn stop(&mut self) { - if let Some(start) = self.start.take() { - self.metrics.add_elapsed(start); - } - } -} - -/// Metrics for [`FileStream`] -/// -/// Note that all of these metrics are in terms of wall clock time -/// (not cpu time) so they include time spent waiting on I/O as well -/// as other operators. -/// -/// [`FileStream`]: -pub struct FileStreamMetrics { - /// Wall clock time elapsed for file opening. - /// - /// Time between when [`FileOpener::open`] is called and when the - /// [`FileStream`] receives a stream for reading. - /// [`FileStream`]: - pub time_opening: StartableTime, - /// Wall clock time elapsed for file scanning + first record batch of decompression + decoding - /// - /// Time between when the [`FileStream`] requests data from the - /// stream and when the first [`RecordBatch`] is produced. - /// [`FileStream`]: - pub time_scanning_until_data: StartableTime, - /// Total elapsed wall clock time for scanning + record batch decompression / decoding - /// - /// Sum of time between when the [`FileStream`] requests data from - /// the stream and when a [`RecordBatch`] is produced for all - /// record batches in the stream. Note that this metric also - /// includes the time of the parent operator's execution. - pub time_scanning_total: StartableTime, - /// Wall clock time elapsed for data decompression + decoding - /// - /// Time spent waiting for the FileStream's input. - pub time_processing: StartableTime, - /// Count of errors opening file. - /// - /// If using `OnError::Skip` this will provide a count of the number of files - /// which were skipped and will not be included in the scan results. - pub file_open_errors: Count, - /// Count of errors scanning file - /// - /// If using `OnError::Skip` this will provide a count of the number of files - /// which were skipped and will not be included in the scan results. - pub file_scan_errors: Count, - /// Count of files successfully opened or evaluated for processing. - /// At t=end (completion of a query) this is equal to `files_opened`, and both values are equal - /// to the total number of files in the query; unless the query itself fails. - /// This value will always be greater than or equal to `files_open`. - /// Note that this value does *not* mean the file was actually scanned. - /// We increment this value for any processing of a file, even if that processing is - /// discarding it because we hit a `LIMIT` (in this case `files_opened` and `files_processed` are both incremented at the same time). - pub files_opened: Count, - /// Count of files completely processed / closed (opened, pruned, or skipped due to limit). - /// At t=0 (the beginning of a query) this is 0. - /// At t=end (completion of a query) this is equal to `files_opened`, and both values are equal - /// to the total number of files in the query; unless the query itself fails. - /// This value will always be less than or equal to `files_open`. - /// We increment this value for any processing of a file, even if that processing is - /// discarding it because we hit a `LIMIT` (in this case `files_opened` and `files_processed` are both incremented at the same time). - pub files_processed: Count, -} - -impl FileStreamMetrics { - pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { - let time_opening = StartableTime { - metrics: MetricBuilder::new(metrics) - .subset_time("time_elapsed_opening", partition), - start: None, - }; - - let time_scanning_until_data = StartableTime { - metrics: MetricBuilder::new(metrics) - .subset_time("time_elapsed_scanning_until_data", partition), - start: None, - }; - - let time_scanning_total = StartableTime { - metrics: MetricBuilder::new(metrics) - .subset_time("time_elapsed_scanning_total", partition), - start: None, - }; - - let time_processing = StartableTime { - metrics: MetricBuilder::new(metrics) - .subset_time("time_elapsed_processing", partition), - start: None, - }; - - let file_open_errors = MetricBuilder::new(metrics) - .with_category(MetricCategory::Rows) - .counter("file_open_errors", partition); - - let file_scan_errors = MetricBuilder::new(metrics) - .with_category(MetricCategory::Rows) - .counter("file_scan_errors", partition); - - let files_opened = MetricBuilder::new(metrics) - .with_category(MetricCategory::Rows) - .counter("files_opened", partition); - - let files_processed = MetricBuilder::new(metrics) - .with_category(MetricCategory::Rows) - .counter("files_processed", partition); - - Self { - time_opening, - time_scanning_until_data, - time_scanning_total, - time_processing, - file_open_errors, - file_scan_errors, - files_opened, - files_processed, - } - } -} - #[cfg(test)] mod tests { use crate::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; diff --git a/datafusion/execution/src/memory_pool/arrow.rs b/datafusion/execution/src/memory_pool/arrow.rs index 4e8d986f1f5e3..929e3b7bd27e5 100644 --- a/datafusion/execution/src/memory_pool/arrow.rs +++ b/datafusion/execution/src/memory_pool/arrow.rs @@ -59,7 +59,7 @@ impl arrow_buffer::MemoryReservation for MemoryReservation { impl arrow_buffer::MemoryPool for ArrowMemoryPool { fn reserve(&self, size: usize) -> Box { let consumer = self.consumer.clone_with_new_id(); - let mut reservation = consumer.register(&self.inner); + let reservation = consumer.register(&self.inner); reservation.grow(size); Box::new(reservation) diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 883c721080611..c0cecad4a35c9 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -22,7 +22,7 @@ use std::fmt::{self, Display, Formatter}; use std::ops::{AddAssign, SubAssign}; use crate::operator::Operator; -use crate::type_coercion::binary::{BinaryTypeCoercer, comparison_coercion_numeric}; +use crate::type_coercion::binary::{BinaryTypeCoercer, comparison_coercion}; use arrow::compute::{CastOptions, cast_with_options}; use arrow::datatypes::{ @@ -730,7 +730,7 @@ impl Interval { (self.lower.clone(), self.upper.clone(), rhs.clone()) } else { let maybe_common_type = - comparison_coercion_numeric(&self.data_type(), &rhs.data_type()); + comparison_coercion(&self.data_type(), &rhs.data_type()); assert_or_internal_err!( maybe_common_type.is_some(), "Data types must be compatible for containment checks, lhs:{}, rhs:{}", diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 82759be9f75e8..42d4de939a8e8 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -158,7 +158,7 @@ pub enum Arity { pub enum TypeSignature { /// One or more arguments of a common type out of a list of valid types. /// - /// For functions that take no arguments (e.g. `random()` see [`TypeSignature::Nullary`]). + /// For functions that take no arguments (e.g. `random()`), see [`TypeSignature::Nullary`]. /// /// # Examples /// @@ -184,7 +184,7 @@ pub enum TypeSignature { Uniform(usize, Vec), /// One or more arguments with exactly the specified types in order. /// - /// For functions that take no arguments (e.g. `random()`) use [`TypeSignature::Nullary`]. + /// For functions that take no arguments (e.g. `random()`), use [`TypeSignature::Nullary`]. Exact(Vec), /// One or more arguments belonging to the [`TypeSignatureClass`], in order. /// @@ -192,12 +192,12 @@ pub enum TypeSignature { /// casts. For example, if you expect a function has string type, but you /// also allow it to be casted from binary type. /// - /// For functions that take no arguments (e.g. `random()`) see [`TypeSignature::Nullary`]. + /// For functions that take no arguments (e.g. `random()`), see [`TypeSignature::Nullary`]. Coercible(Vec), /// One or more arguments coercible to a single, comparable type. /// /// Each argument will be coerced to a single type using the - /// coercion rules described in [`comparison_coercion_numeric`]. + /// coercion rules described in [`comparison_coercion`]. /// /// # Examples /// @@ -205,17 +205,18 @@ pub enum TypeSignature { /// the types will both be coerced to `i64` before the function is invoked. /// /// If the `nullif('1', 2)` function is called with `Utf8` and `i64` arguments - /// the types will both be coerced to `Utf8` before the function is invoked. + /// the types will both be coerced to `Int64` before the function is invoked + /// (numeric is preferred over string). /// /// Note: - /// - For functions that take no arguments (e.g. `random()` see [`TypeSignature::Nullary`]). + /// - For functions that take no arguments (e.g. `random()`), see [`TypeSignature::Nullary`]. /// - If all arguments have type [`DataType::Null`], they are coerced to `Utf8` /// - /// [`comparison_coercion_numeric`]: crate::type_coercion::binary::comparison_coercion_numeric + /// [`comparison_coercion`]: crate::type_coercion::binary::comparison_coercion Comparable(usize), /// One or more arguments of arbitrary types. /// - /// For functions that take no arguments (e.g. `random()`) use [`TypeSignature::Nullary`]. + /// For functions that take no arguments (e.g. `random()`), use [`TypeSignature::Nullary`]. Any(usize), /// Matches exactly one of a list of [`TypeSignature`]s. /// @@ -233,7 +234,7 @@ pub enum TypeSignature { /// /// See [`NativeType::is_numeric`] to know which type is considered numeric /// - /// For functions that take no arguments (e.g. `random()`) use [`TypeSignature::Nullary`]. + /// For functions that take no arguments (e.g. `random()`), use [`TypeSignature::Nullary`]. /// /// [`NativeType::is_numeric`]: datafusion_common::types::NativeType::is_numeric Numeric(usize), @@ -246,7 +247,7 @@ pub enum TypeSignature { /// For example, if a function is called with (utf8, large_utf8), all /// arguments will be coerced to `LargeUtf8` /// - /// For functions that take no arguments (e.g. `random()` use [`TypeSignature::Nullary`]). + /// For functions that take no arguments (e.g. `random()`), use [`TypeSignature::Nullary`]. String(usize), /// No arguments Nullary, diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index fa109e38a4382..eda99dcc32075 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -33,7 +33,6 @@ use arrow::datatypes::{ DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType, Field, FieldRef, Fields, TimeUnit, }; -use datafusion_common::types::NativeType; use datafusion_common::{ Diagnostic, Result, Span, Spans, exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, @@ -580,8 +579,8 @@ impl From<&DataType> for TypeCategory { } /// Coerce dissimilar data types to a single data type. -/// UNION, INTERSECT, EXCEPT, CASE, ARRAY, VALUES, and the GREATEST and LEAST functions are -/// examples that has the similar resolution rules. +/// ARRAY literals, VALUES, COALESCE, and array concatenation are examples +/// of contexts that use this function. /// See for more information. /// The rules in the document provide a clue, but adhering strictly to them doesn't precisely /// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules @@ -741,16 +740,13 @@ fn type_union_resolution_coercion( .collect(); Some(DataType::Struct(fields.into())) } - _ => { - // Numeric coercion is the same as comparison coercion, both find the narrowest type - // that can accommodate both types - binary_numeric_coercion(lhs_type, rhs_type) - .or_else(|| list_coercion(lhs_type, rhs_type)) - .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) - .or_else(|| string_coercion(lhs_type, rhs_type)) - .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) - .or_else(|| binary_coercion(lhs_type, rhs_type)) - } + _ => binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| list_coercion(lhs_type, rhs_type, type_union_resolution_coercion)) + .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) + .or_else(|| string_coercion(lhs_type, rhs_type)) + .or_else(|| null_coercion(lhs_type, rhs_type)) + .or_else(|| string_numeric_coercion(lhs_type, rhs_type)) + .or_else(|| binary_coercion(lhs_type, rhs_type)), } } @@ -843,102 +839,104 @@ pub fn try_type_union_resolution_with_struct( Ok(final_struct_types) } -/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a -/// comparison operation -/// -/// Example comparison operations are `lhs = rhs` and `lhs > rhs` +/// Coerce `lhs_type` and `rhs_type` to a common type for type unification +/// contexts — where two values must be brought to a common type but are not +/// being compared. Examples: UNION, CASE THEN/ELSE branches, NVL2. For other +/// contexts, [`comparison_coercion`] should typically be used instead. /// -/// Binary comparison kernels require the two arguments to be the (exact) same -/// data type. However, users can write queries where the two arguments are -/// different data types. In such cases, the data types are automatically cast -/// (coerced) to a single data type to pass to the kernels. -/// -/// # Numeric comparisons -/// -/// When comparing numeric values, the lower precision type is coerced to the -/// higher precision type to avoid losing data. For example when comparing -/// `Int32` to `Int64` the coerced type is `Int64` so the `Int32` argument will -/// be cast. -/// -/// # Numeric / String comparisons -/// -/// When comparing numeric values and strings, both values will be coerced to -/// strings. For example when comparing `'2' > 1`, the arguments will be -/// coerced to `Utf8` for comparison -pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { +/// The intuition is that we try to find the "widest" type that can represent +/// all values from both sides. When one side is a string and the other is +/// numeric, this prefers strings because every number has a textual +/// representation but not every string can be parsed as a number (e.g., `SELECT +/// 1 UNION SELECT 'a'` coerces both sides to a string). This is in contrast to +/// [`comparison_coercion`], which prefers numeric types so that ordering and +/// equality follow numeric semantics. +pub fn type_union_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { if lhs_type.equals_datatype(rhs_type) { - // same type => equality is possible return Some(lhs_type.clone()); } binary_numeric_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, true)) - .or_else(|| ree_comparison_coercion(lhs_type, rhs_type, true)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type, true, type_union_coercion)) + .or_else(|| ree_coercion(lhs_type, rhs_type, true, type_union_coercion)) .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) - .or_else(|| list_coercion(lhs_type, rhs_type)) + .or_else(|| list_coercion(lhs_type, rhs_type, type_union_coercion)) .or_else(|| null_coercion(lhs_type, rhs_type)) - .or_else(|| string_numeric_coercion(lhs_type, rhs_type)) + .or_else(|| string_numeric_union_coercion(lhs_type, rhs_type)) .or_else(|| string_temporal_coercion(lhs_type, rhs_type)) .or_else(|| binary_coercion(lhs_type, rhs_type)) - .or_else(|| struct_coercion(lhs_type, rhs_type)) - .or_else(|| map_coercion(lhs_type, rhs_type)) + .or_else(|| struct_coercion(lhs_type, rhs_type, type_union_coercion)) + .or_else(|| map_coercion(lhs_type, rhs_type, type_union_coercion)) } -/// Similar to [`comparison_coercion`] but prefers numeric if compares with -/// numeric and string +/// Coerce `lhs_type` and `rhs_type` to a common type for comparison +/// contexts — any context where two values are compared rather than +/// unified. This includes binary comparison operators, IN lists, +/// CASE/WHEN conditions, and BETWEEN. +/// +/// When the two types differ, this function determines the common type +/// to cast to. /// /// # Numeric comparisons /// -/// When comparing numeric values and strings, the values will be coerced to the -/// numeric type. For example, `'2' > 1` if `1` is an `Int32`, the arguments -/// will be coerced to `Int32`. -pub fn comparison_coercion_numeric( - lhs_type: &DataType, - rhs_type: &DataType, -) -> Option { - if lhs_type == rhs_type { +/// The lower precision type is widened to the higher precision type +/// (e.g., `Int32` vs `Int64` → `Int64`). +/// +/// # Numeric / String comparisons +/// +/// Prefers the numeric type (e.g., `'2' > 1` where `1` is `Int32` coerces +/// `'2'` to `Int32`). +/// +/// For type unification contexts (UNION, CASE THEN/ELSE), use +/// [`type_union_coercion`] instead. +pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + if lhs_type.equals_datatype(rhs_type) { // same type => equality is possible return Some(lhs_type.clone()); } binary_numeric_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_comparison_coercion_numeric(lhs_type, rhs_type, true)) - .or_else(|| ree_comparison_coercion_numeric(lhs_type, rhs_type, true)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type, true, comparison_coercion)) + .or_else(|| ree_coercion(lhs_type, rhs_type, true, comparison_coercion)) + .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) + .or_else(|| list_coercion(lhs_type, rhs_type, comparison_coercion)) .or_else(|| null_coercion(lhs_type, rhs_type)) - .or_else(|| string_numeric_coercion_as_numeric(lhs_type, rhs_type)) + .or_else(|| string_numeric_coercion(lhs_type, rhs_type)) + .or_else(|| string_temporal_coercion(lhs_type, rhs_type)) + .or_else(|| binary_coercion(lhs_type, rhs_type)) + .or_else(|| struct_coercion(lhs_type, rhs_type, comparison_coercion)) + .or_else(|| map_coercion(lhs_type, rhs_type, comparison_coercion)) } -/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation -/// where one is numeric and one is `Utf8`/`LargeUtf8`. +/// Coerce a numeric/string pair to the numeric type. +/// +/// Used by [`comparison_coercion`]. fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (Utf8, _) if rhs_type.is_numeric() => Some(Utf8), - (LargeUtf8, _) if rhs_type.is_numeric() => Some(LargeUtf8), - (Utf8View, _) if rhs_type.is_numeric() => Some(Utf8View), - (_, Utf8) if lhs_type.is_numeric() => Some(Utf8), - (_, LargeUtf8) if lhs_type.is_numeric() => Some(LargeUtf8), - (_, Utf8View) if lhs_type.is_numeric() => Some(Utf8View), + (lhs, Utf8 | LargeUtf8 | Utf8View) if lhs.is_numeric() => Some(lhs.clone()), + (Utf8 | LargeUtf8 | Utf8View, rhs) if rhs.is_numeric() => Some(rhs.clone()), _ => None, } } -/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation -/// where one is numeric and one is `Utf8`/`LargeUtf8`. -fn string_numeric_coercion_as_numeric( +/// Coerce a numeric/string pair to the string type. +/// +/// Used by [`type_union_coercion`]. +fn string_numeric_union_coercion( lhs_type: &DataType, rhs_type: &DataType, ) -> Option { - let lhs_logical_type = NativeType::from(lhs_type); - let rhs_logical_type = NativeType::from(rhs_type); - if lhs_logical_type.is_numeric() && rhs_logical_type == NativeType::String { - return Some(lhs_type.to_owned()); - } - if rhs_logical_type.is_numeric() && lhs_logical_type == NativeType::String { - return Some(rhs_type.to_owned()); + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (lhs @ (Utf8 | LargeUtf8 | Utf8View), _) if rhs_type.is_numeric() => { + Some(lhs.clone()) + } + (_, rhs @ (Utf8 | LargeUtf8 | Utf8View)) if lhs_type.is_numeric() => { + Some(rhs.clone()) + } + _ => None, } - - None } /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation @@ -953,7 +951,7 @@ fn string_numeric_coercion_as_numeric( /// ``` /// /// In the absence of a full type inference system, we can't determine the correct type -/// to parse the string argument +/// to parse the string argument. fn string_temporal_coercion( lhs_type: &DataType, rhs_type: &DataType, @@ -1230,7 +1228,13 @@ fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option Option { +/// Coerce two struct types by recursively coercing their fields using +/// `coerce_fn` (either [`comparison_coercion`] or [`type_union_coercion`]). +fn struct_coercion( + lhs_type: &DataType, + rhs_type: &DataType, + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { @@ -1254,10 +1258,10 @@ fn struct_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option // // See docs/source/user-guide/sql/struct_coercion.md for detailed examples. if fields_have_same_names(lhs_fields, rhs_fields) { - return coerce_struct_by_name(lhs_fields, rhs_fields); + return coerce_struct_by_name(lhs_fields, rhs_fields, coerce_fn); } - coerce_struct_by_position(lhs_fields, rhs_fields) + coerce_struct_by_position(lhs_fields, rhs_fields, coerce_fn) } _ => None, } @@ -1300,8 +1304,13 @@ fn fields_have_same_names(lhs_fields: &Fields, rhs_fields: &Fields) -> bool { .all(|lf| rhs_names.contains(lf.name().as_str())) } -/// Coerce two structs by matching fields by name. Assumes the name-sets match. -fn coerce_struct_by_name(lhs_fields: &Fields, rhs_fields: &Fields) -> Option { +/// Coerce two structs by matching fields by name using `coerce_fn`. +/// Assumes the name-sets match. +fn coerce_struct_by_name( + lhs_fields: &Fields, + rhs_fields: &Fields, + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { use arrow::datatypes::DataType::*; let rhs_by_name: HashMap<&str, &FieldRef> = @@ -1311,7 +1320,7 @@ fn coerce_struct_by_name(lhs_fields: &Fields, rhs_fields: &Fields) -> Option Option Option, ) -> Option { use arrow::datatypes::DataType::*; @@ -1335,7 +1345,7 @@ fn coerce_struct_by_position( let coerced_types: Vec = lhs_fields .iter() .zip(rhs_fields.iter()) - .map(|(l, r)| comparison_coercion(l.data_type(), r.data_type())) + .map(|(l, r)| coerce_fn(l.data_type(), r.data_type())) .collect::>>()?; // Build final fields preserving left-side names and combined nullability. @@ -1356,13 +1366,17 @@ fn coerce_fields(common_type: DataType, lhs: &FieldRef, rhs: &FieldRef) -> Field Arc::new(Field::new(name, common_type, is_nullable)) } -/// coerce two types if they are Maps by coercing their inner 'entries' fields' types -/// using struct coercion -fn map_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { +/// Coerce two Map types by coercing their inner entry fields using +/// `coerce_fn` (either [`comparison_coercion`] or [`type_union_coercion`]). +fn map_coercion( + lhs_type: &DataType, + rhs_type: &DataType, + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { (Map(lhs_field, lhs_ordered), Map(rhs_field, rhs_ordered)) => { - struct_coercion(lhs_field.data_type(), rhs_field.data_type()).map( + struct_coercion(lhs_field.data_type(), rhs_field.data_type(), coerce_fn).map( |key_value_type| { Map( Arc::new((**lhs_field).clone().with_data_type(key_value_type)), @@ -1483,15 +1497,12 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> } } -/// Generic coercion rules for Dictionaries: the type that both lhs and rhs -/// can be casted to for the purpose of a computation. -/// -/// Not all operators support dictionaries, if `preserve_dictionaries` is true -/// dictionaries will be preserved if possible. +/// Coerce two Dictionary types by coercing their value types using +/// `coerce_fn` (either [`comparison_coercion`] or [`type_union_coercion`]). /// -/// The `coerce_fn` parameter determines which comparison coercion function to use -/// for comparing the dictionary value types. -fn dictionary_comparison_coercion_generic( +/// If `preserve_dictionaries` is true, dictionaries will be preserved +/// when possible. +fn dictionary_coercion( lhs_type: &DataType, rhs_type: &DataType, preserve_dictionaries: bool, @@ -1515,52 +1526,11 @@ fn dictionary_comparison_coercion_generic( } } -/// Coercion rules for Dictionaries: the type that both lhs and rhs -/// can be casted to for the purpose of a computation. -/// -/// Not all operators support dictionaries, if `preserve_dictionaries` is true -/// dictionaries will be preserved if possible -fn dictionary_comparison_coercion( - lhs_type: &DataType, - rhs_type: &DataType, - preserve_dictionaries: bool, -) -> Option { - dictionary_comparison_coercion_generic( - lhs_type, - rhs_type, - preserve_dictionaries, - comparison_coercion, - ) -} - -/// Coercion rules for Dictionaries with numeric preference: similar to -/// [`dictionary_comparison_coercion`] but uses [`comparison_coercion_numeric`] -/// which prefers numeric types over strings when both are present. -/// -/// This is used by [`comparison_coercion_numeric`] to maintain consistent -/// numeric-preferring semantics when dealing with dictionary types. -fn dictionary_comparison_coercion_numeric( - lhs_type: &DataType, - rhs_type: &DataType, - preserve_dictionaries: bool, -) -> Option { - dictionary_comparison_coercion_generic( - lhs_type, - rhs_type, - preserve_dictionaries, - comparison_coercion_numeric, - ) -} - -/// Coercion rules for RunEndEncoded: the type that both lhs and rhs -/// can be casted to for the purpose of a computation. +/// Coerce two RunEndEncoded types using `coerce_fn` +/// (either [`comparison_coercion`] or [`type_union_coercion`]). /// -/// Not all operators support REE, if `preserve_ree` is true -/// REE will be preserved if possible -/// -/// The `coerce_fn` parameter determines which comparison coercion function to use -/// for comparing the REE value types. -fn ree_comparison_coercion_generic( +/// If `preserve_ree` is true, REE will be preserved when possible. +fn ree_coercion( lhs_type: &DataType, rhs_type: &DataType, preserve_ree: bool, @@ -1587,38 +1557,6 @@ fn ree_comparison_coercion_generic( } } -/// Coercion rules for RunEndEncoded: the type that both lhs and rhs -/// can be casted to for the purpose of a computation. -/// -/// Not all operators support REE, if `preserve_ree` is true -/// REE will be preserved if possible -fn ree_comparison_coercion( - lhs_type: &DataType, - rhs_type: &DataType, - preserve_ree: bool, -) -> Option { - ree_comparison_coercion_generic(lhs_type, rhs_type, preserve_ree, comparison_coercion) -} - -/// Coercion rules for RunEndEncoded with numeric preference: similar to -/// [`ree_comparison_coercion`] but uses [`comparison_coercion_numeric`] -/// which prefers numeric types over strings when both are present. -/// -/// This is used by [`comparison_coercion_numeric`] to maintain consistent -/// numeric-preferring semantics when dealing with REE types. -fn ree_comparison_coercion_numeric( - lhs_type: &DataType, - rhs_type: &DataType, - preserve_ree: bool, -) -> Option { - ree_comparison_coercion_generic( - lhs_type, - rhs_type, - preserve_ree, - comparison_coercion_numeric, - ) -} - /// Coercion rules for string concat. /// This is a union of string coercion rules and specified rules: /// 1. At least one side of lhs and rhs should be string type (Utf8 / LargeUtf8) @@ -1683,32 +1621,28 @@ pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { - use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - (Utf8 | LargeUtf8 | Utf8View, other_type) - | (other_type, Utf8 | LargeUtf8 | Utf8View) - if other_type.is_numeric() => - { - Some(other_type.clone()) - } - _ => None, - } -} - -/// Coerces two fields together, ensuring the field data (name and nullability) is correctly set. -fn coerce_list_children(lhs_field: &FieldRef, rhs_field: &FieldRef) -> Option { - let data_types = vec![lhs_field.data_type().clone(), rhs_field.data_type().clone()]; +/// Coerce two list element fields to a common type using the provided +/// coercion function for element types. +fn coerce_list_children( + lhs_field: &FieldRef, + rhs_field: &FieldRef, + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { Some(Arc::new( (**lhs_field) .clone() - .with_data_type(type_union_resolution(&data_types)?) + .with_data_type(coerce_fn(lhs_field.data_type(), rhs_field.data_type())?) .with_nullable(lhs_field.is_nullable() || rhs_field.is_nullable()), )) } -/// Coercion rules for list types. -fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { +/// Coerce two list types by coercing their element types via `coerce_fn` +/// (either [`comparison_coercion`] or [`type_union_coercion`]). +fn list_coercion( + lhs_type: &DataType, + rhs_type: &DataType, + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { // Coerce to the left side FixedSizeList type if the list lengths are the same, @@ -1716,11 +1650,11 @@ fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { (FixedSizeList(lhs_field, ls), FixedSizeList(rhs_field, rs)) => { if ls == rs { Some(FixedSizeList( - coerce_list_children(lhs_field, rhs_field)?, + coerce_list_children(lhs_field, rhs_field, coerce_fn)?, *rs, )) } else { - Some(List(coerce_list_children(lhs_field, rhs_field)?)) + Some(List(coerce_list_children(lhs_field, rhs_field, coerce_fn)?)) } } // LargeList on any side @@ -1728,13 +1662,13 @@ fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { LargeList(lhs_field), List(rhs_field) | LargeList(rhs_field) | FixedSizeList(rhs_field, _), ) - | (List(lhs_field) | FixedSizeList(lhs_field, _), LargeList(rhs_field)) => { - Some(LargeList(coerce_list_children(lhs_field, rhs_field)?)) - } + | (List(lhs_field) | FixedSizeList(lhs_field, _), LargeList(rhs_field)) => Some( + LargeList(coerce_list_children(lhs_field, rhs_field, coerce_fn)?), + ), // Lists on both sides (List(lhs_field), List(rhs_field) | FixedSizeList(rhs_field, _)) | (FixedSizeList(lhs_field, _), List(rhs_field)) => { - Some(List(coerce_list_children(lhs_field, rhs_field)?)) + Some(List(coerce_list_children(lhs_field, rhs_field, coerce_fn)?)) } _ => None, } @@ -1803,8 +1737,8 @@ fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_coercion(lhs_type, rhs_type) .or_else(|| binary_to_string_coercion(lhs_type, rhs_type)) - .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false)) - .or_else(|| ree_comparison_coercion(lhs_type, rhs_type, false)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type, false, string_coercion)) + .or_else(|| ree_coercion(lhs_type, rhs_type, false, string_coercion)) .or_else(|| regex_null_coercion(lhs_type, rhs_type)) .or_else(|| null_coercion(lhs_type, rhs_type)) } @@ -1821,10 +1755,11 @@ fn regex_null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { string_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type, false, string_coercion)) + .or_else(|| ree_coercion(lhs_type, rhs_type, false, string_coercion)) .or_else(|| regex_null_coercion(lhs_type, rhs_type)) } diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs index 5d1b3bea75b0a..317b022238f67 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs @@ -654,7 +654,7 @@ fn test_list_coercion() { let rhs_type = DataType::List(Arc::new(Field::new("rhs", DataType::Int64, true))); - let coerced_type = list_coercion(&lhs_type, &rhs_type).unwrap(); + let coerced_type = list_coercion(&lhs_type, &rhs_type, comparison_coercion).unwrap(); assert_eq!( coerced_type, DataType::List(Arc::new(Field::new("lhs", DataType::Int64, true))) @@ -791,3 +791,109 @@ fn test_decimal_cross_variant_comparison_coercion() -> Result<()> { Ok(()) } + +/// Tests that `comparison_coercion` prefers the numeric type when one side is +/// numeric and the other is a string (e.g., `numeric_col < '123'`). +#[test] +fn test_comparison_coercion_prefers_numeric() { + assert_eq!( + comparison_coercion(&DataType::Int32, &DataType::Utf8), + Some(DataType::Int32) + ); + assert_eq!( + comparison_coercion(&DataType::Utf8, &DataType::Int32), + Some(DataType::Int32) + ); + assert_eq!( + comparison_coercion(&DataType::Utf8, &DataType::Float64), + Some(DataType::Float64) + ); + assert_eq!( + comparison_coercion(&DataType::Float64, &DataType::Utf8), + Some(DataType::Float64) + ); + assert_eq!( + comparison_coercion(&DataType::Int64, &DataType::LargeUtf8), + Some(DataType::Int64) + ); + assert_eq!( + comparison_coercion(&DataType::Utf8View, &DataType::Int16), + Some(DataType::Int16) + ); + // String-string stays string + assert_eq!( + comparison_coercion(&DataType::Utf8, &DataType::Utf8), + Some(DataType::Utf8) + ); + // Numeric-numeric stays numeric + assert_eq!( + comparison_coercion(&DataType::Int32, &DataType::Int64), + Some(DataType::Int64) + ); +} + +/// Tests that `type_union_coercion` prefers the string type when unifying +/// numeric and string types (for UNION, CASE THEN/ELSE, etc.). +#[test] +fn test_type_union_coercion_prefers_string() { + assert_eq!( + type_union_coercion(&DataType::Int32, &DataType::Utf8), + Some(DataType::Utf8) + ); + assert_eq!( + type_union_coercion(&DataType::Utf8, &DataType::Int32), + Some(DataType::Utf8) + ); + assert_eq!( + type_union_coercion(&DataType::Float64, &DataType::Utf8), + Some(DataType::Utf8) + ); + assert_eq!( + type_union_coercion(&DataType::Utf8, &DataType::Float64), + Some(DataType::Utf8) + ); + assert_eq!( + type_union_coercion(&DataType::Int64, &DataType::LargeUtf8), + Some(DataType::LargeUtf8) + ); + assert_eq!( + type_union_coercion(&DataType::Utf8View, &DataType::Int16), + Some(DataType::Utf8View) + ); + // String-string stays string + assert_eq!( + type_union_coercion(&DataType::Utf8, &DataType::Utf8), + Some(DataType::Utf8) + ); + // Numeric-numeric stays numeric + assert_eq!( + type_union_coercion(&DataType::Int32, &DataType::Int64), + Some(DataType::Int64) + ); +} + +/// Tests that comparison operators coerce to numeric when comparing +/// numeric and string types. +#[test] +fn test_binary_comparison_string_numeric_coercion() -> Result<()> { + let comparison_ops = [ + Operator::Eq, + Operator::NotEq, + Operator::Lt, + Operator::LtEq, + Operator::Gt, + Operator::GtEq, + ]; + for op in &comparison_ops { + let (lhs, rhs) = BinaryTypeCoercer::new(&DataType::Int64, op, &DataType::Utf8) + .get_input_types()?; + assert_eq!(lhs, DataType::Int64, "Op {op}: Int64 vs Utf8 -> lhs"); + assert_eq!(rhs, DataType::Int64, "Op {op}: Int64 vs Utf8 -> rhs"); + + let (lhs, rhs) = BinaryTypeCoercer::new(&DataType::Utf8, op, &DataType::Float64) + .get_input_types()?; + assert_eq!(lhs, DataType::Float64, "Op {op}: Utf8 vs Float64 -> lhs"); + assert_eq!(rhs, DataType::Float64, "Op {op}: Utf8 vs Float64 -> rhs"); + } + Ok(()) +} diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/dictionary.rs b/datafusion/expr-common/src/type_coercion/binary/tests/dictionary.rs index 0fb56a4a2c536..f0aadfd3ce3a5 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/dictionary.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/dictionary.rs @@ -24,49 +24,49 @@ fn test_dictionary_type_coercion() { let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + dictionary_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(Int32) ); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + dictionary_coercion(&lhs_type, &rhs_type, false, comparison_coercion), Some(Int32) ); - // Since we can coerce values of Int16 to Utf8 can support this + // In comparison context, numeric is preferred over string let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), - Some(Utf8) + dictionary_coercion(&lhs_type, &rhs_type, true, comparison_coercion), + Some(Int16) ); // Since we can coerce values of Utf8 to Binary can support this let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Binary)); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + dictionary_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(Binary) ); let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Utf8; assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + dictionary_coercion(&lhs_type, &rhs_type, false, comparison_coercion), Some(Utf8) ); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + dictionary_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(lhs_type.clone()) ); let lhs_type = Utf8; let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + dictionary_coercion(&lhs_type, &rhs_type, false, comparison_coercion), Some(Utf8) ); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + dictionary_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(rhs_type.clone()) ); } diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/run_end_encoded.rs b/datafusion/expr-common/src/type_coercion/binary/tests/run_end_encoded.rs index 9997db7a82688..dab42214d755c 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/run_end_encoded.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/run_end_encoded.rs @@ -30,15 +30,15 @@ fn test_ree_type_coercion() { Arc::new(Field::new("values", Int16, false)), ); assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, true), + ree_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(Int32) ); assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, false), + ree_coercion(&lhs_type, &rhs_type, false, comparison_coercion), Some(Int32) ); - // Since we can coerce values of Int16 to Utf8 can support this: Coercion of Int16 to Utf8 + // In comparison context, numeric is preferred over string let lhs_type = RunEndEncoded( Arc::new(Field::new("run_ends", Int8, false)), Arc::new(Field::new("values", Utf8, false)), @@ -48,8 +48,8 @@ fn test_ree_type_coercion() { Arc::new(Field::new("values", Int16, false)), ); assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, true), - Some(Utf8) + ree_coercion(&lhs_type, &rhs_type, true, comparison_coercion), + Some(Int16) ); // Since we can coerce values of Utf8 to Binary can support this @@ -62,7 +62,7 @@ fn test_ree_type_coercion() { Arc::new(Field::new("values", Binary, false)), ); assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, true), + ree_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(Binary) ); let lhs_type = RunEndEncoded( @@ -72,12 +72,12 @@ fn test_ree_type_coercion() { let rhs_type = Utf8; // Don't preserve REE assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, false), + ree_coercion(&lhs_type, &rhs_type, false, comparison_coercion), Some(Utf8) ); // Preserve REE assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, true), + ree_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(lhs_type.clone()) ); @@ -88,12 +88,12 @@ fn test_ree_type_coercion() { ); // Don't preserve REE assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, false), + ree_coercion(&lhs_type, &rhs_type, false, comparison_coercion), Some(Utf8) ); // Preserve REE assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, true), + ree_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(rhs_type.clone()) ); } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index e9f41670bd7a9..5381313e2ee9b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -2132,6 +2132,8 @@ pub fn wrap_projection_for_join_if_necessary( .into_iter() .map(Expr::Column) .collect::>(); + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] + // Expr contains Arc with interior mutability but is intentionally used as hash key let join_key_items = alias_join_keys .iter() .flat_map(|expr| expr.try_as_col().is_none().then_some(expr)) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 4af2d5ba78c55..654e790667ead 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -33,7 +33,7 @@ use datafusion_expr_common::signature::ArrayFunctionArgument; use datafusion_expr_common::type_coercion::binary::type_union_resolution; use datafusion_expr_common::{ signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, - type_coercion::binary::comparison_coercion_numeric, + type_coercion::binary::comparison_coercion, type_coercion::binary::string_coercion, }; use itertools::Itertools as _; @@ -593,7 +593,7 @@ fn get_valid_types( function_length_check(function_name, current_types.len(), *num)?; let mut target_type = current_types[0].to_owned(); for data_type in current_types.iter().skip(1) { - if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) { + if let Some(dt) = comparison_coercion(&target_type, data_type) { target_type = dt; } else { return plan_err!( diff --git a/datafusion/expr/src/type_coercion/other.rs b/datafusion/expr/src/type_coercion/other.rs index 634558094ae79..48125b661e2ca 100644 --- a/datafusion/expr/src/type_coercion/other.rs +++ b/datafusion/expr/src/type_coercion/other.rs @@ -17,38 +17,58 @@ use arrow::datatypes::DataType; -use super::binary::comparison_coercion; +use super::binary::{comparison_coercion, type_union_coercion}; + +/// Fold `coerce_fn` over `types`, starting from `initial_type`. +fn fold_coerce( + initial_type: &DataType, + types: &[DataType], + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { + types + .iter() + .try_fold(initial_type.clone(), |left_type, right_type| { + coerce_fn(&left_type, right_type) + }) +} /// Attempts to coerce the types of `list_types` to be comparable with the -/// `expr_type`. -/// Returns the common data type for `expr_type` and `list_types` +/// `expr_type` for IN list predicates. +/// Returns the common data type for `expr_type` and `list_types`. +/// +/// Uses comparison coercion because `x IN (a, b)` is semantically equivalent +/// to `x = a OR x = b`. pub fn get_coerce_type_for_list( expr_type: &DataType, list_types: &[DataType], ) -> Option { - list_types - .iter() - .try_fold(expr_type.clone(), |left_type, right_type| { - comparison_coercion(&left_type, right_type) - }) + fold_coerce(expr_type, list_types, comparison_coercion) +} + +/// Find a common coerceable type for `CASE expr WHEN val1 WHEN val2 ...` +/// conditions. Returns the common type for `case_type` and all `when_types`. +/// +/// Uses comparison coercion because `CASE expr WHEN val` is semantically +/// equivalent to `expr = val`. +pub fn get_coerce_type_for_case_when( + when_types: &[DataType], + case_type: &DataType, +) -> Option { + fold_coerce(case_type, when_types, comparison_coercion) } -/// Find a common coerceable type for all `when_or_then_types` as well -/// and the `case_or_else_type`, if specified. -/// Returns the common data type for `when_or_then_types` and `case_or_else_type` +/// Find a common coerceable type for CASE THEN/ELSE result expressions. +/// Returns the common data type for `then_types` and `else_type`. +/// +/// Uses type union coercion because the result branches must be brought to a +/// common type (like UNION), not compared. pub fn get_coerce_type_for_case_expression( - when_or_then_types: &[DataType], - case_or_else_type: Option<&DataType>, + then_types: &[DataType], + else_type: Option<&DataType>, ) -> Option { - let case_or_else_type = match case_or_else_type { - None => when_or_then_types[0].clone(), - Some(data_type) => data_type.clone(), + let (initial_type, remaining) = match else_type { + None => then_types.split_first()?, + Some(data_type) => (data_type, then_types), }; - when_or_then_types - .iter() - .try_fold(case_or_else_type, |left_type, right_type| { - // TODO: now just use the `equal` coercion rule for case when. If find the issue, and - // refactor again. - comparison_coercion(&left_type, right_type) - }) + fold_coerce(initial_type, remaining, type_union_coercion) } diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 322932dcbfb8c..00c0ee3cd47b0 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -285,6 +285,7 @@ impl Accumulator for MedianAccumulator { size_of_val(self) + self.all_values.capacity() * size_of::() } + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let mut to_remove: HashMap = HashMap::new(); diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index 8a37ceef511b2..bbf6821ba09dd 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -440,6 +440,7 @@ where size_of_val(self) + self.all_values.capacity() * size_of::() } + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let mut to_remove: HashMap = HashMap::new(); for i in 0..values[0].len() { diff --git a/datafusion/functions/src/core/greatest_least_utils.rs b/datafusion/functions/src/core/greatest_least_utils.rs index 5f8b4a51186fe..2714a01832175 100644 --- a/datafusion/functions/src/core/greatest_least_utils.rs +++ b/datafusion/functions/src/core/greatest_least_utils.rs @@ -20,7 +20,7 @@ use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; use datafusion_common::{Result, ScalarValue, assert_or_internal_err, plan_err}; use datafusion_expr_common::columnar_value::ColumnarValue; -use datafusion_expr_common::type_coercion::binary::type_union_resolution; +use datafusion_expr_common::type_coercion::binary::comparison_coercion; use std::sync::Arc; pub(super) trait GreatestLeastOperator { @@ -120,13 +120,17 @@ pub(super) fn find_coerced_type( data_types: &[DataType], ) -> Result { if data_types.is_empty() { - plan_err!( + return plan_err!( "{} was called without any arguments. It requires at least 1.", Op::NAME - ) - } else if let Some(coerced_type) = type_union_resolution(data_types) { - Ok(coerced_type) - } else { - plan_err!("Cannot find a common type for arguments") + ); + } + let mut coerced = data_types[0].clone(); + for dt in &data_types[1..] { + let Some(next) = comparison_coercion(&coerced, dt) else { + return plan_err!("Cannot find a common type for arguments to {}", Op::NAME); + }; + coerced = next; } + Ok(coerced) } diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index f5c1e9310ff0d..d68296d9b862b 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -22,7 +22,7 @@ use datafusion_expr::{ ScalarUDFImpl, Signature, Volatility, conditional_expressions::CaseBuilder, simplify::{ExprSimplifyResult, SimplifyContext}, - type_coercion::binary::comparison_coercion, + type_coercion::binary::type_union_coercion, }; use datafusion_macros::user_doc; @@ -129,11 +129,9 @@ impl ScalarUDFImpl for NVL2Func { [if_non_null, if_null] .iter() .try_fold(tested.clone(), |acc, x| { - // The coerced types found by `comparison_coercion` are not guaranteed to be - // coercible for the arguments. `comparison_coercion` returns more loose - // types that can be coerced to both `acc` and `x` for comparison purpose. - // See `maybe_data_types` for the actual coercion. - let coerced_type = comparison_coercion(&acc, x); + // `type_union_coercion` finds a loose common type; the actual + // coercion is done by `maybe_data_types`. + let coerced_type = type_union_coercion(&acc, x); if let Some(coerced_type) = coerced_type { Ok(coerced_type) } else { diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index 747c54e2cd26d..6b8ae3e8531bc 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -72,6 +72,7 @@ fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result>()) } +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn replace_grouping_exprs( input: Arc, schema: &DFSchema, @@ -158,6 +159,7 @@ fn contains_grouping_function(exprs: &[Expr]) -> bool { } /// Validate that the arguments to the grouping function are in the group by clause. +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn validate_args( function: &AggregateFunction, group_by_expr: &HashMap<&Expr, usize>, @@ -178,6 +180,7 @@ fn validate_args( } } +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn grouping_function_on_id( function: &AggregateFunction, group_by_expr: &HashMap<&Expr, usize>, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index bf91595e956c3..253428288ff49 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -41,10 +41,13 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; -use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion}; +use datafusion_expr::type_coercion::binary::{ + comparison_coercion, like_coercion, type_union_coercion, +}; use datafusion_expr::type_coercion::functions::{UDFCoercionExt, fields_with_udf}; use datafusion_expr::type_coercion::other::{ - get_coerce_type_for_case_expression, get_coerce_type_for_list, + get_coerce_type_for_case_expression, get_coerce_type_for_case_when, + get_coerce_type_for_list, }; use datafusion_expr::type_coercion::{ is_datetime, is_interval, is_signed_numeric, is_timestamp, @@ -1043,8 +1046,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { .iter() .map(|(when, _then)| when.get_type(schema)) .collect::>>()?; - let coerced_type = - get_coerce_type_for_case_expression(&when_types, Some(case_type)); + let coerced_type = get_coerce_type_for_case_when(&when_types, case_type); coerced_type.ok_or_else(|| { plan_datafusion_err!( "Failed to coerce case ({case_type}) and when ({}) \ @@ -1122,7 +1124,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { /// **Field-level metadata merging**: Later fields take precedence for duplicate metadata keys. /// /// **Type coercion precedence**: The coerced type is determined by iteratively applying -/// `comparison_coercion()` between the accumulated type and each new input's type. The +/// `type_union_coercion()` between the accumulated type and each new input's type. The /// result depends on type coercion rules, not input order. /// /// **Nullability merging**: Nullability is accumulated using logical OR (`||`). @@ -1145,7 +1147,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { /// ``` /// /// **Precedence Summary**: -/// - **Datatypes**: Determined by `comparison_coercion()` rules, not input order +/// - **Datatypes**: Determined by `type_union_coercion()` rules, not input order /// - **Nullability**: Later inputs can add nullability but cannot remove it (logical OR) /// - **Metadata**: Later inputs take precedence for same keys (HashMap::extend semantics) pub fn coerce_union_schema(inputs: &[Arc]) -> Result { @@ -1195,7 +1197,7 @@ fn coerce_union_schema_with_schema( plan_schema.fields().iter() ) { let coerced_type = - comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else( + type_union_coercion(union_datatype, plan_field.data_type()).ok_or_else( || { plan_datafusion_err!( "Incompatible inputs for Union: Previous inputs were \ @@ -2367,6 +2369,9 @@ mod test { let actual = coerce_case_expression(case, &schema)?; assert_eq!(expected, actual); + // CASE string WHEN float/integer/string: comparison coercion + // prefers numeric, so the common type for the CASE expr and + // WHEN values is Float32. let case = Case { expr: Some(Box::new(col("string"))), when_then_expr: vec![ @@ -2376,7 +2381,7 @@ mod test { ], else_expr: Some(Box::new(col("string"))), }; - let case_when_common_type = Utf8; + let case_when_common_type = DataType::Float32; let then_else_common_type = Utf8; let expected = cast_helper( case.clone(), diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 590b00098bd46..c54cd287dbb46 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -144,7 +144,11 @@ impl OptimizerRule for ScalarSubqueryToJoin { } let mut all_subqueries = vec![]; + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] + // Expr contains Arc with interior mutability but is intentionally used as hash key let mut expr_to_rewrite_expr_map = HashMap::new(); + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] + // Expr contains Arc with interior mutability but is intentionally used as hash key let mut subquery_to_expr_map = HashMap::new(); for expr in projection.expr.iter() { let (subqueries, rewrite_exprs) = diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index e4455b8c82fc1..1af08c91c1109 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1784,6 +1784,8 @@ impl TreeNodeRewriter for Simplifier<'_> { }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { let lhs = to_inlist(*left).unwrap(); let rhs = to_inlist(*right).unwrap(); + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] + // Expr contains Arc with interior mutability but is intentionally used as hash key let mut seen: HashSet = HashSet::new(); let list = lhs .list @@ -2174,6 +2176,7 @@ impl<'a> StringScalar<'a> { } } +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool { let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect(); iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile()) @@ -2258,6 +2261,7 @@ fn to_inlist(expr: Expr) -> Option { /// Return the union of two inlist expressions /// maintaining the order of the elements in the two lists +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { // extend the list in l1 with the elements in l2 that are not already in l1 let l1_items: HashSet<_> = l1.list.iter().collect(); @@ -2276,6 +2280,7 @@ fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { /// Return the intersection of two inlist expressions /// maintaining the order of the elements in the two lists +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result { let l2_items = l2.list.iter().collect::>(); @@ -2292,6 +2297,7 @@ fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result Result { let l2_items = l2.list.iter().collect::>(); diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 35cfed228121e..4ac40df2201e5 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -1430,8 +1430,8 @@ mod tests { use arrow::datatypes::Field; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::plan_err; - use datafusion_common::tree_node::TransformedResult; - use datafusion_expr::type_coercion::binary::comparison_coercion; + use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; + use datafusion_expr::type_coercion::binary::type_union_coercion; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr_common::physical_expr::fmt_sql; use half::f16; @@ -2381,9 +2381,7 @@ mod tests { thens_type .iter() .try_fold(else_type, |left_type, right_type| { - // TODO: now just use the `equal` coercion rule for case when. If find the issue, and - // refactor again. - comparison_coercion(&left_type, right_type) + type_union_coercion(&left_type, right_type) }) } diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index c4ce74fd3a573..00195a4df6307 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -93,6 +93,7 @@ impl LiteralGuarantee { /// Create a new instance of the guarantee if the provided operator is /// supported. Returns None otherwise. See [`LiteralGuarantee::analyze`] to /// create these structures from an predicate (boolean expression). + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key fn new<'a>( column_name: impl Into, guarantee: Guarantee, @@ -309,6 +310,7 @@ impl<'a> GuaranteeBuilder<'a> { /// * `AND (a IN (1,2,3))`: a is in (1, 2, or 3) /// * `AND (a != 1 OR a != 2 OR a != 3)`: a is not in (1, 2, or 3) /// * `AND (a NOT IN (1,2,3))`: a is not in (1, 2, or 3) + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key fn aggregate_multi_conjunct( mut self, col: &'a crate::expressions::Column, diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index 4686648fb1e3d..efaf7eba0f1b5 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -28,6 +28,7 @@ use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use half::f16; use hashbrown::hash_table::HashTable; +#[cfg(not(feature = "force_hash_collisions"))] use std::hash::BuildHasher; use std::mem::size_of; use std::sync::Arc; diff --git a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs index 2666ab8822ed9..9084ea449d6b9 100644 --- a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs +++ b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs @@ -62,7 +62,11 @@ impl InProgressSpillFile { )); } if self.writer.is_none() { - let schema = batch.schema(); + // Use the SpillManager's declared schema rather than the batch's schema. + // Individual batches may have different schemas (e.g., different nullability) + // when they come from different branches of a UnionExec. The SpillManager's + // schema represents the canonical schema that all batches should conform to. + let schema = self.spill_writer.schema(); if let Some(in_progress_file) = &mut self.in_progress_file { self.writer = Some(IPCStreamWriter::new( in_progress_file.path(), @@ -138,3 +142,77 @@ impl InProgressSpillFile { Ok(self.in_progress_file.take()) } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int64Array; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_physical_expr_common::metrics::{ + ExecutionPlanMetricsSet, SpillMetrics, + }; + use futures::TryStreamExt; + + #[tokio::test] + async fn test_spill_file_uses_spill_manager_schema() -> Result<()> { + let nullable_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("val", DataType::Int64, true), + ])); + let non_nullable_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("val", DataType::Int64, false), + ])); + + let runtime = Arc::new(RuntimeEnvBuilder::new().build()?); + let metrics_set = ExecutionPlanMetricsSet::new(); + let spill_metrics = SpillMetrics::new(&metrics_set, 0); + let spill_manager = Arc::new(SpillManager::new( + runtime, + spill_metrics, + Arc::clone(&nullable_schema), + )); + + let mut in_progress = spill_manager.create_in_progress_file("test")?; + + // First batch: non-nullable val (simulates literal-0 UNION branch) + let non_nullable_batch = RecordBatch::try_new( + Arc::clone(&non_nullable_schema), + vec![ + Arc::new(Int64Array::from(vec![1, 2, 3])), + Arc::new(Int64Array::from(vec![0, 0, 0])), + ], + )?; + in_progress.append_batch(&non_nullable_batch)?; + + // Second batch: nullable val with NULLs (simulates table UNION branch) + let nullable_batch = RecordBatch::try_new( + Arc::clone(&nullable_schema), + vec![ + Arc::new(Int64Array::from(vec![4, 5, 6])), + Arc::new(Int64Array::from(vec![Some(10), None, Some(30)])), + ], + )?; + in_progress.append_batch(&nullable_batch)?; + + let spill_file = in_progress.finish()?.unwrap(); + + let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + + // Stream schema should be nullable + assert_eq!(stream.schema(), nullable_schema); + + let batches = stream.try_collect::>().await?; + assert_eq!(batches.len(), 2); + + // Both batches must have the SpillManager's nullable schema + assert_eq!( + batches[0], + non_nullable_batch.with_schema(Arc::clone(&nullable_schema))? + ); + assert_eq!(batches[1], nullable_batch); + + Ok(()) + } +} diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 6f6b00e80abc2..978d79b1f2fb0 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -2158,6 +2158,7 @@ mod tests { } /// Add contained information. + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key pub fn with_contained( mut self, values: impl IntoIterator, @@ -2172,6 +2173,7 @@ mod tests { } /// get any contained information for the specified values + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key fn contained(&self, find_values: &HashSet) -> Option { // find the one with the matching values self.contained diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index bd5f2bb18aaec..0651fe5651de9 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -57,6 +57,7 @@ datafusion-functions = { workspace = true, features = ["crypto_expressions"] } datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } log = { workspace = true } +num-traits = { workspace = true } percent-encoding = "2.3.2" rand = { workspace = true } serde_json = { workspace = true } diff --git a/datafusion/spark/src/function/conversion/cast.rs b/datafusion/spark/src/function/conversion/cast.rs index f5e6e36c84041..45d1b336261d7 100644 --- a/datafusion/spark/src/function/conversion/cast.rs +++ b/datafusion/spark/src/function/conversion/cast.rs @@ -17,12 +17,13 @@ use arrow::array::{Array, ArrayRef, AsArray, TimestampMicrosecondBuilder}; use arrow::datatypes::{ - ArrowPrimitiveType, DataType, Field, FieldRef, Int8Type, Int16Type, Int32Type, - Int64Type, TimeUnit, + ArrowPrimitiveType, DataType, Field, FieldRef, Float32Type, Float64Type, Int8Type, + Int16Type, Int32Type, Int64Type, TimeUnit, }; use datafusion_common::config::ConfigOptions; use datafusion_common::types::{ - logical_int8, logical_int16, logical_int32, logical_int64, logical_string, + logical_float32, logical_float64, logical_int8, logical_int16, logical_int32, + logical_int64, logical_string, }; use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{Coercion, TypeSignatureClass}; @@ -34,12 +35,52 @@ use std::sync::Arc; const MICROS_PER_SECOND: i64 = 1_000_000; -/// Convert seconds to microseconds with saturating overflow behavior (matches spark spec) +/// Convert integer seconds to microseconds with saturating overflow behavior #[inline] fn secs_to_micros(secs: i64) -> i64 { secs.saturating_mul(MICROS_PER_SECOND) } +/// Convert float seconds to microseconds +/// Returns None for NaN/Infinity in non-ANSI mode, error in ANSI mode +/// Saturates to i64::MAX/MIN for overflow +#[inline] +fn float_secs_to_micros(val: f64, enable_ansi_mode: bool) -> Result> { + if val.is_nan() || val.is_infinite() { + if enable_ansi_mode { + let display_val = if val.is_nan() { + "NaN" + } else if val.is_sign_positive() { + "Infinity" + } else { + "-Infinity" + }; + return exec_err!("Cannot cast {} to TIMESTAMP", display_val); + } + return Ok(None); + } + let micros = val * MICROS_PER_SECOND as f64; + + // Bounds check for i64 range. + // Note on precision: i64::MIN (-2^63) is exactly representable in f64, + // but i64::MAX (2^63 - 1) is not - it rounds up to 2^63 (i64::MAX + 1). + // We use strict `<` for the upper bound to reject values >= 2^63, + // which correctly handles the precision loss edge case. + if micros >= i64::MIN as f64 && micros < i64::MAX as f64 { + Ok(Some(micros as i64)) + } else { + if enable_ansi_mode { + return exec_err!("Overflow casting {} to TIMESTAMP", val); + } + // Saturate to i64::MAX or i64::MIN like Spark does for overflow + if micros.is_sign_negative() { + Ok(Some(i64::MIN)) + } else { + Ok(Some(i64::MAX)) + } + } +} + /// Spark-compatible `cast` function for type conversions /// /// This implements Spark's CAST expression with a target type parameter @@ -50,10 +91,11 @@ fn secs_to_micros(secs: i64) -> i64 { /// ``` /// /// # Currently supported conversions -/// - Int8/Int16/Int32/Int64 -> Timestamp (target_type = 'timestamp') +/// - Int8/Int16/Int32/Int64/Float32/Float64 -> Timestamp (target_type = 'timestamp') /// /// The integer value is interpreted as seconds since the Unix epoch (1970-01-01 00:00:00 UTC) -/// and converted to a timestamp with microsecond precision (matches spark's spec) +/// and converted to a timestamp with microsecond precision (matches spark's spec). Same is the case +/// with Float but with higher precision to support micro / nanoseconds. /// /// # Overflow behavior /// Uses saturating multiplication to handle overflow - values that would overflow @@ -79,28 +121,30 @@ impl SparkCast { } pub fn new_with_config(config: &ConfigOptions) -> Self { - // First arg: value to cast (only signed ints - Spark doesn't have unsigned integers) + // First arg: value to cast // Second arg: target datatype as Utf8 string literal (ex : 'timestamp') let string_arg = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); - // Spark only supports signed integers, so we explicitly list them - let signed_int_signatures = [ + // Supported input types: signed integers and floats + let input_type_signatures = [ logical_int8(), logical_int16(), logical_int32(), logical_int64(), + logical_float32(), + logical_float64(), ] - .map(|int_type| { + .map(|input_type| { TypeSignature::Coercible(vec![ - Coercion::new_exact(TypeSignatureClass::Native(int_type)), + Coercion::new_exact(TypeSignatureClass::Native(input_type)), string_arg.clone(), ]) }); Self { signature: Signature::new( - TypeSignature::OneOf(Vec::from(signed_int_signatures)), + TypeSignature::OneOf(Vec::from(input_type_signatures)), Volatility::Stable, ), timezone: config @@ -165,6 +209,35 @@ where Ok(Arc::new(builder.finish().with_timezone_opt(timezone))) } +/// Cast float to timestamp +/// Float value represents seconds (with fractional part) since Unix epoch +/// NaN and Infinity: error in ANSI mode, NULL in non-ANSI mode +fn cast_float_to_timestamp( + array: &ArrayRef, + timezone: Option>, + enable_ansi_mode: bool, +) -> Result +where + T::Native: Into, +{ + let arr = array.as_primitive::(); + let mut builder = TimestampMicrosecondBuilder::with_capacity(arr.len()); + + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + let val: f64 = arr.value(i).into(); + match float_secs_to_micros(val, enable_ansi_mode)? { + Some(micros) => builder.append_value(micros), + None => builder.append_null(), + } + } + } + + Ok(Arc::new(builder.finish().with_timezone_opt(timezone))) +} + impl ScalarUDFImpl for SparkCast { fn name(&self) -> &str { "spark_cast" @@ -183,19 +256,19 @@ impl ScalarUDFImpl for SparkCast { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); let return_type = get_target_type_from_scalar_args( args.scalar_arguments, self.timezone.clone(), )?; - Ok(Arc::new(Field::new(self.name(), return_type, nullable))) + Ok(Arc::new(Field::new(self.name(), return_type, true))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let enable_ansi_mode = args.config_options.execution.enable_ansi_mode; let target_type = args.return_field.data_type(); match target_type { DataType::Timestamp(TimeUnit::Microsecond, tz) => { - cast_to_timestamp(&args.args[0], tz.clone()) + cast_to_timestamp(&args.args[0], tz.clone(), enable_ansi_mode) } other => exec_err!("Unsupported spark_cast target type: {:?}", other), } @@ -206,6 +279,7 @@ impl ScalarUDFImpl for SparkCast { fn cast_to_timestamp( input: &ColumnarValue, timezone: Option>, + enable_ansi_mode: bool, ) -> Result { match input { ColumnarValue::Array(array) => match array.data_type() { @@ -225,6 +299,20 @@ fn cast_to_timestamp( DataType::Int64 => Ok(ColumnarValue::Array(cast_int_to_timestamp::< Int64Type, >(array, timezone)?)), + DataType::Float32 => Ok(ColumnarValue::Array(cast_float_to_timestamp::< + Float32Type, + >( + array, + timezone, + enable_ansi_mode, + )?)), + DataType::Float64 => Ok(ColumnarValue::Array(cast_float_to_timestamp::< + Float64Type, + >( + array, + timezone, + enable_ansi_mode, + )?)), other => exec_err!("Unsupported cast from {:?} to timestamp", other), }, ColumnarValue::Scalar(scalar) => { @@ -233,11 +321,19 @@ fn cast_to_timestamp( | ScalarValue::Int8(None) | ScalarValue::Int16(None) | ScalarValue::Int32(None) - | ScalarValue::Int64(None) => None, + | ScalarValue::Int64(None) + | ScalarValue::Float32(None) + | ScalarValue::Float64(None) => None, ScalarValue::Int8(Some(v)) => Some(secs_to_micros((*v).into())), ScalarValue::Int16(Some(v)) => Some(secs_to_micros((*v).into())), ScalarValue::Int32(Some(v)) => Some(secs_to_micros((*v).into())), ScalarValue::Int64(Some(v)) => Some(secs_to_micros(*v)), + ScalarValue::Float32(Some(v)) => { + float_secs_to_micros(*v as f64, enable_ansi_mode)? + } + ScalarValue::Float64(Some(v)) => { + float_secs_to_micros(*v, enable_ansi_mode)? + } other => { return exec_err!("Unsupported cast from {:?} to timestamp", other); } @@ -252,7 +348,9 @@ fn cast_to_timestamp( #[cfg(test)] mod tests { use super::*; - use arrow::array::{Int8Array, Int16Array, Int32Array, Int64Array}; + use arrow::array::{ + Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, + }; use arrow::datatypes::TimestampMicrosecondType; // helpers to make testing easier @@ -651,4 +749,259 @@ mod tests { // Defaults to UTC assert_scalar_timestamp_with_tz(result, 0, "UTC"); } + + fn make_args_with_ansi_mode( + input: ColumnarValue, + target_type: &str, + enable_ansi_mode: bool, + ) -> ScalarFunctionArgs { + let return_field = Arc::new(Field::new( + "result", + DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), + true, + )); + let mut config = ConfigOptions::default(); + config.execution.time_zone = Some("UTC".to_string()); + config.execution.enable_ansi_mode = enable_ansi_mode; + ScalarFunctionArgs { + args: vec![ + input, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_type.to_string()))), + ], + arg_fields: vec![], + number_rows: 0, + return_field, + config_options: Arc::new(config), + } + } + + #[test] + fn test_cast_float64_array_to_timestamp() { + let array: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(0.0), + Some(1.5), + Some(-1.5), + Some(1704067200.123456), + None, + ])); + + let cast = SparkCast::new(); + let args = make_args(ColumnarValue::Array(array), "timestamp"); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = result_array.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_500_000); // 1.5 seconds + assert_eq!(ts_array.value(2), -1_500_000); // -1.5 seconds + assert_eq!(ts_array.value(3), 1_704_067_200_123_456); // with fractional + assert!(ts_array.is_null(4)); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_cast_float32_array_to_timestamp() { + let array: ArrayRef = Arc::new(Float32Array::from(vec![ + Some(0.0f32), + Some(1.5f32), + Some(-1.5f32), + None, + ])); + + let cast = SparkCast::new(); + let args = make_args(ColumnarValue::Array(array), "timestamp"); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = result_array.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_500_000); // 1.5 seconds + assert_eq!(ts_array.value(2), -1_500_000); // -1.5 seconds + assert!(ts_array.is_null(3)); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_cast_scalar_float64() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Float64(Some(1.5))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_timestamp(result, 1_500_000); + } + + #[test] + fn test_cast_scalar_float32() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Float32(Some(1.5f32))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_timestamp(result, 1_500_000); + } + + #[test] + fn test_cast_float_nan_non_ansi_mode() { + // In non-ANSI mode, NaN should return NULL + let cast = SparkCast::new(); + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::NAN))), + "timestamp", + false, + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_null(result); + } + + #[test] + fn test_cast_float_infinity_non_ansi_mode() { + // In non-ANSI mode, Infinity should return NULL + let cast = SparkCast::new(); + + // Positive infinity + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::INFINITY))), + "timestamp", + false, + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_null(result); + + // Negative infinity + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::NEG_INFINITY))), + "timestamp", + false, + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_null(result); + } + + #[test] + fn test_cast_float_nan_ansi_mode() { + // In ANSI mode, NaN should error + let cast = SparkCast::new(); + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::NAN))), + "timestamp", + true, + ); + let result = cast.invoke_with_args(args); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Cannot cast NaN")); + } + + #[test] + fn test_cast_float_infinity_ansi_mode() { + // In ANSI mode, Infinity should error + let cast = SparkCast::new(); + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::INFINITY))), + "timestamp", + true, + ); + let result = cast.invoke_with_args(args); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Cannot cast Infinity") + ); + } + + #[test] + fn test_cast_float_overflow_non_ansi_mode() { + // Value too large to fit in i64 microseconds - should saturate to i64::MAX like Spark + let cast = SparkCast::new(); + let large_value = 1e19; // Way too large for i64 microseconds + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(large_value))), + "timestamp", + false, + ); + let result = cast.invoke_with_args(args).unwrap(); + // Spark saturates overflow to i64::MAX + assert_scalar_timestamp(result, i64::MAX); + } + + #[test] + fn test_cast_float_negative_overflow_non_ansi_mode() { + // Large negative value - should saturate to i64::MIN like Spark + let cast = SparkCast::new(); + let large_value = -1e19; // Way too large negative for i64 microseconds + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(large_value))), + "timestamp", + false, + ); + let result = cast.invoke_with_args(args).unwrap(); + // Spark saturates negative overflow to i64::MIN + assert_scalar_timestamp(result, i64::MIN); + } + + #[test] + fn test_cast_float_overflow_ansi_mode() { + // Value too large to fit in i64 microseconds - should error in ANSI mode + let cast = SparkCast::new(); + let large_value = 1e19; // Way too large for i64 microseconds + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(large_value))), + "timestamp", + true, + ); + let result = cast.invoke_with_args(args); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Overflow")); + } + + #[test] + fn test_cast_float_array_with_nan_and_infinity() { + // Array with NaN and Infinity in non-ANSI mode + let array: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(1.0), + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(2.0), + ])); + + let cast = SparkCast::new(); + let args = + make_args_with_ansi_mode(ColumnarValue::Array(array), "timestamp", false); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = result_array.as_primitive::(); + assert_eq!(ts_array.value(0), 1_000_000); + assert!(ts_array.is_null(1)); // NaN -> NULL + assert!(ts_array.is_null(2)); // Infinity -> NULL + assert!(ts_array.is_null(3)); // -Infinity -> NULL + assert_eq!(ts_array.value(4), 2_000_000); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_cast_float_negative_values() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Float64(Some(-86400.5))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + // -86400.5 seconds = -86400500000 microseconds (1 day and 0.5 seconds before epoch) + assert_scalar_timestamp(result, -86_400_500_000); + } } diff --git a/datafusion/spark/src/function/map/utils.rs b/datafusion/spark/src/function/map/utils.rs index 28fa3227fd628..f5fff0c4b4c46 100644 --- a/datafusion/spark/src/function/map/utils.rs +++ b/datafusion/spark/src/function/map/utils.rs @@ -147,6 +147,7 @@ pub fn map_from_keys_values_offsets_nulls( )?)) } +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key fn map_deduplicate_keys( flat_keys: &ArrayRef, flat_values: &ArrayRef, diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 7f7d04e06b0be..be3f34d3323fb 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -23,6 +23,7 @@ pub mod hex; pub mod modulus; pub mod negative; pub mod rint; +pub mod round; pub mod trigonometry; pub mod unhex; pub mod width_bucket; @@ -38,6 +39,7 @@ make_udf_function!(hex::SparkHex, hex); make_udf_function!(modulus::SparkMod, modulus); make_udf_function!(modulus::SparkPmod, pmod); make_udf_function!(rint::SparkRint, rint); +make_udf_function!(round::SparkRound, round); make_udf_function!(unhex::SparkUnhex, unhex); make_udf_function!(width_bucket::SparkWidthBucket, width_bucket); make_udf_function!(trigonometry::SparkCsc, csc); @@ -63,6 +65,11 @@ pub mod expr_fn { "Returns the double value that is closest in value to the argument and is equal to a mathematical integer.", arg1 )); + export_functions!(( + round, + "Rounds the value of expr to scale decimal places using HALF_UP rounding mode.", + arg1 arg2 + )); export_functions!((unhex, "Converts hexadecimal string to binary.", arg1)); export_functions!((width_bucket, "Returns the bucket number into which the value of this expression would fall after being evaluated.", arg1 arg2 arg3 arg4)); export_functions!((csc, "Returns the cosecant of expr.", arg1)); @@ -88,6 +95,7 @@ pub fn functions() -> Vec> { modulus(), pmod(), rint(), + round(), unhex(), width_bucket(), csc(), diff --git a/datafusion/spark/src/function/math/round.rs b/datafusion/spark/src/function/math/round.rs new file mode 100644 index 0000000000000..05745666183d3 --- /dev/null +++ b/datafusion/spark/src/function/math/round.rs @@ -0,0 +1,654 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::*; +use arrow::datatypes::{ + ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, + Decimal256Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, + Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; +use datafusion_common::types::{ + NativeType, logical_float32, logical_float64, logical_int32, +}; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; + +/// Spark-compatible `round` expression +/// +/// +/// Rounds the value of `expr` to `scale` decimal places using HALF_UP rounding mode. +/// Returns the same type as the input expression. +/// +/// - `round(expr)` rounds to 0 decimal places (default scale = 0) +/// - `round(expr, scale)` rounds to `scale` decimal places +/// - For integer types with negative scale: `round(25, -1)` → `30` +/// - Uses HALF_UP rounding: 2.5 → 3, -2.5 → -3 (away from zero) +/// +/// Supported types: Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, +/// Float16, Float32, Float64, Decimal32, Decimal64, Decimal128, Decimal256 +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRound { + signature: Signature, +} + +impl Default for SparkRound { + fn default() -> Self { + Self::new() + } +} + +impl SparkRound { + pub fn new() -> Self { + let decimal = Coercion::new_exact(TypeSignatureClass::Decimal); + let integer = Coercion::new_exact(TypeSignatureClass::Integer); + let decimal_places = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ); + let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32())); + let float64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); + Self { + signature: Signature::one_of( + vec![ + // round(decimal, scale) + TypeSignature::Coercible(vec![ + decimal.clone(), + decimal_places.clone(), + ]), + // round(decimal) + TypeSignature::Coercible(vec![decimal]), + // round(integer, scale) + TypeSignature::Coercible(vec![ + integer.clone(), + decimal_places.clone(), + ]), + // round(integer) + TypeSignature::Coercible(vec![integer]), + // round(float32, scale) + TypeSignature::Coercible(vec![ + float32.clone(), + decimal_places.clone(), + ]), + // round(float32) + TypeSignature::Coercible(vec![float32]), + // round(float64, scale) + TypeSignature::Coercible(vec![float64.clone(), decimal_places]), + // round(float64) + TypeSignature::Coercible(vec![float64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkRound { + fn name(&self) -> &str { + "round" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_round(&args.args, args.config_options.execution.enable_ansi_mode) + } +} + +/// Extract the scale (decimal places) from the second argument. +/// Returns `Some(0)` if no second argument is provided. +/// Returns `None` if the scale argument is NULL (Spark returns NULL for `round(expr, NULL)`). +fn get_scale(args: &[ColumnarValue]) -> Result> { + if args.len() < 2 { + return Ok(Some(0)); + } + + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int8(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::Int16(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => Ok(Some(*v)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(ScalarValue::UInt8(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::UInt16(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::UInt32(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(ScalarValue::UInt64(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(sv) if sv.is_null() => Ok(None), + other => exec_err!("Unsupported type for round scale: {}", other.data_type()), + } +} + +/// Round a floating-point value to the given number of decimal places using +/// HALF_UP rounding mode (ties round away from zero). +/// +/// This matches Spark's `RoundBase` behaviour for `FloatType` / `DoubleType`, +/// which internally converts the value to `BigDecimal` and rounds with +/// `RoundingMode.HALF_UP`. +/// +/// # Arguments +/// * `value` – the floating-point number to round +/// * `scale` – number of decimal places to keep. +/// - `scale >= 0`: rounds to that many fractional digits +/// (e.g. `round_float(2.345, 2) == 2.35`) +/// - `scale < 0`: rounds to the left of the decimal point +/// (e.g. `round_float(125.0, -1) == 130.0`) +/// +/// # Examples +/// ```text +/// round_float(2.5, 0) → 3.0 // half rounds up +/// round_float(-2.5, 0) → -3.0 // half rounds away from zero +/// round_float(1.4, 0) → 1.0 +/// round_float(125.0, -1) → 130.0 +/// ``` +fn round_float(value: T, scale: i32) -> T { + if scale >= 0 { + let factor = T::from(10.0f64.powi(scale)).unwrap_or_else(T::infinity); + if factor.is_infinite() { + // Very large positive scale — value is already precise enough, return as-is + return value; + } + (value * factor).round() / factor + } else { + let factor = T::from(10.0f64.powi(-scale)).unwrap_or_else(T::infinity); + if factor.is_infinite() { + // Very large negative scale — any finite value rounds to 0 + return T::zero(); + } + (value / factor).round() * factor + } +} + +/// Round an integer value to the given scale using HALF_UP rounding mode. +/// +/// Only meaningful when `scale` is negative — a non-negative scale leaves +/// the integer unchanged because integers have no fractional part. +/// +/// This matches Spark's `RoundBase` behaviour for `ByteType`, `ShortType`, +/// `IntegerType`, and `LongType`, which round to the nearest power-of-ten +/// boundary and return the same integer type. +/// +/// In ANSI mode, overflow conditions return an error instead of wrapping. +/// +/// # Arguments +/// * `value` – the integer to round (widened to `i64` by callers) +/// * `scale` – rounding position relative to the ones digit. +/// - `scale >= 0`: returns `value` as-is +/// - `scale == -1`: rounds to the nearest 10 +/// - `scale == -2`: rounds to the nearest 100 +/// - If `10^|scale|` overflows `i64`, returns `0` +/// * `enable_ansi_mode` – when true, overflow returns an error +/// +/// # Examples +/// ```text +/// round_integer(25, -1, false) → Ok(30) +/// round_integer(-25, -1, false) → Ok(-30) +/// round_integer(123, -1, false) → Ok(120) +/// round_integer(150, -2, false) → Ok(200) +/// round_integer(42, 2, false) → Ok(42) // no-op for positive scale +/// round_integer(42, -10, false) → Ok(0) // factor overflows → 0 +/// ``` +fn round_integer(value: i64, scale: i32, enable_ansi_mode: bool) -> Result { + if scale >= 0 { + return Ok(value); + } + let abs_scale = (-scale) as u32; + let Some(factor) = 10_i64.checked_pow(abs_scale) else { + return Ok(0); + }; + let remainder = value % factor; + let threshold = factor / 2; + let result = if remainder >= threshold { + if enable_ansi_mode { + value + .checked_sub(remainder) + .and_then(|v| v.checked_add(factor)) + .ok_or_else(|| { + (exec_err!("Int64 overflow on round({value}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + value.wrapping_sub(remainder).wrapping_add(factor) + } + } else if remainder <= -threshold { + if enable_ansi_mode { + value + .checked_sub(remainder) + .and_then(|v| v.checked_sub(factor)) + .ok_or_else(|| { + (exec_err!("Int64 overflow on round({value}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + value.wrapping_sub(remainder).wrapping_sub(factor) + } + } else { + value - remainder + }; + Ok(result) +} + +// --------------------------------------------------------------------------- +// Decimal rounding using ArrowNativeTypeOp (HALF_UP) +// --------------------------------------------------------------------------- + +/// Round a decimal value represented as its unscaled integer using HALF_UP +/// rounding mode (ties round away from zero). +/// +/// This matches Spark's `RoundBase` behaviour for `DecimalType`, which calls +/// `BigDecimal.setScale(scale, RoundingMode.HALF_UP)`. +/// +/// Decimals are stored as `(unscaled_value, precision, scale)` where the real +/// value equals `unscaled_value * 10^(-scale)`. This function operates on the +/// unscaled integer directly: +/// +/// 1. Compute `diff = input_scale - decimal_places`. +/// If `diff <= 0` the requested precision is finer than (or equal to) the +/// stored scale, so nothing needs to be rounded — return as-is. +/// 2. Divide by `10^diff` to shift the rounding boundary into the ones digit. +/// 3. Inspect the remainder to decide whether to round up or down (HALF_UP). +/// 4. Multiply back by `10^diff` so the result is expressed at the original +/// `input_scale`. +/// +/// # Arguments +/// * `value` – unscaled decimal value +/// * `input_scale` – scale of the incoming decimal +/// * `decimal_places` – number of fractional digits to keep (may be negative) +/// +/// # Returns +/// The rounded unscaled value at the same `input_scale`, or an error +/// on overflow. +/// +/// # Examples +/// ```text +/// // 2.5 (unscaled 25, scale 1) rounded to 0 places → 3.0 (unscaled 30) +/// round_decimal(25_i128, 1, 0) → Ok(30) +/// +/// // 2.345 (unscaled 2345, scale 3) rounded to 2 places → 2.350 (unscaled 2350) +/// round_decimal(2345_i128, 3, 2) → Ok(2350) +/// ``` +fn round_decimal( + value: V, + input_scale: i8, + decimal_places: i32, +) -> Result { + let diff = i64::from(input_scale) - i64::from(decimal_places); + if diff <= 0 { + // Nothing to round – the requested precision is finer than (or equal to) the + // stored scale. + return Ok(value); + } + + let diff = diff as u32; + + let one = V::ONE; + let two = V::from_usize(2).ok_or_else(|| { + (exec_err!("Internal error: could not create constant 2") as Result<(), _>) + .unwrap_err() + })?; + let ten = V::from_usize(10).ok_or_else(|| { + (exec_err!("Internal error: could not create constant 10") as Result<(), _>) + .unwrap_err() + })?; + + let Ok(factor) = ten.pow_checked(diff) else { + // 10^diff overflows the decimal type — the rounding position is beyond + // the representable range, so any value rounds to 0. + // This matches Spark's BigDecimal.setScale behavior where rounding to a + // scale far beyond the number's magnitude yields 0. + return Ok(V::ZERO); + }; + + let mut quotient = value.div_wrapping(factor); + let remainder = value.mod_wrapping(factor); + + // HALF_UP: round away from zero when remainder is exactly half + let threshold = factor.div_wrapping(two); + if remainder >= threshold { + quotient = quotient.add_checked(one).map_err(|_| { + (exec_err!("Overflow while rounding decimal") as Result<(), _>).unwrap_err() + })?; + } else if remainder <= threshold.neg_wrapping() { + quotient = quotient.sub_checked(one).map_err(|_| { + (exec_err!("Overflow while rounding decimal") as Result<(), _>).unwrap_err() + })?; + } + + // Re-scale the quotient back to `input_scale` so the returned unscaled integer is + // at the original scale. `factor` is already `10^diff` which is exactly the shift + // we need. + quotient.mul_checked(factor).map_err(|_| { + (exec_err!("Overflow while rounding decimal") as Result<(), _>).unwrap_err() + }) +} + +// --------------------------------------------------------------------------- +// Macros for array dispatch +// --------------------------------------------------------------------------- + +macro_rules! impl_integer_array_round { + ($array:expr, $arrow_type:ty, $scale:expr, $enable_ansi_mode:expr) => {{ + let array = $array.as_primitive::<$arrow_type>(); + type Native = <$arrow_type as arrow::datatypes::ArrowPrimitiveType>::Native; + let result: PrimitiveArray<$arrow_type> = if $enable_ansi_mode { + array.try_unary(|x| { + let v = round_integer(x as i64, $scale, true)?; + Native::try_from(v).map_err(|_| { + (exec_err!( + "{} overflow on round({x}, {})", + stringify!($arrow_type), + $scale + ) as Result<(), _>) + .unwrap_err() + }) + })? + } else { + array.unary(|x| round_integer(x as i64, $scale, false).unwrap() as Native) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +macro_rules! impl_float_array_round { + ($array:expr, $arrow_type:ty, $scale:expr) => {{ + let array = $array.as_primitive::<$arrow_type>(); + let result: PrimitiveArray<$arrow_type> = array.unary(|x| round_float(x, $scale)); + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +macro_rules! impl_decimal_array_round { + ($array:expr, $arrow_type:ty, $input_scale:expr, $scale:expr) => {{ + let array = $array.as_primitive::<$arrow_type>(); + let result: PrimitiveArray<$arrow_type> = array + .try_unary(|x| round_decimal(x, $input_scale, $scale))? + .with_data_type($array.data_type().clone()); + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +// --------------------------------------------------------------------------- +// Core dispatch +// --------------------------------------------------------------------------- + +fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result { + if args.is_empty() || args.len() > 2 { + return exec_err!("round requires 1 or 2 arguments, got {}", args.len()); + } + + let scale = match get_scale(args)? { + Some(s) => s, + None => { + // NULL scale → return NULL with the same data type as the first argument + return Ok(ColumnarValue::Scalar(ScalarValue::try_from( + args[0].data_type(), + )?)); + } + }; + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null => Ok(args[0].clone()), + + // Integer types + DataType::Int8 => { + impl_integer_array_round!(array, Int8Type, scale, enable_ansi_mode) + } + DataType::Int16 => { + impl_integer_array_round!(array, Int16Type, scale, enable_ansi_mode) + } + DataType::Int32 => { + impl_integer_array_round!(array, Int32Type, scale, enable_ansi_mode) + } + DataType::Int64 => { + impl_integer_array_round!(array, Int64Type, scale, enable_ansi_mode) + } + + // Unsigned integer types + DataType::UInt8 => { + impl_integer_array_round!(array, UInt8Type, scale, enable_ansi_mode) + } + DataType::UInt16 => { + impl_integer_array_round!(array, UInt16Type, scale, enable_ansi_mode) + } + DataType::UInt32 => { + impl_integer_array_round!(array, UInt32Type, scale, enable_ansi_mode) + } + DataType::UInt64 => { + let array = array.as_primitive::(); + let result: PrimitiveArray = array.try_unary(|x| { + let v_i64 = i64::try_from(x).map_err(|_| { + (exec_err!( + "round: UInt64 value {x} exceeds i64::MAX and cannot be rounded" + ) as Result<(), _>) + .unwrap_err() + })?; + round_integer(v_i64, scale, enable_ansi_mode) + .map(|v| v as u64) + })?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + + // Float types + DataType::Float16 => impl_float_array_round!(array, Float16Type, scale), + DataType::Float32 => impl_float_array_round!(array, Float32Type, scale), + DataType::Float64 => impl_float_array_round!(array, Float64Type, scale), + + // Decimal types + DataType::Decimal32(_, input_scale) => { + impl_decimal_array_round!(array, Decimal32Type, *input_scale, scale) + } + DataType::Decimal64(_, input_scale) => { + impl_decimal_array_round!(array, Decimal64Type, *input_scale, scale) + } + DataType::Decimal128(_, input_scale) => { + impl_decimal_array_round!(array, Decimal128Type, *input_scale, scale) + } + DataType::Decimal256(_, input_scale) => { + impl_decimal_array_round!(array, Decimal256Type, *input_scale, scale) + } + + dt => not_impl_err!("Unsupported data type for Spark round(): {dt}"), + }, + + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Null => Ok(args[0].clone()), + _ if sv.is_null() => Ok(args[0].clone()), + + // Integer scalars + ScalarValue::Int8(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + i8::try_from(r).map_err(|_| { + (exec_err!("Int8 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as i8 + }; + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(result)))) + } + ScalarValue::Int16(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + i16::try_from(r).map_err(|_| { + (exec_err!("Int16 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as i16 + }; + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(result)))) + } + ScalarValue::Int32(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + i32::try_from(r).map_err(|_| { + (exec_err!("Int32 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as i32 + }; + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) + } + ScalarValue::Int64(Some(v)) => { + let result = round_integer(*v, scale, enable_ansi_mode)?; + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) + } + + // Unsigned integer scalars + ScalarValue::UInt8(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + u8::try_from(r).map_err(|_| { + (exec_err!("UInt8 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as u8 + }; + Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some(result)))) + } + ScalarValue::UInt16(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + u16::try_from(r).map_err(|_| { + (exec_err!("UInt16 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as u16 + }; + Ok(ColumnarValue::Scalar(ScalarValue::UInt16(Some(result)))) + } + ScalarValue::UInt32(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + u32::try_from(r).map_err(|_| { + (exec_err!("UInt32 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as u32 + }; + Ok(ColumnarValue::Scalar(ScalarValue::UInt32(Some(result)))) + } + ScalarValue::UInt64(Some(v)) => { + let v_i64 = i64::try_from(*v).map_err(|_| { + (exec_err!( + "round: UInt64 value {v} exceeds i64::MAX and cannot be rounded" + ) as Result<(), _>) + .unwrap_err() + })?; + let result = round_integer(v_i64, scale, enable_ansi_mode)?; + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some( + result as u64, + )))) + } + + // Float scalars + ScalarValue::Float16(Some(v)) => { + let result = round_float(*v, scale); + Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(result)))) + } + ScalarValue::Float32(Some(v)) => { + let result = round_float(*v, scale); + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result)))) + } + ScalarValue::Float64(Some(v)) => { + let result = round_float(*v, scale); + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result)))) + } + + // Decimal scalars + ScalarValue::Decimal32(Some(v), precision, input_scale) => { + let rounded = round_decimal(*v, *input_scale, scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal32( + Some(rounded), + *precision, + *input_scale, + ))) + } + ScalarValue::Decimal64(Some(v), precision, input_scale) => { + let rounded = round_decimal(*v, *input_scale, scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal64( + Some(rounded), + *precision, + *input_scale, + ))) + } + ScalarValue::Decimal128(Some(v), precision, input_scale) => { + let rounded = round_decimal(*v, *input_scale, scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(rounded), + *precision, + *input_scale, + ))) + } + ScalarValue::Decimal256(Some(v), precision, input_scale) => { + let rounded = round_decimal(*v, *input_scale, scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(rounded), + *precision, + *input_scale, + ))) + } + + dt => not_impl_err!("Unsupported data type for Spark round(): {dt}"), + }, + } +} diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 6aa7ff599d0c5..1b35070b29f44 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -592,9 +592,12 @@ impl SqlToRel<'_, S> { } else { let mut unnest_options = UnnestOptions::new().with_preserve_nulls(false); + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] + // Expr contains Arc with interior mutability but is intentionally used as hash key let mut projection_exprs = match &aggr_expr_using_columns { Some(exprs) => (*exprs).clone(), None => { + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] let mut columns = HashSet::new(); for expr in &aggr_expr { expr.apply(|expr| { diff --git a/datafusion/sqllogictest/test_files/delete.slt b/datafusion/sqllogictest/test_files/delete.slt index 76db46f138bad..6131d6db3d5f7 100644 --- a/datafusion/sqllogictest/test_files/delete.slt +++ b/datafusion/sqllogictest/test_files/delete.slt @@ -45,7 +45,7 @@ explain delete from t1 where a = 1 and b = 2 and c > 3 and d != 4; ---- logical_plan 01)Dml: op=[Delete] table=[t1] -02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND t1.b = CAST(Int64(2) AS Utf8View) AND t1.c > CAST(Int64(3) AS Float64) AND CAST(t1.d AS Int64) != Int64(4) +02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND CAST(t1.b AS Int64) = Int64(2) AND t1.c > CAST(Int64(3) AS Float64) AND CAST(t1.d AS Int64) != Int64(4) 03)----TableScan: t1 physical_plan 01)CooperativeExec @@ -58,7 +58,7 @@ explain delete from t1 where t1.a = 1 and b = 2 and t1.c > 3 and d != 4; ---- logical_plan 01)Dml: op=[Delete] table=[t1] -02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND t1.b = CAST(Int64(2) AS Utf8View) AND t1.c > CAST(Int64(3) AS Float64) AND CAST(t1.d AS Int64) != Int64(4) +02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND CAST(t1.b AS Int64) = Int64(2) AND t1.c > CAST(Int64(3) AS Float64) AND CAST(t1.d AS Int64) != Int64(4) 03)----TableScan: t1 physical_plan 01)CooperativeExec diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index 511061cf82f06..5ec95260e8357 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -426,7 +426,8 @@ physical_plan 02)--DataSourceExec: partitions=1, partition_sizes=[1] -# Now query using an integer which must be coerced into a dictionary string +# Query using an integer literal: comparison coercion prefers numeric types, +# so the dictionary string column is cast to Int64 query TT SELECT * from test where column2 = 1; ---- @@ -436,10 +437,10 @@ query TT explain SELECT * from test where column2 = 1; ---- logical_plan -01)Filter: test.column2 = Dictionary(Int32, Utf8("1")) +01)Filter: CAST(test.column2 AS Int64) = Int64(1) 02)--TableScan: test projection=[column1, column2] physical_plan -01)FilterExec: column2@1 = 1 +01)FilterExec: CAST(column2@1 AS Int64) = 1 02)--DataSourceExec: partitions=1, partition_sizes=[1] # Window Functions diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index a6341bc686f74..9365f3896b618 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -1086,55 +1086,37 @@ SELECT 0.3 NOT IN (0.0,0.1,0.2,NULL) ---- NULL -query B +# Mixed string/integer IN lists: comparison coercion picks the numeric +# type, so non-numeric strings like 'a' fail to cast to Int64. +query error Cannot cast string 'a' to value of Int64 type SELECT '1' IN ('a','b',1) ----- -true -query B +query error Cannot cast string 'a' to value of Int64 type SELECT '2' IN ('a','b',1) ----- -false -query B +query error Cannot cast string 'a' to value of Int64 type SELECT '2' NOT IN ('a','b',1) ----- -true -query B +query error Cannot cast string 'a' to value of Int64 type SELECT '1' NOT IN ('a','b',1) ----- -false -query B +query error Cannot cast string 'a' to value of Int64 type SELECT NULL IN ('a','b',1) ----- -NULL -query B +query error Cannot cast string 'a' to value of Int64 type SELECT NULL NOT IN ('a','b',1) ----- -NULL -query B +query error Cannot cast string 'a' to value of Int64 type SELECT '1' IN ('a','b',NULL,1) ----- -true -query B +query error Cannot cast string 'a' to value of Int64 type SELECT '2' IN ('a','b',NULL,1) ----- -NULL -query B +query error Cannot cast string 'a' to value of Int64 type SELECT '1' NOT IN ('a','b',NULL,1) ----- -false -query B +query error Cannot cast string 'a' to value of Int64 type SELECT '2' NOT IN ('a','b',NULL,1) ----- -NULL # ======================================================================== # Comprehensive IN LIST tests with NULL handling diff --git a/datafusion/sqllogictest/test_files/push_down_filter_parquet.slt b/datafusion/sqllogictest/test_files/push_down_filter_parquet.slt index cca6384360460..ab6847e1c4834 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter_parquet.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter_parquet.slt @@ -122,24 +122,6 @@ explain select a from t where a != '100'; ---- physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 != 100, pruning_predicate=a_null_count@2 != row_count@3 AND (a_min@0 != 100 OR 100 != a_max@1), required_guarantees=[a not in (100)] -# The predicate should still have the column cast when the value is a NOT valid i32 -query TT -explain select a from t where a = '99999999999'; ----- -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 99999999999 - -# The predicate should still have the column cast when the value is a NOT valid i32 -query TT -explain select a from t where a = '99.99'; ----- -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 99.99 - -# The predicate should still have the column cast when the value is a NOT valid i32 -query TT -explain select a from t where a = ''; ----- -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter_parquet/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = - # The predicate should not have a column cast when the operator is = or != and the literal can be round-trip casted without losing information. query TT explain select a from t where cast(a as string) = '100'; diff --git a/datafusion/sqllogictest/test_files/spark/conversion/cast_float_to_timestamp.slt b/datafusion/sqllogictest/test_files/spark/conversion/cast_float_to_timestamp.slt new file mode 100644 index 0000000000000..68bf340a8fc4f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/conversion/cast_float_to_timestamp.slt @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for casting float types to timestamp using spark_cast + +# Test spark_cast from float64 to timestamp +query P +SELECT spark_cast(0.0, 'timestamp'); +---- +1970-01-01T00:00:00Z + +query P +SELECT spark_cast(1.0, 'timestamp'); +---- +1970-01-01T00:00:01Z + +query P +SELECT spark_cast(-1.0, 'timestamp'); +---- +1969-12-31T23:59:59Z + +# Test fractional seconds +query P +SELECT spark_cast(1.5, 'timestamp'); +---- +1970-01-01T00:00:01.500Z + +query P +SELECT spark_cast(-1.5, 'timestamp'); +---- +1969-12-31T23:59:58.500Z + +query P +SELECT spark_cast(1.123456, 'timestamp'); +---- +1970-01-01T00:00:01.123456Z + +# Test larger values +query P +SELECT spark_cast(1704067200.0, 'timestamp'); +---- +2024-01-01T00:00:00Z + +query P +SELECT spark_cast(1704067200.123456, 'timestamp'); +---- +2024-01-01T00:00:00.123456Z + +# Test spark_cast from float32 to timestamp +query P +SELECT spark_cast(arrow_cast(0.0, 'Float32'), 'timestamp'); +---- +1970-01-01T00:00:00Z + +query P +SELECT spark_cast(arrow_cast(1.5, 'Float32'), 'timestamp'); +---- +1970-01-01T00:00:01.500Z + +query P +SELECT spark_cast(arrow_cast(-1.5, 'Float32'), 'timestamp'); +---- +1969-12-31T23:59:58.500Z + +# Test NULL handling +query P +SELECT spark_cast(arrow_cast(NULL, 'Float32'), 'timestamp'); +---- +NULL + +query P +SELECT spark_cast(arrow_cast(NULL, 'Float64'), 'timestamp'); +---- +NULL + +# Test NaN and Infinity in non-ANSI mode (default) - should return NULL +query P +SELECT spark_cast(arrow_cast('NaN', 'Float64'), 'timestamp'); +---- +NULL + +query P +SELECT spark_cast(arrow_cast('Infinity', 'Float64'), 'timestamp'); +---- +NULL + +query P +SELECT spark_cast(arrow_cast('-Infinity', 'Float64'), 'timestamp'); +---- +NULL + +# Test negative values +query P +SELECT spark_cast(-86400.0, 'timestamp'); +---- +1969-12-31T00:00:00Z + +query P +SELECT spark_cast(-86400.5, 'timestamp'); +---- +1969-12-30T23:59:59.500Z + +# Test with timezone America/Los_Angeles +statement ok +SET datafusion.execution.time_zone = 'America/Los_Angeles'; + +query P +SELECT spark_cast(0.0, 'timestamp'); +---- +1969-12-31T16:00:00-08:00 + +query P +SELECT spark_cast(1704067200.0, 'timestamp'); +---- +2023-12-31T16:00:00-08:00 + +# Reset to UTC +statement ok +SET datafusion.execution.time_zone = 'UTC'; + +############################# +# Array Tests +############################# + +# Create test table with float columns +statement ok +CREATE TABLE float_test AS SELECT + arrow_cast(column1, 'Float32') as f32_col, + column2 as f64_col +FROM (VALUES + (NULL, NULL), + (0.0, 0.0), + (1.5, 1.5), + (-1.5, -1.5), + (1704067200.0, 1704067200.123456) +); + +# Test in UTC +query PP +SELECT spark_cast(f32_col, 'timestamp'), spark_cast(f64_col, 'timestamp') FROM float_test; +---- +NULL NULL +1970-01-01T00:00:00Z 1970-01-01T00:00:00Z +1970-01-01T00:00:01.500Z 1970-01-01T00:00:01.500Z +1969-12-31T23:59:58.500Z 1969-12-31T23:59:58.500Z +2024-01-01T00:00:00Z 2024-01-01T00:00:00.123456Z + +# Test with NaN and Infinity in array +statement ok +CREATE TABLE float_special AS SELECT + column1 as f64_col +FROM (VALUES + (1.0), + (arrow_cast('NaN', 'Float64')), + (arrow_cast('Infinity', 'Float64')), + (arrow_cast('-Infinity', 'Float64')), + (2.0) +); + +# NaN and Infinity should return NULL in non-ANSI mode +query P +SELECT spark_cast(f64_col, 'timestamp') FROM float_special; +---- +1970-01-01T00:00:01Z +NULL +NULL +NULL +1970-01-01T00:00:02Z + +# Cleanup +statement ok +DROP TABLE float_test; + +statement ok +DROP TABLE float_special; + +# Note: Overflow saturation tests (1e19, -1e19) are not included here because +# DataFusion's timestamp formatter cannot display i64::MAX/MIN microsecond values. +# The saturation behavior (matching Spark) is verified in unit tests: +# test_cast_float_overflow_non_ansi_mode and test_cast_float_negative_overflow_non_ansi_mode + +############################# +# ANSI Mode Tests +############################# + +# Enable ANSI mode +statement ok +SET datafusion.execution.enable_ansi_mode = true; + +# NaN should error in ANSI mode +statement error +SELECT spark_cast(arrow_cast('NaN', 'Float64'), 'timestamp'); + +# Infinity should error in ANSI mode +statement error +SELECT spark_cast(arrow_cast('Infinity', 'Float64'), 'timestamp'); + +# Very large value should error due to overflow +statement error +SELECT spark_cast(1e19, 'timestamp'); + +# Normal values should still work in ANSI mode +query P +SELECT spark_cast(1.5, 'timestamp'); +---- +1970-01-01T00:00:01.500Z + +# Reset ANSI mode +statement ok +SET datafusion.execution.enable_ansi_mode = false; + +# Reset time_zone to NULL (default) +statement ok +RESET datafusion.execution.time_zone; diff --git a/datafusion/sqllogictest/test_files/spark/math/round.slt b/datafusion/sqllogictest/test_files/spark/math/round.slt index bc1f6b72247a0..91c5bdf0506f5 100644 --- a/datafusion/sqllogictest/test_files/spark/math/round.slt +++ b/datafusion/sqllogictest/test_files/spark/math/round.slt @@ -15,13 +15,567 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function # This file is part of the implementation of the datafusion-spark function library. # For more information, please see: # https://github.com/apache/datafusion/issues/15914 -## Original Query: SELECT round(2.5, 0); -## PySpark 3.5.5 Result: {'round(2.5, 0)': Decimal('3'), 'typeof(round(2.5, 0))': 'decimal(2,0)', 'typeof(2.5)': 'decimal(2,1)', 'typeof(0)': 'int'} -#query -#SELECT round(2.5::decimal(2,1), 0::int); +# ------------------------------------------------------------------- +# Float / Double tests (HALF_UP rounding: .5 rounds away from zero) +# ------------------------------------------------------------------- + +# round(double) default scale = 0 +query R +SELECT round(2.5::double); +---- +3 + +query R +SELECT round(3.5::double); +---- +4 + +query R +SELECT round(-2.5::double); +---- +-3 + +query R +SELECT round(-3.5::double); +---- +-4 + +query R +SELECT round(1.4::double); +---- +1 + +query R +SELECT round(1.6::double); +---- +2 + +# round(double, scale) +query R +SELECT round(2.345::double, 2::int); +---- +2.35 + +query R +SELECT round(2.355::double, 2::int); +---- +2.36 + +query R +SELECT round(123.456::double, 1::int); +---- +123.5 + +# round(float) +query R +SELECT round(arrow_cast(2.5, 'Float32')); +---- +3 + +query R +SELECT round(arrow_cast(-2.5, 'Float32')); +---- +-3 + +# round(float) with scale +query R +SELECT round(arrow_cast(2.345, 'Float32'), 2::int); +---- +2.35 + +# round(float32) with negative scale +query R +SELECT round(arrow_cast(125.0, 'Float32'), -1::int); +---- +130 + +# ------------------------------------------------------------------- +# Float16 tests +# ------------------------------------------------------------------- + +# round(float16) default scale = 0 +query R +SELECT round(arrow_cast(2.5, 'Float16')); +---- +3 + +query R +SELECT round(arrow_cast(-2.5, 'Float16')); +---- +-3 + +# round(float16) with negative scale +query R +SELECT round(arrow_cast(125, 'Float16'), -1::int); +---- +130 + +# round(double) with negative scale +query R +SELECT round(125.0::double, -1::int); +---- +130 + +query R +SELECT round(125.0::double, -2::int); +---- +100 + +query R +SELECT round(0.0::double); +---- +0 + +# round(double) with Infinity and NaN +query R +SELECT round('Infinity'::double); +---- +Infinity + +query R +SELECT round('-Infinity'::double); +---- +-Infinity + +query R +SELECT round('NaN'::double); +---- +NaN + +query R +SELECT round('Infinity'::double, 2::int); +---- +Infinity + +query R +SELECT round('NaN'::double, 2::int); +---- +NaN + +# round(double) with extreme negative scale — should return 0, not NaN +query R +SELECT round(42.0::double, -400::int); +---- +0 + +# round(double) with extreme positive scale — should return value as-is +query R +SELECT round(2.5::double, 400::int); +---- +2.5 + +# ------------------------------------------------------------------- +# Integer tests (negative scale rounds to tens, hundreds, etc.) +# ------------------------------------------------------------------- + +# round(int, -1) — round to nearest 10 +query I +SELECT round(25::int, -1::int); +---- +30 + +query I +SELECT round(24::int, -1::int); +---- +20 + +query I +SELECT round(-25::int, -1::int); +---- +-30 + +query I +SELECT round(123::int, -1::int); +---- +120 + +# round(int, -2) — round to nearest 100 +query I +SELECT round(150::int, -2::int); +---- +200 + +query I +SELECT round(149::int, -2::int); +---- +100 + +# round(int, positive scale) — no-op for integers +query I +SELECT round(42::int, 2::int); +---- +42 + +# round(int) default scale = 0 — returns unchanged +query I +SELECT round(42::int); +---- +42 + +# round(bigint, -1) +query I +SELECT round(25::bigint, -1::int); +---- +30 + +# round(smallint, -1) +query I +SELECT round(25::smallint, -1::int); +---- +30 + +# round(tinyint, -1) +query I +SELECT round(arrow_cast(25, 'Int8'), -1::int); +---- +30 + +# round(int) with very large negative scale — should return 0 +query I +SELECT round(42::int, -10::int); +---- +0 + +# ------------------------------------------------------------------- +# Unsigned integer tests +# ------------------------------------------------------------------- + +# round(uint8, -1) +query I +SELECT round(arrow_cast(25, 'UInt8'), -1::int); +---- +30 + +# round(uint16, -1) +query I +SELECT round(arrow_cast(25, 'UInt16'), -1::int); +---- +30 + +# round(uint32, -1) +query I +SELECT round(arrow_cast(150, 'UInt32'), -2::int); +---- +200 + +# round(uint64, -1) +query I +SELECT round(arrow_cast(25, 'UInt64'), -1::int); +---- +30 + +# round(uint32, positive scale) — no-op for integers +query I +SELECT round(arrow_cast(42, 'UInt32'), 2::int); +---- +42 + +# ------------------------------------------------------------------- +# Decimal tests (HALF_UP rounding) +# ------------------------------------------------------------------- + +# --- Decimal32 --- + +# round(decimal32, 0) — round to integer +query ? +SELECT round(arrow_cast(2.5, 'Decimal32(9, 1)'), 0::int); +---- +3.0 + +query ? +SELECT round(arrow_cast(-2.5, 'Decimal32(9, 1)'), 0::int); +---- +-3.0 + +# round(decimal32, 2) +query ? +SELECT round(arrow_cast(2.345, 'Decimal32(9, 3)'), 2::int); +---- +2.350 + +# round(decimal32) default scale = 0 +query ? +SELECT round(arrow_cast(3.5, 'Decimal32(9, 1)')); +---- +4.0 + +# --- Decimal64 --- + +# round(decimal64, 0) — round to integer +query ? +SELECT round(arrow_cast(2.5, 'Decimal64(18, 1)'), 0::int); +---- +3.0 + +query ? +SELECT round(arrow_cast(-2.5, 'Decimal64(18, 1)'), 0::int); +---- +-3.0 + +# round(decimal64, 2) +query ? +SELECT round(arrow_cast(2.345, 'Decimal64(18, 3)'), 2::int); +---- +2.350 + +# round(decimal64) default scale = 0 +query ? +SELECT round(arrow_cast(3.5, 'Decimal64(18, 1)')); +---- +4.0 + +# --- Decimal128 --- + +# round(decimal, 0) — round to integer +query R +SELECT round(2.5::decimal(2,1), 0::int); +---- +3 + +query R +SELECT round(3.5::decimal(2,1), 0::int); +---- +4 + +query R +SELECT round(-2.5::decimal(2,1), 0::int); +---- +-3 + +# round(decimal) default scale = 0 +query R +SELECT round(2.5::decimal(2,1)); +---- +3 + +# round(decimal, 2) — keep 2 decimal places +query R +SELECT round(2.345::decimal(10,3), 2::int); +---- +2.35 + +query R +SELECT round(2.355::decimal(10,3), 2::int); +---- +2.36 + +# round(decimal, scale larger than input scale) — no change +query R +SELECT round(2.5::decimal(2,1), 5::int); +---- +2.5 + +# round(decimal, 1) +query R +SELECT round(123.456::decimal(10,3), 1::int); +---- +123.5 + +# round(decimal, negative scale) — round to tens +query R +SELECT round(125.0::decimal(10,1), -1::int); +---- +130 + +# round(decimal, extreme negative scale) — should return 0, not error +query R +SELECT round(2.5::decimal(10,1), -400::int); +---- +0 + +# --- Decimal256 --- + +# round(decimal256, 0) — round to integer +query R +SELECT round(arrow_cast(2.5, 'Decimal256(38, 1)'), 0::int); +---- +3 + +query R +SELECT round(arrow_cast(-2.5, 'Decimal256(38, 1)'), 0::int); +---- +-3 + +# round(decimal256, 2) +query R +SELECT round(arrow_cast(2.345, 'Decimal256(38, 3)'), 2::int); +---- +2.35 + +# round(decimal256) default scale = 0 +query R +SELECT round(arrow_cast(3.5, 'Decimal256(38, 1)')); +---- +4 + +# ------------------------------------------------------------------- +# NULL handling +# ------------------------------------------------------------------- + +query I +SELECT round(NULL::int); +---- +NULL + +query R +SELECT round(NULL::double); +---- +NULL + +query R +SELECT round(NULL::decimal(10,2)); +---- +NULL + +# round with NULL scale — Spark returns NULL +query I +SELECT round(42::int, NULL::int); +---- +NULL + +query R +SELECT round(2.5::double, NULL::int); +---- +NULL + +# ------------------------------------------------------------------- +# Column-based tests +# ------------------------------------------------------------------- + +statement ok +CREATE TABLE test_round (id int, int_val int, float_val double, dec_val decimal(10,3)) AS VALUES + (1, 25, 2.5, 2.345), + (2, 35, 3.5, 3.555), + (3, -25, -2.5, -2.345), + (4, 123, 1.4, 1.005), + (5, NULL, NULL, NULL); + +query IIRR rowsort +SELECT id, round(int_val, -1::int), round(float_val), round(dec_val, 2::int) FROM test_round; +---- +1 30 3 2.35 +2 40 4 3.56 +3 -30 -3 -2.35 +4 120 1 1.01 +5 NULL NULL NULL + +statement ok +DROP TABLE test_round; + +# ------------------------------------------------------------------- +# Expression tests +# ------------------------------------------------------------------- + +query R +SELECT round(3.14159::double, 2::int) + 1.0; +---- +4.140000000000001 + +# ------------------------------------------------------------------- +# Non-ANSI wrapping behavior +# When ANSI mode is off, integer overflow wraps silently. +# ------------------------------------------------------------------- + +# round(127::tinyint, -1) → 130, wraps as i8 → -126 +query I +SELECT round(arrow_cast(127, 'Int8'), -1::int); +---- +-126 + +# round(32767::smallint, -1) → 32770, wraps as i16 → -32766 +query I +SELECT round(arrow_cast(32767, 'Int16'), -1::int); +---- +-32766 + +# round(2147483647::int, -1) → 2147483650, wraps as i32 → -2147483646 +query I +SELECT round(2147483647::int, -1::int); +---- +-2147483646 + +# round(i64::MAX, -1) wraps as i64 → -9223372036854775806 +query I +SELECT round(9223372036854775807::bigint, -1::int); +---- +-9223372036854775806 + +# ------------------------------------------------------------------- +# ANSI mode tests: overflow detection for integer rounding +# ------------------------------------------------------------------- + +statement ok +set datafusion.execution.enable_ansi_mode = true; + +# ANSI mode: normal rounding should still work +query I +SELECT round(25::int, -1::int); +---- +30 + +query I +SELECT round(-25::int, -1::int); +---- +-30 + +query I +SELECT round(150::int, -2::int); +---- +200 + +# ANSI mode: positive scale on integers — no-op, no overflow +query I +SELECT round(42::int, 2::int); +---- +42 + +# ANSI mode: floats and decimals should work normally +query R +SELECT round(2.5::double); +---- +3 + +query R +SELECT round(2.5::decimal(2,1), 0::int); +---- +3 + +# ANSI mode: integer overflow should error +query error DataFusion error: Execution error: Int64 overflow on round +SELECT round(9223372036854775807::bigint, -1::int); + +# ANSI mode: Int32 overflow should error +query error DataFusion error: Execution error: Int32 overflow on round +SELECT round(2147483647::int, -1::int); + +# ANSI mode: Int16 overflow should error +query error DataFusion error: Execution error: Int16 overflow on round +SELECT round(arrow_cast(32767, 'Int16'), -1::int); + +# ANSI mode: Int8 overflow should error +query error DataFusion error: Execution error: Int8 overflow on round +SELECT round(arrow_cast(127, 'Int8'), -1::int); + +# Reset ANSI mode +statement ok +set datafusion.execution.enable_ansi_mode = false; + +# ------------------------------------------------------------------- +# Negative tests: unsupported data types +# ------------------------------------------------------------------- + +# round(string) should fail +query error Error during planning: Internal error: Function 'round' failed to match any signature +SELECT round('hello'::varchar); + +# round(boolean) should fail +query error Error during planning: Internal error: Function 'round' failed to match any signature +SELECT round(true); + +# round(timestamp) should fail +query error Error during planning: Internal error: Function 'round' failed to match any signature +SELECT round('2023-01-01T00:00:00'::timestamp); diff --git a/datafusion/sqllogictest/test_files/string/string_query.slt.part b/datafusion/sqllogictest/test_files/string/string_query.slt.part index 2884c3518610d..679ba0aa8a888 100644 --- a/datafusion/sqllogictest/test_files/string/string_query.slt.part +++ b/datafusion/sqllogictest/test_files/string/string_query.slt.part @@ -41,38 +41,17 @@ NULL R NULL 🔥 # -------------------------------------- # test type coercion (compare to int) -# queries should not error +# +# Comparing a string column to an integer literal is allowed but will fail +# at runtime if the string column contains any values that can't be cast +# to integers. # -------------------------------------- -query BB +statement error Arrow error: Cast error: Cannot cast string 'Andrew' to value of Int64 type select ascii_1 = 1 as col1, 1 = ascii_1 as col2 from test_basic_operator; ----- -false false -false false -false false -false false -false false -false false -false false -false false -false false -NULL NULL -NULL NULL -query BB +statement error Arrow error: Cast error: Cannot cast string 'Andrew' to value of Int64 type select ascii_1 <> 1 as col1, 1 <> ascii_1 as col2 from test_basic_operator; ----- -true true -true true -true true -true true -true true -true true -true true -true true -true true -NULL NULL -NULL NULL # Coercion to date/time query BBB diff --git a/datafusion/sqllogictest/test_files/string_numeric_coercion.slt b/datafusion/sqllogictest/test_files/string_numeric_coercion.slt new file mode 100644 index 0000000000000..1567a149bcdf4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/string_numeric_coercion.slt @@ -0,0 +1,584 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Tests for string-numeric comparison coercion +## Verifies that when comparing a numeric column to a string literal, +## the comparison is performed numerically (not lexicographically). +## See: https://github.com/apache/datafusion/issues/15161 +########## + +# Setup test data +statement ok +CREATE TABLE t_int AS VALUES (1), (5), (325), (499), (1000); + +statement ok +CREATE TABLE t_float AS VALUES (1.5), (5.0), (325.7), (499.9), (1000.1); + +# ------------------------------------------------- +# Integer column with comparison operators vs string literals. +# Ensure that the comparison is done with numeric semantics, +# not lexicographically. +# ------------------------------------------------- + +query I rowsort +SELECT * FROM t_int WHERE column1 < '5'; +---- +1 + +query I rowsort +SELECT * FROM t_int WHERE column1 > '5'; +---- +1000 +325 +499 + +query I rowsort +SELECT * FROM t_int WHERE column1 <= '5'; +---- +1 +5 + +query I rowsort +SELECT * FROM t_int WHERE column1 >= '5'; +---- +1000 +325 +499 +5 + +query I rowsort +SELECT * FROM t_int WHERE column1 = '5'; +---- +5 + +query I rowsort +SELECT * FROM t_int WHERE column1 != '5'; +---- +1 +1000 +325 +499 + +query I rowsort +SELECT * FROM t_int WHERE column1 < '10'; +---- +1 +5 + +query I rowsort +SELECT * FROM t_int WHERE column1 <= '100'; +---- +1 +5 + +query I rowsort +SELECT * FROM t_int WHERE column1 > '100'; +---- +1000 +325 +499 + +# ------------------------------------------------- +# Float column with comparison operators vs string literals +# ------------------------------------------------- + +query R rowsort +SELECT * FROM t_float WHERE column1 < '5'; +---- +1.5 + +query R rowsort +SELECT * FROM t_float WHERE column1 > '5'; +---- +1000.1 +325.7 +499.9 + +query R rowsort +SELECT * FROM t_float WHERE column1 = '5'; +---- +5 + +query R rowsort +SELECT * FROM t_float WHERE column1 = '5.0'; +---- +5 + +# ------------------------------------------------- +# Error on strings that cannot be cast to the numeric column type +# ------------------------------------------------- + +# Non-numeric string against integer column +statement error Arrow error: Cast error: Cannot cast string 'hello' to value of Int64 type +SELECT * FROM t_int WHERE column1 < 'hello'; + +# Non-numeric string against float column +statement error Arrow error: Cast error: Cannot cast string 'hello' to value of Float64 type +SELECT * FROM t_float WHERE column1 < 'hello'; + +# Float string against integer column +statement error Arrow error: Cast error: Cannot cast string '99.99' to value of Int64 type +SELECT * FROM t_int WHERE column1 = '99.99'; + +# Empty string against integer column +statement error Arrow error: Cast error: Cannot cast string '' to value of Int64 type +SELECT * FROM t_int WHERE column1 = ''; + +# Empty string against float column +statement error Arrow error: Cast error: Cannot cast string '' to value of Float64 type +SELECT * FROM t_float WHERE column1 = ''; + +# Overflow +statement error Arrow error: Cast error: Cannot cast string '99999999999999999999' to value of Int64 type +SELECT * FROM t_int WHERE column1 = '99999999999999999999'; + + +# ------------------------------------------------- +# UNION still uses string coercion (type unification context) +# ------------------------------------------------- + +statement ok +CREATE TABLE t_str AS VALUES ('one'), ('two'), ('three'); + +query T rowsort +SELECT column1 FROM t_int UNION ALL SELECT column1 FROM t_str; +---- +1 +1000 +325 +499 +5 +one +three +two + +# Verify the UNION coerces to Utf8 (not numeric) +query TT +EXPLAIN SELECT column1 FROM t_int UNION ALL SELECT column1 FROM t_str; +---- +logical_plan +01)Union +02)--Projection: CAST(t_int.column1 AS Utf8) AS column1 +03)----TableScan: t_int projection=[column1] +04)--TableScan: t_str projection=[column1] +physical_plan +01)UnionExec +02)--ProjectionExec: expr=[CAST(column1@0 AS Utf8) as column1] +03)----DataSourceExec: partitions=1, partition_sizes=[1] +04)--DataSourceExec: partitions=1, partition_sizes=[1] + +# ------------------------------------------------- +# BETWEEN uses comparison coercion (numeric preferred) +# ------------------------------------------------- + +query I rowsort +SELECT * FROM t_int WHERE column1 BETWEEN '5' AND '100'; +---- +5 + +# ------------------------------------------------- +# IN list uses comparison coercion (numeric preferred) +# `x IN (a, b)` is semantically equivalent to `x = a OR x = b` +# ------------------------------------------------- + +# Basic IN list with string literals against integer column +query I rowsort +SELECT * FROM t_int WHERE column1 IN ('5', '325'); +---- +325 +5 + +# IN list with a value where numeric coercion matters +query I rowsort +SELECT * FROM t_int WHERE column1 IN ('1000'); +---- +1000 + +# IN list with NOT +query I rowsort +SELECT * FROM t_int WHERE column1 NOT IN ('1', '5'); +---- +1000 +325 +499 + +# IN list with float column +query R rowsort +SELECT * FROM t_float WHERE column1 IN ('5.0', '325.7'); +---- +325.7 +5 + +# Verify the plan shows numeric coercion (not CAST to Utf8) +query TT +EXPLAIN SELECT * FROM t_int WHERE column1 IN ('5', '325'); +---- +logical_plan +01)Filter: t_int.column1 = Int64(5) OR t_int.column1 = Int64(325) +02)--TableScan: t_int projection=[column1] +physical_plan +01)FilterExec: column1@0 = 5 OR column1@0 = 325 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Error on invalid string in IN list +statement error Arrow error: Cast error: Cannot cast string 'hello' to value of Int64 type +SELECT * FROM t_int WHERE column1 IN ('5', 'hello'); + +# Mixed numeric literal and string literal in IN list against integer column +query I rowsort +SELECT * FROM t_int WHERE column1 IN (5, '325'); +---- +325 +5 + +# Mixed numeric literal and string literal in IN list against float column +query R rowsort +SELECT * FROM t_float WHERE column1 IN (5, '325.7'); +---- +325.7 +5 + +# String and numeric literal order reversed +query R rowsort +SELECT * FROM t_float WHERE column1 IN ('5', 325.7); +---- +325.7 +5 + +# Float string literal against integer column errors (cannot cast '10.0' to Int64) +statement error Arrow error: Cast error: Cannot cast string '10.0' to value of Int64 type +SELECT * FROM t_int WHERE column1 IN (5, '10.0'); + +# Non-numeric string in mixed IN list still errors +statement error Arrow error: Cast error: Cannot cast string 'hello' to value of Int64 type +SELECT * FROM t_int WHERE column1 IN ('hello', 5); + +# ------------------------------------------------- +# CASE WHEN uses comparison coercion for conditions (numeric preferred) +# `CASE expr WHEN val` is semantically equivalent to `expr = val` +# ------------------------------------------------- + +# Basic CASE with integer column and string WHEN values +query T rowsort +SELECT CASE column1 WHEN '5' THEN 'five' WHEN '1000' THEN 'thousand' ELSE 'other' END FROM t_int; +---- +five +other +other +other +thousand + +# CASE with float column: '5' is cast to 5.0 numerically, matching the row. +# (Under string comparison, '5.0' != '5' would fail to match.) +query T rowsort +SELECT CASE column1 WHEN '5' THEN 'matched' ELSE 'no match' END FROM t_float; +---- +matched +no match +no match +no match +no match + +# THEN/ELSE results still use type union coercion (string preferred), +# so mixing numeric and string coerces to string +query T rowsort +SELECT CASE WHEN column1 > 500 THEN column1 ELSE 'small' END FROM t_int; +---- +1000 +small +small +small +small + +# ------------------------------------------------- +# GREATEST / LEAST use comparison coercion (numeric preferred) +# ------------------------------------------------- + +# GREATEST with mixed int and string: numeric comparison, not lexicographic. +query I +SELECT GREATEST(10, '9'); +---- +10 + +query T +SELECT arrow_typeof(GREATEST(10, '9')); +---- +Int64 + +# LEAST with mixed int and string: numeric comparison. +query I +SELECT LEAST(10, '9'); +---- +9 + +query T +SELECT arrow_typeof(LEAST(10, '9')); +---- +Int64 + +# GREATEST with multiple mixed args +query I +SELECT GREATEST(1, '20', 3); +---- +20 + +# Non-numeric string in GREATEST errors +statement error Arrow error: Cast error: Cannot cast string 'hello' to value of Int64 type +SELECT GREATEST(1, 'hello'); + +# ------------------------------------------------- +# NULLIF uses comparison coercion (numeric preferred) +# ------------------------------------------------- + +# NULLIF with mixed int and string: numeric comparison. +# 10 != 9 numerically, so returns 10. +query I +SELECT NULLIF(10, '9'); +---- +10 + +query T +SELECT arrow_typeof(NULLIF(10, '9')); +---- +Int64 + +# NULLIF with matching values: 10 = 10 numerically, so returns NULL. +query I +SELECT NULLIF(10, '10'); +---- +NULL + +# ------------------------------------------------- +# Nested struct/map/list comparisons use comparison coercion +# (numeric preferred) for their field/element types +# ------------------------------------------------- + +statement ok +CREATE TABLE t_struct_int AS SELECT named_struct('val', column1) as s FROM (VALUES (1), (5), (10)); + +statement ok +CREATE TABLE t_struct_str AS SELECT named_struct('val', column1) as s FROM (VALUES ('5'), ('10')); + +# Struct comparison: the string field is cast to Int64 (numeric preferred). +query ? rowsort +SELECT t1.s FROM t_struct_int t1, t_struct_str t2 WHERE t1.s = t2.s; +---- +{val: 10} +{val: 5} + +# Struct in UNION uses type union coercion (string preferred). +# The integer struct field is cast to Utf8. +query ? rowsort +SELECT s FROM t_struct_int UNION ALL SELECT s FROM t_struct_str; +---- +{val: 10} +{val: 10} +{val: 1} +{val: 5} +{val: 5} + +statement ok +DROP TABLE t_struct_int; + +statement ok +DROP TABLE t_struct_str; + +# List comparison: string elements are cast to Int64 (numeric preferred). +statement ok +CREATE TABLE t_list_int AS SELECT column1 as l FROM (VALUES ([1, 5, 10]), ([20, 30])); + +statement ok +CREATE TABLE t_list_str AS SELECT column1 as l FROM (VALUES (['5', '10']), (['20', '30'])); + +# Verify the element types are Int64 and Utf8 respectively +query T +SELECT arrow_typeof(l) FROM t_list_int LIMIT 1; +---- +List(Int64) + +query T +SELECT arrow_typeof(l) FROM t_list_str LIMIT 1; +---- +List(Utf8) + +query ? rowsort +SELECT t1.l FROM t_list_int t1, t_list_str t2 WHERE t1.l = t2.l; +---- +[20, 30] + +# List comparison casts the string list to the numeric element type (Int64), +# not the other way around. +query TT +EXPLAIN SELECT t1.l FROM t_list_int t1, t_list_str t2 WHERE t1.l = t2.l; +---- +logical_plan +01)Projection: t1.l +02)--Inner Join: t1.l = CAST(t2.l AS List(Int64)) +03)----SubqueryAlias: t1 +04)------TableScan: t_list_int projection=[l] +05)----SubqueryAlias: t2 +06)------TableScan: t_list_str projection=[l] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(l@0, CAST(t2.l AS List(Int64))@1)], projection=[l@0] +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--ProjectionExec: expr=[l@0 as l, CAST(l@0 AS List(Int64)) as CAST(t2.l AS List(Int64))] +04)----DataSourceExec: partitions=1, partition_sizes=[1] + +# List in UNION uses type union coercion (string preferred). +# The integer list elements are cast to Utf8. +query ? rowsort +SELECT l FROM t_list_int UNION ALL SELECT l FROM t_list_str; +---- +[1, 5, 10] +[20, 30] +[20, 30] +[5, 10] + +# Verify the UNION result type has Utf8 elements (not Int64). +query T +SELECT arrow_typeof(l) FROM (SELECT l FROM t_list_int UNION ALL SELECT l FROM t_list_str) LIMIT 1; +---- +List(Utf8) + +statement ok +DROP TABLE t_list_int; + +statement ok +DROP TABLE t_list_str; + +# Map comparison: string values are cast to Int64 (numeric preferred). +statement ok +CREATE TABLE t_map_int AS SELECT MAP {'a': 1, 'b': 5} as m; + +statement ok +CREATE TABLE t_map_str AS SELECT MAP {'a': '1', 'b': '5'} as m; + +# Verify the value types are Int64 and Utf8 respectively +query T +SELECT arrow_typeof(m) FROM t_map_int LIMIT 1; +---- +Map("entries": non-null Struct("key": non-null Utf8, "value": Int64), unsorted) + +query T +SELECT arrow_typeof(m) FROM t_map_str LIMIT 1; +---- +Map("entries": non-null Struct("key": non-null Utf8, "value": Utf8), unsorted) + +query ? rowsort +SELECT t1.m FROM t_map_int t1, t_map_str t2 WHERE t1.m = t2.m; +---- +{a: 1, b: 5} + +# Map in UNION uses type union coercion (string preferred). +# The integer map values are cast to Utf8. +query ? rowsort +SELECT m FROM t_map_int UNION ALL SELECT m FROM t_map_str; +---- +{a: 1, b: 5} +{a: 1, b: 5} + +# Verify the UNION result type has Utf8 values (not Int64). +query T +SELECT arrow_typeof(m) FROM (SELECT m FROM t_map_int UNION ALL SELECT m FROM t_map_str) LIMIT 1; +---- +Map("entries": non-null Struct("key": non-null Utf8, "value": Utf8), unsorted) + +statement ok +DROP TABLE t_map_int; + +statement ok +DROP TABLE t_map_str; + +# ------------------------------------------------- +# LIKE / regex on dictionary-encoded numeric columns should error, +# consistent with LIKE on plain numeric columns +# ------------------------------------------------- + +# Plain integer column: LIKE is not supported +statement error There isn't a common type to coerce Int64 and Utf8 in LIKE expression +SELECT * FROM t_int WHERE column1 LIKE '%5%'; + +# Dictionary-encoded integer column: should also error +statement error There isn't a common type to coerce Dictionary\(Int32, Int64\) and Utf8 in LIKE expression +SELECT arrow_cast(column1, 'Dictionary(Int32, Int64)') LIKE '%5%' FROM t_int; + +# Dictionary-encoded string column: LIKE works as normal +query B rowsort +SELECT arrow_cast('hello', 'Dictionary(Int32, Utf8)') LIKE '%ell%'; +---- +true + +# REE-encoded integer column: LIKE should also error +statement error There isn't a common type to coerce RunEndEncoded.* and Utf8 in LIKE expression +SELECT arrow_cast(column1, 'RunEndEncoded("run_ends": non-null Int32, "values": Int64)') LIKE '%5%' FROM t_int; + +# REE-encoded string column: LIKE works as normal +query B rowsort +SELECT arrow_cast('hello', 'RunEndEncoded("run_ends": non-null Int32, "values": Utf8)') LIKE '%ell%'; +---- +true + +# Dictionary-encoded integer column: regex should error +statement error Cannot infer common argument type for regex operation +SELECT arrow_cast(column1, 'Dictionary(Int32, Int64)') ~ '5' FROM t_int; + +# Dictionary-encoded string column: regex works as normal +query B rowsort +SELECT arrow_cast('hello', 'Dictionary(Int32, Utf8)') ~ 'ell'; +---- +true + +# REE-encoded integer column: regex should error +statement error Cannot infer common argument type for regex operation +SELECT arrow_cast(column1, 'RunEndEncoded("run_ends": non-null Int32, "values": Int64)') ~ '5' FROM t_int; + +# REE-encoded string column: regex works as normal +query B rowsort +SELECT arrow_cast('hello', 'RunEndEncoded("run_ends": non-null Int32, "values": Utf8)') ~ 'ell'; +---- +true + +# ------------------------------------------------- +# Cleanup +# ------------------------------------------------- + +statement ok +DROP TABLE t_int; + +statement ok +DROP TABLE t_float; + +statement ok +DROP TABLE t_str; + +# ------------------------------------------------- +# List element coercion should reject mixed +# numeric/string categories (same as array literals) +# ------------------------------------------------- + +# Array literal with mixed numeric/string elements errors +query error Cannot cast string 'a' to value of Int64 type +SELECT [1, 'a']; + +# MAP with mixed-category list keys should also error +query error Cannot cast string 'a' to value of Int64 type +SELECT MAP {[1,2,3]:1, ['a', 'b']:2}; + +# MAP with mixed-category list values should also error +query error Cannot cast string 'a' to value of Int64 type +SELECT MAP {'a':[1,2,3], 'b':['a', 'b']}; diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index fcc7805372674..5cf6e4817d475 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -1116,7 +1116,7 @@ select [ query ? select [[{x: 1, y: 2}], [{y: 4, x: 3}]]; ---- -[[{x: 1, y: 2}], [{x: 3, y: 4}]] +[[{y: 2, x: 1}], [{y: 4, x: 3}]] # Test array literal with float type coercion across elements query ? diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs index d216d4ecf3188..0a4048650fa2b 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs @@ -50,6 +50,8 @@ pub async fn from_project_rel( // For WindowFunctions, we need to wrap them in a Window relation. If there are duplicates, // we can do the window'ing only once, then the project will duplicate the result. // Order here doesn't matter since LPB::window_plan sorts the expressions. + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] + // Expr contains Arc with interior mutability but is intentionally used as hash key let mut window_exprs: HashSet = HashSet::new(); for expr in &p.expressions { let e = consumer diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 5dd4aa4e2be91..db2f4b587d923 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -568,7 +568,7 @@ async fn try_cast_decimal_to_int() -> Result<()> { #[tokio::test] async fn try_cast_decimal_to_string() -> Result<()> { - roundtrip("SELECT * FROM data WHERE a = TRY_CAST(b AS string)").await + roundtrip("SELECT * FROM data WHERE f = TRY_CAST(b AS string)").await } #[tokio::test] diff --git a/docs/source/library-user-guide/upgrading/54.0.0.md b/docs/source/library-user-guide/upgrading/54.0.0.md index 7fb1c4ece79ef..6dc08cc344e5f 100644 --- a/docs/source/library-user-guide/upgrading/54.0.0.md +++ b/docs/source/library-user-guide/upgrading/54.0.0.md @@ -25,6 +25,71 @@ in this section pertains to features and changes that have already been merged to the main branch and are awaiting release in this version. +### String/numeric comparison coercion now prefers numeric types + +Previously, comparing a numeric column with a string value (e.g., +`WHERE int_col > '100'`) coerced both sides to strings and performed a +lexicographic comparison. This produced surprising results — for example, +`5 > '100'` yielded `true` because `'5' > '1'` lexicographically, even +though `5 > 100` is `false` numerically. + +DataFusion now coerces the string side to the numeric type in comparison +contexts (`=`, `<`, `>`, `<=`, `>=`, `<>`, `IN`, `BETWEEN`, `CASE .. WHEN`, +`GREATEST`, `LEAST`). For example, `5 > '100'` will now yield `false`. + +**Who is affected:** + +- Queries that compare numeric values with string values +- Queries that use `IN` lists with mixed string and numeric types +- Queries that use `CASE expr WHEN` with mixed string and numeric types +- Queries that use `GREATEST` or `LEAST` with mixed string and numeric types + +**Behavioral changes:** + +| Expression | Old behavior | New behavior | +| ----------------------- | ------------------------------- | ------------------------------------------ | +| `int_col > '100'` | Lexicographic | Numeric | +| `float_col = '5'` | String `'5' != '5.0'` | Numeric `5.0 = 5.0` | +| `int_col = 'hello'` | String comparison, always false | Cast error | +| `str_col IN ('a', 1)` | Coerce to Utf8 | Cast error (`'a'` cannot be cast to Int64) | +| `float_col IN ('1.0')` | String `'1.0' != '1'` | Numeric `1.0 = 1.0` (correct) | +| `CASE str_col WHEN 1.0` | Coerce to Utf8 | Coerce to Float64 | +| `GREATEST(10, '9')` | Utf8 `'9'` (lexicographic) | Int64 `10` (numeric) | +| `LEAST(10, '9')` | Utf8 `10` (lexicographic) | Int64 `9` (numeric) | + +**Migration guide:** + +Most queries will produce more correct results with no changes needed. +However, queries that relied on the old string-comparison behavior may need +adjustment: + +- **Queries comparing numeric columns with non-numeric strings** (e.g., + `int_col = 'hello'` or `int_col > text_col` where `text_col` contains + non-numeric values) will now produce a cast error instead of silently + returning no rows. +- **Mixed-type `IN` lists** (e.g., `str_col IN ('a', 1)`) are now rejected. Use + consistent types for the `IN` list or add an explicit `CAST`. +- **Queries comparing integer columns with non-integer numeric string literals** (e.g., + `int_col = '99.99'`) will now produce a cast error because `'99.99'` + cannot be cast to an integer. Use a float column or adjust the literal. + +See [#15161](https://github.com/apache/datafusion/issues/15161) and +[PR #20426](https://github.com/apache/datafusion/pull/20426) for details. + +### `comparison_coercion_numeric` removed, replaced by `comparison_coercion` + +The `comparison_coercion_numeric` function has been removed. Its behavior +(preferring numeric types for string/numeric comparisons) is now the default in +`comparison_coercion`. A new function, `type_union_coercion`, handles contexts +where string types are preferred (`UNION`, `CASE THEN/ELSE`, `NVL2`). + +**Who is affected:** + +- Crates that call `comparison_coercion_numeric` directly +- Crates that call `comparison_coercion` and relied on its old + string-preferring behavior +- Crates that call `get_coerce_type_for_case_expression` + ### `ExecutionPlan::apply_expressions` is now a required method `apply_expressions` has been added as a **required** method on the `ExecutionPlan` trait (no default implementation). The same applies to the `FileSource` and `DataSource` traits. Any custom implementation of these traits must now implement `apply_expressions`.