From d28e3e0187da8b59d8b7db59b65d8fc45356da41 Mon Sep 17 00:00:00 2001 From: Mark Gordon Date: Thu, 10 Apr 2025 18:40:28 -0700 Subject: [PATCH] Use group by instead of distinct --- subsetter/__main__.py | 2 +- subsetter/plan_model.py | 7 +++++-- subsetter/sampler.py | 13 ++++++++++++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/subsetter/__main__.py b/subsetter/__main__.py index a04e300..d18eb16 100644 --- a/subsetter/__main__.py +++ b/subsetter/__main__.py @@ -158,7 +158,7 @@ def _main_plan(args): ctx = open(args.plan_output, "w", encoding="utf-8") with ctx as fplan: yaml.dump( - plan.dict(exclude_unset=True, by_alias=True), + plan.model_dump(exclude_unset=True, by_alias=True), stream=fplan, default_flow_style=False, width=2**20, diff --git a/subsetter/plan_model.py b/subsetter/plan_model.py index 1f090d6..eaa86ff 100644 --- a/subsetter/plan_model.py +++ b/subsetter/plan_model.py @@ -266,7 +266,7 @@ def build(self, context: SQLBuildContext): for join in self.joins: # pylint: disable=not-an-iterable right = join.right.build(context).alias() - if join.half_unique: + if join.half_unique and table_obj.primary_key: joined = joined.join( right, onclause=sa.and_( @@ -294,7 +294,10 @@ def build(self, context: SQLBuildContext): ) ) - stmt = stmt.select_from(joined).distinct() + stmt = stmt.select_from(joined) + if joined is not table_obj: + stmt = stmt.group_by(*table_obj.primary_key.columns) + if self.joins_outer: exists_constraints.extend(col.is_not(None) for col in joined_cols) stmt = stmt.where(sa.or_(*exists_constraints)) diff --git a/subsetter/sampler.py b/subsetter/sampler.py index 35fd5ab..6333ea8 100644 --- a/subsetter/sampler.py +++ b/subsetter/sampler.py @@ -70,6 +70,7 @@ def create( select: sa.Select, *, name: str = "", + primary_key: Tuple[str, ...] = (), indexes: Iterable[Tuple[str, ...]] = (), ) -> Tuple[sa.Table, int]: """ @@ -83,6 +84,9 @@ def create( schema: The schema to create the temporary table within. For some dialects temporary tables always exist in their own schema and this parameter will be ignored. + primary_key: If set will mark the set of columns passed as primary keys in + the temporary table. This tuple should match a subset of the + column names in the select query. indexes: creates an index on each tuple of columns listed. This is useful if future queries are likely to reference these columns. @@ -106,7 +110,10 @@ def create( metadata, schema=temp_schema, prefixes=["TEMPORARY"], - *(sa.Column(col.name, col.type) for col in select.selected_columns), + *( + sa.Column(col.name, col.type, primary_key=col.name in primary_key) + for col in select.selected_columns + ), ) try: metadata.create_all(conn) @@ -120,6 +127,8 @@ def create( raise for idx, index_cols in enumerate(indexes): + if index_cols == primary_key: + continue # For some dialects/data types we may not be able to construct an index. We just do our # best here instead of hard failing. try: @@ -891,6 +900,7 @@ def _materialize_tables( schema, table_q, name=table_name, + primary_key=table.primary_key, indexes=joined_columns[(schema, table_name)], ) ) @@ -914,6 +924,7 @@ def _materialize_tables( schema, meta.temp_tables[(schema, table_name, 0)].select(), name=table_name, + primary_key=table.primary_key, indexes=joined_columns[(schema, table_name)], ) )