Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ See [key-mapping.md](design-docs/2025-12-key-mapping.md) for motivation
* Renamed `JoinSpec` to `InputSpec`
* Added `keys` parameter to `InputSpec` and `ComputeInput` to support
joining tables with different key names
* Added `DataField` accessor for `InputSpec.keys`

### Python3.9 support is deprecated

Expand Down
23 changes: 5 additions & 18 deletions datapipe/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from datapipe.run_config import RunConfig
from datapipe.store.database import TableStoreDB
from datapipe.store.table_store import TableStore
from datapipe.types import ChangeList, DataField, FieldAccessor, IndexDF, Labels, MetaSchema, TableOrName
from datapipe.types import ChangeList, IndexDF, Labels, MetaSchema, TableOrName

logger = logging.getLogger("datapipe.compute")
tracer = trace.get_tracer("datapipe.compute")
Expand Down Expand Up @@ -87,14 +87,11 @@ class ComputeInput:
dt: DataTable
join_type: Literal["inner", "full"] = "full"

# If provided, this dict tells how to get key columns from meta and data tables
#
# Example: {"idx_col": DataField("data_col")} means that to get idx_col value
# we should read data_col from data table
# If provided, this dict tells how to get key columns from meta table
#
# Example: {"idx_col": "meta_col"} means that to get idx_col value
# we should read meta_col from meta table
keys: dict[str, FieldAccessor] | None = None
keys: dict[str, str] | None = None

@property
def primary_keys(self) -> list[str]:
Expand All @@ -107,23 +104,13 @@ def primary_keys(self) -> list[str]:
def primary_schema(self) -> MetaSchema:
if self.keys:
primary_schema_dict = {col.name: col for col in self.dt.primary_schema}
data_schema_dict = {col.name: col for col in self.dt.table_store.get_schema()}

schema = []
for k, accessor in self.keys.items():
if isinstance(accessor, str):
source_column = primary_schema_dict[accessor]
column_alias = k
elif isinstance(accessor, DataField):
source_column = data_schema_dict[accessor.field_name]
column_alias = k
schema.append(data_schema_dict[accessor.field_name])
else:
raise ValueError(f"Unknown accessor type: {type(accessor)}")

source_column = primary_schema_dict[accessor]
schema.append(
Column(
column_alias,
k,
source_column.type,
primary_key=source_column.primary_key,
)
Expand Down
14 changes: 6 additions & 8 deletions datapipe/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy import Column

from datapipe.run_config import RunConfig
from datapipe.types import ChangeList, DataSchema, FieldAccessor, HashDF, IndexDF, MetadataDF, MetaSchema
from datapipe.types import ChangeList, DataSchema, HashDF, IndexDF, MetadataDF, MetaSchema

if TYPE_CHECKING:
from datapipe.compute import ComputeInput
Expand Down Expand Up @@ -172,25 +172,23 @@ def reset_metadata(
def transform_idx_to_table_idx(
self,
transform_idx: IndexDF,
keys: dict[str, FieldAccessor] | None = None,
keys: dict[str, str] | None = None,
) -> IndexDF:
"""
Given an index dataframe with transform keys, return an index dataframe
with table keys, applying `keys` aliasing if provided.

* `keys` is a mapping from table key to transform key
* `keys` is a mapping from transform key to table key
"""

if keys is None:
return transform_idx

table_key_cols: dict[str, pd.Series] = {}
for transform_col in transform_idx.columns:
accessor = keys.get(transform_col) if keys is not None else transform_col
if isinstance(accessor, str):
table_key_cols[accessor] = transform_idx[transform_col]
else:
pass # skip non-meta fields
table_col = keys.get(transform_col)
if table_col is not None:
table_key_cols[table_col] = transform_idx[transform_col]

return IndexDF(pd.DataFrame(table_key_cols))

Expand Down
31 changes: 4 additions & 27 deletions datapipe/meta/sql_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
from datapipe.types import (
ChangeList,
DataDF,
DataField,
DataSchema,
FieldAccessor,
HashDF,
IndexDF,
MetadataDF,
Expand Down Expand Up @@ -357,17 +355,15 @@ def get_agg_cte(
self,
transform_keys: list[str],
table_store: TableStore,
keys: dict[str, FieldAccessor],
keys: dict[str, str],
filters_idx: IndexDF | None = None,
run_config: RunConfig | None = None,
) -> tuple[list[str], Any]:
"""
Create a CTE that aggregates the table by transform keys, applies keys
aliasing and returns the maximum update_ts for each group.

* `keys` is a mapping from transform key to table key accessor
(can be string for meta table column or DataField for data table
column)
* `keys` is a mapping from transform key to meta table column name
* `transform_keys` is a list of keys used in the transformation

CTE has the following columns:
Expand All @@ -379,41 +375,22 @@ def get_agg_cte(
present in primary keys of this CTE
"""

from datapipe.store.database import TableStoreDB

meta_table = self.sql_table
data_table = None

key_cols: list[Any] = []
cte_transform_keys: list[str] = []
should_join_data_table = False

for transform_key in transform_keys:
# TODO convert to match when we deprecate Python 3.9
accessor = keys.get(transform_key, transform_key)
if isinstance(accessor, str):
if accessor in self.primary_keys:
key_cols.append(meta_table.c[accessor].label(transform_key))
cte_transform_keys.append(transform_key)
elif isinstance(accessor, DataField):
should_join_data_table = True
assert isinstance(table_store, TableStoreDB)
data_table = table_store.data_table

key_cols.append(data_table.c[accessor.field_name].label(transform_key))
if accessor in self.primary_keys:
key_cols.append(meta_table.c[accessor].label(transform_key))
cte_transform_keys.append(transform_key)

sql: Any = sa.select(*key_cols + [sa.func.max(meta_table.c["update_ts"]).label("update_ts")]).select_from(
meta_table
)

if should_join_data_table:
assert data_table is not None
sql = sql.join(
data_table,
sa.and_(*[meta_table.c[pk] == data_table.c[pk] for pk in self.primary_keys]),
)

if len(key_cols) > 0:
sql = sql.group_by(*key_cols)

Expand Down
15 changes: 2 additions & 13 deletions datapipe/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,15 @@
TableOrName = Union[str, OrmTable, "Table"]


@dataclass
class DataField:
field_name: str


FieldAccessor = Union[str, DataField]


@dataclass
class InputSpec:
table: TableOrName

# If provided, this dict tells how to get key columns from meta and data tables
#
# Example: {"idx_col": DataField("data_col")} means that to get idx_col value
# we should read data_col from data table
# If provided, this dict tells how to get key columns from meta table
#
# Example: {"idx_col": "meta_col"} means that to get idx_col value
# we should read meta_col from meta table
keys: dict[str, FieldAccessor] | None = None
keys: dict[str, str] | None = None


@dataclass
Expand Down
28 changes: 8 additions & 20 deletions design-docs/2025-12-key-mapping.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ table do not match by name.

# Use case

You have tables `User (id: PK)` and `Subscription (id: PK, user_id: DATA, sub_user_id: DATA)`
You have tables `User (id: PK)` and `Subscription (id: PK, user_id: PK, sub_user_id: PK)`
You need to enrich both sides of `Subscription` with information

You might write:
Expand All @@ -27,8 +27,8 @@ BatchTransform(
# every table should have a way to join to these keys
transform_keys=["user_id", "sub_user_id"],
inputs=[
# Subscription has needed columns in data table, we fetch them from there
InputSpec(Subscription, keys={"user_id": DataField("user_id"), "sub_user_id": DataField("sub_user_id")}),
# Subscription has user_id and sub_user_id as primary keys
InputSpec(Subscription, keys={"user_id": "user_id", "sub_user_id": "sub_user_id"}),

# matches tr.user_id = User.id
InputSpec(User, keys={"user_id": "id"}),
Expand All @@ -54,26 +54,15 @@ without renamings, it is up to end user to interpret the data.
We introduce `InputSpec` qualifier for `BatchTransform` inputs.

`keys` parameter defines which columns to use for this input table and where to
get them from. `keys` is a dict in a form `{"{transform_key}": key_accessor}`,
where `key_accessor` might be:
* a string, then a column from meta-table is used with possible renaming
* `DataField("data_col")` then a `data_col` from data-table is used instead of
meta-table
get them from. `keys` is a dict in a form `{"{transform_key}": "meta_table_col"}`,
where `meta_table_col` is a string referencing a column from the meta table
(i.e. a primary key column).

If table is provided as is without `InputSpec` wrapper, then it is equivalent to
`InputSpec(Table, join_type="outer", keys={"id1": "id1", ...})`, join type is
outer join and all keys are mapped to themselves.

## DataField limitations

`DataField` accessor serves as an ad-hoc solution for a situation when for some
reason a data field can not be promoted to a meta-field.

Data fields are not used when retreiving a chunk of data, so it is possible to
over-fetch data.

Data fields are not enforced to have indices in DB, so their usage might be very
heavy for database.
Note: transform key columns must always be primary keys in the table schema.


# Implementation
Expand All @@ -97,8 +86,7 @@ heavy for database.

`BatchTransform`:
* [x] correctly converts transform idx to table idx in `get_batch_input_dfs`
* [x] inputs and outputs are stored as `ComputeInput` lists, because we need
data table for `DataField`
* [x] inputs and outputs are stored as `ComputeInput` lists

`DataTable`:
* [x] `DataTable.get_data` accepts `table_idx` which is acquired by applying
Expand Down
5 changes: 2 additions & 3 deletions tests/test_meta_transform_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from datapipe.step.batch_transform import BatchTransformStep
from datapipe.store.database import DBConn, TableStoreDB
from datapipe.tests.util import assert_datatable_equal
from datapipe.types import DataField


def test_transform_keys(dbconn: DBConn):
Expand All @@ -26,7 +25,7 @@ def test_transform_keys(dbconn: DBConn):
"posts",
[
Column("id", String, primary_key=True),
Column("user_id", String),
Column("user_id", String, primary_key=True),
Column("content", String),
],
create_table=True,
Expand Down Expand Up @@ -97,7 +96,7 @@ def transform_func(posts_df, profiles_df):
join_type="full",
keys={
"post_id": "id",
"user_id": DataField("user_id"),
"user_id": "user_id",
},
),
ComputeInput(
Expand Down
Loading