Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions datafusion/functions-aggregate-common/src/tdigest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
use arrow::datatypes::DataType;
use arrow::datatypes::Float64Type;
use datafusion_common::cast::as_primitive_array;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use std::cmp::Ordering;
use std::mem::{size_of, size_of_val};
Expand Down Expand Up @@ -61,41 +60,6 @@ macro_rules! cast_scalar_u64 {
};
}

/// This trait is implemented for each type a [`TDigest`] can operate on,
/// allowing it to support both numerical rust types (obtained from
/// `PrimitiveArray` instances), and [`ScalarValue`] instances.
pub trait TryIntoF64 {
/// A fallible conversion of a possibly null `self` into a [`f64`].
///
/// If `self` is null, this method must return `Ok(None)`.
///
/// If `self` cannot be coerced to the desired type, this method must return
/// an `Err` variant.
fn try_as_f64(&self) -> Result<Option<f64>>;
}

/// Generate an infallible conversion from `type` to an [`f64`].
macro_rules! impl_try_ordered_f64 {
($type:ty) => {
impl TryIntoF64 for $type {
fn try_as_f64(&self) -> Result<Option<f64>> {
Ok(Some(*self as f64))
}
}
};
}

impl_try_ordered_f64!(f64);
impl_try_ordered_f64!(f32);
impl_try_ordered_f64!(i64);
impl_try_ordered_f64!(i32);
impl_try_ordered_f64!(i16);
impl_try_ordered_f64!(i8);
impl_try_ordered_f64!(u64);
impl_try_ordered_f64!(u32);
impl_try_ordered_f64!(u16);
impl_try_ordered_f64!(u8);

