Skip to content

Commit e1e06b2

Browse files
committed
Move support_distinct next to distinct
1 parent 9ab87c6 commit e1e06b2

File tree

2 files changed

+49
-20
lines changed

2 files changed

+49
-20
lines changed

arrow-ord/src/cmp.rs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use arrow_array::{
3131
};
3232
use arrow_buffer::bit_util::ceil;
3333
use arrow_buffer::{BooleanBuffer, NullBuffer};
34-
use arrow_schema::ArrowError;
34+
use arrow_schema::{ArrowError, DataType};
3535
use arrow_select::take::take;
3636
use std::cmp::Ordering;
3737
use std::ops::Not;
@@ -201,6 +201,20 @@ pub fn not_distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, Ar
201201
compare_op(Op::NotDistinct, lhs, rhs)
202202
}
203203

204+
/// Returns true if `distinct` (via `compare_op`) can handle this data type.
205+
///
206+
/// `compare_op` unwraps at most one level of dictionary, then dispatches on
207+
/// the leaf type. Anything else (REE, nested dictionary, nested/complex types)
208+
/// must go through `make_comparator` instead.
209+
pub(crate) fn supports_distinct(dt: &DataType) -> bool {
210+
use arrow_schema::DataType::*;
211+
let leaf = match dt {
212+
Dictionary(_, v) => v.as_ref(),
213+
dt => dt,
214+
};
215+
!leaf.is_nested() && !matches!(leaf, Dictionary(_, _) | RunEndEncoded(_, _))
216+
}
217+
204218
/// Perform `op` on the provided `Datum`
205219
#[inline(never)]
206220
fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
@@ -832,6 +846,38 @@ mod tests {
832846
assert_eq!(not_distinct(&b, &a).unwrap(), expected);
833847
}
834848

849+
#[test]
850+
fn test_supports_distinct() {
851+
use arrow_schema::{DataType::*, Field};
852+
853+
assert!(supports_distinct(&Int32));
854+
assert!(supports_distinct(&Float64));
855+
assert!(supports_distinct(&Utf8));
856+
assert!(supports_distinct(&Boolean));
857+
858+
// One level of dictionary unwrap is supported.
859+
assert!(supports_distinct(&Dictionary(
860+
Box::new(Int16),
861+
Box::new(Utf8),
862+
)));
863+
864+
// REE, nested dictionary, and complex types are not supported.
865+
assert!(!supports_distinct(&RunEndEncoded(
866+
Arc::new(Field::new("run_ends", Int32, false)),
867+
Arc::new(Field::new("values", Int32, true)),
868+
)));
869+
assert!(!supports_distinct(&Dictionary(
870+
Box::new(Int16),
871+
Box::new(Dictionary(Box::new(Int8), Box::new(Utf8))),
872+
)));
873+
assert!(!supports_distinct(&List(Arc::new(Field::new(
874+
"item", Int32, true,
875+
)))));
876+
assert!(!supports_distinct(&Struct(
877+
vec![Field::new("a", Int32, true)].into()
878+
)));
879+
}
880+
835881
#[test]
836882
fn test_scalar_negation() {
837883
let a = Int32Array::new_scalar(54);

arrow-ord/src/partition.rs

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ use std::ops::Range;
2121

2222
use arrow_array::{Array, ArrayRef};
2323
use arrow_buffer::BooleanBuffer;
24-
use arrow_schema::{ArrowError, DataType, SortOptions};
24+
use arrow_schema::{ArrowError, SortOptions};
2525

26-
use crate::cmp::distinct;
26+
use crate::cmp::{distinct, supports_distinct};
2727
use crate::ord::make_comparator;
2828

2929
/// A computed set of partitions, see [`partition`]
@@ -152,23 +152,6 @@ pub fn partition(columns: &[ArrayRef]) -> Result<Partitions, ArrowError> {
152152
Ok(Partitions(Some(acc)))
153153
}
154154

155-
/// Returns true if `distinct` (via `compare_op`) can handle this data type.
156-
///
157-
/// `compare_op` unwraps at most one level of dictionary, then dispatches on
158-
/// the leaf type. Anything else (REE, nested dictionary, nested/complex types)
159-
/// must go through `make_comparator` instead.
160-
fn supports_distinct(dt: &DataType) -> bool {
161-
let leaf = match dt {
162-
DataType::Dictionary(_, v) => v.as_ref(),
163-
dt => dt,
164-
};
165-
!leaf.is_nested()
166-
&& !matches!(
167-
leaf,
168-
DataType::Dictionary(_, _) | DataType::RunEndEncoded(_, _)
169-
)
170-
}
171-
172155
/// Returns a mask with bits set whenever the value or nullability changes
173156
fn find_boundaries(v: &dyn Array) -> Result<BooleanBuffer, ArrowError> {
174157
let slice_len = v.len() - 1;

0 commit comments

Comments
 (0)