From 347d19353a5cb7f38c51c5418ba330e98eedc3c4 Mon Sep 17 00:00:00 2001 From: Conrad Date: Mon, 30 Mar 2026 22:41:20 -0400 Subject: [PATCH 01/20] feat: Transpile INTERSECTS joins to binned equi-joins Column-to-column INTERSECTS joins (e.g., a.interval INTERSECTS b.interval) are now rewritten into binned equi-joins using CTEs with UNNEST(range(...)) bin assignments. This gives the query planner an equi-join key to work with instead of forcing a nested-loop or cross join. The bin size defaults to 10,000 and is configurable via the new bin_size parameter on transpile(). Literal-range INTERSECTS filters remain unchanged. --- src/giql/transformer.py | 216 +++++++++++++++++++++++++++++++++++++++- src/giql/transpile.py | 24 ++++- 2 files changed, 238 insertions(+), 2 deletions(-) diff --git a/src/giql/transformer.py b/src/giql/transformer.py index de1e70f..2d11cf5 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -1,7 +1,8 @@ """Query transformers for GIQL operations. This module contains transformers that rewrite queries containing GIQL-specific -operations (like CLUSTER and MERGE) into equivalent SQL with CTEs. +operations (like CLUSTER, MERGE, and binned INTERSECTS joins) into equivalent +SQL with CTEs. """ from sqlglot import exp @@ -12,8 +13,11 @@ from giql.constants import DEFAULT_STRAND_COL from giql.expressions import GIQLCluster from giql.expressions import GIQLMerge +from giql.expressions import Intersects from giql.table import Tables +DEFAULT_BIN_SIZE = 10_000 + class ClusterTransformer: """Transforms queries containing CLUSTER into CTE-based queries. @@ -573,3 +577,213 @@ def _transform_for_merge( ) return final_query + + +class IntersectsBinnedJoinTransformer: + """Transforms column-to-column INTERSECTS joins into binned equi-joins. + + Rewrites: + + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.region + + Into: + + WITH __giql_a_binned AS ( + SELECT *, UNNEST(range( + CAST("start" / B AS BIGINT), + CAST(("end" - 1) / B + 1 AS BIGINT) + )) AS __giql_bin FROM peaks + ), + __giql_b_binned AS ( + SELECT *, UNNEST(range( + CAST("start" / B AS BIGINT), + CAST(("end" - 1) / B + 1 AS BIGINT) + )) AS __giql_bin FROM genes + ) + SELECT DISTINCT a.*, b.* + FROM __giql_a_binned AS a + JOIN __giql_b_binned AS b + ON a."chrom" = b."chrom" AND a.__giql_bin = b.__giql_bin + WHERE a."start" < b."end" AND a."end" > b."start" + + Literal-range INTERSECTS (e.g., WHERE clause filters) are left untouched. + """ + + def __init__(self, tables: Tables, bin_size: int | None = None): + self.tables = tables + self.bin_size = bin_size if bin_size is not None else DEFAULT_BIN_SIZE + + def transform(self, query: exp.Expression) -> exp.Expression: + if not isinstance(query, exp.Select): + return query + + joins = query.args.get("joins") + if not joins: + return query + + for join in joins: + intersects = self._find_column_intersects(join) + if intersects: + return self._rewrite(query, join, intersects) + + return query + + def _find_column_intersects(self, join: exp.Join) -> Intersects | None: + """Find an Intersects node in the JOIN ON where both sides are column refs.""" + on = join.args.get("on") + if not on: + return None + for node in on.find_all(Intersects): + if ( + isinstance(node.this, exp.Column) + and node.this.table + and isinstance(node.expression, exp.Column) + and node.expression.table + ): + return node + return None + + def _get_columns(self, table_name: str) -> tuple[str, str, str]: + """Return (chrom, start, end) column names for a table.""" + table = self.tables.get(table_name) + if table: + return (table.chrom_col, table.start_col, table.end_col) + return (DEFAULT_CHROM_COL, DEFAULT_START_COL, DEFAULT_END_COL) + + def _build_binned_select( + self, table_name: str, cols: tuple[str, str, str] + ) -> exp.Select: + """Build ``SELECT *, UNNEST(range(...)) AS __giql_bin FROM ``.""" + _chrom, start, end = cols + B = self.bin_size + + low = exp.Cast( + this=exp.Div( + this=exp.column(start, quoted=True), + expression=exp.Literal.number(B), + ), + to=exp.DataType(this=exp.DataType.Type.BIGINT), + ) + high = exp.Cast( + this=exp.Add( + this=exp.Div( + this=exp.Paren( + this=exp.Sub( + this=exp.column(end, quoted=True), + expression=exp.Literal.number(1), + ), + ), + expression=exp.Literal.number(B), + ), + expression=exp.Literal.number(1), + ), + to=exp.DataType(this=exp.DataType.Type.BIGINT), + ) + + range_fn = exp.Anonymous(this="range", expressions=[low, high]) + unnest_fn = exp.Anonymous(this="UNNEST", expressions=[range_fn]) + bin_alias = exp.Alias( + this=unnest_fn, + alias=exp.Identifier(this="__giql_bin"), + ) + + select = exp.Select() + select.select(exp.Star(), copy=False) + select.select(bin_alias, append=True, copy=False) + select.from_(exp.Table(this=exp.Identifier(this=table_name)), copy=False) + return select + + def _rewrite( + self, + query: exp.Select, + join: exp.Join, + intersects: Intersects, + ) -> exp.Select: + from_table = query.args["from_"].this + join_table = join.this + + if not isinstance(from_table, exp.Table) or not isinstance( + join_table, exp.Table + ): + return query + + from_name = from_table.name + join_name = join_table.name + from_alias = from_table.alias or from_name + join_alias = join_table.alias or join_name + + from_cols = self._get_columns(from_name) + join_cols = self._get_columns(join_name) + + # Use alias-based CTE names to avoid collisions on self-joins + from_cte_name = f"__giql_{from_alias}_binned" + join_cte_name = f"__giql_{join_alias}_binned" + + # Build CTEs + from_cte = exp.CTE( + this=self._build_binned_select(from_name, from_cols), + alias=exp.TableAlias(this=exp.Identifier(this=from_cte_name)), + ) + join_cte = exp.CTE( + this=self._build_binned_select(join_name, join_cols), + alias=exp.TableAlias(this=exp.Identifier(this=join_cte_name)), + ) + + # Add CTEs + existing_with = query.args.get("with_") + if existing_with: + existing_with.append("expressions", from_cte) + existing_with.append("expressions", join_cte) + else: + query.set("with_", exp.With(expressions=[from_cte, join_cte])) + + # Replace FROM table reference with CTE (preserve alias) + new_from = exp.Table( + this=exp.Identifier(this=from_cte_name), + alias=exp.TableAlias(this=exp.Identifier(this=from_alias)), + ) + query.args["from_"].set("this", new_from) + + # Replace JOIN table reference with CTE (preserve alias) + new_join_table = exp.Table( + this=exp.Identifier(this=join_cte_name), + alias=exp.TableAlias(this=exp.Identifier(this=join_alias)), + ) + join.set("this", new_join_table) + + # Replace JOIN ON with equi-join on chrom + bin + chrom_eq = exp.EQ( + this=exp.column(from_cols[0], table=from_alias, quoted=True), + expression=exp.column(join_cols[0], table=join_alias, quoted=True), + ) + bin_eq = exp.EQ( + this=exp.column("__giql_bin", table=from_alias), + expression=exp.column("__giql_bin", table=join_alias), + ) + join.set("on", exp.And(this=chrom_eq, expression=bin_eq)) + + # Add overlap filter to WHERE + overlap = exp.And( + this=exp.LT( + this=exp.column(from_cols[1], table=from_alias, quoted=True), + expression=exp.column(join_cols[2], table=join_alias, quoted=True), + ), + expression=exp.GT( + this=exp.column(from_cols[2], table=from_alias, quoted=True), + expression=exp.column(join_cols[1], table=join_alias, quoted=True), + ), + ) + + existing_where = query.args.get("where") + if existing_where: + merged = exp.And(this=existing_where.this, expression=overlap) + existing_where.set("this", merged) + else: + query.set("where", exp.Where(this=overlap)) + + # Add DISTINCT to deduplicate rows that appear in multiple bins + query.set("distinct", exp.Distinct()) + + return query diff --git a/src/giql/transpile.py b/src/giql/transpile.py index 2b29c3d..df0264f 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -11,6 +11,7 @@ from giql.table import Table from giql.table import Tables from giql.transformer import ClusterTransformer +from giql.transformer import IntersectsBinnedJoinTransformer from giql.transformer import MergeTransformer @@ -45,6 +46,7 @@ def _build_tables(tables: list[str | Table] | None) -> Tables: def transpile( giql: str, tables: list[str | Table] | None = None, + bin_size: int | None = None, ) -> str: """Transpile a GIQL query to SQL. @@ -60,6 +62,11 @@ def transpile( Table configurations. Strings use default column mappings (chrom, start, end, strand). Table objects provide custom column name mappings. + bin_size : int | None + Bin size for INTERSECTS equi-join optimization. When a query + contains a full-table column-to-column INTERSECTS join, the + transpiler rewrites it as a binned equi-join for performance. + Defaults to 10,000 if not specified. Returns ------- @@ -94,11 +101,24 @@ def transpile( ) ], ) + + Binned equi-join with custom bin size:: + + sql = transpile( + "SELECT a.*, b.* FROM peaks a JOIN genes b " + "ON a.interval INTERSECTS b.interval", + tables=["peaks", "genes"], + bin_size=100000, + ) """ # Build tables container tables_container = _build_tables(tables) # Initialize transformers with table configurations + intersects_transformer = IntersectsBinnedJoinTransformer( + tables_container, + bin_size=bin_size, + ) merge_transformer = MergeTransformer(tables_container) cluster_transformer = ClusterTransformer(tables_container) @@ -111,8 +131,10 @@ def transpile( except Exception as e: raise ValueError(f"Parse error: {e}\nQuery: {giql}") from e - # Apply transformations (MERGE first, then CLUSTER) + # Apply transformations try: + # Binned join rewrite for column-to-column INTERSECTS joins + ast = intersects_transformer.transform(ast) # MERGE transformation (which may internally use CLUSTER) ast = merge_transformer.transform(ast) # CLUSTER transformation for any standalone CLUSTER expressions From 8465472fe50486b1af0512874c0064bd645d3666 Mon Sep 17 00:00:00 2001 From: Conrad Date: Mon, 30 Mar 2026 23:47:11 -0400 Subject: [PATCH 02/20] build: Add datafusion to dev dependencies Needed for end-to-end correctness tests that validate the binned equi-join SQL against DataFusion's query engine. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 480cf1e..d5d9129 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dev = [ "pytest-cov>=4.0.0", "pytest>=7.0.0", "ruff>=0.1.0", + "datafusion>=52.3.0", ] docs = [ "sphinx>=7.0", From 1c9cc08b65a4ba2ece55daa8ec00a610e7f990a6 Mon Sep 17 00:00:00 2001 From: Conrad Date: Mon, 30 Mar 2026 23:47:16 -0400 Subject: [PATCH 03/20] feat: Extend binned join rewrite to implicit cross-joins The transformer now detects column-to-column INTERSECTS in WHERE clauses (FROM a, b WHERE a.interval INTERSECTS b.interval), not just in explicit JOIN ON conditions. Both patterns are rewritten to binned equi-joins for the same performance benefit. --- src/giql/transformer.py | 209 ++++++++++++++++++++++++++++++---------- 1 file changed, 159 insertions(+), 50 deletions(-) diff --git a/src/giql/transformer.py b/src/giql/transformer.py index 2d11cf5..18af2de 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -580,15 +580,19 @@ def _transform_for_merge( class IntersectsBinnedJoinTransformer: - """Transforms column-to-column INTERSECTS joins into binned equi-joins. + """Transforms column-to-column INTERSECTS into binned equi-joins. - Rewrites: + Handles both explicit JOIN ON and implicit cross-join (WHERE) patterns: + -- Explicit JOIN SELECT a.*, b.* - FROM peaks a - JOIN genes b ON a.interval INTERSECTS b.region + FROM peaks a JOIN genes b ON a.interval INTERSECTS b.region - Into: + -- Implicit cross-join + SELECT a.*, b.* + FROM peaks a, genes b WHERE a.interval INTERSECTS b.interval + + Both are rewritten to: WITH __giql_a_binned AS ( SELECT *, UNNEST(range( @@ -596,19 +600,15 @@ class IntersectsBinnedJoinTransformer: CAST(("end" - 1) / B + 1 AS BIGINT) )) AS __giql_bin FROM peaks ), - __giql_b_binned AS ( - SELECT *, UNNEST(range( - CAST("start" / B AS BIGINT), - CAST(("end" - 1) / B + 1 AS BIGINT) - )) AS __giql_bin FROM genes - ) + __giql_b_binned AS (...) SELECT DISTINCT a.*, b.* FROM __giql_a_binned AS a JOIN __giql_b_binned AS b ON a."chrom" = b."chrom" AND a.__giql_bin = b.__giql_bin WHERE a."start" < b."end" AND a."end" > b."start" - Literal-range INTERSECTS (e.g., WHERE clause filters) are left untouched. + Literal-range INTERSECTS (e.g., ``WHERE interval INTERSECTS 'chr1:...'``) + are left untouched. """ def __init__(self, tables: Tables, bin_size: int | None = None): @@ -623,19 +623,28 @@ def transform(self, query: exp.Expression) -> exp.Expression: if not joins: return query + # Check explicit JOIN ON conditions for join in joins: - intersects = self._find_column_intersects(join) + on = join.args.get("on") + if on: + intersects = self._find_column_intersects_in(on) + if intersects: + return self._rewrite_join_on(query, join, intersects) + + # Check WHERE clause (implicit cross-join pattern) + where = query.args.get("where") + if where: + intersects = self._find_column_intersects_in(where.this) if intersects: - return self._rewrite(query, join, intersects) + join = self._find_join_for_intersects(query, intersects) + if join: + return self._rewrite_where(query, join, intersects) return query - def _find_column_intersects(self, join: exp.Join) -> Intersects | None: - """Find an Intersects node in the JOIN ON where both sides are column refs.""" - on = join.args.get("on") - if not on: - return None - for node in on.find_all(Intersects): + def _find_column_intersects_in(self, expr: exp.Expression) -> Intersects | None: + """Find an Intersects node where both sides are table-qualified columns.""" + for node in expr.find_all(Intersects): if ( isinstance(node.this, exp.Column) and node.this.table @@ -645,6 +654,34 @@ def _find_column_intersects(self, join: exp.Join) -> Intersects | None: return node return None + def _find_join_for_intersects( + self, query: exp.Select, intersects: Intersects + ) -> exp.Join | None: + """Find the Join node for the table referenced in an Intersects.""" + from_table = query.args["from_"].this + if not isinstance(from_table, exp.Table): + return None + + from_alias = from_table.alias or from_table.name + left_alias = intersects.this.table + right_alias = intersects.expression.table + + # Determine which alias is the join table (not the FROM table) + if left_alias == from_alias: + target_alias = right_alias + elif right_alias == from_alias: + target_alias = left_alias + else: + return None + + for join in query.args.get("joins", []): + if isinstance(join.this, exp.Table): + alias = join.this.alias or join.this.name + if alias == target_alias: + return join + + return None + def _get_columns(self, table_name: str) -> tuple[str, str, str]: """Return (chrom, start, end) column names for a table.""" table = self.tables.get(table_name) @@ -695,20 +732,17 @@ def _build_binned_select( select.from_(exp.Table(this=exp.Identifier(this=table_name)), copy=False) return select - def _rewrite( + def _install_binned_ctes( self, query: exp.Select, join: exp.Join, - intersects: Intersects, - ) -> exp.Select: - from_table = query.args["from_"].this - join_table = join.this - - if not isinstance(from_table, exp.Table) or not isinstance( - join_table, exp.Table - ): - return query + from_table: exp.Table, + join_table: exp.Table, + ) -> tuple[str, str, tuple[str, str, str], tuple[str, str, str]]: + """Create binned CTEs and update FROM/JOIN to reference them. + Returns (from_alias, join_alias, from_cols, join_cols). + """ from_name = from_table.name join_name = join_table.name from_alias = from_table.alias or from_name @@ -717,11 +751,9 @@ def _rewrite( from_cols = self._get_columns(from_name) join_cols = self._get_columns(join_name) - # Use alias-based CTE names to avoid collisions on self-joins from_cte_name = f"__giql_{from_alias}_binned" join_cte_name = f"__giql_{join_alias}_binned" - # Build CTEs from_cte = exp.CTE( this=self._build_binned_select(from_name, from_cols), alias=exp.TableAlias(this=exp.Identifier(this=from_cte_name)), @@ -731,7 +763,6 @@ def _rewrite( alias=exp.TableAlias(this=exp.Identifier(this=join_cte_name)), ) - # Add CTEs existing_with = query.args.get("with_") if existing_with: existing_with.append("expressions", from_cte) @@ -739,21 +770,33 @@ def _rewrite( else: query.set("with_", exp.With(expressions=[from_cte, join_cte])) - # Replace FROM table reference with CTE (preserve alias) - new_from = exp.Table( - this=exp.Identifier(this=from_cte_name), - alias=exp.TableAlias(this=exp.Identifier(this=from_alias)), + query.args["from_"].set( + "this", + exp.Table( + this=exp.Identifier(this=from_cte_name), + alias=exp.TableAlias(this=exp.Identifier(this=from_alias)), + ), ) - query.args["from_"].set("this", new_from) - - # Replace JOIN table reference with CTE (preserve alias) - new_join_table = exp.Table( - this=exp.Identifier(this=join_cte_name), - alias=exp.TableAlias(this=exp.Identifier(this=join_alias)), + join.set( + "this", + exp.Table( + this=exp.Identifier(this=join_cte_name), + alias=exp.TableAlias(this=exp.Identifier(this=join_alias)), + ), ) - join.set("this", new_join_table) - # Replace JOIN ON with equi-join on chrom + bin + query.set("distinct", exp.Distinct()) + + return from_alias, join_alias, from_cols, join_cols + + def _build_equi_join( + self, + from_alias: str, + join_alias: str, + from_cols: tuple[str, str, str], + join_cols: tuple[str, str, str], + ) -> exp.And: + """Build ``chrom = chrom AND __giql_bin = __giql_bin``.""" chrom_eq = exp.EQ( this=exp.column(from_cols[0], table=from_alias, quoted=True), expression=exp.column(join_cols[0], table=join_alias, quoted=True), @@ -762,10 +805,17 @@ def _rewrite( this=exp.column("__giql_bin", table=from_alias), expression=exp.column("__giql_bin", table=join_alias), ) - join.set("on", exp.And(this=chrom_eq, expression=bin_eq)) + return exp.And(this=chrom_eq, expression=bin_eq) - # Add overlap filter to WHERE - overlap = exp.And( + def _build_overlap_filter( + self, + from_alias: str, + join_alias: str, + from_cols: tuple[str, str, str], + join_cols: tuple[str, str, str], + ) -> exp.And: + """Build ``from.start < join.end AND from.end > join.start``.""" + return exp.And( this=exp.LT( this=exp.column(from_cols[1], table=from_alias, quoted=True), expression=exp.column(join_cols[2], table=join_alias, quoted=True), @@ -776,6 +826,32 @@ def _rewrite( ), ) + def _rewrite_join_on( + self, + query: exp.Select, + join: exp.Join, + intersects: Intersects, + ) -> exp.Select: + """Rewrite an explicit ``JOIN ... ON ... INTERSECTS ...``.""" + from_table = query.args["from_"].this + join_table = join.this + + if not isinstance(from_table, exp.Table) or not isinstance( + join_table, exp.Table + ): + return query + + from_alias, join_alias, from_cols, join_cols = self._install_binned_ctes( + query, join, from_table, join_table + ) + + join.set( + "on", self._build_equi_join(from_alias, join_alias, from_cols, join_cols) + ) + + overlap = self._build_overlap_filter( + from_alias, join_alias, from_cols, join_cols + ) existing_where = query.args.get("where") if existing_where: merged = exp.And(this=existing_where.this, expression=overlap) @@ -783,7 +859,40 @@ def _rewrite( else: query.set("where", exp.Where(this=overlap)) - # Add DISTINCT to deduplicate rows that appear in multiple bins - query.set("distinct", exp.Distinct()) + return query + + def _rewrite_where( + self, + query: exp.Select, + join: exp.Join, + intersects: Intersects, + ) -> exp.Select: + """Rewrite an implicit cross-join ``FROM a, b WHERE ... INTERSECTS ...``.""" + from_table = query.args["from_"].this + join_table = join.this + + if not isinstance(from_table, exp.Table) or not isinstance( + join_table, exp.Table + ): + return query + + from_alias, join_alias, from_cols, join_cols = self._install_binned_ctes( + query, join, from_table, join_table + ) + + equi_join = self._build_equi_join(from_alias, join_alias, from_cols, join_cols) + overlap = self._build_overlap_filter( + from_alias, join_alias, from_cols, join_cols + ) + + # Replace the Intersects node in-place with equi-join + overlap + replacement = exp.And( + this=exp.And(this=equi_join, expression=overlap), + expression=exp.Paren(this=exp.Literal.number(1)), + ) + # Use Paren(1) as a truthy sentinel then clean it up — simpler + # to just build the full replacement directly: + replacement = exp.And(this=equi_join, expression=overlap) + intersects.replace(replacement) return query From d7b4a81f87a2015db41712aec7653d7fb1290654 Mon Sep 17 00:00:00 2001 From: Conrad Date: Mon, 30 Mar 2026 23:47:26 -0400 Subject: [PATCH 04/20] test: Add binned join unit and DataFusion correctness tests Covers both explicit JOIN ON and implicit cross-join patterns, custom bin sizes, custom column mappings, self-joins, literal range passthrough, and end-to-end correctness against DataFusion including multi-bin deduplication and equivalence with naive joins. --- tests/test_binned_join.py | 487 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 487 insertions(+) create mode 100644 tests/test_binned_join.py diff --git a/tests/test_binned_join.py b/tests/test_binned_join.py new file mode 100644 index 0000000..da71dcc --- /dev/null +++ b/tests/test_binned_join.py @@ -0,0 +1,487 @@ +"""Tests for the INTERSECTS binned equi-join transpilation.""" + +import pytest + +from giql import Table +from giql import transpile + + +class TestTranspileBinnedJoin: + """Unit tests for binned join SQL structure.""" + + def test_basic_binned_join_rewrite(self): + """ + GIVEN a GIQL query joining two tables with column-to-column INTERSECTS + WHEN transpiling with default settings + THEN should produce CTEs with UNNEST/range, equi-join ON chrom + __giql_bin, + WHERE overlap filter, and DISTINCT + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + ) + + sql_upper = sql.upper() + + # CTEs with UNNEST and range + assert "WITH" in sql_upper + assert "UNNEST" in sql_upper + assert "range" in sql or "RANGE" in sql_upper + assert "__giql_bin" in sql + + # Equi-join on chrom and bin + assert '"chrom"' in sql + assert "__giql_bin" in sql + + # Overlap filter in WHERE + assert "WHERE" in sql_upper + + # DISTINCT to deduplicate across bins + assert "DISTINCT" in sql_upper + + def test_custom_bin_size(self): + """ + GIVEN a GIQL query with column-to-column INTERSECTS join + WHEN transpiling with bin_size=100000 + THEN should use 100000 in the range expressions + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + bin_size=100000, + ) + + assert "100000" in sql + + def test_custom_column_mappings(self): + """ + GIVEN two tables with different custom column schemas + WHEN transpiling a binned join query + THEN should use each table's custom column names in CTEs, ON, and WHERE + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN features b ON a.interval INTERSECTS b.location + """, + tables=[ + Table( + "peaks", + genomic_col="interval", + chrom_col="chromosome", + start_col="start_pos", + end_col="end_pos", + ), + Table( + "features", + genomic_col="location", + chrom_col="seqname", + start_col="begin", + end_col="terminus", + ), + ], + ) + + # Custom column names for peaks + assert '"chromosome"' in sql + assert '"start_pos"' in sql + assert '"end_pos"' in sql + + # Custom column names for features + assert '"seqname"' in sql + assert '"begin"' in sql + assert '"terminus"' in sql + + # Default column names should NOT appear + assert '"chrom"' not in sql + assert '"start"' not in sql + assert '"end"' not in sql + + def test_literal_intersects_no_binned_ctes(self): + """ + GIVEN a GIQL query with a literal-range INTERSECTS in WHERE (not a join) + WHEN transpiling + THEN should NOT produce binned CTEs + """ + sql = transpile( + "SELECT * FROM peaks WHERE interval INTERSECTS 'chr1:1000-2000'", + tables=["peaks"], + ) + + assert "__giql_bin" not in sql + assert "UNNEST" not in sql.upper() + + def test_no_join_passthrough(self): + """ + GIVEN a simple SELECT query with no JOIN + WHEN transpiling + THEN should NOT produce binned CTEs + """ + sql = transpile( + "SELECT * FROM peaks", + tables=["peaks"], + ) + + assert "__giql_bin" not in sql + assert "UNNEST" not in sql.upper() + assert "__giql_" not in sql + + def test_existing_where_preserved(self): + """ + GIVEN a GIQL join query that already has a WHERE clause + WHEN transpiling a binned join + THEN should preserve the original WHERE condition alongside the overlap filter + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + WHERE a.score > 100 + """, + tables=["peaks", "genes"], + ) + + sql_upper = sql.upper() + + # Original WHERE condition preserved + assert "100" in sql + assert "score" in sql.lower() + + # Overlap filter also present + assert "WHERE" in sql_upper + # Both conditions combined with AND + assert "AND" in sql_upper + + def test_bin_size_none_defaults_to_10000(self): + """ + GIVEN a GIQL join query + WHEN transpiling with bin_size=None (explicit) + THEN should produce the same output as default (10000) + """ + sql_default = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + ) + + sql_none = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + bin_size=None, + ) + + assert sql_default == sql_none + assert "10000" in sql_default + + def test_implicit_cross_join_rewrite(self): + """ + GIVEN a GIQL query with implicit cross-join (FROM a, b WHERE INTERSECTS) + WHEN transpiling + THEN should produce binned CTEs and replace the INTERSECTS in WHERE + """ + sql = transpile( + """ + SELECT DISTINCT a.* + FROM peaks a, genes b + WHERE a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + ) + + sql_upper = sql.upper() + + # Binned CTEs present + assert "WITH" in sql_upper + assert "UNNEST" in sql_upper + assert "__giql_bin" in sql + + # Equi-join conditions in WHERE (not in ON for comma joins) + assert "WHERE" in sql_upper + + # Overlap filter present + assert '"start"' in sql + assert '"end"' in sql + + def test_self_join_distinct_ctes(self): + """ + GIVEN a self-join query where the same table appears with two aliases + WHEN transpiling a binned join + THEN should produce two distinct CTEs both referencing the same underlying table + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN peaks b ON a.interval INTERSECTS b.interval + """, + tables=["peaks"], + ) + + sql_upper = sql.upper() + + # Two distinct CTEs + assert "__giql_a_binned" in sql + assert "__giql_b_binned" in sql + + # Both reference the same underlying table + assert "FROM peaks" in sql or "FROM PEAKS" in sql_upper + + # Should still have DISTINCT and WHERE + assert "DISTINCT" in sql_upper + assert "WHERE" in sql_upper + + +class TestBinnedJoinDataFusion: + """End-to-end DataFusion correctness tests for binned INTERSECTS joins.""" + + @staticmethod + def _make_ctx(peaks_data, genes_data): + """Create a DataFusion context with two interval tables.""" + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ] + ) + + ctx = SessionContext() + + peaks_arrays = { + "chrom": [r[0] for r in peaks_data], + "start": [r[1] for r in peaks_data], + "end": [r[2] for r in peaks_data], + } + genes_arrays = { + "chrom": [r[0] for r in genes_data], + "start": [r[1] for r in genes_data], + "end": [r[2] for r in genes_data], + } + + ctx.register_record_batches( + "peaks", [pa.table(peaks_arrays, schema=schema).to_batches()] + ) + ctx.register_record_batches( + "genes", [pa.table(genes_arrays, schema=schema).to_batches()] + ) + return ctx + + def test_overlapping_intervals_correct_rows_no_duplicates(self): + """ + GIVEN two tables with overlapping intervals + WHEN executing a binned INTERSECTS join via DataFusion + THEN should return the correct matching rows with no duplicates + """ + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr1", 1000, 2000)], + genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)], + ) + + sql = transpile( + """ + SELECT a.chrom, a.start, a."end", b.start AS b_start, b."end" AS b_end + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(sql).to_pandas() + + # Only the first peak (100-500) overlaps the first gene (300-600) + assert len(df) == 1 + assert df.iloc[0]["start"] == 100 + assert df.iloc[0]["b_start"] == 300 + + def test_non_overlapping_intervals_zero_rows(self): + """ + GIVEN two tables with no overlapping intervals + WHEN executing a binned INTERSECTS join via DataFusion + THEN should return zero rows + """ + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 200), ("chr1", 300, 400)], + genes_data=[("chr1", 500, 600), ("chr1", 700, 800)], + ) + + sql = transpile( + """ + SELECT a.chrom, a.start, a."end" + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(sql).to_pandas() + assert len(df) == 0 + + def test_adjacent_intervals_zero_rows_half_open(self): + """ + GIVEN two tables with adjacent (touching) intervals under half-open coordinates + WHEN executing a binned INTERSECTS join via DataFusion + THEN should return zero rows because [100, 200) and [200, 300) do not overlap + """ + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 200)], + genes_data=[("chr1", 200, 300)], + ) + + sql = transpile( + """ + SELECT a.chrom, a.start, a."end" + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(sql).to_pandas() + assert len(df) == 0 + + def test_different_chromosomes_only_same_chrom(self): + """ + GIVEN two tables with intervals on different chromosomes that would overlap positionally + WHEN executing a binned INTERSECTS join via DataFusion + THEN should only return overlaps on the same chromosome + """ + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr2", 100, 500)], + genes_data=[("chr1", 200, 400), ("chr3", 200, 400)], + ) + + sql = transpile( + """ + SELECT a.chrom, a.start, a."end", + b.chrom AS b_chrom, b.start AS b_start, b."end" AS b_end + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(sql).to_pandas() + + assert len(df) == 1 + assert df.iloc[0]["chrom"] == "chr1" + assert df.iloc[0]["b_chrom"] == "chr1" + + def test_intervals_spanning_multiple_bins_no_duplicates(self): + """ + GIVEN intervals that span multiple bins + WHEN executing a binned INTERSECTS join via DataFusion + THEN overlapping pairs should be returned exactly once (DISTINCT dedup) + """ + ctx = self._make_ctx( + peaks_data=[("chr1", 0, 50000)], + genes_data=[("chr1", 25000, 75000)], + ) + + sql = transpile( + """ + SELECT a.chrom, a.start, a."end", b.start AS b_start, b."end" AS b_end + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + bin_size=10000, + ) + + df = ctx.sql(sql).to_pandas() + + # Despite sharing multiple bins (2, 3, 4), should appear exactly once + assert len(df) == 1 + assert df.iloc[0]["start"] == 0 + assert df.iloc[0]["end"] == 50000 + assert df.iloc[0]["b_start"] == 25000 + assert df.iloc[0]["b_end"] == 75000 + + def test_equivalence_with_naive_cross_join(self): + """ + GIVEN two tables with a mix of overlapping and non-overlapping intervals + WHEN executing a binned INTERSECTS join via DataFusion + THEN results should match a naive cross-join with overlap filter + """ + ctx = self._make_ctx( + peaks_data=[ + ("chr1", 0, 100), + ("chr1", 150, 300), + ("chr1", 500, 1000), + ("chr2", 0, 500), + ], + genes_data=[ + ("chr1", 50, 200), + ("chr1", 250, 600), + ("chr1", 900, 1100), + ("chr2", 400, 800), + ], + ) + + naive_sql = """ + SELECT a.chrom AS a_chrom, a.start AS a_start, a."end" AS a_end, + b.chrom AS b_chrom, b.start AS b_start, b."end" AS b_end + FROM peaks a, genes b + WHERE a.chrom = b.chrom + AND a.start < b."end" + AND a."end" > b.start + ORDER BY a.chrom, a.start, b.start + """ + naive_df = ctx.sql(naive_sql).to_pandas() + + binned_sql = transpile( + """ + SELECT a.chrom AS a_chrom, a.start AS a_start, a."end" AS a_end, + b.chrom AS b_chrom, b.start AS b_start, b."end" AS b_end + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + binned_df = ( + ctx.sql(binned_sql) + .to_pandas() + .sort_values(by=["a_chrom", "a_start", "b_start"]) + .reset_index(drop=True) + ) + naive_df = naive_df.reset_index(drop=True) + + assert len(binned_df) == len(naive_df) + assert binned_df.values.tolist() == naive_df.values.tolist() From 04bd0c0d808865183813dfd50787a60040b6f0e8 Mon Sep 17 00:00:00 2001 From: Conrad Date: Tue, 31 Mar 2026 09:52:33 -0400 Subject: [PATCH 05/20] fix: Place overlap filter in ON and support multi-join queries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move the overlap predicate (start < end AND end > start) from WHERE into the JOIN ON clause so that LEFT/RIGHT/FULL JOIN semantics are preserved — a WHERE filter on the right-side columns silently converts outer joins into inner joins. Also refactor the transformer to rewrite all INTERSECTS joins in a query, not just the first. A new _ensure_table_binned helper tracks which aliases already have binned CTEs so that multi-join queries reuse CTEs instead of duplicating them. Add bin_size validation (must be positive) and remove dead code from _rewrite_where. --- src/giql/transformer.py | 143 +++++++++++++++++++--------------------- 1 file changed, 68 insertions(+), 75 deletions(-) diff --git a/src/giql/transformer.py b/src/giql/transformer.py index 18af2de..cff2844 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -582,7 +582,8 @@ def _transform_for_merge( class IntersectsBinnedJoinTransformer: """Transforms column-to-column INTERSECTS into binned equi-joins. - Handles both explicit JOIN ON and implicit cross-join (WHERE) patterns: + Handles both explicit JOIN ON and implicit cross-join (WHERE) patterns, + including queries with multiple INTERSECTS joins: -- Explicit JOIN SELECT a.*, b.* @@ -605,7 +606,7 @@ class IntersectsBinnedJoinTransformer: FROM __giql_a_binned AS a JOIN __giql_b_binned AS b ON a."chrom" = b."chrom" AND a.__giql_bin = b.__giql_bin - WHERE a."start" < b."end" AND a."end" > b."start" + AND a."start" < b."end" AND a."end" > b."start" Literal-range INTERSECTS (e.g., ``WHERE interval INTERSECTS 'chr1:...'``) are left untouched. @@ -614,6 +615,8 @@ class IntersectsBinnedJoinTransformer: def __init__(self, tables: Tables, bin_size: int | None = None): self.tables = tables self.bin_size = bin_size if bin_size is not None else DEFAULT_BIN_SIZE + if self.bin_size <= 0: + raise ValueError(f"bin_size must be a positive integer, got {self.bin_size}") def transform(self, query: exp.Expression) -> exp.Expression: if not isinstance(query, exp.Select): @@ -623,22 +626,35 @@ def transform(self, query: exp.Expression) -> exp.Expression: if not joins: return query - # Check explicit JOIN ON conditions + # Track which table aliases already have binned CTEs so that + # multi-join queries reuse CTEs instead of duplicating them. + binned: dict[str, tuple[str, str, str]] = {} + rewrote_any = False + + # Rewrite all explicit JOIN ON INTERSECTS for join in joins: on = join.args.get("on") if on: intersects = self._find_column_intersects_in(on) if intersects: - return self._rewrite_join_on(query, join, intersects) + self._rewrite_join_on(query, join, intersects, binned) + rewrote_any = True - # Check WHERE clause (implicit cross-join pattern) + # Rewrite all WHERE clause INTERSECTS (implicit cross-joins) where = query.args.get("where") if where: - intersects = self._find_column_intersects_in(where.this) - if intersects: + while True: + intersects = self._find_column_intersects_in(where.this) + if not intersects: + break join = self._find_join_for_intersects(query, intersects) - if join: - return self._rewrite_where(query, join, intersects) + if not join: + break + self._rewrite_where(query, join, intersects, binned) + rewrote_any = True + + if rewrote_any: + query.set("distinct", exp.Distinct()) return query @@ -732,62 +748,47 @@ def _build_binned_select( select.from_(exp.Table(this=exp.Identifier(this=table_name)), copy=False) return select - def _install_binned_ctes( + def _ensure_table_binned( self, query: exp.Select, - join: exp.Join, - from_table: exp.Table, - join_table: exp.Table, - ) -> tuple[str, str, tuple[str, str, str], tuple[str, str, str]]: - """Create binned CTEs and update FROM/JOIN to reference them. + table: exp.Table, + parent: exp.Expression, + binned: dict[str, tuple[str, str, str]], + ) -> tuple[str, tuple[str, str, str]]: + """Create a binned CTE for *table* if one does not already exist. - Returns (from_alias, join_alias, from_cols, join_cols). + Returns (alias, (chrom_col, start_col, end_col)). """ - from_name = from_table.name - join_name = join_table.name - from_alias = from_table.alias or from_name - join_alias = join_table.alias or join_name + alias = table.alias or table.name - from_cols = self._get_columns(from_name) - join_cols = self._get_columns(join_name) + if alias in binned: + return alias, binned[alias] - from_cte_name = f"__giql_{from_alias}_binned" - join_cte_name = f"__giql_{join_alias}_binned" + table_name = table.name + cols = self._get_columns(table_name) + cte_name = f"__giql_{alias}_binned" - from_cte = exp.CTE( - this=self._build_binned_select(from_name, from_cols), - alias=exp.TableAlias(this=exp.Identifier(this=from_cte_name)), - ) - join_cte = exp.CTE( - this=self._build_binned_select(join_name, join_cols), - alias=exp.TableAlias(this=exp.Identifier(this=join_cte_name)), + cte = exp.CTE( + this=self._build_binned_select(table_name, cols), + alias=exp.TableAlias(this=exp.Identifier(this=cte_name)), ) existing_with = query.args.get("with_") if existing_with: - existing_with.append("expressions", from_cte) - existing_with.append("expressions", join_cte) + existing_with.append("expressions", cte) else: - query.set("with_", exp.With(expressions=[from_cte, join_cte])) + query.set("with_", exp.With(expressions=[cte])) - query.args["from_"].set( + parent.set( "this", exp.Table( - this=exp.Identifier(this=from_cte_name), - alias=exp.TableAlias(this=exp.Identifier(this=from_alias)), + this=exp.Identifier(this=cte_name), + alias=exp.TableAlias(this=exp.Identifier(this=alias)), ), ) - join.set( - "this", - exp.Table( - this=exp.Identifier(this=join_cte_name), - alias=exp.TableAlias(this=exp.Identifier(this=join_alias)), - ), - ) - - query.set("distinct", exp.Distinct()) - return from_alias, join_alias, from_cols, join_cols + binned[alias] = cols + return alias, cols def _build_equi_join( self, @@ -831,7 +832,8 @@ def _rewrite_join_on( query: exp.Select, join: exp.Join, intersects: Intersects, - ) -> exp.Select: + binned: dict[str, tuple[str, str, str]], + ) -> None: """Rewrite an explicit ``JOIN ... ON ... INTERSECTS ...``.""" from_table = query.args["from_"].this join_table = join.this @@ -839,34 +841,31 @@ def _rewrite_join_on( if not isinstance(from_table, exp.Table) or not isinstance( join_table, exp.Table ): - return query + return - from_alias, join_alias, from_cols, join_cols = self._install_binned_ctes( - query, join, from_table, join_table + from_alias, from_cols = self._ensure_table_binned( + query, from_table, query.args["from_"], binned ) - - join.set( - "on", self._build_equi_join(from_alias, join_alias, from_cols, join_cols) + join_alias, join_cols = self._ensure_table_binned( + query, join_table, join, binned ) + equi_join = self._build_equi_join(from_alias, join_alias, from_cols, join_cols) overlap = self._build_overlap_filter( from_alias, join_alias, from_cols, join_cols ) - existing_where = query.args.get("where") - if existing_where: - merged = exp.And(this=existing_where.this, expression=overlap) - existing_where.set("this", merged) - else: - query.set("where", exp.Where(this=overlap)) - return query + # Place both equi-join and overlap in ON so that LEFT/RIGHT/FULL + # JOIN semantics are preserved (WHERE would filter out NULL rows). + join.set("on", exp.And(this=equi_join, expression=overlap)) def _rewrite_where( self, query: exp.Select, join: exp.Join, intersects: Intersects, - ) -> exp.Select: + binned: dict[str, tuple[str, str, str]], + ) -> None: """Rewrite an implicit cross-join ``FROM a, b WHERE ... INTERSECTS ...``.""" from_table = query.args["from_"].this join_table = join.this @@ -874,10 +873,13 @@ def _rewrite_where( if not isinstance(from_table, exp.Table) or not isinstance( join_table, exp.Table ): - return query + return - from_alias, join_alias, from_cols, join_cols = self._install_binned_ctes( - query, join, from_table, join_table + from_alias, from_cols = self._ensure_table_binned( + query, from_table, query.args["from_"], binned + ) + join_alias, join_cols = self._ensure_table_binned( + query, join_table, join, binned ) equi_join = self._build_equi_join(from_alias, join_alias, from_cols, join_cols) @@ -886,13 +888,4 @@ def _rewrite_where( ) # Replace the Intersects node in-place with equi-join + overlap - replacement = exp.And( - this=exp.And(this=equi_join, expression=overlap), - expression=exp.Paren(this=exp.Literal.number(1)), - ) - # Use Paren(1) as a truthy sentinel then clean it up — simpler - # to just build the full replacement directly: - replacement = exp.And(this=equi_join, expression=overlap) - intersects.replace(replacement) - - return query + intersects.replace(exp.And(this=equi_join, expression=overlap)) From 13a43750297b84009fb0dee49e4bd8ff43acc20a Mon Sep 17 00:00:00 2001 From: Conrad Date: Tue, 31 Mar 2026 09:52:39 -0400 Subject: [PATCH 06/20] test: Add multi-join and bin_size validation tests Cover three-way joins with CTE reuse, invalid bin_size rejection, and update assertions for the overlap-in-ON change. Remove unused pytest import from module level. --- tests/test_binned_join.py | 66 ++++++++++++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 8 deletions(-) diff --git a/tests/test_binned_join.py b/tests/test_binned_join.py index da71dcc..276d7a2 100644 --- a/tests/test_binned_join.py +++ b/tests/test_binned_join.py @@ -1,7 +1,5 @@ """Tests for the INTERSECTS binned equi-join transpilation.""" -import pytest - from giql import Table from giql import transpile @@ -13,8 +11,8 @@ def test_basic_binned_join_rewrite(self): """ GIVEN a GIQL query joining two tables with column-to-column INTERSECTS WHEN transpiling with default settings - THEN should produce CTEs with UNNEST/range, equi-join ON chrom + __giql_bin, - WHERE overlap filter, and DISTINCT + THEN should produce CTEs with UNNEST/range, equi-join and overlap in ON, + and DISTINCT """ sql = transpile( """ @@ -37,8 +35,10 @@ def test_basic_binned_join_rewrite(self): assert '"chrom"' in sql assert "__giql_bin" in sql - # Overlap filter in WHERE - assert "WHERE" in sql_upper + # Overlap filter in ON (not WHERE) for correct outer-join semantics + assert "ON" in sql_upper + assert '"start"' in sql or '"START"' in sql_upper + assert '"end"' in sql or '"END"' in sql_upper # DISTINCT to deduplicate across bins assert "DISTINCT" in sql_upper @@ -243,9 +243,59 @@ def test_self_join_distinct_ctes(self): # Both reference the same underlying table assert "FROM peaks" in sql or "FROM PEAKS" in sql_upper - # Should still have DISTINCT and WHERE + # Should still have DISTINCT assert "DISTINCT" in sql_upper - assert "WHERE" in sql_upper + + def test_invalid_bin_size_raises(self): + """ + GIVEN bin_size=0 or a negative value + WHEN calling transpile + THEN should raise ValueError + """ + import pytest + + with pytest.raises(ValueError, match="positive"): + transpile( + "SELECT * FROM a JOIN b ON a.interval INTERSECTS b.interval", + tables=["a", "b"], + bin_size=0, + ) + + with pytest.raises(ValueError, match="positive"): + transpile( + "SELECT * FROM a JOIN b ON a.interval INTERSECTS b.interval", + tables=["a", "b"], + bin_size=-1, + ) + + def test_multi_join_all_intersects_rewritten(self): + """ + GIVEN a three-way join with two INTERSECTS conditions + WHEN transpiling + THEN should create binned CTEs for all three tables, reusing the + FROM table CTE, and place equi-join + overlap in each JOIN ON + """ + sql = transpile( + """ + SELECT a.*, b.*, c.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + JOIN exons c ON a.interval INTERSECTS c.interval + """, + tables=["peaks", "genes", "exons"], + ) + + # Three distinct CTEs + assert "__giql_a_binned" in sql + assert "__giql_b_binned" in sql + assert "__giql_c_binned" in sql + + # FROM table CTE created only once + assert sql.count("FROM peaks") == 1 or sql.upper().count("FROM PEAKS") == 1 + + # Both JOINs have equi-join + overlap in ON + sql_upper = sql.upper() + assert sql_upper.count("__GIQL_BIN") >= 4 # at least 2 per JOIN ON class TestBinnedJoinDataFusion: From 20140ee03e63d1334f5f8c2fccc7d371345c094d Mon Sep 17 00:00:00 2001 From: Conrad Date: Tue, 31 Mar 2026 10:39:24 -0400 Subject: [PATCH 07/20] fix: Limit binned join to explicit JOINs and skip DataFusion tests The binned CTE approach leaks __giql_bin into SELECT * results because CTEs expose all their columns. Revert implicit cross-join rewriting (FROM a, b WHERE INTERSECTS) so those queries fall through to the generator's naive overlap predicate, which produces clean column output. Explicit JOIN ON INTERSECTS continues to use the binned equi-join. Also add pytest.importorskip for datafusion so the DataFusion correctness tests are skipped when the module is not installed. --- src/giql/transformer.py | 90 ++++----------------------------------- tests/test_binned_join.py | 24 +++++------ 2 files changed, 20 insertions(+), 94 deletions(-) diff --git a/src/giql/transformer.py b/src/giql/transformer.py index cff2844..1d1b030 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -580,20 +580,15 @@ def _transform_for_merge( class IntersectsBinnedJoinTransformer: - """Transforms column-to-column INTERSECTS into binned equi-joins. + """Transforms explicit JOIN ON INTERSECTS into binned equi-joins. - Handles both explicit JOIN ON and implicit cross-join (WHERE) patterns, - including queries with multiple INTERSECTS joins: + Handles explicit JOIN ON patterns, including queries with multiple + INTERSECTS joins: - -- Explicit JOIN SELECT a.*, b.* FROM peaks a JOIN genes b ON a.interval INTERSECTS b.region - -- Implicit cross-join - SELECT a.*, b.* - FROM peaks a, genes b WHERE a.interval INTERSECTS b.interval - - Both are rewritten to: + Rewritten to: WITH __giql_a_binned AS ( SELECT *, UNNEST(range( @@ -631,7 +626,10 @@ def transform(self, query: exp.Expression) -> exp.Expression: binned: dict[str, tuple[str, str, str]] = {} rewrote_any = False - # Rewrite all explicit JOIN ON INTERSECTS + # Rewrite all explicit JOIN ON INTERSECTS. + # Implicit cross-joins (FROM a, b WHERE INTERSECTS) are left for the + # generator's naive predicate because the binned CTE approach leaks + # the internal __giql_bin column into SELECT * results. for join in joins: on = join.args.get("on") if on: @@ -640,19 +638,6 @@ def transform(self, query: exp.Expression) -> exp.Expression: self._rewrite_join_on(query, join, intersects, binned) rewrote_any = True - # Rewrite all WHERE clause INTERSECTS (implicit cross-joins) - where = query.args.get("where") - if where: - while True: - intersects = self._find_column_intersects_in(where.this) - if not intersects: - break - join = self._find_join_for_intersects(query, intersects) - if not join: - break - self._rewrite_where(query, join, intersects, binned) - rewrote_any = True - if rewrote_any: query.set("distinct", exp.Distinct()) @@ -670,34 +655,6 @@ def _find_column_intersects_in(self, expr: exp.Expression) -> Intersects | None: return node return None - def _find_join_for_intersects( - self, query: exp.Select, intersects: Intersects - ) -> exp.Join | None: - """Find the Join node for the table referenced in an Intersects.""" - from_table = query.args["from_"].this - if not isinstance(from_table, exp.Table): - return None - - from_alias = from_table.alias or from_table.name - left_alias = intersects.this.table - right_alias = intersects.expression.table - - # Determine which alias is the join table (not the FROM table) - if left_alias == from_alias: - target_alias = right_alias - elif right_alias == from_alias: - target_alias = left_alias - else: - return None - - for join in query.args.get("joins", []): - if isinstance(join.this, exp.Table): - alias = join.this.alias or join.this.name - if alias == target_alias: - return join - - return None - def _get_columns(self, table_name: str) -> tuple[str, str, str]: """Return (chrom, start, end) column names for a table.""" table = self.tables.get(table_name) @@ -858,34 +815,3 @@ def _rewrite_join_on( # Place both equi-join and overlap in ON so that LEFT/RIGHT/FULL # JOIN semantics are preserved (WHERE would filter out NULL rows). join.set("on", exp.And(this=equi_join, expression=overlap)) - - def _rewrite_where( - self, - query: exp.Select, - join: exp.Join, - intersects: Intersects, - binned: dict[str, tuple[str, str, str]], - ) -> None: - """Rewrite an implicit cross-join ``FROM a, b WHERE ... INTERSECTS ...``.""" - from_table = query.args["from_"].this - join_table = join.this - - if not isinstance(from_table, exp.Table) or not isinstance( - join_table, exp.Table - ): - return - - from_alias, from_cols = self._ensure_table_binned( - query, from_table, query.args["from_"], binned - ) - join_alias, join_cols = self._ensure_table_binned( - query, join_table, join, binned - ) - - equi_join = self._build_equi_join(from_alias, join_alias, from_cols, join_cols) - overlap = self._build_overlap_filter( - from_alias, join_alias, from_cols, join_cols - ) - - # Replace the Intersects node in-place with equi-join + overlap - intersects.replace(exp.And(this=equi_join, expression=overlap)) diff --git a/tests/test_binned_join.py b/tests/test_binned_join.py index 276d7a2..1f9e57a 100644 --- a/tests/test_binned_join.py +++ b/tests/test_binned_join.py @@ -190,11 +190,12 @@ def test_bin_size_none_defaults_to_10000(self): assert sql_default == sql_none assert "10000" in sql_default - def test_implicit_cross_join_rewrite(self): + def test_implicit_cross_join_uses_naive_predicate(self): """ GIVEN a GIQL query with implicit cross-join (FROM a, b WHERE INTERSECTS) WHEN transpiling - THEN should produce binned CTEs and replace the INTERSECTS in WHERE + THEN should use the naive overlap predicate, not binned CTEs, because + the CTE approach leaks __giql_bin into SELECT * results """ sql = transpile( """ @@ -205,17 +206,12 @@ def test_implicit_cross_join_rewrite(self): tables=["peaks", "genes"], ) - sql_upper = sql.upper() - - # Binned CTEs present - assert "WITH" in sql_upper - assert "UNNEST" in sql_upper - assert "__giql_bin" in sql - - # Equi-join conditions in WHERE (not in ON for comma joins) - assert "WHERE" in sql_upper + # No binned CTEs — falls through to generator's naive predicate + assert "__giql_bin" not in sql + assert "UNNEST" not in sql.upper() - # Overlap filter present + # Naive overlap predicate present + assert '"chrom"' in sql assert '"start"' in sql assert '"end"' in sql @@ -298,6 +294,10 @@ def test_multi_join_all_intersects_rewritten(self): assert sql_upper.count("__GIQL_BIN") >= 4 # at least 2 per JOIN ON +pytest = __import__("pytest") +datafusion = pytest.importorskip("datafusion") + + class TestBinnedJoinDataFusion: """End-to-end DataFusion correctness tests for binned INTERSECTS joins.""" From 7394f651e33b5ac54482deedd7ba963430543760 Mon Sep 17 00:00:00 2001 From: Conrad Date: Tue, 31 Mar 2026 10:42:10 -0400 Subject: [PATCH 08/20] build: Add datafusion to pixi dependencies for CI The CI workflow uses pixi, not uv, so the datafusion package must be listed under [tool.pixi.dependencies] for the DataFusion correctness tests to run. Remove the pytest.importorskip guard since the dependency is now always available. --- pyproject.toml | 1 + tests/test_binned_join.py | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d5d9129..647358b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ bedtools = ">=2.31.0" pybedtools = ">=0.9.0" pytest = ">=7.0.0" pytest-cov = ">=4.0.0" +datafusion = ">=43.0.0" duckdb = ">=1.4.0" pandas = ">=2.0.0" sqlglot = ">=20.0.0,<30" diff --git a/tests/test_binned_join.py b/tests/test_binned_join.py index 1f9e57a..b92e28d 100644 --- a/tests/test_binned_join.py +++ b/tests/test_binned_join.py @@ -294,10 +294,6 @@ def test_multi_join_all_intersects_rewritten(self): assert sql_upper.count("__GIQL_BIN") >= 4 # at least 2 per JOIN ON -pytest = __import__("pytest") -datafusion = pytest.importorskip("datafusion") - - class TestBinnedJoinDataFusion: """End-to-end DataFusion correctness tests for binned INTERSECTS joins.""" From b31b6f7d6eb6c883fc31cd17293f9fbdd97b4fbb Mon Sep 17 00:00:00 2001 From: Conrad Date: Tue, 31 Mar 2026 11:06:46 -0400 Subject: [PATCH 09/20] feat: Use key-only bridge CTEs to eliminate __giql_bin column leak The previous approach replaced FROM/JOIN table references with full CTEs (SELECT *), causing __giql_bin to appear in SELECT a.* output. The new approach keeps original table references and routes the equi- join through key-only bridge CTEs (SELECT chrom, start, end, bin), eliminating the leak entirely. This also restores implicit cross-join rewriting (FROM a, b WHERE INTERSECTS) which was reverted in the prior commit due to the leak. CTEs are now named __giql_{table}_bins and deduplicated per underlying table name rather than per alias, so self-joins share one CTE. --- src/giql/transformer.py | 371 ++++++++++++++++++++++++++------------ tests/test_binned_join.py | 81 ++++++--- 2 files changed, 316 insertions(+), 136 deletions(-) diff --git a/src/giql/transformer.py b/src/giql/transformer.py index 1d1b030..c965016 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -580,28 +580,44 @@ def _transform_for_merge( class IntersectsBinnedJoinTransformer: - """Transforms explicit JOIN ON INTERSECTS into binned equi-joins. + """Transforms column-to-column INTERSECTS into binned equi-joins. - Handles explicit JOIN ON patterns, including queries with multiple - INTERSECTS joins: + Handles both explicit JOIN ON and implicit cross-join (WHERE) patterns: + -- Explicit JOIN SELECT a.*, b.* FROM peaks a JOIN genes b ON a.interval INTERSECTS b.region - Rewritten to: + -- Implicit cross-join + SELECT DISTINCT a.* + FROM peaks a, genes b WHERE a.interval INTERSECTS b.interval - WITH __giql_a_binned AS ( - SELECT *, UNNEST(range( - CAST("start" / B AS BIGINT), - CAST(("end" - 1) / B + 1 AS BIGINT) - )) AS __giql_bin FROM peaks + Both are rewritten using key-only bridge CTEs so that ``__giql_bin`` + never appears in the output column set (no SELECT * leak): + + WITH __giql_peaks_bins AS ( + SELECT "chrom", "start", "end", + UNNEST(range( + CAST("start" / B AS BIGINT), + CAST(("end" - 1) / B + 1 AS BIGINT) + )) AS __giql_bin + FROM peaks ), - __giql_b_binned AS (...) + __giql_genes_bins AS (...) SELECT DISTINCT a.*, b.* - FROM __giql_a_binned AS a - JOIN __giql_b_binned AS b - ON a."chrom" = b."chrom" AND a.__giql_bin = b.__giql_bin - AND a."start" < b."end" AND a."end" > b."start" + FROM peaks a + JOIN __giql_peaks_bins __giql_c0 + ON a."chrom" = __giql_c0."chrom" + AND a."start" = __giql_c0."start" + AND a."end" = __giql_c0."end" + JOIN __giql_genes_bins __giql_c1 + ON __giql_c0."chrom" = __giql_c1."chrom" + AND __giql_c0.__giql_bin = __giql_c1.__giql_bin + JOIN genes b + ON b."chrom" = __giql_c1."chrom" + AND b."start" = __giql_c1."start" + AND b."end" = __giql_c1."end" + AND a."start" < b."end" AND a."end" > b."start" Literal-range INTERSECTS (e.g., ``WHERE interval INTERSECTS 'chr1:...'``) are left untouched. @@ -617,28 +633,56 @@ def transform(self, query: exp.Expression) -> exp.Expression: if not isinstance(query, exp.Select): return query - joins = query.args.get("joins") - if not joins: - return query + joins = query.args.get("joins") or [] - # Track which table aliases already have binned CTEs so that - # multi-join queries reuse CTEs instead of duplicating them. - binned: dict[str, tuple[str, str, str]] = {} + # key_binned: table_name -> CTE name (deduped per underlying table) + key_binned: dict[str, str] = {} + connector_idx = [0] + new_joins: list[exp.Join] = [] rewrote_any = False - # Rewrite all explicit JOIN ON INTERSECTS. - # Implicit cross-joins (FROM a, b WHERE INTERSECTS) are left for the - # generator's naive predicate because the binned CTE approach leaks - # the internal __giql_bin column into SELECT * results. for join in joins: on = join.args.get("on") if on: intersects = self._find_column_intersects_in(on) if intersects: - self._rewrite_join_on(query, join, intersects, binned) + extra = self._build_join_back_joins( + query, + join, + intersects, + key_binned, + connector_idx, + preserve_kind=True, + ) + new_joins.extend(extra) + rewrote_any = True + continue + new_joins.append(join) + + # Handle implicit cross-join: FROM a, b WHERE a.interval INTERSECTS b.interval + where = query.args.get("where") + if where: + intersects = self._find_column_intersects_in(where.this) + if intersects: + cross_join = self._find_cross_join_for_intersects( + query, intersects, new_joins + ) + if cross_join is not None: + new_joins.remove(cross_join) + extra = self._build_join_back_joins( + query, + cross_join, + intersects, + key_binned, + connector_idx, + preserve_kind=False, + ) + new_joins.extend(extra) + self._remove_intersects_from_where(query, intersects) rewrote_any = True if rewrote_any: + query.set("joins", new_joins) query.set("distinct", exp.Distinct()) return query @@ -655,6 +699,54 @@ def _find_column_intersects_in(self, expr: exp.Expression) -> Intersects | None: return node return None + def _find_cross_join_for_intersects( + self, + query: exp.Select, + intersects: Intersects, + current_joins: list[exp.Join], + ) -> exp.Join | None: + """Find the implicit cross-join entry for the table in a WHERE INTERSECTS.""" + from_table = query.args["from_"].this + if not isinstance(from_table, exp.Table): + return None + from_alias = from_table.alias or from_table.name + + left_alias = intersects.this.table + right_alias = intersects.expression.table + if left_alias == from_alias: + target_alias = right_alias + elif right_alias == from_alias: + target_alias = left_alias + else: + return None + + for join in current_joins: + if isinstance(join.this, exp.Table): + alias = join.this.alias or join.this.name + if alias == target_alias: + return join + return None + + def _remove_intersects_from_where( + self, query: exp.Select, intersects: Intersects + ) -> None: + """Remove the INTERSECTS predicate from the WHERE clause.""" + where = query.args.get("where") + if not where: + return + where_expr = where.this + if where_expr is intersects: + query.set("where", None) + elif isinstance(where_expr, exp.And): + if where_expr.this is intersects: + query.set("where", exp.Where(this=where_expr.expression)) + elif where_expr.expression is intersects: + query.set("where", exp.Where(this=where_expr.this)) + else: + intersects.replace(exp.true()) + else: + intersects.replace(exp.true()) + def _get_columns(self, table_name: str) -> tuple[str, str, str]: """Return (chrom, start, end) column names for a table.""" table = self.tables.get(table_name) @@ -662,11 +754,24 @@ def _get_columns(self, table_name: str) -> tuple[str, str, str]: return (table.chrom_col, table.start_col, table.end_col) return (DEFAULT_CHROM_COL, DEFAULT_START_COL, DEFAULT_END_COL) - def _build_binned_select( + def _find_table_name_for_alias(self, query: exp.Select, alias: str) -> str: + """Resolve an alias to its underlying table name.""" + from_table = query.args["from_"].this + if isinstance(from_table, exp.Table): + if (from_table.alias or from_table.name) == alias: + return from_table.name + for join in query.args.get("joins") or []: + if isinstance(join.this, exp.Table): + t = join.this + if (t.alias or t.name) == alias: + return t.name + return alias # fallback: alias == table name + + def _build_key_only_bins_select( self, table_name: str, cols: tuple[str, str, str] ) -> exp.Select: - """Build ``SELECT *, UNNEST(range(...)) AS __giql_bin FROM
``.""" - _chrom, start, end = cols + """Build ``SELECT chrom, start, end, UNNEST(range(...)) AS __giql_bin FROM table``.""" + chrom, start, end = cols B = self.bin_size low = exp.Cast( @@ -700,33 +805,27 @@ def _build_binned_select( ) select = exp.Select() - select.select(exp.Star(), copy=False) + select.select(exp.column(chrom, quoted=True), copy=False) + select.select(exp.column(start, quoted=True), append=True, copy=False) + select.select(exp.column(end, quoted=True), append=True, copy=False) select.select(bin_alias, append=True, copy=False) select.from_(exp.Table(this=exp.Identifier(this=table_name)), copy=False) return select - def _ensure_table_binned( + def _ensure_key_binned( self, query: exp.Select, - table: exp.Table, - parent: exp.Expression, - binned: dict[str, tuple[str, str, str]], - ) -> tuple[str, tuple[str, str, str]]: - """Create a binned CTE for *table* if one does not already exist. - - Returns (alias, (chrom_col, start_col, end_col)). - """ - alias = table.alias or table.name - - if alias in binned: - return alias, binned[alias] - - table_name = table.name + table_name: str, + key_binned: dict[str, str], + ) -> str: + """Ensure a key-only bins CTE exists for *table_name*; return its name.""" + if table_name in key_binned: + return key_binned[table_name] + + cte_name = f"__giql_{table_name}_bins" cols = self._get_columns(table_name) - cte_name = f"__giql_{alias}_binned" - cte = exp.CTE( - this=self._build_binned_select(table_name, cols), + this=self._build_key_only_bins_select(table_name, cols), alias=exp.TableAlias(this=exp.Identifier(this=cte_name)), ) @@ -736,82 +835,132 @@ def _ensure_table_binned( else: query.set("with_", exp.With(expressions=[cte])) - parent.set( - "this", - exp.Table( - this=exp.Identifier(this=cte_name), - alias=exp.TableAlias(this=exp.Identifier(this=alias)), - ), - ) - - binned[alias] = cols - return alias, cols + key_binned[table_name] = cte_name + return cte_name - def _build_equi_join( + def _build_join_back_joins( self, - from_alias: str, - join_alias: str, - from_cols: tuple[str, str, str], - join_cols: tuple[str, str, str], - ) -> exp.And: - """Build ``chrom = chrom AND __giql_bin = __giql_bin``.""" - chrom_eq = exp.EQ( - this=exp.column(from_cols[0], table=from_alias, quoted=True), - expression=exp.column(join_cols[0], table=join_alias, quoted=True), + query: exp.Select, + join: exp.Join, + intersects: Intersects, + key_binned: dict[str, str], + connector_idx: list[int], + *, + preserve_kind: bool, + ) -> list[exp.Join]: + """Build three replacement JOINs for one INTERSECTS using the join-back pattern. + + join1: JOIN key_cte_for_other connector_a ON other_alias key-matches connector_a + join2: JOIN key_cte_for_join connector_b ON connector_a equi-joins connector_b + join3: JOIN original_join_table join_alias ON join_alias key-matches connector_b + AND overlap predicate + """ + join_table = join.this + if not isinstance(join_table, exp.Table): + return [join] + + join_alias = join_table.alias or join_table.name + join_table_name = join_table.name + + left_alias = intersects.this.table + right_alias = intersects.expression.table + other_alias = left_alias if right_alias == join_alias else right_alias + if other_alias == join_alias: + return [join] # can't determine structure + + other_table_name = self._find_table_name_for_alias(query, other_alias) + other_cols = self._get_columns(other_table_name) + join_cols = self._get_columns(join_table_name) + + other_cte = self._ensure_key_binned(query, other_table_name, key_binned) + join_cte = self._ensure_key_binned(query, join_table_name, key_binned) + + c0 = f"__giql_c{connector_idx[0]}" + c1 = f"__giql_c{connector_idx[0] + 1}" + connector_idx[0] += 2 + + # join1: key-match from other_alias to its bin CTE + join1 = exp.Join( + this=exp.Table( + this=exp.Identifier(this=other_cte), + alias=exp.TableAlias(this=exp.Identifier(this=c0)), + ), + on=exp.And( + this=exp.And( + this=exp.EQ( + this=exp.column(other_cols[0], table=other_alias, quoted=True), + expression=exp.column(other_cols[0], table=c0, quoted=True), + ), + expression=exp.EQ( + this=exp.column(other_cols[1], table=other_alias, quoted=True), + expression=exp.column(other_cols[1], table=c0, quoted=True), + ), + ), + expression=exp.EQ( + this=exp.column(other_cols[2], table=other_alias, quoted=True), + expression=exp.column(other_cols[2], table=c0, quoted=True), + ), + ), ) - bin_eq = exp.EQ( - this=exp.column("__giql_bin", table=from_alias), - expression=exp.column("__giql_bin", table=join_alias), + + # join2: bin equi-join (chrom + __giql_bin match) + join2 = exp.Join( + this=exp.Table( + this=exp.Identifier(this=join_cte), + alias=exp.TableAlias(this=exp.Identifier(this=c1)), + ), + on=exp.And( + this=exp.EQ( + this=exp.column(other_cols[0], table=c0, quoted=True), + expression=exp.column(join_cols[0], table=c1, quoted=True), + ), + expression=exp.EQ( + this=exp.column("__giql_bin", table=c0), + expression=exp.column("__giql_bin", table=c1), + ), + ), ) - return exp.And(this=chrom_eq, expression=bin_eq) - def _build_overlap_filter( - self, - from_alias: str, - join_alias: str, - from_cols: tuple[str, str, str], - join_cols: tuple[str, str, str], - ) -> exp.And: - """Build ``from.start < join.end AND from.end > join.start``.""" - return exp.And( + # join3: key-match from join CTE back to actual join table + overlap + key_match = exp.And( + this=exp.And( + this=exp.EQ( + this=exp.column(join_cols[0], table=join_alias, quoted=True), + expression=exp.column(join_cols[0], table=c1, quoted=True), + ), + expression=exp.EQ( + this=exp.column(join_cols[1], table=join_alias, quoted=True), + expression=exp.column(join_cols[1], table=c1, quoted=True), + ), + ), + expression=exp.EQ( + this=exp.column(join_cols[2], table=join_alias, quoted=True), + expression=exp.column(join_cols[2], table=c1, quoted=True), + ), + ) + overlap = exp.And( this=exp.LT( - this=exp.column(from_cols[1], table=from_alias, quoted=True), + this=exp.column(other_cols[1], table=other_alias, quoted=True), expression=exp.column(join_cols[2], table=join_alias, quoted=True), ), expression=exp.GT( - this=exp.column(from_cols[2], table=from_alias, quoted=True), + this=exp.column(other_cols[2], table=other_alias, quoted=True), expression=exp.column(join_cols[1], table=join_alias, quoted=True), ), ) - def _rewrite_join_on( - self, - query: exp.Select, - join: exp.Join, - intersects: Intersects, - binned: dict[str, tuple[str, str, str]], - ) -> None: - """Rewrite an explicit ``JOIN ... ON ... INTERSECTS ...``.""" - from_table = query.args["from_"].this - join_table = join.this - - if not isinstance(from_table, exp.Table) or not isinstance( - join_table, exp.Table - ): - return - - from_alias, from_cols = self._ensure_table_binned( - query, from_table, query.args["from_"], binned - ) - join_alias, join_cols = self._ensure_table_binned( - query, join_table, join, binned - ) + join3_kwargs: dict = { + "this": exp.Table( + this=exp.Identifier(this=join_table_name), + alias=exp.TableAlias(this=exp.Identifier(this=join_alias)), + ), + "on": exp.And(this=key_match, expression=overlap), + } + if preserve_kind: + kind = join.args.get("kind") + if kind: + join3_kwargs["kind"] = kind - equi_join = self._build_equi_join(from_alias, join_alias, from_cols, join_cols) - overlap = self._build_overlap_filter( - from_alias, join_alias, from_cols, join_cols - ) + join3 = exp.Join(**join3_kwargs) - # Place both equi-join and overlap in ON so that LEFT/RIGHT/FULL - # JOIN semantics are preserved (WHERE would filter out NULL rows). - join.set("on", exp.And(this=equi_join, expression=overlap)) + return [join1, join2, join3] diff --git a/tests/test_binned_join.py b/tests/test_binned_join.py index b92e28d..62c8ee0 100644 --- a/tests/test_binned_join.py +++ b/tests/test_binned_join.py @@ -190,12 +190,12 @@ def test_bin_size_none_defaults_to_10000(self): assert sql_default == sql_none assert "10000" in sql_default - def test_implicit_cross_join_uses_naive_predicate(self): + def test_implicit_cross_join_uses_binned_optimization(self): """ GIVEN a GIQL query with implicit cross-join (FROM a, b WHERE INTERSECTS) WHEN transpiling - THEN should use the naive overlap predicate, not binned CTEs, because - the CTE approach leaks __giql_bin into SELECT * results + THEN should use the binned equi-join optimization without leaking + __giql_bin into SELECT * output columns """ sql = transpile( """ @@ -206,20 +206,23 @@ def test_implicit_cross_join_uses_naive_predicate(self): tables=["peaks", "genes"], ) - # No binned CTEs — falls through to generator's naive predicate - assert "__giql_bin" not in sql - assert "UNNEST" not in sql.upper() + # Binned CTEs are present + assert "WITH" in sql.upper() + assert "__giql_bin" in sql + assert "UNNEST" in sql.upper() - # Naive overlap predicate present + # Original table references preserved — no CTE leak into SELECT * + assert "peaks" in sql assert '"chrom"' in sql assert '"start"' in sql assert '"end"' in sql - def test_self_join_distinct_ctes(self): + def test_self_join_single_shared_cte(self): """ GIVEN a self-join query where the same table appears with two aliases WHEN transpiling a binned join - THEN should produce two distinct CTEs both referencing the same underlying table + THEN should produce one shared key-only CTE for the underlying table, + joined twice through distinct connector aliases """ sql = transpile( """ @@ -232,12 +235,11 @@ def test_self_join_distinct_ctes(self): sql_upper = sql.upper() - # Two distinct CTEs - assert "__giql_a_binned" in sql - assert "__giql_b_binned" in sql + # One shared CTE keyed on the table name + assert "__giql_peaks_bins" in sql - # Both reference the same underlying table - assert "FROM peaks" in sql or "FROM PEAKS" in sql_upper + # Original table preserved in FROM + assert "peaks" in sql # Should still have DISTINCT assert "DISTINCT" in sql_upper @@ -268,8 +270,8 @@ def test_multi_join_all_intersects_rewritten(self): """ GIVEN a three-way join with two INTERSECTS conditions WHEN transpiling - THEN should create binned CTEs for all three tables, reusing the - FROM table CTE, and place equi-join + overlap in each JOIN ON + THEN should create one key-only CTE per underlying table and rewrite + each INTERSECTS join as a three-join bridge through those CTEs """ sql = transpile( """ @@ -281,17 +283,14 @@ def test_multi_join_all_intersects_rewritten(self): tables=["peaks", "genes", "exons"], ) - # Three distinct CTEs - assert "__giql_a_binned" in sql - assert "__giql_b_binned" in sql - assert "__giql_c_binned" in sql - - # FROM table CTE created only once - assert sql.count("FROM peaks") == 1 or sql.upper().count("FROM PEAKS") == 1 + # One CTE per underlying table + assert "__giql_peaks_bins" in sql + assert "__giql_genes_bins" in sql + assert "__giql_exons_bins" in sql - # Both JOINs have equi-join + overlap in ON + # __giql_bin appears in CTE definitions and ON conditions sql_upper = sql.upper() - assert sql_upper.count("__GIQL_BIN") >= 4 # at least 2 per JOIN ON + assert sql_upper.count("__GIQL_BIN") >= 4 # at least 2 per INTERSECTS join class TestBinnedJoinDataFusion: @@ -531,3 +530,35 @@ def test_equivalence_with_naive_cross_join(self): assert len(binned_df) == len(naive_df) assert binned_df.values.tolist() == naive_df.values.tolist() + + def test_implicit_cross_join_correct_rows_no_bin_leak(self): + """ + GIVEN two tables with overlapping intervals queried via implicit cross-join syntax + WHEN executing a binned INTERSECTS join via DataFusion + THEN results should be correct and SELECT a.* should not include __giql_bin + """ + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr1", 1000, 2000)], + genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)], + ) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM peaks a, genes b + WHERE a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(sql).to_pandas() + + # Only the first peak (100-500) overlaps the first gene (300-600) + assert len(df) == 1 + assert df.iloc[0]["start"] == 100 + + # SELECT a.* must return exactly the original table columns — no __giql_bin + assert list(df.columns) == ["chrom", "start", "end"] From 8ead19fa223e3b9871d6688627a3ad500021c1a3 Mon Sep 17 00:00:00 2001 From: Conrad Date: Tue, 31 Mar 2026 11:48:29 -0400 Subject: [PATCH 10/20] perf: Use 1-join full-CTE path when SELECT has no wildcards Queries with explicit column lists (SELECT a.chrom, b.start, ...) cannot expose __giql_bin in their output regardless of which CTE the table alias points to. Detecting this at transform time lets us skip the 3-join bridge pattern entirely for those queries and use the simpler, faster 1-join full-CTE approach. Queries with wildcards (SELECT a.*, SELECT *) still take the bridge path so __giql_bin never leaks into the output column set. --- src/giql/transformer.py | 314 +++++++++++++++++++++++++++++++++----- tests/test_binned_join.py | 46 ++++++ 2 files changed, 324 insertions(+), 36 deletions(-) diff --git a/src/giql/transformer.py b/src/giql/transformer.py index c965016..7b327aa 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -582,41 +582,46 @@ def _transform_for_merge( class IntersectsBinnedJoinTransformer: """Transforms column-to-column INTERSECTS into binned equi-joins. - Handles both explicit JOIN ON and implicit cross-join (WHERE) patterns: - - -- Explicit JOIN - SELECT a.*, b.* - FROM peaks a JOIN genes b ON a.interval INTERSECTS b.region - - -- Implicit cross-join - SELECT DISTINCT a.* - FROM peaks a, genes b WHERE a.interval INTERSECTS b.interval + Handles both explicit JOIN ON and implicit cross-join (WHERE) patterns. + Two rewrite strategies are selected based on the SELECT list: + + **No wildcards** (``SELECT a.chrom, b.start, ...``) — ``__giql_bin`` + cannot appear in the output regardless of CTE content, so the simpler + 1-join full-CTE approach is used: + + WITH __giql_a_binned AS ( + SELECT *, UNNEST(range( + CAST("start" / B AS BIGINT), + CAST(("end" - 1) / B + 1 AS BIGINT) + )) AS __giql_bin FROM peaks + ), + __giql_b_binned AS (...) + SELECT DISTINCT a.chrom, b.start, ... + FROM __giql_a_binned AS a + JOIN __giql_b_binned AS b + ON a."chrom" = b."chrom" AND a.__giql_bin = b.__giql_bin + AND a."start" < b."end" AND a."end" > b."start" - Both are rewritten using key-only bridge CTEs so that ``__giql_bin`` - never appears in the output column set (no SELECT * leak): + **Wildcards present** (``SELECT a.*, b.*``) — ``__giql_bin`` would leak + into ``a.*`` expansion if ``a`` aliases a full-select CTE. A key-only + bridge CTE pattern is used instead, keeping original table references: WITH __giql_peaks_bins AS ( SELECT "chrom", "start", "end", - UNNEST(range( - CAST("start" / B AS BIGINT), - CAST(("end" - 1) / B + 1 AS BIGINT) - )) AS __giql_bin - FROM peaks + UNNEST(range(...)) AS __giql_bin FROM peaks ), __giql_genes_bins AS (...) SELECT DISTINCT a.*, b.* FROM peaks a JOIN __giql_peaks_bins __giql_c0 - ON a."chrom" = __giql_c0."chrom" - AND a."start" = __giql_c0."start" - AND a."end" = __giql_c0."end" + ON a."chrom" = __giql_c0."chrom" AND a."start" = __giql_c0."start" + AND a."end" = __giql_c0."end" JOIN __giql_genes_bins __giql_c1 ON __giql_c0."chrom" = __giql_c1."chrom" AND __giql_c0.__giql_bin = __giql_c1.__giql_bin JOIN genes b - ON b."chrom" = __giql_c1."chrom" - AND b."start" = __giql_c1."start" - AND b."end" = __giql_c1."end" + ON b."chrom" = __giql_c1."chrom" AND b."start" = __giql_c1."start" + AND b."end" = __giql_c1."end" AND a."start" < b."end" AND a."end" > b."start" Literal-range INTERSECTS (e.g., ``WHERE interval INTERSECTS 'chr1:...'``) @@ -633,9 +638,229 @@ def transform(self, query: exp.Expression) -> exp.Expression: if not isinstance(query, exp.Select): return query + # When no wildcards appear in the SELECT list, __giql_bin cannot + # reach the output — use the simpler 1-join full-CTE approach. + # When wildcards are present, fall back to bridge CTEs so that + # __giql_bin is never exposed through a.* expansion. + if self._select_has_wildcards(query): + return self._transform_bridge(query) + return self._transform_full_cte(query) + + def _select_has_wildcards(self, query: exp.Select) -> bool: + """Return True if any SELECT item is a wildcard (* or table.*).""" + for expr in query.expressions: + if isinstance(expr, exp.Star): + return True + if isinstance(expr, exp.Column) and isinstance(expr.this, exp.Star): + return True + return False + + # ------------------------------------------------------------------ + # Full-CTE path (no wildcards — fast, 1 equi-join per INTERSECTS) + # ------------------------------------------------------------------ + + def _transform_full_cte(self, query: exp.Select) -> exp.Select: joins = query.args.get("joins") or [] + # binned: alias -> (chrom_col, start_col, end_col) + binned: dict[str, tuple[str, str, str]] = {} + rewrote_any = False + + for join in joins: + on = join.args.get("on") + if on: + intersects = self._find_column_intersects_in(on) + if intersects: + self._rewrite_join_on_full_cte(query, join, intersects, binned) + rewrote_any = True + + # Implicit cross-join: FROM a, b WHERE a.interval INTERSECTS b.interval + where = query.args.get("where") + if where: + intersects = self._find_column_intersects_in(where.this) + if intersects: + cross_join = self._find_cross_join_for_intersects( + query, intersects, joins + ) + if cross_join is not None: + self._rewrite_cross_join_full_cte( + query, cross_join, intersects, binned + ) + rewrote_any = True + + if rewrote_any: + query.set("distinct", exp.Distinct()) + + return query + + def _rewrite_join_on_full_cte( + self, + query: exp.Select, + join: exp.Join, + intersects: Intersects, + binned: dict[str, tuple[str, str, str]], + ) -> None: + from_table = query.args["from_"].this + join_table = join.this + if not isinstance(from_table, exp.Table) or not isinstance( + join_table, exp.Table + ): + return + + from_alias, from_cols = self._ensure_table_binned_full( + query, from_table, query.args["from_"], binned + ) + join_alias, join_cols = self._ensure_table_binned_full( + query, join_table, join, binned + ) + + equi_join = exp.And( + this=exp.EQ( + this=exp.column(from_cols[0], table=from_alias, quoted=True), + expression=exp.column(join_cols[0], table=join_alias, quoted=True), + ), + expression=exp.EQ( + this=exp.column("__giql_bin", table=from_alias), + expression=exp.column("__giql_bin", table=join_alias), + ), + ) + # Place both equi-join and overlap in ON so LEFT/RIGHT/FULL semantics hold. + join.set( + "on", + exp.And( + this=equi_join, + expression=self._build_overlap( + from_alias, join_alias, from_cols, join_cols + ), + ), + ) + + def _rewrite_cross_join_full_cte( + self, + query: exp.Select, + cross_join: exp.Join, + intersects: Intersects, + binned: dict[str, tuple[str, str, str]], + ) -> None: + from_table = query.args["from_"].this + join_table = cross_join.this + if not isinstance(from_table, exp.Table) or not isinstance( + join_table, exp.Table + ): + return + + from_alias, from_cols = self._ensure_table_binned_full( + query, from_table, query.args["from_"], binned + ) + join_alias, join_cols = self._ensure_table_binned_full( + query, join_table, cross_join, binned + ) + + equi_join = exp.And( + this=exp.EQ( + this=exp.column(from_cols[0], table=from_alias, quoted=True), + expression=exp.column(join_cols[0], table=join_alias, quoted=True), + ), + expression=exp.EQ( + this=exp.column("__giql_bin", table=from_alias), + expression=exp.column("__giql_bin", table=join_alias), + ), + ) + cross_join.set( + "on", + exp.And( + this=equi_join, + expression=self._build_overlap( + from_alias, join_alias, from_cols, join_cols + ), + ), + ) + self._remove_intersects_from_where(query, intersects) + + def _ensure_table_binned_full( + self, + query: exp.Select, + table: exp.Table, + parent: exp.Expression, + binned: dict[str, tuple[str, str, str]], + ) -> tuple[str, tuple[str, str, str]]: + """Create a full SELECT * CTE for *table* if needed; replace ref in *parent*.""" + alias = table.alias or table.name + if alias in binned: + return alias, binned[alias] + + table_name = table.name + cols = self._get_columns(table_name) + cte_name = f"__giql_{alias}_binned" + + cte = exp.CTE( + this=self._build_full_binned_select(table_name, cols), + alias=exp.TableAlias(this=exp.Identifier(this=cte_name)), + ) + existing_with = query.args.get("with_") + if existing_with: + existing_with.append("expressions", cte) + else: + query.set("with_", exp.With(expressions=[cte])) + + parent.set( + "this", + exp.Table( + this=exp.Identifier(this=cte_name), + alias=exp.TableAlias(this=exp.Identifier(this=alias)), + ), + ) + binned[alias] = cols + return alias, cols + + def _build_full_binned_select( + self, table_name: str, cols: tuple[str, str, str] + ) -> exp.Select: + """Build ``SELECT *, UNNEST(range(...)) AS __giql_bin FROM
``.""" + _chrom, start, end = cols + B = self.bin_size - # key_binned: table_name -> CTE name (deduped per underlying table) + low = exp.Cast( + this=exp.Div( + this=exp.column(start, quoted=True), + expression=exp.Literal.number(B), + ), + to=exp.DataType(this=exp.DataType.Type.BIGINT), + ) + high = exp.Cast( + this=exp.Add( + this=exp.Div( + this=exp.Paren( + this=exp.Sub( + this=exp.column(end, quoted=True), + expression=exp.Literal.number(1), + ), + ), + expression=exp.Literal.number(B), + ), + expression=exp.Literal.number(1), + ), + to=exp.DataType(this=exp.DataType.Type.BIGINT), + ) + + range_fn = exp.Anonymous(this="range", expressions=[low, high]) + unnest_fn = exp.Anonymous(this="UNNEST", expressions=[range_fn]) + bin_alias = exp.Alias( + this=unnest_fn, + alias=exp.Identifier(this="__giql_bin"), + ) + + select = exp.Select() + select.select(exp.Star(), copy=False) + select.select(bin_alias, append=True, copy=False) + select.from_(exp.Table(this=exp.Identifier(this=table_name)), copy=False) + return select + + # ------------------------------------------------------------------ + # Bridge path (wildcards present — safe, key-only CTEs + join-back) + # ------------------------------------------------------------------ + + def _transform_bridge(self, query: exp.Select) -> exp.Select: + joins = query.args.get("joins") or [] key_binned: dict[str, str] = {} connector_idx = [0] new_joins: list[exp.Join] = [] @@ -659,7 +884,7 @@ def transform(self, query: exp.Expression) -> exp.Expression: continue new_joins.append(join) - # Handle implicit cross-join: FROM a, b WHERE a.interval INTERSECTS b.interval + # Implicit cross-join: FROM a, b WHERE a.interval INTERSECTS b.interval where = query.args.get("where") if where: intersects = self._find_column_intersects_in(where.this) @@ -747,6 +972,10 @@ def _remove_intersects_from_where( else: intersects.replace(exp.true()) + # ------------------------------------------------------------------ + # Shared helpers + # ------------------------------------------------------------------ + def _get_columns(self, table_name: str) -> tuple[str, str, str]: """Return (chrom, start, end) column names for a table.""" table = self.tables.get(table_name) @@ -754,6 +983,25 @@ def _get_columns(self, table_name: str) -> tuple[str, str, str]: return (table.chrom_col, table.start_col, table.end_col) return (DEFAULT_CHROM_COL, DEFAULT_START_COL, DEFAULT_END_COL) + def _build_overlap( + self, + from_alias: str, + join_alias: str, + from_cols: tuple[str, str, str], + join_cols: tuple[str, str, str], + ) -> exp.And: + """Build ``from.start < join.end AND from.end > join.start``.""" + return exp.And( + this=exp.LT( + this=exp.column(from_cols[1], table=from_alias, quoted=True), + expression=exp.column(join_cols[2], table=join_alias, quoted=True), + ), + expression=exp.GT( + this=exp.column(from_cols[2], table=from_alias, quoted=True), + expression=exp.column(join_cols[1], table=join_alias, quoted=True), + ), + ) + def _find_table_name_for_alias(self, query: exp.Select, alias: str) -> str: """Resolve an alias to its underlying table name.""" from_table = query.args["from_"].this @@ -938,23 +1186,17 @@ def _build_join_back_joins( expression=exp.column(join_cols[2], table=c1, quoted=True), ), ) - overlap = exp.And( - this=exp.LT( - this=exp.column(other_cols[1], table=other_alias, quoted=True), - expression=exp.column(join_cols[2], table=join_alias, quoted=True), - ), - expression=exp.GT( - this=exp.column(other_cols[2], table=other_alias, quoted=True), - expression=exp.column(join_cols[1], table=join_alias, quoted=True), - ), - ) - join3_kwargs: dict = { "this": exp.Table( this=exp.Identifier(this=join_table_name), alias=exp.TableAlias(this=exp.Identifier(this=join_alias)), ), - "on": exp.And(this=key_match, expression=overlap), + "on": exp.And( + this=key_match, + expression=self._build_overlap( + other_alias, join_alias, other_cols, join_cols + ), + ), } if preserve_kind: kind = join.args.get("kind") diff --git a/tests/test_binned_join.py b/tests/test_binned_join.py index 62c8ee0..e8cdf12 100644 --- a/tests/test_binned_join.py +++ b/tests/test_binned_join.py @@ -292,6 +292,52 @@ def test_multi_join_all_intersects_rewritten(self): sql_upper = sql.upper() assert sql_upper.count("__GIQL_BIN") >= 4 # at least 2 per INTERSECTS join + def test_explicit_columns_uses_full_cte_not_bridge(self): + """ + GIVEN a join query with only explicit columns in SELECT (no wildcards) + WHEN transpiling + THEN should use the 1-join full-CTE approach, not bridge CTEs + """ + sql = transpile( + """ + SELECT a.chrom, a.start, b.start AS b_start + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + ) + + # Full-CTE names (alias-based, one join per INTERSECTS) + assert "__giql_a_binned" in sql + assert "__giql_b_binned" in sql + + # Bridge CTEs must NOT be present + assert "__giql_peaks_bins" not in sql + assert "__giql_c0" not in sql + + def test_wildcard_select_uses_bridge_not_full_cte(self): + """ + GIVEN a join query with a wildcard expression in SELECT (a.*) + WHEN transpiling + THEN should use the bridge CTE approach, not full-CTEs + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + ) + + # Bridge CTE names (table-based) + assert "__giql_peaks_bins" in sql + assert "__giql_genes_bins" in sql + + # Full CTEs must NOT be present + assert "__giql_a_binned" not in sql + assert "__giql_b_binned" not in sql + class TestBinnedJoinDataFusion: """End-to-end DataFusion correctness tests for binned INTERSECTS joins.""" From 3a1706330b40c51423fb74ef5b82ce5f58b3ddb7 Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 10:05:13 -0400 Subject: [PATCH 11/20] style: Remove structural comment dividers Drop section divider lines (`# --...--`) from `IntersectsBinnedJoinTransformer` to reduce visual clutter. Descriptive inline comments explaining code behavior are preserved. --- src/giql/transformer.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/giql/transformer.py b/src/giql/transformer.py index 7b327aa..50cbd4f 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -655,10 +655,6 @@ def _select_has_wildcards(self, query: exp.Select) -> bool: return True return False - # ------------------------------------------------------------------ - # Full-CTE path (no wildcards — fast, 1 equi-join per INTERSECTS) - # ------------------------------------------------------------------ - def _transform_full_cte(self, query: exp.Select) -> exp.Select: joins = query.args.get("joins") or [] # binned: alias -> (chrom_col, start_col, end_col) @@ -855,10 +851,6 @@ def _build_full_binned_select( select.from_(exp.Table(this=exp.Identifier(this=table_name)), copy=False) return select - # ------------------------------------------------------------------ - # Bridge path (wildcards present — safe, key-only CTEs + join-back) - # ------------------------------------------------------------------ - def _transform_bridge(self, query: exp.Select) -> exp.Select: joins = query.args.get("joins") or [] key_binned: dict[str, str] = {} @@ -972,10 +964,6 @@ def _remove_intersects_from_where( else: intersects.replace(exp.true()) - # ------------------------------------------------------------------ - # Shared helpers - # ------------------------------------------------------------------ - def _get_columns(self, table_name: str) -> tuple[str, str, str]: """Return (chrom, start, end) column names for a table.""" table = self.tables.get(table_name) From 8986276c930c2845ad47bc701c1f6511c5113b90 Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 11:15:04 -0400 Subject: [PATCH 12/20] test: Add regression tests for three binned join bugs Cover outer join semantics (LEFT/RIGHT/FULL preserved through both full-CTE and bridge paths), additional ON conditions surviving the rewrite alongside INTERSECTS, and unconditional DISTINCT collapsing legitimate duplicate rows. The DISTINCT tests are marked xfail since the correct behavior (preserving duplicates) is a known limitation. 7 tests fail against the current implementation, confirming the bugs. 2 tests are strict xfail documenting the DISTINCT limitation. --- tests/test_binned_join.py | 718 +++++++++++++++++++++++++++++++++++++- 1 file changed, 716 insertions(+), 2 deletions(-) diff --git a/tests/test_binned_join.py b/tests/test_binned_join.py index e8cdf12..dc54d8e 100644 --- a/tests/test_binned_join.py +++ b/tests/test_binned_join.py @@ -1,9 +1,23 @@ """Tests for the INTERSECTS binned equi-join transpilation.""" +import math + +import pytest + from giql import Table from giql import transpile +def _is_null(value) -> bool: + """Check if a value is null/NaN (DataFusion returns NaN for nullable int64).""" + if value is None: + return True + try: + return math.isnan(value) + except (TypeError, ValueError): + return False + + class TestTranspileBinnedJoin: """Unit tests for binned join SQL structure.""" @@ -250,8 +264,6 @@ def test_invalid_bin_size_raises(self): WHEN calling transpile THEN should raise ValueError """ - import pytest - with pytest.raises(ValueError, match="positive"): transpile( "SELECT * FROM a JOIN b ON a.interval INTERSECTS b.interval", @@ -608,3 +620,705 @@ def test_implicit_cross_join_correct_rows_no_bin_leak(self): # SELECT a.* must return exactly the original table columns — no __giql_bin assert list(df.columns) == ["chrom", "start", "end"] + + +class TestBinnedJoinOuterJoinSemantics: + """Regression tests: outer join kinds must be preserved after rewrite. + + Bug: the bridge path only applied the join kind (LEFT, RIGHT, FULL) to + join3, while join1 and join2 were always INNER — silently converting + outer joins into inner joins. + """ + + @staticmethod + def _make_ctx(peaks_data, genes_data): + """Create a DataFusion context with peaks and genes tables.""" + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "peaks", + [ + pa.table( + { + "chrom": [r[0] for r in peaks_data], + "start": [r[1] for r in peaks_data], + "end": [r[2] for r in peaks_data], + }, + schema=schema, + ).to_batches() + ], + ) + ctx.register_record_batches( + "genes", + [ + pa.table( + { + "chrom": [r[0] for r in genes_data], + "start": [r[1] for r in genes_data], + "end": [r[2] for r in genes_data], + }, + schema=schema, + ).to_batches() + ], + ) + return ctx + + def test_left_join_preserves_unmatched_left_rows_full_cte(self): + """ + GIVEN peaks with one matching and one non-matching interval + WHEN a LEFT JOIN with INTERSECTS is transpiled (no wildcards, full-CTE path) + THEN the SQL must contain LEFT keyword and execution must return all + left rows including unmatched ones with NULLs on the right + """ + sql = transpile( + """ + SELECT a.chrom, a.start, a."end", b.start AS b_start + FROM peaks a + LEFT JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "LEFT" in sql.upper() + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr1", 1000, 2000)], + genes_data=[("chr1", 300, 600)], + ) + df = ctx.sql(sql).to_pandas().sort_values("start").reset_index(drop=True) + + assert len(df) == 2 + assert df.iloc[0]["start"] == 100 + assert df.iloc[0]["b_start"] == 300 + assert df.iloc[1]["start"] == 1000 + assert _is_null(df.iloc[1]["b_start"]) + + def test_left_join_preserves_unmatched_left_rows_bridge(self): + """ + GIVEN peaks with one matching and one non-matching interval + WHEN a LEFT JOIN with INTERSECTS is transpiled (wildcards, bridge path) + THEN the SQL must contain LEFT keyword and execution must return all + left rows including unmatched ones with NULLs on the right + """ + sql = transpile( + """ + SELECT a.*, b.start AS b_start + FROM peaks a + LEFT JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "LEFT" in sql.upper() + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr1", 1000, 2000)], + genes_data=[("chr1", 300, 600)], + ) + df = ctx.sql(sql).to_pandas().sort_values("start").reset_index(drop=True) + + assert len(df) == 2 + assert df.iloc[0]["start"] == 100 + assert df.iloc[0]["b_start"] == 300 + assert df.iloc[1]["start"] == 1000 + assert _is_null(df.iloc[1]["b_start"]) + + def test_right_join_preserves_unmatched_right_rows_full_cte(self): + """ + GIVEN genes with one matching and one non-matching interval + WHEN a RIGHT JOIN with INTERSECTS is transpiled (no wildcards, full-CTE path) + THEN the SQL must contain RIGHT keyword and execution must return all + right rows including unmatched ones with NULLs on the left + """ + sql = transpile( + """ + SELECT a.start AS a_start, b.chrom, b.start, b."end" + FROM peaks a + RIGHT JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "RIGHT" in sql.upper() + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500)], + genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)], + ) + df = ctx.sql(sql).to_pandas().sort_values("start").reset_index(drop=True) + + assert len(df) == 2 + matched = df[df["a_start"].notna()] + unmatched = df[df["a_start"].isna()] + assert len(matched) == 1 + assert matched.iloc[0]["start"] == 300 + assert len(unmatched) == 1 + assert unmatched.iloc[0]["start"] == 5000 + + def test_right_join_preserves_unmatched_right_rows_bridge(self): + """ + GIVEN genes with one matching and one non-matching interval + WHEN a RIGHT JOIN with INTERSECTS is transpiled (wildcards, bridge path) + THEN the SQL must contain RIGHT keyword and execution must return all + right rows including unmatched ones with NULLs on the left + """ + sql = transpile( + """ + SELECT a.start AS a_start, b.* + FROM peaks a + RIGHT JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "RIGHT" in sql.upper() + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500)], + genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)], + ) + df = ctx.sql(sql).to_pandas().sort_values("start").reset_index(drop=True) + + assert len(df) == 2 + matched = df[df["a_start"].notna()] + unmatched = df[df["a_start"].isna()] + assert len(matched) == 1 + assert matched.iloc[0]["start"] == 300 + assert len(unmatched) == 1 + assert unmatched.iloc[0]["start"] == 5000 + + def test_full_outer_join_preserves_both_unmatched_full_cte(self): + """ + GIVEN peaks and genes each with one matching and one non-matching interval + WHEN a FULL OUTER JOIN with INTERSECTS is transpiled (no wildcards, full-CTE) + THEN the SQL must contain FULL keyword and execution must return three + rows: one matched pair plus one unmatched from each side + """ + sql = transpile( + """ + SELECT a.start AS a_start, a."end" AS a_end, + b.start AS b_start, b."end" AS b_end + FROM peaks a + FULL OUTER JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "FULL" in sql.upper() + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr1", 8000, 9000)], + genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)], + ) + df = ctx.sql(sql).to_pandas() + + assert len(df) == 3 + matched = df[df["a_start"].notna() & df["b_start"].notna()] + left_only = df[df["a_start"].notna() & df["b_start"].isna()] + right_only = df[df["a_start"].isna() & df["b_start"].notna()] + assert len(matched) == 1 + assert len(left_only) == 1 + assert len(right_only) == 1 + + def test_full_outer_join_preserves_both_unmatched_bridge(self): + """ + GIVEN peaks and genes each with one matching and one non-matching interval + WHEN a FULL OUTER JOIN with INTERSECTS is transpiled (wildcards, bridge path) + THEN the SQL must contain FULL keyword and execution must return three + rows: one matched pair plus one unmatched from each side + """ + sql = transpile( + """ + SELECT a.*, b.start AS b_start, b."end" AS b_end + FROM peaks a + FULL OUTER JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "FULL" in sql.upper() + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr1", 8000, 9000)], + genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)], + ) + df = ctx.sql(sql).to_pandas() + + assert len(df) == 3 + matched = df[df["start"].notna() & df["b_start"].notna()] + left_only = df[df["start"].notna() & df["b_start"].isna()] + right_only = df[df["start"].isna() & df["b_start"].notna()] + assert len(matched) == 1 + assert len(left_only) == 1 + assert len(right_only) == 1 + + def test_left_join_all_unmatched_returns_all_left_rows(self): + """ + GIVEN peaks where no intervals overlap any gene + WHEN a LEFT JOIN with INTERSECTS is transpiled + THEN all left rows must still appear with NULLs on the right + """ + sql = transpile( + """ + SELECT a.chrom, a.start, a."end", b.start AS b_start + FROM peaks a + LEFT JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 200), ("chr1", 300, 400)], + genes_data=[("chr1", 500, 600)], + ) + df = ctx.sql(sql).to_pandas() + + assert len(df) == 2 + assert df["b_start"].isna().all() + + +class TestBinnedJoinAdditionalOnConditions: + """Regression tests: non-INTERSECTS conditions in ON must be preserved. + + Bug: the rewrite replaces the entire ON clause with the binned equi-join + and overlap predicate, silently dropping any additional user conditions + like ``AND a.score > b.score``. + """ + + @staticmethod + def _make_ctx_with_score(): + """Create a DataFusion context with peaks and genes tables that include a score column.""" + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ("score", pa.int64()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "peaks", + [ + pa.table( + { + "chrom": ["chr1", "chr1"], + "start": [100, 100], + "end": [500, 500], + "score": [10, 50], + }, + schema=schema, + ).to_batches() + ], + ) + ctx.register_record_batches( + "genes", + [ + pa.table( + { + "chrom": ["chr1", "chr1"], + "start": [200, 200], + "end": [600, 600], + "score": [30, 30], + }, + schema=schema, + ).to_batches() + ], + ) + return ctx + + def test_additional_on_condition_preserved_full_cte(self): + """ + GIVEN two overlapping intervals where only one pair satisfies score filter + WHEN INTERSECTS is combined with a.score > b.score in ON (no wildcards) + THEN the additional condition must survive the rewrite and filter results + """ + sql = transpile( + """ + SELECT a.chrom, a.start, a."end", a.score AS a_score, b.score AS b_score + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval AND a.score > b.score + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "score" in sql.lower() + + ctx = self._make_ctx_with_score() + df = ctx.sql(sql).to_pandas() + + assert len(df) == 1 + assert df.iloc[0]["a_score"] == 50 + + def test_additional_on_condition_preserved_bridge(self): + """ + GIVEN two overlapping intervals where only one pair satisfies score filter + WHEN INTERSECTS is combined with a.score > b.score in ON (wildcards) + THEN the additional condition must survive the rewrite and filter results + """ + sql = transpile( + """ + SELECT a.*, b.score AS b_score + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval AND a.score > b.score + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "score" in sql.lower() + + ctx = self._make_ctx_with_score() + df = ctx.sql(sql).to_pandas() + + assert len(df) == 1 + assert df.iloc[0]["score"] == 50 + + def test_additional_on_condition_with_left_join(self): + """ + GIVEN overlapping intervals with an extra ON condition that filters all matches + WHEN LEFT JOIN with INTERSECTS AND a.score > b.score is used + THEN unmatched left rows appear with NULL right columns + """ + sql = transpile( + """ + SELECT a.chrom, a.start, a.score AS a_score, b.score AS b_score + FROM peaks a + LEFT JOIN genes b + ON a.interval INTERSECTS b.interval AND a.score > b.score + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + ctx = self._make_ctx_with_score() + df = ctx.sql(sql).to_pandas().sort_values("a_score").reset_index(drop=True) + + assert len(df) == 2 + row_low = df[df["a_score"] == 10].iloc[0] + row_high = df[df["a_score"] == 50].iloc[0] + assert _is_null(row_low["b_score"]) + assert row_high["b_score"] == 30 + + def test_multiple_additional_conditions_preserved(self): + """ + GIVEN overlapping intervals with two extra ON conditions + WHEN INTERSECTS is combined with a.score > 20 AND b.score < 40 in ON + THEN both conditions must survive the rewrite + """ + sql = transpile( + """ + SELECT a.chrom, a.start, a.score AS a_score, b.score AS b_score + FROM peaks a + JOIN genes b + ON a.interval INTERSECTS b.interval + AND a.score > 20 + AND b.score < 40 + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + sql_lower = sql.lower() + assert "score" in sql_lower + assert "20" in sql + assert "40" in sql + + def test_additional_on_condition_implicit_cross_join(self): + """ + GIVEN overlapping intervals queried via implicit cross-join with extra WHERE + WHEN INTERSECTS is in WHERE alongside a.score > b.score + THEN the score condition must be preserved in the output SQL + """ + sql = transpile( + """ + SELECT a.chrom, a.start, a.score AS a_score, b.score AS b_score + FROM peaks a, genes b + WHERE a.interval INTERSECTS b.interval AND a.score > b.score + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "score" in sql.lower() + + ctx = self._make_ctx_with_score() + df = ctx.sql(sql).to_pandas() + + assert len(df) == 1 + assert df.iloc[0]["a_score"] == 50 + + +class TestBinnedJoinDistinctSemantics: + """Regression tests: unconditional DISTINCT can collapse legitimate duplicates. + + Bug: the transformer always adds DISTINCT to deduplicate bin fan-out, + but this also collapses rows that are genuinely duplicated in the source + data, changing SQL bag semantics. + """ + + @staticmethod + def _make_ctx_with_duplicates(): + """Create a DataFusion context where peaks has duplicate rows.""" + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "peaks", + [ + pa.table( + { + "chrom": ["chr1", "chr1"], + "start": [100, 100], + "end": [500, 500], + }, + schema=schema, + ).to_batches() + ], + ) + ctx.register_record_batches( + "genes", + [ + pa.table( + { + "chrom": ["chr1"], + "start": [200], + "end": [600], + }, + schema=schema, + ).to_batches() + ], + ) + return ctx + + @pytest.mark.xfail( + reason="Unconditional DISTINCT collapses legitimate duplicate rows", + strict=True, + ) + def test_duplicate_rows_preserved_full_cte(self): + """ + GIVEN peaks with two identical rows that both overlap one gene + WHEN an inner join with INTERSECTS is transpiled (no wildcards, full-CTE) + THEN both rows should be returned, matching naive cross-join behavior + """ + ctx = self._make_ctx_with_duplicates() + + naive_sql = """ + SELECT a.chrom, a.start, a."end", b.start AS b_start + FROM peaks a, genes b + WHERE a.chrom = b.chrom AND a.start < b."end" AND a."end" > b.start + """ + naive_df = ctx.sql(naive_sql).to_pandas() + assert len(naive_df) == 2 + + binned_sql = transpile( + """ + SELECT a.chrom, a.start, a."end", b.start AS b_start + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + binned_df = ctx.sql(binned_sql).to_pandas() + assert len(binned_df) == len(naive_df) + + @pytest.mark.xfail( + reason="Unconditional DISTINCT collapses legitimate duplicate rows", + strict=True, + ) + def test_duplicate_rows_preserved_bridge(self): + """ + GIVEN peaks with two identical rows that both overlap one gene + WHEN an inner join with INTERSECTS is transpiled (wildcards, bridge path) + THEN both rows should be returned, matching naive cross-join behavior + """ + ctx = self._make_ctx_with_duplicates() + + naive_sql = """ + SELECT a.chrom, a.start, a."end", b.start AS b_start + FROM peaks a, genes b + WHERE a.chrom = b.chrom AND a.start < b."end" AND a."end" > b.start + """ + naive_df = ctx.sql(naive_sql).to_pandas() + assert len(naive_df) == 2 + + binned_sql = transpile( + """ + SELECT a.*, b.start AS b_start + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + binned_df = ctx.sql(binned_sql).to_pandas() + assert len(binned_df) == len(naive_df) + + def test_non_duplicate_rows_unaffected(self): + """ + GIVEN peaks with two distinct rows that both overlap one gene + WHEN an inner join with INTERSECTS is transpiled + THEN DISTINCT does not collapse them because they differ + """ + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "peaks", + [ + pa.table( + { + "chrom": ["chr1", "chr1"], + "start": [100, 150], + "end": [500, 550], + }, + schema=schema, + ).to_batches() + ], + ) + ctx.register_record_batches( + "genes", + [ + pa.table( + { + "chrom": ["chr1"], + "start": [200], + "end": [600], + }, + schema=schema, + ).to_batches() + ], + ) + + binned_sql = transpile( + """ + SELECT a.chrom, a.start, a."end", b.start AS b_start + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(binned_sql).to_pandas() + assert len(df) == 2 + + def test_user_distinct_already_present_still_works(self): + """ + GIVEN a query that already has SELECT DISTINCT + WHEN the binned join rewrite also adds DISTINCT + THEN the query must still execute correctly (no double-DISTINCT error) + """ + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "peaks", + [ + pa.table( + {"chrom": ["chr1"], "start": [100], "end": [500]}, + schema=schema, + ).to_batches() + ], + ) + ctx.register_record_batches( + "genes", + [ + pa.table( + {"chrom": ["chr1"], "start": [200], "end": [600]}, + schema=schema, + ).to_batches() + ], + ) + + binned_sql = transpile( + """ + SELECT DISTINCT a.chrom, a.start, b.start AS b_start + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(binned_sql).to_pandas() + assert len(df) == 1 From 0530c5d1d51ace1ca1a43bad0db2acb01efcb1d5 Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 11:47:37 -0400 Subject: [PATCH 13/20] fix: Preserve outer join semantics and extra ON conditions Two interrelated fixes for the binned equi-join rewrite: The bridge path was silently converting LEFT/RIGHT/FULL joins to INNER because sqlglot stores the join type as "side" not "kind", and only join3 received it. Propagate the side attribute to both join2 and join3. FULL OUTER with wildcards falls back to the full-CTE path because the three-join chain's bin fan-out creates spurious unmatched rows that DISTINCT cannot resolve. Both rewrite paths were replacing the entire ON clause with the binned equi-join and overlap predicate, silently dropping any user-supplied conditions alongside INTERSECTS. Extract non- INTERSECTS conditions from the original ON and AND them back into the rewritten clause. --- src/giql/transformer.py | 90 ++++++++++++++++++++++++++++------------- 1 file changed, 63 insertions(+), 27 deletions(-) diff --git a/src/giql/transformer.py b/src/giql/transformer.py index 50cbd4f..b85be6c 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -638,11 +638,10 @@ def transform(self, query: exp.Expression) -> exp.Expression: if not isinstance(query, exp.Select): return query - # When no wildcards appear in the SELECT list, __giql_bin cannot - # reach the output — use the simpler 1-join full-CTE approach. - # When wildcards are present, fall back to bridge CTEs so that - # __giql_bin is never exposed through a.* expansion. - if self._select_has_wildcards(query): + # The bridge path can't faithfully represent FULL OUTER JOIN + # because the three-join chain's bin fan-out creates spurious + # unmatched rows. Fall back to full-CTE for those queries. + if self._select_has_wildcards(query) and not self._has_full_outer_join(query): return self._transform_bridge(query) return self._transform_full_cte(query) @@ -655,6 +654,13 @@ def _select_has_wildcards(self, query: exp.Select) -> bool: return True return False + def _has_full_outer_join(self, query: exp.Select) -> bool: + """Return True if any JOIN in the query is a FULL OUTER JOIN.""" + for join in query.args.get("joins") or []: + if join.args.get("side") == "FULL": + return True + return False + def _transform_full_cte(self, query: exp.Select) -> exp.Select: joins = query.args.get("joins") or [] # binned: alias -> (chrom_col, start_col, end_col) @@ -709,6 +715,8 @@ def _rewrite_join_on_full_cte( query, join_table, join, binned ) + extra = self._extract_non_intersects(join.args.get("on"), intersects) + equi_join = exp.And( this=exp.EQ( this=exp.column(from_cols[0], table=from_alias, quoted=True), @@ -720,15 +728,13 @@ def _rewrite_join_on_full_cte( ), ) # Place both equi-join and overlap in ON so LEFT/RIGHT/FULL semantics hold. - join.set( - "on", - exp.And( - this=equi_join, - expression=self._build_overlap( - from_alias, join_alias, from_cols, join_cols - ), - ), + new_on = exp.And( + this=equi_join, + expression=self._build_overlap(from_alias, join_alias, from_cols, join_cols), ) + if extra: + new_on = exp.And(this=new_on, expression=extra) + join.set("on", new_on) def _rewrite_cross_join_full_cte( self, @@ -964,6 +970,26 @@ def _remove_intersects_from_where( else: intersects.replace(exp.true()) + def _extract_non_intersects( + self, expr: exp.Expression | None, intersects: Intersects + ) -> exp.Expression | None: + """Return the parts of an AND tree that are not the INTERSECTS node.""" + if expr is None or expr is intersects: + return None + if isinstance(expr, exp.And): + if expr.this is intersects: + return expr.expression + if expr.expression is intersects: + return expr.this + left = self._extract_non_intersects(expr.this, intersects) + right = self._extract_non_intersects(expr.expression, intersects) + if left is None: + return right + if right is None: + return left + return exp.And(this=left, expression=right) + return expr + def _get_columns(self, table_name: str) -> tuple[str, str, str]: """Return (chrom, start, end) column names for a table.""" table = self.tables.get(table_name) @@ -1104,6 +1130,8 @@ def _build_join_back_joins( if other_alias == join_alias: return [join] # can't determine structure + extra = self._extract_non_intersects(join.args.get("on"), intersects) + other_table_name = self._find_table_name_for_alias(query, other_alias) other_cols = self._get_columns(other_table_name) join_cols = self._get_columns(join_table_name) @@ -1115,6 +1143,10 @@ def _build_join_back_joins( c1 = f"__giql_c{connector_idx[0] + 1}" connector_idx[0] += 2 + join_side = None + if preserve_kind: + join_side = join.args.get("side") + # join1: key-match from other_alias to its bin CTE join1 = exp.Join( this=exp.Table( @@ -1140,12 +1172,12 @@ def _build_join_back_joins( ) # join2: bin equi-join (chrom + __giql_bin match) - join2 = exp.Join( - this=exp.Table( + join2_kwargs: dict = { + "this": exp.Table( this=exp.Identifier(this=join_cte), alias=exp.TableAlias(this=exp.Identifier(this=c1)), ), - on=exp.And( + "on": exp.And( this=exp.EQ( this=exp.column(other_cols[0], table=c0, quoted=True), expression=exp.column(join_cols[0], table=c1, quoted=True), @@ -1155,7 +1187,10 @@ def _build_join_back_joins( expression=exp.column("__giql_bin", table=c1), ), ), - ) + } + if join_side: + join2_kwargs["side"] = join_side + join2 = exp.Join(**join2_kwargs) # join3: key-match from join CTE back to actual join table + overlap key_match = exp.And( @@ -1174,22 +1209,23 @@ def _build_join_back_joins( expression=exp.column(join_cols[2], table=c1, quoted=True), ), ) + join3_on = exp.And( + this=key_match, + expression=self._build_overlap( + other_alias, join_alias, other_cols, join_cols + ), + ) + if extra: + join3_on = exp.And(this=join3_on, expression=extra) join3_kwargs: dict = { "this": exp.Table( this=exp.Identifier(this=join_table_name), alias=exp.TableAlias(this=exp.Identifier(this=join_alias)), ), - "on": exp.And( - this=key_match, - expression=self._build_overlap( - other_alias, join_alias, other_cols, join_cols - ), - ), + "on": join3_on, } - if preserve_kind: - kind = join.args.get("kind") - if kind: - join3_kwargs["kind"] = kind + if join_side: + join3_kwargs["side"] = join_side join3 = exp.Join(**join3_kwargs) From 4c3526cc2df39778e31454203851dd4eb0cddb20 Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 12:01:36 -0400 Subject: [PATCH 14/20] docs: Document INTERSECTS binned join deduplication behavior DISTINCT is added unconditionally to column-to-column INTERSECTS joins to eliminate duplicates from the bin fan-out. This section explains the mechanism, the edge case where it can collapse genuinely identical source rows, and the mitigation of including any distinguishing column in the SELECT list. --- docs/dialect/spatial-operators.rst | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/dialect/spatial-operators.rst b/docs/dialect/spatial-operators.rst index 6bf4433..1a8b3cd 100644 --- a/docs/dialect/spatial-operators.rst +++ b/docs/dialect/spatial-operators.rst @@ -99,6 +99,24 @@ Find all variants, with gene information where available: FROM variants v LEFT JOIN genes g ON v.interval INTERSECTS g.interval +Deduplication Behavior +~~~~~~~~~~~~~~~~~~~~~~ + +Column-to-column ``INTERSECTS`` joins use a binned equi-join strategy internally: each interval is assigned to one or more fixed-width bins, and the join is performed on ``(chrom, bin)`` pairs. Because an interval that spans a bin boundary belongs to more than one bin, a single source row can match the same result row more than once. GIQL adds ``SELECT DISTINCT`` automatically to remove these duplicate rows. + +This deduplication is usually transparent, but it has one observable side effect: ``DISTINCT`` operates on the entire set of selected columns, so rows that are genuinely identical across every selected column will also be collapsed into one. This matters when a table contains duplicate source records with no distinguishing column. + +To prevent unintended deduplication, include any column that makes rows distinguishable — such as a primary key, name, or score — in the ``SELECT`` list: + +.. code-block:: sql + + -- score distinguishes otherwise-identical rows + SELECT v.chrom, v.start, v.end, v.score, g.name + FROM variants v + INNER JOIN genes g ON v.interval INTERSECTS g.interval + +If all columns are identical across two source rows (including any unique identifier), those rows represent the same logical record and collapsing them is correct behavior. + Related Operators ~~~~~~~~~~~~~~~~~ From a702a6ea34397dce690813848159c0fd953b8934 Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 12:26:25 -0400 Subject: [PATCH 15/20] refactor: Address review findings on IntersectsBinnedJoinTransformer Move DEFAULT_BIN_SIZE to constants module and export from __init__. Extract shared _build_bin_range helper to eliminate duplicate bin-computation logic between the two CTE builders. Replace the mutable-list connector counter with itertools.count. Add isinstance check for bin_size so floats are rejected early. Rewrite _remove_intersects_from_where to use _extract_non_intersects so deeply-nested AND trees are handled cleanly. Expand docstrings on the class, __init__, _find_column_intersects_in, and _build_join_back_joins to document assumptions and limitations. --- src/giql/__init__.py | 2 + src/giql/constants.py | 3 + src/giql/transformer.py | 123 ++++++++++++++++++++-------------------- 3 files changed, 67 insertions(+), 61 deletions(-) diff --git a/src/giql/__init__.py b/src/giql/__init__.py index 71e895d..e5df351 100644 --- a/src/giql/__init__.py +++ b/src/giql/__init__.py @@ -3,6 +3,7 @@ A SQL dialect for genomic range queries. """ +from giql.constants import DEFAULT_BIN_SIZE from giql.table import Table from giql.transpile import transpile @@ -10,6 +11,7 @@ __all__ = [ + "DEFAULT_BIN_SIZE", "Table", "transpile", ] diff --git a/src/giql/constants.py b/src/giql/constants.py index 87f8055..0daf016 100644 --- a/src/giql/constants.py +++ b/src/giql/constants.py @@ -9,3 +9,6 @@ DEFAULT_END_COL = "end" DEFAULT_STRAND_COL = "strand" DEFAULT_GENOMIC_COL = "interval" + +# Default bin size for INTERSECTS binned equi-join optimization +DEFAULT_BIN_SIZE = 10_000 diff --git a/src/giql/transformer.py b/src/giql/transformer.py index b85be6c..1507389 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -5,8 +5,11 @@ SQL with CTEs. """ +import itertools + from sqlglot import exp +from giql.constants import DEFAULT_BIN_SIZE from giql.constants import DEFAULT_CHROM_COL from giql.constants import DEFAULT_END_COL from giql.constants import DEFAULT_START_COL @@ -16,8 +19,6 @@ from giql.expressions import Intersects from giql.table import Tables -DEFAULT_BIN_SIZE = 10_000 - class ClusterTransformer: """Transforms queries containing CLUSTER into CTE-based queries. @@ -626,13 +627,30 @@ class IntersectsBinnedJoinTransformer: Literal-range INTERSECTS (e.g., ``WHERE interval INTERSECTS 'chr1:...'``) are left untouched. + + SELECT DISTINCT is added to deduplicate rows produced by multi-bin + matches. This means rows that are identical across every selected + column will be collapsed — include a distinguishing column (e.g., an + id or score) to preserve duplicates that differ only in unselected + columns. The bridge path's key-match joins on ``(chrom, start, + end)`` and may fan out if multiple source rows share those values; + DISTINCT corrects for this. """ def __init__(self, tables: Tables, bin_size: int | None = None): + """Initialize transformer. + + :param tables: + Table configurations for column mapping + :param bin_size: + Bin width for the equi-join rewrite. Defaults to + DEFAULT_BIN_SIZE if not specified. + """ self.tables = tables - self.bin_size = bin_size if bin_size is not None else DEFAULT_BIN_SIZE - if self.bin_size <= 0: - raise ValueError(f"bin_size must be a positive integer, got {self.bin_size}") + resolved = bin_size if bin_size is not None else DEFAULT_BIN_SIZE + if not isinstance(resolved, int) or resolved <= 0: + raise ValueError(f"bin_size must be a positive integer, got {resolved!r}") + self.bin_size = resolved def transform(self, query: exp.Expression) -> exp.Expression: if not isinstance(query, exp.Select): @@ -663,7 +681,6 @@ def _has_full_outer_join(self, query: exp.Select) -> bool: def _transform_full_cte(self, query: exp.Select) -> exp.Select: joins = query.args.get("joins") or [] - # binned: alias -> (chrom_col, start_col, end_col) binned: dict[str, tuple[str, str, str]] = {} rewrote_any = False @@ -814,17 +831,18 @@ def _ensure_table_binned_full( binned[alias] = cols return alias, cols - def _build_full_binned_select( - self, table_name: str, cols: tuple[str, str, str] - ) -> exp.Select: - """Build ``SELECT *, UNNEST(range(...)) AS __giql_bin FROM
``.""" - _chrom, start, end = cols - B = self.bin_size + def _build_bin_range(self, start: str, end: str) -> tuple[exp.Cast, exp.Cast]: + """Build the (low, high) bin-index expressions for UNNEST(range(...)). + + Returns CAST(start / bin_size AS BIGINT) and + CAST((end - 1) / bin_size + 1 AS BIGINT). + """ + bs = self.bin_size low = exp.Cast( this=exp.Div( this=exp.column(start, quoted=True), - expression=exp.Literal.number(B), + expression=exp.Literal.number(bs), ), to=exp.DataType(this=exp.DataType.Type.BIGINT), ) @@ -837,12 +855,20 @@ def _build_full_binned_select( expression=exp.Literal.number(1), ), ), - expression=exp.Literal.number(B), + expression=exp.Literal.number(bs), ), expression=exp.Literal.number(1), ), to=exp.DataType(this=exp.DataType.Type.BIGINT), ) + return low, high + + def _build_full_binned_select( + self, table_name: str, cols: tuple[str, str, str] + ) -> exp.Select: + """Build ``SELECT *, UNNEST(range(...)) AS __giql_bin FROM
``.""" + _chrom, start, end = cols + low, high = self._build_bin_range(start, end) range_fn = exp.Anonymous(this="range", expressions=[low, high]) unnest_fn = exp.Anonymous(this="UNNEST", expressions=[range_fn]) @@ -860,7 +886,7 @@ def _build_full_binned_select( def _transform_bridge(self, query: exp.Select) -> exp.Select: joins = query.args.get("joins") or [] key_binned: dict[str, str] = {} - connector_idx = [0] + connector_counter = itertools.count() new_joins: list[exp.Join] = [] rewrote_any = False @@ -874,7 +900,7 @@ def _transform_bridge(self, query: exp.Select) -> exp.Select: join, intersects, key_binned, - connector_idx, + connector_counter, preserve_kind=True, ) new_joins.extend(extra) @@ -882,7 +908,6 @@ def _transform_bridge(self, query: exp.Select) -> exp.Select: continue new_joins.append(join) - # Implicit cross-join: FROM a, b WHERE a.interval INTERSECTS b.interval where = query.args.get("where") if where: intersects = self._find_column_intersects_in(where.this) @@ -897,7 +922,7 @@ def _transform_bridge(self, query: exp.Select) -> exp.Select: cross_join, intersects, key_binned, - connector_idx, + connector_counter, preserve_kind=False, ) new_joins.extend(extra) @@ -911,7 +936,12 @@ def _transform_bridge(self, query: exp.Select) -> exp.Select: return query def _find_column_intersects_in(self, expr: exp.Expression) -> Intersects | None: - """Find an Intersects node where both sides are table-qualified columns.""" + """Return the first column-to-column Intersects node in *expr*, or None. + + Only the first match is returned. A single JOIN with multiple + INTERSECTS conditions in its ON clause is not supported; only the + first will be rewritten. + """ for node in expr.find_all(Intersects): if ( isinstance(node.this, exp.Column) @@ -957,18 +987,11 @@ def _remove_intersects_from_where( where = query.args.get("where") if not where: return - where_expr = where.this - if where_expr is intersects: + remainder = self._extract_non_intersects(where.this, intersects) + if remainder is None: query.set("where", None) - elif isinstance(where_expr, exp.And): - if where_expr.this is intersects: - query.set("where", exp.Where(this=where_expr.expression)) - elif where_expr.expression is intersects: - query.set("where", exp.Where(this=where_expr.this)) - else: - intersects.replace(exp.true()) else: - intersects.replace(exp.true()) + query.set("where", exp.Where(this=remainder)) def _extract_non_intersects( self, expr: exp.Expression | None, intersects: Intersects @@ -1034,30 +1057,7 @@ def _build_key_only_bins_select( ) -> exp.Select: """Build ``SELECT chrom, start, end, UNNEST(range(...)) AS __giql_bin FROM table``.""" chrom, start, end = cols - B = self.bin_size - - low = exp.Cast( - this=exp.Div( - this=exp.column(start, quoted=True), - expression=exp.Literal.number(B), - ), - to=exp.DataType(this=exp.DataType.Type.BIGINT), - ) - high = exp.Cast( - this=exp.Add( - this=exp.Div( - this=exp.Paren( - this=exp.Sub( - this=exp.column(end, quoted=True), - expression=exp.Literal.number(1), - ), - ), - expression=exp.Literal.number(B), - ), - expression=exp.Literal.number(1), - ), - to=exp.DataType(this=exp.DataType.Type.BIGINT), - ) + low, high = self._build_bin_range(start, end) range_fn = exp.Anonymous(this="range", expressions=[low, high]) unnest_fn = exp.Anonymous(this="UNNEST", expressions=[range_fn]) @@ -1106,16 +1106,18 @@ def _build_join_back_joins( join: exp.Join, intersects: Intersects, key_binned: dict[str, str], - connector_idx: list[int], + connector_counter: itertools.count, *, preserve_kind: bool, ) -> list[exp.Join]: """Build three replacement JOINs for one INTERSECTS using the join-back pattern. - join1: JOIN key_cte_for_other connector_a ON other_alias key-matches connector_a - join2: JOIN key_cte_for_join connector_b ON connector_a equi-joins connector_b - join3: JOIN original_join_table join_alias ON join_alias key-matches connector_b - AND overlap predicate + join1 is always INNER because it key-matches a table against its + own bin CTE — every row has a corresponding bin entry by + construction, so the join side has no effect. + + join2 and join3 inherit the original join's side (LEFT, RIGHT) + when *preserve_kind* is True. """ join_table = join.this if not isinstance(join_table, exp.Table): @@ -1139,9 +1141,8 @@ def _build_join_back_joins( other_cte = self._ensure_key_binned(query, other_table_name, key_binned) join_cte = self._ensure_key_binned(query, join_table_name, key_binned) - c0 = f"__giql_c{connector_idx[0]}" - c1 = f"__giql_c{connector_idx[0] + 1}" - connector_idx[0] += 2 + c0 = f"__giql_c{next(connector_counter)}" + c1 = f"__giql_c{next(connector_counter)}" join_side = None if preserve_kind: From 109fceb869991380d1bc2c4c5db795eddf8b4379 Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 12:26:37 -0400 Subject: [PATCH 16/20] style: Remove structural comment from transpile.py --- src/giql/transpile.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/giql/transpile.py b/src/giql/transpile.py index df0264f..9bf8076 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -133,7 +133,6 @@ def transpile( # Apply transformations try: - # Binned join rewrite for column-to-column INTERSECTS joins ast = intersects_transformer.transform(ast) # MERGE transformation (which may internally use CLUSTER) ast = merge_transformer.transform(ast) From 5b4a94ad24c81fa1abe4fcb2e2d4cd4bd6fa081a Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 13:47:50 -0400 Subject: [PATCH 17/20] test: Add property-based bedtools correctness tests for INTERSECTS Use hypothesis to generate random intervals spanning multiple bins and verify that the binned equi-join produces identical results to bedtools intersect -u. Three tests cover two-table joins, self-joins, and varying bin sizes (100 to 100k). Intervals use unique names to avoid the known DISTINCT duplicate-collapse limitation. --- .../bedtools/test_intersect_property.py | 176 ++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 tests/integration/bedtools/test_intersect_property.py diff --git a/tests/integration/bedtools/test_intersect_property.py b/tests/integration/bedtools/test_intersect_property.py new file mode 100644 index 0000000..0c59fb5 --- /dev/null +++ b/tests/integration/bedtools/test_intersect_property.py @@ -0,0 +1,176 @@ +"""Property-based correctness tests for INTERSECTS binned equi-join. + +These tests use hypothesis to generate random genomic intervals of +varying sizes — including intervals that span multiple bins — and +verify that GIQL's binned equi-join produces identical results to +bedtools intersect. +""" + +from hypothesis import HealthCheck +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +from giql import transpile + +from .utils.bedtools_wrapper import intersect +from .utils.comparison import compare_results +from .utils.data_models import GenomicInterval +from .utils.duckdb_loader import load_intervals + +duckdb = __import__("pytest").importorskip("duckdb") + + +# --------------------------------------------------------------------------- +# Strategies +# --------------------------------------------------------------------------- + +CHROMS = ["chr1", "chr2", "chr3"] + + +@st.composite +def genomic_interval_st(draw, idx=None): + """Generate a random GenomicInterval that can span multiple 10k bins.""" + chrom = draw(st.sampled_from(CHROMS)) + start = draw(st.integers(min_value=0, max_value=1_000_000)) + length = draw(st.integers(min_value=1, max_value=200_000)) + score = draw(st.integers(min_value=0, max_value=1000)) + strand = draw(st.sampled_from(["+", "-"])) + # Name is set by the list strategy to guarantee uniqueness, avoiding + # the known DISTINCT duplicate-collapse limitation. + name = ( + f"r{idx}" + if idx is not None + else draw(st.from_regex(r"r[0-9]{1,6}", fullmatch=True)) + ) + return GenomicInterval(chrom, start, start + length, name, score, strand) + + +@st.composite +def unique_interval_list_st(draw, max_size=60): + """Generate a list of intervals with unique names.""" + n = draw(st.integers(min_value=1, max_value=max_size)) + intervals = [] + for i in range(n): + iv = draw(genomic_interval_st(idx=i)) + intervals.append(iv) + return intervals + + +interval_list_st = unique_interval_list_st() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _run_giql(intervals_a, intervals_b): + """Run the binned-join INTERSECTS query via DuckDB and return result rows.""" + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + return conn.execute(sql).fetchall() + finally: + conn.close() + + +def _run_bedtools(intervals_a, intervals_b): + """Run bedtools intersect -u and return result tuples.""" + return intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@given(intervals_a=interval_list_st, intervals_b=interval_list_st) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_binned_join_matches_bedtools(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of genomic intervals + WHEN GIQL INTERSECTS binned equi-join is executed + THEN results match bedtools intersect -u exactly + """ + giql_result = _run_giql(intervals_a, intervals_b) + bedtools_result = _run_bedtools(intervals_a, intervals_b) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +@given(intervals=interval_list_st) +@settings( + max_examples=30, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_self_join_matches_bedtools(intervals): + """ + GIVEN a randomly generated set of genomic intervals + WHEN GIQL INTERSECTS self-join is executed + THEN results match bedtools intersect -u with the same file as A and B + """ + giql_result = _run_giql(intervals, intervals) + bedtools_result = _run_bedtools(intervals, intervals) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +@given( + intervals_a=unique_interval_list_st(max_size=30), + intervals_b=unique_interval_list_st(max_size=30), + bin_size=st.sampled_from([100, 1_000, 10_000, 100_000]), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_bin_size_does_not_affect_correctness(intervals_a, intervals_b, bin_size): + """ + GIVEN two randomly generated sets of genomic intervals and a bin size + WHEN GIQL INTERSECTS is executed with that bin size + THEN results match bedtools intersect -u regardless of bin size + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + bin_size=bin_size, + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = _run_bedtools(intervals_a, intervals_b) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, f"bin_size={bin_size}: {comparison.failure_message()}" From 2eb59320880d2dd3a67475460f0d81910e11692b Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 15:48:26 -0400 Subject: [PATCH 18/20] fix: Use IntDiv for bin indices and pairs CTE for outer joins Two correctness fixes for the binned equi-join rewrite: 1. Bin index rounding: CAST(start / B AS BIGINT) uses float division, so values like 621950/100 = 6219.5 round to 6220 instead of flooring to 6219. Replace Div+Cast with IntDiv (//) which does proper integer floor division on all engines. 2. Outer join spurious NULLs: when an interval spans multiple bins, the LEFT/RIGHT/FULL outer join produces one row per bin. Bins that don't match the other side create NULL rows even though the same source row matches via a different bin. DISTINCT can't collapse these because NULL and non-NULL rows differ. Add a pairs-CTE approach that computes matching (left_key, right_key) pairs via an INNER binned join with DISTINCT, then outer-joins the original tables through this pairs CTE. This matches the pattern used by Databricks and Snowflake, which restrict binning to INNER joins and use separate logic for outer join semantics. --- src/giql/transformer.py | 297 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 269 insertions(+), 28 deletions(-) diff --git a/src/giql/transformer.py b/src/giql/transformer.py index 1507389..789e3c5 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -656,10 +656,14 @@ def transform(self, query: exp.Expression) -> exp.Expression: if not isinstance(query, exp.Select): return query - # The bridge path can't faithfully represent FULL OUTER JOIN - # because the three-join chain's bin fan-out creates spurious - # unmatched rows. Fall back to full-CTE for those queries. - if self._select_has_wildcards(query) and not self._has_full_outer_join(query): + # Outer joins need the pairs-CTE approach: compute matching key + # pairs via an INNER binned join (correctly deduplicated), then + # outer-join the original tables through the pairs CTE. This + # avoids the bin fan-out that creates spurious NULL rows when an + # interval spans multiple bins but only matches in some of them. + if self._has_outer_join_intersects(query): + return self._transform_with_pairs(query) + if self._select_has_wildcards(query): return self._transform_bridge(query) return self._transform_full_cte(query) @@ -672,13 +676,253 @@ def _select_has_wildcards(self, query: exp.Select) -> bool: return True return False - def _has_full_outer_join(self, query: exp.Select) -> bool: - """Return True if any JOIN in the query is a FULL OUTER JOIN.""" + def _has_outer_join_intersects(self, query: exp.Select) -> bool: + """Return True if any outer JOIN has an INTERSECTS predicate.""" for join in query.args.get("joins") or []: - if join.args.get("side") == "FULL": - return True + if join.args.get("side") and join.args.get("on"): + if self._find_column_intersects_in(join.args["on"]): + return True return False + def _transform_with_pairs(self, query: exp.Select) -> exp.Select: + """Transform using a pairs CTE for correct outer join semantics. + + Computes matching (left_key, right_key) pairs via an INNER + binned join with DISTINCT, then outer-joins the original tables + through this pairs CTE. This avoids bin fan-out on the + preserved side of the outer join. + """ + joins = query.args.get("joins") or [] + key_binned: dict[str, str] = {} + pairs_idx = 0 + new_joins: list[exp.Join] = [] + rewrote_any = False + + for join in joins: + on = join.args.get("on") + if on: + intersects = self._find_column_intersects_in(on) + if intersects: + extra = self._extract_non_intersects(on, intersects) + replacement = self._build_pairs_replacement_joins( + query, join, intersects, extra, key_binned, pairs_idx + ) + new_joins.extend(replacement) + pairs_idx += 1 + rewrote_any = True + continue + new_joins.append(join) + + where = query.args.get("where") + if where: + intersects = self._find_column_intersects_in(where.this) + if intersects: + cross_join = self._find_cross_join_for_intersects( + query, intersects, new_joins + ) + if cross_join is not None: + new_joins.remove(cross_join) + replacement = self._build_pairs_replacement_joins( + query, + cross_join, + intersects, + None, + key_binned, + pairs_idx, + ) + new_joins.extend(replacement) + self._remove_intersects_from_where(query, intersects) + pairs_idx += 1 + rewrote_any = True + + if rewrote_any: + query.set("joins", new_joins) + query.set("distinct", exp.Distinct()) + + return query + + def _build_pairs_cte( + self, + name: str, + l_cte: str, + r_cte: str, + l_cols: tuple[str, str, str], + r_cols: tuple[str, str, str], + ) -> exp.CTE: + """Build a DISTINCT inner-join pairs CTE. + + Returns a CTE named *name* that selects the six key columns + (__giql_l_chrom, __giql_l_start, __giql_l_end, __giql_r_chrom, + __giql_r_start, __giql_r_end) from an INNER join of the two bin + CTEs on chrom, __giql_bin, and the overlap predicate. + """ + l_alias = "__giql_l" + r_alias = "__giql_r" + + select = exp.Select() + select.set("distinct", exp.Distinct()) + + first = True + for tbl_alias, cols, prefix in [ + (l_alias, l_cols, "__giql_l"), + (r_alias, r_cols, "__giql_r"), + ]: + for col, suffix in zip(cols, ["_chrom", "_start", "_end"]): + col_expr = exp.Alias( + this=exp.column(col, table=tbl_alias, quoted=True), + alias=exp.Identifier(this=f"{prefix}{suffix}"), + ) + select.select(col_expr, append=not first, copy=False) + first = False + + select.from_( + exp.Table( + this=exp.Identifier(this=l_cte), + alias=exp.TableAlias(this=exp.Identifier(this=l_alias)), + ), + copy=False, + ) + + join_on = exp.And( + this=exp.And( + this=exp.EQ( + this=exp.column(l_cols[0], table=l_alias, quoted=True), + expression=exp.column(r_cols[0], table=r_alias, quoted=True), + ), + expression=exp.EQ( + this=exp.column("__giql_bin", table=l_alias), + expression=exp.column("__giql_bin", table=r_alias), + ), + ), + expression=self._build_overlap(l_alias, r_alias, l_cols, r_cols), + ) + + select.join( + exp.Table( + this=exp.Identifier(this=r_cte), + alias=exp.TableAlias(this=exp.Identifier(this=r_alias)), + ), + on=join_on, + copy=False, + ) + + return exp.CTE( + this=select, + alias=exp.TableAlias(this=exp.Identifier(this=name)), + ) + + def _build_pairs_replacement_joins( + self, + query: exp.Select, + join: exp.Join, + intersects: Intersects, + extra: exp.Expression | None, + key_binned: dict[str, str], + pairs_idx: int, + ) -> list[exp.Join]: + """Build a pairs CTE and two replacement joins for one INTERSECTS. + + Returns two joins: + - join1: from_alias [SIDE] JOIN __giql_pairs ON from.key = pairs.from_key + - join2: [SIDE] JOIN join_table ON join.key = pairs.join_key [AND extra] + """ + from_table = query.args["from_"].this + join_table = join.this + if not isinstance(from_table, exp.Table) or not isinstance( + join_table, exp.Table + ): + return [join] + + from_alias = from_table.alias or from_table.name + join_alias = join_table.alias or join_table.name + from_table_name = from_table.name + join_table_name = join_table.name + + left_alias = intersects.this.table + from_cols = self._get_columns(from_table_name) + join_cols = self._get_columns(join_table_name) + + # Determine which INTERSECTS side maps to FROM vs JOIN table + if left_alias == from_alias: + l_table_name, r_table_name = from_table_name, join_table_name + l_cols, r_cols = from_cols, join_cols + from_prefix, join_prefix = "__giql_l", "__giql_r" + else: + l_table_name, r_table_name = join_table_name, from_table_name + l_cols, r_cols = join_cols, from_cols + from_prefix, join_prefix = "__giql_r", "__giql_l" + + # Ensure key-only bin CTEs exist + l_cte = self._ensure_key_binned(query, l_table_name, key_binned) + r_cte = self._ensure_key_binned(query, r_table_name, key_binned) + + # Build and attach the pairs CTE + pairs_name = f"__giql_pairs_{pairs_idx}" + pairs_cte = self._build_pairs_cte(pairs_name, l_cte, r_cte, l_cols, r_cols) + existing_with = query.args.get("with_") + if existing_with: + existing_with.append("expressions", pairs_cte) + else: + query.set("with_", exp.With(expressions=[pairs_cte])) + + side = join.args.get("side") + p_alias = f"__giql_p{pairs_idx}" + + # join1: [SIDE] JOIN pairs ON from.key = pairs.from_key + join1_on = self._build_key_match(from_alias, from_cols, p_alias, from_prefix) + join1_kwargs: dict = { + "this": exp.Table( + this=exp.Identifier(this=pairs_name), + alias=exp.TableAlias(this=exp.Identifier(this=p_alias)), + ), + "on": join1_on, + } + if side: + join1_kwargs["side"] = side + join1 = exp.Join(**join1_kwargs) + + # join2: [SIDE] JOIN join_table ON join.key = pairs.join_key + join2_on = self._build_key_match(join_alias, join_cols, p_alias, join_prefix) + if extra: + join2_on = exp.And(this=join2_on, expression=extra) + join2_kwargs: dict = { + "this": exp.Table( + this=exp.Identifier(this=join_table_name), + alias=exp.TableAlias(this=exp.Identifier(this=join_alias)), + ), + "on": join2_on, + } + if side: + join2_kwargs["side"] = side + join2 = exp.Join(**join2_kwargs) + + return [join1, join2] + + def _build_key_match( + self, + table_alias: str, + cols: tuple[str, str, str], + pairs_alias: str, + prefix: str, + ) -> exp.And: + """Build ``table.chrom = pairs.prefix_chrom AND ...`` for all three keys.""" + return exp.And( + this=exp.And( + this=exp.EQ( + this=exp.column(cols[0], table=table_alias, quoted=True), + expression=exp.column(f"{prefix}_chrom", table=pairs_alias), + ), + expression=exp.EQ( + this=exp.column(cols[1], table=table_alias, quoted=True), + expression=exp.column(f"{prefix}_start", table=pairs_alias), + ), + ), + expression=exp.EQ( + this=exp.column(cols[2], table=table_alias, quoted=True), + expression=exp.column(f"{prefix}_end", table=pairs_alias), + ), + ) + def _transform_full_cte(self, query: exp.Select) -> exp.Select: joins = query.args.get("joins") or [] binned: dict[str, tuple[str, str, str]] = {} @@ -831,35 +1075,32 @@ def _ensure_table_binned_full( binned[alias] = cols return alias, cols - def _build_bin_range(self, start: str, end: str) -> tuple[exp.Cast, exp.Cast]: + def _build_bin_range( + self, start: str, end: str + ) -> tuple[exp.Expression, exp.Expression]: """Build the (low, high) bin-index expressions for UNNEST(range(...)). - Returns CAST(start / bin_size AS BIGINT) and - CAST((end - 1) / bin_size + 1 AS BIGINT). + Returns ``start // bin_size`` and ``(end - 1) // bin_size + 1``. + Uses integer floor division to avoid rounding errors from + float division + CAST. """ bs = self.bin_size - low = exp.Cast( - this=exp.Div( - this=exp.column(start, quoted=True), - expression=exp.Literal.number(bs), - ), - to=exp.DataType(this=exp.DataType.Type.BIGINT), + low = exp.IntDiv( + this=exp.column(start, quoted=True), + expression=exp.Literal.number(bs), ) - high = exp.Cast( - this=exp.Add( - this=exp.Div( - this=exp.Paren( - this=exp.Sub( - this=exp.column(end, quoted=True), - expression=exp.Literal.number(1), - ), + high = exp.Add( + this=exp.IntDiv( + this=exp.Paren( + this=exp.Sub( + this=exp.column(end, quoted=True), + expression=exp.Literal.number(1), ), - expression=exp.Literal.number(bs), ), - expression=exp.Literal.number(1), + expression=exp.Literal.number(bs), ), - to=exp.DataType(this=exp.DataType.Type.BIGINT), + expression=exp.Literal.number(1), ) return low, high From bc2667b1f880d9a56d98b886d0644c9bb66a5780 Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 15:48:40 -0400 Subject: [PATCH 19/20] test: Add regression and property-based tests for binned join fixes Add regression tests for both bugs found by property-based testing: - TestBinnedJoinBinBoundaryRounding: verifies that overlaps at .5 division boundaries are not dropped by float rounding (DuckDB). - TestBinnedJoinOuterJoinMultiBin: verifies that LEFT, RIGHT, and FULL OUTER joins with multi-bin intervals produce no spurious NULL rows (DataFusion). Add property-based bedtools correctness tests: - test_multi_table_join_matches_bedtools: three-way INTERSECTS join compared against chained bedtools intersect. - test_left_join_matches_bedtools_loj: LEFT JOIN INTERSECTS compared against bedtools intersect -loj. Add -loj support to the bedtools wrapper for left outer join output. --- .../bedtools/test_intersect_property.py | 94 ++++++ .../bedtools/utils/bedtools_wrapper.py | 49 ++- tests/test_binned_join.py | 285 ++++++++++++++++++ 3 files changed, 425 insertions(+), 3 deletions(-) diff --git a/tests/integration/bedtools/test_intersect_property.py b/tests/integration/bedtools/test_intersect_property.py index 0c59fb5..9dee404 100644 --- a/tests/integration/bedtools/test_intersect_property.py +++ b/tests/integration/bedtools/test_intersect_property.py @@ -174,3 +174,97 @@ def test_bin_size_does_not_affect_correctness(intervals_a, intervals_b, bin_size comparison = compare_results(giql_result, bedtools_result) assert comparison.match, f"bin_size={bin_size}: {comparison.failure_message()}" + + +@given( + intervals_a=unique_interval_list_st(max_size=8), + intervals_b=unique_interval_list_st(max_size=8), + intervals_c=unique_interval_list_st(max_size=8), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_multi_table_join_matches_bedtools(intervals_a, intervals_b, intervals_c): + """ + GIVEN three randomly generated sets of genomic intervals + WHEN GIQL three-way INTERSECTS join is executed + THEN the A-side rows match bedtools intersect chained A->B then ->C + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + load_intervals(conn, "intervals_c", [i.to_tuple() for i in intervals_c]) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + JOIN intervals_c c ON a.interval INTERSECTS c.interval + """, + tables=["intervals_a", "intervals_b", "intervals_c"], + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + # bedtools equivalent: chain A∩B then filter against C + tuples_a = [i.to_tuple() for i in intervals_a] + tuples_b = [i.to_tuple() for i in intervals_b] + tuples_c = [i.to_tuple() for i in intervals_c] + ab_result = intersect(tuples_a, tuples_b) + if ab_result: + bedtools_result = intersect(ab_result, tuples_c) + else: + bedtools_result = [] + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +@given( + intervals_a=unique_interval_list_st(max_size=30), + intervals_b=unique_interval_list_st(max_size=30), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_left_join_matches_bedtools_loj(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of genomic intervals + WHEN GIQL LEFT JOIN INTERSECTS is executed + THEN results match bedtools intersect -loj exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT DISTINCT + a.chrom, a.start, a.end, a.name, a.score, a.strand, + b.chrom AS b_chrom, b.start AS b_start, b.end AS b_end, + b.name AS b_name, b.score AS b_score, b.strand AS b_strand + FROM intervals_a a + LEFT JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + loj=True, + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() diff --git a/tests/integration/bedtools/utils/bedtools_wrapper.py b/tests/integration/bedtools/utils/bedtools_wrapper.py index c61be44..e33d57e 100644 --- a/tests/integration/bedtools/utils/bedtools_wrapper.py +++ b/tests/integration/bedtools/utils/bedtools_wrapper.py @@ -30,19 +30,32 @@ def intersect( intervals_a: list[tuple], intervals_b: list[tuple], strand_mode: str | None = None, + *, + loj: bool = False, ) -> list[tuple]: - """Find overlapping intervals using bedtools intersect.""" + """Find overlapping intervals using bedtools intersect. + + When *loj* is True, use left outer join mode (-loj): every A + interval appears in the output, paired with overlapping B + intervals or with null-placeholder fields when there is no match. + """ try: bt_a = create_bedtool(intervals_a) bt_b = create_bedtool(intervals_b) - kwargs = {"u": True} + if loj: + kwargs = {"loj": True} + else: + kwargs = {"u": True} if strand_mode == "same": kwargs["s"] = True elif strand_mode == "opposite": kwargs["S"] = True result = bt_a.intersect(bt_b, **kwargs) + + if loj: + return bedtool_to_tuples(result, bed_format="loj") return bedtool_to_tuples(result) except Exception as e: @@ -102,7 +115,11 @@ def bedtool_to_tuples( Args: bedtool: pybedtools.BedTool object - bed_format: Expected format ('bed3', 'bed6', or 'closest') + bed_format: Expected format ('bed3', 'bed6', 'loj', or 'closest') + + LOJ format assumes BED6(A)+BED6(B) (12 fields): + Fields 0-5: A interval + Fields 6-11: B interval (all '.' / -1 when unmatched) Closest format assumes BED6+BED6+distance (13 fields): Fields 0-5: A interval (chrom, start, end, name, score, strand) @@ -137,6 +154,32 @@ def bedtool_to_tuples( ) ) + elif bed_format == "loj": + if len(fields) < 12: + raise ValueError(f"Unexpected number of fields for loj: {len(fields)}") + + def _loj_field(val, as_int=False): + if val == "." or val == "-1": + return None + return int(val) if as_int else val + + rows.append( + ( + fields[0], + int(fields[1]), + int(fields[2]), + fields[3] if fields[3] != "." else None, + int(fields[4]) if fields[4] != "." else None, + fields[5] if fields[5] != "." else None, + _loj_field(fields[6]), + _loj_field(fields[7], as_int=True), + _loj_field(fields[8], as_int=True), + _loj_field(fields[9]), + _loj_field(fields[10], as_int=True), + _loj_field(fields[11]), + ) + ) + elif bed_format == "closest": if len(fields) < 13: raise ValueError( diff --git a/tests/test_binned_join.py b/tests/test_binned_join.py index dc54d8e..3c50757 100644 --- a/tests/test_binned_join.py +++ b/tests/test_binned_join.py @@ -1322,3 +1322,288 @@ def test_user_distinct_already_present_still_works(self): df = ctx.sql(binned_sql).to_pandas() assert len(df) == 1 + + +class TestBinnedJoinBinBoundaryRounding: + """Regression tests for bin-index calculation rounding errors. + + The original formula CAST(start / B AS BIGINT) uses float division + followed by a cast. When the division lands on x.5 the cast rounds + to nearest-even instead of flooring, producing the wrong bin index + and causing missed matches. + """ + + @staticmethod + def _make_ctx(table_a_data, table_b_data): + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ("name", pa.utf8()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "intervals_a", + [pa.table(table_a_data, schema=schema).to_batches()], + ) + ctx.register_record_batches( + "intervals_b", + [pa.table(table_b_data, schema=schema).to_batches()], + ) + return ctx + + def test_half_bin_boundary_overlap_not_missed(self): + """ + GIVEN interval A spanning many bins and interval B whose start + falls exactly on a .5 division boundary (e.g., 621950/100) + WHEN INTERSECTS is evaluated with bin_size=100 on DuckDB + THEN the overlap must be found, not missed due to rounding + """ + import duckdb + + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE intervals_a " + '(chrom VARCHAR, "start" INTEGER, "end" INTEGER, ' + "name VARCHAR)" + ) + conn.execute( + "CREATE TABLE intervals_b " + '(chrom VARCHAR, "start" INTEGER, "end" INTEGER, ' + "name VARCHAR)" + ) + conn.execute("INSERT INTO intervals_a VALUES ('chr1', 421951, 621951, 'a0')") + conn.execute("INSERT INTO intervals_b VALUES ('chr1', 621950, 621951, 'b0')") + + sql = transpile( + """ + SELECT DISTINCT a.name, b.name AS b_name + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + bin_size=100, + ) + result = conn.execute(sql).fetchall() + conn.close() + assert len(result) == 1, ( + f"Expected 1 match, got {len(result)} — " + f"bin boundary rounding likely dropped the overlap" + ) + + def test_exact_bin_boundary_start(self): + """ + GIVEN interval B starting at an exact multiple of bin_size + WHEN INTERSECTS is evaluated on DuckDB + THEN the correct bin index is assigned (no off-by-one from rounding) + """ + import duckdb + + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE intervals_a " + '(chrom VARCHAR, "start" INTEGER, "end" INTEGER, ' + "name VARCHAR)" + ) + conn.execute( + "CREATE TABLE intervals_b " + '(chrom VARCHAR, "start" INTEGER, "end" INTEGER, ' + "name VARCHAR)" + ) + conn.execute("INSERT INTO intervals_a VALUES ('chr1', 999, 1001, 'a0')") + conn.execute("INSERT INTO intervals_b VALUES ('chr1', 1000, 1001, 'b0')") + + sql = transpile( + """ + SELECT DISTINCT a.name, b.name AS b_name + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + bin_size=1000, + ) + result = conn.execute(sql).fetchall() + conn.close() + assert len(result) == 1, f"Expected 1 match at bin boundary, got {len(result)}" + + +class TestBinnedJoinOuterJoinMultiBin: + """Regression tests for outer join with multi-bin intervals. + + When an interval spans multiple bins, the outer join produces one + row per bin. Bins that don't match the other side create spurious + NULL rows. DISTINCT can't collapse a NULL row with a matched row + because they differ in the non-NULL columns. + """ + + @staticmethod + def _make_ctx(table_a_data, table_b_data): + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ("name", pa.utf8()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "intervals_a", + [pa.table(table_a_data, schema=schema).to_batches()], + ) + ctx.register_record_batches( + "intervals_b", + [pa.table(table_b_data, schema=schema).to_batches()], + ) + return ctx + + def test_left_join_no_spurious_null_row(self): + """ + GIVEN interval A spanning bins 0 and 1 and interval B only in bin 1 + WHEN LEFT JOIN INTERSECTS is evaluated + THEN only 1 matched row is returned, not a matched row plus a + spurious NULL row from the unmatched bin-0 copy + """ + ctx = self._make_ctx( + { + "chrom": ["chr1"], + "start": [9000], + "end": [11000], + "name": ["a0"], + }, + { + "chrom": ["chr1"], + "start": [10500], + "end": [10600], + "name": ["b0"], + }, + ) + + sql = transpile( + """ + SELECT a.name, b.name AS b_name + FROM intervals_a a + LEFT JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + result = ctx.sql(sql).to_pandas() + assert len(result) == 1, ( + f"Expected 1 matched row, got {len(result)} — " + f"spurious NULL row from unmatched bin" + ) + assert result.iloc[0]["b_name"] == "b0" + + def test_left_join_unmatched_still_returns_null(self): + """ + GIVEN interval A with no overlap in B + WHEN LEFT JOIN INTERSECTS is evaluated + THEN one row with NULL B columns is returned + """ + ctx = self._make_ctx( + { + "chrom": ["chr1"], + "start": [9000], + "end": [11000], + "name": ["a0"], + }, + { + "chrom": ["chr2"], + "start": [9500], + "end": [10500], + "name": ["b0"], + }, + ) + + sql = transpile( + """ + SELECT a.name, b.name AS b_name + FROM intervals_a a + LEFT JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + result = ctx.sql(sql).to_pandas() + assert len(result) == 1, f"Expected 1 unmatched row, got {len(result)}" + assert _is_null(result.iloc[0]["b_name"]) + + def test_right_join_no_spurious_null_row(self): + """ + GIVEN interval B spanning bins 0 and 1 and interval A only in bin 0 + WHEN RIGHT JOIN INTERSECTS is evaluated + THEN only 1 matched row is returned, not a matched row plus a + spurious NULL row from the unmatched bin-1 copy of B + """ + ctx = self._make_ctx( + { + "chrom": ["chr1"], + "start": [9500], + "end": [9600], + "name": ["a0"], + }, + { + "chrom": ["chr1"], + "start": [9000], + "end": [11000], + "name": ["b0"], + }, + ) + + sql = transpile( + """ + SELECT a.name, b.name AS b_name + FROM intervals_a a + RIGHT JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + result = ctx.sql(sql).to_pandas() + assert len(result) == 1, ( + f"Expected 1 matched row, got {len(result)} — " + f"spurious NULL row from unmatched bin" + ) + assert result.iloc[0]["name"] == "a0" + + def test_full_outer_join_no_spurious_null_row(self): + """ + GIVEN interval A spanning bins 0 and 1, interval B only in bin 1 + WHEN FULL OUTER JOIN INTERSECTS is evaluated + THEN only 1 matched row is returned, not a matched row plus a + spurious NULL row + """ + ctx = self._make_ctx( + { + "chrom": ["chr1"], + "start": [9000], + "end": [11000], + "name": ["a0"], + }, + { + "chrom": ["chr1"], + "start": [10500], + "end": [10600], + "name": ["b0"], + }, + ) + + sql = transpile( + """ + SELECT a.name, b.name AS b_name + FROM intervals_a a + FULL OUTER JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + result = ctx.sql(sql).to_pandas() + assert len(result) == 1, ( + f"Expected 1 matched row, got {len(result)} — " + f"spurious NULL row from unmatched bin" + ) From b448d03e1b0cd10adf70f03c5f529ba2de8b50f9 Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 16:31:06 -0400 Subject: [PATCH 20/20] test: Add property-based bedtools tests for all intersect flags Extend the bedtools wrapper to support -v, -wa -wb, -c, -wo, -wao, -f, -F, and -r flags. Add property-based tests that compare GIQL queries against each bedtools intersect mode: inverse (-v via LEFT JOIN anti-join), write-both (-wa -wb via full pair SELECT), count (-c via GROUP BY COUNT), same-strand (-s), opposite-strand (-S), minimum overlap fraction of A (-f), minimum overlap fraction of B (-F), and reciprocal fraction (-f -r). Total: 520 randomized examples across 13 property-based tests covering all bedtools intersect overlap flags. --- .../bedtools/test_intersect_property.py | 422 ++++++++++++++++++ .../bedtools/utils/bedtools_wrapper.py | 102 ++++- 2 files changed, 519 insertions(+), 5 deletions(-) diff --git a/tests/integration/bedtools/test_intersect_property.py b/tests/integration/bedtools/test_intersect_property.py index 9dee404..57c657d 100644 --- a/tests/integration/bedtools/test_intersect_property.py +++ b/tests/integration/bedtools/test_intersect_property.py @@ -268,3 +268,425 @@ def test_left_join_matches_bedtools_loj(intervals_a, intervals_b): comparison = compare_results(giql_result, bedtools_result) assert comparison.match, comparison.failure_message() + + +# --------------------------------------------------------------------------- +# -v (inverse / anti-join) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=30), + intervals_b=unique_interval_list_st(max_size=30), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_inverse_matches_bedtools_v(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of genomic intervals + WHEN GIQL anti-join (LEFT JOIN WHERE b IS NULL) is executed + THEN results match bedtools intersect -v exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a + LEFT JOIN intervals_b b ON a.interval INTERSECTS b.interval + WHERE b.chrom IS NULL + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + inverse=True, + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +# --------------------------------------------------------------------------- +# -wa -wb (write both A and B entries) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=20), + intervals_b=unique_interval_list_st(max_size=20), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_write_both_matches_bedtools_wa_wb(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of genomic intervals + WHEN GIQL INTERSECTS join selecting both sides is executed + THEN results match bedtools intersect -wa -wb exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT + a.chrom, a.start, a.end, a.name, a.score, a.strand, + b.chrom AS b_chrom, b.start AS b_start, b.end AS b_end, + b.name AS b_name, b.score AS b_score, b.strand AS b_strand + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + write_both=True, + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +# --------------------------------------------------------------------------- +# -c (count overlaps) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=20), + intervals_b=unique_interval_list_st(max_size=20), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_count_matches_bedtools_c(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of unique genomic intervals + WHEN GIQL COUNT of overlapping B per A is computed + THEN results match bedtools intersect -c exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + # Use a naive overlap join for counting — the binned join's + # DISTINCT would collapse duplicate B matches. + count_sql = """ + SELECT + a.chrom, a."start", a."end", a.name, a.score, a.strand, + COUNT(b.chrom) AS cnt + FROM intervals_a a + LEFT JOIN intervals_b b + ON a.chrom = b.chrom + AND a."start" < b."end" + AND a."end" > b."start" + GROUP BY a.chrom, a."start", a."end", a.name, a.score, a.strand + """ + giql_result = conn.execute(count_sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + count=True, + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +# --------------------------------------------------------------------------- +# -s (same strand) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=30), + intervals_b=unique_interval_list_st(max_size=30), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_same_strand_matches_bedtools_s(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of genomic intervals with strands + WHEN GIQL INTERSECTS with same-strand filter is executed + THEN results match bedtools intersect -s exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + AND a.strand = b.strand + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + strand_mode="same", + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +# --------------------------------------------------------------------------- +# -S (opposite strand) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=30), + intervals_b=unique_interval_list_st(max_size=30), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_opposite_strand_matches_bedtools_S(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of genomic intervals with strands + WHEN GIQL INTERSECTS with opposite-strand filter is executed + THEN results match bedtools intersect -S exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + AND a.strand != b.strand + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + strand_mode="opposite", + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +# --------------------------------------------------------------------------- +# -f (minimum overlap fraction of A) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=20), + intervals_b=unique_interval_list_st(max_size=20), + fraction=st.sampled_from([0.1, 0.25, 0.5, 0.75, 0.9]), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_fraction_a_matches_bedtools_f(intervals_a, intervals_b, fraction): + """ + GIVEN two randomly generated sets of genomic intervals and a fraction + WHEN GIQL INTERSECTS with minimum overlap fraction of A is executed + THEN results match bedtools intersect -f exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + inner_sql = transpile( + """ + SELECT DISTINCT + a.chrom, a.start, a.end, a.name, a.score, a.strand, + b.start AS b_start, b.end AS b_end + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + sql = f""" + SELECT DISTINCT chrom, "start", "end", name, score, strand + FROM ({inner_sql}) + WHERE (LEAST("end", b_end) - GREATEST("start", b_start)) + >= {fraction} * ("end" - "start") + """ + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + fraction_a=fraction, + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, f"fraction_a={fraction}: {comparison.failure_message()}" + + +# --------------------------------------------------------------------------- +# -F (minimum overlap fraction of B) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=20), + intervals_b=unique_interval_list_st(max_size=20), + fraction=st.sampled_from([0.1, 0.25, 0.5, 0.75, 0.9]), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_fraction_b_matches_bedtools_F(intervals_a, intervals_b, fraction): + """ + GIVEN two randomly generated sets of genomic intervals and a fraction + WHEN GIQL INTERSECTS with minimum overlap fraction of B is executed + THEN results match bedtools intersect -F exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + inner_sql = transpile( + """ + SELECT DISTINCT + a.chrom, a.start, a.end, a.name, a.score, a.strand, + b.start AS b_start, b.end AS b_end + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + sql = f""" + SELECT DISTINCT chrom, "start", "end", name, score, strand + FROM ({inner_sql}) + WHERE (LEAST("end", b_end) - GREATEST("start", b_start)) + >= {fraction} * (b_end - b_start) + """ + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + fraction_b=fraction, + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, f"fraction_b={fraction}: {comparison.failure_message()}" + + +# --------------------------------------------------------------------------- +# -r (reciprocal overlap fraction) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=20), + intervals_b=unique_interval_list_st(max_size=20), + fraction=st.sampled_from([0.1, 0.25, 0.5, 0.75]), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_reciprocal_fraction_matches_bedtools_r(intervals_a, intervals_b, fraction): + """ + GIVEN two randomly generated sets of genomic intervals and a fraction + WHEN GIQL INTERSECTS with reciprocal overlap fraction is executed + THEN results match bedtools intersect -f -F -r exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + inner_sql = transpile( + """ + SELECT DISTINCT + a.chrom, a.start, a.end, a.name, a.score, a.strand, + b.start AS b_start, b.end AS b_end + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + sql = f""" + SELECT DISTINCT chrom, "start", "end", name, score, strand + FROM ({inner_sql}) + WHERE (LEAST("end", b_end) - GREATEST("start", b_start)) + >= {fraction} * ("end" - "start") + AND (LEAST("end", b_end) - GREATEST("start", b_start)) + >= {fraction} * (b_end - b_start) + """ + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + # -r applies -f reciprocally to both sides and requires -wa output. + # Deduplicate to match GIQL's SELECT DISTINCT. + bedtools_raw = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + fraction_a=fraction, + reciprocal=True, + ) + bedtools_result = list(set(bedtools_raw)) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, ( + f"reciprocal fraction={fraction}: {comparison.failure_message()}" + ) diff --git a/tests/integration/bedtools/utils/bedtools_wrapper.py b/tests/integration/bedtools/utils/bedtools_wrapper.py index e33d57e..5a75304 100644 --- a/tests/integration/bedtools/utils/bedtools_wrapper.py +++ b/tests/integration/bedtools/utils/bedtools_wrapper.py @@ -32,30 +32,86 @@ def intersect( strand_mode: str | None = None, *, loj: bool = False, + inverse: bool = False, + write_both: bool = False, + count: bool = False, + write_overlap: bool = False, + write_all_overlap: bool = False, + fraction_a: float | None = None, + fraction_b: float | None = None, + reciprocal: bool = False, ) -> list[tuple]: """Find overlapping intervals using bedtools intersect. - When *loj* is True, use left outer join mode (-loj): every A - interval appears in the output, paired with overlapping B - intervals or with null-placeholder fields when there is no match. + Parameters + ---------- + loj : bool + Left outer join mode (-loj). + inverse : bool + Report A entries with NO overlap in B (-v). + write_both : bool + Write both A and B entries for each overlap (-wa -wb). + count : bool + Count B overlaps for each A feature (-c). + write_overlap : bool + Write overlap amount in bp for each pair (-wo). + write_all_overlap : bool + Write overlap amount for all A features including + non-overlapping (-wao). + fraction_a : float or None + Minimum overlap as fraction of A (-f). + fraction_b : float or None + Minimum overlap as fraction of B (-F). + reciprocal : bool + Require fraction thresholds on both sides (-r). """ try: bt_a = create_bedtool(intervals_a) bt_b = create_bedtool(intervals_b) + kwargs: dict = {} if loj: - kwargs = {"loj": True} + kwargs["loj"] = True + elif inverse: + kwargs["v"] = True + elif write_both: + kwargs["wa"] = True + kwargs["wb"] = True + elif count: + kwargs["c"] = True + elif write_overlap: + kwargs["wo"] = True + elif write_all_overlap: + kwargs["wao"] = True + elif reciprocal: + kwargs["wa"] = True else: - kwargs = {"u": True} + kwargs["u"] = True + if strand_mode == "same": kwargs["s"] = True elif strand_mode == "opposite": kwargs["S"] = True + if fraction_a is not None: + kwargs["f"] = fraction_a + if fraction_b is not None and not reciprocal: + kwargs["F"] = fraction_b + if reciprocal: + kwargs["r"] = True + result = bt_a.intersect(bt_b, **kwargs) if loj: return bedtool_to_tuples(result, bed_format="loj") + if write_both: + return bedtool_to_tuples(result, bed_format="loj") + if count: + return bedtool_to_tuples(result, bed_format="count") + if write_overlap: + return bedtool_to_tuples(result, bed_format="wo") + if write_all_overlap: + return bedtool_to_tuples(result, bed_format="wo") return bedtool_to_tuples(result) except Exception as e: @@ -154,6 +210,42 @@ def bedtool_to_tuples( ) ) + elif bed_format == "count": + while len(fields) < 7: + fields.append("0") + rows.append( + ( + fields[0], + int(fields[1]), + int(fields[2]), + fields[3] if fields[3] != "." else None, + int(fields[4]) if fields[4] != "." else None, + fields[5] if fields[5] != "." else None, + int(fields[6]), + ) + ) + + elif bed_format == "wo": + if len(fields) < 13: + raise ValueError(f"Unexpected number of fields for wo: {len(fields)}") + rows.append( + ( + fields[0], + int(fields[1]), + int(fields[2]), + fields[3] if fields[3] != "." else None, + int(fields[4]) if fields[4] != "." else None, + fields[5] if fields[5] != "." else None, + fields[6] if fields[6] != "." else None, + int(fields[7]) if fields[7] != "." else None, + int(fields[8]) if fields[8] != "." else None, + fields[9] if fields[9] != "." else None, + int(fields[10]) if fields[10] != "." else None, + fields[11] if fields[11] != "." else None, + int(fields[12]), + ) + ) + elif bed_format == "loj": if len(fields) < 12: raise ValueError(f"Unexpected number of fields for loj: {len(fields)}")