Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1893,19 +1893,24 @@ impl PhysicalPlanner {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;

let builder = match datatype {
DataType::Decimal128(_, _) => {
let func =
AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype));
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eval_mode is computed for AVG but not used in the Decimal branch (AvgDecimal::new), so ANSI/TRY mode won’t affect Decimal AVG (e.g., overflow behavior). Consider propagating eval_mode to the decimal implementation to ensure semantics match Spark’s modes (also applies given getSupportLevel removal in Scala).

🤖 Was this useful? React with 👍 or 👎

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback: The Augment AI reviewer is correct that the eval_mode is ignored for the AvgDecimal implementation. https://github.com/apache/spark/blob/211dd995b221f135340375159672dcb77ef90ef4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L113 shows that it is used in the Spark implementation. Prevents wrong behavior in DataFusion Comet compared to Spark.

AggregateExprBuilder::new(Arc::new(func), vec![child])
}
_ => {
// cast to the result data type of AVG if the result data type is different
// from the input type, e.g. AVG(Int32). We should not expect a cast
// failure since it should have already been checked at Spark side.
// For all other numeric types (Int8/16/32/64, Float32/64):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Decimal AVG ignores eval_mode parameter

The eval_mode is extracted from the protobuf at line 1896 but not passed to AvgDecimal::new for Decimal128 types. This means decimal averages won't respect ANSI mode settings for overflow handling, while non-decimal averages correctly receive the eval_mode parameter. The AvgDecimal implementation needs to be updated to accept and use eval_mode for consistent behavior across all numeric types.

Fix in Cursor Fix in Web

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback: The Bugbot AI reviewer is correct that the eval_mode is ignored for the AvgDecimal implementation. https://github.com/apache/spark/blob/211dd995b221f135340375159672dcb77ef90ef4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L113 shows that it is used in the Spark implementation. Prevents wrong behavior in DataFusion Comet compared to Spark.

// Cast to Float64 for accumulation
let child: Arc<dyn PhysicalExpr> =
Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None));
let func = AggregateUDF::new_from_impl(Avg::new("avg", datatype));
Arc::new(CastExpr::new(Arc::clone(&child), DataType::Float64, None));
let func = AggregateUDF::new_from_impl(Avg::new(
"avg",
DataType::Float64,
eval_mode,
));
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
};
Expand Down
2 changes: 1 addition & 1 deletion native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ message Avg {
Expr child = 1;
DataType datatype = 2;
DataType sum_datatype = 3;
bool fail_on_error = 4; // currently unused (useful for deciding Ansi vs Legacy mode)
EvalMode eval_mode = 4;
}

