Skip to content

Commit 5ec11f7

Browse files
feat: Perform tolerance-based comparison for lists and arrays
1 parent 2ae4c11 commit 5ec11f7

4 files changed

Lines changed: 257 additions & 6 deletions

File tree

diffly/_conditions.py

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ def condition_equal_rows(
2222
abs_tol_by_column: Mapping[str, float],
2323
rel_tol_by_column: Mapping[str, float],
2424
abs_tol_temporal_by_column: Mapping[str, dt.timedelta],
25+
max_list_lengths_by_column: Mapping[str, int] | None = None,
2526
) -> pl.Expr:
2627
"""Build an expression whether two rows are equal, based on all columns' data
2728
types."""
2829
if not columns:
2930
return pl.lit(True)
3031

32+
_max_list_lengths = max_list_lengths_by_column or {}
3133
return pl.all_horizontal(
3234
[
3335
condition_equal_columns(
@@ -37,6 +39,7 @@ def condition_equal_rows(
3739
abs_tol=abs_tol_by_column[column],
3840
rel_tol=rel_tol_by_column[column],
3941
abs_tol_temporal=abs_tol_temporal_by_column[column],
42+
max_list_length=_max_list_lengths.get(column, 0),
4043
)
4144
for column in columns
4245
]
@@ -50,6 +53,7 @@ def condition_equal_columns(
5053
abs_tol: float = ABS_TOL_DEFAULT,
5154
rel_tol: float = REL_TOL_DEFAULT,
5255
abs_tol_temporal: dt.timedelta = ABS_TOL_TEMPORAL_DEFAULT,
56+
max_list_length: int = 0,
5357
) -> pl.Expr:
5458
"""Build an expression whether two columns are equal, depending on the columns' data
5559
types."""
@@ -61,6 +65,7 @@ def condition_equal_columns(
6165
abs_tol=abs_tol,
6266
rel_tol=rel_tol,
6367
abs_tol_temporal=abs_tol_temporal,
68+
max_list_length=max_list_length,
6469
)
6570

6671

@@ -95,6 +100,7 @@ def _compare_columns(
95100
abs_tol: float,
96101
rel_tol: float,
97102
abs_tol_temporal: dt.timedelta,
103+
max_list_length: int = 0,
98104
) -> pl.Expr:
99105
"""Build an expression whether two expressions yield the same value.
100106
@@ -133,10 +139,18 @@ def _compare_columns(
133139
elif isinstance(dtype_left, pl.List | pl.Array) and isinstance(
134140
dtype_right, pl.List | pl.Array
135141
):
136-
# As of polars 1.28, there is no way to access another column within
137-
# `list.eval`. Hence, we necessarily need to resort to a primitive
138-
# comparison in this case.
139-
pass
142+
result = _compare_sequence_columns(
143+
col_left=col_left,
144+
col_right=col_right,
145+
dtype_left=dtype_left,
146+
dtype_right=dtype_right,
147+
max_list_length=max_list_length,
148+
abs_tol=abs_tol,
149+
rel_tol=rel_tol,
150+
abs_tol_temporal=abs_tol_temporal,
151+
)
152+
if result is not None:
153+
return result
140154

141155
if (
142156
isinstance(dtype_left, pl.Enum)
@@ -167,6 +181,77 @@ def _compare_columns(
167181
)
168182

169183

184+
def _compare_sequence_columns(
185+
col_left: pl.Expr,
186+
col_right: pl.Expr,
187+
dtype_left: DataType | DataTypeClass,
188+
dtype_right: DataType | DataTypeClass,
189+
max_list_length: int,
190+
abs_tol: float,
191+
rel_tol: float,
192+
abs_tol_temporal: dt.timedelta,
193+
) -> pl.Expr | None:
194+
"""Compare Array/List columns element-wise with tolerance.
195+
196+
Returns ``None`` if the comparison cannot be performed element-wise (e.g. List vs
197+
List without a known ``max_list_length``), signalling to the caller that it should
198+
fall back to primitive comparison.
199+
"""
200+
assert isinstance(dtype_left, pl.List | pl.Array)
201+
assert isinstance(dtype_right, pl.List | pl.Array)
202+
inner_left = dtype_left.inner
203+
inner_right = dtype_right.inner
204+
205+
def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Expr:
206+
if isinstance(dtype, pl.Array):
207+
return col.arr.get(i)
208+
return col.list.get(i, null_on_oob=True)
209+
210+
n: int | None = None
211+
length_check: pl.Expr | None = None
212+
213+
if isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.Array):
214+
if dtype_left.shape != dtype_right.shape:
215+
return pl.repeat(pl.lit(False), pl.len())
216+
n = dtype_left.shape[0]
217+
elif isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.List):
218+
n = dtype_left.shape[0]
219+
length_check = col_right.list.len().eq(pl.lit(n))
220+
elif isinstance(dtype_left, pl.List) and isinstance(dtype_right, pl.Array):
221+
n = dtype_right.shape[0]
222+
length_check = col_left.list.len().eq(pl.lit(n))
223+
else:
224+
# List vs List
225+
if max_list_length == 0:
226+
return None
227+
n = max_list_length
228+
length_check = col_left.list.len().eq_missing(col_right.list.len())
229+
230+
if n == 0:
231+
if length_check is not None:
232+
return _eq_missing(length_check, col_left, col_right)
233+
return _eq_missing(pl.lit(True), col_left, col_right)
234+
235+
elements_match = pl.all_horizontal(
236+
[
237+
_compare_columns(
238+
col_left=_get_element(col_left, dtype_left, i),
239+
col_right=_get_element(col_right, dtype_right, i),
240+
dtype_left=inner_left,
241+
dtype_right=inner_right,
242+
abs_tol=abs_tol,
243+
rel_tol=rel_tol,
244+
abs_tol_temporal=abs_tol_temporal,
245+
)
246+
for i in range(n)
247+
]
248+
)
249+
250+
if length_check is not None:
251+
return _eq_missing(length_check & elements_match, col_left, col_right)
252+
return elements_match
253+
254+
170255
def _compare_primitive_columns(
171256
col_left: pl.Expr,
172257
col_right: pl.Expr,

diffly/comparison.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ def equal(self, *, check_dtypes: bool = True) -> bool:
511511
abs_tol_by_column=self.abs_tol_by_column,
512512
rel_tol_by_column=self.rel_tol_by_column,
513513
abs_tol_temporal_by_column=self.abs_tol_temporal_by_column,
514+
max_list_lengths_by_column=self._max_list_lengths,
514515
).all()
515516
)
516517
.item()
@@ -708,6 +709,26 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str]
708709
raise ValueError(f"{difference} are not common columns.")
709710
return list(subset)
710711

712+
@cached_property
713+
def _max_list_lengths(self) -> dict[str, int]:
714+
list_columns = [
715+
col
716+
for col in self._other_common_columns
717+
if isinstance(self.left_schema[col], pl.List)
718+
and isinstance(self.right_schema[col], pl.List)
719+
]
720+
if not list_columns:
721+
return {}
722+
723+
exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns]
724+
[left_max, right_max] = pl.collect_all(
725+
[self.left.select(exprs), self.right.select(exprs)]
726+
)
727+
return {
728+
col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0))
729+
for col in list_columns
730+
}
731+
711732
def _condition_equal_rows(self, columns: list[str]) -> pl.Expr:
712733
return condition_equal_rows(
713734
columns=columns,
@@ -716,6 +737,7 @@ def _condition_equal_rows(self, columns: list[str]) -> pl.Expr:
716737
abs_tol_by_column=self.abs_tol_by_column,
717738
rel_tol_by_column=self.rel_tol_by_column,
718739
abs_tol_temporal_by_column=self.abs_tol_temporal_by_column,
740+
max_list_lengths_by_column=self._max_list_lengths,
719741
)
720742

721743
def _condition_equal_columns(self, column: str) -> pl.Expr:
@@ -726,6 +748,7 @@ def _condition_equal_columns(self, column: str) -> pl.Expr:
726748
abs_tol=self.abs_tol_by_column[column],
727749
rel_tol=self.rel_tol_by_column[column],
728750
abs_tol_temporal=self.abs_tol_temporal_by_column[column],
751+
max_list_length=self._max_list_lengths.get(column, 0),
729752
)
730753

731754
def _equal_rows(self) -> bool:

multi_array.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# %%
2+
import polars as pl
3+
# %%
4+
df = pl.DataFrame({"a": [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]}, schema={"a": pl.Array(inner=pl.UInt8, shape=(2, 2))})
5+
# %%
6+
df.select(pl.col("a").arr.get(1))

tests/test_conditions.py

Lines changed: 139 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_condition_equal_columns_different_struct_fields() -> None:
8181
@pytest.mark.parametrize(
8282
"rhs_type", [pl.Array(pl.Float64, shape=2), pl.List(pl.Float64)]
8383
)
84-
def test_condition_equal_columns_list_array_equal_exact(
84+
def test_condition_equal_columns_list_array_with_tolerance(
8585
lhs_type: pl.DataType, rhs_type: pl.DataType
8686
) -> None:
8787
# Arrange
@@ -110,12 +110,58 @@ def test_condition_equal_columns_list_array_equal_exact(
110110
dtype_right=rhs.schema["a_right"],
111111
abs_tol=0.5,
112112
rel_tol=0,
113+
max_list_length=2,
113114
)
114115
)
115116
.to_series()
116117
)
117118

118-
# Assert
119+
# Assert: diff is 0.1, within abs_tol=0.5
120+
assert actual.to_list() == [True, True]
121+
122+
123+
@pytest.mark.parametrize(
124+
"lhs_type", [pl.Array(pl.Float64, shape=2), pl.List(pl.Float64)]
125+
)
126+
@pytest.mark.parametrize(
127+
"rhs_type", [pl.Array(pl.Float64, shape=2), pl.List(pl.Float64)]
128+
)
129+
def test_condition_equal_columns_list_array_exceeds_tolerance(
130+
lhs_type: pl.DataType, rhs_type: pl.DataType
131+
) -> None:
132+
# Arrange
133+
lhs = pl.DataFrame(
134+
{
135+
"pk": [1, 2],
136+
"a_left": [[1.0, 1.1], [2.0, 2.1]],
137+
},
138+
schema={"pk": pl.Int64, "a_left": lhs_type},
139+
)
140+
rhs = pl.DataFrame(
141+
{
142+
"pk": [1, 2],
143+
"a_right": [[1.0, 1.1], [2.0, 2.8]],
144+
},
145+
schema={"pk": pl.Int64, "a_right": rhs_type},
146+
)
147+
148+
# Act
149+
actual = (
150+
lhs.join(rhs, on="pk", maintain_order="left")
151+
.select(
152+
condition_equal_columns(
153+
"a",
154+
dtype_left=lhs.schema["a_left"],
155+
dtype_right=rhs.schema["a_right"],
156+
abs_tol=0.5,
157+
rel_tol=0,
158+
max_list_length=2,
159+
)
160+
)
161+
.to_series()
162+
)
163+
164+
# Assert: diff is 0.7, exceeds abs_tol=0.5
119165
assert actual.to_list() == [True, False]
120166

121167

@@ -226,6 +272,97 @@ def test_condition_equal_columns_temporal_tolerance() -> None:
226272
assert actual.to_list() == [True, False, False, True]
227273

228274

275+
def test_condition_equal_columns_list_different_lengths() -> None:
276+
lhs = pl.DataFrame(
277+
{
278+
"pk": [1, 2],
279+
"a_left": [[1.0, 2.0], [3.0]],
280+
},
281+
)
282+
rhs = pl.DataFrame(
283+
{
284+
"pk": [1, 2],
285+
"a_right": [[1.0, 2.0], [3.0, 4.0]],
286+
},
287+
)
288+
289+
actual = (
290+
lhs.join(rhs, on="pk", maintain_order="left")
291+
.select(
292+
condition_equal_columns(
293+
"a",
294+
dtype_left=lhs.schema["a_left"],
295+
dtype_right=rhs.schema["a_right"],
296+
abs_tol=0.5,
297+
rel_tol=0,
298+
max_list_length=2,
299+
)
300+
)
301+
.to_series()
302+
)
303+
assert actual.to_list() == [True, False]
304+
305+
306+
def test_condition_equal_columns_list_nulls() -> None:
307+
lhs = pl.DataFrame(
308+
{
309+
"pk": [1, 2, 3],
310+
"a_left": [[1.0, 2.0], None, None],
311+
},
312+
)
313+
rhs = pl.DataFrame(
314+
{
315+
"pk": [1, 2, 3],
316+
"a_right": [[1.0, 2.0], [3.0], None],
317+
},
318+
)
319+
320+
actual = (
321+
lhs.join(rhs, on="pk", maintain_order="left")
322+
.select(
323+
condition_equal_columns(
324+
"a",
325+
dtype_left=lhs.schema["a_left"],
326+
dtype_right=rhs.schema["a_right"],
327+
max_list_length=2,
328+
)
329+
)
330+
.to_series()
331+
)
332+
assert actual.to_list() == [True, False, True]
333+
334+
335+
def test_condition_equal_columns_array_vs_list_length_mismatch() -> None:
336+
lhs = pl.DataFrame(
337+
{
338+
"pk": [1, 2],
339+
"a_left": [[1.0, 2.0], [3.0, 4.0]],
340+
},
341+
schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=2)},
342+
)
343+
rhs = pl.DataFrame(
344+
{
345+
"pk": [1, 2],
346+
"a_right": [[1.0, 2.0], [3.0]],
347+
},
348+
)
349+
350+
actual = (
351+
lhs.join(rhs, on="pk", maintain_order="left")
352+
.select(
353+
condition_equal_columns(
354+
"a",
355+
dtype_left=lhs.schema["a_left"],
356+
dtype_right=rhs.schema["a_right"],
357+
abs_tol=0.5,
358+
rel_tol=0,
359+
)
360+
)
361+
.to_series()
362+
)
363+
assert actual.to_list() == [True, False]
364+
365+
229366
@pytest.mark.parametrize(
230367
("dtype_left", "dtype_right", "can_compare_dtypes"),
231368
[

0 commit comments

Comments
 (0)