From 52271152ff694b919fdedb3362514442bd33374b Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 25 Nov 2025 12:18:20 -0800 Subject: [PATCH 1/8] support_ansi_sum_decimal_input --- native/core/src/execution/planner.rs | 4 +- native/proto/src/proto/expr.proto | 2 +- native/spark-expr/benches/aggregate.rs | 4 +- .../spark-expr/src/agg_funcs/sum_decimal.rs | 61 +++++++++++++------ 4 files changed, 47 insertions(+), 24 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0fe04a5a41..3ab08063a8 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1870,7 +1870,9 @@ impl PhysicalPlanner { let builder = match datatype { DataType::Decimal128(_, _) => { - let func = AggregateUDF::new_from_impl(SumDecimal::try_new(datatype)?); + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let func = + AggregateUDF::new_from_impl(SumDecimal::try_new(datatype, eval_mode)?); AggregateExprBuilder::new(Arc::new(func), vec![child]) } _ => { diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index c9037dcd69..a7736f561a 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -120,7 +120,7 @@ message Count { message Sum { Expr child = 1; DataType datatype = 2; - bool fail_on_error = 3; + EvalMode eval_mode = 3; } message Min { diff --git a/native/spark-expr/benches/aggregate.rs b/native/spark-expr/benches/aggregate.rs index 3aa0233716..c3a1978d69 100644 --- a/native/spark-expr/benches/aggregate.rs +++ b/native/spark-expr/benches/aggregate.rs @@ -31,7 +31,7 @@ use datafusion::physical_expr::expressions::Column; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use datafusion::physical_plan::ExecutionPlan; -use datafusion_comet_spark_expr::AvgDecimal; +use datafusion_comet_spark_expr::{AvgDecimal, EvalMode}; use datafusion_comet_spark_expr::SumDecimal; use futures::StreamExt; use std::hint::black_box; @@ -97,7 +97,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("sum_decimal_comet", |b| { let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl( - SumDecimal::try_new(DataType::Decimal128(38, 10)).unwrap(), + SumDecimal::try_new(DataType::Decimal128(38, 10), EvalMode::Legacy).unwrap(), )); b.to_async(&rt).iter(|| { black_box(agg_test( diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs b/native/spark-expr/src/agg_funcs/sum_decimal.rs index cc25855902..c87ed508fc 100644 --- a/native/spark-expr/src/agg_funcs/sum_decimal.rs +++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs @@ -16,6 +16,7 @@ // under the License. use crate::utils::{build_bool_state, is_valid_decimal_precision}; +use crate::{arithmetic_overflow_error, EvalMode}; use arrow::array::{ cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array, }; @@ -40,11 +41,11 @@ pub struct SumDecimal { precision: u8, /// Decimal scale scale: i8, + eval_mode: EvalMode, } impl SumDecimal { - pub fn try_new(data_type: DataType) -> DFResult { - // The `data_type` is the SUM result type passed from Spark side + pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { let (precision, scale) = match data_type { DataType::Decimal128(p, s) => (p, s), _ => { @@ -58,6 +59,7 @@ impl SumDecimal { result_type: data_type, precision, scale, + eval_mode, }) } } @@ -71,6 +73,7 @@ impl AggregateUDFImpl for SumDecimal { Ok(Box::new(SumDecimalAccumulator::new( self.precision, self.scale, + self.eval_mode, ))) } @@ -109,6 +112,7 @@ impl AggregateUDFImpl for SumDecimal { Ok(Box::new(SumDecimalGroupsAccumulator::new( self.result_type.clone(), self.precision, + self.eval_mode, ))) } @@ -137,31 +141,36 @@ struct SumDecimalAccumulator { precision: u8, scale: i8, + eval_mode: EvalMode, } impl SumDecimalAccumulator { - fn new(precision: u8, scale: i8) -> Self { + fn new(precision: u8, scale: i8, eval_mode: EvalMode) -> Self { Self { sum: 0, is_empty: true, is_not_null: true, precision, scale, + eval_mode, } } - fn update_single(&mut self, values: &Decimal128Array, idx: usize) { + fn update_single(&mut self, values: &Decimal128Array, idx: usize) -> DFResult<()> { let v = unsafe { values.value_unchecked(idx) }; let (new_sum, is_overflow) = self.sum.overflowing_add(v); if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { - // Overflow: set buffer accumulator to null + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + } self.is_not_null = false; - return; + return Ok(()); } self.sum = new_sum; self.is_not_null = true; + Ok(()) } } @@ -187,14 +196,14 @@ impl Accumulator for SumDecimalAccumulator { if values.null_count() == 0 { for i in 0..data.len() { - self.update_single(data, i); + self.update_single(data, i)?; } } else { for i in 0..data.len() { if data.is_null(i) { continue; } - self.update_single(data, i); + self.update_single(data, i)?; } } @@ -205,16 +214,22 @@ impl Accumulator for SumDecimalAccumulator { // For each group: // 1. if `is_empty` is true, it means either there is no value or all values for the group // are null, in this case we'll return null - // 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In - // non-ANSI mode Spark returns null. - if self.is_empty - || !self.is_not_null - || !is_valid_decimal_precision(self.sum, self.precision) - { + // 2. if `is_empty` is false, but `is_not_null` is false, it means there's an overflow. + // In ANSI mode we return an error, in Try/Legacy mode we return null. + if self.is_empty { ScalarValue::new_primitive::( None, &DataType::Decimal128(self.precision, self.scale), ) + } else if !self.is_not_null { + if self.eval_mode == EvalMode::Ansi { + Err(DataFusionError::from(arithmetic_overflow_error("decimal"))) + } else { + ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + ) + } } else { ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale) } @@ -270,16 +285,18 @@ struct SumDecimalGroupsAccumulator { sum: Vec, result_type: DataType, precision: u8, + eval_mode: EvalMode, } impl SumDecimalGroupsAccumulator { - fn new(result_type: DataType, precision: u8) -> Self { + fn new(result_type: DataType, precision: u8, eval_mode: EvalMode) -> Self { Self { is_not_null: BooleanBufferBuilder::new(0), is_empty: BooleanBufferBuilder::new(0), sum: Vec::new(), result_type, precision, + eval_mode, } } @@ -288,15 +305,18 @@ impl SumDecimalGroupsAccumulator { } #[inline] - fn update_single(&mut self, group_index: usize, value: i128) { + fn update_single(&mut self, group_index: usize, value: i128) -> DFResult<()> { self.is_empty.set_bit(group_index, false); let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value); self.sum[group_index] = new_sum; if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { - // Overflow: set buffer accumulator to null + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + } self.is_not_null.set_bit(group_index, false); } + Ok(()) } } @@ -328,14 +348,14 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { let iter = group_indices.iter().zip(data.iter()); if values.null_count() == 0 { for (&group_index, &value) in iter { - self.update_single(group_index, value); + self.update_single(group_index, value)?; } } else { for (idx, (&group_index, &value)) in iter.enumerate() { if values.is_null(idx) { continue; } - self.update_single(group_index, value); + self.update_single(group_index, value)?; } } @@ -463,7 +483,7 @@ mod tests { #[test] fn invalid_data_type() { - assert!(SumDecimal::try_new(DataType::Int32).is_err()); + assert!(SumDecimal::try_new(DataType::Int32, EvalMode::Legacy).is_err()); } #[tokio::test] @@ -486,6 +506,7 @@ mod tests { let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new( data_type.clone(), + EvalMode::Legacy, )?)); let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) From dbe1e4c322fdb292a1afc6b4dfc63ad6a2b92941 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 25 Nov 2025 12:23:46 -0800 Subject: [PATCH 2/8] support_ansi_sum_decimal_input --- .../org/apache/comet/serde/aggregates.scala | 16 +- .../comet/exec/CometAggregateSuite.scala | 147 +++++++++++++++++- .../sql/comet/CometPlanStabilitySuite.scala | 3 +- 3 files changed, 150 insertions(+), 16 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index d00bbf4dfa..1918f8b653 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType import org.apache.comet.CometConf import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} +import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType} +import org.apache.comet.shims.CometEvalModeUtil object CometMin extends CometAggregateExpressionSerde[Min] { @@ -212,17 +213,6 @@ object CometAverage extends CometAggregateExpressionSerde[Average] { object CometSum extends CometAggregateExpressionSerde[Sum] { - override def getSupportLevel(sum: Sum): SupportLevel = { - sum.evalMode match { - case EvalMode.ANSI => - Incompatible(Some("ANSI mode is not supported")) - case EvalMode.TRY => - Incompatible(Some("TRY mode is not supported")) - case _ => - Compatible() - } - } - override def convert( aggExpr: AggregateExpression, sum: Sum, @@ -242,7 +232,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] { val builder = ExprOuterClass.Sum.newBuilder() builder.setChild(childExpr.get) builder.setDatatype(dataType.get) - builder.setFailOnError(sum.evalMode == EvalMode.ANSI) + builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(sum.evalMode))) Some( ExprOuterClass.AggExpr diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 7e577c5fda..ebf1c21451 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.optimizer.EliminateSorts import org.apache.spark.sql.comet.CometHashAggregateExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.{avg, count_distinct, sum} +import org.apache.spark.sql.functions.{avg, col, count_distinct, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataTypes, StructField, StructType} @@ -1471,6 +1471,151 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for decimal sum - null test") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq( + (null.asInstanceOf[java.math.BigDecimal], "a"), + (null.asInstanceOf[java.math.BigDecimal], "b")), + "null_tbl") { + val res = sql("SELECT sum(_1) FROM null_tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + } + } + } + + test("ANSI support for try_sum decimal - null test") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq( + (null.asInstanceOf[java.math.BigDecimal], "a"), + (null.asInstanceOf[java.math.BigDecimal], "b")), + "null_tbl") { + val res = sql("SELECT try_sum(_1) FROM null_tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + } + } + } + + test("ANSI support for decimal sum - null test (group by)") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq( + (null.asInstanceOf[java.math.BigDecimal], "a"), + (null.asInstanceOf[java.math.BigDecimal], "a"), + (null.asInstanceOf[java.math.BigDecimal], "b"), + (null.asInstanceOf[java.math.BigDecimal], "b"), + (null.asInstanceOf[java.math.BigDecimal], "b")), + "tbl") { + val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") + checkSparkAnswerAndOperator(res) + assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) + } + } + } + } + + test("ANSI support for try_sum decimal - null test (group by)") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq( + (null.asInstanceOf[java.math.BigDecimal], "a"), + (null.asInstanceOf[java.math.BigDecimal], "a"), + (null.asInstanceOf[java.math.BigDecimal], "b"), + (null.asInstanceOf[java.math.BigDecimal], "b"), + (null.asInstanceOf[java.math.BigDecimal], "b")), + "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1") + checkSparkAnswerAndOperator(res) + assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) + } + } + } + } + + protected def generateOverflowDecimalInputs: Seq[(java.math.BigDecimal, Int)] = { + val maxDec38_0 = new java.math.BigDecimal("99999999999999999999") + (1 to 50).flatMap(_ => Seq((maxDec38_0, 1))) + } + + test("ANSI support - decimal SUM function") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable(generateOverflowDecimalInputs, "tbl") { + val input = sql("SELECT _1 FROM tbl") + val res = sql("SELECT SUM(_1) FROM tbl") + if (ansiEnabled) { + checkSparkAnswerMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for decimal overflow in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + } + } + } + + test("ANSI support for decimal SUM - GROUP BY") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable(generateOverflowDecimalInputs, "tbl") { + val res = + sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2) + if (ansiEnabled) { + checkSparkAnswerMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for decimal overflow with GROUP BY in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + } + } + } + + test("try_sum decimal overflow") { + withParquetTable(generateOverflowDecimalInputs, "tbl") { + val res = sql("SELECT try_sum(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + } + + test("try_sum decimal overflow - with GROUP BY") { + withParquetTable(generateOverflowDecimalInputs, "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) + } + } + + test("try_sum decimal partial overflow - with GROUP BY") { + // Group 1 overflows, Group 2 succeeds + val data: Seq[(java.math.BigDecimal, Int)] = generateOverflowDecimalInputs ++ Seq( + (new java.math.BigDecimal(300), 2), + (new java.math.BigDecimal(200), 2)) + withParquetTable(data, "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2") + // Group 1 should be NULL, Group 2 should be 500 + checkSparkAnswerAndOperator(res) + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index 8f260e2ca8..1728ce5b27 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkContext import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.spark.sql.TPCDSBase import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast} -import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.Average import org.apache.spark.sql.catalyst.util.resourceToString import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec, SubqueryExec} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite @@ -226,7 +226,6 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true", // Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI support CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true", - CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true", // as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64 CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") { From 28c260e680ccc41c89f374947d4f89887294f023 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 25 Nov 2025 18:54:01 -0800 Subject: [PATCH 3/8] support_ansi_sum_decimal_input --- native/spark-expr/benches/aggregate.rs | 2 +- .../spark-expr/src/agg_funcs/sum_decimal.rs | 314 ++++++++++-------- .../comet/exec/CometAggregateSuite.scala | 1 - 3 files changed, 177 insertions(+), 140 deletions(-) diff --git a/native/spark-expr/benches/aggregate.rs b/native/spark-expr/benches/aggregate.rs index c3a1978d69..72628975b3 100644 --- a/native/spark-expr/benches/aggregate.rs +++ b/native/spark-expr/benches/aggregate.rs @@ -31,8 +31,8 @@ use datafusion::physical_expr::expressions::Column; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use datafusion::physical_plan::ExecutionPlan; -use datafusion_comet_spark_expr::{AvgDecimal, EvalMode}; use datafusion_comet_spark_expr::SumDecimal; +use datafusion_comet_spark_expr::{AvgDecimal, EvalMode}; use futures::StreamExt; use std::hint::black_box; use std::sync::Arc; diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs b/native/spark-expr/src/agg_funcs/sum_decimal.rs index c87ed508fc..5cdbb5e845 100644 --- a/native/spark-expr/src/agg_funcs/sum_decimal.rs +++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs @@ -15,20 +15,19 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::{build_bool_state, is_valid_decimal_precision}; +use crate::utils::is_valid_decimal_precision; use crate::{arithmetic_overflow_error, EvalMode}; use arrow::array::{ cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array, }; use arrow::datatypes::{DataType, Field, FieldRef}; -use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer}; use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::Volatility::Immutable; use datafusion::logical_expr::{ Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, }; -use std::{any::Any, ops::BitAnd, sync::Arc}; +use std::{any::Any, sync::Arc}; #[derive(Debug, PartialEq, Eq, Hash)] pub struct SumDecimal { @@ -78,15 +77,13 @@ impl AggregateUDFImpl for SumDecimal { } fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { - let fields = vec![ - Arc::new(Field::new( - self.name(), - self.result_type.clone(), - self.is_nullable(), - )), + // For decimal sum, we always track is_empty regardless of eval_mode + // This matches Spark's behavior where DecimalType always uses shouldTrackIsEmpty = true + let data_type = self.result_type.clone(); + Ok(vec![ + Arc::new(Field::new("sum", data_type, true)), Arc::new(Field::new("is_empty", DataType::Boolean, false)), - ]; - Ok(fields) + ]) } fn name(&self) -> &str { @@ -135,10 +132,8 @@ impl AggregateUDFImpl for SumDecimal { #[derive(Debug)] struct SumDecimalAccumulator { - sum: i128, + sum: Option, is_empty: bool, - is_not_null: bool, - precision: u8, scale: i8, eval_mode: EvalMode, @@ -146,10 +141,11 @@ struct SumDecimalAccumulator { impl SumDecimalAccumulator { fn new(precision: u8, scale: i8, eval_mode: EvalMode) -> Self { + // For decimal sum, always track is_empty regardless of eval_mode + // This matches Spark's behavior where DecimalType always uses shouldTrackIsEmpty = true Self { - sum: 0, + sum: Some(0), is_empty: true, - is_not_null: true, precision, scale, eval_mode, @@ -158,18 +154,20 @@ impl SumDecimalAccumulator { fn update_single(&mut self, values: &Decimal128Array, idx: usize) -> DFResult<()> { let v = unsafe { values.value_unchecked(idx) }; - let (new_sum, is_overflow) = self.sum.overflowing_add(v); + let running_sum = self.sum.unwrap_or(0); + let (new_sum, is_overflow) = running_sum.overflowing_add(v); if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { if self.eval_mode == EvalMode::Ansi { return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); } - self.is_not_null = false; + self.sum = None; + self.is_empty = false; return Ok(()); } - self.sum = new_sum; - self.is_not_null = true; + self.sum = Some(new_sum); + self.is_empty = false; Ok(()) } } @@ -183,16 +181,17 @@ impl Accumulator for SumDecimalAccumulator { values.len() ); - if !self.is_empty && !self.is_not_null { - // This means there's a overflow in decimal, so we will just skip the rest - // of the computation + // For decimal sum, always check for overflow regardless of eval_mode + if !self.is_empty && self.sum.is_none() { return Ok(()); } let values = &values[0]; let data = values.as_primitive::(); - self.is_empty = self.is_empty && values.len() == values.null_count(); + if values.len() == values.null_count() { + return Ok(()); + } if values.null_count() == 0 { for i in 0..data.len() { @@ -211,27 +210,21 @@ impl Accumulator for SumDecimalAccumulator { } fn evaluate(&mut self) -> DFResult { - // For each group: - // 1. if `is_empty` is true, it means either there is no value or all values for the group - // are null, in this case we'll return null - // 2. if `is_empty` is false, but `is_not_null` is false, it means there's an overflow. - // In ANSI mode we return an error, in Try/Legacy mode we return null. if self.is_empty { ScalarValue::new_primitive::( None, &DataType::Decimal128(self.precision, self.scale), ) - } else if !self.is_not_null { - if self.eval_mode == EvalMode::Ansi { - Err(DataFusionError::from(arithmetic_overflow_error("decimal"))) - } else { - ScalarValue::new_primitive::( + } else { + match self.sum { + Some(sum_value) => { + ScalarValue::try_new_decimal128(sum_value, self.precision, self.scale) + } + None => ScalarValue::new_primitive::( None, &DataType::Decimal128(self.precision, self.scale), - ) + ), } - } else { - ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale) } } @@ -240,38 +233,71 @@ impl Accumulator for SumDecimalAccumulator { } fn state(&mut self) -> DFResult> { - let sum = if self.is_not_null { - ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)? - } else { - ScalarValue::new_primitive::( + let sum = match self.sum { + Some(sum_value) => { + ScalarValue::try_new_decimal128(sum_value, self.precision, self.scale)? + } + None => ScalarValue::new_primitive::( None, &DataType::Decimal128(self.precision, self.scale), - )? + )?, }; + + // For decimal sum, always return 2 state values regardless of eval_mode Ok(vec![sum, ScalarValue::from(self.is_empty)]) } fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { + // For decimal sum, always expect 2 state arrays regardless of eval_mode assert_eq!( states.len(), 2, - "Expect two element in 'states' but found {}", + "Expect two elements in 'states' but found {}", states.len() ); assert_eq!(states[0].len(), 1); assert_eq!(states[1].len(), 1); - let that_sum = states[0].as_primitive::(); - let that_is_empty = states[1].as_any().downcast_ref::().unwrap(); + let that_sum_array = states[0].as_primitive::(); + let that_sum = if that_sum_array.is_null(0) { + None + } else { + Some(that_sum_array.value(0)) + }; + + let that_is_empty = states[1].as_boolean().value(0); + let that_overflowed = !that_is_empty && that_sum.is_none(); + let this_overflowed = !self.is_empty && self.sum.is_none(); - let this_overflow = !self.is_empty && !self.is_not_null; - let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0); + if that_overflowed || this_overflowed { + self.sum = None; + self.is_empty = false; + return Ok(()); + } - self.is_not_null = !this_overflow && !that_overflow; - self.is_empty = self.is_empty && that_is_empty.value(0); + if that_is_empty { + return Ok(()); + } - if self.is_not_null { - self.sum += that_sum.value(0); + if self.is_empty { + self.sum = that_sum; + self.is_empty = false; + return Ok(()); + } + + let left = self.sum.unwrap(); + let right = that_sum.unwrap(); + let (new_sum, is_overflow) = left.overflowing_add(right); + + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + } else { + self.sum = None; + self.is_empty = false; + } + } else { + self.sum = Some(new_sum); } Ok(()) @@ -279,10 +305,8 @@ impl Accumulator for SumDecimalAccumulator { } struct SumDecimalGroupsAccumulator { - // Whether aggregate buffer for a particular group is null. True indicates it is not null. - is_not_null: BooleanBufferBuilder, - is_empty: BooleanBufferBuilder, - sum: Vec, + sum: Vec>, + is_empty: Vec, result_type: DataType, precision: u8, eval_mode: EvalMode, @@ -291,42 +315,43 @@ struct SumDecimalGroupsAccumulator { impl SumDecimalGroupsAccumulator { fn new(result_type: DataType, precision: u8, eval_mode: EvalMode) -> Self { Self { - is_not_null: BooleanBufferBuilder::new(0), - is_empty: BooleanBufferBuilder::new(0), sum: Vec::new(), + is_empty: Vec::new(), result_type, precision, eval_mode, } } - fn is_overflow(&self, index: usize) -> bool { - !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index) + fn resize_helper(&mut self, total_num_groups: usize) { + // For decimal sum, always initialize properly regardless of eval_mode + self.sum.resize(total_num_groups, Some(0)); + self.is_empty.resize(total_num_groups, true); } #[inline] fn update_single(&mut self, group_index: usize, value: i128) -> DFResult<()> { - self.is_empty.set_bit(group_index, false); - let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value); - self.sum[group_index] = new_sum; + // For decimal sum, always check for overflow regardless of eval_mode + if !self.is_empty[group_index] && self.sum[group_index].is_none() { + return Ok(()); + } + + let running_sum = self.sum[group_index].unwrap_or(0); + let (new_sum, is_overflow) = running_sum.overflowing_add(value); if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { if self.eval_mode == EvalMode::Ansi { return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); } - self.is_not_null.set_bit(group_index, false); + self.sum[group_index] = None; + } else { + self.sum[group_index] = Some(new_sum); } + self.is_empty[group_index] = false; Ok(()) } } -fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) { - if builder.len() < capacity { - let additional = capacity - builder.len(); - builder.append_n(additional, true); - } -} - impl GroupsAccumulator for SumDecimalGroupsAccumulator { fn update_batch( &mut self, @@ -340,10 +365,7 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { let values = values[0].as_primitive::(); let data = values.values(); - // Update size for the accumulate states - self.sum.resize(total_num_groups, 0); - ensure_bit_capacity(&mut self.is_empty, total_num_groups); - ensure_bit_capacity(&mut self.is_not_null, total_num_groups); + self.resize_helper(total_num_groups); let iter = group_indices.iter().zip(data.iter()); if values.null_count() == 0 { @@ -363,42 +385,45 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { } fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { - // For each group: - // 1. if `is_empty` is true, it means either there is no value or all values for the group - // are null, in this case we'll return null - // 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In - // non-ANSI mode Spark returns null. - let result = emit_to.take_needed(&mut self.sum); - result.iter().enumerate().for_each(|(i, &v)| { - if !is_valid_decimal_precision(v, self.precision) { - self.is_not_null.set_bit(i, false); - } - }); - - let nulls = build_bool_state(&mut self.is_not_null, &emit_to); - let is_empty = build_bool_state(&mut self.is_empty, &emit_to); - let x = (!&is_empty).bitand(&nulls); + match emit_to { + EmitTo::All => { + let result = Decimal128Array::from_iter( + self.sum + .iter() + .zip(self.is_empty.iter()) + .map(|(&sum, &empty)| if empty { None } else { sum }), + ) + .with_data_type(self.result_type.clone()); - let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x))) - .with_data_type(self.result_type.clone()); + self.sum.clear(); + self.is_empty.clear(); + Ok(Arc::new(result)) + } + EmitTo::First(n) => { + let result = Decimal128Array::from_iter( + self.sum + .drain(..n) + .zip(self.is_empty.drain(..n)) + .map(|(sum, empty)| if empty { None } else { sum }), + ) + .with_data_type(self.result_type.clone()); - Ok(Arc::new(result)) + Ok(Arc::new(result)) + } + } } fn state(&mut self, emit_to: EmitTo) -> DFResult> { - let nulls = build_bool_state(&mut self.is_not_null, &emit_to); - let nulls = Some(NullBuffer::new(nulls)); + let sums = emit_to.take_needed(&mut self.sum); - let sum = emit_to.take_needed(&mut self.sum); - let sum = Decimal128Array::new(sum.into(), nulls.clone()) + let sum_array = Decimal128Array::from_iter(sums.iter().copied()) .with_data_type(self.result_type.clone()); - let is_empty = build_bool_state(&mut self.is_empty, &emit_to); - let is_empty = BooleanArray::new(is_empty, None); - + // For decimal sum, always return 2 state arrays regardless of eval_mode + let is_empty = emit_to.take_needed(&mut self.is_empty); Ok(vec![ - Arc::new(sum) as ArrayRef, - Arc::new(is_empty) as ArrayRef, + Arc::new(sum_array), + Arc::new(BooleanArray::from(is_empty)), ]) } @@ -409,57 +434,70 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> DFResult<()> { + assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + + self.resize_helper(total_num_groups); + + // For decimal sum, always expect 2 arrays regardless of eval_mode assert_eq!( values.len(), 2, "Expected two arrays: 'sum' and 'is_empty', but found {}", values.len() ); - assert!(opt_filter.is_none(), "opt_filter is not supported yet"); - // Make sure we have enough capacity for the additional groups - self.sum.resize(total_num_groups, 0); - ensure_bit_capacity(&mut self.is_empty, total_num_groups); - ensure_bit_capacity(&mut self.is_not_null, total_num_groups); - - let that_sum = &values[0]; - let that_sum = that_sum.as_primitive::(); - let that_is_empty = &values[1]; - let that_is_empty = that_is_empty - .as_any() - .downcast_ref::() - .unwrap(); + let that_sum = values[0].as_primitive::(); + let that_is_empty = values[1].as_boolean(); + + for (idx, &group_index) in group_indices.iter().enumerate() { + let that_sum_val = if that_sum.is_null(idx) { + None + } else { + Some(that_sum.value(idx)) + }; + + let that_is_empty_val = that_is_empty.value(idx); + let that_overflowed = !that_is_empty_val && that_sum_val.is_none(); + let this_overflowed = !self.is_empty[group_index] && self.sum[group_index].is_none(); + + if that_overflowed || this_overflowed { + self.sum[group_index] = None; + self.is_empty[group_index] = false; + continue; + } - group_indices - .iter() - .enumerate() - .for_each(|(idx, &group_index)| unsafe { - let this_overflow = self.is_overflow(group_index); - let that_is_empty = that_is_empty.value_unchecked(idx); - let that_overflow = !that_is_empty && that_sum.is_null(idx); - let is_overflow = this_overflow || that_overflow; - - // This part follows the logic in Spark: - // `org.apache.spark.sql.catalyst.expressions.aggregate.Sum` - self.is_not_null.set_bit(group_index, !is_overflow); - self.is_empty.set_bit( - group_index, - self.is_empty.get_bit(group_index) && that_is_empty, - ); - if !is_overflow { - // .. otherwise, the sum value for this particular index must not be null, - // and thus we merge both values and update this sum. - self.sum[group_index] += that_sum.value_unchecked(idx); + if that_is_empty_val { + continue; + } + + if self.is_empty[group_index] { + self.sum[group_index] = that_sum_val; + self.is_empty[group_index] = false; + continue; + } + + let left = self.sum[group_index].unwrap(); + let right = that_sum_val.unwrap(); + let (new_sum, is_overflow) = left.overflowing_add(right); + + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + } else { + self.sum[group_index] = None; + self.is_empty[group_index] = false; } - }); + } else { + self.sum[group_index] = Some(new_sum); + } + } Ok(()) } fn size(&self) -> usize { - self.sum.capacity() * std::mem::size_of::() - + self.is_empty.capacity() / 8 - + self.is_not_null.capacity() / 8 + self.sum.capacity() * std::mem::size_of::>() + + self.is_empty.capacity() * std::mem::size_of::() } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index ebf1c21451..5c2a94b6cd 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1550,7 +1550,6 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { withParquetTable(generateOverflowDecimalInputs, "tbl") { - val input = sql("SELECT _1 FROM tbl") val res = sql("SELECT SUM(_1) FROM tbl") if (ansiEnabled) { checkSparkAnswerMaybeThrows(res) match { From 5df4821ced07adce0d9639396dd27aa29b79f31a Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Fri, 28 Nov 2025 08:06:17 -0800 Subject: [PATCH 4/8] support_ansi_sum_decimal_input --- .../spark-expr/src/agg_funcs/sum_decimal.rs | 81 +++++++++++-------- .../org/apache/comet/serde/aggregates.scala | 11 +++ .../apache/comet/exec/CometExecSuite.scala | 2 +- 3 files changed, 59 insertions(+), 35 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs b/native/spark-expr/src/agg_funcs/sum_decimal.rs index 5cdbb5e845..a818a23710 100644 --- a/native/spark-expr/src/agg_funcs/sum_decimal.rs +++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs @@ -153,6 +153,11 @@ impl SumDecimalAccumulator { } fn update_single(&mut self, values: &Decimal128Array, idx: usize) -> DFResult<()> { + // If already overflowed (sum is None but not empty), stay in overflow state + if !self.is_empty && self.sum.is_none() { + return Ok(()); + } + let v = unsafe { values.value_unchecked(idx) }; let running_sum = self.sum.unwrap_or(0); let (new_sum, is_overflow) = running_sum.overflowing_add(v); @@ -181,7 +186,7 @@ impl Accumulator for SumDecimalAccumulator { values.len() ); - // For decimal sum, always check for overflow regardless of eval_mode + // For decimal sum, always check for overflow regardless of eval_mode (per Spark's expectation) if !self.is_empty && self.sum.is_none() { return Ok(()); } @@ -189,23 +194,19 @@ impl Accumulator for SumDecimalAccumulator { let values = &values[0]; let data = values.as_primitive::(); - if values.len() == values.null_count() { + // Update is_empty: it remains true only if it was true AND all values are null + self.is_empty = self.is_empty && values.len() == values.null_count(); + + if self.is_empty { return Ok(()); } - if values.null_count() == 0 { - for i in 0..data.len() { - self.update_single(data, i)?; - } - } else { - for i in 0..data.len() { - if data.is_null(i) { - continue; - } - self.update_single(data, i)?; + for i in 0..data.len() { + if data.is_null(i) { + continue; } + self.update_single(data, i)?; } - Ok(()) } @@ -217,13 +218,15 @@ impl Accumulator for SumDecimalAccumulator { ) } else { match self.sum { - Some(sum_value) => { + Some(sum_value) if is_valid_decimal_precision(sum_value, self.precision) => { ScalarValue::try_new_decimal128(sum_value, self.precision, self.scale) } - None => ScalarValue::new_primitive::( - None, - &DataType::Decimal128(self.precision, self.scale), - ), + _ => { + ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + ) + } } } } @@ -237,10 +240,12 @@ impl Accumulator for SumDecimalAccumulator { Some(sum_value) => { ScalarValue::try_new_decimal128(sum_value, self.precision, self.scale)? } - None => ScalarValue::new_primitive::( - None, - &DataType::Decimal128(self.precision, self.scale), - )?, + None => { + ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + )? + } }; // For decimal sum, always return 2 state values regardless of eval_mode @@ -387,12 +392,16 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { match emit_to { EmitTo::All => { - let result = Decimal128Array::from_iter( - self.sum - .iter() - .zip(self.is_empty.iter()) - .map(|(&sum, &empty)| if empty { None } else { sum }), - ) + let result = Decimal128Array::from_iter(self.sum.iter().zip(self.is_empty.iter()).map(|(&sum, &empty)| { + if empty { + None + } else { + match sum { + Some(v) if is_valid_decimal_precision(v, self.precision) => Some(v), + _ => None, + } + } + })) .with_data_type(self.result_type.clone()); self.sum.clear(); @@ -400,12 +409,16 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { Ok(Arc::new(result)) } EmitTo::First(n) => { - let result = Decimal128Array::from_iter( - self.sum - .drain(..n) - .zip(self.is_empty.drain(..n)) - .map(|(sum, empty)| if empty { None } else { sum }), - ) + let result = Decimal128Array::from_iter(self.sum.drain(..n).zip(self.is_empty.drain(..n)).map(|(sum, empty)| { + if empty { + None + } else { + match sum { + Some(v) if is_valid_decimal_precision(v, self.precision) => Some(v), + _ => None, + } + } + })) .with_data_type(self.result_type.clone()); Ok(Arc::new(result)) diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 1918f8b653..8ab568dc83 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -213,6 +213,17 @@ object CometAverage extends CometAggregateExpressionSerde[Average] { object CometSum extends CometAggregateExpressionSerde[Sum] { + override def getSupportLevel(sum: Sum): SupportLevel = { + sum.evalMode match { + case EvalMode.ANSI if !sum.dataType.isInstanceOf[DecimalType] => + Incompatible(Some("ANSI mode for non decimal inputs is not supported")) + case EvalMode.TRY if !sum.dataType.isInstanceOf[DecimalType] => + Incompatible(Some("TRY mode for non decimal inputs is not supported")) + case _ => + Compatible() + } + } + override def convert( aggExpr: AggregateExpression, sum: Sum, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 9f9df73a91..edb6ccc6ef 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -469,7 +469,7 @@ class CometExecSuite extends CometTestBase { val dayTimeDf = Seq(106751991L, 106751991L, 2L) .map(Duration.ofDays) .toDF("v") - Seq(longDf, yearMonthDf, dayTimeDf).foreach { df => + Seq(longDf).foreach { df => checkSparkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)")) } } From 272b661b0cd6a7c5bc72ad4b026913d56142d5cf Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Fri, 28 Nov 2025 08:41:30 -0800 Subject: [PATCH 5/8] support_ansi_sum_decimal_input --- .../spark-expr/src/agg_funcs/sum_decimal.rs | 74 ++++++++++--------- 1 file changed, 41 insertions(+), 33 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs b/native/spark-expr/src/agg_funcs/sum_decimal.rs index a818a23710..50645391fd 100644 --- a/native/spark-expr/src/agg_funcs/sum_decimal.rs +++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs @@ -221,12 +221,10 @@ impl Accumulator for SumDecimalAccumulator { Some(sum_value) if is_valid_decimal_precision(sum_value, self.precision) => { ScalarValue::try_new_decimal128(sum_value, self.precision, self.scale) } - _ => { - ScalarValue::new_primitive::( - None, - &DataType::Decimal128(self.precision, self.scale), - ) - } + _ => ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + ), } } } @@ -240,12 +238,10 @@ impl Accumulator for SumDecimalAccumulator { Some(sum_value) => { ScalarValue::try_new_decimal128(sum_value, self.precision, self.scale)? } - None => { - ScalarValue::new_primitive::( - None, - &DataType::Decimal128(self.precision, self.scale), - )? - } + None => ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + )?, }; // For decimal sum, always return 2 state values regardless of eval_mode @@ -392,33 +388,45 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { match emit_to { EmitTo::All => { - let result = Decimal128Array::from_iter(self.sum.iter().zip(self.is_empty.iter()).map(|(&sum, &empty)| { - if empty { - None - } else { - match sum { - Some(v) if is_valid_decimal_precision(v, self.precision) => Some(v), - _ => None, - } - } - })) - .with_data_type(self.result_type.clone()); + let result = + Decimal128Array::from_iter(self.sum.iter().zip(self.is_empty.iter()).map( + |(&sum, &empty)| { + if empty { + None + } else { + match sum { + Some(v) if is_valid_decimal_precision(v, self.precision) => { + Some(v) + } + _ => None, + } + } + }, + )) + .with_data_type(self.result_type.clone()); self.sum.clear(); self.is_empty.clear(); Ok(Arc::new(result)) } EmitTo::First(n) => { - let result = Decimal128Array::from_iter(self.sum.drain(..n).zip(self.is_empty.drain(..n)).map(|(sum, empty)| { - if empty { - None - } else { - match sum { - Some(v) if is_valid_decimal_precision(v, self.precision) => Some(v), - _ => None, - } - } - })) + let result = Decimal128Array::from_iter( + self.sum + .drain(..n) + .zip(self.is_empty.drain(..n)) + .map(|(sum, empty)| { + if empty { + None + } else { + match sum { + Some(v) if is_valid_decimal_precision(v, self.precision) => { + Some(v) + } + _ => None, + } + } + }), + ) .with_data_type(self.result_type.clone()); Ok(Arc::new(result)) From ac6b493728633649a44d5ff1fabd592edf79d8a4 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Fri, 28 Nov 2025 18:12:39 -0800 Subject: [PATCH 6/8] support_ansi_sum_decimal_input_fix_scala --- spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index edb6ccc6ef..9f9df73a91 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -469,7 +469,7 @@ class CometExecSuite extends CometTestBase { val dayTimeDf = Seq(106751991L, 106751991L, 2L) .map(Duration.ofDays) .toDF("v") - Seq(longDf).foreach { df => + Seq(longDf, yearMonthDf, dayTimeDf).foreach { df => checkSparkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)")) } } From 25aeb33e619f543f3033369523f025d90bf8ad0a Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Mon, 1 Dec 2025 14:08:14 -0800 Subject: [PATCH 7/8] support_ansi_sum_decimal_input_fix_plan_failure_tests --- .../comet/exec/CometAggregateSuite.scala | 61 ++++++++++++------- .../sql/comet/CometPlanStabilitySuite.scala | 3 +- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 5c2a94b6cd..060579b2ba 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.optimizer.EliminateSorts import org.apache.spark.sql.comet.CometHashAggregateExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -1473,7 +1474,9 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("ANSI support for decimal sum - null test") { Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { withParquetTable( Seq( (null.asInstanceOf[java.math.BigDecimal], "a"), @@ -1489,7 +1492,9 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("ANSI support for try_sum decimal - null test") { Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { withParquetTable( Seq( (null.asInstanceOf[java.math.BigDecimal], "a"), @@ -1505,7 +1510,9 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("ANSI support for decimal sum - null test (group by)") { Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { withParquetTable( Seq( (null.asInstanceOf[java.math.BigDecimal], "a"), @@ -1524,7 +1531,9 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("ANSI support for try_sum decimal - null test (group by)") { Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { withParquetTable( Seq( (null.asInstanceOf[java.math.BigDecimal], "a"), @@ -1546,9 +1555,11 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { (1 to 50).flatMap(_ => Seq((maxDec38_0, 1))) } - test("ANSI support - decimal SUM function") { + test("ANSI support for decimal SUM function") { Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { withParquetTable(generateOverflowDecimalInputs, "tbl") { val res = sql("SELECT SUM(_1) FROM tbl") if (ansiEnabled) { @@ -1569,7 +1580,9 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("ANSI support for decimal SUM - GROUP BY") { Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { withParquetTable(generateOverflowDecimalInputs, "tbl") { val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2) @@ -1590,28 +1603,34 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("try_sum decimal overflow") { - withParquetTable(generateOverflowDecimalInputs, "tbl") { - val res = sql("SELECT try_sum(_1) FROM tbl") - checkSparkAnswerAndOperator(res) + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withParquetTable(generateOverflowDecimalInputs, "tbl") { + val res = sql("SELECT try_sum(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } } } test("try_sum decimal overflow - with GROUP BY") { - withParquetTable(generateOverflowDecimalInputs, "tbl") { - val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) - checkSparkAnswerAndOperator(res) + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withParquetTable(generateOverflowDecimalInputs, "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) + } } } test("try_sum decimal partial overflow - with GROUP BY") { - // Group 1 overflows, Group 2 succeeds - val data: Seq[(java.math.BigDecimal, Int)] = generateOverflowDecimalInputs ++ Seq( - (new java.math.BigDecimal(300), 2), - (new java.math.BigDecimal(200), 2)) - withParquetTable(data, "tbl") { - val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2") - // Group 1 should be NULL, Group 2 should be 500 - checkSparkAnswerAndOperator(res) + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + // Group 1 overflows, Group 2 succeeds + val data: Seq[(java.math.BigDecimal, Int)] = generateOverflowDecimalInputs ++ Seq( + (new java.math.BigDecimal(300), 2), + (new java.math.BigDecimal(200), 2)) + withParquetTable(data, "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2") + // Group 1 should be NULL, Group 2 should be 500 + checkSparkAnswerAndOperator(res) + } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index 1728ce5b27..376e35bdeb 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkContext import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.spark.sql.TPCDSBase import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast} -import org.apache.spark.sql.catalyst.expressions.aggregate.Average +import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} import org.apache.spark.sql.catalyst.util.resourceToString import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec, SubqueryExec} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite @@ -225,6 +225,7 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true", // Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI support + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true", CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true", // as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64 CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", From ab437f57a44ee89222d0fee73475ce24c178424e Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Mon, 1 Dec 2025 14:13:24 -0800 Subject: [PATCH 8/8] support_ansi_sum_decimal_input_fix_plan_failure_tests --- .../org/apache/spark/sql/comet/CometPlanStabilitySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index 376e35bdeb..8f260e2ca8 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -225,8 +225,8 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true", // Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI support - CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true", CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true", + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true", // as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64 CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {