diff --git a/datapipe/sql_util.py b/datapipe/sql_util.py index 1a7495a1..94a4dfd3 100644 --- a/datapipe/sql_util.py +++ b/datapipe/sql_util.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional +import pandas as pd from sqlalchemy import Column, Integer, String, Table, tuple_ from datapipe.run_config import RunConfig @@ -21,9 +22,17 @@ def sql_apply_idx_filter_to_table( # Когда ключей много - сравниваем по кортежу keys = tuple_(*[table.c[key] for key in primary_keys]) # type: ignore - sql = sql.where( - keys.in_([tuple([r[key] for key in primary_keys]) for r in idx.to_dict(orient="records")]) # type: ignore - ) + where_values: list[tuple] = [] + for r in idx.to_dict(orient="records"): + this_row: list[Any] = [] + for key in primary_keys: + if pd.isna(r[key]): + this_row.append(None) + else: + this_row.append(r[key]) + where_values.append(tuple(this_row)) + + sql = sql.where(keys.in_(where_values)) return sql diff --git a/datapipe/store/tests/abstract.py b/datapipe/store/tests/abstract.py index d5a64d76..50faa4c1 100644 --- a/datapipe/store/tests/abstract.py +++ b/datapipe/store/tests/abstract.py @@ -6,7 +6,7 @@ import cloudpickle import pandas as pd import pytest -from sqlalchemy import Column, String +from sqlalchemy import Column, Integer, String from datapipe.run_config import RunConfig from datapipe.store.table_store import TableStore @@ -49,6 +49,47 @@ def test_get_schema( assert store.get_schema() == schema + def test_multiple_keys_with_nan(self, store_maker: TableStoreMaker) -> None: + data_df = pd.DataFrame( + { + "id1": [1, 2, 3], + "id2": ["a", "b", "c"], + "value": [10, 20, 30], + } + ) + schema: list[Column] = [ + Column("id1", Integer, primary_key=True), + Column("id2", String(100), primary_key=True), + Column("value", Integer), + ] + + store = store_maker(schema) + store.insert_rows(data_df) + + assert_df_equal( + store.read_rows( + data_to_index( + pd.DataFrame.from_dict({"id1": [2, -1], "id2": ["b", None]}), + ["id1", "id2"], + ) + ), + pd.DataFrame.from_dict({"id1": [2], "id2": ["b"], "value": [20]}), + index_cols=["id1", "id2"], + ) + + assert_df_equal( + store.read_rows( + data_to_index( + pd.DataFrame.from_dict({"id1": [2, None], "id2": ["b", "z"]}), + ["id1", "id2"], + ) + ), + pd.DataFrame.from_dict({"id1": [2], "id2": ["b"], "value": [20]}), + index_cols=["id1", "id2"], + ) + + assert_ts_contains(store, data_df) + @pytest.mark.parametrize("data_df,schema", DATA_PARAMS) def test_write_read_rows( self,