diff --git a/datafusion/functions-aggregate-common/src/tdigest.rs b/datafusion/functions-aggregate-common/src/tdigest.rs
index 320157fb7bd83..85bf2e1b4cabb 100644
--- a/datafusion/functions-aggregate-common/src/tdigest.rs
+++ b/datafusion/functions-aggregate-common/src/tdigest.rs
@@ -32,7 +32,6 @@
use arrow::datatypes::DataType;
use arrow::datatypes::Float64Type;
use datafusion_common::cast::as_primitive_array;
-use datafusion_common::Result;
use datafusion_common::ScalarValue;
use std::cmp::Ordering;
use std::mem::{size_of, size_of_val};
@@ -61,41 +60,6 @@ macro_rules! cast_scalar_u64 {
};
}
-/// This trait is implemented for each type a [`TDigest`] can operate on,
-/// allowing it to support both numerical rust types (obtained from
-/// `PrimitiveArray` instances), and [`ScalarValue`] instances.
-pub trait TryIntoF64 {
- /// A fallible conversion of a possibly null `self` into a [`f64`].
- ///
- /// If `self` is null, this method must return `Ok(None)`.
- ///
- /// If `self` cannot be coerced to the desired type, this method must return
- /// an `Err` variant.
- fn try_as_f64(&self) -> Result>;
-}
-
-/// Generate an infallible conversion from `type` to an [`f64`].
-macro_rules! impl_try_ordered_f64 {
- ($type:ty) => {
- impl TryIntoF64 for $type {
- fn try_as_f64(&self) -> Result > {
- Ok(Some(*self as f64))
- }
- }
- };
-}
-
-impl_try_ordered_f64!(f64);
-impl_try_ordered_f64!(f32);
-impl_try_ordered_f64!(i64);
-impl_try_ordered_f64!(i32);
-impl_try_ordered_f64!(i16);
-impl_try_ordered_f64!(i8);
-impl_try_ordered_f64!(u64);
-impl_try_ordered_f64!(u32);
-impl_try_ordered_f64!(u16);
-impl_try_ordered_f64!(u8);
-
/// Centroid implementation to the cluster mentioned in the paper.
#[derive(Debug, PartialEq, Clone)]
pub struct Centroid {
diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs
index 530dbf3e43c79..1d0d7a318ae13 100644
--- a/datafusion/functions-aggregate/src/approx_median.rs
+++ b/datafusion/functions-aggregate/src/approx_median.rs
@@ -19,16 +19,18 @@
use arrow::datatypes::DataType::{Float64, UInt64};
use arrow::datatypes::{DataType, Field, FieldRef};
+use datafusion_common::types::NativeType;
+use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator;
use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;
-use datafusion_common::{not_impl_err, plan_err, Result};
+use datafusion_common::{not_impl_err, Result};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
-use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
- Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
+ Accumulator, AggregateUDFImpl, Coercion, Documentation, Signature, TypeSignature,
+ TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
@@ -57,20 +59,11 @@ make_udaf_expr_and_func!(
```"#,
standard_argument(name = "expression",)
)]
-#[derive(PartialEq, Eq, Hash)]
+#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ApproxMedian {
signature: Signature,
}
-impl Debug for ApproxMedian {
- fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
- f.debug_struct("ApproxMedian")
- .field("name", &self.name())
- .field("signature", &self.signature)
- .finish()
- }
-}
-
impl Default for ApproxMedian {
fn default() -> Self {
Self::new()
@@ -81,33 +74,53 @@ impl ApproxMedian {
/// Create a new APPROX_MEDIAN aggregate function
pub fn new() -> Self {
Self {
- signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
+ signature: Signature::one_of(
+ vec![
+ TypeSignature::Coercible(vec![Coercion::new_exact(
+ TypeSignatureClass::Integer,
+ )]),
+ TypeSignature::Coercible(vec![Coercion::new_implicit(
+ TypeSignatureClass::Float,
+ vec![TypeSignatureClass::Decimal],
+ NativeType::Float64,
+ )]),
+ ],
+ Volatility::Immutable,
+ ),
}
}
}
impl AggregateUDFImpl for ApproxMedian {
- /// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}
fn state_fields(&self, args: StateFieldsArgs) -> Result> {
- Ok(vec![
- Field::new(format_state_name(args.name, "max_size"), UInt64, false),
- Field::new(format_state_name(args.name, "sum"), Float64, false),
- Field::new(format_state_name(args.name, "count"), UInt64, false),
- Field::new(format_state_name(args.name, "max"), Float64, false),
- Field::new(format_state_name(args.name, "min"), Float64, false),
- Field::new_list(
- format_state_name(args.name, "centroids"),
- Field::new_list_field(Float64, true),
- false,
- ),
- ]
- .into_iter()
- .map(Arc::new)
- .collect())
+ if args.input_fields[0].data_type().is_null() {
+ Ok(vec![Field::new(
+ format_state_name(args.name, self.name()),
+ DataType::Null,
+ true,
+ )
+ .into()])
+ } else {
+ Ok(vec![
+ Field::new(format_state_name(args.name, "max_size"), UInt64, false),
+ Field::new(format_state_name(args.name, "sum"), Float64, false),
+ Field::new(format_state_name(args.name, "count"), UInt64, false),
+ Field::new(format_state_name(args.name, "max"), Float64, false),
+ Field::new(format_state_name(args.name, "min"), Float64, false),
+ Field::new_list(
+ format_state_name(args.name, "centroids"),
+ Field::new_list_field(Float64, true),
+ false,
+ ),
+ ]
+ .into_iter()
+ .map(Arc::new)
+ .collect())
+ }
}
fn name(&self) -> &str {
@@ -119,9 +132,6 @@ impl AggregateUDFImpl for ApproxMedian {
}
fn return_type(&self, arg_types: &[DataType]) -> Result {
- if !arg_types[0].is_numeric() {
- return plan_err!("ApproxMedian requires numeric input types");
- }
Ok(arg_types[0].clone())
}
@@ -132,10 +142,14 @@ impl AggregateUDFImpl for ApproxMedian {
);
}
- Ok(Box::new(ApproxPercentileAccumulator::new(
- 0.5_f64,
- acc_args.expr_fields[0].data_type().clone(),
- )))
+ if acc_args.expr_fields[0].data_type().is_null() {
+ Ok(Box::new(NoopAccumulator::default()))
+ } else {
+ Ok(Box::new(ApproxPercentileAccumulator::new(
+ 0.5_f64,
+ acc_args.expr_fields[0].data_type().clone(),
+ )))
+ }
}
fn documentation(&self) -> Option<&Documentation> {
diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
index 4015abc6adf70..0a12c52eb1256 100644
--- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs
+++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
@@ -16,11 +16,11 @@
// under the License.
use std::any::Any;
-use std::fmt::{Debug, Formatter};
+use std::fmt::Debug;
use std::mem::size_of_val;
use std::sync::Arc;
-use arrow::array::Array;
+use arrow::array::{Array, Float16Array};
use arrow::compute::{filter, is_not_null};
use arrow::datatypes::FieldRef;
use arrow::{
@@ -42,9 +42,7 @@ use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
Volatility,
};
-use datafusion_functions_aggregate_common::tdigest::{
- TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
-};
+use datafusion_functions_aggregate_common::tdigest::{TDigest, DEFAULT_MAX_SIZE};
use datafusion_macros::user_doc;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
@@ -121,20 +119,11 @@ An alternate syntax is also supported:
description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
)
)]
-#[derive(PartialEq, Eq, Hash)]
+#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ApproxPercentileCont {
signature: Signature,
}
-impl Debug for ApproxPercentileCont {
- fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
- f.debug_struct("ApproxPercentileCont")
- .field("name", &self.name())
- .field("signature", &self.signature)
- .finish()
- }
-}
-
impl Default for ApproxPercentileCont {
fn default() -> Self {
Self::new()
@@ -197,6 +186,7 @@ impl ApproxPercentileCont {
| DataType::Int16
| DataType::Int32
| DataType::Int64
+ | DataType::Float16
| DataType::Float32
| DataType::Float64 => {
if let Some(max_size) = tdigest_max_size {
@@ -376,83 +366,51 @@ impl ApproxPercentileAccumulator {
match values.data_type() {
DataType::Float64 => {
let array = downcast_value!(values, Float64Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::>>()?)
+ Ok(array.values().iter().copied().collect::>())
}
DataType::Float32 => {
let array = downcast_value!(values, Float32Array);
+ Ok(array.values().iter().map(|v| *v as f64).collect::>())
+ }
+ DataType::Float16 => {
+ let array = downcast_value!(values, Float16Array);
Ok(array
.values()
.iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::>>()?)
+ .map(|v| v.to_f64())
+ .collect::>())
}
DataType::Int64 => {
let array = downcast_value!(values, Int64Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::>>()?)
+ Ok(array.values().iter().map(|v| *v as f64).collect::>())
}
DataType::Int32 => {
let array = downcast_value!(values, Int32Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::>>()?)
+ Ok(array.values().iter().map(|v| *v as f64).collect::>())
}
DataType::Int16 => {
let array = downcast_value!(values, Int16Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::>>()?)
+ Ok(array.values().iter().map(|v| *v as f64).collect::>())
}
DataType::Int8 => {
let array = downcast_value!(values, Int8Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::>>()?)
+ Ok(array.values().iter().map(|v| *v as f64).collect::>())
}
DataType::UInt64 => {
let array = downcast_value!(values, UInt64Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::>>()?)
+ Ok(array.values().iter().map(|v| *v as f64).collect::>())
}
DataType::UInt32 => {
let array = downcast_value!(values, UInt32Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::>>()?)
+ Ok(array.values().iter().map(|v| *v as f64).collect::>())
}
DataType::UInt16 => {
let array = downcast_value!(values, UInt16Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::>>()?)
+ Ok(array.values().iter().map(|v| *v as f64).collect::>())
}
DataType::UInt8 => {
let array = downcast_value!(values, UInt8Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::>>()?)
+ Ok(array.values().iter().map(|v| *v as f64).collect::>())
}
e => internal_err!(
"APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
@@ -495,6 +453,7 @@ impl Accumulator for ApproxPercentileAccumulator {
DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
+ DataType::Float16 => ScalarValue::Float16(Some(half::f16::from_f64(q))),
DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
DataType::Float64 => ScalarValue::Float64(Some(q)),
v => unreachable!("unexpected return type {}", v),
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt
index a593f66c7d938..6de2a7b2fbe29 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -910,6 +910,16 @@ SELECT approx_median(col_f64_nan) FROM median_table
----
NaN
+query RT
+select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median(arrow_cast(col_f32, 'Float16'))) from median_table;
+----
+2.75 Float16
+
+query ?T
+select approx_median(NULL), arrow_typeof(approx_median(NULL)) from median_table;
+----
+NULL Null
+
# median decimal
statement ok
create table t(c decimal(10, 4)) as values (0.0001), (0.0002), (0.0003), (0.0004), (0.0005), (0.0006);