@@ -31,7 +31,7 @@ use arrow_array::{
3131} ;
3232use arrow_buffer:: bit_util:: ceil;
3333use arrow_buffer:: { BooleanBuffer , NullBuffer } ;
34- use arrow_schema:: ArrowError ;
34+ use arrow_schema:: { ArrowError , DataType } ;
3535use arrow_select:: take:: take;
3636use std:: cmp:: Ordering ;
3737use std:: ops:: Not ;
@@ -201,6 +201,20 @@ pub fn not_distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, Ar
201201 compare_op ( Op :: NotDistinct , lhs, rhs)
202202}
203203
204+ /// Returns true if `distinct` (via `compare_op`) can handle this data type.
205+ ///
206+ /// `compare_op` unwraps at most one level of dictionary, then dispatches on
207+ /// the leaf type. Anything else (REE, nested dictionary, nested/complex types)
208+ /// must go through `make_comparator` instead.
209+ pub ( crate ) fn supports_distinct ( dt : & DataType ) -> bool {
210+ use arrow_schema:: DataType :: * ;
211+ let leaf = match dt {
212+ Dictionary ( _, v) => v. as_ref ( ) ,
213+ dt => dt,
214+ } ;
215+ !leaf. is_nested ( ) && !matches ! ( leaf, Dictionary ( _, _) | RunEndEncoded ( _, _) )
216+ }
217+
204218/// Perform `op` on the provided `Datum`
205219#[ inline( never) ]
206220fn compare_op ( op : Op , lhs : & dyn Datum , rhs : & dyn Datum ) -> Result < BooleanArray , ArrowError > {
@@ -832,6 +846,38 @@ mod tests {
832846 assert_eq ! ( not_distinct( & b, & a) . unwrap( ) , expected) ;
833847 }
834848
849+ #[ test]
850+ fn test_supports_distinct ( ) {
851+ use arrow_schema:: { DataType :: * , Field } ;
852+
853+ assert ! ( supports_distinct( & Int32 ) ) ;
854+ assert ! ( supports_distinct( & Float64 ) ) ;
855+ assert ! ( supports_distinct( & Utf8 ) ) ;
856+ assert ! ( supports_distinct( & Boolean ) ) ;
857+
858+ // One level of dictionary unwrap is supported.
859+ assert ! ( supports_distinct( & Dictionary (
860+ Box :: new( Int16 ) ,
861+ Box :: new( Utf8 ) ,
862+ ) ) ) ;
863+
864+ // REE, nested dictionary, and complex types are not supported.
865+ assert ! ( !supports_distinct( & RunEndEncoded (
866+ Arc :: new( Field :: new( "run_ends" , Int32 , false ) ) ,
867+ Arc :: new( Field :: new( "values" , Int32 , true ) ) ,
868+ ) ) ) ;
869+ assert ! ( !supports_distinct( & Dictionary (
870+ Box :: new( Int16 ) ,
871+ Box :: new( Dictionary ( Box :: new( Int8 ) , Box :: new( Utf8 ) ) ) ,
872+ ) ) ) ;
873+ assert ! ( !supports_distinct( & List ( Arc :: new( Field :: new(
874+ "item" , Int32 , true ,
875+ ) ) ) ) ) ;
876+ assert ! ( !supports_distinct( & Struct (
877+ vec![ Field :: new( "a" , Int32 , true ) ] . into( )
878+ ) ) ) ;
879+ }
880+
835881 #[ test]
836882 fn test_scalar_negation ( ) {
837883 let a = Int32Array :: new_scalar ( 54 ) ;
0 commit comments