diff --git a/docs/dialect/spatial-operators.rst b/docs/dialect/spatial-operators.rst index 6bf4433..1a8b3cd 100644 --- a/docs/dialect/spatial-operators.rst +++ b/docs/dialect/spatial-operators.rst @@ -99,6 +99,24 @@ Find all variants, with gene information where available: FROM variants v LEFT JOIN genes g ON v.interval INTERSECTS g.interval +Deduplication Behavior +~~~~~~~~~~~~~~~~~~~~~~ + +Column-to-column ``INTERSECTS`` joins use a binned equi-join strategy internally: each interval is assigned to one or more fixed-width bins, and the join is performed on ``(chrom, bin)`` pairs. Because an interval that spans a bin boundary belongs to more than one bin, a single source row can match the same result row more than once. GIQL adds ``SELECT DISTINCT`` automatically to remove these duplicate rows. + +This deduplication is usually transparent, but it has one observable side effect: ``DISTINCT`` operates on the entire set of selected columns, so rows that are genuinely identical across every selected column will also be collapsed into one. This matters when a table contains duplicate source records with no distinguishing column. + +To prevent unintended deduplication, include any column that makes rows distinguishable — such as a primary key, name, or score — in the ``SELECT`` list: + +.. code-block:: sql + + -- score distinguishes otherwise-identical rows + SELECT v.chrom, v.start, v.end, v.score, g.name + FROM variants v + INNER JOIN genes g ON v.interval INTERSECTS g.interval + +If all columns are identical across two source rows (including any unique identifier), those rows represent the same logical record and collapsing them is correct behavior. + Related Operators ~~~~~~~~~~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index 480cf1e..647358b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dev = [ "pytest-cov>=4.0.0", "pytest>=7.0.0", "ruff>=0.1.0", + "datafusion>=52.3.0", ] docs = [ "sphinx>=7.0", @@ -82,6 +83,7 @@ bedtools = ">=2.31.0" pybedtools = ">=0.9.0" pytest = ">=7.0.0" pytest-cov = ">=4.0.0" +datafusion = ">=43.0.0" duckdb = ">=1.4.0" pandas = ">=2.0.0" sqlglot = ">=20.0.0,<30" diff --git a/src/giql/__init__.py b/src/giql/__init__.py index 71e895d..e5df351 100644 --- a/src/giql/__init__.py +++ b/src/giql/__init__.py @@ -3,6 +3,7 @@ A SQL dialect for genomic range queries. """ +from giql.constants import DEFAULT_BIN_SIZE from giql.table import Table from giql.transpile import transpile @@ -10,6 +11,7 @@ __all__ = [ + "DEFAULT_BIN_SIZE", "Table", "transpile", ] diff --git a/src/giql/constants.py b/src/giql/constants.py index 87f8055..0daf016 100644 --- a/src/giql/constants.py +++ b/src/giql/constants.py @@ -9,3 +9,6 @@ DEFAULT_END_COL = "end" DEFAULT_STRAND_COL = "strand" DEFAULT_GENOMIC_COL = "interval" + +# Default bin size for INTERSECTS binned equi-join optimization +DEFAULT_BIN_SIZE = 10_000 diff --git a/src/giql/transformer.py b/src/giql/transformer.py index de1e70f..789e3c5 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -1,17 +1,22 @@ """Query transformers for GIQL operations. This module contains transformers that rewrite queries containing GIQL-specific -operations (like CLUSTER and MERGE) into equivalent SQL with CTEs. +operations (like CLUSTER, MERGE, and binned INTERSECTS joins) into equivalent +SQL with CTEs. """ +import itertools + from sqlglot import exp +from giql.constants import DEFAULT_BIN_SIZE from giql.constants import DEFAULT_CHROM_COL from giql.constants import DEFAULT_END_COL from giql.constants import DEFAULT_START_COL from giql.constants import DEFAULT_STRAND_COL from giql.expressions import GIQLCluster from giql.expressions import GIQLMerge +from giql.expressions import Intersects from giql.table import Tables @@ -573,3 +578,897 @@ def _transform_for_merge( ) return final_query + + +class IntersectsBinnedJoinTransformer: + """Transforms column-to-column INTERSECTS into binned equi-joins. + + Handles both explicit JOIN ON and implicit cross-join (WHERE) patterns. + Two rewrite strategies are selected based on the SELECT list: + + **No wildcards** (``SELECT a.chrom, b.start, ...``) — ``__giql_bin`` + cannot appear in the output regardless of CTE content, so the simpler + 1-join full-CTE approach is used: + + WITH __giql_a_binned AS ( + SELECT *, UNNEST(range( + CAST("start" / B AS BIGINT), + CAST(("end" - 1) / B + 1 AS BIGINT) + )) AS __giql_bin FROM peaks + ), + __giql_b_binned AS (...) + SELECT DISTINCT a.chrom, b.start, ... + FROM __giql_a_binned AS a + JOIN __giql_b_binned AS b + ON a."chrom" = b."chrom" AND a.__giql_bin = b.__giql_bin + AND a."start" < b."end" AND a."end" > b."start" + + **Wildcards present** (``SELECT a.*, b.*``) — ``__giql_bin`` would leak + into ``a.*`` expansion if ``a`` aliases a full-select CTE. A key-only + bridge CTE pattern is used instead, keeping original table references: + + WITH __giql_peaks_bins AS ( + SELECT "chrom", "start", "end", + UNNEST(range(...)) AS __giql_bin FROM peaks + ), + __giql_genes_bins AS (...) + SELECT DISTINCT a.*, b.* + FROM peaks a + JOIN __giql_peaks_bins __giql_c0 + ON a."chrom" = __giql_c0."chrom" AND a."start" = __giql_c0."start" + AND a."end" = __giql_c0."end" + JOIN __giql_genes_bins __giql_c1 + ON __giql_c0."chrom" = __giql_c1."chrom" + AND __giql_c0.__giql_bin = __giql_c1.__giql_bin + JOIN genes b + ON b."chrom" = __giql_c1."chrom" AND b."start" = __giql_c1."start" + AND b."end" = __giql_c1."end" + AND a."start" < b."end" AND a."end" > b."start" + + Literal-range INTERSECTS (e.g., ``WHERE interval INTERSECTS 'chr1:...'``) + are left untouched. + + SELECT DISTINCT is added to deduplicate rows produced by multi-bin + matches. This means rows that are identical across every selected + column will be collapsed — include a distinguishing column (e.g., an + id or score) to preserve duplicates that differ only in unselected + columns. The bridge path's key-match joins on ``(chrom, start, + end)`` and may fan out if multiple source rows share those values; + DISTINCT corrects for this. + """ + + def __init__(self, tables: Tables, bin_size: int | None = None): + """Initialize transformer. + + :param tables: + Table configurations for column mapping + :param bin_size: + Bin width for the equi-join rewrite. Defaults to + DEFAULT_BIN_SIZE if not specified. + """ + self.tables = tables + resolved = bin_size if bin_size is not None else DEFAULT_BIN_SIZE + if not isinstance(resolved, int) or resolved <= 0: + raise ValueError(f"bin_size must be a positive integer, got {resolved!r}") + self.bin_size = resolved + + def transform(self, query: exp.Expression) -> exp.Expression: + if not isinstance(query, exp.Select): + return query + + # Outer joins need the pairs-CTE approach: compute matching key + # pairs via an INNER binned join (correctly deduplicated), then + # outer-join the original tables through the pairs CTE. This + # avoids the bin fan-out that creates spurious NULL rows when an + # interval spans multiple bins but only matches in some of them. + if self._has_outer_join_intersects(query): + return self._transform_with_pairs(query) + if self._select_has_wildcards(query): + return self._transform_bridge(query) + return self._transform_full_cte(query) + + def _select_has_wildcards(self, query: exp.Select) -> bool: + """Return True if any SELECT item is a wildcard (* or table.*).""" + for expr in query.expressions: + if isinstance(expr, exp.Star): + return True + if isinstance(expr, exp.Column) and isinstance(expr.this, exp.Star): + return True + return False + + def _has_outer_join_intersects(self, query: exp.Select) -> bool: + """Return True if any outer JOIN has an INTERSECTS predicate.""" + for join in query.args.get("joins") or []: + if join.args.get("side") and join.args.get("on"): + if self._find_column_intersects_in(join.args["on"]): + return True + return False + + def _transform_with_pairs(self, query: exp.Select) -> exp.Select: + """Transform using a pairs CTE for correct outer join semantics. + + Computes matching (left_key, right_key) pairs via an INNER + binned join with DISTINCT, then outer-joins the original tables + through this pairs CTE. This avoids bin fan-out on the + preserved side of the outer join. + """ + joins = query.args.get("joins") or [] + key_binned: dict[str, str] = {} + pairs_idx = 0 + new_joins: list[exp.Join] = [] + rewrote_any = False + + for join in joins: + on = join.args.get("on") + if on: + intersects = self._find_column_intersects_in(on) + if intersects: + extra = self._extract_non_intersects(on, intersects) + replacement = self._build_pairs_replacement_joins( + query, join, intersects, extra, key_binned, pairs_idx + ) + new_joins.extend(replacement) + pairs_idx += 1 + rewrote_any = True + continue + new_joins.append(join) + + where = query.args.get("where") + if where: + intersects = self._find_column_intersects_in(where.this) + if intersects: + cross_join = self._find_cross_join_for_intersects( + query, intersects, new_joins + ) + if cross_join is not None: + new_joins.remove(cross_join) + replacement = self._build_pairs_replacement_joins( + query, + cross_join, + intersects, + None, + key_binned, + pairs_idx, + ) + new_joins.extend(replacement) + self._remove_intersects_from_where(query, intersects) + pairs_idx += 1 + rewrote_any = True + + if rewrote_any: + query.set("joins", new_joins) + query.set("distinct", exp.Distinct()) + + return query + + def _build_pairs_cte( + self, + name: str, + l_cte: str, + r_cte: str, + l_cols: tuple[str, str, str], + r_cols: tuple[str, str, str], + ) -> exp.CTE: + """Build a DISTINCT inner-join pairs CTE. + + Returns a CTE named *name* that selects the six key columns + (__giql_l_chrom, __giql_l_start, __giql_l_end, __giql_r_chrom, + __giql_r_start, __giql_r_end) from an INNER join of the two bin + CTEs on chrom, __giql_bin, and the overlap predicate. + """ + l_alias = "__giql_l" + r_alias = "__giql_r" + + select = exp.Select() + select.set("distinct", exp.Distinct()) + + first = True + for tbl_alias, cols, prefix in [ + (l_alias, l_cols, "__giql_l"), + (r_alias, r_cols, "__giql_r"), + ]: + for col, suffix in zip(cols, ["_chrom", "_start", "_end"]): + col_expr = exp.Alias( + this=exp.column(col, table=tbl_alias, quoted=True), + alias=exp.Identifier(this=f"{prefix}{suffix}"), + ) + select.select(col_expr, append=not first, copy=False) + first = False + + select.from_( + exp.Table( + this=exp.Identifier(this=l_cte), + alias=exp.TableAlias(this=exp.Identifier(this=l_alias)), + ), + copy=False, + ) + + join_on = exp.And( + this=exp.And( + this=exp.EQ( + this=exp.column(l_cols[0], table=l_alias, quoted=True), + expression=exp.column(r_cols[0], table=r_alias, quoted=True), + ), + expression=exp.EQ( + this=exp.column("__giql_bin", table=l_alias), + expression=exp.column("__giql_bin", table=r_alias), + ), + ), + expression=self._build_overlap(l_alias, r_alias, l_cols, r_cols), + ) + + select.join( + exp.Table( + this=exp.Identifier(this=r_cte), + alias=exp.TableAlias(this=exp.Identifier(this=r_alias)), + ), + on=join_on, + copy=False, + ) + + return exp.CTE( + this=select, + alias=exp.TableAlias(this=exp.Identifier(this=name)), + ) + + def _build_pairs_replacement_joins( + self, + query: exp.Select, + join: exp.Join, + intersects: Intersects, + extra: exp.Expression | None, + key_binned: dict[str, str], + pairs_idx: int, + ) -> list[exp.Join]: + """Build a pairs CTE and two replacement joins for one INTERSECTS. + + Returns two joins: + - join1: from_alias [SIDE] JOIN __giql_pairs ON from.key = pairs.from_key + - join2: [SIDE] JOIN join_table ON join.key = pairs.join_key [AND extra] + """ + from_table = query.args["from_"].this + join_table = join.this + if not isinstance(from_table, exp.Table) or not isinstance( + join_table, exp.Table + ): + return [join] + + from_alias = from_table.alias or from_table.name + join_alias = join_table.alias or join_table.name + from_table_name = from_table.name + join_table_name = join_table.name + + left_alias = intersects.this.table + from_cols = self._get_columns(from_table_name) + join_cols = self._get_columns(join_table_name) + + # Determine which INTERSECTS side maps to FROM vs JOIN table + if left_alias == from_alias: + l_table_name, r_table_name = from_table_name, join_table_name + l_cols, r_cols = from_cols, join_cols + from_prefix, join_prefix = "__giql_l", "__giql_r" + else: + l_table_name, r_table_name = join_table_name, from_table_name + l_cols, r_cols = join_cols, from_cols + from_prefix, join_prefix = "__giql_r", "__giql_l" + + # Ensure key-only bin CTEs exist + l_cte = self._ensure_key_binned(query, l_table_name, key_binned) + r_cte = self._ensure_key_binned(query, r_table_name, key_binned) + + # Build and attach the pairs CTE + pairs_name = f"__giql_pairs_{pairs_idx}" + pairs_cte = self._build_pairs_cte(pairs_name, l_cte, r_cte, l_cols, r_cols) + existing_with = query.args.get("with_") + if existing_with: + existing_with.append("expressions", pairs_cte) + else: + query.set("with_", exp.With(expressions=[pairs_cte])) + + side = join.args.get("side") + p_alias = f"__giql_p{pairs_idx}" + + # join1: [SIDE] JOIN pairs ON from.key = pairs.from_key + join1_on = self._build_key_match(from_alias, from_cols, p_alias, from_prefix) + join1_kwargs: dict = { + "this": exp.Table( + this=exp.Identifier(this=pairs_name), + alias=exp.TableAlias(this=exp.Identifier(this=p_alias)), + ), + "on": join1_on, + } + if side: + join1_kwargs["side"] = side + join1 = exp.Join(**join1_kwargs) + + # join2: [SIDE] JOIN join_table ON join.key = pairs.join_key + join2_on = self._build_key_match(join_alias, join_cols, p_alias, join_prefix) + if extra: + join2_on = exp.And(this=join2_on, expression=extra) + join2_kwargs: dict = { + "this": exp.Table( + this=exp.Identifier(this=join_table_name), + alias=exp.TableAlias(this=exp.Identifier(this=join_alias)), + ), + "on": join2_on, + } + if side: + join2_kwargs["side"] = side + join2 = exp.Join(**join2_kwargs) + + return [join1, join2] + + def _build_key_match( + self, + table_alias: str, + cols: tuple[str, str, str], + pairs_alias: str, + prefix: str, + ) -> exp.And: + """Build ``table.chrom = pairs.prefix_chrom AND ...`` for all three keys.""" + return exp.And( + this=exp.And( + this=exp.EQ( + this=exp.column(cols[0], table=table_alias, quoted=True), + expression=exp.column(f"{prefix}_chrom", table=pairs_alias), + ), + expression=exp.EQ( + this=exp.column(cols[1], table=table_alias, quoted=True), + expression=exp.column(f"{prefix}_start", table=pairs_alias), + ), + ), + expression=exp.EQ( + this=exp.column(cols[2], table=table_alias, quoted=True), + expression=exp.column(f"{prefix}_end", table=pairs_alias), + ), + ) + + def _transform_full_cte(self, query: exp.Select) -> exp.Select: + joins = query.args.get("joins") or [] + binned: dict[str, tuple[str, str, str]] = {} + rewrote_any = False + + for join in joins: + on = join.args.get("on") + if on: + intersects = self._find_column_intersects_in(on) + if intersects: + self._rewrite_join_on_full_cte(query, join, intersects, binned) + rewrote_any = True + + # Implicit cross-join: FROM a, b WHERE a.interval INTERSECTS b.interval + where = query.args.get("where") + if where: + intersects = self._find_column_intersects_in(where.this) + if intersects: + cross_join = self._find_cross_join_for_intersects( + query, intersects, joins + ) + if cross_join is not None: + self._rewrite_cross_join_full_cte( + query, cross_join, intersects, binned + ) + rewrote_any = True + + if rewrote_any: + query.set("distinct", exp.Distinct()) + + return query + + def _rewrite_join_on_full_cte( + self, + query: exp.Select, + join: exp.Join, + intersects: Intersects, + binned: dict[str, tuple[str, str, str]], + ) -> None: + from_table = query.args["from_"].this + join_table = join.this + if not isinstance(from_table, exp.Table) or not isinstance( + join_table, exp.Table + ): + return + + from_alias, from_cols = self._ensure_table_binned_full( + query, from_table, query.args["from_"], binned + ) + join_alias, join_cols = self._ensure_table_binned_full( + query, join_table, join, binned + ) + + extra = self._extract_non_intersects(join.args.get("on"), intersects) + + equi_join = exp.And( + this=exp.EQ( + this=exp.column(from_cols[0], table=from_alias, quoted=True), + expression=exp.column(join_cols[0], table=join_alias, quoted=True), + ), + expression=exp.EQ( + this=exp.column("__giql_bin", table=from_alias), + expression=exp.column("__giql_bin", table=join_alias), + ), + ) + # Place both equi-join and overlap in ON so LEFT/RIGHT/FULL semantics hold. + new_on = exp.And( + this=equi_join, + expression=self._build_overlap(from_alias, join_alias, from_cols, join_cols), + ) + if extra: + new_on = exp.And(this=new_on, expression=extra) + join.set("on", new_on) + + def _rewrite_cross_join_full_cte( + self, + query: exp.Select, + cross_join: exp.Join, + intersects: Intersects, + binned: dict[str, tuple[str, str, str]], + ) -> None: + from_table = query.args["from_"].this + join_table = cross_join.this + if not isinstance(from_table, exp.Table) or not isinstance( + join_table, exp.Table + ): + return + + from_alias, from_cols = self._ensure_table_binned_full( + query, from_table, query.args["from_"], binned + ) + join_alias, join_cols = self._ensure_table_binned_full( + query, join_table, cross_join, binned + ) + + equi_join = exp.And( + this=exp.EQ( + this=exp.column(from_cols[0], table=from_alias, quoted=True), + expression=exp.column(join_cols[0], table=join_alias, quoted=True), + ), + expression=exp.EQ( + this=exp.column("__giql_bin", table=from_alias), + expression=exp.column("__giql_bin", table=join_alias), + ), + ) + cross_join.set( + "on", + exp.And( + this=equi_join, + expression=self._build_overlap( + from_alias, join_alias, from_cols, join_cols + ), + ), + ) + self._remove_intersects_from_where(query, intersects) + + def _ensure_table_binned_full( + self, + query: exp.Select, + table: exp.Table, + parent: exp.Expression, + binned: dict[str, tuple[str, str, str]], + ) -> tuple[str, tuple[str, str, str]]: + """Create a full SELECT * CTE for *table* if needed; replace ref in *parent*.""" + alias = table.alias or table.name + if alias in binned: + return alias, binned[alias] + + table_name = table.name + cols = self._get_columns(table_name) + cte_name = f"__giql_{alias}_binned" + + cte = exp.CTE( + this=self._build_full_binned_select(table_name, cols), + alias=exp.TableAlias(this=exp.Identifier(this=cte_name)), + ) + existing_with = query.args.get("with_") + if existing_with: + existing_with.append("expressions", cte) + else: + query.set("with_", exp.With(expressions=[cte])) + + parent.set( + "this", + exp.Table( + this=exp.Identifier(this=cte_name), + alias=exp.TableAlias(this=exp.Identifier(this=alias)), + ), + ) + binned[alias] = cols + return alias, cols + + def _build_bin_range( + self, start: str, end: str + ) -> tuple[exp.Expression, exp.Expression]: + """Build the (low, high) bin-index expressions for UNNEST(range(...)). + + Returns ``start // bin_size`` and ``(end - 1) // bin_size + 1``. + Uses integer floor division to avoid rounding errors from + float division + CAST. + """ + bs = self.bin_size + + low = exp.IntDiv( + this=exp.column(start, quoted=True), + expression=exp.Literal.number(bs), + ) + high = exp.Add( + this=exp.IntDiv( + this=exp.Paren( + this=exp.Sub( + this=exp.column(end, quoted=True), + expression=exp.Literal.number(1), + ), + ), + expression=exp.Literal.number(bs), + ), + expression=exp.Literal.number(1), + ) + return low, high + + def _build_full_binned_select( + self, table_name: str, cols: tuple[str, str, str] + ) -> exp.Select: + """Build ``SELECT *, UNNEST(range(...)) AS __giql_bin FROM ``.""" + _chrom, start, end = cols + low, high = self._build_bin_range(start, end) + + range_fn = exp.Anonymous(this="range", expressions=[low, high]) + unnest_fn = exp.Anonymous(this="UNNEST", expressions=[range_fn]) + bin_alias = exp.Alias( + this=unnest_fn, + alias=exp.Identifier(this="__giql_bin"), + ) + + select = exp.Select() + select.select(exp.Star(), copy=False) + select.select(bin_alias, append=True, copy=False) + select.from_(exp.Table(this=exp.Identifier(this=table_name)), copy=False) + return select + + def _transform_bridge(self, query: exp.Select) -> exp.Select: + joins = query.args.get("joins") or [] + key_binned: dict[str, str] = {} + connector_counter = itertools.count() + new_joins: list[exp.Join] = [] + rewrote_any = False + + for join in joins: + on = join.args.get("on") + if on: + intersects = self._find_column_intersects_in(on) + if intersects: + extra = self._build_join_back_joins( + query, + join, + intersects, + key_binned, + connector_counter, + preserve_kind=True, + ) + new_joins.extend(extra) + rewrote_any = True + continue + new_joins.append(join) + + where = query.args.get("where") + if where: + intersects = self._find_column_intersects_in(where.this) + if intersects: + cross_join = self._find_cross_join_for_intersects( + query, intersects, new_joins + ) + if cross_join is not None: + new_joins.remove(cross_join) + extra = self._build_join_back_joins( + query, + cross_join, + intersects, + key_binned, + connector_counter, + preserve_kind=False, + ) + new_joins.extend(extra) + self._remove_intersects_from_where(query, intersects) + rewrote_any = True + + if rewrote_any: + query.set("joins", new_joins) + query.set("distinct", exp.Distinct()) + + return query + + def _find_column_intersects_in(self, expr: exp.Expression) -> Intersects | None: + """Return the first column-to-column Intersects node in *expr*, or None. + + Only the first match is returned. A single JOIN with multiple + INTERSECTS conditions in its ON clause is not supported; only the + first will be rewritten. + """ + for node in expr.find_all(Intersects): + if ( + isinstance(node.this, exp.Column) + and node.this.table + and isinstance(node.expression, exp.Column) + and node.expression.table + ): + return node + return None + + def _find_cross_join_for_intersects( + self, + query: exp.Select, + intersects: Intersects, + current_joins: list[exp.Join], + ) -> exp.Join | None: + """Find the implicit cross-join entry for the table in a WHERE INTERSECTS.""" + from_table = query.args["from_"].this + if not isinstance(from_table, exp.Table): + return None + from_alias = from_table.alias or from_table.name + + left_alias = intersects.this.table + right_alias = intersects.expression.table + if left_alias == from_alias: + target_alias = right_alias + elif right_alias == from_alias: + target_alias = left_alias + else: + return None + + for join in current_joins: + if isinstance(join.this, exp.Table): + alias = join.this.alias or join.this.name + if alias == target_alias: + return join + return None + + def _remove_intersects_from_where( + self, query: exp.Select, intersects: Intersects + ) -> None: + """Remove the INTERSECTS predicate from the WHERE clause.""" + where = query.args.get("where") + if not where: + return + remainder = self._extract_non_intersects(where.this, intersects) + if remainder is None: + query.set("where", None) + else: + query.set("where", exp.Where(this=remainder)) + + def _extract_non_intersects( + self, expr: exp.Expression | None, intersects: Intersects + ) -> exp.Expression | None: + """Return the parts of an AND tree that are not the INTERSECTS node.""" + if expr is None or expr is intersects: + return None + if isinstance(expr, exp.And): + if expr.this is intersects: + return expr.expression + if expr.expression is intersects: + return expr.this + left = self._extract_non_intersects(expr.this, intersects) + right = self._extract_non_intersects(expr.expression, intersects) + if left is None: + return right + if right is None: + return left + return exp.And(this=left, expression=right) + return expr + + def _get_columns(self, table_name: str) -> tuple[str, str, str]: + """Return (chrom, start, end) column names for a table.""" + table = self.tables.get(table_name) + if table: + return (table.chrom_col, table.start_col, table.end_col) + return (DEFAULT_CHROM_COL, DEFAULT_START_COL, DEFAULT_END_COL) + + def _build_overlap( + self, + from_alias: str, + join_alias: str, + from_cols: tuple[str, str, str], + join_cols: tuple[str, str, str], + ) -> exp.And: + """Build ``from.start < join.end AND from.end > join.start``.""" + return exp.And( + this=exp.LT( + this=exp.column(from_cols[1], table=from_alias, quoted=True), + expression=exp.column(join_cols[2], table=join_alias, quoted=True), + ), + expression=exp.GT( + this=exp.column(from_cols[2], table=from_alias, quoted=True), + expression=exp.column(join_cols[1], table=join_alias, quoted=True), + ), + ) + + def _find_table_name_for_alias(self, query: exp.Select, alias: str) -> str: + """Resolve an alias to its underlying table name.""" + from_table = query.args["from_"].this + if isinstance(from_table, exp.Table): + if (from_table.alias or from_table.name) == alias: + return from_table.name + for join in query.args.get("joins") or []: + if isinstance(join.this, exp.Table): + t = join.this + if (t.alias or t.name) == alias: + return t.name + return alias # fallback: alias == table name + + def _build_key_only_bins_select( + self, table_name: str, cols: tuple[str, str, str] + ) -> exp.Select: + """Build ``SELECT chrom, start, end, UNNEST(range(...)) AS __giql_bin FROM table``.""" + chrom, start, end = cols + low, high = self._build_bin_range(start, end) + + range_fn = exp.Anonymous(this="range", expressions=[low, high]) + unnest_fn = exp.Anonymous(this="UNNEST", expressions=[range_fn]) + bin_alias = exp.Alias( + this=unnest_fn, + alias=exp.Identifier(this="__giql_bin"), + ) + + select = exp.Select() + select.select(exp.column(chrom, quoted=True), copy=False) + select.select(exp.column(start, quoted=True), append=True, copy=False) + select.select(exp.column(end, quoted=True), append=True, copy=False) + select.select(bin_alias, append=True, copy=False) + select.from_(exp.Table(this=exp.Identifier(this=table_name)), copy=False) + return select + + def _ensure_key_binned( + self, + query: exp.Select, + table_name: str, + key_binned: dict[str, str], + ) -> str: + """Ensure a key-only bins CTE exists for *table_name*; return its name.""" + if table_name in key_binned: + return key_binned[table_name] + + cte_name = f"__giql_{table_name}_bins" + cols = self._get_columns(table_name) + cte = exp.CTE( + this=self._build_key_only_bins_select(table_name, cols), + alias=exp.TableAlias(this=exp.Identifier(this=cte_name)), + ) + + existing_with = query.args.get("with_") + if existing_with: + existing_with.append("expressions", cte) + else: + query.set("with_", exp.With(expressions=[cte])) + + key_binned[table_name] = cte_name + return cte_name + + def _build_join_back_joins( + self, + query: exp.Select, + join: exp.Join, + intersects: Intersects, + key_binned: dict[str, str], + connector_counter: itertools.count, + *, + preserve_kind: bool, + ) -> list[exp.Join]: + """Build three replacement JOINs for one INTERSECTS using the join-back pattern. + + join1 is always INNER because it key-matches a table against its + own bin CTE — every row has a corresponding bin entry by + construction, so the join side has no effect. + + join2 and join3 inherit the original join's side (LEFT, RIGHT) + when *preserve_kind* is True. + """ + join_table = join.this + if not isinstance(join_table, exp.Table): + return [join] + + join_alias = join_table.alias or join_table.name + join_table_name = join_table.name + + left_alias = intersects.this.table + right_alias = intersects.expression.table + other_alias = left_alias if right_alias == join_alias else right_alias + if other_alias == join_alias: + return [join] # can't determine structure + + extra = self._extract_non_intersects(join.args.get("on"), intersects) + + other_table_name = self._find_table_name_for_alias(query, other_alias) + other_cols = self._get_columns(other_table_name) + join_cols = self._get_columns(join_table_name) + + other_cte = self._ensure_key_binned(query, other_table_name, key_binned) + join_cte = self._ensure_key_binned(query, join_table_name, key_binned) + + c0 = f"__giql_c{next(connector_counter)}" + c1 = f"__giql_c{next(connector_counter)}" + + join_side = None + if preserve_kind: + join_side = join.args.get("side") + + # join1: key-match from other_alias to its bin CTE + join1 = exp.Join( + this=exp.Table( + this=exp.Identifier(this=other_cte), + alias=exp.TableAlias(this=exp.Identifier(this=c0)), + ), + on=exp.And( + this=exp.And( + this=exp.EQ( + this=exp.column(other_cols[0], table=other_alias, quoted=True), + expression=exp.column(other_cols[0], table=c0, quoted=True), + ), + expression=exp.EQ( + this=exp.column(other_cols[1], table=other_alias, quoted=True), + expression=exp.column(other_cols[1], table=c0, quoted=True), + ), + ), + expression=exp.EQ( + this=exp.column(other_cols[2], table=other_alias, quoted=True), + expression=exp.column(other_cols[2], table=c0, quoted=True), + ), + ), + ) + + # join2: bin equi-join (chrom + __giql_bin match) + join2_kwargs: dict = { + "this": exp.Table( + this=exp.Identifier(this=join_cte), + alias=exp.TableAlias(this=exp.Identifier(this=c1)), + ), + "on": exp.And( + this=exp.EQ( + this=exp.column(other_cols[0], table=c0, quoted=True), + expression=exp.column(join_cols[0], table=c1, quoted=True), + ), + expression=exp.EQ( + this=exp.column("__giql_bin", table=c0), + expression=exp.column("__giql_bin", table=c1), + ), + ), + } + if join_side: + join2_kwargs["side"] = join_side + join2 = exp.Join(**join2_kwargs) + + # join3: key-match from join CTE back to actual join table + overlap + key_match = exp.And( + this=exp.And( + this=exp.EQ( + this=exp.column(join_cols[0], table=join_alias, quoted=True), + expression=exp.column(join_cols[0], table=c1, quoted=True), + ), + expression=exp.EQ( + this=exp.column(join_cols[1], table=join_alias, quoted=True), + expression=exp.column(join_cols[1], table=c1, quoted=True), + ), + ), + expression=exp.EQ( + this=exp.column(join_cols[2], table=join_alias, quoted=True), + expression=exp.column(join_cols[2], table=c1, quoted=True), + ), + ) + join3_on = exp.And( + this=key_match, + expression=self._build_overlap( + other_alias, join_alias, other_cols, join_cols + ), + ) + if extra: + join3_on = exp.And(this=join3_on, expression=extra) + join3_kwargs: dict = { + "this": exp.Table( + this=exp.Identifier(this=join_table_name), + alias=exp.TableAlias(this=exp.Identifier(this=join_alias)), + ), + "on": join3_on, + } + if join_side: + join3_kwargs["side"] = join_side + + join3 = exp.Join(**join3_kwargs) + + return [join1, join2, join3] diff --git a/src/giql/transpile.py b/src/giql/transpile.py index 2b29c3d..9bf8076 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -11,6 +11,7 @@ from giql.table import Table from giql.table import Tables from giql.transformer import ClusterTransformer +from giql.transformer import IntersectsBinnedJoinTransformer from giql.transformer import MergeTransformer @@ -45,6 +46,7 @@ def _build_tables(tables: list[str | Table] | None) -> Tables: def transpile( giql: str, tables: list[str | Table] | None = None, + bin_size: int | None = None, ) -> str: """Transpile a GIQL query to SQL. @@ -60,6 +62,11 @@ def transpile( Table configurations. Strings use default column mappings (chrom, start, end, strand). Table objects provide custom column name mappings. + bin_size : int | None + Bin size for INTERSECTS equi-join optimization. When a query + contains a full-table column-to-column INTERSECTS join, the + transpiler rewrites it as a binned equi-join for performance. + Defaults to 10,000 if not specified. Returns ------- @@ -94,11 +101,24 @@ def transpile( ) ], ) + + Binned equi-join with custom bin size:: + + sql = transpile( + "SELECT a.*, b.* FROM peaks a JOIN genes b " + "ON a.interval INTERSECTS b.interval", + tables=["peaks", "genes"], + bin_size=100000, + ) """ # Build tables container tables_container = _build_tables(tables) # Initialize transformers with table configurations + intersects_transformer = IntersectsBinnedJoinTransformer( + tables_container, + bin_size=bin_size, + ) merge_transformer = MergeTransformer(tables_container) cluster_transformer = ClusterTransformer(tables_container) @@ -111,8 +131,9 @@ def transpile( except Exception as e: raise ValueError(f"Parse error: {e}\nQuery: {giql}") from e - # Apply transformations (MERGE first, then CLUSTER) + # Apply transformations try: + ast = intersects_transformer.transform(ast) # MERGE transformation (which may internally use CLUSTER) ast = merge_transformer.transform(ast) # CLUSTER transformation for any standalone CLUSTER expressions diff --git a/tests/integration/bedtools/test_intersect_property.py b/tests/integration/bedtools/test_intersect_property.py new file mode 100644 index 0000000..57c657d --- /dev/null +++ b/tests/integration/bedtools/test_intersect_property.py @@ -0,0 +1,692 @@ +"""Property-based correctness tests for INTERSECTS binned equi-join. + +These tests use hypothesis to generate random genomic intervals of +varying sizes — including intervals that span multiple bins — and +verify that GIQL's binned equi-join produces identical results to +bedtools intersect. +""" + +from hypothesis import HealthCheck +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +from giql import transpile + +from .utils.bedtools_wrapper import intersect +from .utils.comparison import compare_results +from .utils.data_models import GenomicInterval +from .utils.duckdb_loader import load_intervals + +duckdb = __import__("pytest").importorskip("duckdb") + + +# --------------------------------------------------------------------------- +# Strategies +# --------------------------------------------------------------------------- + +CHROMS = ["chr1", "chr2", "chr3"] + + +@st.composite +def genomic_interval_st(draw, idx=None): + """Generate a random GenomicInterval that can span multiple 10k bins.""" + chrom = draw(st.sampled_from(CHROMS)) + start = draw(st.integers(min_value=0, max_value=1_000_000)) + length = draw(st.integers(min_value=1, max_value=200_000)) + score = draw(st.integers(min_value=0, max_value=1000)) + strand = draw(st.sampled_from(["+", "-"])) + # Name is set by the list strategy to guarantee uniqueness, avoiding + # the known DISTINCT duplicate-collapse limitation. + name = ( + f"r{idx}" + if idx is not None + else draw(st.from_regex(r"r[0-9]{1,6}", fullmatch=True)) + ) + return GenomicInterval(chrom, start, start + length, name, score, strand) + + +@st.composite +def unique_interval_list_st(draw, max_size=60): + """Generate a list of intervals with unique names.""" + n = draw(st.integers(min_value=1, max_value=max_size)) + intervals = [] + for i in range(n): + iv = draw(genomic_interval_st(idx=i)) + intervals.append(iv) + return intervals + + +interval_list_st = unique_interval_list_st() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _run_giql(intervals_a, intervals_b): + """Run the binned-join INTERSECTS query via DuckDB and return result rows.""" + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + return conn.execute(sql).fetchall() + finally: + conn.close() + + +def _run_bedtools(intervals_a, intervals_b): + """Run bedtools intersect -u and return result tuples.""" + return intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@given(intervals_a=interval_list_st, intervals_b=interval_list_st) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_binned_join_matches_bedtools(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of genomic intervals + WHEN GIQL INTERSECTS binned equi-join is executed + THEN results match bedtools intersect -u exactly + """ + giql_result = _run_giql(intervals_a, intervals_b) + bedtools_result = _run_bedtools(intervals_a, intervals_b) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +@given(intervals=interval_list_st) +@settings( + max_examples=30, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_self_join_matches_bedtools(intervals): + """ + GIVEN a randomly generated set of genomic intervals + WHEN GIQL INTERSECTS self-join is executed + THEN results match bedtools intersect -u with the same file as A and B + """ + giql_result = _run_giql(intervals, intervals) + bedtools_result = _run_bedtools(intervals, intervals) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +@given( + intervals_a=unique_interval_list_st(max_size=30), + intervals_b=unique_interval_list_st(max_size=30), + bin_size=st.sampled_from([100, 1_000, 10_000, 100_000]), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_bin_size_does_not_affect_correctness(intervals_a, intervals_b, bin_size): + """ + GIVEN two randomly generated sets of genomic intervals and a bin size + WHEN GIQL INTERSECTS is executed with that bin size + THEN results match bedtools intersect -u regardless of bin size + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + bin_size=bin_size, + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = _run_bedtools(intervals_a, intervals_b) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, f"bin_size={bin_size}: {comparison.failure_message()}" + + +@given( + intervals_a=unique_interval_list_st(max_size=8), + intervals_b=unique_interval_list_st(max_size=8), + intervals_c=unique_interval_list_st(max_size=8), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_multi_table_join_matches_bedtools(intervals_a, intervals_b, intervals_c): + """ + GIVEN three randomly generated sets of genomic intervals + WHEN GIQL three-way INTERSECTS join is executed + THEN the A-side rows match bedtools intersect chained A->B then ->C + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + load_intervals(conn, "intervals_c", [i.to_tuple() for i in intervals_c]) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + JOIN intervals_c c ON a.interval INTERSECTS c.interval + """, + tables=["intervals_a", "intervals_b", "intervals_c"], + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + # bedtools equivalent: chain A∩B then filter against C + tuples_a = [i.to_tuple() for i in intervals_a] + tuples_b = [i.to_tuple() for i in intervals_b] + tuples_c = [i.to_tuple() for i in intervals_c] + ab_result = intersect(tuples_a, tuples_b) + if ab_result: + bedtools_result = intersect(ab_result, tuples_c) + else: + bedtools_result = [] + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +@given( + intervals_a=unique_interval_list_st(max_size=30), + intervals_b=unique_interval_list_st(max_size=30), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_left_join_matches_bedtools_loj(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of genomic intervals + WHEN GIQL LEFT JOIN INTERSECTS is executed + THEN results match bedtools intersect -loj exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT DISTINCT + a.chrom, a.start, a.end, a.name, a.score, a.strand, + b.chrom AS b_chrom, b.start AS b_start, b.end AS b_end, + b.name AS b_name, b.score AS b_score, b.strand AS b_strand + FROM intervals_a a + LEFT JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + loj=True, + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +# --------------------------------------------------------------------------- +# -v (inverse / anti-join) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=30), + intervals_b=unique_interval_list_st(max_size=30), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_inverse_matches_bedtools_v(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of genomic intervals + WHEN GIQL anti-join (LEFT JOIN WHERE b IS NULL) is executed + THEN results match bedtools intersect -v exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a + LEFT JOIN intervals_b b ON a.interval INTERSECTS b.interval + WHERE b.chrom IS NULL + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + inverse=True, + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +# --------------------------------------------------------------------------- +# -wa -wb (write both A and B entries) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=20), + intervals_b=unique_interval_list_st(max_size=20), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_write_both_matches_bedtools_wa_wb(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of genomic intervals + WHEN GIQL INTERSECTS join selecting both sides is executed + THEN results match bedtools intersect -wa -wb exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT + a.chrom, a.start, a.end, a.name, a.score, a.strand, + b.chrom AS b_chrom, b.start AS b_start, b.end AS b_end, + b.name AS b_name, b.score AS b_score, b.strand AS b_strand + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + write_both=True, + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +# --------------------------------------------------------------------------- +# -c (count overlaps) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=20), + intervals_b=unique_interval_list_st(max_size=20), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_count_matches_bedtools_c(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of unique genomic intervals + WHEN GIQL COUNT of overlapping B per A is computed + THEN results match bedtools intersect -c exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + # Use a naive overlap join for counting — the binned join's + # DISTINCT would collapse duplicate B matches. + count_sql = """ + SELECT + a.chrom, a."start", a."end", a.name, a.score, a.strand, + COUNT(b.chrom) AS cnt + FROM intervals_a a + LEFT JOIN intervals_b b + ON a.chrom = b.chrom + AND a."start" < b."end" + AND a."end" > b."start" + GROUP BY a.chrom, a."start", a."end", a.name, a.score, a.strand + """ + giql_result = conn.execute(count_sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + count=True, + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +# --------------------------------------------------------------------------- +# -s (same strand) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=30), + intervals_b=unique_interval_list_st(max_size=30), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_same_strand_matches_bedtools_s(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of genomic intervals with strands + WHEN GIQL INTERSECTS with same-strand filter is executed + THEN results match bedtools intersect -s exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + AND a.strand = b.strand + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + strand_mode="same", + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +# --------------------------------------------------------------------------- +# -S (opposite strand) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=30), + intervals_b=unique_interval_list_st(max_size=30), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_opposite_strand_matches_bedtools_S(intervals_a, intervals_b): + """ + GIVEN two randomly generated sets of genomic intervals with strands + WHEN GIQL INTERSECTS with opposite-strand filter is executed + THEN results match bedtools intersect -S exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + AND a.strand != b.strand + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + strand_mode="opposite", + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, comparison.failure_message() + + +# --------------------------------------------------------------------------- +# -f (minimum overlap fraction of A) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=20), + intervals_b=unique_interval_list_st(max_size=20), + fraction=st.sampled_from([0.1, 0.25, 0.5, 0.75, 0.9]), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_fraction_a_matches_bedtools_f(intervals_a, intervals_b, fraction): + """ + GIVEN two randomly generated sets of genomic intervals and a fraction + WHEN GIQL INTERSECTS with minimum overlap fraction of A is executed + THEN results match bedtools intersect -f exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + inner_sql = transpile( + """ + SELECT DISTINCT + a.chrom, a.start, a.end, a.name, a.score, a.strand, + b.start AS b_start, b.end AS b_end + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + sql = f""" + SELECT DISTINCT chrom, "start", "end", name, score, strand + FROM ({inner_sql}) + WHERE (LEAST("end", b_end) - GREATEST("start", b_start)) + >= {fraction} * ("end" - "start") + """ + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + fraction_a=fraction, + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, f"fraction_a={fraction}: {comparison.failure_message()}" + + +# --------------------------------------------------------------------------- +# -F (minimum overlap fraction of B) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=20), + intervals_b=unique_interval_list_st(max_size=20), + fraction=st.sampled_from([0.1, 0.25, 0.5, 0.75, 0.9]), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_fraction_b_matches_bedtools_F(intervals_a, intervals_b, fraction): + """ + GIVEN two randomly generated sets of genomic intervals and a fraction + WHEN GIQL INTERSECTS with minimum overlap fraction of B is executed + THEN results match bedtools intersect -F exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + inner_sql = transpile( + """ + SELECT DISTINCT + a.chrom, a.start, a.end, a.name, a.score, a.strand, + b.start AS b_start, b.end AS b_end + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + sql = f""" + SELECT DISTINCT chrom, "start", "end", name, score, strand + FROM ({inner_sql}) + WHERE (LEAST("end", b_end) - GREATEST("start", b_start)) + >= {fraction} * (b_end - b_start) + """ + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + fraction_b=fraction, + ) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, f"fraction_b={fraction}: {comparison.failure_message()}" + + +# --------------------------------------------------------------------------- +# -r (reciprocal overlap fraction) +# --------------------------------------------------------------------------- + + +@given( + intervals_a=unique_interval_list_st(max_size=20), + intervals_b=unique_interval_list_st(max_size=20), + fraction=st.sampled_from([0.1, 0.25, 0.5, 0.75]), +) +@settings( + max_examples=40, + deadline=None, + suppress_health_check=[HealthCheck.too_slow], +) +def test_reciprocal_fraction_matches_bedtools_r(intervals_a, intervals_b, fraction): + """ + GIVEN two randomly generated sets of genomic intervals and a fraction + WHEN GIQL INTERSECTS with reciprocal overlap fraction is executed + THEN results match bedtools intersect -f -F -r exactly + """ + conn = duckdb.connect(":memory:") + try: + load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b]) + + inner_sql = transpile( + """ + SELECT DISTINCT + a.chrom, a.start, a.end, a.name, a.score, a.strand, + b.start AS b_start, b.end AS b_end + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + sql = f""" + SELECT DISTINCT chrom, "start", "end", name, score, strand + FROM ({inner_sql}) + WHERE (LEAST("end", b_end) - GREATEST("start", b_start)) + >= {fraction} * ("end" - "start") + AND (LEAST("end", b_end) - GREATEST("start", b_start)) + >= {fraction} * (b_end - b_start) + """ + giql_result = conn.execute(sql).fetchall() + finally: + conn.close() + + # -r applies -f reciprocally to both sides and requires -wa output. + # Deduplicate to match GIQL's SELECT DISTINCT. + bedtools_raw = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + fraction_a=fraction, + reciprocal=True, + ) + bedtools_result = list(set(bedtools_raw)) + + comparison = compare_results(giql_result, bedtools_result) + assert comparison.match, ( + f"reciprocal fraction={fraction}: {comparison.failure_message()}" + ) diff --git a/tests/integration/bedtools/utils/bedtools_wrapper.py b/tests/integration/bedtools/utils/bedtools_wrapper.py index c61be44..5a75304 100644 --- a/tests/integration/bedtools/utils/bedtools_wrapper.py +++ b/tests/integration/bedtools/utils/bedtools_wrapper.py @@ -30,19 +30,88 @@ def intersect( intervals_a: list[tuple], intervals_b: list[tuple], strand_mode: str | None = None, + *, + loj: bool = False, + inverse: bool = False, + write_both: bool = False, + count: bool = False, + write_overlap: bool = False, + write_all_overlap: bool = False, + fraction_a: float | None = None, + fraction_b: float | None = None, + reciprocal: bool = False, ) -> list[tuple]: - """Find overlapping intervals using bedtools intersect.""" + """Find overlapping intervals using bedtools intersect. + + Parameters + ---------- + loj : bool + Left outer join mode (-loj). + inverse : bool + Report A entries with NO overlap in B (-v). + write_both : bool + Write both A and B entries for each overlap (-wa -wb). + count : bool + Count B overlaps for each A feature (-c). + write_overlap : bool + Write overlap amount in bp for each pair (-wo). + write_all_overlap : bool + Write overlap amount for all A features including + non-overlapping (-wao). + fraction_a : float or None + Minimum overlap as fraction of A (-f). + fraction_b : float or None + Minimum overlap as fraction of B (-F). + reciprocal : bool + Require fraction thresholds on both sides (-r). + """ try: bt_a = create_bedtool(intervals_a) bt_b = create_bedtool(intervals_b) - kwargs = {"u": True} + kwargs: dict = {} + if loj: + kwargs["loj"] = True + elif inverse: + kwargs["v"] = True + elif write_both: + kwargs["wa"] = True + kwargs["wb"] = True + elif count: + kwargs["c"] = True + elif write_overlap: + kwargs["wo"] = True + elif write_all_overlap: + kwargs["wao"] = True + elif reciprocal: + kwargs["wa"] = True + else: + kwargs["u"] = True + if strand_mode == "same": kwargs["s"] = True elif strand_mode == "opposite": kwargs["S"] = True + if fraction_a is not None: + kwargs["f"] = fraction_a + if fraction_b is not None and not reciprocal: + kwargs["F"] = fraction_b + if reciprocal: + kwargs["r"] = True + result = bt_a.intersect(bt_b, **kwargs) + + if loj: + return bedtool_to_tuples(result, bed_format="loj") + if write_both: + return bedtool_to_tuples(result, bed_format="loj") + if count: + return bedtool_to_tuples(result, bed_format="count") + if write_overlap: + return bedtool_to_tuples(result, bed_format="wo") + if write_all_overlap: + return bedtool_to_tuples(result, bed_format="wo") return bedtool_to_tuples(result) except Exception as e: @@ -102,7 +171,11 @@ def bedtool_to_tuples( Args: bedtool: pybedtools.BedTool object - bed_format: Expected format ('bed3', 'bed6', or 'closest') + bed_format: Expected format ('bed3', 'bed6', 'loj', or 'closest') + + LOJ format assumes BED6(A)+BED6(B) (12 fields): + Fields 0-5: A interval + Fields 6-11: B interval (all '.' / -1 when unmatched) Closest format assumes BED6+BED6+distance (13 fields): Fields 0-5: A interval (chrom, start, end, name, score, strand) @@ -137,6 +210,68 @@ def bedtool_to_tuples( ) ) + elif bed_format == "count": + while len(fields) < 7: + fields.append("0") + rows.append( + ( + fields[0], + int(fields[1]), + int(fields[2]), + fields[3] if fields[3] != "." else None, + int(fields[4]) if fields[4] != "." else None, + fields[5] if fields[5] != "." else None, + int(fields[6]), + ) + ) + + elif bed_format == "wo": + if len(fields) < 13: + raise ValueError(f"Unexpected number of fields for wo: {len(fields)}") + rows.append( + ( + fields[0], + int(fields[1]), + int(fields[2]), + fields[3] if fields[3] != "." else None, + int(fields[4]) if fields[4] != "." else None, + fields[5] if fields[5] != "." else None, + fields[6] if fields[6] != "." else None, + int(fields[7]) if fields[7] != "." else None, + int(fields[8]) if fields[8] != "." else None, + fields[9] if fields[9] != "." else None, + int(fields[10]) if fields[10] != "." else None, + fields[11] if fields[11] != "." else None, + int(fields[12]), + ) + ) + + elif bed_format == "loj": + if len(fields) < 12: + raise ValueError(f"Unexpected number of fields for loj: {len(fields)}") + + def _loj_field(val, as_int=False): + if val == "." or val == "-1": + return None + return int(val) if as_int else val + + rows.append( + ( + fields[0], + int(fields[1]), + int(fields[2]), + fields[3] if fields[3] != "." else None, + int(fields[4]) if fields[4] != "." else None, + fields[5] if fields[5] != "." else None, + _loj_field(fields[6]), + _loj_field(fields[7], as_int=True), + _loj_field(fields[8], as_int=True), + _loj_field(fields[9]), + _loj_field(fields[10], as_int=True), + _loj_field(fields[11]), + ) + ) + elif bed_format == "closest": if len(fields) < 13: raise ValueError( diff --git a/tests/test_binned_join.py b/tests/test_binned_join.py new file mode 100644 index 0000000..3c50757 --- /dev/null +++ b/tests/test_binned_join.py @@ -0,0 +1,1609 @@ +"""Tests for the INTERSECTS binned equi-join transpilation.""" + +import math + +import pytest + +from giql import Table +from giql import transpile + + +def _is_null(value) -> bool: + """Check if a value is null/NaN (DataFusion returns NaN for nullable int64).""" + if value is None: + return True + try: + return math.isnan(value) + except (TypeError, ValueError): + return False + + +class TestTranspileBinnedJoin: + """Unit tests for binned join SQL structure.""" + + def test_basic_binned_join_rewrite(self): + """ + GIVEN a GIQL query joining two tables with column-to-column INTERSECTS + WHEN transpiling with default settings + THEN should produce CTEs with UNNEST/range, equi-join and overlap in ON, + and DISTINCT + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + ) + + sql_upper = sql.upper() + + # CTEs with UNNEST and range + assert "WITH" in sql_upper + assert "UNNEST" in sql_upper + assert "range" in sql or "RANGE" in sql_upper + assert "__giql_bin" in sql + + # Equi-join on chrom and bin + assert '"chrom"' in sql + assert "__giql_bin" in sql + + # Overlap filter in ON (not WHERE) for correct outer-join semantics + assert "ON" in sql_upper + assert '"start"' in sql or '"START"' in sql_upper + assert '"end"' in sql or '"END"' in sql_upper + + # DISTINCT to deduplicate across bins + assert "DISTINCT" in sql_upper + + def test_custom_bin_size(self): + """ + GIVEN a GIQL query with column-to-column INTERSECTS join + WHEN transpiling with bin_size=100000 + THEN should use 100000 in the range expressions + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + bin_size=100000, + ) + + assert "100000" in sql + + def test_custom_column_mappings(self): + """ + GIVEN two tables with different custom column schemas + WHEN transpiling a binned join query + THEN should use each table's custom column names in CTEs, ON, and WHERE + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN features b ON a.interval INTERSECTS b.location + """, + tables=[ + Table( + "peaks", + genomic_col="interval", + chrom_col="chromosome", + start_col="start_pos", + end_col="end_pos", + ), + Table( + "features", + genomic_col="location", + chrom_col="seqname", + start_col="begin", + end_col="terminus", + ), + ], + ) + + # Custom column names for peaks + assert '"chromosome"' in sql + assert '"start_pos"' in sql + assert '"end_pos"' in sql + + # Custom column names for features + assert '"seqname"' in sql + assert '"begin"' in sql + assert '"terminus"' in sql + + # Default column names should NOT appear + assert '"chrom"' not in sql + assert '"start"' not in sql + assert '"end"' not in sql + + def test_literal_intersects_no_binned_ctes(self): + """ + GIVEN a GIQL query with a literal-range INTERSECTS in WHERE (not a join) + WHEN transpiling + THEN should NOT produce binned CTEs + """ + sql = transpile( + "SELECT * FROM peaks WHERE interval INTERSECTS 'chr1:1000-2000'", + tables=["peaks"], + ) + + assert "__giql_bin" not in sql + assert "UNNEST" not in sql.upper() + + def test_no_join_passthrough(self): + """ + GIVEN a simple SELECT query with no JOIN + WHEN transpiling + THEN should NOT produce binned CTEs + """ + sql = transpile( + "SELECT * FROM peaks", + tables=["peaks"], + ) + + assert "__giql_bin" not in sql + assert "UNNEST" not in sql.upper() + assert "__giql_" not in sql + + def test_existing_where_preserved(self): + """ + GIVEN a GIQL join query that already has a WHERE clause + WHEN transpiling a binned join + THEN should preserve the original WHERE condition alongside the overlap filter + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + WHERE a.score > 100 + """, + tables=["peaks", "genes"], + ) + + sql_upper = sql.upper() + + # Original WHERE condition preserved + assert "100" in sql + assert "score" in sql.lower() + + # Overlap filter also present + assert "WHERE" in sql_upper + # Both conditions combined with AND + assert "AND" in sql_upper + + def test_bin_size_none_defaults_to_10000(self): + """ + GIVEN a GIQL join query + WHEN transpiling with bin_size=None (explicit) + THEN should produce the same output as default (10000) + """ + sql_default = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + ) + + sql_none = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + bin_size=None, + ) + + assert sql_default == sql_none + assert "10000" in sql_default + + def test_implicit_cross_join_uses_binned_optimization(self): + """ + GIVEN a GIQL query with implicit cross-join (FROM a, b WHERE INTERSECTS) + WHEN transpiling + THEN should use the binned equi-join optimization without leaking + __giql_bin into SELECT * output columns + """ + sql = transpile( + """ + SELECT DISTINCT a.* + FROM peaks a, genes b + WHERE a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + ) + + # Binned CTEs are present + assert "WITH" in sql.upper() + assert "__giql_bin" in sql + assert "UNNEST" in sql.upper() + + # Original table references preserved — no CTE leak into SELECT * + assert "peaks" in sql + assert '"chrom"' in sql + assert '"start"' in sql + assert '"end"' in sql + + def test_self_join_single_shared_cte(self): + """ + GIVEN a self-join query where the same table appears with two aliases + WHEN transpiling a binned join + THEN should produce one shared key-only CTE for the underlying table, + joined twice through distinct connector aliases + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN peaks b ON a.interval INTERSECTS b.interval + """, + tables=["peaks"], + ) + + sql_upper = sql.upper() + + # One shared CTE keyed on the table name + assert "__giql_peaks_bins" in sql + + # Original table preserved in FROM + assert "peaks" in sql + + # Should still have DISTINCT + assert "DISTINCT" in sql_upper + + def test_invalid_bin_size_raises(self): + """ + GIVEN bin_size=0 or a negative value + WHEN calling transpile + THEN should raise ValueError + """ + with pytest.raises(ValueError, match="positive"): + transpile( + "SELECT * FROM a JOIN b ON a.interval INTERSECTS b.interval", + tables=["a", "b"], + bin_size=0, + ) + + with pytest.raises(ValueError, match="positive"): + transpile( + "SELECT * FROM a JOIN b ON a.interval INTERSECTS b.interval", + tables=["a", "b"], + bin_size=-1, + ) + + def test_multi_join_all_intersects_rewritten(self): + """ + GIVEN a three-way join with two INTERSECTS conditions + WHEN transpiling + THEN should create one key-only CTE per underlying table and rewrite + each INTERSECTS join as a three-join bridge through those CTEs + """ + sql = transpile( + """ + SELECT a.*, b.*, c.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + JOIN exons c ON a.interval INTERSECTS c.interval + """, + tables=["peaks", "genes", "exons"], + ) + + # One CTE per underlying table + assert "__giql_peaks_bins" in sql + assert "__giql_genes_bins" in sql + assert "__giql_exons_bins" in sql + + # __giql_bin appears in CTE definitions and ON conditions + sql_upper = sql.upper() + assert sql_upper.count("__GIQL_BIN") >= 4 # at least 2 per INTERSECTS join + + def test_explicit_columns_uses_full_cte_not_bridge(self): + """ + GIVEN a join query with only explicit columns in SELECT (no wildcards) + WHEN transpiling + THEN should use the 1-join full-CTE approach, not bridge CTEs + """ + sql = transpile( + """ + SELECT a.chrom, a.start, b.start AS b_start + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + ) + + # Full-CTE names (alias-based, one join per INTERSECTS) + assert "__giql_a_binned" in sql + assert "__giql_b_binned" in sql + + # Bridge CTEs must NOT be present + assert "__giql_peaks_bins" not in sql + assert "__giql_c0" not in sql + + def test_wildcard_select_uses_bridge_not_full_cte(self): + """ + GIVEN a join query with a wildcard expression in SELECT (a.*) + WHEN transpiling + THEN should use the bridge CTE approach, not full-CTEs + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=["peaks", "genes"], + ) + + # Bridge CTE names (table-based) + assert "__giql_peaks_bins" in sql + assert "__giql_genes_bins" in sql + + # Full CTEs must NOT be present + assert "__giql_a_binned" not in sql + assert "__giql_b_binned" not in sql + + +class TestBinnedJoinDataFusion: + """End-to-end DataFusion correctness tests for binned INTERSECTS joins.""" + + @staticmethod + def _make_ctx(peaks_data, genes_data): + """Create a DataFusion context with two interval tables.""" + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ] + ) + + ctx = SessionContext() + + peaks_arrays = { + "chrom": [r[0] for r in peaks_data], + "start": [r[1] for r in peaks_data], + "end": [r[2] for r in peaks_data], + } + genes_arrays = { + "chrom": [r[0] for r in genes_data], + "start": [r[1] for r in genes_data], + "end": [r[2] for r in genes_data], + } + + ctx.register_record_batches( + "peaks", [pa.table(peaks_arrays, schema=schema).to_batches()] + ) + ctx.register_record_batches( + "genes", [pa.table(genes_arrays, schema=schema).to_batches()] + ) + return ctx + + def test_overlapping_intervals_correct_rows_no_duplicates(self): + """ + GIVEN two tables with overlapping intervals + WHEN executing a binned INTERSECTS join via DataFusion + THEN should return the correct matching rows with no duplicates + """ + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr1", 1000, 2000)], + genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)], + ) + + sql = transpile( + """ + SELECT a.chrom, a.start, a."end", b.start AS b_start, b."end" AS b_end + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(sql).to_pandas() + + # Only the first peak (100-500) overlaps the first gene (300-600) + assert len(df) == 1 + assert df.iloc[0]["start"] == 100 + assert df.iloc[0]["b_start"] == 300 + + def test_non_overlapping_intervals_zero_rows(self): + """ + GIVEN two tables with no overlapping intervals + WHEN executing a binned INTERSECTS join via DataFusion + THEN should return zero rows + """ + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 200), ("chr1", 300, 400)], + genes_data=[("chr1", 500, 600), ("chr1", 700, 800)], + ) + + sql = transpile( + """ + SELECT a.chrom, a.start, a."end" + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(sql).to_pandas() + assert len(df) == 0 + + def test_adjacent_intervals_zero_rows_half_open(self): + """ + GIVEN two tables with adjacent (touching) intervals under half-open coordinates + WHEN executing a binned INTERSECTS join via DataFusion + THEN should return zero rows because [100, 200) and [200, 300) do not overlap + """ + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 200)], + genes_data=[("chr1", 200, 300)], + ) + + sql = transpile( + """ + SELECT a.chrom, a.start, a."end" + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(sql).to_pandas() + assert len(df) == 0 + + def test_different_chromosomes_only_same_chrom(self): + """ + GIVEN two tables with intervals on different chromosomes that would overlap positionally + WHEN executing a binned INTERSECTS join via DataFusion + THEN should only return overlaps on the same chromosome + """ + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr2", 100, 500)], + genes_data=[("chr1", 200, 400), ("chr3", 200, 400)], + ) + + sql = transpile( + """ + SELECT a.chrom, a.start, a."end", + b.chrom AS b_chrom, b.start AS b_start, b."end" AS b_end + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(sql).to_pandas() + + assert len(df) == 1 + assert df.iloc[0]["chrom"] == "chr1" + assert df.iloc[0]["b_chrom"] == "chr1" + + def test_intervals_spanning_multiple_bins_no_duplicates(self): + """ + GIVEN intervals that span multiple bins + WHEN executing a binned INTERSECTS join via DataFusion + THEN overlapping pairs should be returned exactly once (DISTINCT dedup) + """ + ctx = self._make_ctx( + peaks_data=[("chr1", 0, 50000)], + genes_data=[("chr1", 25000, 75000)], + ) + + sql = transpile( + """ + SELECT a.chrom, a.start, a."end", b.start AS b_start, b."end" AS b_end + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + bin_size=10000, + ) + + df = ctx.sql(sql).to_pandas() + + # Despite sharing multiple bins (2, 3, 4), should appear exactly once + assert len(df) == 1 + assert df.iloc[0]["start"] == 0 + assert df.iloc[0]["end"] == 50000 + assert df.iloc[0]["b_start"] == 25000 + assert df.iloc[0]["b_end"] == 75000 + + def test_equivalence_with_naive_cross_join(self): + """ + GIVEN two tables with a mix of overlapping and non-overlapping intervals + WHEN executing a binned INTERSECTS join via DataFusion + THEN results should match a naive cross-join with overlap filter + """ + ctx = self._make_ctx( + peaks_data=[ + ("chr1", 0, 100), + ("chr1", 150, 300), + ("chr1", 500, 1000), + ("chr2", 0, 500), + ], + genes_data=[ + ("chr1", 50, 200), + ("chr1", 250, 600), + ("chr1", 900, 1100), + ("chr2", 400, 800), + ], + ) + + naive_sql = """ + SELECT a.chrom AS a_chrom, a.start AS a_start, a."end" AS a_end, + b.chrom AS b_chrom, b.start AS b_start, b."end" AS b_end + FROM peaks a, genes b + WHERE a.chrom = b.chrom + AND a.start < b."end" + AND a."end" > b.start + ORDER BY a.chrom, a.start, b.start + """ + naive_df = ctx.sql(naive_sql).to_pandas() + + binned_sql = transpile( + """ + SELECT a.chrom AS a_chrom, a.start AS a_start, a."end" AS a_end, + b.chrom AS b_chrom, b.start AS b_start, b."end" AS b_end + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + binned_df = ( + ctx.sql(binned_sql) + .to_pandas() + .sort_values(by=["a_chrom", "a_start", "b_start"]) + .reset_index(drop=True) + ) + naive_df = naive_df.reset_index(drop=True) + + assert len(binned_df) == len(naive_df) + assert binned_df.values.tolist() == naive_df.values.tolist() + + def test_implicit_cross_join_correct_rows_no_bin_leak(self): + """ + GIVEN two tables with overlapping intervals queried via implicit cross-join syntax + WHEN executing a binned INTERSECTS join via DataFusion + THEN results should be correct and SELECT a.* should not include __giql_bin + """ + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr1", 1000, 2000)], + genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)], + ) + + sql = transpile( + """ + SELECT DISTINCT a.* + FROM peaks a, genes b + WHERE a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(sql).to_pandas() + + # Only the first peak (100-500) overlaps the first gene (300-600) + assert len(df) == 1 + assert df.iloc[0]["start"] == 100 + + # SELECT a.* must return exactly the original table columns — no __giql_bin + assert list(df.columns) == ["chrom", "start", "end"] + + +class TestBinnedJoinOuterJoinSemantics: + """Regression tests: outer join kinds must be preserved after rewrite. + + Bug: the bridge path only applied the join kind (LEFT, RIGHT, FULL) to + join3, while join1 and join2 were always INNER — silently converting + outer joins into inner joins. + """ + + @staticmethod + def _make_ctx(peaks_data, genes_data): + """Create a DataFusion context with peaks and genes tables.""" + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "peaks", + [ + pa.table( + { + "chrom": [r[0] for r in peaks_data], + "start": [r[1] for r in peaks_data], + "end": [r[2] for r in peaks_data], + }, + schema=schema, + ).to_batches() + ], + ) + ctx.register_record_batches( + "genes", + [ + pa.table( + { + "chrom": [r[0] for r in genes_data], + "start": [r[1] for r in genes_data], + "end": [r[2] for r in genes_data], + }, + schema=schema, + ).to_batches() + ], + ) + return ctx + + def test_left_join_preserves_unmatched_left_rows_full_cte(self): + """ + GIVEN peaks with one matching and one non-matching interval + WHEN a LEFT JOIN with INTERSECTS is transpiled (no wildcards, full-CTE path) + THEN the SQL must contain LEFT keyword and execution must return all + left rows including unmatched ones with NULLs on the right + """ + sql = transpile( + """ + SELECT a.chrom, a.start, a."end", b.start AS b_start + FROM peaks a + LEFT JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "LEFT" in sql.upper() + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr1", 1000, 2000)], + genes_data=[("chr1", 300, 600)], + ) + df = ctx.sql(sql).to_pandas().sort_values("start").reset_index(drop=True) + + assert len(df) == 2 + assert df.iloc[0]["start"] == 100 + assert df.iloc[0]["b_start"] == 300 + assert df.iloc[1]["start"] == 1000 + assert _is_null(df.iloc[1]["b_start"]) + + def test_left_join_preserves_unmatched_left_rows_bridge(self): + """ + GIVEN peaks with one matching and one non-matching interval + WHEN a LEFT JOIN with INTERSECTS is transpiled (wildcards, bridge path) + THEN the SQL must contain LEFT keyword and execution must return all + left rows including unmatched ones with NULLs on the right + """ + sql = transpile( + """ + SELECT a.*, b.start AS b_start + FROM peaks a + LEFT JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "LEFT" in sql.upper() + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr1", 1000, 2000)], + genes_data=[("chr1", 300, 600)], + ) + df = ctx.sql(sql).to_pandas().sort_values("start").reset_index(drop=True) + + assert len(df) == 2 + assert df.iloc[0]["start"] == 100 + assert df.iloc[0]["b_start"] == 300 + assert df.iloc[1]["start"] == 1000 + assert _is_null(df.iloc[1]["b_start"]) + + def test_right_join_preserves_unmatched_right_rows_full_cte(self): + """ + GIVEN genes with one matching and one non-matching interval + WHEN a RIGHT JOIN with INTERSECTS is transpiled (no wildcards, full-CTE path) + THEN the SQL must contain RIGHT keyword and execution must return all + right rows including unmatched ones with NULLs on the left + """ + sql = transpile( + """ + SELECT a.start AS a_start, b.chrom, b.start, b."end" + FROM peaks a + RIGHT JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "RIGHT" in sql.upper() + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500)], + genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)], + ) + df = ctx.sql(sql).to_pandas().sort_values("start").reset_index(drop=True) + + assert len(df) == 2 + matched = df[df["a_start"].notna()] + unmatched = df[df["a_start"].isna()] + assert len(matched) == 1 + assert matched.iloc[0]["start"] == 300 + assert len(unmatched) == 1 + assert unmatched.iloc[0]["start"] == 5000 + + def test_right_join_preserves_unmatched_right_rows_bridge(self): + """ + GIVEN genes with one matching and one non-matching interval + WHEN a RIGHT JOIN with INTERSECTS is transpiled (wildcards, bridge path) + THEN the SQL must contain RIGHT keyword and execution must return all + right rows including unmatched ones with NULLs on the left + """ + sql = transpile( + """ + SELECT a.start AS a_start, b.* + FROM peaks a + RIGHT JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "RIGHT" in sql.upper() + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500)], + genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)], + ) + df = ctx.sql(sql).to_pandas().sort_values("start").reset_index(drop=True) + + assert len(df) == 2 + matched = df[df["a_start"].notna()] + unmatched = df[df["a_start"].isna()] + assert len(matched) == 1 + assert matched.iloc[0]["start"] == 300 + assert len(unmatched) == 1 + assert unmatched.iloc[0]["start"] == 5000 + + def test_full_outer_join_preserves_both_unmatched_full_cte(self): + """ + GIVEN peaks and genes each with one matching and one non-matching interval + WHEN a FULL OUTER JOIN with INTERSECTS is transpiled (no wildcards, full-CTE) + THEN the SQL must contain FULL keyword and execution must return three + rows: one matched pair plus one unmatched from each side + """ + sql = transpile( + """ + SELECT a.start AS a_start, a."end" AS a_end, + b.start AS b_start, b."end" AS b_end + FROM peaks a + FULL OUTER JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "FULL" in sql.upper() + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr1", 8000, 9000)], + genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)], + ) + df = ctx.sql(sql).to_pandas() + + assert len(df) == 3 + matched = df[df["a_start"].notna() & df["b_start"].notna()] + left_only = df[df["a_start"].notna() & df["b_start"].isna()] + right_only = df[df["a_start"].isna() & df["b_start"].notna()] + assert len(matched) == 1 + assert len(left_only) == 1 + assert len(right_only) == 1 + + def test_full_outer_join_preserves_both_unmatched_bridge(self): + """ + GIVEN peaks and genes each with one matching and one non-matching interval + WHEN a FULL OUTER JOIN with INTERSECTS is transpiled (wildcards, bridge path) + THEN the SQL must contain FULL keyword and execution must return three + rows: one matched pair plus one unmatched from each side + """ + sql = transpile( + """ + SELECT a.*, b.start AS b_start, b."end" AS b_end + FROM peaks a + FULL OUTER JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "FULL" in sql.upper() + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 500), ("chr1", 8000, 9000)], + genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)], + ) + df = ctx.sql(sql).to_pandas() + + assert len(df) == 3 + matched = df[df["start"].notna() & df["b_start"].notna()] + left_only = df[df["start"].notna() & df["b_start"].isna()] + right_only = df[df["start"].isna() & df["b_start"].notna()] + assert len(matched) == 1 + assert len(left_only) == 1 + assert len(right_only) == 1 + + def test_left_join_all_unmatched_returns_all_left_rows(self): + """ + GIVEN peaks where no intervals overlap any gene + WHEN a LEFT JOIN with INTERSECTS is transpiled + THEN all left rows must still appear with NULLs on the right + """ + sql = transpile( + """ + SELECT a.chrom, a.start, a."end", b.start AS b_start + FROM peaks a + LEFT JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + ctx = self._make_ctx( + peaks_data=[("chr1", 100, 200), ("chr1", 300, 400)], + genes_data=[("chr1", 500, 600)], + ) + df = ctx.sql(sql).to_pandas() + + assert len(df) == 2 + assert df["b_start"].isna().all() + + +class TestBinnedJoinAdditionalOnConditions: + """Regression tests: non-INTERSECTS conditions in ON must be preserved. + + Bug: the rewrite replaces the entire ON clause with the binned equi-join + and overlap predicate, silently dropping any additional user conditions + like ``AND a.score > b.score``. + """ + + @staticmethod + def _make_ctx_with_score(): + """Create a DataFusion context with peaks and genes tables that include a score column.""" + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ("score", pa.int64()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "peaks", + [ + pa.table( + { + "chrom": ["chr1", "chr1"], + "start": [100, 100], + "end": [500, 500], + "score": [10, 50], + }, + schema=schema, + ).to_batches() + ], + ) + ctx.register_record_batches( + "genes", + [ + pa.table( + { + "chrom": ["chr1", "chr1"], + "start": [200, 200], + "end": [600, 600], + "score": [30, 30], + }, + schema=schema, + ).to_batches() + ], + ) + return ctx + + def test_additional_on_condition_preserved_full_cte(self): + """ + GIVEN two overlapping intervals where only one pair satisfies score filter + WHEN INTERSECTS is combined with a.score > b.score in ON (no wildcards) + THEN the additional condition must survive the rewrite and filter results + """ + sql = transpile( + """ + SELECT a.chrom, a.start, a."end", a.score AS a_score, b.score AS b_score + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval AND a.score > b.score + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "score" in sql.lower() + + ctx = self._make_ctx_with_score() + df = ctx.sql(sql).to_pandas() + + assert len(df) == 1 + assert df.iloc[0]["a_score"] == 50 + + def test_additional_on_condition_preserved_bridge(self): + """ + GIVEN two overlapping intervals where only one pair satisfies score filter + WHEN INTERSECTS is combined with a.score > b.score in ON (wildcards) + THEN the additional condition must survive the rewrite and filter results + """ + sql = transpile( + """ + SELECT a.*, b.score AS b_score + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval AND a.score > b.score + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "score" in sql.lower() + + ctx = self._make_ctx_with_score() + df = ctx.sql(sql).to_pandas() + + assert len(df) == 1 + assert df.iloc[0]["score"] == 50 + + def test_additional_on_condition_with_left_join(self): + """ + GIVEN overlapping intervals with an extra ON condition that filters all matches + WHEN LEFT JOIN with INTERSECTS AND a.score > b.score is used + THEN unmatched left rows appear with NULL right columns + """ + sql = transpile( + """ + SELECT a.chrom, a.start, a.score AS a_score, b.score AS b_score + FROM peaks a + LEFT JOIN genes b + ON a.interval INTERSECTS b.interval AND a.score > b.score + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + ctx = self._make_ctx_with_score() + df = ctx.sql(sql).to_pandas().sort_values("a_score").reset_index(drop=True) + + assert len(df) == 2 + row_low = df[df["a_score"] == 10].iloc[0] + row_high = df[df["a_score"] == 50].iloc[0] + assert _is_null(row_low["b_score"]) + assert row_high["b_score"] == 30 + + def test_multiple_additional_conditions_preserved(self): + """ + GIVEN overlapping intervals with two extra ON conditions + WHEN INTERSECTS is combined with a.score > 20 AND b.score < 40 in ON + THEN both conditions must survive the rewrite + """ + sql = transpile( + """ + SELECT a.chrom, a.start, a.score AS a_score, b.score AS b_score + FROM peaks a + JOIN genes b + ON a.interval INTERSECTS b.interval + AND a.score > 20 + AND b.score < 40 + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + sql_lower = sql.lower() + assert "score" in sql_lower + assert "20" in sql + assert "40" in sql + + def test_additional_on_condition_implicit_cross_join(self): + """ + GIVEN overlapping intervals queried via implicit cross-join with extra WHERE + WHEN INTERSECTS is in WHERE alongside a.score > b.score + THEN the score condition must be preserved in the output SQL + """ + sql = transpile( + """ + SELECT a.chrom, a.start, a.score AS a_score, b.score AS b_score + FROM peaks a, genes b + WHERE a.interval INTERSECTS b.interval AND a.score > b.score + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + assert "score" in sql.lower() + + ctx = self._make_ctx_with_score() + df = ctx.sql(sql).to_pandas() + + assert len(df) == 1 + assert df.iloc[0]["a_score"] == 50 + + +class TestBinnedJoinDistinctSemantics: + """Regression tests: unconditional DISTINCT can collapse legitimate duplicates. + + Bug: the transformer always adds DISTINCT to deduplicate bin fan-out, + but this also collapses rows that are genuinely duplicated in the source + data, changing SQL bag semantics. + """ + + @staticmethod + def _make_ctx_with_duplicates(): + """Create a DataFusion context where peaks has duplicate rows.""" + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "peaks", + [ + pa.table( + { + "chrom": ["chr1", "chr1"], + "start": [100, 100], + "end": [500, 500], + }, + schema=schema, + ).to_batches() + ], + ) + ctx.register_record_batches( + "genes", + [ + pa.table( + { + "chrom": ["chr1"], + "start": [200], + "end": [600], + }, + schema=schema, + ).to_batches() + ], + ) + return ctx + + @pytest.mark.xfail( + reason="Unconditional DISTINCT collapses legitimate duplicate rows", + strict=True, + ) + def test_duplicate_rows_preserved_full_cte(self): + """ + GIVEN peaks with two identical rows that both overlap one gene + WHEN an inner join with INTERSECTS is transpiled (no wildcards, full-CTE) + THEN both rows should be returned, matching naive cross-join behavior + """ + ctx = self._make_ctx_with_duplicates() + + naive_sql = """ + SELECT a.chrom, a.start, a."end", b.start AS b_start + FROM peaks a, genes b + WHERE a.chrom = b.chrom AND a.start < b."end" AND a."end" > b.start + """ + naive_df = ctx.sql(naive_sql).to_pandas() + assert len(naive_df) == 2 + + binned_sql = transpile( + """ + SELECT a.chrom, a.start, a."end", b.start AS b_start + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + binned_df = ctx.sql(binned_sql).to_pandas() + assert len(binned_df) == len(naive_df) + + @pytest.mark.xfail( + reason="Unconditional DISTINCT collapses legitimate duplicate rows", + strict=True, + ) + def test_duplicate_rows_preserved_bridge(self): + """ + GIVEN peaks with two identical rows that both overlap one gene + WHEN an inner join with INTERSECTS is transpiled (wildcards, bridge path) + THEN both rows should be returned, matching naive cross-join behavior + """ + ctx = self._make_ctx_with_duplicates() + + naive_sql = """ + SELECT a.chrom, a.start, a."end", b.start AS b_start + FROM peaks a, genes b + WHERE a.chrom = b.chrom AND a.start < b."end" AND a."end" > b.start + """ + naive_df = ctx.sql(naive_sql).to_pandas() + assert len(naive_df) == 2 + + binned_sql = transpile( + """ + SELECT a.*, b.start AS b_start + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + binned_df = ctx.sql(binned_sql).to_pandas() + assert len(binned_df) == len(naive_df) + + def test_non_duplicate_rows_unaffected(self): + """ + GIVEN peaks with two distinct rows that both overlap one gene + WHEN an inner join with INTERSECTS is transpiled + THEN DISTINCT does not collapse them because they differ + """ + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "peaks", + [ + pa.table( + { + "chrom": ["chr1", "chr1"], + "start": [100, 150], + "end": [500, 550], + }, + schema=schema, + ).to_batches() + ], + ) + ctx.register_record_batches( + "genes", + [ + pa.table( + { + "chrom": ["chr1"], + "start": [200], + "end": [600], + }, + schema=schema, + ).to_batches() + ], + ) + + binned_sql = transpile( + """ + SELECT a.chrom, a.start, a."end", b.start AS b_start + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(binned_sql).to_pandas() + assert len(df) == 2 + + def test_user_distinct_already_present_still_works(self): + """ + GIVEN a query that already has SELECT DISTINCT + WHEN the binned join rewrite also adds DISTINCT + THEN the query must still execute correctly (no double-DISTINCT error) + """ + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "peaks", + [ + pa.table( + {"chrom": ["chr1"], "start": [100], "end": [500]}, + schema=schema, + ).to_batches() + ], + ) + ctx.register_record_batches( + "genes", + [ + pa.table( + {"chrom": ["chr1"], "start": [200], "end": [600]}, + schema=schema, + ).to_batches() + ], + ) + + binned_sql = transpile( + """ + SELECT DISTINCT a.chrom, a.start, b.start AS b_start + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.interval + """, + tables=[ + Table("peaks", chrom_col="chrom", start_col="start", end_col="end"), + Table("genes", chrom_col="chrom", start_col="start", end_col="end"), + ], + ) + + df = ctx.sql(binned_sql).to_pandas() + assert len(df) == 1 + + +class TestBinnedJoinBinBoundaryRounding: + """Regression tests for bin-index calculation rounding errors. + + The original formula CAST(start / B AS BIGINT) uses float division + followed by a cast. When the division lands on x.5 the cast rounds + to nearest-even instead of flooring, producing the wrong bin index + and causing missed matches. + """ + + @staticmethod + def _make_ctx(table_a_data, table_b_data): + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ("name", pa.utf8()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "intervals_a", + [pa.table(table_a_data, schema=schema).to_batches()], + ) + ctx.register_record_batches( + "intervals_b", + [pa.table(table_b_data, schema=schema).to_batches()], + ) + return ctx + + def test_half_bin_boundary_overlap_not_missed(self): + """ + GIVEN interval A spanning many bins and interval B whose start + falls exactly on a .5 division boundary (e.g., 621950/100) + WHEN INTERSECTS is evaluated with bin_size=100 on DuckDB + THEN the overlap must be found, not missed due to rounding + """ + import duckdb + + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE intervals_a " + '(chrom VARCHAR, "start" INTEGER, "end" INTEGER, ' + "name VARCHAR)" + ) + conn.execute( + "CREATE TABLE intervals_b " + '(chrom VARCHAR, "start" INTEGER, "end" INTEGER, ' + "name VARCHAR)" + ) + conn.execute("INSERT INTO intervals_a VALUES ('chr1', 421951, 621951, 'a0')") + conn.execute("INSERT INTO intervals_b VALUES ('chr1', 621950, 621951, 'b0')") + + sql = transpile( + """ + SELECT DISTINCT a.name, b.name AS b_name + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + bin_size=100, + ) + result = conn.execute(sql).fetchall() + conn.close() + assert len(result) == 1, ( + f"Expected 1 match, got {len(result)} — " + f"bin boundary rounding likely dropped the overlap" + ) + + def test_exact_bin_boundary_start(self): + """ + GIVEN interval B starting at an exact multiple of bin_size + WHEN INTERSECTS is evaluated on DuckDB + THEN the correct bin index is assigned (no off-by-one from rounding) + """ + import duckdb + + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE intervals_a " + '(chrom VARCHAR, "start" INTEGER, "end" INTEGER, ' + "name VARCHAR)" + ) + conn.execute( + "CREATE TABLE intervals_b " + '(chrom VARCHAR, "start" INTEGER, "end" INTEGER, ' + "name VARCHAR)" + ) + conn.execute("INSERT INTO intervals_a VALUES ('chr1', 999, 1001, 'a0')") + conn.execute("INSERT INTO intervals_b VALUES ('chr1', 1000, 1001, 'b0')") + + sql = transpile( + """ + SELECT DISTINCT a.name, b.name AS b_name + FROM intervals_a a + JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + bin_size=1000, + ) + result = conn.execute(sql).fetchall() + conn.close() + assert len(result) == 1, f"Expected 1 match at bin boundary, got {len(result)}" + + +class TestBinnedJoinOuterJoinMultiBin: + """Regression tests for outer join with multi-bin intervals. + + When an interval spans multiple bins, the outer join produces one + row per bin. Bins that don't match the other side create spurious + NULL rows. DISTINCT can't collapse a NULL row with a matched row + because they differ in the non-NULL columns. + """ + + @staticmethod + def _make_ctx(table_a_data, table_b_data): + import pyarrow as pa + from datafusion import SessionContext + + schema = pa.schema( + [ + ("chrom", pa.utf8()), + ("start", pa.int64()), + ("end", pa.int64()), + ("name", pa.utf8()), + ] + ) + ctx = SessionContext() + ctx.register_record_batches( + "intervals_a", + [pa.table(table_a_data, schema=schema).to_batches()], + ) + ctx.register_record_batches( + "intervals_b", + [pa.table(table_b_data, schema=schema).to_batches()], + ) + return ctx + + def test_left_join_no_spurious_null_row(self): + """ + GIVEN interval A spanning bins 0 and 1 and interval B only in bin 1 + WHEN LEFT JOIN INTERSECTS is evaluated + THEN only 1 matched row is returned, not a matched row plus a + spurious NULL row from the unmatched bin-0 copy + """ + ctx = self._make_ctx( + { + "chrom": ["chr1"], + "start": [9000], + "end": [11000], + "name": ["a0"], + }, + { + "chrom": ["chr1"], + "start": [10500], + "end": [10600], + "name": ["b0"], + }, + ) + + sql = transpile( + """ + SELECT a.name, b.name AS b_name + FROM intervals_a a + LEFT JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + result = ctx.sql(sql).to_pandas() + assert len(result) == 1, ( + f"Expected 1 matched row, got {len(result)} — " + f"spurious NULL row from unmatched bin" + ) + assert result.iloc[0]["b_name"] == "b0" + + def test_left_join_unmatched_still_returns_null(self): + """ + GIVEN interval A with no overlap in B + WHEN LEFT JOIN INTERSECTS is evaluated + THEN one row with NULL B columns is returned + """ + ctx = self._make_ctx( + { + "chrom": ["chr1"], + "start": [9000], + "end": [11000], + "name": ["a0"], + }, + { + "chrom": ["chr2"], + "start": [9500], + "end": [10500], + "name": ["b0"], + }, + ) + + sql = transpile( + """ + SELECT a.name, b.name AS b_name + FROM intervals_a a + LEFT JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + result = ctx.sql(sql).to_pandas() + assert len(result) == 1, f"Expected 1 unmatched row, got {len(result)}" + assert _is_null(result.iloc[0]["b_name"]) + + def test_right_join_no_spurious_null_row(self): + """ + GIVEN interval B spanning bins 0 and 1 and interval A only in bin 0 + WHEN RIGHT JOIN INTERSECTS is evaluated + THEN only 1 matched row is returned, not a matched row plus a + spurious NULL row from the unmatched bin-1 copy of B + """ + ctx = self._make_ctx( + { + "chrom": ["chr1"], + "start": [9500], + "end": [9600], + "name": ["a0"], + }, + { + "chrom": ["chr1"], + "start": [9000], + "end": [11000], + "name": ["b0"], + }, + ) + + sql = transpile( + """ + SELECT a.name, b.name AS b_name + FROM intervals_a a + RIGHT JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + result = ctx.sql(sql).to_pandas() + assert len(result) == 1, ( + f"Expected 1 matched row, got {len(result)} — " + f"spurious NULL row from unmatched bin" + ) + assert result.iloc[0]["name"] == "a0" + + def test_full_outer_join_no_spurious_null_row(self): + """ + GIVEN interval A spanning bins 0 and 1, interval B only in bin 1 + WHEN FULL OUTER JOIN INTERSECTS is evaluated + THEN only 1 matched row is returned, not a matched row plus a + spurious NULL row + """ + ctx = self._make_ctx( + { + "chrom": ["chr1"], + "start": [9000], + "end": [11000], + "name": ["a0"], + }, + { + "chrom": ["chr1"], + "start": [10500], + "end": [10600], + "name": ["b0"], + }, + ) + + sql = transpile( + """ + SELECT a.name, b.name AS b_name + FROM intervals_a a + FULL OUTER JOIN intervals_b b ON a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + result = ctx.sql(sql).to_pandas() + assert len(result) == 1, ( + f"Expected 1 matched row, got {len(result)} — " + f"spurious NULL row from unmatched bin" + )