3757: feat: Support Spark expression: percentile_cont#49
3757: feat: Support Spark expression: percentile_cont#49martin-augment wants to merge 4 commits intomainfrom
Conversation
WalkthroughThis pull request adds support for Spark's ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the system's analytical capabilities by integrating full support for Spark's Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
|
Thanks for adding Spark CompatibilityNaN handling (potential bug): In -- NaN should sort as greater than all other values
SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_nanDecimal precision: The serde allows Rust ImplementationPanics via
Using
Could use the Binary format comment overstates Spark compatibility: The module doc says the state format matches Spark's Interval Type Changes
Tests and DocsThe SQL test file is quite thorough — good coverage of negative values, boundary percentiles, nulls, floats, and GROUP BY. A few additions that might be worth considering:
The PR doesn't update the documentation. It would be worth adding BenchmarksThe PR description doesn't include benchmark numbers. Per the project's contribution guidelines, expression PRs should include microbenchmark results comparing Comet vs Spark. Is there a benchmark in |
🤖 Augment PR SummarySummary: Adds Comet/native execution support for Spark’s Changes:
Technical Notes: The native planner currently extracts the percentile as a constant value and the native accumulator performs Spark-style linear interpolation for continuous percentiles. 🤖 Was this summary useful? React with 👍 or 👎 |
|
|
||
| // Cast input to Float64 for numeric types | ||
| let child = | ||
| Arc::new(CastExpr::new(child, DataType::Float64, None)) as Arc<dyn PhysicalExpr>; |
There was a problem hiding this comment.
Casting the input to Float64 here can change ordering / collapse distinct values for high-precision DecimalType or large Long values (e.g., >2^53), which can make percentile_cont diverge from Spark’s behavior that orders on the original type.
Severity: medium
Other Locations
spark/src/main/scala/org/apache/comet/serde/aggregates.scala:697
🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.
| .collect(); | ||
|
|
||
| // Sort by f64 value | ||
| entries.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); |
There was a problem hiding this comment.
partial_cmp returns None for NaN, and treating that as Ordering::Equal can make the sort order (and thus the percentile result) non-deterministic for inputs containing NaNs; Spark uses a total ordering for sorting keys. It may be worth defining an explicit total order for f64 here to stay Spark-compatible.
Severity: low
🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.
There was a problem hiding this comment.
Code Review
This pull request introduces support for the percentile_cont aggregate function, including its protobuf definition, planning logic in the native execution engine, and Scala-side serialization. It also adds support for YearMonthInterval and DayTimeInterval types. The PercentileAccumulator is implemented to handle value-frequency mapping and continuous percentile calculation with comprehensive unit tests. Feedback indicates a correctness issue in compute_percentile_i64 where interpolation is not applied unconditionally, leading to percentile_disc behavior for non-Float64 types despite input casting. Another correctness concern is in the evaluate function, which hardcodes the return type to Float64, potentially causing type mismatches for other expected return types like Decimal. Additionally, there's a suggestion to improve robustness by replacing .unwrap() calls on optional protobuf fields in planner.rs with graceful error handling.
| match &self.return_type { | ||
| DataType::Float64 => { | ||
| // Interpret i64 bits as f64 | ||
| let lower_f = f64::from_bits(lower_key as u64); | ||
| let higher_f = f64::from_bits(higher_key as u64); | ||
| let result = (1.0 - fraction) * lower_f + fraction * higher_f; | ||
| Some(result.to_bits() as i64) | ||
| } | ||
| _ => Some(lower_key), | ||
| } |
There was a problem hiding this comment.
The current implementation for compute_percentile_i64 performs linear interpolation only for DataType::Float64. For all other types, it falls back to returning lower_key, which is the behavior of percentile_disc (discrete percentile), not percentile_cont (continuous percentile). This is a correctness issue.
Since the planner casts all numeric inputs to Float64, the interpolation logic should be applied unconditionally. The match statement on return_type is misleading and incorrect here.
// Interpret i64 bits as f64
let lower_f = f64::from_bits(lower_key as u64);
let higher_f = f64::from_bits(higher_key as u64);
let result = (1.0 - fraction) * lower_f + fraction * higher_f;
Some(result.to_bits() as i64)| fn evaluate(&mut self) -> Result<ScalarValue> { | ||
| match self.compute_percentile_i64() { | ||
| Some(value) => Ok(ScalarValue::Float64(Some(f64::from_bits(value as u64)))), | ||
| None => Ok(ScalarValue::Float64(None)), | ||
| } |
There was a problem hiding this comment.
The evaluate function hardcodes the return type to ScalarValue::Float64, ignoring self.return_type. This is incorrect as Spark's percentile_cont can return other types, such as Decimal. The function should respect the return_type passed during initialization and cast the final computed f64 value to the expected ScalarValue variant. Failing to do so can lead to type mismatches in the query results.
| let return_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); | ||
| let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; | ||
|
|
||
| // Cast input to Float64 for numeric types | ||
| let child = | ||
| Arc::new(CastExpr::new(child, DataType::Float64, None)) as Arc<dyn PhysicalExpr>; | ||
|
|
||
| // Extract the literal percentile value | ||
| let percentile_expr = | ||
| self.create_expr(expr.percentile.as_ref().unwrap(), Arc::clone(&schema))?; |
There was a problem hiding this comment.
The use of .unwrap() on optional fields from the protobuf expression (expr.datatype, expr.child, expr.percentile) can lead to a panic if any of these fields are unexpectedly None. While the Scala-side serialization logic seems to ensure these fields are present, it's safer to handle the None case gracefully by returning an ExecutionError.
For example:
let datatype = expr.datatype.as_ref().ok_or_else(|| {
ExecutionError::GeneralError("Datatype for PercentileCont is missing".into())
})?;
let return_type = to_arrow_datatype(datatype);This practice should be applied to expr.child and expr.percentile as well for robustness.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (5)
native/proto/src/proto/expr.proto (1)
247-251: Encodepercentileas a literal/scalar contract.Spark’s ordered-set aggregate syntax requires the percentile argument to be a constant between
0.0and1.0, but this message models it as an arbitraryExpr. The native planner then has to downcast it back to a literal and fail late when that assumption is violated. Narrowing this field to a literal/scalar payload, or validating that constraint before serialization, would keep the Scala serde and Rust planner aligned and make invalid plans unrepresentable. (archive.apache.org)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@native/proto/src/proto/expr.proto` around lines 247 - 251, The PercentileCont message currently models the percentile argument as an Expr which permits non-constant expressions; change the percentile field to a literal/scalar payload type (e.g. replace Expr percentile = 2 with a Literal/ScalarValue-style message) so the proto enforces a constant 0.0–1.0 contract at serialization time, and if that concrete literal type is not available, add validation in the serializer for PercentileCont.percentile to reject non-literal Expr values before emitting the proto (update any constructors/parsers that build PercentileCont accordingly).native/spark-expr/src/agg_funcs/percentile.rs (4)
292-308: Same defensive pattern recommended for merge_batch.The
unwrap()on line 293 could similarly be replaced with explicit error handling for consistency.♻️ Suggested defensive handling
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let binary_array = states[0].as_any().downcast_ref::<BinaryArray>().unwrap(); + let binary_array = states[0] + .as_any() + .downcast_ref::<BinaryArray>() + .ok_or_else(|| { + datafusion::common::DataFusionError::Internal( + "Percentile accumulator expected BinaryArray state".to_string(), + ) + })?;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@native/spark-expr/src/agg_funcs/percentile.rs` around lines 292 - 308, In merge_batch, avoid calling unwrap() on the downcast_ref; instead check whether states[0].as_any().downcast_ref::<BinaryArray>() returns Some and return an Err if it is None (matching the defensive pattern used elsewhere), so that merge_batch returns a proper Result error instead of panicking when the array type is unexpected; update the downcast of BinaryArray in merge_batch to perform an explicit match or if-let and construct/propagate an appropriate error (using the same Result/Error type used in this module) before proceeding with deserialization and merging into self.counts.
157-184: Consider returning an error for malformed binary data.The deserialize method silently breaks the loop when encountering malformed data (line 171-173) rather than returning an error. This could mask data corruption issues during state merging.
Additionally, the
sizefield is read but not validated against the expected value of 16.♻️ Suggested improvement for error handling
fn deserialize(bytes: &[u8]) -> Result<BTreeMap<i64, i64>> { let mut counts = BTreeMap::new(); let mut offset = 0; while offset + 4 <= bytes.len() { let size = i32::from_be_bytes(bytes[offset..offset + 4].try_into().unwrap()); offset += 4; if size < 0 { // End marker break; } + if size != 16 { + return Err(datafusion::common::DataFusionError::Internal( + format!("Invalid percentile state record size: {}, expected 16", size), + )); + } + if offset + 16 > bytes.len() { - break; + return Err(datafusion::common::DataFusionError::Internal( + "Truncated percentile state data".to_string(), + )); }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@native/spark-expr/src/agg_funcs/percentile.rs` around lines 157 - 184, The deserialize function currently swallows malformed input by breaking the loop; change it to validate and return an Err instead: check that the read size equals the expected 16 for key+count and return an appropriate error if size is negative but not the end marker or if size != 16, and when remaining bytes are insufficient (the offset + 16 check) return an Err rather than silently breaking; update error paths in deserialize to propagate a descriptive failure (refer to the deserialize function, the size variable, offset checks and the counts BTreeMap) so callers can detect corrupted binary state.
277-290: Consider defensive error handling for type downcast.The
unwrap()ondowncast_ref(line 279) will panic if the input array is notFloat64Array. While the planner ensures type casting, a more defensive approach would return an error.♻️ Suggested defensive handling
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let array = &values[0]; - let values = array.as_any().downcast_ref::<Float64Array>().unwrap(); + let values = array + .as_any() + .downcast_ref::<Float64Array>() + .ok_or_else(|| { + datafusion::common::DataFusionError::Internal( + "Percentile accumulator expected Float64Array".to_string(), + ) + })?;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@native/spark-expr/src/agg_funcs/percentile.rs` around lines 277 - 290, In update_batch, avoid panicking by replacing the Float64Array downcast_ref().unwrap() with a defensive check (use if let Some(values) = array.as_any().downcast_ref::<Float64Array>() or match) and return an Err when the downcast fails; include a clear diagnostic message that mentions the expected Float64Array and the actual array type (use array.data_type() or debug format) so callers get a proper Result::Err instead of a panic (refer to update_batch, array, values, and Float64Array to locate the code).
322-324: Thesize()method underestimates actual memory usage.The calculation only accounts for the key-value pairs but not the BTreeMap node overhead (pointers, tree structure). This is typically acceptable as an approximation, but worth noting for memory accounting accuracy.
♻️ More accurate size estimation
fn size(&self) -> usize { - std::mem::size_of_val(self) + self.counts.len() * (std::mem::size_of::<i64>() * 2) + // BTreeMap overhead is approximately 3 pointers per node plus the entry + std::mem::size_of_val(self) + + self.counts.len() * (std::mem::size_of::<i64>() * 2 + std::mem::size_of::<usize>() * 3) }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@native/spark-expr/src/agg_funcs/percentile.rs` around lines 322 - 324, The size() method currently only counts the raw key/value bytes (counts.len() * i64 sizes) plus size_of_val(self), which underestimates memory because BTreeMap nodes have per-entry overhead (pointers/parent/child/metadata). Update size() to add an estimated per-node overhead multiplied by counts.len() (e.g., a constant like ESTIMATED_BTREE_NODE_OVERHEAD representing pointer/metadata sizes) so the total becomes size_of_val(self) + counts.len() * (raw key+value sizes + ESTIMATED_BTREE_NODE_OVERHEAD); reference the size() method and the counts field when implementing this estimated per-node overhead.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@native/spark-expr/src/agg_funcs/percentile.rs`:
- Around line 402-417: The test function name test_percentile_all_negative is
misleading because the test data in PercentileAccumulator::new and the values
vector include non-negative numbers; either rename the test to something like
test_percentile_mixed_values (and update the inline comment) or change the
values to be all negative to match the original name; locate the test function
test_percentile_all_negative and update its identifier and accompanying comment
or replace the values Vec in the Float64Array::from call to all-negative numbers
accordingly.
- Around line 196-207: The current sort comparator uses
unwrap_or(Ordering::Equal) which treats NaNs as equal; update the sort used in
entries.sort_by to explicitly place NaN values last in ascending order: inside
the closure check a.0.is_nan() and b.0.is_nan() first (both NaN => Equal), if
a.0.is_nan() return Greater, if b.0.is_nan() return Less, otherwise use
a.0.partial_cmp(&b.0).unwrap(); keep the existing entries.reverse() behavior
(self.reverse) unchanged so NaN-last semantics match Spark's percentile_cont in
ascending order.
In `@spark/src/main/scala/org/apache/comet/serde/aggregates.scala`:
- Around line 697-703: The NumericType-only gate in the PercentileCont aggregate
serde blocks ANSI interval inputs; update the pattern match on
expr.child.dataType to also accept YearMonthIntervalType and DayTimeIntervalType
(treat them like NumericType — no-op cases) so the native Comet path can handle
interval-typed order-by columns, leaving the withInfo(...) fallback unchanged;
reference the match around expr.child.dataType, PercentileCont, and
withInfo(aggExpr, ...) when making the change.
---
Nitpick comments:
In `@native/proto/src/proto/expr.proto`:
- Around line 247-251: The PercentileCont message currently models the
percentile argument as an Expr which permits non-constant expressions; change
the percentile field to a literal/scalar payload type (e.g. replace Expr
percentile = 2 with a Literal/ScalarValue-style message) so the proto enforces a
constant 0.0–1.0 contract at serialization time, and if that concrete literal
type is not available, add validation in the serializer for
PercentileCont.percentile to reject non-literal Expr values before emitting the
proto (update any constructors/parsers that build PercentileCont accordingly).
In `@native/spark-expr/src/agg_funcs/percentile.rs`:
- Around line 292-308: In merge_batch, avoid calling unwrap() on the
downcast_ref; instead check whether
states[0].as_any().downcast_ref::<BinaryArray>() returns Some and return an Err
if it is None (matching the defensive pattern used elsewhere), so that
merge_batch returns a proper Result error instead of panicking when the array
type is unexpected; update the downcast of BinaryArray in merge_batch to perform
an explicit match or if-let and construct/propagate an appropriate error (using
the same Result/Error type used in this module) before proceeding with
deserialization and merging into self.counts.
- Around line 157-184: The deserialize function currently swallows malformed
input by breaking the loop; change it to validate and return an Err instead:
check that the read size equals the expected 16 for key+count and return an
appropriate error if size is negative but not the end marker or if size != 16,
and when remaining bytes are insufficient (the offset + 16 check) return an Err
rather than silently breaking; update error paths in deserialize to propagate a
descriptive failure (refer to the deserialize function, the size variable,
offset checks and the counts BTreeMap) so callers can detect corrupted binary
state.
- Around line 277-290: In update_batch, avoid panicking by replacing the
Float64Array downcast_ref().unwrap() with a defensive check (use if let
Some(values) = array.as_any().downcast_ref::<Float64Array>() or match) and
return an Err when the downcast fails; include a clear diagnostic message that
mentions the expected Float64Array and the actual array type (use
array.data_type() or debug format) so callers get a proper Result::Err instead
of a panic (refer to update_batch, array, values, and Float64Array to locate the
code).
- Around line 322-324: The size() method currently only counts the raw key/value
bytes (counts.len() * i64 sizes) plus size_of_val(self), which underestimates
memory because BTreeMap nodes have per-entry overhead
(pointers/parent/child/metadata). Update size() to add an estimated per-node
overhead multiplied by counts.len() (e.g., a constant like
ESTIMATED_BTREE_NODE_OVERHEAD representing pointer/metadata sizes) so the total
becomes size_of_val(self) + counts.len() * (raw key+value sizes +
ESTIMATED_BTREE_NODE_OVERHEAD); reference the size() method and the counts field
when implementing this estimated per-node overhead.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 15ec97da-2198-4d7f-8568-68033dd7a6c1
📒 Files selected for processing (9)
native/core/src/execution/planner.rsnative/core/src/execution/serde.rsnative/proto/src/proto/expr.protonative/proto/src/proto/types.protonative/spark-expr/src/agg_funcs/mod.rsnative/spark-expr/src/agg_funcs/percentile.rsspark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scalaspark/src/main/scala/org/apache/comet/serde/aggregates.scalaspark/src/test/resources/sql-tests/expressions/aggregate/percentile_cont.sql
| let mut entries: Vec<(f64, i64)> = self | ||
| .counts | ||
| .iter() | ||
| .map(|(&bits, &count)| (f64::from_bits(bits as u64), count)) | ||
| .collect(); | ||
|
|
||
| // Sort by f64 value | ||
| entries.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); | ||
|
|
||
| if self.reverse { | ||
| entries.reverse(); | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Spark SQL percentile_cont NaN handling behavior
💡 Result:
Spark SQL's percentile_cont function, introduced in Spark 3.3.0 as an ANSI-standard ordered-set aggregate, is implemented as a runtime replacement for the underlying percentile aggregate function. The percentile_cont(percentage) WITHIN GROUP (ORDER BY sortKey) computes interpolated continuous percentiles from the ordered sortKey values. NULL handling: NULL values in sortKey are ignored, similar to other aggregates. If all values are NULL or the partition/group is empty, it returns NULL. NaN handling: Spark SQL treats NaN as a regular floating-point value with special semantics (NaN = NaN is true, NaNs sort last in ascending order). Since percentile_cont sorts the input values (including any NaNs) and performs linear interpolation between sorted positions, NaN values are included in the computation. If the computed percentile position falls on or interpolates to a NaN value, the result will be NaN. Official docs do not explicitly document NaN behavior for percentile_cont, but it follows the general percentile implementation and Spark's NaN semantics. Evidence from sources: - Official Spark docs confirm percentile_cont syntax and replacement by percentile [5,15,16]. - Percentile implementation collects all non-null values into a hashmap, sorts them, and interpolates [20]. - Spark NaN semantics: treated as values in sorting/aggregates, go last [17]. - Aggregate examples show NULL returned for empty/non-contributing groups [2]. - No sources indicate NaNs are filtered; they propagate through sorting/interpolation like regular values.
Citations:
- 1: [SPARK-38219][SQL] Support ANSI aggregation function
percentile_contas window function apache/spark#35531 - 2: https://spark.apache.org/docs/latest/sql-ref-functions-builtin.html
- 3: https://spark.apache.org/docs/4.2.0-preview2/api/sql/agg-functions/
- 4: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-aggregate.html
- 5: https://spark.apache.org/docs/3.0.0-preview2/sql-ref-nan-semantics.html
- 6: https://github.com/apache/spark/blob/5a7403623d0525c23ab8ae575e9d1383e3e10635/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
Fix NaN sorting to match Spark semantics.
The current implementation treats NaN as equal to all values using unwrap_or(std::cmp::Ordering::Equal). However, Spark's percentile_cont sorts NaN values last in ascending order, not as equal. This mismatch can cause non-deterministic sort order and incorrect percentile results when NaN values are present. Update the comparison logic to sort NaN values last, consistent with Spark's documented NaN handling for percentiles.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@native/spark-expr/src/agg_funcs/percentile.rs` around lines 196 - 207, The
current sort comparator uses unwrap_or(Ordering::Equal) which treats NaNs as
equal; update the sort used in entries.sort_by to explicitly place NaN values
last in ascending order: inside the closure check a.0.is_nan() and b.0.is_nan()
first (both NaN => Equal), if a.0.is_nan() return Greater, if b.0.is_nan()
return Less, otherwise use a.0.partial_cmp(&b.0).unwrap(); keep the existing
entries.reverse() behavior (self.reverse) unchanged so NaN-last semantics match
Spark's percentile_cont in ascending order.
| #[test] | ||
| fn test_percentile_all_negative() { | ||
| // Test all negative values | ||
| // Values: -50, -20, 0, 10, 30 | ||
| // Sorted: -50, -20, 0, 10, 30 | ||
| // Median (50th percentile) with 5 values: | ||
| // position = (5-1) * 0.5 = 2 | ||
| // result = value at index 2 = 0 | ||
| let mut acc = PercentileAccumulator::new(0.5, false, DataType::Float64); | ||
| let values: ArrayRef = | ||
| Arc::new(Float64Array::from(vec![-50.0, -20.0, 0.0, 10.0, 30.0])); | ||
| acc.update_batch(&[values]).unwrap(); | ||
|
|
||
| let result = acc.evaluate().unwrap(); | ||
| assert_eq!(result, ScalarValue::Float64(Some(0.0))); | ||
| } |
There was a problem hiding this comment.
Test name doesn't match the test data.
The test is named test_percentile_all_negative but the values [-50.0, -20.0, 0.0, 10.0, 30.0] include non-negative values (0, 10, 30). Consider renaming to test_percentile_mixed_values or similar.
📝 Suggested fix
#[test]
- fn test_percentile_all_negative() {
- // Test all negative values
+ fn test_percentile_mixed_negative_positive() {
+ // Test mixed negative and positive values
// Values: -50, -20, 0, 10, 30📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #[test] | |
| fn test_percentile_all_negative() { | |
| // Test all negative values | |
| // Values: -50, -20, 0, 10, 30 | |
| // Sorted: -50, -20, 0, 10, 30 | |
| // Median (50th percentile) with 5 values: | |
| // position = (5-1) * 0.5 = 2 | |
| // result = value at index 2 = 0 | |
| let mut acc = PercentileAccumulator::new(0.5, false, DataType::Float64); | |
| let values: ArrayRef = | |
| Arc::new(Float64Array::from(vec![-50.0, -20.0, 0.0, 10.0, 30.0])); | |
| acc.update_batch(&[values]).unwrap(); | |
| let result = acc.evaluate().unwrap(); | |
| assert_eq!(result, ScalarValue::Float64(Some(0.0))); | |
| } | |
| #[test] | |
| fn test_percentile_mixed_negative_positive() { | |
| // Test mixed negative and positive values | |
| // Values: -50, -20, 0, 10, 30 | |
| // Sorted: -50, -20, 0, 10, 30 | |
| // Median (50th percentile) with 5 values: | |
| // position = (5-1) * 0.5 = 2 | |
| // result = value at index 2 = 0 | |
| let mut acc = PercentileAccumulator::new(0.5, false, DataType::Float64); | |
| let values: ArrayRef = | |
| Arc::new(Float64Array::from(vec![-50.0, -20.0, 0.0, 10.0, 30.0])); | |
| acc.update_batch(&[values]).unwrap(); | |
| let result = acc.evaluate().unwrap(); | |
| assert_eq!(result, ScalarValue::Float64(Some(0.0))); | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@native/spark-expr/src/agg_funcs/percentile.rs` around lines 402 - 417, The
test function name test_percentile_all_negative is misleading because the test
data in PercentileAccumulator::new and the values vector include non-negative
numbers; either rename the test to something like test_percentile_mixed_values
(and update the inline comment) or change the values to be all negative to match
the original name; locate the test function test_percentile_all_negative and
update its identifier and accompanying comment or replace the values Vec in the
Float64Array::from call to all-negative numbers accordingly.
| // Support numeric types (includes DecimalType) | ||
| expr.child.dataType match { | ||
| case _: NumericType => | ||
| case _ => | ||
| withInfo(aggExpr, s"unsupported input type: ${expr.child.dataType}") | ||
| return None | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the specific file and lines mentioned
cat -n spark/src/main/scala/org/apache/comet/serde/aggregates.scala | sed -n '690,710p'Repository: martin-augment/datafusion-comet
Length of output: 1015
🏁 Script executed:
# Search for interval type references in the aggregates.scala file
rg -n "YearMonthIntervalType|DayTimeIntervalType|IntervalType" spark/src/main/scala/org/apache/comet/serde/aggregates.scalaRepository: martin-augment/datafusion-comet
Length of output: 57
🏁 Script executed:
# Search for percentile-related code and interval handling in the aggregates file
rg -n -A 5 -B 5 "Percentile" spark/src/main/scala/org/apache/comet/serde/aggregates.scala | head -80Repository: martin-augment/datafusion-comet
Length of output: 2097
🏁 Script executed:
# Check if interval types are imported or used elsewhere in recent changes
rg -n "YearMonthIntervalType|DayTimeIntervalType" spark/src/main/scala/org/apache/comet/serde/Repository: martin-augment/datafusion-comet
Length of output: 592
🏁 Script executed:
# Search for percentile_cont in the codebase to see how it's handled
rg -n "percentile_cont|PercentileCont" --type scalaRepository: martin-augment/datafusion-comet
Length of output: 453
🏁 Script executed:
# Check the context around the interval type comment in math.scala
cat -n spark/src/main/scala/org/apache/comet/serde/math.scala | sed -n '170,190p'Repository: martin-augment/datafusion-comet
Length of output: 946
🏁 Script executed:
# Check what operations in math.scala support intervals
rg -n -B 3 -A 3 "DayTimeIntervalType|YearMonthIntervalType" spark/src/main/scala/org/apache/comet/serde/math.scala | head -40Repository: martin-augment/datafusion-comet
Length of output: 356
🌐 Web query:
Spark SQL percentile_cont function supported data types intervals
💡 Result:
The Spark SQL percentile_cont function supports numeric data types and ANSI interval types for the sortKey (the expression in the ORDER BY clause). It returns DOUBLE (or ARRAY) regardless of input type. Official Apache Spark documentation explicitly states: "numeric or ANSI interval column col". Databricks documentation describes sortKey as "a numeric expression", consistent with numeric support. ANSI intervals (DayTimeIntervalType and YearMonthIntervalType) are supported as per SPARK-39567 resolution and related functions like approx_percentile, percentile. Examples in docs use numeric columns (integers, decimals). The function is a runtime replacement for the exact Percentile aggregate, which requires numeric/interval inputs supporting sorting and interpolation to DOUBLE. Intervals are supported because they are orderable and can be interpolated linearly (e.g., INTERVAL '1.5' SECOND).
Citations:
- 1: https://spark.apache.org/docs/latest/sql-ref-functions-builtin.html
- 2: https://spark.apache.org/docs/4.2.0-preview2/api/sql/agg-functions/
- 3: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-aggregate.html
- 4: https://issues.apache.org/jira/browse/SPARK-39567
- 5: https://docs.databricks.com/aws/en/sql/language-manual/functions/percentile_cont
PercentileCont excludes ANSI interval inputs that Spark SQL supports.
Apache Spark's percentile_cont function accepts both numeric and ANSI interval types for the order-by column, as documented in the official Spark SQL reference. The current NumericType-only gate at lines 697–703 prevents interval-typed percentile_cont queries from using the native Comet path. While this PR adds interval type infrastructure elsewhere (type IDs in QueryPlanSerde), the aggregate serde must extend support to YearMonthIntervalType and DayTimeIntervalType to handle these cases end-to-end.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@spark/src/main/scala/org/apache/comet/serde/aggregates.scala` around lines
697 - 703, The NumericType-only gate in the PercentileCont aggregate serde
blocks ANSI interval inputs; update the pattern match on expr.child.dataType to
also accept YearMonthIntervalType and DayTimeIntervalType (treat them like
NumericType — no-op cases) so the native Comet path can handle interval-typed
order-by columns, leaving the withInfo(...) fallback unchanged; reference the
match around expr.child.dataType, PercentileCont, and withInfo(aggExpr, ...)
when making the change.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| .with_ignore_nulls(false) | ||
| .with_distinct(false) | ||
| .build() | ||
| .map_err(|e| ExecutionError::DataFusionError(e.to_string())) |
There was a problem hiding this comment.
Inline builder duplicates existing helper function
Low Severity
The PercentileCont arm manually constructs an AggregateExprBuilder chain, but the existing Self::create_aggr_func_expr helper (used by all other aggregate expressions like BloomFilterAgg, Variance, Stddev, Correlation) does exactly the same thing. The inline version also uses a different error mapping (.map_err(|e| ExecutionError::DataFusionError(e.to_string()))) compared to the helper's .map_err(|e| e.into()), which loses error context. A single call to Self::create_aggr_func_expr("spark_percentile", schema, vec![child], func) would replace 7 lines.


3757: To review by AI