-
Notifications
You must be signed in to change notification settings - Fork 0
3619: perf: Optimize some decimal expressions #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4b3fd48
d7495bd
91092a6
486eda2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -126,8 +126,9 @@ use datafusion_comet_proto::{ | |||||||||||||||||||||||||||||||||
| use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId; | ||||||||||||||||||||||||||||||||||
| use datafusion_comet_spark_expr::{ | ||||||||||||||||||||||||||||||||||
| ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, | ||||||||||||||||||||||||||||||||||
| GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RandExpr, | ||||||||||||||||||||||||||||||||||
| RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, | ||||||||||||||||||||||||||||||||||
| DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract, | ||||||||||||||||||||||||||||||||||
| NormalizeNaNAndZero, RandExpr, RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, | ||||||||||||||||||||||||||||||||||
| UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, | ||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||
| use itertools::Itertools; | ||||||||||||||||||||||||||||||||||
| use jni::objects::GlobalRef; | ||||||||||||||||||||||||||||||||||
|
|
@@ -376,10 +377,37 @@ impl PhysicalPlanner { | |||||||||||||||||||||||||||||||||
| ))) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| ExprStruct::CheckOverflow(expr) => { | ||||||||||||||||||||||||||||||||||
| let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; | ||||||||||||||||||||||||||||||||||
| let child = | ||||||||||||||||||||||||||||||||||
| self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; | ||||||||||||||||||||||||||||||||||
| let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); | ||||||||||||||||||||||||||||||||||
| let fail_on_error = expr.fail_on_error; | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // WideDecimalBinaryExpr already handles overflow — skip redundant check | ||||||||||||||||||||||||||||||||||
| if child | ||||||||||||||||||||||||||||||||||
| .as_any() | ||||||||||||||||||||||||||||||||||
| .downcast_ref::<WideDecimalBinaryExpr>() | ||||||||||||||||||||||||||||||||||
| .is_some() | ||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||
| return Ok(child); | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // Fuse Cast(Decimal128→Decimal128) + CheckOverflow into single rescale+check | ||||||||||||||||||||||||||||||||||
| if let Some(cast) = child.as_any().downcast_ref::<Cast>() { | ||||||||||||||||||||||||||||||||||
| if let ( | ||||||||||||||||||||||||||||||||||
| DataType::Decimal128(p_out, s_out), | ||||||||||||||||||||||||||||||||||
| Ok(DataType::Decimal128(_p_in, s_in)), | ||||||||||||||||||||||||||||||||||
| ) = (&data_type, cast.child.data_type(&input_schema)) | ||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||
| return Ok(Arc::new(DecimalRescaleCheckOverflow::new( | ||||||||||||||||||||||||||||||||||
| Arc::clone(&cast.child), | ||||||||||||||||||||||||||||||||||
| s_in, | ||||||||||||||||||||||||||||||||||
| *p_out, | ||||||||||||||||||||||||||||||||||
| *s_out, | ||||||||||||||||||||||||||||||||||
| fail_on_error, | ||||||||||||||||||||||||||||||||||
| ))); | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+395
to
+409
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tighten fusion precondition by validating cast output type too. The fusion currently validates only 🛠️ Suggested fix if let Some(cast) = child.as_any().downcast_ref::<Cast>() {
if let (
DataType::Decimal128(p_out, s_out),
Ok(DataType::Decimal128(_p_in, s_in)),
- ) = (&data_type, cast.child.data_type(&input_schema))
+ Ok(DataType::Decimal128(cast_p, cast_s)),
+ ) = (
+ &data_type,
+ cast.child.data_type(&input_schema),
+ cast.data_type(&input_schema),
+ )
{
- return Ok(Arc::new(DecimalRescaleCheckOverflow::new(
- Arc::clone(&cast.child),
- s_in,
- *p_out,
- *s_out,
- fail_on_error,
- )));
+ if cast_p == *p_out && cast_s == *s_out {
+ return Ok(Arc::new(DecimalRescaleCheckOverflow::new(
+ Arc::clone(&cast.child),
+ s_in,
+ *p_out,
+ *s_out,
+ fail_on_error,
+ )));
+ }
}
}🤖 Prompt for AI Agents
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value:useful; category:bug; feedback: The CodeRabbit AI reviewer is correct! Before fusing (by using WideDecimalBinaryExpr) the logic should check that the Cast's precision/scale match the requested output precision/scale pair. Prevents calculating wrong rescaling if they don't match |
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Ok(Arc::new(CheckOverflow::new( | ||||||||||||||||||||||||||||||||||
| child, | ||||||||||||||||||||||||||||||||||
| data_type, | ||||||||||||||||||||||||||||||||||
|
|
@@ -674,31 +702,31 @@ impl PhysicalPlanner { | |||||||||||||||||||||||||||||||||
| ) { | ||||||||||||||||||||||||||||||||||
| ( | ||||||||||||||||||||||||||||||||||
| DataFusionOperator::Plus | DataFusionOperator::Minus | DataFusionOperator::Multiply, | ||||||||||||||||||||||||||||||||||
| Ok(DataType::Decimal128(p1, s1)), | ||||||||||||||||||||||||||||||||||
| Ok(DataType::Decimal128(p2, s2)), | ||||||||||||||||||||||||||||||||||
| Ok(DataType::Decimal128(_p1, _s1)), | ||||||||||||||||||||||||||||||||||
| Ok(DataType::Decimal128(_p2, _s2)), | ||||||||||||||||||||||||||||||||||
| ) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus) | ||||||||||||||||||||||||||||||||||
| && max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8) | ||||||||||||||||||||||||||||||||||
| && max(_s1, _s2) as u8 + max(_p1 - _s1 as u8, _p2 - _s2 as u8) | ||||||||||||||||||||||||||||||||||
| >= DECIMAL128_MAX_PRECISION) | ||||||||||||||||||||||||||||||||||
| || (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) => | ||||||||||||||||||||||||||||||||||
| || (op == DataFusionOperator::Multiply | ||||||||||||||||||||||||||||||||||
| && _p1 + _p2 >= DECIMAL128_MAX_PRECISION) => | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+705
to
+711
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variables
Suggested change
|
||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||
| let data_type = return_type.map(to_arrow_datatype).unwrap(); | ||||||||||||||||||||||||||||||||||
| // For some Decimal128 operations, we need wider internal digits. | ||||||||||||||||||||||||||||||||||
| // Cast left and right to Decimal256 and cast the result back to Decimal128 | ||||||||||||||||||||||||||||||||||
| let left = Arc::new(Cast::new( | ||||||||||||||||||||||||||||||||||
| left, | ||||||||||||||||||||||||||||||||||
| DataType::Decimal256(p1, s1), | ||||||||||||||||||||||||||||||||||
| SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), | ||||||||||||||||||||||||||||||||||
| )); | ||||||||||||||||||||||||||||||||||
| let right = Arc::new(Cast::new( | ||||||||||||||||||||||||||||||||||
| right, | ||||||||||||||||||||||||||||||||||
| DataType::Decimal256(p2, s2), | ||||||||||||||||||||||||||||||||||
| SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), | ||||||||||||||||||||||||||||||||||
| )); | ||||||||||||||||||||||||||||||||||
| let child = Arc::new(BinaryExpr::new(left, op, right)); | ||||||||||||||||||||||||||||||||||
| Ok(Arc::new(Cast::new( | ||||||||||||||||||||||||||||||||||
| child, | ||||||||||||||||||||||||||||||||||
| data_type, | ||||||||||||||||||||||||||||||||||
| SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), | ||||||||||||||||||||||||||||||||||
| let (p_out, s_out) = match &data_type { | ||||||||||||||||||||||||||||||||||
| DataType::Decimal128(p, s) => (*p, *s), | ||||||||||||||||||||||||||||||||||
| dt => { | ||||||||||||||||||||||||||||||||||
| return Err(ExecutionError::GeneralError(format!( | ||||||||||||||||||||||||||||||||||
| "Expected Decimal128 return type, got {dt:?}" | ||||||||||||||||||||||||||||||||||
| ))) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||
| let wide_op = match op { | ||||||||||||||||||||||||||||||||||
| DataFusionOperator::Plus => WideDecimalOp::Add, | ||||||||||||||||||||||||||||||||||
| DataFusionOperator::Minus => WideDecimalOp::Subtract, | ||||||||||||||||||||||||||||||||||
| DataFusionOperator::Multiply => WideDecimalOp::Multiply, | ||||||||||||||||||||||||||||||||||
| _ => unreachable!(), | ||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||
| Ok(Arc::new(WideDecimalBinaryExpr::new( | ||||||||||||||||||||||||||||||||||
| left, right, wide_op, p_out, s_out, eval_mode, | ||||||||||||||||||||||||||||||||||
| ))) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| ( | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| // Licensed to the Apache Software Foundation (ASF) under one | ||
| // or more contributor license agreements. See the NOTICE file | ||
| // distributed with this work for additional information | ||
| // regarding copyright ownership. The ASF licenses this file | ||
| // to you under the Apache License, Version 2.0 (the | ||
| // "License"); you may not use this file except in compliance | ||
| // with the License. You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, | ||
| // software distributed under the License is distributed on an | ||
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| // KIND, either express or implied. See the License for the | ||
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| //! Benchmarks comparing the old Cast->BinaryExpr->Cast chain vs the fused WideDecimalBinaryExpr | ||
| //! for Decimal128 arithmetic that requires wider intermediate precision. | ||
|
|
||
| use arrow::array::builder::Decimal128Builder; | ||
| use arrow::array::RecordBatch; | ||
| use arrow::datatypes::{DataType, Field, Schema}; | ||
| use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; | ||
| use datafusion::logical_expr::Operator; | ||
| use datafusion::physical_expr::expressions::{BinaryExpr, Column}; | ||
| use datafusion::physical_expr::PhysicalExpr; | ||
| use datafusion_comet_spark_expr::{ | ||
| Cast, EvalMode, SparkCastOptions, WideDecimalBinaryExpr, WideDecimalOp, | ||
| }; | ||
| use std::sync::Arc; | ||
|
|
||
| const BATCH_SIZE: usize = 8192; | ||
|
|
||
| /// Build a RecordBatch with two Decimal128 columns. | ||
| fn make_decimal_batch(p1: u8, s1: i8, p2: u8, s2: i8) -> RecordBatch { | ||
| let mut left = Decimal128Builder::new(); | ||
| let mut right = Decimal128Builder::new(); | ||
| for i in 0..BATCH_SIZE as i128 { | ||
| left.append_value(123456789012345_i128 + i * 1000); | ||
| right.append_value(987654321098765_i128 - i * 1000); | ||
| } | ||
| let left = left.finish().with_data_type(DataType::Decimal128(p1, s1)); | ||
| let right = right.finish().with_data_type(DataType::Decimal128(p2, s2)); | ||
| let schema = Schema::new(vec![ | ||
| Field::new("left", DataType::Decimal128(p1, s1), false), | ||
| Field::new("right", DataType::Decimal128(p2, s2), false), | ||
| ]); | ||
| RecordBatch::try_new(Arc::new(schema), vec![Arc::new(left), Arc::new(right)]).unwrap() | ||
| } | ||
|
|
||
| /// Old approach: Cast(Decimal128->Decimal256) both sides, BinaryExpr, Cast(Decimal256->Decimal128). | ||
| fn build_old_expr( | ||
| p1: u8, | ||
| s1: i8, | ||
| p2: u8, | ||
| s2: i8, | ||
| op: Operator, | ||
| out_type: DataType, | ||
| ) -> Arc<dyn PhysicalExpr> { | ||
| let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0)); | ||
| let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1)); | ||
| let cast_opts = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false); | ||
| let left_cast = Arc::new(Cast::new( | ||
| left_col, | ||
| DataType::Decimal256(p1, s1), | ||
| cast_opts.clone(), | ||
| )); | ||
| let right_cast = Arc::new(Cast::new( | ||
| right_col, | ||
| DataType::Decimal256(p2, s2), | ||
| cast_opts.clone(), | ||
| )); | ||
| let binary = Arc::new(BinaryExpr::new(left_cast, op, right_cast)); | ||
| Arc::new(Cast::new(binary, out_type, cast_opts)) | ||
| } | ||
|
|
||
| /// New approach: single fused WideDecimalBinaryExpr. | ||
| fn build_new_expr(op: WideDecimalOp, p_out: u8, s_out: i8) -> Arc<dyn PhysicalExpr> { | ||
| let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0)); | ||
| let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1)); | ||
| Arc::new(WideDecimalBinaryExpr::new( | ||
| left_col, | ||
| right_col, | ||
| op, | ||
| p_out, | ||
| s_out, | ||
| EvalMode::Legacy, | ||
| )) | ||
| } | ||
|
|
||
| fn bench_case( | ||
| group: &mut criterion::BenchmarkGroup<criterion::measurement::WallTime>, | ||
| name: &str, | ||
| batch: &RecordBatch, | ||
| old_expr: &Arc<dyn PhysicalExpr>, | ||
| new_expr: &Arc<dyn PhysicalExpr>, | ||
| ) { | ||
| group.bench_with_input(BenchmarkId::new("old", name), batch, |b, batch| { | ||
| b.iter(|| old_expr.evaluate(batch).unwrap()); | ||
| }); | ||
| group.bench_with_input(BenchmarkId::new("fused", name), batch, |b, batch| { | ||
| b.iter(|| new_expr.evaluate(batch).unwrap()); | ||
| }); | ||
| } | ||
|
|
||
| fn criterion_benchmark(c: &mut Criterion) { | ||
| let mut group = c.benchmark_group("wide_decimal"); | ||
|
|
||
| // Case 1: Add with same scale - Decimal128(38,10) + Decimal128(38,10) -> Decimal128(38,10) | ||
| // Triggers wide path because max(s1,s2) + max(p1-s1, p2-s2) = 10 + 28 = 38 >= 38 | ||
| { | ||
| let batch = make_decimal_batch(38, 10, 38, 10); | ||
| let old = build_old_expr(38, 10, 38, 10, Operator::Plus, DataType::Decimal128(38, 10)); | ||
| let new = build_new_expr(WideDecimalOp::Add, 38, 10); | ||
| bench_case(&mut group, "add_same_scale", &batch, &old, &new); | ||
| } | ||
|
|
||
| // Case 2: Add with different scales - Decimal128(38,6) + Decimal128(38,4) -> Decimal128(38,6) | ||
| { | ||
| let batch = make_decimal_batch(38, 6, 38, 4); | ||
| let old = build_old_expr(38, 6, 38, 4, Operator::Plus, DataType::Decimal128(38, 6)); | ||
| let new = build_new_expr(WideDecimalOp::Add, 38, 6); | ||
| bench_case(&mut group, "add_diff_scale", &batch, &old, &new); | ||
| } | ||
|
|
||
| // Case 3: Multiply - Decimal128(20,10) * Decimal128(20,10) -> Decimal128(38,6) | ||
| // Triggers wide path because p1 + p2 = 40 >= 38 | ||
| { | ||
| let batch = make_decimal_batch(20, 10, 20, 10); | ||
| let old = build_old_expr( | ||
| 20, | ||
| 10, | ||
| 20, | ||
| 10, | ||
| Operator::Multiply, | ||
| DataType::Decimal128(38, 6), | ||
| ); | ||
| let new = build_new_expr(WideDecimalOp::Multiply, 38, 6); | ||
| bench_case(&mut group, "multiply", &batch, &old, &new); | ||
| } | ||
|
|
||
| // Case 4: Subtract with same scale - Decimal128(38,18) - Decimal128(38,18) -> Decimal128(38,18) | ||
| { | ||
| let batch = make_decimal_batch(38, 18, 38, 18); | ||
| let old = build_old_expr( | ||
| 38, | ||
| 18, | ||
| 38, | ||
| 18, | ||
| Operator::Minus, | ||
| DataType::Decimal128(38, 18), | ||
| ); | ||
| let new = build_new_expr(WideDecimalOp::Subtract, 38, 18); | ||
| bench_case(&mut group, "subtract", &batch, &old, &new); | ||
| } | ||
|
|
||
| group.finish(); | ||
| } | ||
|
|
||
| criterion_group!(benches, criterion_benchmark); | ||
| criterion_main!(benches); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CheckOverflow skip ignores its own data_type and fail_on_error
Medium Severity
When
CheckOverflowwraps aWideDecimalBinaryExpr, the code returns the child directly without verifying thatCheckOverflow'sdata_type(precision/scale) matches theWideDecimalBinaryExpr's output type. If Spark's plan specifies a different precision/scale inCheckOverflowthan the binary expression'sreturn_type, the output type would be wrong. Thefail_on_errorflag fromCheckOverflowis also silently discarded.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value:useful; category:bug; feedback: The Bugbot AI reviewer is correct! Before fusing (by using WideDecimalBinaryExpr) the logic should check that the Cast's precision/scale match the requested output precision/scale pair. Prevents calculating wrong rescaling if they don't match