Skip to content

Commit 55766c4

Browse files
add test
1 parent 83e79b4 commit 55766c4

1 file changed

Lines changed: 70 additions & 0 deletions

File tree

tests/test_conditions.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,3 +577,73 @@ def test_can_compare_dtypes(
577577
dtype_left=dtype_left, dtype_right=dtype_right
578578
)
579579
assert can_compare_dtypes_actual == can_compare_dtypes
580+
581+
582+
@pytest.mark.parametrize(
583+
("dtype_left", "dtype_right", "expected"),
584+
[
585+
# Primitives that don't need element-wise comparison
586+
(pl.Int64, pl.Int64, False),
587+
(pl.String, pl.String, False),
588+
(pl.Boolean, pl.Boolean, False),
589+
# Float/numeric pairs
590+
(pl.Float64, pl.Float64, True),
591+
(pl.Int64, pl.Float64, True),
592+
(pl.Float32, pl.Int32, True),
593+
# Temporal pairs
594+
(pl.Datetime, pl.Datetime, True),
595+
(pl.Date, pl.Date, True),
596+
(pl.Datetime, pl.Date, True),
597+
# Enum/categorical
598+
(pl.Enum(["a", "b"]), pl.Enum(["a", "b"]), False),
599+
(pl.Enum(["a", "b"]), pl.Enum(["a", "b", "c"]), True),
600+
(pl.Enum(["a"]), pl.Categorical(), True),
601+
(pl.Categorical(), pl.Enum(["a"]), True),
602+
# Struct with no tolerance-requiring fields
603+
(
604+
pl.Struct({"x": pl.Int64, "y": pl.String}),
605+
pl.Struct({"x": pl.Int64, "y": pl.String}),
606+
False,
607+
),
608+
# Struct with a float field
609+
(
610+
pl.Struct({"x": pl.Int64, "y": pl.Float64}),
611+
pl.Struct({"x": pl.Int64, "y": pl.Float64}),
612+
True,
613+
),
614+
# Struct with different-category enums
615+
(
616+
pl.Struct({"x": pl.Enum(["a"])}),
617+
pl.Struct({"x": pl.Enum(["b"])}),
618+
True,
619+
),
620+
# List/Array with non-tolerance inner type
621+
(pl.List(pl.Int64), pl.List(pl.Int64), False),
622+
(pl.Array(pl.String, shape=3), pl.Array(pl.String, shape=3), False),
623+
# List/Array with tolerance-requiring inner type
624+
(pl.List(pl.Float64), pl.List(pl.Float64), True),
625+
(pl.Array(pl.Datetime, shape=2), pl.Array(pl.Datetime, shape=2), True),
626+
# Nested: list of structs with a float field
627+
(
628+
pl.List(pl.Struct({"x": pl.Float64})),
629+
pl.List(pl.Struct({"x": pl.Float64})),
630+
True,
631+
),
632+
# Nested: list of structs without tolerance-requiring fields
633+
(
634+
pl.List(pl.Struct({"x": pl.Int64})),
635+
pl.List(pl.Struct({"x": pl.Int64})),
636+
False,
637+
),
638+
# Deeply nested: struct with a list of structs with a float field
639+
(
640+
pl.List(pl.Struct({"x": pl.String, "y": pl.List(pl.Float64)})),
641+
pl.List(pl.Struct({"x": pl.String, "y": pl.List(pl.Float64)})),
642+
True,
643+
),
644+
],
645+
)
646+
def test_needs_element_wise_comparison(
647+
dtype_left: pl.DataType, dtype_right: pl.DataType, expected: bool
648+
) -> None:
649+
assert _needs_element_wise_comparison(dtype_left, dtype_right) == expected

0 commit comments

Comments
 (0)