diff --git a/CHANGELOG.md b/CHANGELOG.md index c6b8fdb5..07c1d5f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,6 @@ See [key-mapping.md](design-docs/2025-12-key-mapping.md) for motivation * Renamed `JoinSpec` to `InputSpec` * Added `keys` parameter to `InputSpec` and `ComputeInput` to support joining tables with different key names -* Added `DataField` accessor for `InputSpec.keys` ### Python3.9 support is deprecated diff --git a/datapipe/compute.py b/datapipe/compute.py index 5cfe83fd..b696dd9b 100644 --- a/datapipe/compute.py +++ b/datapipe/compute.py @@ -12,7 +12,7 @@ from datapipe.run_config import RunConfig from datapipe.store.database import TableStoreDB from datapipe.store.table_store import TableStore -from datapipe.types import ChangeList, DataField, FieldAccessor, IndexDF, Labels, MetaSchema, TableOrName +from datapipe.types import ChangeList, IndexDF, Labels, MetaSchema, TableOrName logger = logging.getLogger("datapipe.compute") tracer = trace.get_tracer("datapipe.compute") @@ -87,14 +87,11 @@ class ComputeInput: dt: DataTable join_type: Literal["inner", "full"] = "full" - # If provided, this dict tells how to get key columns from meta and data tables - # - # Example: {"idx_col": DataField("data_col")} means that to get idx_col value - # we should read data_col from data table + # If provided, this dict tells how to get key columns from meta table # # Example: {"idx_col": "meta_col"} means that to get idx_col value # we should read meta_col from meta table - keys: dict[str, FieldAccessor] | None = None + keys: dict[str, str] | None = None @property def primary_keys(self) -> list[str]: @@ -107,23 +104,13 @@ def primary_keys(self) -> list[str]: def primary_schema(self) -> MetaSchema: if self.keys: primary_schema_dict = {col.name: col for col in self.dt.primary_schema} - data_schema_dict = {col.name: col for col in self.dt.table_store.get_schema()} schema = [] for k, accessor in self.keys.items(): - if isinstance(accessor, str): - source_column = primary_schema_dict[accessor] - column_alias = k - elif isinstance(accessor, DataField): - source_column = data_schema_dict[accessor.field_name] - column_alias = k - schema.append(data_schema_dict[accessor.field_name]) - else: - raise ValueError(f"Unknown accessor type: {type(accessor)}") - + source_column = primary_schema_dict[accessor] schema.append( Column( - column_alias, + k, source_column.type, primary_key=source_column.primary_key, ) diff --git a/datapipe/meta/base.py b/datapipe/meta/base.py index 1060cb01..7ab60cfd 100644 --- a/datapipe/meta/base.py +++ b/datapipe/meta/base.py @@ -5,7 +5,7 @@ from sqlalchemy import Column from datapipe.run_config import RunConfig -from datapipe.types import ChangeList, DataSchema, FieldAccessor, HashDF, IndexDF, MetadataDF, MetaSchema +from datapipe.types import ChangeList, DataSchema, HashDF, IndexDF, MetadataDF, MetaSchema if TYPE_CHECKING: from datapipe.compute import ComputeInput @@ -172,13 +172,13 @@ def reset_metadata( def transform_idx_to_table_idx( self, transform_idx: IndexDF, - keys: dict[str, FieldAccessor] | None = None, + keys: dict[str, str] | None = None, ) -> IndexDF: """ Given an index dataframe with transform keys, return an index dataframe with table keys, applying `keys` aliasing if provided. - * `keys` is a mapping from table key to transform key + * `keys` is a mapping from transform key to table key """ if keys is None: @@ -186,11 +186,9 @@ def transform_idx_to_table_idx( table_key_cols: dict[str, pd.Series] = {} for transform_col in transform_idx.columns: - accessor = keys.get(transform_col) if keys is not None else transform_col - if isinstance(accessor, str): - table_key_cols[accessor] = transform_idx[transform_col] - else: - pass # skip non-meta fields + table_col = keys.get(transform_col) + if table_col is not None: + table_key_cols[table_col] = transform_idx[transform_col] return IndexDF(pd.DataFrame(table_key_cols)) diff --git a/datapipe/meta/sql_meta.py b/datapipe/meta/sql_meta.py index 6e26f628..dc0c94fb 100644 --- a/datapipe/meta/sql_meta.py +++ b/datapipe/meta/sql_meta.py @@ -26,9 +26,7 @@ from datapipe.types import ( ChangeList, DataDF, - DataField, DataSchema, - FieldAccessor, HashDF, IndexDF, MetadataDF, @@ -357,7 +355,7 @@ def get_agg_cte( self, transform_keys: list[str], table_store: TableStore, - keys: dict[str, FieldAccessor], + keys: dict[str, str], filters_idx: IndexDF | None = None, run_config: RunConfig | None = None, ) -> tuple[list[str], Any]: @@ -365,9 +363,7 @@ def get_agg_cte( Create a CTE that aggregates the table by transform keys, applies keys aliasing and returns the maximum update_ts for each group. - * `keys` is a mapping from transform key to table key accessor - (can be string for meta table column or DataField for data table - column) + * `keys` is a mapping from transform key to meta table column name * `transform_keys` is a list of keys used in the transformation CTE has the following columns: @@ -379,41 +375,22 @@ def get_agg_cte( present in primary keys of this CTE """ - from datapipe.store.database import TableStoreDB - meta_table = self.sql_table - data_table = None key_cols: list[Any] = [] cte_transform_keys: list[str] = [] - should_join_data_table = False for transform_key in transform_keys: # TODO convert to match when we deprecate Python 3.9 accessor = keys.get(transform_key, transform_key) - if isinstance(accessor, str): - if accessor in self.primary_keys: - key_cols.append(meta_table.c[accessor].label(transform_key)) - cte_transform_keys.append(transform_key) - elif isinstance(accessor, DataField): - should_join_data_table = True - assert isinstance(table_store, TableStoreDB) - data_table = table_store.data_table - - key_cols.append(data_table.c[accessor.field_name].label(transform_key)) + if accessor in self.primary_keys: + key_cols.append(meta_table.c[accessor].label(transform_key)) cte_transform_keys.append(transform_key) sql: Any = sa.select(*key_cols + [sa.func.max(meta_table.c["update_ts"]).label("update_ts")]).select_from( meta_table ) - if should_join_data_table: - assert data_table is not None - sql = sql.join( - data_table, - sa.and_(*[meta_table.c[pk] == data_table.c[pk] for pk in self.primary_keys]), - ) - if len(key_cols) > 0: sql = sql.group_by(*key_cols) diff --git a/datapipe/types.py b/datapipe/types.py index fe553c3a..55a32a96 100644 --- a/datapipe/types.py +++ b/datapipe/types.py @@ -56,26 +56,15 @@ TableOrName = Union[str, OrmTable, "Table"] -@dataclass -class DataField: - field_name: str - - -FieldAccessor = Union[str, DataField] - - @dataclass class InputSpec: table: TableOrName - # If provided, this dict tells how to get key columns from meta and data tables - # - # Example: {"idx_col": DataField("data_col")} means that to get idx_col value - # we should read data_col from data table + # If provided, this dict tells how to get key columns from meta table # # Example: {"idx_col": "meta_col"} means that to get idx_col value # we should read meta_col from meta table - keys: dict[str, FieldAccessor] | None = None + keys: dict[str, str] | None = None @dataclass diff --git a/design-docs/2025-12-key-mapping.md b/design-docs/2025-12-key-mapping.md index 3bfc1a54..c7ff03c2 100644 --- a/design-docs/2025-12-key-mapping.md +++ b/design-docs/2025-12-key-mapping.md @@ -14,7 +14,7 @@ table do not match by name. # Use case -You have tables `User (id: PK)` and `Subscription (id: PK, user_id: DATA, sub_user_id: DATA)` +You have tables `User (id: PK)` and `Subscription (id: PK, user_id: PK, sub_user_id: PK)` You need to enrich both sides of `Subscription` with information You might write: @@ -27,8 +27,8 @@ BatchTransform( # every table should have a way to join to these keys transform_keys=["user_id", "sub_user_id"], inputs=[ - # Subscription has needed columns in data table, we fetch them from there - InputSpec(Subscription, keys={"user_id": DataField("user_id"), "sub_user_id": DataField("sub_user_id")}), + # Subscription has user_id and sub_user_id as primary keys + InputSpec(Subscription, keys={"user_id": "user_id", "sub_user_id": "sub_user_id"}), # matches tr.user_id = User.id InputSpec(User, keys={"user_id": "id"}), @@ -54,26 +54,15 @@ without renamings, it is up to end user to interpret the data. We introduce `InputSpec` qualifier for `BatchTransform` inputs. `keys` parameter defines which columns to use for this input table and where to -get them from. `keys` is a dict in a form `{"{transform_key}": key_accessor}`, -where `key_accessor` might be: -* a string, then a column from meta-table is used with possible renaming -* `DataField("data_col")` then a `data_col` from data-table is used instead of - meta-table +get them from. `keys` is a dict in a form `{"{transform_key}": "meta_table_col"}`, +where `meta_table_col` is a string referencing a column from the meta table +(i.e. a primary key column). If table is provided as is without `InputSpec` wrapper, then it is equivalent to `InputSpec(Table, join_type="outer", keys={"id1": "id1", ...})`, join type is outer join and all keys are mapped to themselves. -## DataField limitations - -`DataField` accessor serves as an ad-hoc solution for a situation when for some -reason a data field can not be promoted to a meta-field. - -Data fields are not used when retreiving a chunk of data, so it is possible to -over-fetch data. - -Data fields are not enforced to have indices in DB, so their usage might be very -heavy for database. +Note: transform key columns must always be primary keys in the table schema. # Implementation @@ -97,8 +86,7 @@ heavy for database. `BatchTransform`: * [x] correctly converts transform idx to table idx in `get_batch_input_dfs` -* [x] inputs and outputs are stored as `ComputeInput` lists, because we need - data table for `DataField` +* [x] inputs and outputs are stored as `ComputeInput` lists `DataTable`: * [x] `DataTable.get_data` accepts `table_idx` which is acquired by applying diff --git a/tests/test_meta_transform_keys.py b/tests/test_meta_transform_keys.py index 6b25ca10..9ddd1c68 100644 --- a/tests/test_meta_transform_keys.py +++ b/tests/test_meta_transform_keys.py @@ -8,7 +8,6 @@ from datapipe.step.batch_transform import BatchTransformStep from datapipe.store.database import DBConn, TableStoreDB from datapipe.tests.util import assert_datatable_equal -from datapipe.types import DataField def test_transform_keys(dbconn: DBConn): @@ -26,7 +25,7 @@ def test_transform_keys(dbconn: DBConn): "posts", [ Column("id", String, primary_key=True), - Column("user_id", String), + Column("user_id", String, primary_key=True), Column("content", String), ], create_table=True, @@ -97,7 +96,7 @@ def transform_func(posts_df, profiles_df): join_type="full", keys={ "post_id": "id", - "user_id": DataField("user_id"), + "user_id": "user_id", }, ), ComputeInput(