/// Centroid implementation to the cluster mentioned in the paper.
#[derive(Debug, PartialEq, Clone)]
pub struct Centroid {
Expand Down
88 changes: 51 additions & 37 deletions datafusion/functions-aggregate/src/approx_median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@

use arrow::datatypes::DataType::{Float64, UInt64};
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::types::NativeType;
use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator;
use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;

use datafusion_common::{not_impl_err, plan_err, Result};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
Accumulator, AggregateUDFImpl, Coercion, Documentation, Signature, TypeSignature,
TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

Expand Down Expand Up @@ -57,20 +59,11 @@ make_udaf_expr_and_func!(
```"#,
standard_argument(name = "expression",)
)]
#[derive(PartialEq, Eq, Hash)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ApproxMedian {
signature: Signature,
}

impl Debug for ApproxMedian {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ApproxMedian")
.field("name", &self.name())
.field("signature", &self.signature)
.finish()
}
}

impl Default for ApproxMedian {
fn default() -> Self {
Self::new()
Expand All @@ -81,33 +74,53 @@ impl ApproxMedian {
/// Create a new APPROX_MEDIAN aggregate function
pub fn new() -> Self {
Self {
signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::Coercible(vec![Coercion::new_exact(
TypeSignatureClass::Integer,
)]),
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Float,
vec![TypeSignatureClass::Decimal],
NativeType::Float64,
)]),
],
Volatility::Immutable,
),
}
}
}

impl AggregateUDFImpl for ApproxMedian {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
Ok(vec![
Field::new(format_state_name(args.name, "max_size"), UInt64, false),
Field::new(format_state_name(args.name, "sum"), Float64, false),
Field::new(format_state_name(args.name, "count"), UInt64, false),
Field::new(format_state_name(args.name, "max"), Float64, false),
Field::new(format_state_name(args.name, "min"), Float64, false),
Field::new_list(
format_state_name(args.name, "centroids"),
Field::new_list_field(Float64, true),
false,
),
]
.into_iter()
.map(Arc::new)
.collect())
if args.input_fields[0].data_type().is_null() {
Ok(vec![Field::new(
format_state_name(args.name, self.name()),
DataType::Null,
true,
)
.into()])
} else {
Ok(vec![
Field::new(format_state_name(args.name, "max_size"), UInt64, false),
Field::new(format_state_name(args.name, "sum"), Float64, false),
Field::new(format_state_name(args.name, "count"), UInt64, false),
Field::new(format_state_name(args.name, "max"), Float64, false),
Field::new(format_state_name(args.name, "min"), Float64, false),
Field::new_list(
format_state_name(args.name, "centroids"),
Field::new_list_field(Float64, true),
false,
),
]
.into_iter()
.map(Arc::new)
.collect())
}
}

fn name(&self) -> &str {
Expand All @@ -119,9 +132,6 @@ impl AggregateUDFImpl for ApproxMedian {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("ApproxMedian requires numeric input types");
}
Ok(arg_types[0].clone())
}

Expand All @@ -132,10 +142,14 @@ impl AggregateUDFImpl for ApproxMedian {
);
}

Ok(Box::new(ApproxPercentileAccumulator::new(
0.5_f64,
acc_args.expr_fields[0].data_type().clone(),
)))
if acc_args.expr_fields[0].data_type().is_null() {
Ok(Box::new(NoopAccumulator::default()))
} else {
Ok(Box::new(ApproxPercentileAccumulator::new(
0.5_f64,
acc_args.expr_fields[0].data_type().clone(),
)))
}
}

fn documentation(&self) -> Option<&Documentation> {
Expand Down
83 changes: 21 additions & 62 deletions datafusion/functions-aggregate/src/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
// under the License.

use std::any::Any;
use std::fmt::{Debug, Formatter};
use std::fmt::Debug;
use std::mem::size_of_val;
use std::sync::Arc;

use arrow::array::Array;
use arrow::array::{Array, Float16Array};
use arrow::compute::{filter, is_not_null};
use arrow::datatypes::FieldRef;
use arrow::{
Expand All @@ -42,9 +42,7 @@ use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
Volatility,
};
use datafusion_functions_aggregate_common::tdigest::{
TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
};
use datafusion_functions_aggregate_common::tdigest::{TDigest, DEFAULT_MAX_SIZE};
use datafusion_macros::user_doc;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;

Expand Down Expand Up @@ -121,20 +119,11 @@ An alternate syntax is also supported:
description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
)
)]
#[derive(PartialEq, Eq, Hash)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ApproxPercentileCont {
signature: Signature,
}

impl Debug for ApproxPercentileCont {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
f.debug_struct("ApproxPercentileCont")
.field("name", &self.name())
.field("signature", &self.signature)
.finish()
}
}

impl Default for ApproxPercentileCont {
fn default() -> Self {
Self::new()
Expand Down Expand Up @@ -197,6 +186,7 @@ impl ApproxPercentileCont {
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float16
| DataType::Float32
| DataType::Float64 => {
if let Some(max_size) = tdigest_max_size {
Expand Down Expand Up @@ -376,83 +366,51 @@ impl ApproxPercentileAccumulator {
match values.data_type() {
DataType::Float64 => {
let array = downcast_value!(values, Float64Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
Ok(array.values().iter().copied().collect::<Vec<_>>())
}
DataType::Float32 => {
let array = downcast_value!(values, Float32Array);
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
}
DataType::Float16 => {
let array = downcast_value!(values, Float16Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
.map(|v| v.to_f64())
.collect::<Vec<_>>())
}
DataType::Int64 => {
let array = downcast_value!(values, Int64Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
}
DataType::Int32 => {
let array = downcast_value!(values, Int32Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
}
DataType::Int16 => {
let array = downcast_value!(values, Int16Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
}
DataType::Int8 => {
let array = downcast_value!(values, Int8Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
}
DataType::UInt64 => {
let array = downcast_value!(values, UInt64Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
}
DataType::UInt32 => {
let array = downcast_value!(values, UInt32Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
}
DataType::UInt16 => {
let array = downcast_value!(values, UInt16Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
}
DataType::UInt8 => {
let array = downcast_value!(values, UInt8Array);
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
}
e => internal_err!(
"APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
Expand Down Expand Up @@ -495,6 +453,7 @@ impl Accumulator for ApproxPercentileAccumulator {
DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
DataType::Float16 => ScalarValue::Float16(Some(half::f16::from_f64(q))),
DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
DataType::Float64 => ScalarValue::Float64(Some(q)),
v => unreachable!("unexpected return type {}", v),
Expand Down
10 changes: 10 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,16 @@ SELECT approx_median(col_f64_nan) FROM median_table
----
NaN

query RT
select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median(arrow_cast(col_f32, 'Float16'))) from median_table;
----
2.75 Float16

query ?T
select approx_median(NULL), arrow_typeof(approx_median(NULL)) from median_table;
----
NULL Null

# median decimal
statement ok
create table t(c decimal(10, 4)) as values (0.0001), (0.0002), (0.0003), (0.0004), (0.0005), (0.0006);
Expand Down
Loading