@@ -577,3 +577,73 @@ def test_can_compare_dtypes(
577577 dtype_left = dtype_left , dtype_right = dtype_right
578578 )
579579 assert can_compare_dtypes_actual == can_compare_dtypes
580+
581+
582+ @pytest .mark .parametrize (
583+ ("dtype_left" , "dtype_right" , "expected" ),
584+ [
585+ # Primitives that don't need element-wise comparison
586+ (pl .Int64 , pl .Int64 , False ),
587+ (pl .String , pl .String , False ),
588+ (pl .Boolean , pl .Boolean , False ),
589+ # Float/numeric pairs
590+ (pl .Float64 , pl .Float64 , True ),
591+ (pl .Int64 , pl .Float64 , True ),
592+ (pl .Float32 , pl .Int32 , True ),
593+ # Temporal pairs
594+ (pl .Datetime , pl .Datetime , True ),
595+ (pl .Date , pl .Date , True ),
596+ (pl .Datetime , pl .Date , True ),
597+ # Enum/categorical
598+ (pl .Enum (["a" , "b" ]), pl .Enum (["a" , "b" ]), False ),
599+ (pl .Enum (["a" , "b" ]), pl .Enum (["a" , "b" , "c" ]), True ),
600+ (pl .Enum (["a" ]), pl .Categorical (), True ),
601+ (pl .Categorical (), pl .Enum (["a" ]), True ),
602+ # Struct with no tolerance-requiring fields
603+ (
604+ pl .Struct ({"x" : pl .Int64 , "y" : pl .String }),
605+ pl .Struct ({"x" : pl .Int64 , "y" : pl .String }),
606+ False ,
607+ ),
608+ # Struct with a float field
609+ (
610+ pl .Struct ({"x" : pl .Int64 , "y" : pl .Float64 }),
611+ pl .Struct ({"x" : pl .Int64 , "y" : pl .Float64 }),
612+ True ,
613+ ),
614+ # Struct with different-category enums
615+ (
616+ pl .Struct ({"x" : pl .Enum (["a" ])}),
617+ pl .Struct ({"x" : pl .Enum (["b" ])}),
618+ True ,
619+ ),
620+ # List/Array with non-tolerance inner type
621+ (pl .List (pl .Int64 ), pl .List (pl .Int64 ), False ),
622+ (pl .Array (pl .String , shape = 3 ), pl .Array (pl .String , shape = 3 ), False ),
623+ # List/Array with tolerance-requiring inner type
624+ (pl .List (pl .Float64 ), pl .List (pl .Float64 ), True ),
625+ (pl .Array (pl .Datetime , shape = 2 ), pl .Array (pl .Datetime , shape = 2 ), True ),
626+ # Nested: list of structs with a float field
627+ (
628+ pl .List (pl .Struct ({"x" : pl .Float64 })),
629+ pl .List (pl .Struct ({"x" : pl .Float64 })),
630+ True ,
631+ ),
632+ # Nested: list of structs without tolerance-requiring fields
633+ (
634+ pl .List (pl .Struct ({"x" : pl .Int64 })),
635+ pl .List (pl .Struct ({"x" : pl .Int64 })),
636+ False ,
637+ ),
638+ # Deeply nested: struct with a list of structs with a float field
639+ (
640+ pl .List (pl .Struct ({"x" : pl .String , "y" : pl .List (pl .Float64 )})),
641+ pl .List (pl .Struct ({"x" : pl .String , "y" : pl .List (pl .Float64 )})),
642+ True ,
643+ ),
644+ ],
645+ )
646+ def test_needs_element_wise_comparison (
647+ dtype_left : pl .DataType , dtype_right : pl .DataType , expected : bool
648+ ) -> None :
649+ assert _needs_element_wise_comparison (dtype_left , dtype_right ) == expected
0 commit comments