message First {
Expand Down
52 changes: 32 additions & 20 deletions native/spark-expr/src/agg_funcs/avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
// specific language governing permissions and limitations
// under the License.

use crate::EvalMode;
use arrow::array::{
builder::PrimitiveBuilder,
cast::AsArray,
types::{Float64Type, Int64Type},
Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray,
Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, Int64Array, PrimitiveArray,
};
use arrow::compute::sum;
use arrow::datatypes::{DataType, Field, FieldRef};
Expand All @@ -31,45 +32,43 @@ use datafusion::logical_expr::{
use datafusion::physical_expr::expressions::format_state_name;
use std::{any::Any, sync::Arc};

use arrow::array::ArrowNativeTypeOp;
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::Volatility::Immutable;
use DataType::*;

/// AVG aggregate expression
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Avg {
name: String,
signature: Signature,
// expr: Arc<dyn PhysicalExpr>,
input_data_type: DataType,
result_data_type: DataType,
eval_mode: EvalMode,
}

impl Avg {
/// Create a new AVG aggregate function
pub fn new(name: impl Into<String>, data_type: DataType) -> Self {
pub fn new(name: impl Into<String>, data_type: DataType, eval_mode: EvalMode) -> Self {
let result_data_type = avg_return_type("avg", &data_type).unwrap();

Self {
name: name.into(),
signature: Signature::user_defined(Immutable),
input_data_type: data_type,
result_data_type,
eval_mode,
}
}
}

impl AggregateUDFImpl for Avg {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
// instantiate specialized accumulator based for the type
// All numeric types use Float64 accumulation after casting
match (&self.input_data_type, &self.result_data_type) {
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
(Float64, Float64) => Ok(Box::new(AvgAccumulator::new(self.eval_mode))),
_ => not_impl_err!(
"AvgAccumulator for ({} --> {})",
self.input_data_type,
Expand Down Expand Up @@ -109,10 +108,10 @@ impl AggregateUDFImpl for Avg {
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
// instantiate specialized accumulator based for the type
match (&self.input_data_type, &self.result_data_type) {
(Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
&self.input_data_type,
self.eval_mode,
|sum: f64, count: i64| Ok(sum / count as f64),
))),

Expand All @@ -137,11 +136,22 @@ impl AggregateUDFImpl for Avg {
}
}

/// An accumulator to compute the average
#[derive(Debug, Default)]
#[derive(Debug)]
pub struct AvgAccumulator {
sum: Option<f64>,
count: i64,
#[allow(dead_code)]
eval_mode: EvalMode,
}

impl AvgAccumulator {
pub fn new(eval_mode: EvalMode) -> Self {
Self {
sum: None,
count: 0,
eval_mode,
}
}
}

impl Accumulator for AvgAccumulator {
Expand All @@ -166,7 +176,7 @@ impl Accumulator for AvgAccumulator {
// counts are summed
self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();

// sums are summed
// sums are summed - no overflow checking
if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
let v = self.sum.get_or_insert(0.);
*v += x;
Expand All @@ -176,8 +186,6 @@ impl Accumulator for AvgAccumulator {

fn evaluate(&mut self) -> Result<ScalarValue> {
if self.count == 0 {
// If all input are nulls, count will be 0 and we will get null after the division.
// This is consistent with Spark Average implementation.
Ok(ScalarValue::Float64(None))
} else {
Ok(ScalarValue::Float64(
Expand All @@ -192,7 +200,7 @@ impl Accumulator for AvgAccumulator {
}

/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
/// Stores values as native types, and does overflow checking
/// Stores values as native types.
///
/// F: Function that calculates the average value from a sum of
/// T::Native and a total count
Expand All @@ -211,6 +219,10 @@ where
/// Sums per group, stored as the native type
sums: Vec<T::Native>,

/// Evaluation mode (stored but not used for Float64)
#[allow(dead_code)]
eval_mode: EvalMode,

/// Function that computes the final average (value / count)
avg_fn: F,
}
Expand All @@ -220,11 +232,12 @@ where
T: ArrowNumericType + Send,
F: Fn(T::Native, i64) -> Result<T::Native> + Send,
{
pub fn new(return_data_type: &DataType, avg_fn: F) -> Self {
pub fn new(return_data_type: &DataType, eval_mode: EvalMode, avg_fn: F) -> Self {
Self {
return_data_type: return_data_type.clone(),
counts: vec![],
sums: vec![],
eval_mode,
avg_fn,
}
}
Expand Down Expand Up @@ -254,6 +267,7 @@ where
if values.null_count() == 0 {
for (&group_index, &value) in iter {
let sum = &mut self.sums[group_index];
// No overflow checking - INFINITY is a valid result
*sum = (*sum).add_wrapping(value);
self.counts[group_index] += 1;
}
Expand All @@ -264,7 +278,6 @@ where
}
let sum = &mut self.sums[group_index];
*sum = (*sum).add_wrapping(value);

self.counts[group_index] += 1;
}
}
Expand All @@ -280,17 +293,17 @@ where
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 2, "two arguments to merge_batch");
// first batch is partial sums, second is counts
let partial_sums = values[0].as_primitive::<T>();
let partial_counts = values[1].as_primitive::<Int64Type>();

// update counts with partial counts
self.counts.resize(total_num_groups, 0);
let iter1 = group_indices.iter().zip(partial_counts.values().iter());
for (&group_index, &partial_count) in iter1 {
self.counts[group_index] += partial_count;
}

// update sums
// update sums - no overflow checking
self.sums.resize(total_num_groups, T::default_value());
let iter2 = group_indices.iter().zip(partial_sums.values().iter());
for (&group_index, &new_value) in iter2 {
Expand Down Expand Up @@ -319,7 +332,6 @@ where
Ok(Arc::new(array))
}

// return arrays for sums and counts
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let counts = emit_to.take_needed(&mut self.counts);
let counts = Int64Array::new(counts.into(), None);
Expand Down
16 changes: 3 additions & 13 deletions spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType
import org.apache.comet.CometConf
import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType}
import org.apache.comet.shims.CometEvalModeUtil

object CometMin extends CometAggregateExpressionSerde[Min] {

Expand Down Expand Up @@ -150,17 +151,6 @@ object CometCount extends CometAggregateExpressionSerde[Count] {

object CometAverage extends CometAggregateExpressionSerde[Average] {

override def getSupportLevel(avg: Average): SupportLevel = {
avg.evalMode match {
case EvalMode.ANSI =>
Incompatible(Some("ANSI mode is not supported"))
case EvalMode.TRY =>
Incompatible(Some("TRY mode is not supported"))
case _ =>
Compatible()
}
}

override def convert(
aggExpr: AggregateExpression,
avg: Average,
Expand Down Expand Up @@ -192,7 +182,7 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
val builder = ExprOuterClass.Avg.newBuilder()
builder.setChild(childExpr.get)
builder.setDatatype(dataType.get)
builder.setFailOnError(avg.evalMode == EvalMode.ANSI)
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(avg.evalMode)))
builder.setSumDatatype(sumDataType.get)

Some(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,42 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("AVG and try_avg - basic functionality") {
withParquetTable(
Seq(
(10L, 1),
(20L, 1),
(null.asInstanceOf[Long], 1),
(100L, 2),
(200L, 2),
(null.asInstanceOf[Long], 3)),
"tbl") {

Seq(true, false).foreach({ k =>
// without GROUP BY
withSQLConf(SQLConf.ANSI_ENABLED.key -> k.toString) {
val res = sql("SELECT avg(_1) FROM tbl")
checkSparkAnswerAndOperator(res)
}

// with GROUP BY
withSQLConf(SQLConf.ANSI_ENABLED.key -> k.toString) {
val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
checkSparkAnswerAndOperator(res)
}

})

// try_avg without GROUP BY
val resTry = sql("SELECT try_avg(_1) FROM tbl")
checkSparkAnswerAndOperator(resTry)

// try_avg with GROUP BY
val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2")
checkSparkAnswerAndOperator(resTryGroup)
}
}

protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)
Expand Down