Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions datapipe/sql_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict, List, Optional

import pandas as pd
from sqlalchemy import Column, Integer, String, Table, tuple_

from datapipe.run_config import RunConfig
Expand All @@ -21,9 +22,17 @@ def sql_apply_idx_filter_to_table(
# Когда ключей много - сравниваем по кортежу
keys = tuple_(*[table.c[key] for key in primary_keys]) # type: ignore

sql = sql.where(
keys.in_([tuple([r[key] for key in primary_keys]) for r in idx.to_dict(orient="records")]) # type: ignore
)
where_values: list[tuple] = []
for r in idx.to_dict(orient="records"):
this_row: list[Any] = []
for key in primary_keys:
if pd.isna(r[key]):
this_row.append(None)
else:
this_row.append(r[key])
where_values.append(tuple(this_row))

sql = sql.where(keys.in_(where_values))

return sql

Expand Down
43 changes: 42 additions & 1 deletion datapipe/store/tests/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import cloudpickle
import pandas as pd
import pytest
from sqlalchemy import Column, String
from sqlalchemy import Column, Integer, String

from datapipe.run_config import RunConfig
from datapipe.store.table_store import TableStore
Expand Down Expand Up @@ -49,6 +49,47 @@ def test_get_schema(

assert store.get_schema() == schema

def test_multiple_keys_with_nan(self, store_maker: TableStoreMaker) -> None:
data_df = pd.DataFrame(
{
"id1": [1, 2, 3],
"id2": ["a", "b", "c"],
"value": [10, 20, 30],
}
)
schema: list[Column] = [
Column("id1", Integer, primary_key=True),
Column("id2", String(100), primary_key=True),
Column("value", Integer),
]

store = store_maker(schema)
store.insert_rows(data_df)

assert_df_equal(
store.read_rows(
data_to_index(
pd.DataFrame.from_dict({"id1": [2, -1], "id2": ["b", None]}),
["id1", "id2"],
)
),
pd.DataFrame.from_dict({"id1": [2], "id2": ["b"], "value": [20]}),
index_cols=["id1", "id2"],
)

assert_df_equal(
store.read_rows(
data_to_index(
pd.DataFrame.from_dict({"id1": [2, None], "id2": ["b", "z"]}),
["id1", "id2"],
)
),
pd.DataFrame.from_dict({"id1": [2], "id2": ["b"], "value": [20]}),
index_cols=["id1", "id2"],
)

assert_ts_contains(store, data_df)

@pytest.mark.parametrize("data_df,schema", DATA_PARAMS)
def test_write_read_rows(
self,
Expand Down
Loading