Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ use datafusion_comet_proto::{
use datafusion_comet_spark_expr::{
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract,
NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance,
WideDecimalBinaryExpr, WideDecimalOp,
NormalizeNaNAndZero, Percentile, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn,
Variance, WideDecimalBinaryExpr, WideDecimalOp,
};
use itertools::Itertools;
use jni::objects::GlobalRef;
Expand Down Expand Up @@ -2267,6 +2267,55 @@ impl PhysicalPlanner {
));
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
}
AggExprStruct::PercentileCont(expr) => {
let return_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;

// Cast input to Float64 for numeric types
let child =
Arc::new(CastExpr::new(child, DataType::Float64, None)) as Arc<dyn PhysicalExpr>;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting the input to Float64 here can change ordering / collapse distinct values for high-precision DecimalType or large Long values (e.g., >2^53), which can make percentile_cont diverge from Spark’s behavior that orders on the original type.

Severity: medium

Other Locations
  • spark/src/main/scala/org/apache/comet/serde/aggregates.scala:697

Fix This in Augment

🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.


// Extract the literal percentile value
let percentile_expr =
self.create_expr(expr.percentile.as_ref().unwrap(), Arc::clone(&schema))?;
Comment on lines +2271 to +2280
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of .unwrap() on optional fields from the protobuf expression (expr.datatype, expr.child, expr.percentile) can lead to a panic if any of these fields are unexpectedly None. While the Scala-side serialization logic seems to ensure these fields are present, it's safer to handle the None case gracefully by returning an ExecutionError.

For example:

let datatype = expr.datatype.as_ref().ok_or_else(|| {
    ExecutionError::GeneralError("Datatype for PercentileCont is missing".into())
})?;
let return_type = to_arrow_datatype(datatype);

This practice should be applied to expr.child and expr.percentile as well for robustness.

let percentile_value = percentile_expr
.as_any()
.downcast_ref::<DataFusionLiteral>()
.ok_or_else(|| {
ExecutionError::GeneralError("percentile must be a literal".into())
})?
.value()
.clone();

let percentile = match percentile_value {
ScalarValue::Float64(Some(p)) => p,
ScalarValue::Float32(Some(p)) => p as f64,
ScalarValue::Int64(Some(p)) => p as f64,
ScalarValue::Int32(Some(p)) => p as f64,
_ => {
return Err(ExecutionError::GeneralError(format!(
"percentile must be a numeric literal, got {:?}",
percentile_value
)))
}
};

// Custom Spark-compatible Percentile implementation
let func = AggregateUDF::new_from_impl(Percentile::new(
"spark_percentile",
percentile,
expr.reverse,
return_type,
));

AggregateExprBuilder::new(Arc::new(func), vec![child])
.schema(schema)
.alias("spark_percentile")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inline builder duplicates existing helper function

Low Severity

The PercentileCont arm manually constructs an AggregateExprBuilder chain, but the existing Self::create_aggr_func_expr helper (used by all other aggregate expressions like BloomFilterAgg, Variance, Stddev, Correlation) does exactly the same thing. The inline version also uses a different error mapping (.map_err(|e| ExecutionError::DataFusionError(e.to_string()))) compared to the helper's .map_err(|e| e.into()), which loses error context. A single call to Self::create_aggr_func_expr("spark_percentile", schema, vec![child], func) would replace 7 lines.

Additional Locations (1)
Fix in Cursor Fix in Web

}
}
}

Expand Down
6 changes: 6 additions & 0 deletions native/core/src/execution/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,11 @@ pub fn to_arrow_datatype(dt_value: &DataType) -> ArrowDataType {
}
_ => unreachable!(),
},
DataTypeId::YearMonthInterval => {
ArrowDataType::Interval(arrow::datatypes::IntervalUnit::YearMonth)
}
DataTypeId::DayTimeInterval => {
ArrowDataType::Interval(arrow::datatypes::IntervalUnit::DayTime)
}
}
}
8 changes: 8 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ message AggExpr {
Stddev stddev = 14;
Correlation correlation = 15;
BloomFilterAgg bloomFilterAgg = 16;
PercentileCont percentileCont = 17;
}

// Optional QueryContext for error reporting (contains SQL text and position)
Expand Down Expand Up @@ -243,6 +244,13 @@ message BloomFilterAgg {
DataType datatype = 4;
}

message PercentileCont {
Expr child = 1; // The column to compute percentile on
Expr percentile = 2; // The percentile value (0.0-1.0)
DataType datatype = 3; // Return type
bool reverse = 4; // True if ORDER BY DESC
}

enum EvalMode {
LEGACY = 0;
TRY = 1;
Expand Down
2 changes: 2 additions & 0 deletions native/proto/src/proto/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ message DataType {
LIST = 14;
MAP = 15;
STRUCT = 16;
YEAR_MONTH_INTERVAL = 17;
DAY_TIME_INTERVAL = 18;
}
DataTypeId type_id = 1;

Expand Down
2 changes: 2 additions & 0 deletions native/spark-expr/src/agg_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod avg;
mod avg_decimal;
mod correlation;
mod covariance;
mod percentile;
mod stddev;
mod sum_decimal;
mod sum_int;
Expand All @@ -28,6 +29,7 @@ pub use avg::Avg;
pub use avg_decimal::AvgDecimal;
pub use correlation::Correlation;
pub use covariance::Covariance;
pub use percentile::Percentile;
pub use stddev::Stddev;
pub use sum_decimal::SumDecimal;
pub use sum_int::SumInteger;
Expand Down
Loading
Loading