Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 52 additions & 24 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ use datafusion_comet_proto::{
use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId;
use datafusion_comet_spark_expr::{
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RandExpr,
RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance,
DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract,
NormalizeNaNAndZero, RandExpr, RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson,
UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp,
};
use itertools::Itertools;
use jni::objects::GlobalRef;
Expand Down Expand Up @@ -376,10 +377,37 @@ impl PhysicalPlanner {
)))
}
ExprStruct::CheckOverflow(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let child =
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let fail_on_error = expr.fail_on_error;

// WideDecimalBinaryExpr already handles overflow — skip redundant check
if child
.as_any()
.downcast_ref::<WideDecimalBinaryExpr>()
.is_some()
{
return Ok(child);
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CheckOverflow skip ignores its own data_type and fail_on_error

Medium Severity

When CheckOverflow wraps a WideDecimalBinaryExpr, the code returns the child directly without verifying that CheckOverflow's data_type (precision/scale) matches the WideDecimalBinaryExpr's output type. If Spark's plan specifies a different precision/scale in CheckOverflow than the binary expression's return_type, the output type would be wrong. The fail_on_error flag from CheckOverflow is also silently discarded.

Fix in Cursor Fix in Web

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback: The Bugbot AI reviewer is correct! Before fusing (by using WideDecimalBinaryExpr) the logic should check that the Cast's precision/scale match the requested output precision/scale pair. Prevents calculating wrong rescaling if they don't match


// Fuse Cast(Decimal128→Decimal128) + CheckOverflow into single rescale+check
if let Some(cast) = child.as_any().downcast_ref::<Cast>() {
if let (
DataType::Decimal128(p_out, s_out),
Ok(DataType::Decimal128(_p_in, s_in)),
) = (&data_type, cast.child.data_type(&input_schema))
{
return Ok(Arc::new(DecimalRescaleCheckOverflow::new(
Arc::clone(&cast.child),
s_in,
*p_out,
*s_out,
fail_on_error,
)));
}
}
Comment on lines +395 to +409
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Tighten fusion precondition by validating cast output type too.

The fusion currently validates only cast.child input type. Please also require the Cast output type to match the CheckOverflow target decimal type before replacing the pair.

🛠️ Suggested fix
                 if let Some(cast) = child.as_any().downcast_ref::<Cast>() {
                     if let (
                         DataType::Decimal128(p_out, s_out),
                         Ok(DataType::Decimal128(_p_in, s_in)),
-                    ) = (&data_type, cast.child.data_type(&input_schema))
+                        Ok(DataType::Decimal128(cast_p, cast_s)),
+                    ) = (
+                        &data_type,
+                        cast.child.data_type(&input_schema),
+                        cast.data_type(&input_schema),
+                    )
                     {
-                        return Ok(Arc::new(DecimalRescaleCheckOverflow::new(
-                            Arc::clone(&cast.child),
-                            s_in,
-                            *p_out,
-                            *s_out,
-                            fail_on_error,
-                        )));
+                        if cast_p == *p_out && cast_s == *s_out {
+                            return Ok(Arc::new(DecimalRescaleCheckOverflow::new(
+                                Arc::clone(&cast.child),
+                                s_in,
+                                *p_out,
+                                *s_out,
+                                fail_on_error,
+                            )));
+                        }
                     }
                 }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@native/core/src/execution/planner.rs` around lines 395 - 409, The fusion
replaces a Cast+CheckOverflow pair without verifying the Cast's output type;
update the precondition in the block that matches
child.as_any().downcast_ref::<Cast>() so it also checks that
cast.data_type(&input_schema) equals the target DataType::Decimal128(p_out,
s_out) (i.e., the same precision/scale as the CheckOverflow target) before
creating DecimalRescaleCheckOverflow; locate the matching logic around data_type
and cast.child.data_type(&input_schema) and add the extra equality check on
cast.data_type(&input_schema) to tighten the fusion precondition.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback: The CodeRabbit AI reviewer is correct! Before fusing (by using WideDecimalBinaryExpr) the logic should check that the Cast's precision/scale match the requested output precision/scale pair. Prevents calculating wrong rescaling if they don't match


Ok(Arc::new(CheckOverflow::new(
child,
data_type,
Expand Down Expand Up @@ -674,31 +702,31 @@ impl PhysicalPlanner {
) {
(
DataFusionOperator::Plus | DataFusionOperator::Minus | DataFusionOperator::Multiply,
Ok(DataType::Decimal128(p1, s1)),
Ok(DataType::Decimal128(p2, s2)),
Ok(DataType::Decimal128(_p1, _s1)),
Ok(DataType::Decimal128(_p2, _s2)),
) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus)
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
&& max(_s1, _s2) as u8 + max(_p1 - _s1 as u8, _p2 - _s2 as u8)
>= DECIMAL128_MAX_PRECISION)
|| (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) =>
|| (op == DataFusionOperator::Multiply
&& _p1 + _p2 >= DECIMAL128_MAX_PRECISION) =>
Comment on lines +705 to +711
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variables _p1, _s1, _p2, and _s2 are prefixed with an underscore, which typically indicates they are unused. However, they are used in the if guard of this match arm. This can be confusing for future readers. It would be clearer to remove the underscore prefix from these variable names to signal that they are intentionally used.

Suggested change
Ok(DataType::Decimal128(_p1, _s1)),
Ok(DataType::Decimal128(_p2, _s2)),
) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus)
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
&& max(_s1, _s2) as u8 + max(_p1 - _s1 as u8, _p2 - _s2 as u8)
>= DECIMAL128_MAX_PRECISION)
|| (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) =>
|| (op == DataFusionOperator::Multiply
&& _p1 + _p2 >= DECIMAL128_MAX_PRECISION) =>
Ok(DataType::Decimal128(p1, s1)),
Ok(DataType::Decimal128(p2, s2)),
) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus)
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
>= DECIMAL128_MAX_PRECISION)
|| (op == DataFusionOperator::Multiply
&& p1 + p2 >= DECIMAL128_MAX_PRECISION) =>

{
let data_type = return_type.map(to_arrow_datatype).unwrap();
// For some Decimal128 operations, we need wider internal digits.
// Cast left and right to Decimal256 and cast the result back to Decimal128
let left = Arc::new(Cast::new(
left,
DataType::Decimal256(p1, s1),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
));
let right = Arc::new(Cast::new(
right,
DataType::Decimal256(p2, s2),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
));
let child = Arc::new(BinaryExpr::new(left, op, right));
Ok(Arc::new(Cast::new(
child,
data_type,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
let (p_out, s_out) = match &data_type {
DataType::Decimal128(p, s) => (*p, *s),
dt => {
return Err(ExecutionError::GeneralError(format!(
"Expected Decimal128 return type, got {dt:?}"
)))
}
};
let wide_op = match op {
DataFusionOperator::Plus => WideDecimalOp::Add,
DataFusionOperator::Minus => WideDecimalOp::Subtract,
DataFusionOperator::Multiply => WideDecimalOp::Multiply,
_ => unreachable!(),
};
Ok(Arc::new(WideDecimalBinaryExpr::new(
left, right, wide_op, p_out, s_out, eval_mode,
)))
}
(
Expand Down
4 changes: 4 additions & 0 deletions native/spark-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,7 @@ path = "tests/spark_expr_reg.rs"
[[bench]]
name = "cast_from_boolean"
harness = false

[[bench]]
name = "wide_decimal"
harness = false
162 changes: 162 additions & 0 deletions native/spark-expr/benches/wide_decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Benchmarks comparing the old Cast->BinaryExpr->Cast chain vs the fused WideDecimalBinaryExpr
//! for Decimal128 arithmetic that requires wider intermediate precision.

use arrow::array::builder::Decimal128Builder;
use arrow::array::RecordBatch;
use arrow::datatypes::{DataType, Field, Schema};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use datafusion::logical_expr::Operator;
use datafusion::physical_expr::expressions::{BinaryExpr, Column};
use datafusion::physical_expr::PhysicalExpr;
use datafusion_comet_spark_expr::{
Cast, EvalMode, SparkCastOptions, WideDecimalBinaryExpr, WideDecimalOp,
};
use std::sync::Arc;

const BATCH_SIZE: usize = 8192;

/// Build a RecordBatch with two Decimal128 columns.
fn make_decimal_batch(p1: u8, s1: i8, p2: u8, s2: i8) -> RecordBatch {
let mut left = Decimal128Builder::new();
let mut right = Decimal128Builder::new();
for i in 0..BATCH_SIZE as i128 {
left.append_value(123456789012345_i128 + i * 1000);
right.append_value(987654321098765_i128 - i * 1000);
}
let left = left.finish().with_data_type(DataType::Decimal128(p1, s1));
let right = right.finish().with_data_type(DataType::Decimal128(p2, s2));
let schema = Schema::new(vec![
Field::new("left", DataType::Decimal128(p1, s1), false),
Field::new("right", DataType::Decimal128(p2, s2), false),
]);
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(left), Arc::new(right)]).unwrap()
}

/// Old approach: Cast(Decimal128->Decimal256) both sides, BinaryExpr, Cast(Decimal256->Decimal128).
fn build_old_expr(
p1: u8,
s1: i8,
p2: u8,
s2: i8,
op: Operator,
out_type: DataType,
) -> Arc<dyn PhysicalExpr> {
let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
let cast_opts = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false);
let left_cast = Arc::new(Cast::new(
left_col,
DataType::Decimal256(p1, s1),
cast_opts.clone(),
));
let right_cast = Arc::new(Cast::new(
right_col,
DataType::Decimal256(p2, s2),
cast_opts.clone(),
));
let binary = Arc::new(BinaryExpr::new(left_cast, op, right_cast));
Arc::new(Cast::new(binary, out_type, cast_opts))
}

/// New approach: single fused WideDecimalBinaryExpr.
fn build_new_expr(op: WideDecimalOp, p_out: u8, s_out: i8) -> Arc<dyn PhysicalExpr> {
let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
Arc::new(WideDecimalBinaryExpr::new(
left_col,
right_col,
op,
p_out,
s_out,
EvalMode::Legacy,
))
}

fn bench_case(
group: &mut criterion::BenchmarkGroup<criterion::measurement::WallTime>,
name: &str,
batch: &RecordBatch,
old_expr: &Arc<dyn PhysicalExpr>,
new_expr: &Arc<dyn PhysicalExpr>,
) {
group.bench_with_input(BenchmarkId::new("old", name), batch, |b, batch| {
b.iter(|| old_expr.evaluate(batch).unwrap());
});
group.bench_with_input(BenchmarkId::new("fused", name), batch, |b, batch| {
b.iter(|| new_expr.evaluate(batch).unwrap());
});
}

fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("wide_decimal");

// Case 1: Add with same scale - Decimal128(38,10) + Decimal128(38,10) -> Decimal128(38,10)
// Triggers wide path because max(s1,s2) + max(p1-s1, p2-s2) = 10 + 28 = 38 >= 38
{
let batch = make_decimal_batch(38, 10, 38, 10);
let old = build_old_expr(38, 10, 38, 10, Operator::Plus, DataType::Decimal128(38, 10));
let new = build_new_expr(WideDecimalOp::Add, 38, 10);
bench_case(&mut group, "add_same_scale", &batch, &old, &new);
}

// Case 2: Add with different scales - Decimal128(38,6) + Decimal128(38,4) -> Decimal128(38,6)
{
let batch = make_decimal_batch(38, 6, 38, 4);
let old = build_old_expr(38, 6, 38, 4, Operator::Plus, DataType::Decimal128(38, 6));
let new = build_new_expr(WideDecimalOp::Add, 38, 6);
bench_case(&mut group, "add_diff_scale", &batch, &old, &new);
}

// Case 3: Multiply - Decimal128(20,10) * Decimal128(20,10) -> Decimal128(38,6)
// Triggers wide path because p1 + p2 = 40 >= 38
{
let batch = make_decimal_batch(20, 10, 20, 10);
let old = build_old_expr(
20,
10,
20,
10,
Operator::Multiply,
DataType::Decimal128(38, 6),
);
let new = build_new_expr(WideDecimalOp::Multiply, 38, 6);
bench_case(&mut group, "multiply", &batch, &old, &new);
}

// Case 4: Subtract with same scale - Decimal128(38,18) - Decimal128(38,18) -> Decimal128(38,18)
{
let batch = make_decimal_batch(38, 18, 38, 18);
let old = build_old_expr(
38,
18,
38,
18,
Operator::Minus,
DataType::Decimal128(38, 18),
);
let new = build_new_expr(WideDecimalOp::Subtract, 38, 18);
bench_case(&mut group, "subtract", &batch, &old, &new);
}

group.finish();
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
3 changes: 2 additions & 1 deletion native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ pub use json_funcs::{FromJson, ToJson};
pub use math_funcs::{
create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div,
spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex,
spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero,
spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr,
NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp,
};
pub use string_funcs::*;

Expand Down
Loading