diff --git a/docs/dialect/spatial-operators.rst b/docs/dialect/spatial-operators.rst
index 6bf4433..1a8b3cd 100644
--- a/docs/dialect/spatial-operators.rst
+++ b/docs/dialect/spatial-operators.rst
@@ -99,6 +99,24 @@ Find all variants, with gene information where available:
FROM variants v
LEFT JOIN genes g ON v.interval INTERSECTS g.interval
+Deduplication Behavior
+~~~~~~~~~~~~~~~~~~~~~~
+
+Column-to-column ``INTERSECTS`` joins use a binned equi-join strategy internally: each interval is assigned to one or more fixed-width bins, and the join is performed on ``(chrom, bin)`` pairs. Because an interval that spans a bin boundary belongs to more than one bin, a single source row can match the same result row more than once. GIQL adds ``SELECT DISTINCT`` automatically to remove these duplicate rows.
+
+This deduplication is usually transparent, but it has one observable side effect: ``DISTINCT`` operates on the entire set of selected columns, so rows that are genuinely identical across every selected column will also be collapsed into one. This matters when a table contains duplicate source records with no distinguishing column.
+
+To prevent unintended deduplication, include any column that makes rows distinguishable — such as a primary key, name, or score — in the ``SELECT`` list:
+
+.. code-block:: sql
+
+ -- score distinguishes otherwise-identical rows
+ SELECT v.chrom, v.start, v.end, v.score, g.name
+ FROM variants v
+ INNER JOIN genes g ON v.interval INTERSECTS g.interval
+
+If all columns are identical across two source rows (including any unique identifier), those rows represent the same logical record and collapsing them is correct behavior.
+
Related Operators
~~~~~~~~~~~~~~~~~
diff --git a/pyproject.toml b/pyproject.toml
index 480cf1e..647358b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -43,6 +43,7 @@ dev = [
"pytest-cov>=4.0.0",
"pytest>=7.0.0",
"ruff>=0.1.0",
+ "datafusion>=52.3.0",
]
docs = [
"sphinx>=7.0",
@@ -82,6 +83,7 @@ bedtools = ">=2.31.0"
pybedtools = ">=0.9.0"
pytest = ">=7.0.0"
pytest-cov = ">=4.0.0"
+datafusion = ">=43.0.0"
duckdb = ">=1.4.0"
pandas = ">=2.0.0"
sqlglot = ">=20.0.0,<30"
diff --git a/src/giql/__init__.py b/src/giql/__init__.py
index 71e895d..e5df351 100644
--- a/src/giql/__init__.py
+++ b/src/giql/__init__.py
@@ -3,6 +3,7 @@
A SQL dialect for genomic range queries.
"""
+from giql.constants import DEFAULT_BIN_SIZE
from giql.table import Table
from giql.transpile import transpile
@@ -10,6 +11,7 @@
__all__ = [
+ "DEFAULT_BIN_SIZE",
"Table",
"transpile",
]
diff --git a/src/giql/constants.py b/src/giql/constants.py
index 87f8055..0daf016 100644
--- a/src/giql/constants.py
+++ b/src/giql/constants.py
@@ -9,3 +9,6 @@
DEFAULT_END_COL = "end"
DEFAULT_STRAND_COL = "strand"
DEFAULT_GENOMIC_COL = "interval"
+
+# Default bin size for INTERSECTS binned equi-join optimization
+DEFAULT_BIN_SIZE = 10_000
diff --git a/src/giql/transformer.py b/src/giql/transformer.py
index de1e70f..789e3c5 100644
--- a/src/giql/transformer.py
+++ b/src/giql/transformer.py
@@ -1,17 +1,22 @@
"""Query transformers for GIQL operations.
This module contains transformers that rewrite queries containing GIQL-specific
-operations (like CLUSTER and MERGE) into equivalent SQL with CTEs.
+operations (like CLUSTER, MERGE, and binned INTERSECTS joins) into equivalent
+SQL with CTEs.
"""
+import itertools
+
from sqlglot import exp
+from giql.constants import DEFAULT_BIN_SIZE
from giql.constants import DEFAULT_CHROM_COL
from giql.constants import DEFAULT_END_COL
from giql.constants import DEFAULT_START_COL
from giql.constants import DEFAULT_STRAND_COL
from giql.expressions import GIQLCluster
from giql.expressions import GIQLMerge
+from giql.expressions import Intersects
from giql.table import Tables
@@ -573,3 +578,897 @@ def _transform_for_merge(
)
return final_query
+
+
+class IntersectsBinnedJoinTransformer:
+ """Transforms column-to-column INTERSECTS into binned equi-joins.
+
+ Handles both explicit JOIN ON and implicit cross-join (WHERE) patterns.
+ Two rewrite strategies are selected based on the SELECT list:
+
+ **No wildcards** (``SELECT a.chrom, b.start, ...``) — ``__giql_bin``
+ cannot appear in the output regardless of CTE content, so the simpler
+ 1-join full-CTE approach is used:
+
+ WITH __giql_a_binned AS (
+ SELECT *, UNNEST(range(
+ CAST("start" / B AS BIGINT),
+ CAST(("end" - 1) / B + 1 AS BIGINT)
+ )) AS __giql_bin FROM peaks
+ ),
+ __giql_b_binned AS (...)
+ SELECT DISTINCT a.chrom, b.start, ...
+ FROM __giql_a_binned AS a
+ JOIN __giql_b_binned AS b
+ ON a."chrom" = b."chrom" AND a.__giql_bin = b.__giql_bin
+ AND a."start" < b."end" AND a."end" > b."start"
+
+ **Wildcards present** (``SELECT a.*, b.*``) — ``__giql_bin`` would leak
+ into ``a.*`` expansion if ``a`` aliases a full-select CTE. A key-only
+ bridge CTE pattern is used instead, keeping original table references:
+
+ WITH __giql_peaks_bins AS (
+ SELECT "chrom", "start", "end",
+ UNNEST(range(...)) AS __giql_bin FROM peaks
+ ),
+ __giql_genes_bins AS (...)
+ SELECT DISTINCT a.*, b.*
+ FROM peaks a
+ JOIN __giql_peaks_bins __giql_c0
+ ON a."chrom" = __giql_c0."chrom" AND a."start" = __giql_c0."start"
+ AND a."end" = __giql_c0."end"
+ JOIN __giql_genes_bins __giql_c1
+ ON __giql_c0."chrom" = __giql_c1."chrom"
+ AND __giql_c0.__giql_bin = __giql_c1.__giql_bin
+ JOIN genes b
+ ON b."chrom" = __giql_c1."chrom" AND b."start" = __giql_c1."start"
+ AND b."end" = __giql_c1."end"
+ AND a."start" < b."end" AND a."end" > b."start"
+
+ Literal-range INTERSECTS (e.g., ``WHERE interval INTERSECTS 'chr1:...'``)
+ are left untouched.
+
+ SELECT DISTINCT is added to deduplicate rows produced by multi-bin
+ matches. This means rows that are identical across every selected
+ column will be collapsed — include a distinguishing column (e.g., an
+ id or score) to preserve duplicates that differ only in unselected
+ columns. The bridge path's key-match joins on ``(chrom, start,
+ end)`` and may fan out if multiple source rows share those values;
+ DISTINCT corrects for this.
+ """
+
+ def __init__(self, tables: Tables, bin_size: int | None = None):
+ """Initialize transformer.
+
+ :param tables:
+ Table configurations for column mapping
+ :param bin_size:
+ Bin width for the equi-join rewrite. Defaults to
+ DEFAULT_BIN_SIZE if not specified.
+ """
+ self.tables = tables
+ resolved = bin_size if bin_size is not None else DEFAULT_BIN_SIZE
+ if not isinstance(resolved, int) or resolved <= 0:
+ raise ValueError(f"bin_size must be a positive integer, got {resolved!r}")
+ self.bin_size = resolved
+
+ def transform(self, query: exp.Expression) -> exp.Expression:
+ if not isinstance(query, exp.Select):
+ return query
+
+ # Outer joins need the pairs-CTE approach: compute matching key
+ # pairs via an INNER binned join (correctly deduplicated), then
+ # outer-join the original tables through the pairs CTE. This
+ # avoids the bin fan-out that creates spurious NULL rows when an
+ # interval spans multiple bins but only matches in some of them.
+ if self._has_outer_join_intersects(query):
+ return self._transform_with_pairs(query)
+ if self._select_has_wildcards(query):
+ return self._transform_bridge(query)
+ return self._transform_full_cte(query)
+
+ def _select_has_wildcards(self, query: exp.Select) -> bool:
+ """Return True if any SELECT item is a wildcard (* or table.*)."""
+ for expr in query.expressions:
+ if isinstance(expr, exp.Star):
+ return True
+ if isinstance(expr, exp.Column) and isinstance(expr.this, exp.Star):
+ return True
+ return False
+
+ def _has_outer_join_intersects(self, query: exp.Select) -> bool:
+ """Return True if any outer JOIN has an INTERSECTS predicate."""
+ for join in query.args.get("joins") or []:
+ if join.args.get("side") and join.args.get("on"):
+ if self._find_column_intersects_in(join.args["on"]):
+ return True
+ return False
+
+ def _transform_with_pairs(self, query: exp.Select) -> exp.Select:
+ """Transform using a pairs CTE for correct outer join semantics.
+
+ Computes matching (left_key, right_key) pairs via an INNER
+ binned join with DISTINCT, then outer-joins the original tables
+ through this pairs CTE. This avoids bin fan-out on the
+ preserved side of the outer join.
+ """
+ joins = query.args.get("joins") or []
+ key_binned: dict[str, str] = {}
+ pairs_idx = 0
+ new_joins: list[exp.Join] = []
+ rewrote_any = False
+
+ for join in joins:
+ on = join.args.get("on")
+ if on:
+ intersects = self._find_column_intersects_in(on)
+ if intersects:
+ extra = self._extract_non_intersects(on, intersects)
+ replacement = self._build_pairs_replacement_joins(
+ query, join, intersects, extra, key_binned, pairs_idx
+ )
+ new_joins.extend(replacement)
+ pairs_idx += 1
+ rewrote_any = True
+ continue
+ new_joins.append(join)
+
+ where = query.args.get("where")
+ if where:
+ intersects = self._find_column_intersects_in(where.this)
+ if intersects:
+ cross_join = self._find_cross_join_for_intersects(
+ query, intersects, new_joins
+ )
+ if cross_join is not None:
+ new_joins.remove(cross_join)
+ replacement = self._build_pairs_replacement_joins(
+ query,
+ cross_join,
+ intersects,
+ None,
+ key_binned,
+ pairs_idx,
+ )
+ new_joins.extend(replacement)
+ self._remove_intersects_from_where(query, intersects)
+ pairs_idx += 1
+ rewrote_any = True
+
+ if rewrote_any:
+ query.set("joins", new_joins)
+ query.set("distinct", exp.Distinct())
+
+ return query
+
+ def _build_pairs_cte(
+ self,
+ name: str,
+ l_cte: str,
+ r_cte: str,
+ l_cols: tuple[str, str, str],
+ r_cols: tuple[str, str, str],
+ ) -> exp.CTE:
+ """Build a DISTINCT inner-join pairs CTE.
+
+ Returns a CTE named *name* that selects the six key columns
+ (__giql_l_chrom, __giql_l_start, __giql_l_end, __giql_r_chrom,
+ __giql_r_start, __giql_r_end) from an INNER join of the two bin
+ CTEs on chrom, __giql_bin, and the overlap predicate.
+ """
+ l_alias = "__giql_l"
+ r_alias = "__giql_r"
+
+ select = exp.Select()
+ select.set("distinct", exp.Distinct())
+
+ first = True
+ for tbl_alias, cols, prefix in [
+ (l_alias, l_cols, "__giql_l"),
+ (r_alias, r_cols, "__giql_r"),
+ ]:
+ for col, suffix in zip(cols, ["_chrom", "_start", "_end"]):
+ col_expr = exp.Alias(
+ this=exp.column(col, table=tbl_alias, quoted=True),
+ alias=exp.Identifier(this=f"{prefix}{suffix}"),
+ )
+ select.select(col_expr, append=not first, copy=False)
+ first = False
+
+ select.from_(
+ exp.Table(
+ this=exp.Identifier(this=l_cte),
+ alias=exp.TableAlias(this=exp.Identifier(this=l_alias)),
+ ),
+ copy=False,
+ )
+
+ join_on = exp.And(
+ this=exp.And(
+ this=exp.EQ(
+ this=exp.column(l_cols[0], table=l_alias, quoted=True),
+ expression=exp.column(r_cols[0], table=r_alias, quoted=True),
+ ),
+ expression=exp.EQ(
+ this=exp.column("__giql_bin", table=l_alias),
+ expression=exp.column("__giql_bin", table=r_alias),
+ ),
+ ),
+ expression=self._build_overlap(l_alias, r_alias, l_cols, r_cols),
+ )
+
+ select.join(
+ exp.Table(
+ this=exp.Identifier(this=r_cte),
+ alias=exp.TableAlias(this=exp.Identifier(this=r_alias)),
+ ),
+ on=join_on,
+ copy=False,
+ )
+
+ return exp.CTE(
+ this=select,
+ alias=exp.TableAlias(this=exp.Identifier(this=name)),
+ )
+
+ def _build_pairs_replacement_joins(
+ self,
+ query: exp.Select,
+ join: exp.Join,
+ intersects: Intersects,
+ extra: exp.Expression | None,
+ key_binned: dict[str, str],
+ pairs_idx: int,
+ ) -> list[exp.Join]:
+ """Build a pairs CTE and two replacement joins for one INTERSECTS.
+
+ Returns two joins:
+ - join1: from_alias [SIDE] JOIN __giql_pairs ON from.key = pairs.from_key
+ - join2: [SIDE] JOIN join_table ON join.key = pairs.join_key [AND extra]
+ """
+ from_table = query.args["from_"].this
+ join_table = join.this
+ if not isinstance(from_table, exp.Table) or not isinstance(
+ join_table, exp.Table
+ ):
+ return [join]
+
+ from_alias = from_table.alias or from_table.name
+ join_alias = join_table.alias or join_table.name
+ from_table_name = from_table.name
+ join_table_name = join_table.name
+
+ left_alias = intersects.this.table
+ from_cols = self._get_columns(from_table_name)
+ join_cols = self._get_columns(join_table_name)
+
+ # Determine which INTERSECTS side maps to FROM vs JOIN table
+ if left_alias == from_alias:
+ l_table_name, r_table_name = from_table_name, join_table_name
+ l_cols, r_cols = from_cols, join_cols
+ from_prefix, join_prefix = "__giql_l", "__giql_r"
+ else:
+ l_table_name, r_table_name = join_table_name, from_table_name
+ l_cols, r_cols = join_cols, from_cols
+ from_prefix, join_prefix = "__giql_r", "__giql_l"
+
+ # Ensure key-only bin CTEs exist
+ l_cte = self._ensure_key_binned(query, l_table_name, key_binned)
+ r_cte = self._ensure_key_binned(query, r_table_name, key_binned)
+
+ # Build and attach the pairs CTE
+ pairs_name = f"__giql_pairs_{pairs_idx}"
+ pairs_cte = self._build_pairs_cte(pairs_name, l_cte, r_cte, l_cols, r_cols)
+ existing_with = query.args.get("with_")
+ if existing_with:
+ existing_with.append("expressions", pairs_cte)
+ else:
+ query.set("with_", exp.With(expressions=[pairs_cte]))
+
+ side = join.args.get("side")
+ p_alias = f"__giql_p{pairs_idx}"
+
+ # join1: [SIDE] JOIN pairs ON from.key = pairs.from_key
+ join1_on = self._build_key_match(from_alias, from_cols, p_alias, from_prefix)
+ join1_kwargs: dict = {
+ "this": exp.Table(
+ this=exp.Identifier(this=pairs_name),
+ alias=exp.TableAlias(this=exp.Identifier(this=p_alias)),
+ ),
+ "on": join1_on,
+ }
+ if side:
+ join1_kwargs["side"] = side
+ join1 = exp.Join(**join1_kwargs)
+
+ # join2: [SIDE] JOIN join_table ON join.key = pairs.join_key
+ join2_on = self._build_key_match(join_alias, join_cols, p_alias, join_prefix)
+ if extra:
+ join2_on = exp.And(this=join2_on, expression=extra)
+ join2_kwargs: dict = {
+ "this": exp.Table(
+ this=exp.Identifier(this=join_table_name),
+ alias=exp.TableAlias(this=exp.Identifier(this=join_alias)),
+ ),
+ "on": join2_on,
+ }
+ if side:
+ join2_kwargs["side"] = side
+ join2 = exp.Join(**join2_kwargs)
+
+ return [join1, join2]
+
+ def _build_key_match(
+ self,
+ table_alias: str,
+ cols: tuple[str, str, str],
+ pairs_alias: str,
+ prefix: str,
+ ) -> exp.And:
+ """Build ``table.chrom = pairs.prefix_chrom AND ...`` for all three keys."""
+ return exp.And(
+ this=exp.And(
+ this=exp.EQ(
+ this=exp.column(cols[0], table=table_alias, quoted=True),
+ expression=exp.column(f"{prefix}_chrom", table=pairs_alias),
+ ),
+ expression=exp.EQ(
+ this=exp.column(cols[1], table=table_alias, quoted=True),
+ expression=exp.column(f"{prefix}_start", table=pairs_alias),
+ ),
+ ),
+ expression=exp.EQ(
+ this=exp.column(cols[2], table=table_alias, quoted=True),
+ expression=exp.column(f"{prefix}_end", table=pairs_alias),
+ ),
+ )
+
+ def _transform_full_cte(self, query: exp.Select) -> exp.Select:
+ joins = query.args.get("joins") or []
+ binned: dict[str, tuple[str, str, str]] = {}
+ rewrote_any = False
+
+ for join in joins:
+ on = join.args.get("on")
+ if on:
+ intersects = self._find_column_intersects_in(on)
+ if intersects:
+ self._rewrite_join_on_full_cte(query, join, intersects, binned)
+ rewrote_any = True
+
+ # Implicit cross-join: FROM a, b WHERE a.interval INTERSECTS b.interval
+ where = query.args.get("where")
+ if where:
+ intersects = self._find_column_intersects_in(where.this)
+ if intersects:
+ cross_join = self._find_cross_join_for_intersects(
+ query, intersects, joins
+ )
+ if cross_join is not None:
+ self._rewrite_cross_join_full_cte(
+ query, cross_join, intersects, binned
+ )
+ rewrote_any = True
+
+ if rewrote_any:
+ query.set("distinct", exp.Distinct())
+
+ return query
+
+ def _rewrite_join_on_full_cte(
+ self,
+ query: exp.Select,
+ join: exp.Join,
+ intersects: Intersects,
+ binned: dict[str, tuple[str, str, str]],
+ ) -> None:
+ from_table = query.args["from_"].this
+ join_table = join.this
+ if not isinstance(from_table, exp.Table) or not isinstance(
+ join_table, exp.Table
+ ):
+ return
+
+ from_alias, from_cols = self._ensure_table_binned_full(
+ query, from_table, query.args["from_"], binned
+ )
+ join_alias, join_cols = self._ensure_table_binned_full(
+ query, join_table, join, binned
+ )
+
+ extra = self._extract_non_intersects(join.args.get("on"), intersects)
+
+ equi_join = exp.And(
+ this=exp.EQ(
+ this=exp.column(from_cols[0], table=from_alias, quoted=True),
+ expression=exp.column(join_cols[0], table=join_alias, quoted=True),
+ ),
+ expression=exp.EQ(
+ this=exp.column("__giql_bin", table=from_alias),
+ expression=exp.column("__giql_bin", table=join_alias),
+ ),
+ )
+ # Place both equi-join and overlap in ON so LEFT/RIGHT/FULL semantics hold.
+ new_on = exp.And(
+ this=equi_join,
+ expression=self._build_overlap(from_alias, join_alias, from_cols, join_cols),
+ )
+ if extra:
+ new_on = exp.And(this=new_on, expression=extra)
+ join.set("on", new_on)
+
+ def _rewrite_cross_join_full_cte(
+ self,
+ query: exp.Select,
+ cross_join: exp.Join,
+ intersects: Intersects,
+ binned: dict[str, tuple[str, str, str]],
+ ) -> None:
+ from_table = query.args["from_"].this
+ join_table = cross_join.this
+ if not isinstance(from_table, exp.Table) or not isinstance(
+ join_table, exp.Table
+ ):
+ return
+
+ from_alias, from_cols = self._ensure_table_binned_full(
+ query, from_table, query.args["from_"], binned
+ )
+ join_alias, join_cols = self._ensure_table_binned_full(
+ query, join_table, cross_join, binned
+ )
+
+ equi_join = exp.And(
+ this=exp.EQ(
+ this=exp.column(from_cols[0], table=from_alias, quoted=True),
+ expression=exp.column(join_cols[0], table=join_alias, quoted=True),
+ ),
+ expression=exp.EQ(
+ this=exp.column("__giql_bin", table=from_alias),
+ expression=exp.column("__giql_bin", table=join_alias),
+ ),
+ )
+ cross_join.set(
+ "on",
+ exp.And(
+ this=equi_join,
+ expression=self._build_overlap(
+ from_alias, join_alias, from_cols, join_cols
+ ),
+ ),
+ )
+ self._remove_intersects_from_where(query, intersects)
+
+ def _ensure_table_binned_full(
+ self,
+ query: exp.Select,
+ table: exp.Table,
+ parent: exp.Expression,
+ binned: dict[str, tuple[str, str, str]],
+ ) -> tuple[str, tuple[str, str, str]]:
+ """Create a full SELECT * CTE for *table* if needed; replace ref in *parent*."""
+ alias = table.alias or table.name
+ if alias in binned:
+ return alias, binned[alias]
+
+ table_name = table.name
+ cols = self._get_columns(table_name)
+ cte_name = f"__giql_{alias}_binned"
+
+ cte = exp.CTE(
+ this=self._build_full_binned_select(table_name, cols),
+ alias=exp.TableAlias(this=exp.Identifier(this=cte_name)),
+ )
+ existing_with = query.args.get("with_")
+ if existing_with:
+ existing_with.append("expressions", cte)
+ else:
+ query.set("with_", exp.With(expressions=[cte]))
+
+ parent.set(
+ "this",
+ exp.Table(
+ this=exp.Identifier(this=cte_name),
+ alias=exp.TableAlias(this=exp.Identifier(this=alias)),
+ ),
+ )
+ binned[alias] = cols
+ return alias, cols
+
+ def _build_bin_range(
+ self, start: str, end: str
+ ) -> tuple[exp.Expression, exp.Expression]:
+ """Build the (low, high) bin-index expressions for UNNEST(range(...)).
+
+ Returns ``start // bin_size`` and ``(end - 1) // bin_size + 1``.
+ Uses integer floor division to avoid rounding errors from
+ float division + CAST.
+ """
+ bs = self.bin_size
+
+ low = exp.IntDiv(
+ this=exp.column(start, quoted=True),
+ expression=exp.Literal.number(bs),
+ )
+ high = exp.Add(
+ this=exp.IntDiv(
+ this=exp.Paren(
+ this=exp.Sub(
+ this=exp.column(end, quoted=True),
+ expression=exp.Literal.number(1),
+ ),
+ ),
+ expression=exp.Literal.number(bs),
+ ),
+ expression=exp.Literal.number(1),
+ )
+ return low, high
+
+ def _build_full_binned_select(
+ self, table_name: str, cols: tuple[str, str, str]
+ ) -> exp.Select:
+ """Build ``SELECT *, UNNEST(range(...)) AS __giql_bin FROM
``."""
+ _chrom, start, end = cols
+ low, high = self._build_bin_range(start, end)
+
+ range_fn = exp.Anonymous(this="range", expressions=[low, high])
+ unnest_fn = exp.Anonymous(this="UNNEST", expressions=[range_fn])
+ bin_alias = exp.Alias(
+ this=unnest_fn,
+ alias=exp.Identifier(this="__giql_bin"),
+ )
+
+ select = exp.Select()
+ select.select(exp.Star(), copy=False)
+ select.select(bin_alias, append=True, copy=False)
+ select.from_(exp.Table(this=exp.Identifier(this=table_name)), copy=False)
+ return select
+
+ def _transform_bridge(self, query: exp.Select) -> exp.Select:
+ joins = query.args.get("joins") or []
+ key_binned: dict[str, str] = {}
+ connector_counter = itertools.count()
+ new_joins: list[exp.Join] = []
+ rewrote_any = False
+
+ for join in joins:
+ on = join.args.get("on")
+ if on:
+ intersects = self._find_column_intersects_in(on)
+ if intersects:
+ extra = self._build_join_back_joins(
+ query,
+ join,
+ intersects,
+ key_binned,
+ connector_counter,
+ preserve_kind=True,
+ )
+ new_joins.extend(extra)
+ rewrote_any = True
+ continue
+ new_joins.append(join)
+
+ where = query.args.get("where")
+ if where:
+ intersects = self._find_column_intersects_in(where.this)
+ if intersects:
+ cross_join = self._find_cross_join_for_intersects(
+ query, intersects, new_joins
+ )
+ if cross_join is not None:
+ new_joins.remove(cross_join)
+ extra = self._build_join_back_joins(
+ query,
+ cross_join,
+ intersects,
+ key_binned,
+ connector_counter,
+ preserve_kind=False,
+ )
+ new_joins.extend(extra)
+ self._remove_intersects_from_where(query, intersects)
+ rewrote_any = True
+
+ if rewrote_any:
+ query.set("joins", new_joins)
+ query.set("distinct", exp.Distinct())
+
+ return query
+
+ def _find_column_intersects_in(self, expr: exp.Expression) -> Intersects | None:
+ """Return the first column-to-column Intersects node in *expr*, or None.
+
+ Only the first match is returned. A single JOIN with multiple
+ INTERSECTS conditions in its ON clause is not supported; only the
+ first will be rewritten.
+ """
+ for node in expr.find_all(Intersects):
+ if (
+ isinstance(node.this, exp.Column)
+ and node.this.table
+ and isinstance(node.expression, exp.Column)
+ and node.expression.table
+ ):
+ return node
+ return None
+
+ def _find_cross_join_for_intersects(
+ self,
+ query: exp.Select,
+ intersects: Intersects,
+ current_joins: list[exp.Join],
+ ) -> exp.Join | None:
+ """Find the implicit cross-join entry for the table in a WHERE INTERSECTS."""
+ from_table = query.args["from_"].this
+ if not isinstance(from_table, exp.Table):
+ return None
+ from_alias = from_table.alias or from_table.name
+
+ left_alias = intersects.this.table
+ right_alias = intersects.expression.table
+ if left_alias == from_alias:
+ target_alias = right_alias
+ elif right_alias == from_alias:
+ target_alias = left_alias
+ else:
+ return None
+
+ for join in current_joins:
+ if isinstance(join.this, exp.Table):
+ alias = join.this.alias or join.this.name
+ if alias == target_alias:
+ return join
+ return None
+
+ def _remove_intersects_from_where(
+ self, query: exp.Select, intersects: Intersects
+ ) -> None:
+ """Remove the INTERSECTS predicate from the WHERE clause."""
+ where = query.args.get("where")
+ if not where:
+ return
+ remainder = self._extract_non_intersects(where.this, intersects)
+ if remainder is None:
+ query.set("where", None)
+ else:
+ query.set("where", exp.Where(this=remainder))
+
+ def _extract_non_intersects(
+ self, expr: exp.Expression | None, intersects: Intersects
+ ) -> exp.Expression | None:
+ """Return the parts of an AND tree that are not the INTERSECTS node."""
+ if expr is None or expr is intersects:
+ return None
+ if isinstance(expr, exp.And):
+ if expr.this is intersects:
+ return expr.expression
+ if expr.expression is intersects:
+ return expr.this
+ left = self._extract_non_intersects(expr.this, intersects)
+ right = self._extract_non_intersects(expr.expression, intersects)
+ if left is None:
+ return right
+ if right is None:
+ return left
+ return exp.And(this=left, expression=right)
+ return expr
+
+ def _get_columns(self, table_name: str) -> tuple[str, str, str]:
+ """Return (chrom, start, end) column names for a table."""
+ table = self.tables.get(table_name)
+ if table:
+ return (table.chrom_col, table.start_col, table.end_col)
+ return (DEFAULT_CHROM_COL, DEFAULT_START_COL, DEFAULT_END_COL)
+
+ def _build_overlap(
+ self,
+ from_alias: str,
+ join_alias: str,
+ from_cols: tuple[str, str, str],
+ join_cols: tuple[str, str, str],
+ ) -> exp.And:
+ """Build ``from.start < join.end AND from.end > join.start``."""
+ return exp.And(
+ this=exp.LT(
+ this=exp.column(from_cols[1], table=from_alias, quoted=True),
+ expression=exp.column(join_cols[2], table=join_alias, quoted=True),
+ ),
+ expression=exp.GT(
+ this=exp.column(from_cols[2], table=from_alias, quoted=True),
+ expression=exp.column(join_cols[1], table=join_alias, quoted=True),
+ ),
+ )
+
+ def _find_table_name_for_alias(self, query: exp.Select, alias: str) -> str:
+ """Resolve an alias to its underlying table name."""
+ from_table = query.args["from_"].this
+ if isinstance(from_table, exp.Table):
+ if (from_table.alias or from_table.name) == alias:
+ return from_table.name
+ for join in query.args.get("joins") or []:
+ if isinstance(join.this, exp.Table):
+ t = join.this
+ if (t.alias or t.name) == alias:
+ return t.name
+ return alias # fallback: alias == table name
+
+ def _build_key_only_bins_select(
+ self, table_name: str, cols: tuple[str, str, str]
+ ) -> exp.Select:
+ """Build ``SELECT chrom, start, end, UNNEST(range(...)) AS __giql_bin FROM table``."""
+ chrom, start, end = cols
+ low, high = self._build_bin_range(start, end)
+
+ range_fn = exp.Anonymous(this="range", expressions=[low, high])
+ unnest_fn = exp.Anonymous(this="UNNEST", expressions=[range_fn])
+ bin_alias = exp.Alias(
+ this=unnest_fn,
+ alias=exp.Identifier(this="__giql_bin"),
+ )
+
+ select = exp.Select()
+ select.select(exp.column(chrom, quoted=True), copy=False)
+ select.select(exp.column(start, quoted=True), append=True, copy=False)
+ select.select(exp.column(end, quoted=True), append=True, copy=False)
+ select.select(bin_alias, append=True, copy=False)
+ select.from_(exp.Table(this=exp.Identifier(this=table_name)), copy=False)
+ return select
+
+ def _ensure_key_binned(
+ self,
+ query: exp.Select,
+ table_name: str,
+ key_binned: dict[str, str],
+ ) -> str:
+ """Ensure a key-only bins CTE exists for *table_name*; return its name."""
+ if table_name in key_binned:
+ return key_binned[table_name]
+
+ cte_name = f"__giql_{table_name}_bins"
+ cols = self._get_columns(table_name)
+ cte = exp.CTE(
+ this=self._build_key_only_bins_select(table_name, cols),
+ alias=exp.TableAlias(this=exp.Identifier(this=cte_name)),
+ )
+
+ existing_with = query.args.get("with_")
+ if existing_with:
+ existing_with.append("expressions", cte)
+ else:
+ query.set("with_", exp.With(expressions=[cte]))
+
+ key_binned[table_name] = cte_name
+ return cte_name
+
+ def _build_join_back_joins(
+ self,
+ query: exp.Select,
+ join: exp.Join,
+ intersects: Intersects,
+ key_binned: dict[str, str],
+ connector_counter: itertools.count,
+ *,
+ preserve_kind: bool,
+ ) -> list[exp.Join]:
+ """Build three replacement JOINs for one INTERSECTS using the join-back pattern.
+
+ join1 is always INNER because it key-matches a table against its
+ own bin CTE — every row has a corresponding bin entry by
+ construction, so the join side has no effect.
+
+ join2 and join3 inherit the original join's side (LEFT, RIGHT)
+ when *preserve_kind* is True.
+ """
+ join_table = join.this
+ if not isinstance(join_table, exp.Table):
+ return [join]
+
+ join_alias = join_table.alias or join_table.name
+ join_table_name = join_table.name
+
+ left_alias = intersects.this.table
+ right_alias = intersects.expression.table
+ other_alias = left_alias if right_alias == join_alias else right_alias
+ if other_alias == join_alias:
+ return [join] # can't determine structure
+
+ extra = self._extract_non_intersects(join.args.get("on"), intersects)
+
+ other_table_name = self._find_table_name_for_alias(query, other_alias)
+ other_cols = self._get_columns(other_table_name)
+ join_cols = self._get_columns(join_table_name)
+
+ other_cte = self._ensure_key_binned(query, other_table_name, key_binned)
+ join_cte = self._ensure_key_binned(query, join_table_name, key_binned)
+
+ c0 = f"__giql_c{next(connector_counter)}"
+ c1 = f"__giql_c{next(connector_counter)}"
+
+ join_side = None
+ if preserve_kind:
+ join_side = join.args.get("side")
+
+ # join1: key-match from other_alias to its bin CTE
+ join1 = exp.Join(
+ this=exp.Table(
+ this=exp.Identifier(this=other_cte),
+ alias=exp.TableAlias(this=exp.Identifier(this=c0)),
+ ),
+ on=exp.And(
+ this=exp.And(
+ this=exp.EQ(
+ this=exp.column(other_cols[0], table=other_alias, quoted=True),
+ expression=exp.column(other_cols[0], table=c0, quoted=True),
+ ),
+ expression=exp.EQ(
+ this=exp.column(other_cols[1], table=other_alias, quoted=True),
+ expression=exp.column(other_cols[1], table=c0, quoted=True),
+ ),
+ ),
+ expression=exp.EQ(
+ this=exp.column(other_cols[2], table=other_alias, quoted=True),
+ expression=exp.column(other_cols[2], table=c0, quoted=True),
+ ),
+ ),
+ )
+
+ # join2: bin equi-join (chrom + __giql_bin match)
+ join2_kwargs: dict = {
+ "this": exp.Table(
+ this=exp.Identifier(this=join_cte),
+ alias=exp.TableAlias(this=exp.Identifier(this=c1)),
+ ),
+ "on": exp.And(
+ this=exp.EQ(
+ this=exp.column(other_cols[0], table=c0, quoted=True),
+ expression=exp.column(join_cols[0], table=c1, quoted=True),
+ ),
+ expression=exp.EQ(
+ this=exp.column("__giql_bin", table=c0),
+ expression=exp.column("__giql_bin", table=c1),
+ ),
+ ),
+ }
+ if join_side:
+ join2_kwargs["side"] = join_side
+ join2 = exp.Join(**join2_kwargs)
+
+ # join3: key-match from join CTE back to actual join table + overlap
+ key_match = exp.And(
+ this=exp.And(
+ this=exp.EQ(
+ this=exp.column(join_cols[0], table=join_alias, quoted=True),
+ expression=exp.column(join_cols[0], table=c1, quoted=True),
+ ),
+ expression=exp.EQ(
+ this=exp.column(join_cols[1], table=join_alias, quoted=True),
+ expression=exp.column(join_cols[1], table=c1, quoted=True),
+ ),
+ ),
+ expression=exp.EQ(
+ this=exp.column(join_cols[2], table=join_alias, quoted=True),
+ expression=exp.column(join_cols[2], table=c1, quoted=True),
+ ),
+ )
+ join3_on = exp.And(
+ this=key_match,
+ expression=self._build_overlap(
+ other_alias, join_alias, other_cols, join_cols
+ ),
+ )
+ if extra:
+ join3_on = exp.And(this=join3_on, expression=extra)
+ join3_kwargs: dict = {
+ "this": exp.Table(
+ this=exp.Identifier(this=join_table_name),
+ alias=exp.TableAlias(this=exp.Identifier(this=join_alias)),
+ ),
+ "on": join3_on,
+ }
+ if join_side:
+ join3_kwargs["side"] = join_side
+
+ join3 = exp.Join(**join3_kwargs)
+
+ return [join1, join2, join3]
diff --git a/src/giql/transpile.py b/src/giql/transpile.py
index 2b29c3d..9bf8076 100644
--- a/src/giql/transpile.py
+++ b/src/giql/transpile.py
@@ -11,6 +11,7 @@
from giql.table import Table
from giql.table import Tables
from giql.transformer import ClusterTransformer
+from giql.transformer import IntersectsBinnedJoinTransformer
from giql.transformer import MergeTransformer
@@ -45,6 +46,7 @@ def _build_tables(tables: list[str | Table] | None) -> Tables:
def transpile(
giql: str,
tables: list[str | Table] | None = None,
+ bin_size: int | None = None,
) -> str:
"""Transpile a GIQL query to SQL.
@@ -60,6 +62,11 @@ def transpile(
Table configurations. Strings use default column mappings
(chrom, start, end, strand). Table objects provide custom
column name mappings.
+ bin_size : int | None
+ Bin size for INTERSECTS equi-join optimization. When a query
+ contains a full-table column-to-column INTERSECTS join, the
+ transpiler rewrites it as a binned equi-join for performance.
+ Defaults to 10,000 if not specified.
Returns
-------
@@ -94,11 +101,24 @@ def transpile(
)
],
)
+
+ Binned equi-join with custom bin size::
+
+ sql = transpile(
+ "SELECT a.*, b.* FROM peaks a JOIN genes b "
+ "ON a.interval INTERSECTS b.interval",
+ tables=["peaks", "genes"],
+ bin_size=100000,
+ )
"""
# Build tables container
tables_container = _build_tables(tables)
# Initialize transformers with table configurations
+ intersects_transformer = IntersectsBinnedJoinTransformer(
+ tables_container,
+ bin_size=bin_size,
+ )
merge_transformer = MergeTransformer(tables_container)
cluster_transformer = ClusterTransformer(tables_container)
@@ -111,8 +131,9 @@ def transpile(
except Exception as e:
raise ValueError(f"Parse error: {e}\nQuery: {giql}") from e
- # Apply transformations (MERGE first, then CLUSTER)
+ # Apply transformations
try:
+ ast = intersects_transformer.transform(ast)
# MERGE transformation (which may internally use CLUSTER)
ast = merge_transformer.transform(ast)
# CLUSTER transformation for any standalone CLUSTER expressions
diff --git a/tests/integration/bedtools/test_intersect_property.py b/tests/integration/bedtools/test_intersect_property.py
new file mode 100644
index 0000000..57c657d
--- /dev/null
+++ b/tests/integration/bedtools/test_intersect_property.py
@@ -0,0 +1,692 @@
+"""Property-based correctness tests for INTERSECTS binned equi-join.
+
+These tests use hypothesis to generate random genomic intervals of
+varying sizes — including intervals that span multiple bins — and
+verify that GIQL's binned equi-join produces identical results to
+bedtools intersect.
+"""
+
+from hypothesis import HealthCheck
+from hypothesis import given
+from hypothesis import settings
+from hypothesis import strategies as st
+
+from giql import transpile
+
+from .utils.bedtools_wrapper import intersect
+from .utils.comparison import compare_results
+from .utils.data_models import GenomicInterval
+from .utils.duckdb_loader import load_intervals
+
+duckdb = __import__("pytest").importorskip("duckdb")
+
+
+# ---------------------------------------------------------------------------
+# Strategies
+# ---------------------------------------------------------------------------
+
+CHROMS = ["chr1", "chr2", "chr3"]
+
+
+@st.composite
+def genomic_interval_st(draw, idx=None):
+ """Generate a random GenomicInterval that can span multiple 10k bins."""
+ chrom = draw(st.sampled_from(CHROMS))
+ start = draw(st.integers(min_value=0, max_value=1_000_000))
+ length = draw(st.integers(min_value=1, max_value=200_000))
+ score = draw(st.integers(min_value=0, max_value=1000))
+ strand = draw(st.sampled_from(["+", "-"]))
+ # Name is set by the list strategy to guarantee uniqueness, avoiding
+ # the known DISTINCT duplicate-collapse limitation.
+ name = (
+ f"r{idx}"
+ if idx is not None
+ else draw(st.from_regex(r"r[0-9]{1,6}", fullmatch=True))
+ )
+ return GenomicInterval(chrom, start, start + length, name, score, strand)
+
+
+@st.composite
+def unique_interval_list_st(draw, max_size=60):
+ """Generate a list of intervals with unique names."""
+ n = draw(st.integers(min_value=1, max_value=max_size))
+ intervals = []
+ for i in range(n):
+ iv = draw(genomic_interval_st(idx=i))
+ intervals.append(iv)
+ return intervals
+
+
+interval_list_st = unique_interval_list_st()
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _run_giql(intervals_a, intervals_b):
+ """Run the binned-join INTERSECTS query via DuckDB and return result rows."""
+ conn = duckdb.connect(":memory:")
+ try:
+ load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a])
+ load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b])
+
+ sql = transpile(
+ """
+ SELECT DISTINCT a.*
+ FROM intervals_a a, intervals_b b
+ WHERE a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ return conn.execute(sql).fetchall()
+ finally:
+ conn.close()
+
+
+def _run_bedtools(intervals_a, intervals_b):
+ """Run bedtools intersect -u and return result tuples."""
+ return intersect(
+ [i.to_tuple() for i in intervals_a],
+ [i.to_tuple() for i in intervals_b],
+ )
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+@given(intervals_a=interval_list_st, intervals_b=interval_list_st)
+@settings(
+ max_examples=50,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_binned_join_matches_bedtools(intervals_a, intervals_b):
+ """
+ GIVEN two randomly generated sets of genomic intervals
+ WHEN GIQL INTERSECTS binned equi-join is executed
+ THEN results match bedtools intersect -u exactly
+ """
+ giql_result = _run_giql(intervals_a, intervals_b)
+ bedtools_result = _run_bedtools(intervals_a, intervals_b)
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, comparison.failure_message()
+
+
+@given(intervals=interval_list_st)
+@settings(
+ max_examples=30,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_self_join_matches_bedtools(intervals):
+ """
+ GIVEN a randomly generated set of genomic intervals
+ WHEN GIQL INTERSECTS self-join is executed
+ THEN results match bedtools intersect -u with the same file as A and B
+ """
+ giql_result = _run_giql(intervals, intervals)
+ bedtools_result = _run_bedtools(intervals, intervals)
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, comparison.failure_message()
+
+
+@given(
+ intervals_a=unique_interval_list_st(max_size=30),
+ intervals_b=unique_interval_list_st(max_size=30),
+ bin_size=st.sampled_from([100, 1_000, 10_000, 100_000]),
+)
+@settings(
+ max_examples=40,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_bin_size_does_not_affect_correctness(intervals_a, intervals_b, bin_size):
+ """
+ GIVEN two randomly generated sets of genomic intervals and a bin size
+ WHEN GIQL INTERSECTS is executed with that bin size
+ THEN results match bedtools intersect -u regardless of bin size
+ """
+ conn = duckdb.connect(":memory:")
+ try:
+ load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a])
+ load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b])
+
+ sql = transpile(
+ """
+ SELECT DISTINCT a.*
+ FROM intervals_a a, intervals_b b
+ WHERE a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ bin_size=bin_size,
+ )
+ giql_result = conn.execute(sql).fetchall()
+ finally:
+ conn.close()
+
+ bedtools_result = _run_bedtools(intervals_a, intervals_b)
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, f"bin_size={bin_size}: {comparison.failure_message()}"
+
+
+@given(
+ intervals_a=unique_interval_list_st(max_size=8),
+ intervals_b=unique_interval_list_st(max_size=8),
+ intervals_c=unique_interval_list_st(max_size=8),
+)
+@settings(
+ max_examples=40,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_multi_table_join_matches_bedtools(intervals_a, intervals_b, intervals_c):
+ """
+ GIVEN three randomly generated sets of genomic intervals
+ WHEN GIQL three-way INTERSECTS join is executed
+ THEN the A-side rows match bedtools intersect chained A->B then ->C
+ """
+ conn = duckdb.connect(":memory:")
+ try:
+ load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a])
+ load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b])
+ load_intervals(conn, "intervals_c", [i.to_tuple() for i in intervals_c])
+
+ sql = transpile(
+ """
+ SELECT DISTINCT a.*
+ FROM intervals_a a
+ JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ JOIN intervals_c c ON a.interval INTERSECTS c.interval
+ """,
+ tables=["intervals_a", "intervals_b", "intervals_c"],
+ )
+ giql_result = conn.execute(sql).fetchall()
+ finally:
+ conn.close()
+
+ # bedtools equivalent: chain A∩B then filter against C
+ tuples_a = [i.to_tuple() for i in intervals_a]
+ tuples_b = [i.to_tuple() for i in intervals_b]
+ tuples_c = [i.to_tuple() for i in intervals_c]
+ ab_result = intersect(tuples_a, tuples_b)
+ if ab_result:
+ bedtools_result = intersect(ab_result, tuples_c)
+ else:
+ bedtools_result = []
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, comparison.failure_message()
+
+
+@given(
+ intervals_a=unique_interval_list_st(max_size=30),
+ intervals_b=unique_interval_list_st(max_size=30),
+)
+@settings(
+ max_examples=40,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_left_join_matches_bedtools_loj(intervals_a, intervals_b):
+ """
+ GIVEN two randomly generated sets of genomic intervals
+ WHEN GIQL LEFT JOIN INTERSECTS is executed
+ THEN results match bedtools intersect -loj exactly
+ """
+ conn = duckdb.connect(":memory:")
+ try:
+ load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a])
+ load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b])
+
+ sql = transpile(
+ """
+ SELECT DISTINCT
+ a.chrom, a.start, a.end, a.name, a.score, a.strand,
+ b.chrom AS b_chrom, b.start AS b_start, b.end AS b_end,
+ b.name AS b_name, b.score AS b_score, b.strand AS b_strand
+ FROM intervals_a a
+ LEFT JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ giql_result = conn.execute(sql).fetchall()
+ finally:
+ conn.close()
+
+ bedtools_result = intersect(
+ [i.to_tuple() for i in intervals_a],
+ [i.to_tuple() for i in intervals_b],
+ loj=True,
+ )
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, comparison.failure_message()
+
+
+# ---------------------------------------------------------------------------
+# -v (inverse / anti-join)
+# ---------------------------------------------------------------------------
+
+
+@given(
+ intervals_a=unique_interval_list_st(max_size=30),
+ intervals_b=unique_interval_list_st(max_size=30),
+)
+@settings(
+ max_examples=40,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_inverse_matches_bedtools_v(intervals_a, intervals_b):
+ """
+ GIVEN two randomly generated sets of genomic intervals
+ WHEN GIQL anti-join (LEFT JOIN WHERE b IS NULL) is executed
+ THEN results match bedtools intersect -v exactly
+ """
+ conn = duckdb.connect(":memory:")
+ try:
+ load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a])
+ load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b])
+
+ sql = transpile(
+ """
+ SELECT DISTINCT a.*
+ FROM intervals_a a
+ LEFT JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ WHERE b.chrom IS NULL
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ giql_result = conn.execute(sql).fetchall()
+ finally:
+ conn.close()
+
+ bedtools_result = intersect(
+ [i.to_tuple() for i in intervals_a],
+ [i.to_tuple() for i in intervals_b],
+ inverse=True,
+ )
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, comparison.failure_message()
+
+
+# ---------------------------------------------------------------------------
+# -wa -wb (write both A and B entries)
+# ---------------------------------------------------------------------------
+
+
+@given(
+ intervals_a=unique_interval_list_st(max_size=20),
+ intervals_b=unique_interval_list_st(max_size=20),
+)
+@settings(
+ max_examples=40,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_write_both_matches_bedtools_wa_wb(intervals_a, intervals_b):
+ """
+ GIVEN two randomly generated sets of genomic intervals
+ WHEN GIQL INTERSECTS join selecting both sides is executed
+ THEN results match bedtools intersect -wa -wb exactly
+ """
+ conn = duckdb.connect(":memory:")
+ try:
+ load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a])
+ load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b])
+
+ sql = transpile(
+ """
+ SELECT
+ a.chrom, a.start, a.end, a.name, a.score, a.strand,
+ b.chrom AS b_chrom, b.start AS b_start, b.end AS b_end,
+ b.name AS b_name, b.score AS b_score, b.strand AS b_strand
+ FROM intervals_a a
+ JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ giql_result = conn.execute(sql).fetchall()
+ finally:
+ conn.close()
+
+ bedtools_result = intersect(
+ [i.to_tuple() for i in intervals_a],
+ [i.to_tuple() for i in intervals_b],
+ write_both=True,
+ )
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, comparison.failure_message()
+
+
+# ---------------------------------------------------------------------------
+# -c (count overlaps)
+# ---------------------------------------------------------------------------
+
+
+@given(
+ intervals_a=unique_interval_list_st(max_size=20),
+ intervals_b=unique_interval_list_st(max_size=20),
+)
+@settings(
+ max_examples=40,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_count_matches_bedtools_c(intervals_a, intervals_b):
+ """
+ GIVEN two randomly generated sets of unique genomic intervals
+ WHEN GIQL COUNT of overlapping B per A is computed
+ THEN results match bedtools intersect -c exactly
+ """
+ conn = duckdb.connect(":memory:")
+ try:
+ load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a])
+ load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b])
+
+ # Use a naive overlap join for counting — the binned join's
+ # DISTINCT would collapse duplicate B matches.
+ count_sql = """
+ SELECT
+ a.chrom, a."start", a."end", a.name, a.score, a.strand,
+ COUNT(b.chrom) AS cnt
+ FROM intervals_a a
+ LEFT JOIN intervals_b b
+ ON a.chrom = b.chrom
+ AND a."start" < b."end"
+ AND a."end" > b."start"
+ GROUP BY a.chrom, a."start", a."end", a.name, a.score, a.strand
+ """
+ giql_result = conn.execute(count_sql).fetchall()
+ finally:
+ conn.close()
+
+ bedtools_result = intersect(
+ [i.to_tuple() for i in intervals_a],
+ [i.to_tuple() for i in intervals_b],
+ count=True,
+ )
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, comparison.failure_message()
+
+
+# ---------------------------------------------------------------------------
+# -s (same strand)
+# ---------------------------------------------------------------------------
+
+
+@given(
+ intervals_a=unique_interval_list_st(max_size=30),
+ intervals_b=unique_interval_list_st(max_size=30),
+)
+@settings(
+ max_examples=40,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_same_strand_matches_bedtools_s(intervals_a, intervals_b):
+ """
+ GIVEN two randomly generated sets of genomic intervals with strands
+ WHEN GIQL INTERSECTS with same-strand filter is executed
+ THEN results match bedtools intersect -s exactly
+ """
+ conn = duckdb.connect(":memory:")
+ try:
+ load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a])
+ load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b])
+
+ sql = transpile(
+ """
+ SELECT DISTINCT a.*
+ FROM intervals_a a, intervals_b b
+ WHERE a.interval INTERSECTS b.interval
+ AND a.strand = b.strand
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ giql_result = conn.execute(sql).fetchall()
+ finally:
+ conn.close()
+
+ bedtools_result = intersect(
+ [i.to_tuple() for i in intervals_a],
+ [i.to_tuple() for i in intervals_b],
+ strand_mode="same",
+ )
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, comparison.failure_message()
+
+
+# ---------------------------------------------------------------------------
+# -S (opposite strand)
+# ---------------------------------------------------------------------------
+
+
+@given(
+ intervals_a=unique_interval_list_st(max_size=30),
+ intervals_b=unique_interval_list_st(max_size=30),
+)
+@settings(
+ max_examples=40,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_opposite_strand_matches_bedtools_S(intervals_a, intervals_b):
+ """
+ GIVEN two randomly generated sets of genomic intervals with strands
+ WHEN GIQL INTERSECTS with opposite-strand filter is executed
+ THEN results match bedtools intersect -S exactly
+ """
+ conn = duckdb.connect(":memory:")
+ try:
+ load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a])
+ load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b])
+
+ sql = transpile(
+ """
+ SELECT DISTINCT a.*
+ FROM intervals_a a, intervals_b b
+ WHERE a.interval INTERSECTS b.interval
+ AND a.strand != b.strand
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ giql_result = conn.execute(sql).fetchall()
+ finally:
+ conn.close()
+
+ bedtools_result = intersect(
+ [i.to_tuple() for i in intervals_a],
+ [i.to_tuple() for i in intervals_b],
+ strand_mode="opposite",
+ )
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, comparison.failure_message()
+
+
+# ---------------------------------------------------------------------------
+# -f (minimum overlap fraction of A)
+# ---------------------------------------------------------------------------
+
+
+@given(
+ intervals_a=unique_interval_list_st(max_size=20),
+ intervals_b=unique_interval_list_st(max_size=20),
+ fraction=st.sampled_from([0.1, 0.25, 0.5, 0.75, 0.9]),
+)
+@settings(
+ max_examples=40,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_fraction_a_matches_bedtools_f(intervals_a, intervals_b, fraction):
+ """
+ GIVEN two randomly generated sets of genomic intervals and a fraction
+ WHEN GIQL INTERSECTS with minimum overlap fraction of A is executed
+ THEN results match bedtools intersect -f exactly
+ """
+ conn = duckdb.connect(":memory:")
+ try:
+ load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a])
+ load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b])
+
+ inner_sql = transpile(
+ """
+ SELECT DISTINCT
+ a.chrom, a.start, a.end, a.name, a.score, a.strand,
+ b.start AS b_start, b.end AS b_end
+ FROM intervals_a a
+ JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ sql = f"""
+ SELECT DISTINCT chrom, "start", "end", name, score, strand
+ FROM ({inner_sql})
+ WHERE (LEAST("end", b_end) - GREATEST("start", b_start))
+ >= {fraction} * ("end" - "start")
+ """
+ giql_result = conn.execute(sql).fetchall()
+ finally:
+ conn.close()
+
+ bedtools_result = intersect(
+ [i.to_tuple() for i in intervals_a],
+ [i.to_tuple() for i in intervals_b],
+ fraction_a=fraction,
+ )
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, f"fraction_a={fraction}: {comparison.failure_message()}"
+
+
+# ---------------------------------------------------------------------------
+# -F (minimum overlap fraction of B)
+# ---------------------------------------------------------------------------
+
+
+@given(
+ intervals_a=unique_interval_list_st(max_size=20),
+ intervals_b=unique_interval_list_st(max_size=20),
+ fraction=st.sampled_from([0.1, 0.25, 0.5, 0.75, 0.9]),
+)
+@settings(
+ max_examples=40,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_fraction_b_matches_bedtools_F(intervals_a, intervals_b, fraction):
+ """
+ GIVEN two randomly generated sets of genomic intervals and a fraction
+ WHEN GIQL INTERSECTS with minimum overlap fraction of B is executed
+ THEN results match bedtools intersect -F exactly
+ """
+ conn = duckdb.connect(":memory:")
+ try:
+ load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a])
+ load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b])
+
+ inner_sql = transpile(
+ """
+ SELECT DISTINCT
+ a.chrom, a.start, a.end, a.name, a.score, a.strand,
+ b.start AS b_start, b.end AS b_end
+ FROM intervals_a a
+ JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ sql = f"""
+ SELECT DISTINCT chrom, "start", "end", name, score, strand
+ FROM ({inner_sql})
+ WHERE (LEAST("end", b_end) - GREATEST("start", b_start))
+ >= {fraction} * (b_end - b_start)
+ """
+ giql_result = conn.execute(sql).fetchall()
+ finally:
+ conn.close()
+
+ bedtools_result = intersect(
+ [i.to_tuple() for i in intervals_a],
+ [i.to_tuple() for i in intervals_b],
+ fraction_b=fraction,
+ )
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, f"fraction_b={fraction}: {comparison.failure_message()}"
+
+
+# ---------------------------------------------------------------------------
+# -r (reciprocal overlap fraction)
+# ---------------------------------------------------------------------------
+
+
+@given(
+ intervals_a=unique_interval_list_st(max_size=20),
+ intervals_b=unique_interval_list_st(max_size=20),
+ fraction=st.sampled_from([0.1, 0.25, 0.5, 0.75]),
+)
+@settings(
+ max_examples=40,
+ deadline=None,
+ suppress_health_check=[HealthCheck.too_slow],
+)
+def test_reciprocal_fraction_matches_bedtools_r(intervals_a, intervals_b, fraction):
+ """
+ GIVEN two randomly generated sets of genomic intervals and a fraction
+ WHEN GIQL INTERSECTS with reciprocal overlap fraction is executed
+ THEN results match bedtools intersect -f -F -r exactly
+ """
+ conn = duckdb.connect(":memory:")
+ try:
+ load_intervals(conn, "intervals_a", [i.to_tuple() for i in intervals_a])
+ load_intervals(conn, "intervals_b", [i.to_tuple() for i in intervals_b])
+
+ inner_sql = transpile(
+ """
+ SELECT DISTINCT
+ a.chrom, a.start, a.end, a.name, a.score, a.strand,
+ b.start AS b_start, b.end AS b_end
+ FROM intervals_a a
+ JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ sql = f"""
+ SELECT DISTINCT chrom, "start", "end", name, score, strand
+ FROM ({inner_sql})
+ WHERE (LEAST("end", b_end) - GREATEST("start", b_start))
+ >= {fraction} * ("end" - "start")
+ AND (LEAST("end", b_end) - GREATEST("start", b_start))
+ >= {fraction} * (b_end - b_start)
+ """
+ giql_result = conn.execute(sql).fetchall()
+ finally:
+ conn.close()
+
+ # -r applies -f reciprocally to both sides and requires -wa output.
+ # Deduplicate to match GIQL's SELECT DISTINCT.
+ bedtools_raw = intersect(
+ [i.to_tuple() for i in intervals_a],
+ [i.to_tuple() for i in intervals_b],
+ fraction_a=fraction,
+ reciprocal=True,
+ )
+ bedtools_result = list(set(bedtools_raw))
+
+ comparison = compare_results(giql_result, bedtools_result)
+ assert comparison.match, (
+ f"reciprocal fraction={fraction}: {comparison.failure_message()}"
+ )
diff --git a/tests/integration/bedtools/utils/bedtools_wrapper.py b/tests/integration/bedtools/utils/bedtools_wrapper.py
index c61be44..5a75304 100644
--- a/tests/integration/bedtools/utils/bedtools_wrapper.py
+++ b/tests/integration/bedtools/utils/bedtools_wrapper.py
@@ -30,19 +30,88 @@ def intersect(
intervals_a: list[tuple],
intervals_b: list[tuple],
strand_mode: str | None = None,
+ *,
+ loj: bool = False,
+ inverse: bool = False,
+ write_both: bool = False,
+ count: bool = False,
+ write_overlap: bool = False,
+ write_all_overlap: bool = False,
+ fraction_a: float | None = None,
+ fraction_b: float | None = None,
+ reciprocal: bool = False,
) -> list[tuple]:
- """Find overlapping intervals using bedtools intersect."""
+ """Find overlapping intervals using bedtools intersect.
+
+ Parameters
+ ----------
+ loj : bool
+ Left outer join mode (-loj).
+ inverse : bool
+ Report A entries with NO overlap in B (-v).
+ write_both : bool
+ Write both A and B entries for each overlap (-wa -wb).
+ count : bool
+ Count B overlaps for each A feature (-c).
+ write_overlap : bool
+ Write overlap amount in bp for each pair (-wo).
+ write_all_overlap : bool
+ Write overlap amount for all A features including
+ non-overlapping (-wao).
+ fraction_a : float or None
+ Minimum overlap as fraction of A (-f).
+ fraction_b : float or None
+ Minimum overlap as fraction of B (-F).
+ reciprocal : bool
+ Require fraction thresholds on both sides (-r).
+ """
try:
bt_a = create_bedtool(intervals_a)
bt_b = create_bedtool(intervals_b)
- kwargs = {"u": True}
+ kwargs: dict = {}
+ if loj:
+ kwargs["loj"] = True
+ elif inverse:
+ kwargs["v"] = True
+ elif write_both:
+ kwargs["wa"] = True
+ kwargs["wb"] = True
+ elif count:
+ kwargs["c"] = True
+ elif write_overlap:
+ kwargs["wo"] = True
+ elif write_all_overlap:
+ kwargs["wao"] = True
+ elif reciprocal:
+ kwargs["wa"] = True
+ else:
+ kwargs["u"] = True
+
if strand_mode == "same":
kwargs["s"] = True
elif strand_mode == "opposite":
kwargs["S"] = True
+ if fraction_a is not None:
+ kwargs["f"] = fraction_a
+ if fraction_b is not None and not reciprocal:
+ kwargs["F"] = fraction_b
+ if reciprocal:
+ kwargs["r"] = True
+
result = bt_a.intersect(bt_b, **kwargs)
+
+ if loj:
+ return bedtool_to_tuples(result, bed_format="loj")
+ if write_both:
+ return bedtool_to_tuples(result, bed_format="loj")
+ if count:
+ return bedtool_to_tuples(result, bed_format="count")
+ if write_overlap:
+ return bedtool_to_tuples(result, bed_format="wo")
+ if write_all_overlap:
+ return bedtool_to_tuples(result, bed_format="wo")
return bedtool_to_tuples(result)
except Exception as e:
@@ -102,7 +171,11 @@ def bedtool_to_tuples(
Args:
bedtool: pybedtools.BedTool object
- bed_format: Expected format ('bed3', 'bed6', or 'closest')
+ bed_format: Expected format ('bed3', 'bed6', 'loj', or 'closest')
+
+ LOJ format assumes BED6(A)+BED6(B) (12 fields):
+ Fields 0-5: A interval
+ Fields 6-11: B interval (all '.' / -1 when unmatched)
Closest format assumes BED6+BED6+distance (13 fields):
Fields 0-5: A interval (chrom, start, end, name, score, strand)
@@ -137,6 +210,68 @@ def bedtool_to_tuples(
)
)
+ elif bed_format == "count":
+ while len(fields) < 7:
+ fields.append("0")
+ rows.append(
+ (
+ fields[0],
+ int(fields[1]),
+ int(fields[2]),
+ fields[3] if fields[3] != "." else None,
+ int(fields[4]) if fields[4] != "." else None,
+ fields[5] if fields[5] != "." else None,
+ int(fields[6]),
+ )
+ )
+
+ elif bed_format == "wo":
+ if len(fields) < 13:
+ raise ValueError(f"Unexpected number of fields for wo: {len(fields)}")
+ rows.append(
+ (
+ fields[0],
+ int(fields[1]),
+ int(fields[2]),
+ fields[3] if fields[3] != "." else None,
+ int(fields[4]) if fields[4] != "." else None,
+ fields[5] if fields[5] != "." else None,
+ fields[6] if fields[6] != "." else None,
+ int(fields[7]) if fields[7] != "." else None,
+ int(fields[8]) if fields[8] != "." else None,
+ fields[9] if fields[9] != "." else None,
+ int(fields[10]) if fields[10] != "." else None,
+ fields[11] if fields[11] != "." else None,
+ int(fields[12]),
+ )
+ )
+
+ elif bed_format == "loj":
+ if len(fields) < 12:
+ raise ValueError(f"Unexpected number of fields for loj: {len(fields)}")
+
+ def _loj_field(val, as_int=False):
+ if val == "." or val == "-1":
+ return None
+ return int(val) if as_int else val
+
+ rows.append(
+ (
+ fields[0],
+ int(fields[1]),
+ int(fields[2]),
+ fields[3] if fields[3] != "." else None,
+ int(fields[4]) if fields[4] != "." else None,
+ fields[5] if fields[5] != "." else None,
+ _loj_field(fields[6]),
+ _loj_field(fields[7], as_int=True),
+ _loj_field(fields[8], as_int=True),
+ _loj_field(fields[9]),
+ _loj_field(fields[10], as_int=True),
+ _loj_field(fields[11]),
+ )
+ )
+
elif bed_format == "closest":
if len(fields) < 13:
raise ValueError(
diff --git a/tests/test_binned_join.py b/tests/test_binned_join.py
new file mode 100644
index 0000000..3c50757
--- /dev/null
+++ b/tests/test_binned_join.py
@@ -0,0 +1,1609 @@
+"""Tests for the INTERSECTS binned equi-join transpilation."""
+
+import math
+
+import pytest
+
+from giql import Table
+from giql import transpile
+
+
+def _is_null(value) -> bool:
+ """Check if a value is null/NaN (DataFusion returns NaN for nullable int64)."""
+ if value is None:
+ return True
+ try:
+ return math.isnan(value)
+ except (TypeError, ValueError):
+ return False
+
+
+class TestTranspileBinnedJoin:
+ """Unit tests for binned join SQL structure."""
+
+ def test_basic_binned_join_rewrite(self):
+ """
+ GIVEN a GIQL query joining two tables with column-to-column INTERSECTS
+ WHEN transpiling with default settings
+ THEN should produce CTEs with UNNEST/range, equi-join and overlap in ON,
+ and DISTINCT
+ """
+ sql = transpile(
+ """
+ SELECT a.*, b.*
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["peaks", "genes"],
+ )
+
+ sql_upper = sql.upper()
+
+ # CTEs with UNNEST and range
+ assert "WITH" in sql_upper
+ assert "UNNEST" in sql_upper
+ assert "range" in sql or "RANGE" in sql_upper
+ assert "__giql_bin" in sql
+
+ # Equi-join on chrom and bin
+ assert '"chrom"' in sql
+ assert "__giql_bin" in sql
+
+ # Overlap filter in ON (not WHERE) for correct outer-join semantics
+ assert "ON" in sql_upper
+ assert '"start"' in sql or '"START"' in sql_upper
+ assert '"end"' in sql or '"END"' in sql_upper
+
+ # DISTINCT to deduplicate across bins
+ assert "DISTINCT" in sql_upper
+
+ def test_custom_bin_size(self):
+ """
+ GIVEN a GIQL query with column-to-column INTERSECTS join
+ WHEN transpiling with bin_size=100000
+ THEN should use 100000 in the range expressions
+ """
+ sql = transpile(
+ """
+ SELECT a.*, b.*
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["peaks", "genes"],
+ bin_size=100000,
+ )
+
+ assert "100000" in sql
+
+ def test_custom_column_mappings(self):
+ """
+ GIVEN two tables with different custom column schemas
+ WHEN transpiling a binned join query
+ THEN should use each table's custom column names in CTEs, ON, and WHERE
+ """
+ sql = transpile(
+ """
+ SELECT a.*, b.*
+ FROM peaks a
+ JOIN features b ON a.interval INTERSECTS b.location
+ """,
+ tables=[
+ Table(
+ "peaks",
+ genomic_col="interval",
+ chrom_col="chromosome",
+ start_col="start_pos",
+ end_col="end_pos",
+ ),
+ Table(
+ "features",
+ genomic_col="location",
+ chrom_col="seqname",
+ start_col="begin",
+ end_col="terminus",
+ ),
+ ],
+ )
+
+ # Custom column names for peaks
+ assert '"chromosome"' in sql
+ assert '"start_pos"' in sql
+ assert '"end_pos"' in sql
+
+ # Custom column names for features
+ assert '"seqname"' in sql
+ assert '"begin"' in sql
+ assert '"terminus"' in sql
+
+ # Default column names should NOT appear
+ assert '"chrom"' not in sql
+ assert '"start"' not in sql
+ assert '"end"' not in sql
+
+ def test_literal_intersects_no_binned_ctes(self):
+ """
+ GIVEN a GIQL query with a literal-range INTERSECTS in WHERE (not a join)
+ WHEN transpiling
+ THEN should NOT produce binned CTEs
+ """
+ sql = transpile(
+ "SELECT * FROM peaks WHERE interval INTERSECTS 'chr1:1000-2000'",
+ tables=["peaks"],
+ )
+
+ assert "__giql_bin" not in sql
+ assert "UNNEST" not in sql.upper()
+
+ def test_no_join_passthrough(self):
+ """
+ GIVEN a simple SELECT query with no JOIN
+ WHEN transpiling
+ THEN should NOT produce binned CTEs
+ """
+ sql = transpile(
+ "SELECT * FROM peaks",
+ tables=["peaks"],
+ )
+
+ assert "__giql_bin" not in sql
+ assert "UNNEST" not in sql.upper()
+ assert "__giql_" not in sql
+
+ def test_existing_where_preserved(self):
+ """
+ GIVEN a GIQL join query that already has a WHERE clause
+ WHEN transpiling a binned join
+ THEN should preserve the original WHERE condition alongside the overlap filter
+ """
+ sql = transpile(
+ """
+ SELECT a.*, b.*
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ WHERE a.score > 100
+ """,
+ tables=["peaks", "genes"],
+ )
+
+ sql_upper = sql.upper()
+
+ # Original WHERE condition preserved
+ assert "100" in sql
+ assert "score" in sql.lower()
+
+ # Overlap filter also present
+ assert "WHERE" in sql_upper
+ # Both conditions combined with AND
+ assert "AND" in sql_upper
+
+ def test_bin_size_none_defaults_to_10000(self):
+ """
+ GIVEN a GIQL join query
+ WHEN transpiling with bin_size=None (explicit)
+ THEN should produce the same output as default (10000)
+ """
+ sql_default = transpile(
+ """
+ SELECT a.*, b.*
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["peaks", "genes"],
+ )
+
+ sql_none = transpile(
+ """
+ SELECT a.*, b.*
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["peaks", "genes"],
+ bin_size=None,
+ )
+
+ assert sql_default == sql_none
+ assert "10000" in sql_default
+
+ def test_implicit_cross_join_uses_binned_optimization(self):
+ """
+ GIVEN a GIQL query with implicit cross-join (FROM a, b WHERE INTERSECTS)
+ WHEN transpiling
+ THEN should use the binned equi-join optimization without leaking
+ __giql_bin into SELECT * output columns
+ """
+ sql = transpile(
+ """
+ SELECT DISTINCT a.*
+ FROM peaks a, genes b
+ WHERE a.interval INTERSECTS b.interval
+ """,
+ tables=["peaks", "genes"],
+ )
+
+ # Binned CTEs are present
+ assert "WITH" in sql.upper()
+ assert "__giql_bin" in sql
+ assert "UNNEST" in sql.upper()
+
+ # Original table references preserved — no CTE leak into SELECT *
+ assert "peaks" in sql
+ assert '"chrom"' in sql
+ assert '"start"' in sql
+ assert '"end"' in sql
+
+ def test_self_join_single_shared_cte(self):
+ """
+ GIVEN a self-join query where the same table appears with two aliases
+ WHEN transpiling a binned join
+ THEN should produce one shared key-only CTE for the underlying table,
+ joined twice through distinct connector aliases
+ """
+ sql = transpile(
+ """
+ SELECT a.*, b.*
+ FROM peaks a
+ JOIN peaks b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["peaks"],
+ )
+
+ sql_upper = sql.upper()
+
+ # One shared CTE keyed on the table name
+ assert "__giql_peaks_bins" in sql
+
+ # Original table preserved in FROM
+ assert "peaks" in sql
+
+ # Should still have DISTINCT
+ assert "DISTINCT" in sql_upper
+
+ def test_invalid_bin_size_raises(self):
+ """
+ GIVEN bin_size=0 or a negative value
+ WHEN calling transpile
+ THEN should raise ValueError
+ """
+ with pytest.raises(ValueError, match="positive"):
+ transpile(
+ "SELECT * FROM a JOIN b ON a.interval INTERSECTS b.interval",
+ tables=["a", "b"],
+ bin_size=0,
+ )
+
+ with pytest.raises(ValueError, match="positive"):
+ transpile(
+ "SELECT * FROM a JOIN b ON a.interval INTERSECTS b.interval",
+ tables=["a", "b"],
+ bin_size=-1,
+ )
+
+ def test_multi_join_all_intersects_rewritten(self):
+ """
+ GIVEN a three-way join with two INTERSECTS conditions
+ WHEN transpiling
+ THEN should create one key-only CTE per underlying table and rewrite
+ each INTERSECTS join as a three-join bridge through those CTEs
+ """
+ sql = transpile(
+ """
+ SELECT a.*, b.*, c.*
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ JOIN exons c ON a.interval INTERSECTS c.interval
+ """,
+ tables=["peaks", "genes", "exons"],
+ )
+
+ # One CTE per underlying table
+ assert "__giql_peaks_bins" in sql
+ assert "__giql_genes_bins" in sql
+ assert "__giql_exons_bins" in sql
+
+ # __giql_bin appears in CTE definitions and ON conditions
+ sql_upper = sql.upper()
+ assert sql_upper.count("__GIQL_BIN") >= 4 # at least 2 per INTERSECTS join
+
+ def test_explicit_columns_uses_full_cte_not_bridge(self):
+ """
+ GIVEN a join query with only explicit columns in SELECT (no wildcards)
+ WHEN transpiling
+ THEN should use the 1-join full-CTE approach, not bridge CTEs
+ """
+ sql = transpile(
+ """
+ SELECT a.chrom, a.start, b.start AS b_start
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["peaks", "genes"],
+ )
+
+ # Full-CTE names (alias-based, one join per INTERSECTS)
+ assert "__giql_a_binned" in sql
+ assert "__giql_b_binned" in sql
+
+ # Bridge CTEs must NOT be present
+ assert "__giql_peaks_bins" not in sql
+ assert "__giql_c0" not in sql
+
+ def test_wildcard_select_uses_bridge_not_full_cte(self):
+ """
+ GIVEN a join query with a wildcard expression in SELECT (a.*)
+ WHEN transpiling
+ THEN should use the bridge CTE approach, not full-CTEs
+ """
+ sql = transpile(
+ """
+ SELECT a.*, b.*
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["peaks", "genes"],
+ )
+
+ # Bridge CTE names (table-based)
+ assert "__giql_peaks_bins" in sql
+ assert "__giql_genes_bins" in sql
+
+ # Full CTEs must NOT be present
+ assert "__giql_a_binned" not in sql
+ assert "__giql_b_binned" not in sql
+
+
+class TestBinnedJoinDataFusion:
+ """End-to-end DataFusion correctness tests for binned INTERSECTS joins."""
+
+ @staticmethod
+ def _make_ctx(peaks_data, genes_data):
+ """Create a DataFusion context with two interval tables."""
+ import pyarrow as pa
+ from datafusion import SessionContext
+
+ schema = pa.schema(
+ [
+ ("chrom", pa.utf8()),
+ ("start", pa.int64()),
+ ("end", pa.int64()),
+ ]
+ )
+
+ ctx = SessionContext()
+
+ peaks_arrays = {
+ "chrom": [r[0] for r in peaks_data],
+ "start": [r[1] for r in peaks_data],
+ "end": [r[2] for r in peaks_data],
+ }
+ genes_arrays = {
+ "chrom": [r[0] for r in genes_data],
+ "start": [r[1] for r in genes_data],
+ "end": [r[2] for r in genes_data],
+ }
+
+ ctx.register_record_batches(
+ "peaks", [pa.table(peaks_arrays, schema=schema).to_batches()]
+ )
+ ctx.register_record_batches(
+ "genes", [pa.table(genes_arrays, schema=schema).to_batches()]
+ )
+ return ctx
+
+ def test_overlapping_intervals_correct_rows_no_duplicates(self):
+ """
+ GIVEN two tables with overlapping intervals
+ WHEN executing a binned INTERSECTS join via DataFusion
+ THEN should return the correct matching rows with no duplicates
+ """
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 100, 500), ("chr1", 1000, 2000)],
+ genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)],
+ )
+
+ sql = transpile(
+ """
+ SELECT a.chrom, a.start, a."end", b.start AS b_start, b."end" AS b_end
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ df = ctx.sql(sql).to_pandas()
+
+ # Only the first peak (100-500) overlaps the first gene (300-600)
+ assert len(df) == 1
+ assert df.iloc[0]["start"] == 100
+ assert df.iloc[0]["b_start"] == 300
+
+ def test_non_overlapping_intervals_zero_rows(self):
+ """
+ GIVEN two tables with no overlapping intervals
+ WHEN executing a binned INTERSECTS join via DataFusion
+ THEN should return zero rows
+ """
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 100, 200), ("chr1", 300, 400)],
+ genes_data=[("chr1", 500, 600), ("chr1", 700, 800)],
+ )
+
+ sql = transpile(
+ """
+ SELECT a.chrom, a.start, a."end"
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ df = ctx.sql(sql).to_pandas()
+ assert len(df) == 0
+
+ def test_adjacent_intervals_zero_rows_half_open(self):
+ """
+ GIVEN two tables with adjacent (touching) intervals under half-open coordinates
+ WHEN executing a binned INTERSECTS join via DataFusion
+ THEN should return zero rows because [100, 200) and [200, 300) do not overlap
+ """
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 100, 200)],
+ genes_data=[("chr1", 200, 300)],
+ )
+
+ sql = transpile(
+ """
+ SELECT a.chrom, a.start, a."end"
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ df = ctx.sql(sql).to_pandas()
+ assert len(df) == 0
+
+ def test_different_chromosomes_only_same_chrom(self):
+ """
+ GIVEN two tables with intervals on different chromosomes that would overlap positionally
+ WHEN executing a binned INTERSECTS join via DataFusion
+ THEN should only return overlaps on the same chromosome
+ """
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 100, 500), ("chr2", 100, 500)],
+ genes_data=[("chr1", 200, 400), ("chr3", 200, 400)],
+ )
+
+ sql = transpile(
+ """
+ SELECT a.chrom, a.start, a."end",
+ b.chrom AS b_chrom, b.start AS b_start, b."end" AS b_end
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ df = ctx.sql(sql).to_pandas()
+
+ assert len(df) == 1
+ assert df.iloc[0]["chrom"] == "chr1"
+ assert df.iloc[0]["b_chrom"] == "chr1"
+
+ def test_intervals_spanning_multiple_bins_no_duplicates(self):
+ """
+ GIVEN intervals that span multiple bins
+ WHEN executing a binned INTERSECTS join via DataFusion
+ THEN overlapping pairs should be returned exactly once (DISTINCT dedup)
+ """
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 0, 50000)],
+ genes_data=[("chr1", 25000, 75000)],
+ )
+
+ sql = transpile(
+ """
+ SELECT a.chrom, a.start, a."end", b.start AS b_start, b."end" AS b_end
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ bin_size=10000,
+ )
+
+ df = ctx.sql(sql).to_pandas()
+
+ # Despite sharing multiple bins (2, 3, 4), should appear exactly once
+ assert len(df) == 1
+ assert df.iloc[0]["start"] == 0
+ assert df.iloc[0]["end"] == 50000
+ assert df.iloc[0]["b_start"] == 25000
+ assert df.iloc[0]["b_end"] == 75000
+
+ def test_equivalence_with_naive_cross_join(self):
+ """
+ GIVEN two tables with a mix of overlapping and non-overlapping intervals
+ WHEN executing a binned INTERSECTS join via DataFusion
+ THEN results should match a naive cross-join with overlap filter
+ """
+ ctx = self._make_ctx(
+ peaks_data=[
+ ("chr1", 0, 100),
+ ("chr1", 150, 300),
+ ("chr1", 500, 1000),
+ ("chr2", 0, 500),
+ ],
+ genes_data=[
+ ("chr1", 50, 200),
+ ("chr1", 250, 600),
+ ("chr1", 900, 1100),
+ ("chr2", 400, 800),
+ ],
+ )
+
+ naive_sql = """
+ SELECT a.chrom AS a_chrom, a.start AS a_start, a."end" AS a_end,
+ b.chrom AS b_chrom, b.start AS b_start, b."end" AS b_end
+ FROM peaks a, genes b
+ WHERE a.chrom = b.chrom
+ AND a.start < b."end"
+ AND a."end" > b.start
+ ORDER BY a.chrom, a.start, b.start
+ """
+ naive_df = ctx.sql(naive_sql).to_pandas()
+
+ binned_sql = transpile(
+ """
+ SELECT a.chrom AS a_chrom, a.start AS a_start, a."end" AS a_end,
+ b.chrom AS b_chrom, b.start AS b_start, b."end" AS b_end
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+ binned_df = (
+ ctx.sql(binned_sql)
+ .to_pandas()
+ .sort_values(by=["a_chrom", "a_start", "b_start"])
+ .reset_index(drop=True)
+ )
+ naive_df = naive_df.reset_index(drop=True)
+
+ assert len(binned_df) == len(naive_df)
+ assert binned_df.values.tolist() == naive_df.values.tolist()
+
+ def test_implicit_cross_join_correct_rows_no_bin_leak(self):
+ """
+ GIVEN two tables with overlapping intervals queried via implicit cross-join syntax
+ WHEN executing a binned INTERSECTS join via DataFusion
+ THEN results should be correct and SELECT a.* should not include __giql_bin
+ """
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 100, 500), ("chr1", 1000, 2000)],
+ genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)],
+ )
+
+ sql = transpile(
+ """
+ SELECT DISTINCT a.*
+ FROM peaks a, genes b
+ WHERE a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ df = ctx.sql(sql).to_pandas()
+
+ # Only the first peak (100-500) overlaps the first gene (300-600)
+ assert len(df) == 1
+ assert df.iloc[0]["start"] == 100
+
+ # SELECT a.* must return exactly the original table columns — no __giql_bin
+ assert list(df.columns) == ["chrom", "start", "end"]
+
+
+class TestBinnedJoinOuterJoinSemantics:
+ """Regression tests: outer join kinds must be preserved after rewrite.
+
+ Bug: the bridge path only applied the join kind (LEFT, RIGHT, FULL) to
+ join3, while join1 and join2 were always INNER — silently converting
+ outer joins into inner joins.
+ """
+
+ @staticmethod
+ def _make_ctx(peaks_data, genes_data):
+ """Create a DataFusion context with peaks and genes tables."""
+ import pyarrow as pa
+ from datafusion import SessionContext
+
+ schema = pa.schema(
+ [
+ ("chrom", pa.utf8()),
+ ("start", pa.int64()),
+ ("end", pa.int64()),
+ ]
+ )
+ ctx = SessionContext()
+ ctx.register_record_batches(
+ "peaks",
+ [
+ pa.table(
+ {
+ "chrom": [r[0] for r in peaks_data],
+ "start": [r[1] for r in peaks_data],
+ "end": [r[2] for r in peaks_data],
+ },
+ schema=schema,
+ ).to_batches()
+ ],
+ )
+ ctx.register_record_batches(
+ "genes",
+ [
+ pa.table(
+ {
+ "chrom": [r[0] for r in genes_data],
+ "start": [r[1] for r in genes_data],
+ "end": [r[2] for r in genes_data],
+ },
+ schema=schema,
+ ).to_batches()
+ ],
+ )
+ return ctx
+
+ def test_left_join_preserves_unmatched_left_rows_full_cte(self):
+ """
+ GIVEN peaks with one matching and one non-matching interval
+ WHEN a LEFT JOIN with INTERSECTS is transpiled (no wildcards, full-CTE path)
+ THEN the SQL must contain LEFT keyword and execution must return all
+ left rows including unmatched ones with NULLs on the right
+ """
+ sql = transpile(
+ """
+ SELECT a.chrom, a.start, a."end", b.start AS b_start
+ FROM peaks a
+ LEFT JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ assert "LEFT" in sql.upper()
+
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 100, 500), ("chr1", 1000, 2000)],
+ genes_data=[("chr1", 300, 600)],
+ )
+ df = ctx.sql(sql).to_pandas().sort_values("start").reset_index(drop=True)
+
+ assert len(df) == 2
+ assert df.iloc[0]["start"] == 100
+ assert df.iloc[0]["b_start"] == 300
+ assert df.iloc[1]["start"] == 1000
+ assert _is_null(df.iloc[1]["b_start"])
+
+ def test_left_join_preserves_unmatched_left_rows_bridge(self):
+ """
+ GIVEN peaks with one matching and one non-matching interval
+ WHEN a LEFT JOIN with INTERSECTS is transpiled (wildcards, bridge path)
+ THEN the SQL must contain LEFT keyword and execution must return all
+ left rows including unmatched ones with NULLs on the right
+ """
+ sql = transpile(
+ """
+ SELECT a.*, b.start AS b_start
+ FROM peaks a
+ LEFT JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ assert "LEFT" in sql.upper()
+
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 100, 500), ("chr1", 1000, 2000)],
+ genes_data=[("chr1", 300, 600)],
+ )
+ df = ctx.sql(sql).to_pandas().sort_values("start").reset_index(drop=True)
+
+ assert len(df) == 2
+ assert df.iloc[0]["start"] == 100
+ assert df.iloc[0]["b_start"] == 300
+ assert df.iloc[1]["start"] == 1000
+ assert _is_null(df.iloc[1]["b_start"])
+
+ def test_right_join_preserves_unmatched_right_rows_full_cte(self):
+ """
+ GIVEN genes with one matching and one non-matching interval
+ WHEN a RIGHT JOIN with INTERSECTS is transpiled (no wildcards, full-CTE path)
+ THEN the SQL must contain RIGHT keyword and execution must return all
+ right rows including unmatched ones with NULLs on the left
+ """
+ sql = transpile(
+ """
+ SELECT a.start AS a_start, b.chrom, b.start, b."end"
+ FROM peaks a
+ RIGHT JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ assert "RIGHT" in sql.upper()
+
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 100, 500)],
+ genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)],
+ )
+ df = ctx.sql(sql).to_pandas().sort_values("start").reset_index(drop=True)
+
+ assert len(df) == 2
+ matched = df[df["a_start"].notna()]
+ unmatched = df[df["a_start"].isna()]
+ assert len(matched) == 1
+ assert matched.iloc[0]["start"] == 300
+ assert len(unmatched) == 1
+ assert unmatched.iloc[0]["start"] == 5000
+
+ def test_right_join_preserves_unmatched_right_rows_bridge(self):
+ """
+ GIVEN genes with one matching and one non-matching interval
+ WHEN a RIGHT JOIN with INTERSECTS is transpiled (wildcards, bridge path)
+ THEN the SQL must contain RIGHT keyword and execution must return all
+ right rows including unmatched ones with NULLs on the left
+ """
+ sql = transpile(
+ """
+ SELECT a.start AS a_start, b.*
+ FROM peaks a
+ RIGHT JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ assert "RIGHT" in sql.upper()
+
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 100, 500)],
+ genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)],
+ )
+ df = ctx.sql(sql).to_pandas().sort_values("start").reset_index(drop=True)
+
+ assert len(df) == 2
+ matched = df[df["a_start"].notna()]
+ unmatched = df[df["a_start"].isna()]
+ assert len(matched) == 1
+ assert matched.iloc[0]["start"] == 300
+ assert len(unmatched) == 1
+ assert unmatched.iloc[0]["start"] == 5000
+
+ def test_full_outer_join_preserves_both_unmatched_full_cte(self):
+ """
+ GIVEN peaks and genes each with one matching and one non-matching interval
+ WHEN a FULL OUTER JOIN with INTERSECTS is transpiled (no wildcards, full-CTE)
+ THEN the SQL must contain FULL keyword and execution must return three
+ rows: one matched pair plus one unmatched from each side
+ """
+ sql = transpile(
+ """
+ SELECT a.start AS a_start, a."end" AS a_end,
+ b.start AS b_start, b."end" AS b_end
+ FROM peaks a
+ FULL OUTER JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ assert "FULL" in sql.upper()
+
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 100, 500), ("chr1", 8000, 9000)],
+ genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)],
+ )
+ df = ctx.sql(sql).to_pandas()
+
+ assert len(df) == 3
+ matched = df[df["a_start"].notna() & df["b_start"].notna()]
+ left_only = df[df["a_start"].notna() & df["b_start"].isna()]
+ right_only = df[df["a_start"].isna() & df["b_start"].notna()]
+ assert len(matched) == 1
+ assert len(left_only) == 1
+ assert len(right_only) == 1
+
+ def test_full_outer_join_preserves_both_unmatched_bridge(self):
+ """
+ GIVEN peaks and genes each with one matching and one non-matching interval
+ WHEN a FULL OUTER JOIN with INTERSECTS is transpiled (wildcards, bridge path)
+ THEN the SQL must contain FULL keyword and execution must return three
+ rows: one matched pair plus one unmatched from each side
+ """
+ sql = transpile(
+ """
+ SELECT a.*, b.start AS b_start, b."end" AS b_end
+ FROM peaks a
+ FULL OUTER JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ assert "FULL" in sql.upper()
+
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 100, 500), ("chr1", 8000, 9000)],
+ genes_data=[("chr1", 300, 600), ("chr1", 5000, 6000)],
+ )
+ df = ctx.sql(sql).to_pandas()
+
+ assert len(df) == 3
+ matched = df[df["start"].notna() & df["b_start"].notna()]
+ left_only = df[df["start"].notna() & df["b_start"].isna()]
+ right_only = df[df["start"].isna() & df["b_start"].notna()]
+ assert len(matched) == 1
+ assert len(left_only) == 1
+ assert len(right_only) == 1
+
+ def test_left_join_all_unmatched_returns_all_left_rows(self):
+ """
+ GIVEN peaks where no intervals overlap any gene
+ WHEN a LEFT JOIN with INTERSECTS is transpiled
+ THEN all left rows must still appear with NULLs on the right
+ """
+ sql = transpile(
+ """
+ SELECT a.chrom, a.start, a."end", b.start AS b_start
+ FROM peaks a
+ LEFT JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ ctx = self._make_ctx(
+ peaks_data=[("chr1", 100, 200), ("chr1", 300, 400)],
+ genes_data=[("chr1", 500, 600)],
+ )
+ df = ctx.sql(sql).to_pandas()
+
+ assert len(df) == 2
+ assert df["b_start"].isna().all()
+
+
+class TestBinnedJoinAdditionalOnConditions:
+ """Regression tests: non-INTERSECTS conditions in ON must be preserved.
+
+ Bug: the rewrite replaces the entire ON clause with the binned equi-join
+ and overlap predicate, silently dropping any additional user conditions
+ like ``AND a.score > b.score``.
+ """
+
+ @staticmethod
+ def _make_ctx_with_score():
+ """Create a DataFusion context with peaks and genes tables that include a score column."""
+ import pyarrow as pa
+ from datafusion import SessionContext
+
+ schema = pa.schema(
+ [
+ ("chrom", pa.utf8()),
+ ("start", pa.int64()),
+ ("end", pa.int64()),
+ ("score", pa.int64()),
+ ]
+ )
+ ctx = SessionContext()
+ ctx.register_record_batches(
+ "peaks",
+ [
+ pa.table(
+ {
+ "chrom": ["chr1", "chr1"],
+ "start": [100, 100],
+ "end": [500, 500],
+ "score": [10, 50],
+ },
+ schema=schema,
+ ).to_batches()
+ ],
+ )
+ ctx.register_record_batches(
+ "genes",
+ [
+ pa.table(
+ {
+ "chrom": ["chr1", "chr1"],
+ "start": [200, 200],
+ "end": [600, 600],
+ "score": [30, 30],
+ },
+ schema=schema,
+ ).to_batches()
+ ],
+ )
+ return ctx
+
+ def test_additional_on_condition_preserved_full_cte(self):
+ """
+ GIVEN two overlapping intervals where only one pair satisfies score filter
+ WHEN INTERSECTS is combined with a.score > b.score in ON (no wildcards)
+ THEN the additional condition must survive the rewrite and filter results
+ """
+ sql = transpile(
+ """
+ SELECT a.chrom, a.start, a."end", a.score AS a_score, b.score AS b_score
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval AND a.score > b.score
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ assert "score" in sql.lower()
+
+ ctx = self._make_ctx_with_score()
+ df = ctx.sql(sql).to_pandas()
+
+ assert len(df) == 1
+ assert df.iloc[0]["a_score"] == 50
+
+ def test_additional_on_condition_preserved_bridge(self):
+ """
+ GIVEN two overlapping intervals where only one pair satisfies score filter
+ WHEN INTERSECTS is combined with a.score > b.score in ON (wildcards)
+ THEN the additional condition must survive the rewrite and filter results
+ """
+ sql = transpile(
+ """
+ SELECT a.*, b.score AS b_score
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval AND a.score > b.score
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ assert "score" in sql.lower()
+
+ ctx = self._make_ctx_with_score()
+ df = ctx.sql(sql).to_pandas()
+
+ assert len(df) == 1
+ assert df.iloc[0]["score"] == 50
+
+ def test_additional_on_condition_with_left_join(self):
+ """
+ GIVEN overlapping intervals with an extra ON condition that filters all matches
+ WHEN LEFT JOIN with INTERSECTS AND a.score > b.score is used
+ THEN unmatched left rows appear with NULL right columns
+ """
+ sql = transpile(
+ """
+ SELECT a.chrom, a.start, a.score AS a_score, b.score AS b_score
+ FROM peaks a
+ LEFT JOIN genes b
+ ON a.interval INTERSECTS b.interval AND a.score > b.score
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ ctx = self._make_ctx_with_score()
+ df = ctx.sql(sql).to_pandas().sort_values("a_score").reset_index(drop=True)
+
+ assert len(df) == 2
+ row_low = df[df["a_score"] == 10].iloc[0]
+ row_high = df[df["a_score"] == 50].iloc[0]
+ assert _is_null(row_low["b_score"])
+ assert row_high["b_score"] == 30
+
+ def test_multiple_additional_conditions_preserved(self):
+ """
+ GIVEN overlapping intervals with two extra ON conditions
+ WHEN INTERSECTS is combined with a.score > 20 AND b.score < 40 in ON
+ THEN both conditions must survive the rewrite
+ """
+ sql = transpile(
+ """
+ SELECT a.chrom, a.start, a.score AS a_score, b.score AS b_score
+ FROM peaks a
+ JOIN genes b
+ ON a.interval INTERSECTS b.interval
+ AND a.score > 20
+ AND b.score < 40
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ sql_lower = sql.lower()
+ assert "score" in sql_lower
+ assert "20" in sql
+ assert "40" in sql
+
+ def test_additional_on_condition_implicit_cross_join(self):
+ """
+ GIVEN overlapping intervals queried via implicit cross-join with extra WHERE
+ WHEN INTERSECTS is in WHERE alongside a.score > b.score
+ THEN the score condition must be preserved in the output SQL
+ """
+ sql = transpile(
+ """
+ SELECT a.chrom, a.start, a.score AS a_score, b.score AS b_score
+ FROM peaks a, genes b
+ WHERE a.interval INTERSECTS b.interval AND a.score > b.score
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ assert "score" in sql.lower()
+
+ ctx = self._make_ctx_with_score()
+ df = ctx.sql(sql).to_pandas()
+
+ assert len(df) == 1
+ assert df.iloc[0]["a_score"] == 50
+
+
+class TestBinnedJoinDistinctSemantics:
+ """Regression tests: unconditional DISTINCT can collapse legitimate duplicates.
+
+ Bug: the transformer always adds DISTINCT to deduplicate bin fan-out,
+ but this also collapses rows that are genuinely duplicated in the source
+ data, changing SQL bag semantics.
+ """
+
+ @staticmethod
+ def _make_ctx_with_duplicates():
+ """Create a DataFusion context where peaks has duplicate rows."""
+ import pyarrow as pa
+ from datafusion import SessionContext
+
+ schema = pa.schema(
+ [
+ ("chrom", pa.utf8()),
+ ("start", pa.int64()),
+ ("end", pa.int64()),
+ ]
+ )
+ ctx = SessionContext()
+ ctx.register_record_batches(
+ "peaks",
+ [
+ pa.table(
+ {
+ "chrom": ["chr1", "chr1"],
+ "start": [100, 100],
+ "end": [500, 500],
+ },
+ schema=schema,
+ ).to_batches()
+ ],
+ )
+ ctx.register_record_batches(
+ "genes",
+ [
+ pa.table(
+ {
+ "chrom": ["chr1"],
+ "start": [200],
+ "end": [600],
+ },
+ schema=schema,
+ ).to_batches()
+ ],
+ )
+ return ctx
+
+ @pytest.mark.xfail(
+ reason="Unconditional DISTINCT collapses legitimate duplicate rows",
+ strict=True,
+ )
+ def test_duplicate_rows_preserved_full_cte(self):
+ """
+ GIVEN peaks with two identical rows that both overlap one gene
+ WHEN an inner join with INTERSECTS is transpiled (no wildcards, full-CTE)
+ THEN both rows should be returned, matching naive cross-join behavior
+ """
+ ctx = self._make_ctx_with_duplicates()
+
+ naive_sql = """
+ SELECT a.chrom, a.start, a."end", b.start AS b_start
+ FROM peaks a, genes b
+ WHERE a.chrom = b.chrom AND a.start < b."end" AND a."end" > b.start
+ """
+ naive_df = ctx.sql(naive_sql).to_pandas()
+ assert len(naive_df) == 2
+
+ binned_sql = transpile(
+ """
+ SELECT a.chrom, a.start, a."end", b.start AS b_start
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ binned_df = ctx.sql(binned_sql).to_pandas()
+ assert len(binned_df) == len(naive_df)
+
+ @pytest.mark.xfail(
+ reason="Unconditional DISTINCT collapses legitimate duplicate rows",
+ strict=True,
+ )
+ def test_duplicate_rows_preserved_bridge(self):
+ """
+ GIVEN peaks with two identical rows that both overlap one gene
+ WHEN an inner join with INTERSECTS is transpiled (wildcards, bridge path)
+ THEN both rows should be returned, matching naive cross-join behavior
+ """
+ ctx = self._make_ctx_with_duplicates()
+
+ naive_sql = """
+ SELECT a.chrom, a.start, a."end", b.start AS b_start
+ FROM peaks a, genes b
+ WHERE a.chrom = b.chrom AND a.start < b."end" AND a."end" > b.start
+ """
+ naive_df = ctx.sql(naive_sql).to_pandas()
+ assert len(naive_df) == 2
+
+ binned_sql = transpile(
+ """
+ SELECT a.*, b.start AS b_start
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ binned_df = ctx.sql(binned_sql).to_pandas()
+ assert len(binned_df) == len(naive_df)
+
+ def test_non_duplicate_rows_unaffected(self):
+ """
+ GIVEN peaks with two distinct rows that both overlap one gene
+ WHEN an inner join with INTERSECTS is transpiled
+ THEN DISTINCT does not collapse them because they differ
+ """
+ import pyarrow as pa
+ from datafusion import SessionContext
+
+ schema = pa.schema(
+ [
+ ("chrom", pa.utf8()),
+ ("start", pa.int64()),
+ ("end", pa.int64()),
+ ]
+ )
+ ctx = SessionContext()
+ ctx.register_record_batches(
+ "peaks",
+ [
+ pa.table(
+ {
+ "chrom": ["chr1", "chr1"],
+ "start": [100, 150],
+ "end": [500, 550],
+ },
+ schema=schema,
+ ).to_batches()
+ ],
+ )
+ ctx.register_record_batches(
+ "genes",
+ [
+ pa.table(
+ {
+ "chrom": ["chr1"],
+ "start": [200],
+ "end": [600],
+ },
+ schema=schema,
+ ).to_batches()
+ ],
+ )
+
+ binned_sql = transpile(
+ """
+ SELECT a.chrom, a.start, a."end", b.start AS b_start
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ df = ctx.sql(binned_sql).to_pandas()
+ assert len(df) == 2
+
+ def test_user_distinct_already_present_still_works(self):
+ """
+ GIVEN a query that already has SELECT DISTINCT
+ WHEN the binned join rewrite also adds DISTINCT
+ THEN the query must still execute correctly (no double-DISTINCT error)
+ """
+ import pyarrow as pa
+ from datafusion import SessionContext
+
+ schema = pa.schema(
+ [
+ ("chrom", pa.utf8()),
+ ("start", pa.int64()),
+ ("end", pa.int64()),
+ ]
+ )
+ ctx = SessionContext()
+ ctx.register_record_batches(
+ "peaks",
+ [
+ pa.table(
+ {"chrom": ["chr1"], "start": [100], "end": [500]},
+ schema=schema,
+ ).to_batches()
+ ],
+ )
+ ctx.register_record_batches(
+ "genes",
+ [
+ pa.table(
+ {"chrom": ["chr1"], "start": [200], "end": [600]},
+ schema=schema,
+ ).to_batches()
+ ],
+ )
+
+ binned_sql = transpile(
+ """
+ SELECT DISTINCT a.chrom, a.start, b.start AS b_start
+ FROM peaks a
+ JOIN genes b ON a.interval INTERSECTS b.interval
+ """,
+ tables=[
+ Table("peaks", chrom_col="chrom", start_col="start", end_col="end"),
+ Table("genes", chrom_col="chrom", start_col="start", end_col="end"),
+ ],
+ )
+
+ df = ctx.sql(binned_sql).to_pandas()
+ assert len(df) == 1
+
+
+class TestBinnedJoinBinBoundaryRounding:
+ """Regression tests for bin-index calculation rounding errors.
+
+ The original formula CAST(start / B AS BIGINT) uses float division
+ followed by a cast. When the division lands on x.5 the cast rounds
+ to nearest-even instead of flooring, producing the wrong bin index
+ and causing missed matches.
+ """
+
+ @staticmethod
+ def _make_ctx(table_a_data, table_b_data):
+ import pyarrow as pa
+ from datafusion import SessionContext
+
+ schema = pa.schema(
+ [
+ ("chrom", pa.utf8()),
+ ("start", pa.int64()),
+ ("end", pa.int64()),
+ ("name", pa.utf8()),
+ ]
+ )
+ ctx = SessionContext()
+ ctx.register_record_batches(
+ "intervals_a",
+ [pa.table(table_a_data, schema=schema).to_batches()],
+ )
+ ctx.register_record_batches(
+ "intervals_b",
+ [pa.table(table_b_data, schema=schema).to_batches()],
+ )
+ return ctx
+
+ def test_half_bin_boundary_overlap_not_missed(self):
+ """
+ GIVEN interval A spanning many bins and interval B whose start
+ falls exactly on a .5 division boundary (e.g., 621950/100)
+ WHEN INTERSECTS is evaluated with bin_size=100 on DuckDB
+ THEN the overlap must be found, not missed due to rounding
+ """
+ import duckdb
+
+ conn = duckdb.connect(":memory:")
+ conn.execute(
+ "CREATE TABLE intervals_a "
+ '(chrom VARCHAR, "start" INTEGER, "end" INTEGER, '
+ "name VARCHAR)"
+ )
+ conn.execute(
+ "CREATE TABLE intervals_b "
+ '(chrom VARCHAR, "start" INTEGER, "end" INTEGER, '
+ "name VARCHAR)"
+ )
+ conn.execute("INSERT INTO intervals_a VALUES ('chr1', 421951, 621951, 'a0')")
+ conn.execute("INSERT INTO intervals_b VALUES ('chr1', 621950, 621951, 'b0')")
+
+ sql = transpile(
+ """
+ SELECT DISTINCT a.name, b.name AS b_name
+ FROM intervals_a a
+ JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ bin_size=100,
+ )
+ result = conn.execute(sql).fetchall()
+ conn.close()
+ assert len(result) == 1, (
+ f"Expected 1 match, got {len(result)} — "
+ f"bin boundary rounding likely dropped the overlap"
+ )
+
+ def test_exact_bin_boundary_start(self):
+ """
+ GIVEN interval B starting at an exact multiple of bin_size
+ WHEN INTERSECTS is evaluated on DuckDB
+ THEN the correct bin index is assigned (no off-by-one from rounding)
+ """
+ import duckdb
+
+ conn = duckdb.connect(":memory:")
+ conn.execute(
+ "CREATE TABLE intervals_a "
+ '(chrom VARCHAR, "start" INTEGER, "end" INTEGER, '
+ "name VARCHAR)"
+ )
+ conn.execute(
+ "CREATE TABLE intervals_b "
+ '(chrom VARCHAR, "start" INTEGER, "end" INTEGER, '
+ "name VARCHAR)"
+ )
+ conn.execute("INSERT INTO intervals_a VALUES ('chr1', 999, 1001, 'a0')")
+ conn.execute("INSERT INTO intervals_b VALUES ('chr1', 1000, 1001, 'b0')")
+
+ sql = transpile(
+ """
+ SELECT DISTINCT a.name, b.name AS b_name
+ FROM intervals_a a
+ JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ bin_size=1000,
+ )
+ result = conn.execute(sql).fetchall()
+ conn.close()
+ assert len(result) == 1, f"Expected 1 match at bin boundary, got {len(result)}"
+
+
+class TestBinnedJoinOuterJoinMultiBin:
+ """Regression tests for outer join with multi-bin intervals.
+
+ When an interval spans multiple bins, the outer join produces one
+ row per bin. Bins that don't match the other side create spurious
+ NULL rows. DISTINCT can't collapse a NULL row with a matched row
+ because they differ in the non-NULL columns.
+ """
+
+ @staticmethod
+ def _make_ctx(table_a_data, table_b_data):
+ import pyarrow as pa
+ from datafusion import SessionContext
+
+ schema = pa.schema(
+ [
+ ("chrom", pa.utf8()),
+ ("start", pa.int64()),
+ ("end", pa.int64()),
+ ("name", pa.utf8()),
+ ]
+ )
+ ctx = SessionContext()
+ ctx.register_record_batches(
+ "intervals_a",
+ [pa.table(table_a_data, schema=schema).to_batches()],
+ )
+ ctx.register_record_batches(
+ "intervals_b",
+ [pa.table(table_b_data, schema=schema).to_batches()],
+ )
+ return ctx
+
+ def test_left_join_no_spurious_null_row(self):
+ """
+ GIVEN interval A spanning bins 0 and 1 and interval B only in bin 1
+ WHEN LEFT JOIN INTERSECTS is evaluated
+ THEN only 1 matched row is returned, not a matched row plus a
+ spurious NULL row from the unmatched bin-0 copy
+ """
+ ctx = self._make_ctx(
+ {
+ "chrom": ["chr1"],
+ "start": [9000],
+ "end": [11000],
+ "name": ["a0"],
+ },
+ {
+ "chrom": ["chr1"],
+ "start": [10500],
+ "end": [10600],
+ "name": ["b0"],
+ },
+ )
+
+ sql = transpile(
+ """
+ SELECT a.name, b.name AS b_name
+ FROM intervals_a a
+ LEFT JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ result = ctx.sql(sql).to_pandas()
+ assert len(result) == 1, (
+ f"Expected 1 matched row, got {len(result)} — "
+ f"spurious NULL row from unmatched bin"
+ )
+ assert result.iloc[0]["b_name"] == "b0"
+
+ def test_left_join_unmatched_still_returns_null(self):
+ """
+ GIVEN interval A with no overlap in B
+ WHEN LEFT JOIN INTERSECTS is evaluated
+ THEN one row with NULL B columns is returned
+ """
+ ctx = self._make_ctx(
+ {
+ "chrom": ["chr1"],
+ "start": [9000],
+ "end": [11000],
+ "name": ["a0"],
+ },
+ {
+ "chrom": ["chr2"],
+ "start": [9500],
+ "end": [10500],
+ "name": ["b0"],
+ },
+ )
+
+ sql = transpile(
+ """
+ SELECT a.name, b.name AS b_name
+ FROM intervals_a a
+ LEFT JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ result = ctx.sql(sql).to_pandas()
+ assert len(result) == 1, f"Expected 1 unmatched row, got {len(result)}"
+ assert _is_null(result.iloc[0]["b_name"])
+
+ def test_right_join_no_spurious_null_row(self):
+ """
+ GIVEN interval B spanning bins 0 and 1 and interval A only in bin 0
+ WHEN RIGHT JOIN INTERSECTS is evaluated
+ THEN only 1 matched row is returned, not a matched row plus a
+ spurious NULL row from the unmatched bin-1 copy of B
+ """
+ ctx = self._make_ctx(
+ {
+ "chrom": ["chr1"],
+ "start": [9500],
+ "end": [9600],
+ "name": ["a0"],
+ },
+ {
+ "chrom": ["chr1"],
+ "start": [9000],
+ "end": [11000],
+ "name": ["b0"],
+ },
+ )
+
+ sql = transpile(
+ """
+ SELECT a.name, b.name AS b_name
+ FROM intervals_a a
+ RIGHT JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ result = ctx.sql(sql).to_pandas()
+ assert len(result) == 1, (
+ f"Expected 1 matched row, got {len(result)} — "
+ f"spurious NULL row from unmatched bin"
+ )
+ assert result.iloc[0]["name"] == "a0"
+
+ def test_full_outer_join_no_spurious_null_row(self):
+ """
+ GIVEN interval A spanning bins 0 and 1, interval B only in bin 1
+ WHEN FULL OUTER JOIN INTERSECTS is evaluated
+ THEN only 1 matched row is returned, not a matched row plus a
+ spurious NULL row
+ """
+ ctx = self._make_ctx(
+ {
+ "chrom": ["chr1"],
+ "start": [9000],
+ "end": [11000],
+ "name": ["a0"],
+ },
+ {
+ "chrom": ["chr1"],
+ "start": [10500],
+ "end": [10600],
+ "name": ["b0"],
+ },
+ )
+
+ sql = transpile(
+ """
+ SELECT a.name, b.name AS b_name
+ FROM intervals_a a
+ FULL OUTER JOIN intervals_b b ON a.interval INTERSECTS b.interval
+ """,
+ tables=["intervals_a", "intervals_b"],
+ )
+ result = ctx.sql(sql).to_pandas()
+ assert len(result) == 1, (
+ f"Expected 1 matched row, got {len(result)} — "
+ f"spurious NULL row from unmatched bin"
+ )