From c6368eaf0e24c63a8600984711bf3038f29ebbbd Mon Sep 17 00:00:00 2001 From: Bolin Lin Date: Sat, 21 Mar 2026 16:35:30 -0400 Subject: [PATCH 1/3] feat: support percentile cont --- native/core/src/execution/planner.rs | 56 +- native/core/src/execution/serde.rs | 6 + native/proto/src/proto/expr.proto | 8 + native/proto/src/proto/types.proto | 2 + native/spark-expr/src/agg_funcs/mod.rs | 2 + native/spark-expr/src/agg_funcs/percentile.rs | 481 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 3 + .../org/apache/comet/serde/aggregates.scala | 61 ++- .../expressions/aggregate/percentile_cont.sql | 77 +++ 9 files changed, 692 insertions(+), 4 deletions(-) create mode 100644 native/spark-expr/src/agg_funcs/percentile.rs create mode 100644 spark/src/test/resources/sql-tests/expressions/aggregate/percentile_cont.sql diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index e730fd0c89..a22acb1278 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -123,8 +123,8 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::{ ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract, - NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, - WideDecimalBinaryExpr, WideDecimalOp, + NormalizeNaNAndZero, Percentile, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, + Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::GlobalRef; @@ -2206,6 +2206,58 @@ impl PhysicalPlanner { )); Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func) } + AggExprStruct::PercentileCont(expr) => { + let return_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + + // Cast input to appropriate type based on return type + // For interval types, we preserve the type; for numeric types, cast to Float64 + let child = match &return_type { + DataType::Interval(_) => child, + _ => Arc::new(CastExpr::new(child, DataType::Float64, None)) as Arc, + }; + + // Extract the literal percentile value + let percentile_expr = + self.create_expr(expr.percentile.as_ref().unwrap(), Arc::clone(&schema))?; + let percentile_value = percentile_expr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ExecutionError::GeneralError("percentile must be a literal".into()) + })? + .value() + .clone(); + + let percentile = match percentile_value { + ScalarValue::Float64(Some(p)) => p, + ScalarValue::Float32(Some(p)) => p as f64, + ScalarValue::Int64(Some(p)) => p as f64, + ScalarValue::Int32(Some(p)) => p as f64, + _ => { + return Err(ExecutionError::GeneralError(format!( + "percentile must be a numeric literal, got {:?}", + percentile_value + ))) + } + }; + + // Custom Spark-compatible Percentile implementation + let func = AggregateUDF::new_from_impl(Percentile::new( + "spark_percentile", + percentile, + expr.reverse, + return_type, + )); + + AggregateExprBuilder::new(Arc::new(func), vec![child]) + .schema(schema) + .alias("spark_percentile") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + } } } diff --git a/native/core/src/execution/serde.rs b/native/core/src/execution/serde.rs index ae0554ee76..e11afb33e4 100644 --- a/native/core/src/execution/serde.rs +++ b/native/core/src/execution/serde.rs @@ -168,5 +168,11 @@ pub fn to_arrow_datatype(dt_value: &DataType) -> ArrowDataType { } _ => unreachable!(), }, + DataTypeId::YearMonthInterval => { + ArrowDataType::Interval(arrow::datatypes::IntervalUnit::YearMonth) + } + DataTypeId::DayTimeInterval => { + ArrowDataType::Interval(arrow::datatypes::IntervalUnit::DayTime) + } } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 32cbc0ce13..1301b17957 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -139,6 +139,7 @@ message AggExpr { Stddev stddev = 14; Correlation correlation = 15; BloomFilterAgg bloomFilterAgg = 16; + PercentileCont percentileCont = 17; } // Optional QueryContext for error reporting (contains SQL text and position) @@ -243,6 +244,13 @@ message BloomFilterAgg { DataType datatype = 4; } +message PercentileCont { + Expr child = 1; // The column to compute percentile on + Expr percentile = 2; // The percentile value (0.0-1.0) + DataType datatype = 3; // Return type + bool reverse = 4; // True if ORDER BY DESC +} + enum EvalMode { LEGACY = 0; TRY = 1; diff --git a/native/proto/src/proto/types.proto b/native/proto/src/proto/types.proto index 2fd3d59a73..361607f14d 100644 --- a/native/proto/src/proto/types.proto +++ b/native/proto/src/proto/types.proto @@ -59,6 +59,8 @@ message DataType { LIST = 14; MAP = 15; STRUCT = 16; + YEAR_MONTH_INTERVAL = 17; + DAY_TIME_INTERVAL = 18; } DataTypeId type_id = 1; diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index b1027153e8..e71c49c17b 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod percentile; mod stddev; mod sum_decimal; mod sum_int; @@ -28,6 +29,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use percentile::Percentile; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/native/spark-expr/src/agg_funcs/percentile.rs b/native/spark-expr/src/agg_funcs/percentile.rs new file mode 100644 index 0000000000..3cadbe3392 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/percentile.rs @@ -0,0 +1,481 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Spark-compatible Percentile aggregate function. +//! +//! This implementation matches Spark's `Percentile` class intermediate state format, +//! which uses a serialized map of (value -> frequency) stored as BinaryType. + +use arrow::array::{Array, ArrayRef, BinaryArray, Float64Array, IntervalDayTimeArray, IntervalYearMonthArray}; +use arrow::datatypes::{DataType, Field, FieldRef, IntervalDayTimeType, IntervalUnit}; +use datafusion::common::{Result, ScalarValue}; +use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature}; +use datafusion::physical_expr::expressions::format_state_name; +use std::any::Any; +use std::collections::BTreeMap; +use std::sync::Arc; + +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::Volatility::Immutable; + +/// Spark-compatible Percentile aggregate function. +/// +/// Stores intermediate state as BinaryType containing serialized (value, count) pairs, +/// matching Spark's `TypedAggregateWithHashMapAsBuffer` format. +#[derive(Debug, Clone, PartialEq)] +pub struct Percentile { + name: String, + signature: Signature, + /// Percentile value stored as bits for Hash/Eq + percentile_bits: u64, + reverse: bool, + /// The return data type + return_type: DataType, +} + +impl std::hash::Hash for Percentile { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.percentile_bits.hash(state); + self.reverse.hash(state); + } +} + +impl Eq for Percentile {} + +impl Percentile { + pub fn new(name: impl Into, percentile: f64, reverse: bool, return_type: DataType) -> Self { + Self { + name: name.into(), + signature: Signature::any(1, Immutable), + percentile_bits: percentile.to_bits(), + reverse, + return_type, + } + } + + fn percentile(&self) -> f64 { + f64::from_bits(self.percentile_bits) + } +} + +impl AggregateUDFImpl for Percentile { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + // Match Spark's BinaryType state format + Ok(vec![Arc::new(Field::new( + format_state_name(args.name, "counts"), + DataType::Binary, + true, + ))]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(PercentileAccumulator::new( + self.percentile(), + self.reverse, + self.return_type.clone(), + ))) + } + + fn default_value(&self, _data_type: &DataType) -> Result { + match &self.return_type { + DataType::Float64 => Ok(ScalarValue::Float64(None)), + DataType::Interval(IntervalUnit::YearMonth) => Ok(ScalarValue::IntervalYearMonth(None)), + DataType::Interval(IntervalUnit::DayTime) => Ok(ScalarValue::IntervalDayTime(None)), + _ => Ok(ScalarValue::Float64(None)), + } + } +} + +/// Accumulator for Percentile that stores (value -> count) map. +/// Values are stored as i64 regardless of input type to simplify the implementation. +#[derive(Debug)] +pub struct PercentileAccumulator { + /// Map of value (as i64 bits) -> frequency count (using BTreeMap for sorted iteration) + counts: BTreeMap, + /// The percentile to compute (0.0 to 1.0) + percentile: f64, + /// Whether to reverse the order (for DESC) + reverse: bool, + /// The return data type + return_type: DataType, +} + +impl PercentileAccumulator { + pub fn new(percentile: f64, reverse: bool, return_type: DataType) -> Self { + Self { + counts: BTreeMap::new(), + percentile, + reverse, + return_type, + } + } + + /// Serialize the counts map to Spark's binary format. + fn serialize(&self) -> Vec { + let mut buf = Vec::new(); + + for (&key, &count) in &self.counts { + // Each entry: [size: i32][key: i64][count: i64] + // Size = 8 (i64) + 8 (i64) = 16 bytes + let size: i32 = 16; + buf.extend_from_slice(&size.to_be_bytes()); + buf.extend_from_slice(&key.to_be_bytes()); + buf.extend_from_slice(&count.to_be_bytes()); + } + + // End marker + buf.extend_from_slice(&(-1i32).to_be_bytes()); + buf + } + + /// Deserialize counts map from Spark's binary format. + fn deserialize(bytes: &[u8]) -> Result> { + let mut counts = BTreeMap::new(); + let mut offset = 0; + + while offset + 4 <= bytes.len() { + let size = i32::from_be_bytes(bytes[offset..offset + 4].try_into().unwrap()); + offset += 4; + + if size < 0 { + // End marker + break; + } + + if offset + 16 > bytes.len() { + break; + } + + let key = i64::from_be_bytes(bytes[offset..offset + 8].try_into().unwrap()); + offset += 8; + let count = i64::from_be_bytes(bytes[offset..offset + 8].try_into().unwrap()); + offset += 8; + + counts.insert(key, count); + } + + Ok(counts) + } + + /// Compute the percentile from the accumulated counts. + /// Returns the result as i64 bits (can be interpreted as f64 bits or interval value). + fn compute_percentile_i64(&self) -> Option { + if self.counts.is_empty() { + return None; + } + + // Get sorted (value, accumulated_count) pairs + let sorted_counts: Vec<(i64, i64)> = if self.reverse { + self.counts.iter().rev().map(|(&k, &v)| (k, v)).collect() + } else { + self.counts.iter().map(|(&k, &v)| (k, v)).collect() + }; + + // Compute accumulated counts + let mut accumulated: Vec<(i64, i64)> = Vec::with_capacity(sorted_counts.len()); + let mut total: i64 = 0; + for (value, count) in sorted_counts { + total += count; + accumulated.push((value, total)); + } + + let total_count = total; + if total_count == 0 { + return None; + } + + // Position in the distribution (0-indexed) + let position = (total_count - 1) as f64 * self.percentile; + let lower = position.floor() as i64; + let higher = position.ceil() as i64; + + // Binary search for lower and higher indices + let lower_idx = Self::binary_search_count(&accumulated, lower + 1); + let higher_idx = Self::binary_search_count(&accumulated, higher + 1); + + let lower_key = accumulated[lower_idx].0; + + if higher == lower { + // No interpolation needed + return Some(lower_key); + } + + let higher_key = accumulated[higher_idx].0; + + if lower_key == higher_key { + // Same key, no interpolation needed + return Some(lower_key); + } + + // Linear interpolation + let fraction = position - lower as f64; + + // Handle interpolation based on return type + match &self.return_type { + DataType::Float64 => { + // Interpret i64 bits as f64 + let lower_f = f64::from_bits(lower_key as u64); + let higher_f = f64::from_bits(higher_key as u64); + let result = (1.0 - fraction) * lower_f + fraction * higher_f; + Some(result.to_bits() as i64) + } + DataType::Interval(IntervalUnit::YearMonth) => { + // Values are i32 months stored as i64 + let lower_months = lower_key as i32; + let higher_months = higher_key as i32; + let result = (1.0 - fraction) * (lower_months as f64) + fraction * (higher_months as f64); + Some(result.round() as i64) + } + DataType::Interval(IntervalUnit::DayTime) => { + // Values are packed as (days << 32) | milliseconds + let lower_days = (lower_key >> 32) as i32; + let lower_ms = lower_key as i32; + let higher_days = (higher_key >> 32) as i32; + let higher_ms = higher_key as i32; + + // Convert to total milliseconds for interpolation + let lower_total_ms = (lower_days as i64) * 86_400_000 + (lower_ms as i64); + let higher_total_ms = (higher_days as i64) * 86_400_000 + (higher_ms as i64); + let result_ms = ((1.0 - fraction) * (lower_total_ms as f64) + fraction * (higher_total_ms as f64)).round() as i64; + + // Convert back to days and milliseconds + let result_days = (result_ms / 86_400_000) as i32; + let result_remaining_ms = (result_ms % 86_400_000) as i32; + + Some(((result_days as i64) << 32) | (result_remaining_ms as i64 & 0xFFFFFFFF)) + } + _ => Some(lower_key), + } + } + + /// Binary search to find the index where accumulated count >= target + fn binary_search_count(accumulated: &[(i64, i64)], target: i64) -> usize { + match accumulated.binary_search_by(|(_, count)| count.cmp(&target)) { + Ok(idx) => idx, + Err(idx) => idx.min(accumulated.len() - 1), + } + } +} + +impl Accumulator for PercentileAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + + match array.data_type() { + DataType::Float64 => { + let values = array.as_any().downcast_ref::().unwrap(); + for i in 0..values.len() { + if values.is_null(i) { + continue; + } + let key = values.value(i).to_bits() as i64; + *self.counts.entry(key).or_insert(0) += 1; + } + } + DataType::Interval(IntervalUnit::YearMonth) => { + let values = array.as_any().downcast_ref::().unwrap(); + for i in 0..values.len() { + if values.is_null(i) { + continue; + } + let key = values.value(i) as i64; + *self.counts.entry(key).or_insert(0) += 1; + } + } + DataType::Interval(IntervalUnit::DayTime) => { + let values = array.as_any().downcast_ref::().unwrap(); + for i in 0..values.len() { + if values.is_null(i) { + continue; + } + // Convert IntervalDayTime struct to packed i64: (days << 32) | milliseconds + let (days, ms) = IntervalDayTimeType::to_parts(values.value(i)); + let key = ((days as i64) << 32) | (ms as i64 & 0xFFFFFFFF); + *self.counts.entry(key).or_insert(0) += 1; + } + } + _ => { + // Fallback: try to treat as Float64 + if let Some(values) = array.as_any().downcast_ref::() { + for i in 0..values.len() { + if values.is_null(i) { + continue; + } + let key = values.value(i).to_bits() as i64; + *self.counts.entry(key).or_insert(0) += 1; + } + } + } + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let binary_array = states[0].as_any().downcast_ref::().unwrap(); + + for i in 0..binary_array.len() { + if binary_array.is_null(i) { + continue; + } + let bytes = binary_array.value(i); + let other_counts = Self::deserialize(bytes)?; + + for (key, count) in other_counts { + *self.counts.entry(key).or_insert(0) += count; + } + } + + Ok(()) + } + + fn state(&mut self) -> Result> { + let bytes = self.serialize(); + Ok(vec![ScalarValue::Binary(Some(bytes))]) + } + + fn evaluate(&mut self) -> Result { + match self.compute_percentile_i64() { + Some(value) => match &self.return_type { + DataType::Float64 => Ok(ScalarValue::Float64(Some(f64::from_bits(value as u64)))), + DataType::Interval(IntervalUnit::YearMonth) => { + Ok(ScalarValue::IntervalYearMonth(Some(value as i32))) + } + DataType::Interval(IntervalUnit::DayTime) => { + // Unpack i64 to (days, milliseconds) and create IntervalDayTime struct + let days = (value >> 32) as i32; + let ms = value as i32; + Ok(ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(days, ms)))) + } + _ => Ok(ScalarValue::Float64(Some(f64::from_bits(value as u64)))), + }, + None => match &self.return_type { + DataType::Float64 => Ok(ScalarValue::Float64(None)), + DataType::Interval(IntervalUnit::YearMonth) => Ok(ScalarValue::IntervalYearMonth(None)), + DataType::Interval(IntervalUnit::DayTime) => Ok(ScalarValue::IntervalDayTime(None)), + _ => Ok(ScalarValue::Float64(None)), + }, + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + self.counts.len() * (std::mem::size_of::() * 2) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Float64Array; + use std::sync::Arc; + + #[test] + fn test_percentile_median() { + let mut acc = PercentileAccumulator::new(0.5, false, DataType::Float64); + let values: ArrayRef = Arc::new(Float64Array::from(vec![0.0, 10.0, 20.0, 30.0, 40.0])); + acc.update_batch(&[values]).unwrap(); + + let result = acc.evaluate().unwrap(); + assert_eq!(result, ScalarValue::Float64(Some(20.0))); + } + + #[test] + fn test_percentile_25th() { + let mut acc = PercentileAccumulator::new(0.25, false, DataType::Float64); + let values: ArrayRef = Arc::new(Float64Array::from(vec![0.0, 10.0, 20.0, 30.0, 40.0])); + acc.update_batch(&[values]).unwrap(); + + let result = acc.evaluate().unwrap(); + assert_eq!(result, ScalarValue::Float64(Some(10.0))); + } + + #[test] + fn test_percentile_serialize_deserialize() { + let mut acc = PercentileAccumulator::new(0.5, false, DataType::Float64); + let values: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])); + acc.update_batch(&[values]).unwrap(); + + let state = acc.state().unwrap(); + let bytes = match &state[0] { + ScalarValue::Binary(Some(b)) => b.clone(), + _ => panic!("Expected Binary state"), + }; + + let deserialized = PercentileAccumulator::deserialize(&bytes).unwrap(); + assert_eq!(deserialized.len(), 3); + } + + #[test] + fn test_percentile_reverse() { + // With DESC ordering, 25th percentile should equal 75th percentile of ASC + let mut acc_asc = PercentileAccumulator::new(0.75, false, DataType::Float64); + let mut acc_desc = PercentileAccumulator::new(0.25, true, DataType::Float64); + + let values: ArrayRef = Arc::new(Float64Array::from(vec![0.0, 10.0, 20.0, 30.0, 40.0])); + acc_asc.update_batch(&[values.clone()]).unwrap(); + acc_desc.update_batch(&[values]).unwrap(); + + let result_asc = acc_asc.evaluate().unwrap(); + let result_desc = acc_desc.evaluate().unwrap(); + assert_eq!(result_asc, result_desc); + } + + #[test] + fn test_percentile_year_month_interval() { + let mut acc = PercentileAccumulator::new(0.5, false, DataType::Interval(IntervalUnit::YearMonth)); + let values: ArrayRef = Arc::new(IntervalYearMonthArray::from(vec![0, 10, 20, 30, 40])); + acc.update_batch(&[values]).unwrap(); + + let result = acc.evaluate().unwrap(); + assert_eq!(result, ScalarValue::IntervalYearMonth(Some(20))); + } + + #[test] + fn test_percentile_day_time_interval() { + let mut acc = PercentileAccumulator::new(0.5, false, DataType::Interval(IntervalUnit::DayTime)); + // Create intervals: 1 day, 2 days, 3 days, 4 days, 5 days (no milliseconds) + let values: ArrayRef = Arc::new(IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(1, 0), + IntervalDayTimeType::make_value(2, 0), + IntervalDayTimeType::make_value(3, 0), + IntervalDayTimeType::make_value(4, 0), + IntervalDayTimeType::make_value(5, 0), + ])); + acc.update_batch(&[values]).unwrap(); + + let result = acc.evaluate().unwrap(); + // Median should be 3 days + assert_eq!(result, ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(3, 0)))); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8c39ba779d..74e4fdb567 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -263,6 +263,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Last] -> CometLast, classOf[Max] -> CometMax, classOf[Min] -> CometMin, + classOf[Percentile] -> CometPercentile, classOf[StddevPop] -> CometStddevPop, classOf[StddevSamp] -> CometStddevSamp, classOf[Sum] -> CometSum, @@ -370,6 +371,8 @@ object QueryPlanSerde extends Logging with CometExprShim { case _: ArrayType => 14 case _: MapType => 15 case _: StructType => 16 + case _: YearMonthIntervalType => 17 + case _: DayTimeIntervalType => 18 case dt => logWarning(s"Cannot serialize Spark data type: $dt") return None diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 1485589b46..6746a5f825 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -22,9 +22,10 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType} +import org.apache.spark.sql.types.{ArrayType, ByteType, DataTypes, DayTimeIntervalType, DecimalType, IntegerType, LongType, NumericType, , StringType, YearMonthIntervalType} import org.apache.comet.CometConf import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT @@ -671,6 +672,62 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt } } +object CometPercentile extends CometAggregateExpressionSerde[Percentile] { + override def convert( + aggExpr: AggregateExpression, + expr: Percentile, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + + // Only support when frequency is Literal(1L) - i.e., percentile_cont behavior + expr.frequencyExpression match { + case Literal(1L, LongType) => + case _ => + withInfo(aggExpr, "weighted percentile not supported") + return None + } + + // Only support scalar percentile, not array of percentiles + if (expr.percentageExpression.dataType.isInstanceOf[ArrayType]) { + withInfo(aggExpr, "array of percentiles not supported") + return None + } + + // Support numeric types and interval types + expr.child.dataType match { + case _: NumericType => + case _: DecimalType => + case _: YearMonthIntervalType => + case _: DayTimeIntervalType => + case _ => + withInfo(aggExpr, s"unsupported input type: ${expr.child.dataType}") + return None + } + + val childExpr = exprToProto(expr.child, inputs, binding) + val percentileExpr = exprToProto(expr.percentageExpression, inputs, binding) + val dataType = serializeDataType(expr.dataType) + + if (childExpr.isDefined && percentileExpr.isDefined && dataType.isDefined) { + val builder = ExprOuterClass.PercentileCont.newBuilder() + builder.setChild(childExpr.get) + builder.setPercentile(percentileExpr.get) + builder.setDatatype(dataType.get) + builder.setReverse(expr.reverse) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setPercentileCont(builder) + .build()) + } else { + withInfo(aggExpr, expr.child, expr.percentageExpression) + None + } + } +} + object AggSerde { import org.apache.spark.sql.types._ diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_cont.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_cont.sql new file mode 100644 index 0000000000..5cb61c6610 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_cont.sql @@ -0,0 +1,77 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Tests for percentile_cont aggregate function +-- Uses similar test data as Spark's percentiles.sql + +statement +CREATE TABLE test_percentile(k int, v int) USING parquet + +statement +INSERT INTO test_percentile VALUES (0, 0), (0, 10), (0, 20), (0, 30), (0, 40), (1, 10), (1, 20), (2, 10), (2, 20), (2, 25), (2, 30), (3, 60), (4, NULL) + +-- Basic percentile_cont (25th percentile) - should match Spark result: 10.0 +query +SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FROM test_percentile + +-- percentile_cont with DESC ordering - should match Spark result: 30.0 +query +SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FROM test_percentile + +-- percentile_cont with GROUP BY - should match Spark results +query +SELECT k, percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FROM test_percentile GROUP BY k ORDER BY k + +-- percentile_cont with GROUP BY and DESC - should match Spark results +query +SELECT k, percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FROM test_percentile GROUP BY k ORDER BY k + +-- median (50th percentile) +query +SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_percentile + +-- Multiple percentile_cont in same query +query +SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY v), percentile_cont(0.75) WITHIN GROUP (ORDER BY v) FROM test_percentile + +-- Tests for interval types +statement +CREATE TABLE test_interval ( + id INT, + ym INTERVAL YEAR TO MONTH, + dt INTERVAL DAY TO SECOND +) USING parquet + +statement +INSERT INTO test_interval VALUES + (1, INTERVAL '1' YEAR, INTERVAL '1' DAY), + (2, INTERVAL '2' YEAR, INTERVAL '2' DAY), + (3, INTERVAL '3' YEAR, INTERVAL '3' DAY), + (4, INTERVAL '4' YEAR, INTERVAL '4' DAY), + (5, INTERVAL '5' YEAR, INTERVAL '5' DAY) + +-- percentile_cont with YearMonthIntervalType +query +SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY ym) FROM test_interval + +-- percentile_cont with DayTimeIntervalType +query +SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY dt) FROM test_interval + +-- percentile_cont with interval types and GROUP BY +query +SELECT id % 2 AS grp, percentile_cont(0.5) WITHIN GROUP (ORDER BY ym) FROM test_interval GROUP BY id % 2 ORDER BY grp From a9321710e5c2c240968a2978f64a7bd9d6248846 Mon Sep 17 00:00:00 2001 From: Bolin Lin Date: Mon, 23 Mar 2026 09:15:13 -0400 Subject: [PATCH 2/3] fix: remove decimal type --- spark/src/main/scala/org/apache/comet/serde/aggregates.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 6746a5f825..a57159577c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, ByteType, DataTypes, DayTimeIntervalType, DecimalType, IntegerType, LongType, NumericType, , StringType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, ByteType, DataTypes, DayTimeIntervalType, DecimalType, IntegerType, LongType, NumericType, ShortType, StringType, YearMonthIntervalType} import org.apache.comet.CometConf import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT @@ -697,7 +697,6 @@ object CometPercentile extends CometAggregateExpressionSerde[Percentile] { // Support numeric types and interval types expr.child.dataType match { case _: NumericType => - case _: DecimalType => case _: YearMonthIntervalType => case _: DayTimeIntervalType => case _ => From ca3fa2e96f013d70ca9ca3dc8aa4587b422bd28c Mon Sep 17 00:00:00 2001 From: Bolin Lin Date: Mon, 23 Mar 2026 18:47:44 -0400 Subject: [PATCH 3/3] fix: negative value sorting --- native/core/src/execution/planner.rs | 9 +- native/spark-expr/src/agg_funcs/percentile.rs | 190 +++++++----------- .../org/apache/comet/serde/aggregates.scala | 6 +- .../expressions/aggregate/percentile_cont.sql | 132 ++++++++++-- 4 files changed, 191 insertions(+), 146 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 2175670361..4c623a7173 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -2271,12 +2271,9 @@ impl PhysicalPlanner { let return_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; - // Cast input to appropriate type based on return type - // For interval types, we preserve the type; for numeric types, cast to Float64 - let child = match &return_type { - DataType::Interval(_) => child, - _ => Arc::new(CastExpr::new(child, DataType::Float64, None)) as Arc, - }; + // Cast input to Float64 for numeric types + let child = + Arc::new(CastExpr::new(child, DataType::Float64, None)) as Arc; // Extract the literal percentile value let percentile_expr = diff --git a/native/spark-expr/src/agg_funcs/percentile.rs b/native/spark-expr/src/agg_funcs/percentile.rs index 3cadbe3392..d8de36a137 100644 --- a/native/spark-expr/src/agg_funcs/percentile.rs +++ b/native/spark-expr/src/agg_funcs/percentile.rs @@ -20,8 +20,8 @@ //! This implementation matches Spark's `Percentile` class intermediate state format, //! which uses a serialized map of (value -> frequency) stored as BinaryType. -use arrow::array::{Array, ArrayRef, BinaryArray, Float64Array, IntervalDayTimeArray, IntervalYearMonthArray}; -use arrow::datatypes::{DataType, Field, FieldRef, IntervalDayTimeType, IntervalUnit}; +use arrow::array::{Array, ArrayRef, BinaryArray, Float64Array}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion::common::{Result, ScalarValue}; use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature}; use datafusion::physical_expr::expressions::format_state_name; @@ -108,12 +108,7 @@ impl AggregateUDFImpl for Percentile { } fn default_value(&self, _data_type: &DataType) -> Result { - match &self.return_type { - DataType::Float64 => Ok(ScalarValue::Float64(None)), - DataType::Interval(IntervalUnit::YearMonth) => Ok(ScalarValue::IntervalYearMonth(None)), - DataType::Interval(IntervalUnit::DayTime) => Ok(ScalarValue::IntervalDayTime(None)), - _ => Ok(ScalarValue::Float64(None)), - } + Ok(ScalarValue::Float64(None)) } } @@ -195,12 +190,27 @@ impl PercentileAccumulator { return None; } - // Get sorted (value, accumulated_count) pairs - let sorted_counts: Vec<(i64, i64)> = if self.reverse { - self.counts.iter().rev().map(|(&k, &v)| (k, v)).collect() - } else { - self.counts.iter().map(|(&k, &v)| (k, v)).collect() - }; + // Collect entries and sort by actual f64 value (not bit pattern). + // We can't rely on BTreeMap's i64 ordering because f64 bit patterns + // don't preserve numeric ordering for negative values. + let mut entries: Vec<(f64, i64)> = self + .counts + .iter() + .map(|(&bits, &count)| (f64::from_bits(bits as u64), count)) + .collect(); + + // Sort by f64 value + entries.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); + + if self.reverse { + entries.reverse(); + } + + // Convert back to (bits, count) for the rest of the computation + let sorted_counts: Vec<(i64, i64)> = entries + .into_iter() + .map(|(f, count)| (f.to_bits() as i64, count)) + .collect(); // Compute accumulated counts let mut accumulated: Vec<(i64, i64)> = Vec::with_capacity(sorted_counts.len()); @@ -250,31 +260,6 @@ impl PercentileAccumulator { let result = (1.0 - fraction) * lower_f + fraction * higher_f; Some(result.to_bits() as i64) } - DataType::Interval(IntervalUnit::YearMonth) => { - // Values are i32 months stored as i64 - let lower_months = lower_key as i32; - let higher_months = higher_key as i32; - let result = (1.0 - fraction) * (lower_months as f64) + fraction * (higher_months as f64); - Some(result.round() as i64) - } - DataType::Interval(IntervalUnit::DayTime) => { - // Values are packed as (days << 32) | milliseconds - let lower_days = (lower_key >> 32) as i32; - let lower_ms = lower_key as i32; - let higher_days = (higher_key >> 32) as i32; - let higher_ms = higher_key as i32; - - // Convert to total milliseconds for interpolation - let lower_total_ms = (lower_days as i64) * 86_400_000 + (lower_ms as i64); - let higher_total_ms = (higher_days as i64) * 86_400_000 + (higher_ms as i64); - let result_ms = ((1.0 - fraction) * (lower_total_ms as f64) + fraction * (higher_total_ms as f64)).round() as i64; - - // Convert back to days and milliseconds - let result_days = (result_ms / 86_400_000) as i32; - let result_remaining_ms = (result_ms % 86_400_000) as i32; - - Some(((result_days as i64) << 32) | (result_remaining_ms as i64 & 0xFFFFFFFF)) - } _ => Some(lower_key), } } @@ -291,52 +276,14 @@ impl PercentileAccumulator { impl Accumulator for PercentileAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let array = &values[0]; + let values = array.as_any().downcast_ref::().unwrap(); - match array.data_type() { - DataType::Float64 => { - let values = array.as_any().downcast_ref::().unwrap(); - for i in 0..values.len() { - if values.is_null(i) { - continue; - } - let key = values.value(i).to_bits() as i64; - *self.counts.entry(key).or_insert(0) += 1; - } - } - DataType::Interval(IntervalUnit::YearMonth) => { - let values = array.as_any().downcast_ref::().unwrap(); - for i in 0..values.len() { - if values.is_null(i) { - continue; - } - let key = values.value(i) as i64; - *self.counts.entry(key).or_insert(0) += 1; - } - } - DataType::Interval(IntervalUnit::DayTime) => { - let values = array.as_any().downcast_ref::().unwrap(); - for i in 0..values.len() { - if values.is_null(i) { - continue; - } - // Convert IntervalDayTime struct to packed i64: (days << 32) | milliseconds - let (days, ms) = IntervalDayTimeType::to_parts(values.value(i)); - let key = ((days as i64) << 32) | (ms as i64 & 0xFFFFFFFF); - *self.counts.entry(key).or_insert(0) += 1; - } - } - _ => { - // Fallback: try to treat as Float64 - if let Some(values) = array.as_any().downcast_ref::() { - for i in 0..values.len() { - if values.is_null(i) { - continue; - } - let key = values.value(i).to_bits() as i64; - *self.counts.entry(key).or_insert(0) += 1; - } - } + for i in 0..values.len() { + if values.is_null(i) { + continue; } + let key = values.value(i).to_bits() as i64; + *self.counts.entry(key).or_insert(0) += 1; } Ok(()) @@ -367,25 +314,8 @@ impl Accumulator for PercentileAccumulator { fn evaluate(&mut self) -> Result { match self.compute_percentile_i64() { - Some(value) => match &self.return_type { - DataType::Float64 => Ok(ScalarValue::Float64(Some(f64::from_bits(value as u64)))), - DataType::Interval(IntervalUnit::YearMonth) => { - Ok(ScalarValue::IntervalYearMonth(Some(value as i32))) - } - DataType::Interval(IntervalUnit::DayTime) => { - // Unpack i64 to (days, milliseconds) and create IntervalDayTime struct - let days = (value >> 32) as i32; - let ms = value as i32; - Ok(ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(days, ms)))) - } - _ => Ok(ScalarValue::Float64(Some(f64::from_bits(value as u64)))), - }, - None => match &self.return_type { - DataType::Float64 => Ok(ScalarValue::Float64(None)), - DataType::Interval(IntervalUnit::YearMonth) => Ok(ScalarValue::IntervalYearMonth(None)), - DataType::Interval(IntervalUnit::DayTime) => Ok(ScalarValue::IntervalDayTime(None)), - _ => Ok(ScalarValue::Float64(None)), - }, + Some(value) => Ok(ScalarValue::Float64(Some(f64::from_bits(value as u64)))), + None => Ok(ScalarValue::Float64(None)), } } @@ -452,30 +382,54 @@ mod tests { } #[test] - fn test_percentile_year_month_interval() { - let mut acc = PercentileAccumulator::new(0.5, false, DataType::Interval(IntervalUnit::YearMonth)); - let values: ArrayRef = Arc::new(IntervalYearMonthArray::from(vec![0, 10, 20, 30, 40])); + fn test_percentile_negative_values() { + // Test that negative values are sorted correctly + // Values: -100, -50, 50, 100 + // Sorted: -100, -50, 50, 100 + // Median (50th percentile) with 4 values: + // position = (4-1) * 0.5 = 1.5 + // lower_idx = 1 (-50), upper_idx = 2 (50) + // result = 0.5 * (-50) + 0.5 * 50 = 0 + let mut acc = PercentileAccumulator::new(0.5, false, DataType::Float64); + let values: ArrayRef = + Arc::new(Float64Array::from(vec![-100.0, -50.0, 50.0, 100.0])); acc.update_batch(&[values]).unwrap(); let result = acc.evaluate().unwrap(); - assert_eq!(result, ScalarValue::IntervalYearMonth(Some(20))); + assert_eq!(result, ScalarValue::Float64(Some(0.0))); } #[test] - fn test_percentile_day_time_interval() { - let mut acc = PercentileAccumulator::new(0.5, false, DataType::Interval(IntervalUnit::DayTime)); - // Create intervals: 1 day, 2 days, 3 days, 4 days, 5 days (no milliseconds) - let values: ArrayRef = Arc::new(IntervalDayTimeArray::from(vec![ - IntervalDayTimeType::make_value(1, 0), - IntervalDayTimeType::make_value(2, 0), - IntervalDayTimeType::make_value(3, 0), - IntervalDayTimeType::make_value(4, 0), - IntervalDayTimeType::make_value(5, 0), - ])); + fn test_percentile_all_negative() { + // Test all negative values + // Values: -50, -20, 0, 10, 30 + // Sorted: -50, -20, 0, 10, 30 + // Median (50th percentile) with 5 values: + // position = (5-1) * 0.5 = 2 + // result = value at index 2 = 0 + let mut acc = PercentileAccumulator::new(0.5, false, DataType::Float64); + let values: ArrayRef = + Arc::new(Float64Array::from(vec![-50.0, -20.0, 0.0, 10.0, 30.0])); + acc.update_batch(&[values]).unwrap(); + + let result = acc.evaluate().unwrap(); + assert_eq!(result, ScalarValue::Float64(Some(0.0))); + } + + #[test] + fn test_percentile_negative_25th() { + // Test 25th percentile with negative values + // Values: -100, -50, 50, 100 + // position = (4-1) * 0.25 = 0.75 + // lower_idx = 0 (-100), upper_idx = 1 (-50) + // result = 0.25 * (-100) + 0.75 * (-50) = -25 + -37.5 = -62.5 + // Actually: (1-0.75)*(-100) + 0.75*(-50) = 0.25*(-100) + 0.75*(-50) = -25 - 37.5 = -62.5 + let mut acc = PercentileAccumulator::new(0.25, false, DataType::Float64); + let values: ArrayRef = + Arc::new(Float64Array::from(vec![-100.0, -50.0, 50.0, 100.0])); acc.update_batch(&[values]).unwrap(); let result = acc.evaluate().unwrap(); - // Median should be 3 days - assert_eq!(result, ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(3, 0)))); + assert_eq!(result, ScalarValue::Float64(Some(-62.5))); } } diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index a57159577c..145027f229 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, ByteType, DataTypes, DayTimeIntervalType, DecimalType, IntegerType, LongType, NumericType, ShortType, StringType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, ByteType, DataTypes, DecimalType, IntegerType, LongType, NumericType, ShortType, StringType} import org.apache.comet.CometConf import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT @@ -694,11 +694,9 @@ object CometPercentile extends CometAggregateExpressionSerde[Percentile] { return None } - // Support numeric types and interval types + // Support numeric types (includes DecimalType) expr.child.dataType match { case _: NumericType => - case _: YearMonthIntervalType => - case _: DayTimeIntervalType => case _ => withInfo(aggExpr, s"unsupported input type: ${expr.child.dataType}") return None diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_cont.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_cont.sql index 5cb61c6610..ce802a9c1c 100644 --- a/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_cont.sql +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_cont.sql @@ -48,30 +48,126 @@ SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_percentile query SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY v), percentile_cont(0.75) WITHIN GROUP (ORDER BY v) FROM test_percentile --- Tests for interval types +-- ============================================================ +-- Tests for negative values (sort order correctness is critical) +-- ============================================================ + +statement +CREATE TABLE test_negative(k int, v int) USING parquet + +statement +INSERT INTO test_negative VALUES (0, -50), (0, -20), (0, 0), (0, 10), (0, 30), (1, -100), (1, -50), (1, 50), (1, 100) + +-- Negative values with ASC ordering +query +SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FROM test_negative + +-- Negative values with DESC ordering +query +SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FROM test_negative + +-- Negative values with GROUP BY +query +SELECT k, percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_negative GROUP BY k ORDER BY k + +-- Negative values median +query +SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_negative + +-- ============================================================ +-- Tests for boundary percentiles (0.0 and 1.0) +-- ============================================================ + +-- 0th percentile (minimum value) +query +SELECT percentile_cont(0.0) WITHIN GROUP (ORDER BY v) FROM test_percentile + +-- 100th percentile (maximum value) +query +SELECT percentile_cont(1.0) WITHIN GROUP (ORDER BY v) FROM test_percentile + +-- Boundary percentiles with negative values +query +SELECT percentile_cont(0.0) WITHIN GROUP (ORDER BY v), percentile_cont(1.0) WITHIN GROUP (ORDER BY v) FROM test_negative + +-- Boundary percentiles with GROUP BY +query +SELECT k, percentile_cont(0.0) WITHIN GROUP (ORDER BY v), percentile_cont(1.0) WITHIN GROUP (ORDER BY v) FROM test_negative GROUP BY k ORDER BY k + +-- ============================================================ +-- Tests for all-null groups and single-value groups +-- ============================================================ + statement -CREATE TABLE test_interval ( - id INT, - ym INTERVAL YEAR TO MONTH, - dt INTERVAL DAY TO SECOND -) USING parquet +CREATE TABLE test_edge_cases(k int, v int) USING parquet statement -INSERT INTO test_interval VALUES - (1, INTERVAL '1' YEAR, INTERVAL '1' DAY), - (2, INTERVAL '2' YEAR, INTERVAL '2' DAY), - (3, INTERVAL '3' YEAR, INTERVAL '3' DAY), - (4, INTERVAL '4' YEAR, INTERVAL '4' DAY), - (5, INTERVAL '5' YEAR, INTERVAL '5' DAY) +INSERT INTO test_edge_cases VALUES (0, NULL), (0, NULL), (1, 42), (2, 10), (2, 10), (2, 10) + +-- All-null group should return NULL +query +SELECT k, percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_edge_cases GROUP BY k ORDER BY k + +-- Single value group +query +SELECT k, percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FROM test_edge_cases WHERE k = 1 GROUP BY k + +-- All same values in group +query +SELECT k, percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_edge_cases WHERE k = 2 GROUP BY k + +-- Empty result (no rows match) +query +SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_edge_cases WHERE k = 999 + +-- ============================================================ +-- Tests for DOUBLE column type +-- ============================================================ + +statement +CREATE TABLE test_double(k int, v double) USING parquet + +statement +INSERT INTO test_double VALUES (0, -1.5), (0, 0.0), (0, 1.5), (0, 3.0), (1, -100.25), (1, 0.5), (1, 100.75), (2, NULL) + +-- Double values basic +query +SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_double + +-- Double values with GROUP BY +query +SELECT k, percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FROM test_double GROUP BY k ORDER BY k + +-- Double boundary percentiles +query +SELECT percentile_cont(0.0) WITHIN GROUP (ORDER BY v), percentile_cont(1.0) WITHIN GROUP (ORDER BY v) FROM test_double WHERE k = 0 + +-- Double with DESC ordering +query +SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FROM test_double WHERE k = 1 + +-- ============================================================ +-- Tests for FLOAT column type +-- ============================================================ + +statement +CREATE TABLE test_float(k int, v float) USING parquet + +statement +INSERT INTO test_float VALUES (0, -2.5), (0, -0.5), (0, 0.5), (0, 2.5), (1, -50.0), (1, 0.0), (1, 50.0) + +-- Float values basic +query +SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_float --- percentile_cont with YearMonthIntervalType +-- Float values with GROUP BY query -SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY ym) FROM test_interval +SELECT k, percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_float GROUP BY k ORDER BY k --- percentile_cont with DayTimeIntervalType +-- Float boundary percentiles query -SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY dt) FROM test_interval +SELECT percentile_cont(0.0) WITHIN GROUP (ORDER BY v), percentile_cont(1.0) WITHIN GROUP (ORDER BY v) FROM test_float --- percentile_cont with interval types and GROUP BY +-- Float with negative values and DESC query -SELECT id % 2 AS grp, percentile_cont(0.5) WITHIN GROUP (ORDER BY ym) FROM test_interval GROUP BY id % 2 ORDER BY grp +SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY v DESC) FROM test_float WHERE k = 0