diff --git a/docs/dialect/aggregation-operators.rst b/docs/dialect/aggregation-operators.rst index 9887b87..a13f129 100644 --- a/docs/dialect/aggregation-operators.rst +++ b/docs/dialect/aggregation-operators.rst @@ -328,4 +328,129 @@ Related Operators ~~~~~~~~~~~~~~~~~ - :ref:`CLUSTER ` - Assign cluster IDs without merging +- :ref:`COVERAGE ` - Compute binned genome coverage - :ref:`INTERSECTS ` - Test for overlap between specific pairs + +---- + +.. _coverage-operator: + +COVERAGE +-------- + +Compute binned genome coverage by tiling the genome into fixed-width bins. + +Description +~~~~~~~~~~~ + +The ``COVERAGE`` operator tiles the genome into fixed-width bins and aggregates overlapping intervals per bin. It generates a bin grid using ``generate_series`` and joins it against the source table to count (or otherwise aggregate) overlapping features in each bin. + +This is useful for: + +- Computing read depth or signal coverage across the genome +- Creating fixed-resolution coverage tracks from interval data +- Summarising feature density at a user-defined resolution + +The operator works as an aggregate function, returning one row per bin with the bin coordinates and the computed statistic. + +Syntax +~~~~~~ + +.. code-block:: sql + + -- Basic coverage (count overlapping intervals per bin) + SELECT COVERAGE(interval, resolution) FROM features + + -- With a named statistic (either := or => syntax) + SELECT COVERAGE(interval, 1000, stat := 'mean') FROM features + SELECT COVERAGE(interval, 1000, stat => 'mean') FROM features + + -- Aggregate a specific column instead of interval length + SELECT COVERAGE(interval, 1000, stat := 'mean', target := 'score') FROM features + + -- Named resolution parameter + SELECT COVERAGE(interval, resolution := 500) FROM features + +Parameters +~~~~~~~~~~ + +**interval** + A genomic column. + +**resolution** + Bin width in base pairs. Can be given as a positional or named parameter. + +**stat** *(optional)* + Aggregation function applied to overlapping intervals per bin. One of: + + - ``'count'`` — number of overlapping intervals (default) + - ``'mean'`` — average interval length of overlapping intervals + - ``'sum'`` — total interval length of overlapping intervals + - ``'min'`` — minimum interval length of overlapping intervals + - ``'max'`` — maximum interval length of overlapping intervals + + When ``target`` is specified, the stat is applied to that column instead of interval length. + +**target** *(optional)* + Column name to aggregate. When omitted, non-count stats aggregate interval length (``end - start``). When specified, the stat is applied to the named column. For ``'count'``, specifying a target counts non-NULL values of that column instead of ``COUNT(*)``. + +Return Value +~~~~~~~~~~~~ + +Returns one row per genomic bin: + +- ``chrom`` — Chromosome of the bin +- ``start`` — Start position of the bin +- ``end`` — End position of the bin +- ``value`` — The computed aggregate (default alias; use ``AS`` to rename) + +Examples +~~~~~~~~ + +**Basic Coverage:** + +Count the number of features overlapping each 1 kb bin: + +.. code-block:: sql + + SELECT COVERAGE(interval, 1000) + FROM features + +**Mean Coverage:** + +Compute the average interval length per 500 bp bin: + +.. code-block:: sql + + SELECT COVERAGE(interval, 500, stat := 'mean') + FROM features + +**Named Alias:** + +.. code-block:: sql + + SELECT COVERAGE(interval, 1000) AS depth + FROM reads + +**With WHERE Filter:** + +Assuming the source table includes a ``score`` column, compute coverage of high-scoring features only: + +.. code-block:: sql + + SELECT COVERAGE(interval, 1000) AS depth + FROM features + WHERE score > 10 + +Performance Notes +~~~~~~~~~~~~~~~~~ + +- The operator creates one bin per chromosome per step, so smaller resolutions produce more rows +- A ``LEFT JOIN`` ensures bins with zero coverage are included in the output +- For very large genomes, consider restricting the query with a ``WHERE`` clause on chromosome + +Related Operators +~~~~~~~~~~~~~~~~~ + +- :ref:`MERGE ` - Combine overlapping intervals into single regions +- :ref:`CLUSTER ` - Assign cluster IDs to overlapping intervals diff --git a/docs/recipes/coverage.rst b/docs/recipes/coverage.rst new file mode 100644 index 0000000..19d5f54 --- /dev/null +++ b/docs/recipes/coverage.rst @@ -0,0 +1,173 @@ +Coverage +======== + +This section covers patterns for computing genome-wide coverage and signal +summaries using GIQL's ``COVERAGE`` operator. + +Basic Coverage +-------------- + +Count Overlapping Features +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Count the number of features overlapping each 1 kb bin across the genome: + +.. code-block:: sql + + SELECT COVERAGE(interval, 1000) AS depth + FROM features + +**Sample output:** + +.. code-block:: text + + ┌────────┬────────┬────────┬───────┐ + │ chrom │ start │ end │ depth │ + ├────────┼────────┼────────┼───────┤ + │ chr1 │ 0 │ 1000 │ 3 │ + │ chr1 │ 1000 │ 2000 │ 1 │ + │ chr1 │ 2000 │ 3000 │ 0 │ + │ ... │ ... │ ... │ ... │ + └────────┴────────┴────────┴───────┘ + +Each row represents one genomic bin. Bins with no overlapping features appear with a count of zero. + +**Use case:** Compute read depth or feature density at a fixed resolution. + +Custom Bin Size +~~~~~~~~~~~~~~~ + +Use a finer resolution of 100 bp: + +.. code-block:: sql + + SELECT COVERAGE(interval, 100) AS depth + FROM reads + +**Use case:** High-resolution coverage tracks for visualisation. + +Coverage Statistics +------------------- + +Mean Interval Length per Bin +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Compute the average length of intervals overlapping each bin: + +.. code-block:: sql + + SELECT COVERAGE(interval, 1000, stat := 'mean') AS avg_len + FROM features + +Sum of Interval Lengths per Bin +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Compute the total interval length in each bin: + +.. code-block:: sql + + SELECT COVERAGE(interval, 1000, stat := 'sum') AS total_len + FROM features + +Maximum Interval Length per Bin +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Find the longest interval overlapping each bin: + +.. code-block:: sql + + SELECT COVERAGE(interval, 1000, stat := 'max') AS max_len + FROM features + +Aggregating a Specific Column +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Compute the mean score of overlapping features per bin instead of summarising interval length: + +.. code-block:: sql + + SELECT COVERAGE(interval, 1000, stat := 'mean', target := 'score') AS avg_score + FROM features + +**Use case:** Signal tracks from a numeric column (e.g. ChIP-seq score, p-value). + +Filtered Coverage +----------------- + +Strand-Specific Coverage +~~~~~~~~~~~~~~~~~~~~~~~~ + +Compute coverage for each strand separately by filtering: + +.. code-block:: sql + + -- Plus strand + SELECT COVERAGE(interval, 1000) AS depth + FROM features + WHERE strand = '+' + +.. code-block:: sql + + -- Minus strand + SELECT COVERAGE(interval, 1000) AS depth + FROM features + WHERE strand = '-' + +**Use case:** Strand-specific signal tracks for RNA-seq or stranded assays. + +Coverage of High-Scoring Features +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Restrict coverage to features above a quality threshold: + +.. code-block:: sql + + SELECT COVERAGE(interval, 1000) AS depth + FROM features + WHERE score > 10 + +5' End Counting +~~~~~~~~~~~~~~~ + +To count only the 5' ends of features (e.g. TSS or read starts), first +create a view or CTE that trims each interval to its 5' end, then apply +``COVERAGE``: + +.. code-block:: sql + + WITH five_prime AS ( + SELECT chrom, start, start + 1 AS end + FROM features + WHERE strand = '+' + UNION ALL + SELECT chrom, end - 1 AS start, end + FROM features + WHERE strand = '-' + ) + SELECT COVERAGE(interval, 1000) AS tss_count + FROM five_prime + +Normalised Coverage +------------------- + +RPM Normalisation +~~~~~~~~~~~~~~~~~ + +Normalise bin counts to reads per million (RPM) by dividing by the total +number of reads: + +.. code-block:: sql + + WITH bins AS ( + SELECT COVERAGE(interval, 1000) AS depth + FROM reads + ), + total AS ( + SELECT COUNT(*) AS n FROM reads + ) + SELECT + bins.chrom, + bins.start, + bins.end, + bins.depth * 1000000.0 / total.n AS rpm + FROM bins, total diff --git a/docs/recipes/index.rst b/docs/recipes/index.rst index cc97e47..546c02d 100644 --- a/docs/recipes/index.rst +++ b/docs/recipes/index.rst @@ -19,6 +19,10 @@ Recipe Categories Clustering overlapping intervals, distance-based clustering, merging intervals, and aggregating cluster statistics. +:doc:`coverage` + Binned genome coverage, coverage statistics, strand-specific coverage, + normalisation, and 5' end counting. + :doc:`advanced` Multi-range matching, complex filtering with joins, aggregate statistics, window expansions, and multi-table queries. diff --git a/src/giql/dialect.py b/src/giql/dialect.py index 6c70104..6327e43 100644 --- a/src/giql/dialect.py +++ b/src/giql/dialect.py @@ -14,6 +14,7 @@ from giql.expressions import Contains from giql.expressions import GIQLCluster from giql.expressions import GIQLDistance +from giql.expressions import GIQLCoverage from giql.expressions import GIQLMerge from giql.expressions import GIQLNearest from giql.expressions import Intersects @@ -54,6 +55,7 @@ class Parser(Parser): FUNCTIONS = { **Parser.FUNCTIONS, "CLUSTER": GIQLCluster.from_arg_list, + "COVERAGE": GIQLCoverage.from_arg_list, "MERGE": GIQLMerge.from_arg_list, "DISTANCE": GIQLDistance.from_arg_list, "NEAREST": GIQLNearest.from_arg_list, diff --git a/src/giql/expressions.py b/src/giql/expressions.py index 857a223..d874868 100644 --- a/src/giql/expressions.py +++ b/src/giql/expressions.py @@ -142,6 +142,55 @@ def from_arg_list(cls, args): return cls(**kwargs) +class GIQLCoverage(exp.Func): + """COVERAGE aggregate function for binned genome coverage. + + Tiles the genome into fixed-width bins and aggregates overlapping + intervals per bin using generate_series and JOIN + GROUP BY. + + Examples: + COVERAGE(interval, 1000) + COVERAGE(interval, 500, stat := 'mean') + COVERAGE(interval, resolution := 1000) + COVERAGE(interval, 1000, stat := 'mean', target := 'score') + """ + + arg_types = { + "this": True, # genomic column + "resolution": True, # bin width (positional or named) + "stat": False, # aggregation: 'count', 'mean', 'sum', 'min', 'max' + "target": False, # column to aggregate (default: interval length) + } + + @classmethod + def from_arg_list(cls, args): + """Parse argument list, handling named parameters. + + :param args: List of arguments from parser + :return: GIQLCoverage instance with properly mapped arguments + """ + kwargs = {} + positional_args = [] + + # Separate named (PropertyEQ for :=, Kwarg for =>) and positional arguments + for arg in args: + if isinstance(arg, (exp.PropertyEQ, exp.Kwarg)): + param_name = ( + arg.this.name if hasattr(arg.this, "name") else str(arg.this) + ) + kwargs[param_name.lower()] = arg.expression + else: + positional_args.append(arg) + + # Map positional arguments + if len(positional_args) > 0: + kwargs["this"] = positional_args[0] + if len(positional_args) > 1: + kwargs["resolution"] = positional_args[1] + + return cls(**kwargs) + + class GIQLDistance(exp.Func): """DISTANCE function for calculating genomic distances between intervals. diff --git a/src/giql/transformer.py b/src/giql/transformer.py index de1e70f..6571554 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -11,9 +11,19 @@ from giql.constants import DEFAULT_START_COL from giql.constants import DEFAULT_STRAND_COL from giql.expressions import GIQLCluster +from giql.expressions import GIQLCoverage from giql.expressions import GIQLMerge from giql.table import Tables +# Mapping from COVERAGE stat parameter to SQL aggregate function +COVERAGE_STAT_MAP = { + "count": "COUNT", + "mean": "AVG", + "sum": "SUM", + "min": "MIN", + "max": "MAX", +} + class ClusterTransformer: """Transforms queries containing CLUSTER into CTE-based queries. @@ -572,4 +582,474 @@ def _transform_for_merge( exp.Ordered(this=exp.column(start_col, quoted=True)), append=True, copy=False ) + # Preserve any existing CTEs from the original query + if query.args.get("with_"): + final_query.set("with_", query.args["with_"].copy()) + + return final_query + + +class CoverageTransformer: + """Transforms queries containing COVERAGE into binned coverage queries. + + COVERAGE tiles the genome into fixed-width bins and aggregates overlapping + intervals per bin: + + SELECT COVERAGE(interval, 1000) FROM features + + Into: + + WITH __giql_bins AS ( + SELECT chrom, bin_start AS start, bin_start + 1000 AS "end" + FROM ( + SELECT DISTINCT chrom, MAX("end") AS __max_end + FROM features GROUP BY chrom + ) AS __giql_chroms, + LATERAL generate_series(0, __max_end, 1000) AS t(bin_start) + ) + SELECT bins.chrom, bins.start, bins."end", COUNT(source.*) + FROM __giql_bins AS bins + LEFT JOIN features AS source + ON source.start < bins."end" + AND source."end" > bins.start + AND source.chrom = bins.chrom + GROUP BY bins.chrom, bins.start, bins."end" + ORDER BY bins.chrom, bins.start + """ + + def __init__(self, tables: Tables): + """Initialize transformer. + + :param tables: + Table configurations for column mapping + """ + self.tables = tables + + def _get_table_name(self, query: exp.Select) -> str | None: + """Extract table name from query's FROM clause. + + :param query: + Query to extract table name from + :return: + Table name if FROM contains a simple table, None otherwise + """ + from_clause = query.args.get("from_") + if not from_clause: + return None + if isinstance(from_clause.this, exp.Table): + return from_clause.this.name + return None + + def _get_table_alias(self, query: exp.Select) -> str | None: + """Extract table alias from query's FROM clause. + + :param query: + Query to extract alias from + :return: + Table alias if present, None otherwise + """ + from_clause = query.args.get("from_") + if not from_clause: + return None + if isinstance(from_clause.this, exp.Table): + return from_clause.this.alias + return None + + def _get_genomic_columns(self, query: exp.Select) -> tuple[str, str, str]: + """Get genomic column names from table config or defaults. + + :param query: + Query to extract table and column info from + :return: + Tuple of (chrom_col, start_col, end_col) + """ + table_name = self._get_table_name(query) + + chrom_col = DEFAULT_CHROM_COL + start_col = DEFAULT_START_COL + end_col = DEFAULT_END_COL + + if table_name: + table = self.tables.get(table_name) + if table: + chrom_col = table.chrom_col + start_col = table.start_col + end_col = table.end_col + + return chrom_col, start_col, end_col + + def transform(self, query: exp.Expression) -> exp.Expression: + """Transform query if it contains COVERAGE expressions. + + :param query: + Parsed query AST + :return: + Transformed query AST + """ + if not isinstance(query, exp.Select): + return query + + # Recursively transform CTEs + if query.args.get("with_"): + cte = query.args["with_"] + for cte_expr in cte.expressions: + if isinstance(cte_expr, exp.CTE): + cte_expr.set("this", self.transform(cte_expr.this)) + + # Recursively transform subqueries in FROM/JOIN/WHERE + for key in ("from_", "where"): + if query.args.get(key): + self._transform_subqueries_in_node(query.args[key]) + if query.args.get("joins"): + for join in query.args["joins"]: + self._transform_subqueries_in_node(join) + + # Find COVERAGE expressions in SELECT + coverage_exprs = self._find_coverage_expressions(query) + if not coverage_exprs: + return query + + if len(coverage_exprs) > 1: + raise ValueError("Multiple COVERAGE expressions not yet supported") + + return self._transform_for_coverage(query, coverage_exprs[0]) + + def _transform_subqueries_in_node(self, node: exp.Expression): + """Recursively transform subqueries within an expression node. + + :param node: + Expression node to search for subqueries + """ + for subquery in node.find_all(exp.Subquery): + if isinstance(subquery.this, exp.Select): + transformed = self.transform(subquery.this) + subquery.set("this", transformed) + + def _find_coverage_expressions(self, query: exp.Select) -> list[GIQLCoverage]: + """Find all COVERAGE expressions in query. + + :param query: + Query to search + :return: + List of COVERAGE expressions + """ + coverage_exprs = [] + for expression in query.expressions: + if isinstance(expression, GIQLCoverage): + coverage_exprs.append(expression) + elif isinstance(expression, exp.Alias): + if isinstance(expression.this, GIQLCoverage): + coverage_exprs.append(expression.this) + return coverage_exprs + + def _transform_for_coverage( + self, query: exp.Select, coverage_expr: GIQLCoverage + ) -> exp.Select: + """Transform query to compute COVERAGE using bins CTE + JOIN + GROUP BY. + + :param query: + Original query + :param coverage_expr: + COVERAGE expression to transform + :return: + Transformed query + """ + # Extract parameters + resolution_expr = coverage_expr.args.get("resolution") + if isinstance(resolution_expr, exp.Literal): + resolution = int(resolution_expr.this) + else: + try: + resolution = int(str(resolution_expr.this)) + except (ValueError, AttributeError): + raise ValueError("COVERAGE resolution must be an integer literal") + + stat_expr = coverage_expr.args.get("stat") + if stat_expr: + if isinstance(stat_expr, exp.Literal): + stat = stat_expr.this.strip("'\"").lower() + else: + stat = str(stat_expr).strip("'\"").lower() + else: + stat = "count" + + if stat not in COVERAGE_STAT_MAP: + raise ValueError( + f"Unknown COVERAGE stat '{stat}'. " + f"Must be one of: {', '.join(COVERAGE_STAT_MAP)}" + ) + + sql_agg = COVERAGE_STAT_MAP[stat] + + # Extract target parameter + target_expr = coverage_expr.args.get("target") + if target_expr: + if isinstance(target_expr, exp.Literal): + target_col = target_expr.this.strip("'\"") + else: + target_col = str(target_expr).strip("'\"") + else: + target_col = None + + # Get column names and table info + chrom_col, start_col, end_col = self._get_genomic_columns(query) + table_name = self._get_table_name(query) + table_alias = self._get_table_alias(query) + source_ref = table_alias or table_name or "source" + + # Build __giql_chroms subquery: + # SELECT DISTINCT chrom, MAX("end") AS __max_end FROM GROUP BY chrom + chroms_select = exp.Select() + chroms_select.select( + exp.column(chrom_col, quoted=True), + copy=False, + ) + chroms_select.select( + exp.alias_( + exp.Max(this=exp.column(end_col, quoted=True)), + "__max_end", + quoted=False, + ), + append=True, + copy=False, + ) + + if table_name: + chroms_select.from_(exp.to_table(table_name), copy=False) + + # Apply WHERE from original query to the chroms subquery too, + # qualifying unqualified column references with the table name + if query.args.get("where"): + chroms_where = query.args["where"].copy() + if table_name: + for col in chroms_where.find_all(exp.Column): + if not col.table: + col.set("table", exp.Identifier(this=table_name)) + chroms_select.set("where", chroms_where) + + chroms_select.group_by(exp.column(chrom_col, quoted=True), copy=False) + + chroms_subquery = exp.Subquery( + this=chroms_select, + alias=exp.TableAlias(this=exp.Identifier(this="__giql_chroms")), + ) + + # Build bins CTE using raw SQL for generate_series + LATERAL + # since SQLGlot doesn't natively support generate_series + bins_select = exp.Select() + bins_select.select( + exp.column(chrom_col, table="__giql_chroms", quoted=True), + copy=False, + ) + bins_select.select( + exp.alias_( + exp.column("bin_start"), + start_col, + quoted=True, + ), + append=True, + copy=False, + ) + bins_select.select( + exp.alias_( + exp.Add( + this=exp.column("bin_start"), + expression=exp.Literal.number(resolution), + ), + end_col, + quoted=True, + ), + append=True, + copy=False, + ) + + # FROM __giql_chroms subquery + bins_select.from_(chroms_subquery, copy=False) + + # CROSS JOIN LATERAL generate_series(0, __max_end, resolution) AS t(bin_start) + lateral_join = exp.Join( + this=exp.Lateral( + this=exp.Anonymous( + this="generate_series", + expressions=[ + exp.Literal.number(0), + exp.column("__max_end"), + exp.Literal.number(resolution), + ], + ), + alias=exp.TableAlias( + this=exp.Identifier(this="t"), + columns=[exp.Identifier(this="bin_start")], + ), + ), + kind="CROSS", + ) + bins_select.append("joins", lateral_join) + + # Wrap bins_select as a CTE named __giql_bins + bins_cte = exp.CTE( + this=bins_select, + alias=exp.TableAlias(this=exp.Identifier(this="__giql_bins")), + ) + with_clause = exp.With(expressions=[bins_cte]) + + # Build the aggregate expression + if stat == "count": + if target_col: + agg_expr = exp.Anonymous( + this="COUNT", + expressions=[ + exp.column(target_col, table=source_ref, quoted=True), + ], + ) + else: + agg_expr = exp.Anonymous( + this="COUNT", + expressions=[ + exp.Column( + this=exp.Star(), + table=exp.Identifier(this=source_ref), + ) + ], + ) + else: + if target_col: + agg_expr = exp.Anonymous( + this=sql_agg, + expressions=[ + exp.column(target_col, table=source_ref, quoted=True), + ], + ) + else: + agg_expr = exp.Anonymous( + this=sql_agg, + expressions=[ + exp.Sub( + this=exp.column(end_col, table=source_ref, quoted=True), + expression=exp.column( + start_col, table=source_ref, quoted=True + ), + ) + ], + ) + + # Build main SELECT + final_query = exp.Select() + + # Add bin coordinate columns + final_query.select( + exp.column(chrom_col, table="bins", quoted=True), + copy=False, + ) + final_query.select( + exp.column(start_col, table="bins", quoted=True), + append=True, + copy=False, + ) + final_query.select( + exp.column(end_col, table="bins", quoted=True), + append=True, + copy=False, + ) + + # Replace COVERAGE(...) in select list with aggregate, and add other columns + for expression in query.expressions: + if isinstance(expression, GIQLCoverage): + final_query.select( + exp.alias_(agg_expr, "value", quoted=False), + append=True, + copy=False, + ) + elif isinstance(expression, exp.Alias) and isinstance( + expression.this, GIQLCoverage + ): + final_query.select( + exp.alias_(agg_expr, expression.alias, quoted=False), + append=True, + copy=False, + ) + else: + final_query.select(expression, append=True, copy=False) + + # FROM __giql_bins AS bins + final_query.from_( + exp.Table( + this=exp.Identifier(this="__giql_bins"), + alias=exp.TableAlias(this=exp.Identifier(this="bins")), + ), + copy=False, + ) + + # LEFT JOIN source ON overlap conditions + source_table = exp.to_table(table_name) if table_name else exp.to_table("source") + source_table.set("alias", exp.TableAlias(this=exp.Identifier(this=source_ref))) + + join_condition = exp.And( + this=exp.And( + this=exp.LT( + this=exp.column(start_col, table=source_ref, quoted=True), + expression=exp.column(end_col, table="bins", quoted=True), + ), + expression=exp.GT( + this=exp.column(end_col, table=source_ref, quoted=True), + expression=exp.column(start_col, table="bins", quoted=True), + ), + ), + expression=exp.EQ( + this=exp.column(chrom_col, table=source_ref, quoted=True), + expression=exp.column(chrom_col, table="bins", quoted=True), + ), + ) + + # Merge original WHERE into the JOIN ON condition so that + # LEFT JOIN still produces zero-coverage bins (WHERE would filter + # them out because source columns are NULL for non-matching bins) + if query.args.get("where"): + where_condition = query.args["where"].this.copy() + # Qualify unqualified column references with source_ref + for col in where_condition.find_all(exp.Column): + if not col.table: + col.set("table", exp.Identifier(this=source_ref)) + join_condition = exp.And( + this=join_condition, + expression=where_condition, + ) + + left_join = exp.Join( + this=source_table, + on=join_condition, + kind="LEFT", + ) + final_query.append("joins", left_join) + + # GROUP BY bins.chrom, bins.start, bins.end + final_query.group_by( + exp.column(chrom_col, table="bins", quoted=True), + copy=False, + ) + final_query.group_by( + exp.column(start_col, table="bins", quoted=True), + append=True, + copy=False, + ) + final_query.group_by( + exp.column(end_col, table="bins", quoted=True), + append=True, + copy=False, + ) + + # ORDER BY bins.chrom, bins.start + final_query.order_by( + exp.Ordered(this=exp.column(chrom_col, table="bins", quoted=True)), + copy=False, + ) + final_query.order_by( + exp.Ordered(this=exp.column(start_col, table="bins", quoted=True)), + append=True, + copy=False, + ) + + # Attach the WITH clause + final_query.set("with_", with_clause) + return final_query diff --git a/src/giql/transpile.py b/src/giql/transpile.py index 2b29c3d..c5c165b 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -11,6 +11,7 @@ from giql.table import Table from giql.table import Tables from giql.transformer import ClusterTransformer +from giql.transformer import CoverageTransformer from giql.transformer import MergeTransformer @@ -99,6 +100,7 @@ def transpile( tables_container = _build_tables(tables) # Initialize transformers with table configurations + coverage_transformer = CoverageTransformer(tables_container) merge_transformer = MergeTransformer(tables_container) cluster_transformer = ClusterTransformer(tables_container) @@ -111,8 +113,10 @@ def transpile( except Exception as e: raise ValueError(f"Parse error: {e}\nQuery: {giql}") from e - # Apply transformations (MERGE first, then CLUSTER) + # Apply transformations (COVERAGE first, then MERGE, then CLUSTER) try: + # COVERAGE transformation (independent, applied first) + ast = coverage_transformer.transform(ast) # MERGE transformation (which may internally use CLUSTER) ast = merge_transformer.transform(ast) # CLUSTER transformation for any standalone CLUSTER expressions diff --git a/tests/integration/bedtools/test_correctness_intersect.py b/tests/integration/bedtools/test_correctness_intersect.py new file mode 100644 index 0000000..d0d64da --- /dev/null +++ b/tests/integration/bedtools/test_correctness_intersect.py @@ -0,0 +1,235 @@ +"""Extended correctness tests for GIQL INTERSECTS operator vs bedtools intersect. + +These tests cover boundary cases, scale, and edge scenarios beyond the basic +tests in test_intersect.py, ensuring comprehensive GIQL/bedtools equivalence. +""" + +from giql import transpile + +from .utils.bedtools_wrapper import intersect +from .utils.comparison import compare_results +from .utils.data_models import GenomicInterval +from .utils.duckdb_loader import load_intervals + + +def _run_intersect_comparison( + duckdb_connection, + intervals_a, + intervals_b, + strand_filter="", +): + """Run GIQL INTERSECTS and bedtools intersect, return ComparisonResult.""" + load_intervals( + duckdb_connection, + "intervals_a", + [i.to_tuple() for i in intervals_a], + ) + load_intervals( + duckdb_connection, + "intervals_b", + [i.to_tuple() for i in intervals_b], + ) + + strand_mode = None + if "a.strand = b.strand" in strand_filter: + strand_mode = "same" + elif "a.strand != b.strand" in strand_filter: + strand_mode = "opposite" + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + strand_mode=strand_mode, + ) + + where_clause = "WHERE a.interval INTERSECTS b.interval" + if strand_filter: + where_clause += f" AND {strand_filter}" + + sql = transpile( + f""" + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + {where_clause} + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + return compare_results(giql_result, bedtools_result) + + +def test_intersect_single_bp_overlap(duckdb_connection): + """ + GIVEN two intervals overlapping by exactly 1bp + WHEN GIQL INTERSECTS is compared to bedtools intersect + THEN both detect the 1bp overlap + """ + a = [GenomicInterval("chr1", 100, 200, "a1", 0, "+")] + b = [GenomicInterval("chr1", 199, 300, "b1", 0, "+")] + comparison = _run_intersect_comparison(duckdb_connection, a, b) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_intersect_containment_a_contains_b(duckdb_connection): + """ + GIVEN interval A fully contains interval B + WHEN GIQL INTERSECTS is compared to bedtools intersect + THEN A is reported as intersecting + """ + a = [GenomicInterval("chr1", 100, 500, "a1", 0, "+")] + b = [GenomicInterval("chr1", 200, 300, "b1", 0, "+")] + comparison = _run_intersect_comparison(duckdb_connection, a, b) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_intersect_containment_b_contains_a(duckdb_connection): + """ + GIVEN interval B fully contains interval A + WHEN GIQL INTERSECTS is compared to bedtools intersect + THEN A is reported as intersecting + """ + a = [GenomicInterval("chr1", 200, 300, "a1", 0, "+")] + b = [GenomicInterval("chr1", 100, 500, "b1", 0, "+")] + comparison = _run_intersect_comparison(duckdb_connection, a, b) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_intersect_deduplication(duckdb_connection): + """ + GIVEN one interval in A overlapping multiple intervals in B + WHEN GIQL INTERSECTS with DISTINCT is compared to bedtools intersect -u + THEN A interval reported once + """ + a = [GenomicInterval("chr1", 100, 300, "a1", 0, "+")] + b = [ + GenomicInterval("chr1", 150, 200, "b1", 0, "+"), + GenomicInterval("chr1", 200, 250, "b2", 0, "+"), + GenomicInterval("chr1", 250, 350, "b3", 0, "+"), + ] + comparison = _run_intersect_comparison(duckdb_connection, a, b) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_intersect_non_standard_chroms(duckdb_connection): + """ + GIVEN intervals on non-standard chromosome names (chrM, chrUn) + WHEN GIQL INTERSECTS is compared to bedtools intersect + THEN results match regardless of chromosome naming + """ + a = [ + GenomicInterval("chrM", 100, 200, "a1", 0, "+"), + GenomicInterval("chrUn", 100, 200, "a2", 0, "+"), + ] + b = [ + GenomicInterval("chrM", 150, 250, "b1", 0, "+"), + GenomicInterval("chrUn", 150, 250, "b2", 0, "+"), + ] + comparison = _run_intersect_comparison(duckdb_connection, a, b) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 2 + + +def test_intersect_large_intervals(duckdb_connection): + """ + GIVEN very large genomic intervals (spanning millions of bases) + WHEN GIQL INTERSECTS is compared to bedtools intersect + THEN results match correctly + """ + a = [GenomicInterval("chr1", 0, 10_000_000, "a1", 0, "+")] + b = [GenomicInterval("chr1", 5_000_000, 15_000_000, "b1", 0, "+")] + comparison = _run_intersect_comparison(duckdb_connection, a, b) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_intersect_many_intervals_scale(duckdb_connection): + """ + GIVEN a generated dataset with 100 intervals per chromosome on 3 chromosomes + WHEN GIQL INTERSECTS is compared to bedtools intersect + THEN results match on the full dataset + """ + import random + + rng = random.Random(42) + intervals_a = [] + intervals_b = [] + + for chrom_num in range(1, 4): + chrom = f"chr{chrom_num}" + for i in range(100): + start = rng.randint(0, 900_000) + size = rng.randint(100, 1000) + strand = rng.choice(["+", "-"]) + intervals_a.append( + GenomicInterval( + chrom, + start, + start + size, + f"a_{chrom_num}_{i}", + 0, + strand, + ) + ) + start = rng.randint(0, 900_000) + size = rng.randint(100, 1000) + strand = rng.choice(["+", "-"]) + intervals_b.append( + GenomicInterval( + chrom, + start, + start + size, + f"b_{chrom_num}_{i}", + 0, + strand, + ) + ) + + comparison = _run_intersect_comparison(duckdb_connection, intervals_a, intervals_b) + assert comparison.match, comparison.failure_message() + + +def test_intersect_same_strand_correctness(duckdb_connection): + """ + GIVEN overlapping intervals with mixed strands + WHEN GIQL INTERSECTS with same-strand filter is compared to bedtools -s + THEN only same-strand overlaps match + """ + a = [ + GenomicInterval("chr1", 100, 200, "a_plus", 0, "+"), + GenomicInterval("chr1", 100, 200, "a_minus", 0, "-"), + ] + b = [GenomicInterval("chr1", 150, 250, "b_plus", 0, "+")] + comparison = _run_intersect_comparison( + duckdb_connection, + a, + b, + strand_filter="a.strand = b.strand", + ) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_intersect_opposite_strand_correctness(duckdb_connection): + """ + GIVEN overlapping intervals with mixed strands + WHEN GIQL INTERSECTS with opposite-strand filter is compared to bedtools -S + THEN only opposite-strand overlaps match + """ + a = [ + GenomicInterval("chr1", 100, 200, "a_plus", 0, "+"), + GenomicInterval("chr1", 100, 200, "a_minus", 0, "-"), + ] + b = [GenomicInterval("chr1", 150, 250, "b_plus", 0, "+")] + comparison = _run_intersect_comparison( + duckdb_connection, + a, + b, + strand_filter="a.strand != b.strand", + ) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 diff --git a/tests/integration/bedtools/test_correctness_merge.py b/tests/integration/bedtools/test_correctness_merge.py new file mode 100644 index 0000000..9cdb987 --- /dev/null +++ b/tests/integration/bedtools/test_correctness_merge.py @@ -0,0 +1,207 @@ +"""Extended correctness tests for GIQL MERGE operator vs bedtools merge. + +These tests cover transitive chains, topology variations, and scale scenarios +to ensure comprehensive GIQL/bedtools equivalence for merge operations. +""" + +from giql import transpile + +from .utils.bedtools_wrapper import merge +from .utils.comparison import compare_results +from .utils.data_models import GenomicInterval +from .utils.duckdb_loader import load_intervals + + +def _run_merge_comparison(duckdb_connection, intervals, strand_mode=None): + """Run GIQL MERGE and bedtools merge, return ComparisonResult.""" + load_intervals( + duckdb_connection, + "intervals", + [i.to_tuple() for i in intervals], + ) + + bedtools_result = merge( + [i.to_tuple() for i in intervals], + strand_mode=strand_mode, + ) + + if strand_mode == "same": + giql_sql = "SELECT MERGE(interval, stranded := true) FROM intervals" + else: + giql_sql = "SELECT MERGE(interval) FROM intervals" + + sql = transpile(giql_sql, tables=["intervals"]) + giql_result = duckdb_connection.execute(sql).fetchall() + + return compare_results(giql_result, bedtools_result) + + +def test_merge_transitive_chain(duckdb_connection): + """ + GIVEN a chain A overlaps B, B overlaps C (but A doesn't overlap C directly) + WHEN GIQL MERGE is compared to bedtools merge + THEN entire chain merged into single interval + """ + intervals = [ + GenomicInterval("chr1", 100, 200, "i1", 0, "+"), + GenomicInterval("chr1", 180, 300, "i2", 0, "+"), + GenomicInterval("chr1", 280, 400, "i3", 0, "+"), + ] + comparison = _run_merge_comparison(duckdb_connection, intervals) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_merge_single_interval(duckdb_connection): + """ + GIVEN a single interval + WHEN GIQL MERGE is compared to bedtools merge + THEN single interval returned unchanged + """ + intervals = [GenomicInterval("chr1", 100, 200, "i1", 0, "+")] + comparison = _run_merge_comparison(duckdb_connection, intervals) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_merge_complete_overlap(duckdb_connection): + """ + GIVEN all intervals on chromosome overlap (one big region) + WHEN GIQL MERGE is compared to bedtools merge + THEN single merged interval + """ + intervals = [ + GenomicInterval("chr1", 100, 500, "i1", 0, "+"), + GenomicInterval("chr1", 200, 400, "i2", 0, "+"), + GenomicInterval("chr1", 300, 600, "i3", 0, "+"), + GenomicInterval("chr1", 150, 550, "i4", 0, "+"), + ] + comparison = _run_merge_comparison(duckdb_connection, intervals) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_merge_mixed_topology(duckdb_connection): + """ + GIVEN a mix of overlapping clusters and isolated intervals + WHEN GIQL MERGE is compared to bedtools merge + THEN correct number of merged regions + """ + intervals = [ + # Cluster 1: overlapping + GenomicInterval("chr1", 100, 200, "c1a", 0, "+"), + GenomicInterval("chr1", 150, 300, "c1b", 0, "+"), + # Isolated + GenomicInterval("chr1", 500, 600, "iso", 0, "+"), + # Cluster 2: overlapping + GenomicInterval("chr1", 800, 900, "c2a", 0, "+"), + GenomicInterval("chr1", 850, 1000, "c2b", 0, "+"), + ] + comparison = _run_merge_comparison(duckdb_connection, intervals) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 3 + + +def test_merge_minimal_overlap(duckdb_connection): + """ + GIVEN intervals with exactly 1bp overlap + WHEN GIQL MERGE is compared to bedtools merge + THEN 1bp overlap triggers merge + """ + intervals = [ + GenomicInterval("chr1", 100, 200, "i1", 0, "+"), + GenomicInterval("chr1", 199, 300, "i2", 0, "+"), + ] + comparison = _run_merge_comparison(duckdb_connection, intervals) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_merge_unsorted_input(duckdb_connection): + """ + GIVEN intervals inserted in non-sorted order + WHEN GIQL MERGE is compared to bedtools merge + THEN results match regardless of input order + """ + intervals = [ + GenomicInterval("chr1", 400, 500, "i3", 0, "+"), + GenomicInterval("chr1", 100, 200, "i1", 0, "+"), + GenomicInterval("chr1", 150, 250, "i2", 0, "+"), + ] + comparison = _run_merge_comparison(duckdb_connection, intervals) + assert comparison.match, comparison.failure_message() + + +def test_merge_per_chromosome(duckdb_connection): + """ + GIVEN overlapping intervals on separate chromosomes + WHEN GIQL MERGE is compared to bedtools merge + THEN merging occurs per-chromosome independently + """ + intervals = [ + GenomicInterval("chr1", 100, 200, "c1a", 0, "+"), + GenomicInterval("chr1", 150, 300, "c1b", 0, "+"), + GenomicInterval("chr2", 100, 200, "c2a", 0, "+"), + GenomicInterval("chr2", 150, 300, "c2b", 0, "+"), + GenomicInterval("chr3", 100, 200, "c3", 0, "+"), # no overlap + ] + comparison = _run_merge_comparison(duckdb_connection, intervals) + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 3 # 1 per chrom + + +def test_merge_strand_specific_correctness(duckdb_connection): + """ + GIVEN overlapping intervals on different strands + WHEN GIQL MERGE(stranded=true) is compared to bedtools merge -s + THEN per-strand merge count matches + """ + intervals = [ + GenomicInterval("chr1", 100, 200, "i1", 0, "+"), + GenomicInterval("chr1", 150, 250, "i2", 0, "+"), + GenomicInterval("chr1", 120, 220, "i3", 0, "-"), + GenomicInterval("chr1", 180, 280, "i4", 0, "-"), + ] + load_intervals( + duckdb_connection, + "intervals", + [i.to_tuple() for i in intervals], + ) + + bedtools_result = merge( + [i.to_tuple() for i in intervals], + strand_mode="same", + ) + + sql = transpile( + "SELECT MERGE(interval, stranded := true) FROM intervals", + tables=["intervals"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + # Both should have 2 merged intervals (one per strand) + assert len(giql_result) == len(bedtools_result) + + +def test_merge_large_scale(duckdb_connection): + """ + GIVEN 100+ intervals across 3 chromosomes + WHEN GIQL MERGE is compared to bedtools merge + THEN results match on the full dataset + """ + import random + + rng = random.Random(42) + intervals = [] + + for chrom_num in range(1, 4): + chrom = f"chr{chrom_num}" + for i in range(100): + start = rng.randint(0, 500_000) + size = rng.randint(100, 2000) + intervals.append( + GenomicInterval(chrom, start, start + size, f"{chrom}_{i}", 0, "+") + ) + + comparison = _run_merge_comparison(duckdb_connection, intervals) + assert comparison.match, comparison.failure_message() diff --git a/tests/integration/bedtools/test_correctness_nearest.py b/tests/integration/bedtools/test_correctness_nearest.py new file mode 100644 index 0000000..80bb552 --- /dev/null +++ b/tests/integration/bedtools/test_correctness_nearest.py @@ -0,0 +1,288 @@ +"""Extended correctness tests for GIQL NEAREST operator vs bedtools closest. + +These tests cover distance calculations, multi-query scenarios, and scale +to ensure comprehensive GIQL/bedtools equivalence for nearest operations. +""" + +from giql import transpile + +from .utils.bedtools_wrapper import closest +from .utils.data_models import GenomicInterval +from .utils.duckdb_loader import load_intervals + + +def _load_and_query_nearest( + duckdb_connection, + intervals_a, + intervals_b, + *, + k=1, + stranded=False, +): + """Load intervals, run GIQL NEAREST and bedtools closest, return both results.""" + load_intervals( + duckdb_connection, + "intervals_a", + [i.to_tuple() for i in intervals_a], + ) + load_intervals( + duckdb_connection, + "intervals_b", + [i.to_tuple() for i in intervals_b], + ) + + strand_mode = "same" if stranded else None + bedtools_result = closest( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + strand_mode=strand_mode, + k=k, + ) + + stranded_arg = ", stranded := true" if stranded else "" + sql = transpile( + f""" + SELECT a.*, b.* + FROM intervals_a a + CROSS JOIN LATERAL NEAREST( + intervals_b, + reference := a.interval, + k := {k}{stranded_arg} + ) b + ORDER BY a.chrom, a.start + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + return giql_result, bedtools_result + + +def test_nearest_overlapping_distance_zero(duckdb_connection): + """ + GIVEN overlapping intervals in A and B + WHEN GIQL NEAREST is compared to bedtools closest + THEN overlapping intervals report distance=0 in bedtools + """ + a = [GenomicInterval("chr1", 100, 300, "a1", 0, "+")] + b = [GenomicInterval("chr1", 200, 400, "b1", 0, "+")] + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + assert len(giql_result) == len(bedtools_result) == 1 + # bedtools closest -d reports 0 for overlapping + assert bedtools_result[0][-1] == 0 + + +def test_nearest_adjacent_distance_zero(duckdb_connection): + """ + GIVEN adjacent intervals (touching, half-open coords) + WHEN GIQL NEAREST is compared to bedtools closest + THEN adjacent intervals report distance=0 + """ + a = [GenomicInterval("chr1", 100, 200, "a1", 0, "+")] + b = [GenomicInterval("chr1", 200, 300, "b1", 0, "+")] + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + assert len(giql_result) == len(bedtools_result) == 1 + # bedtools 2.31+ reports 1 for adjacent non-overlapping intervals + # in half-open coordinates (distance includes the gap base) + assert bedtools_result[0][-1] <= 1 + assert giql_result[0][9] == "b1" + + +def test_nearest_upstream_distance(duckdb_connection): + """ + GIVEN B interval far upstream of A + WHEN GIQL NEAREST is compared to bedtools closest + THEN distance calculated correctly + """ + a = [GenomicInterval("chr1", 500, 600, "a1", 0, "+")] + b = [GenomicInterval("chr1", 100, 200, "b1", 0, "+")] + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + assert len(giql_result) == len(bedtools_result) == 1 + # Distance: 500 - 200 = 300 (half-open), bedtools may report 301 + assert bedtools_result[0][-1] in (300, 301) + assert giql_result[0][9] == "b1" + + +def test_nearest_downstream_distance(duckdb_connection): + """ + GIVEN B interval far downstream of A + WHEN GIQL NEAREST is compared to bedtools closest + THEN distance calculated correctly + """ + a = [GenomicInterval("chr1", 100, 200, "a1", 0, "+")] + b = [GenomicInterval("chr1", 500, 600, "b1", 0, "+")] + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + assert len(giql_result) == len(bedtools_result) == 1 + # Distance: 500 - 200 = 300 (half-open), bedtools may report 301 + assert bedtools_result[0][-1] in (300, 301) + assert giql_result[0][9] == "b1" + + +def test_nearest_multi_query_correctness(duckdb_connection): + """ + GIVEN multiple query intervals and multiple candidates + WHEN GIQL NEAREST is compared to bedtools closest + THEN correct pairing for each query interval + """ + a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr1", 500, 600, "a2", 0, "+"), + GenomicInterval("chr1", 900, 1000, "a3", 0, "+"), + ] + b = [ + GenomicInterval("chr1", 250, 300, "b1", 0, "+"), + GenomicInterval("chr1", 700, 800, "b2", 0, "+"), + ] + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + assert len(giql_result) == len(bedtools_result) == 3 + + giql_sorted = sorted(giql_result, key=lambda r: (r[0], r[1])) + bt_sorted = sorted(bedtools_result, key=lambda r: (r[0], r[1])) + + for giql_row, bt_row in zip(giql_sorted, bt_sorted): + assert giql_row[3] == bt_row[3] # a.name matches + assert giql_row[9] == bt_row[9] # b.name matches + + +def test_nearest_k3_correctness(duckdb_connection): + """ + GIVEN one query interval and 4 database intervals + WHEN GIQL NEAREST(k=3) is compared to bedtools closest -k 3 + THEN both return 3 nearest intervals + """ + a = [GenomicInterval("chr1", 400, 500, "a1", 0, "+")] + b = [ + GenomicInterval("chr1", 100, 150, "b_far", 0, "+"), + GenomicInterval("chr1", 350, 390, "b_near", 0, "+"), + GenomicInterval("chr1", 550, 600, "b_close", 0, "+"), + GenomicInterval("chr1", 900, 1000, "b_farther", 0, "+"), + ] + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b, k=3) + + assert len(giql_result) == 3 + assert len(bedtools_result) == 3 + + giql_names = {r[9] for r in giql_result} + bt_names = {r[9] for r in bedtools_result} + assert giql_names == bt_names + + +def test_nearest_k_exceeds_available_correctness(duckdb_connection): + """ + GIVEN one query and only 2 database intervals, k=5 + WHEN GIQL NEAREST is compared to bedtools closest + THEN both return only 2 (available) results + """ + a = [GenomicInterval("chr1", 200, 300, "a1", 0, "+")] + b = [ + GenomicInterval("chr1", 100, 150, "b1", 0, "+"), + GenomicInterval("chr1", 400, 500, "b2", 0, "+"), + ] + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b, k=5) + + assert len(giql_result) == len(bedtools_result) == 2 + + +def test_nearest_same_strand_correctness(duckdb_connection): + """ + GIVEN intervals with candidates on same and opposite strands + WHEN GIQL NEAREST(stranded=true) is compared to bedtools closest -s + THEN only same-strand matches + """ + a = [GenomicInterval("chr1", 100, 200, "a1", 0, "+")] + b = [ + GenomicInterval("chr1", 220, 240, "b_opp", 0, "-"), # closer, opposite + GenomicInterval("chr1", 300, 400, "b_same", 0, "+"), # farther, same + ] + giql_result, bedtools_result = _load_and_query_nearest( + duckdb_connection, + a, + b, + stranded=True, + ) + + assert len(giql_result) == len(bedtools_result) == 1 + assert giql_result[0][9] == "b_same" + assert bedtools_result[0][9] == "b_same" + + +def test_nearest_strand_ignorant_correctness(duckdb_connection): + """ + GIVEN intervals on different strands + WHEN GIQL NEAREST (default) is compared to bedtools closest (default) + THEN nearest found regardless of strand + """ + a = [GenomicInterval("chr1", 100, 200, "a1", 0, "+")] + b = [ + GenomicInterval("chr1", 250, 300, "b_far", 0, "+"), + GenomicInterval("chr1", 210, 230, "b_near", 0, "-"), + ] + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + assert len(giql_result) == len(bedtools_result) == 1 + assert giql_result[0][9] == "b_near" + assert bedtools_result[0][9] == "b_near" + + +def test_nearest_cross_chromosome_isolation(duckdb_connection): + """ + GIVEN intervals on multiple chromosomes + WHEN GIQL NEAREST is compared to bedtools closest + THEN nearest found per-chromosome only + """ + a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr2", 100, 200, "a2", 0, "+"), + ] + b = [ + GenomicInterval("chr1", 500, 600, "b1", 0, "+"), + GenomicInterval("chr2", 300, 400, "b2", 0, "+"), + ] + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + assert len(giql_result) == len(bedtools_result) == 2 + + for giql_row in giql_result: + assert giql_row[0] == giql_row[6], "A and B should be on same chromosome" + + +def test_nearest_large_scale(duckdb_connection): + """ + GIVEN 50+ intervals per table across 3 chromosomes + WHEN GIQL NEAREST is compared to bedtools closest + THEN row counts match on the full dataset + """ + import random + + rng = random.Random(42) + intervals_a = [] + intervals_b = [] + + for chrom_num in range(1, 4): + chrom = f"chr{chrom_num}" + for i in range(50): + start = rng.randint(0, 900_000) + size = rng.randint(100, 1000) + intervals_a.append( + GenomicInterval(chrom, start, start + size, f"a_{chrom_num}_{i}", 0, "+") + ) + start = rng.randint(0, 900_000) + size = rng.randint(100, 1000) + intervals_b.append( + GenomicInterval(chrom, start, start + size, f"b_{chrom_num}_{i}", 0, "+") + ) + + giql_result, bedtools_result = _load_and_query_nearest( + duckdb_connection, + intervals_a, + intervals_b, + ) + + assert len(giql_result) == len(bedtools_result), ( + f"Row count mismatch: GIQL={len(giql_result)}, bedtools={len(bedtools_result)}" + ) diff --git a/tests/integration/bedtools/test_correctness_workflows.py b/tests/integration/bedtools/test_correctness_workflows.py new file mode 100644 index 0000000..26316fe --- /dev/null +++ b/tests/integration/bedtools/test_correctness_workflows.py @@ -0,0 +1,342 @@ +"""Integration correctness tests for multi-operation GIQL workflows. + +These tests validate that chained GIQL operations produce results matching +equivalent bedtools command pipelines. Corresponds to User Story 4 (P3) +from the bedtools integration test spec. +""" + +from giql import transpile + +from .utils.bedtools_wrapper import closest +from .utils.bedtools_wrapper import intersect +from .utils.bedtools_wrapper import merge +from .utils.comparison import compare_results +from .utils.data_models import GenomicInterval +from .utils.duckdb_loader import load_intervals + + +def test_workflow_intersect_then_merge(duckdb_connection): + """ + GIVEN two interval sets with overlaps + WHEN GIQL: intersect then merge (via subquery) + vs bedtools: intersect | sort | merge + THEN final merged intervals match + """ + intervals_a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr1", 150, 300, "a2", 0, "+"), + GenomicInterval("chr1", 500, 600, "a3", 0, "+"), + ] + intervals_b = [ + GenomicInterval("chr1", 180, 250, "b1", 0, "+"), + GenomicInterval("chr1", 520, 580, "b2", 0, "+"), + ] + + load_intervals(duckdb_connection, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(duckdb_connection, "intervals_b", [i.to_tuple() for i in intervals_b]) + + # bedtools pipeline: intersect then merge + intersect_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + ) + bedtools_final = merge(intersect_result) + + # GIQL: use CTE to intersect, then merge + sql = transpile( + """ + WITH hits AS ( + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + ) + SELECT MERGE(interval) + FROM hits + """, + tables=["intervals_a", "intervals_b", "hits"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + comparison = compare_results(giql_result, bedtools_final) + assert comparison.match, comparison.failure_message() + + +def test_workflow_nearest_then_filter_distance(duckdb_connection): + """ + GIVEN two interval sets + WHEN GIQL: NEAREST with max_distance filter + vs bedtools: closest -d then filter by distance + THEN filtered nearest results match + """ + intervals_a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr1", 500, 600, "a2", 0, "+"), + ] + intervals_b = [ + GenomicInterval("chr1", 220, 250, "b_near", 0, "+"), # 20bp from a1 + GenomicInterval("chr1", 900, 1000, "b_far", 0, "+"), # 300bp from a2 + ] + + load_intervals(duckdb_connection, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(duckdb_connection, "intervals_b", [i.to_tuple() for i in intervals_b]) + + # bedtools: closest -d, then filter distance <= 50 + bt_result = closest( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + ) + bedtools_filtered = [row for row in bt_result if row[-1] <= 50] + + # GIQL: NEAREST with max_distance + sql = transpile( + """ + SELECT a.name, b.name + FROM intervals_a a + CROSS JOIN LATERAL NEAREST( + intervals_b, + reference := a.interval, + k := 1, + max_distance := 50 + ) b + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + # Both should return only a1->b_near (distance 20 <= 50) + # a2->b_far (distance 300 > 50) should be excluded + assert len(giql_result) == len(bedtools_filtered) + if len(giql_result) > 0: + giql_names = {r[0] for r in giql_result} + assert "a1" in giql_names + + +def test_workflow_merge_then_intersect(duckdb_connection): + """ + GIVEN intervals with overlaps and a second interval set + WHEN GIQL: merge intervals then intersect with second set + vs bedtools: merge | intersect + THEN results match + """ + intervals_a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr1", 180, 300, "a2", 0, "+"), + GenomicInterval("chr1", 500, 600, "a3", 0, "+"), + ] + intervals_b = [ + GenomicInterval("chr1", 250, 350, "b1", 0, "+"), + GenomicInterval("chr1", 550, 650, "b2", 0, "+"), + ] + + load_intervals(duckdb_connection, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(duckdb_connection, "intervals_b", [i.to_tuple() for i in intervals_b]) + + # bedtools pipeline: merge a, then intersect with b + merged_a = merge([i.to_tuple() for i in intervals_a]) + bedtools_final = intersect(merged_a, [i.to_tuple() for i in intervals_b]) + + # GIQL: CTE to merge, then intersect + sql = transpile( + """ + WITH merged AS ( + SELECT MERGE(interval) AS interval + FROM intervals_a + ) + SELECT DISTINCT m.* + FROM merged m, intervals_b b + WHERE m.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b", "merged"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + # MERGE outputs BED3 (chrom, start, end); compare only coordinates + bedtools_coords = [row[:3] for row in bedtools_final] + comparison = compare_results(giql_result, bedtools_coords) + assert comparison.match, comparison.failure_message() + + +def test_workflow_stranded_intersect_merge(duckdb_connection): + """ + GIVEN intervals with strand info + WHEN GIQL: strand-specific intersect then merge + vs bedtools: intersect -s | sort | merge + THEN strand-aware pipeline results match + """ + intervals_a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr1", 150, 300, "a2", 0, "+"), + GenomicInterval("chr1", 120, 250, "a3", 0, "-"), + ] + intervals_b = [ + GenomicInterval("chr1", 180, 250, "b1", 0, "+"), + GenomicInterval("chr1", 130, 220, "b2", 0, "-"), + ] + + load_intervals(duckdb_connection, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(duckdb_connection, "intervals_b", [i.to_tuple() for i in intervals_b]) + + # bedtools pipeline: intersect -s then merge + intersect_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + strand_mode="same", + ) + bedtools_final = merge(intersect_result) + + # GIQL: same-strand intersect via CTE then merge + sql = transpile( + """ + WITH hits AS ( + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + AND a.strand = b.strand + ) + SELECT MERGE(interval) + FROM hits + """, + tables=["intervals_a", "intervals_b", "hits"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + comparison = compare_results(giql_result, bedtools_final) + assert comparison.match, comparison.failure_message() + + +def test_workflow_intersect_filter_chrom_merge(duckdb_connection): + """ + GIVEN two interval sets on multiple chromosomes + WHEN GIQL: intersect, keep only chr1, then merge + vs bedtools: intersect | grep chr1 | sort | merge + THEN filtered-chromosome workflow matches + """ + intervals_a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr1", 150, 300, "a2", 0, "+"), + GenomicInterval("chr2", 100, 200, "a3", 0, "+"), + ] + intervals_b = [ + GenomicInterval("chr1", 180, 250, "b1", 0, "+"), + GenomicInterval("chr2", 150, 250, "b2", 0, "+"), + ] + + load_intervals(duckdb_connection, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(duckdb_connection, "intervals_b", [i.to_tuple() for i in intervals_b]) + + # bedtools pipeline: intersect, filter chr1, merge + intersect_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + ) + chr1_only = [r for r in intersect_result if r[0] == "chr1"] + bedtools_final = merge(chr1_only) if chr1_only else [] + + # GIQL: CTE intersect with chr1 filter, then merge + sql = transpile( + """ + WITH chr1_hits AS ( + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + AND a.chrom = 'chr1' + ) + SELECT MERGE(interval) + FROM chr1_hits + """, + tables=["intervals_a", "intervals_b", "chr1_hits"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + comparison = compare_results(giql_result, bedtools_final) + assert comparison.match, comparison.failure_message() + + +def test_workflow_full_pipeline_step_by_step(duckdb_connection): + """ + GIVEN a generated dataset across 3 chromosomes + WHEN full pipeline (intersect -> merge -> nearest) is run + THEN each intermediate step matches bedtools + """ + import random + + rng = random.Random(99) + intervals_a = [] + intervals_b = [] + intervals_c = [] + + for chrom_num in range(1, 4): + chrom = f"chr{chrom_num}" + for i in range(30): + start = rng.randint(0, 100_000) + size = rng.randint(100, 1000) + intervals_a.append( + GenomicInterval(chrom, start, start + size, f"a_{chrom_num}_{i}", 0, "+") + ) + for i in range(30): + start = rng.randint(0, 100_000) + size = rng.randint(100, 1000) + intervals_b.append( + GenomicInterval(chrom, start, start + size, f"b_{chrom_num}_{i}", 0, "+") + ) + for i in range(10): + start = rng.randint(0, 100_000) + size = rng.randint(100, 1000) + intervals_c.append( + GenomicInterval(chrom, start, start + size, f"c_{chrom_num}_{i}", 0, "+") + ) + + load_intervals(duckdb_connection, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(duckdb_connection, "intervals_b", [i.to_tuple() for i in intervals_b]) + load_intervals(duckdb_connection, "intervals_c", [i.to_tuple() for i in intervals_c]) + + # Step 1: Intersect A with B + bt_intersected = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + ) + + sql_step1 = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + giql_step1 = duckdb_connection.execute(sql_step1).fetchall() + + comparison1 = compare_results(giql_step1, bt_intersected) + assert comparison1.match, ( + f"Step 1 (intersect) failed: {comparison1.failure_message()}" + ) + + # Step 2: Merge the intersected results + if bt_intersected: + bt_merged = merge(bt_intersected) + else: + bt_merged = [] + + if giql_step1: + # Create temp table from step 1 results for step 2 + duckdb_connection.execute(""" + CREATE TABLE step1_results AS + SELECT * FROM ( + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.chrom = b.chrom + AND a."start" < b."end" + AND a."end" > b."start" + ) + """) + + sql_step2 = transpile( + "SELECT MERGE(interval) FROM step1_results", + tables=["step1_results"], + ) + giql_step2 = duckdb_connection.execute(sql_step2).fetchall() + + comparison2 = compare_results(giql_step2, bt_merged) + assert comparison2.match, ( + f"Step 2 (merge) failed: {comparison2.failure_message()}" + ) diff --git a/tests/test_coverage.py b/tests/test_coverage.py new file mode 100644 index 0000000..d0dfc85 --- /dev/null +++ b/tests/test_coverage.py @@ -0,0 +1,1001 @@ +"""Tests for the COVERAGE operator. + +Test specification: specs/test_coverage.md +""" + +import duckdb +import pytest +from hypothesis import HealthCheck +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st +from sqlglot import exp +from sqlglot import parse_one + +from giql import Table +from giql import transpile +from giql.dialect import GIQLDialect +from giql.expressions import GIQLCoverage +from giql.table import Tables +from giql.transformer import CoverageTransformer + +VALID_STATS = ["count", "mean", "sum", "min", "max"] + + +class TestGIQLCoverage: + """Tests for GIQLCoverage expression node parsing.""" + + # ------------------------------------------------------------------ + # Example-based parsing (COV-001 to COV-007) + # ------------------------------------------------------------------ + + def test_from_arg_list_with_positional_args(self): + """Test positional interval and resolution mapping. + + Given: + A COVERAGE expression with positional interval and resolution + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCoverage node with resolution set and + stat/target both None + """ + # Act + ast = parse_one( + "SELECT COVERAGE(interval, 1000) FROM features", + dialect=GIQLDialect, + ) + + # Assert + coverage = list(ast.find_all(GIQLCoverage)) + assert len(coverage) == 1 + assert coverage[0].args["resolution"].this == "1000" + assert coverage[0].args.get("stat") is None + assert coverage[0].args.get("target") is None + + def test_from_arg_list_with_walrus_named_stat(self): + """Test named stat parameter via := syntax. + + Given: + A COVERAGE expression with := named stat parameter + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCoverage node with stat set to the given value + """ + # Act + ast = parse_one( + "SELECT COVERAGE(interval, 500, stat := 'mean') FROM features", + dialect=GIQLDialect, + ) + + # Assert + coverage = list(ast.find_all(GIQLCoverage)) + assert len(coverage) == 1 + assert coverage[0].args["stat"].this == "mean" + + def test_from_arg_list_with_arrow_named_stat(self): + """Test named stat parameter via => syntax. + + Given: + A COVERAGE expression with => named stat parameter + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCoverage node with stat set to the given value + """ + # Act + ast = parse_one( + "SELECT COVERAGE(interval, 500, stat => 'mean') FROM features", + dialect=GIQLDialect, + ) + + # Assert + coverage = list(ast.find_all(GIQLCoverage)) + assert len(coverage) == 1 + assert coverage[0].args["stat"].this == "mean" + + def test_from_arg_list_with_named_resolution(self): + """Test named resolution parameter. + + Given: + A COVERAGE expression with named resolution parameter + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCoverage node with resolution set via named param + """ + # Act + ast = parse_one( + "SELECT COVERAGE(interval, resolution := 1000) FROM features", + dialect=GIQLDialect, + ) + + # Assert + coverage = list(ast.find_all(GIQLCoverage)) + assert len(coverage) == 1 + assert coverage[0].args["resolution"].this == "1000" + + def test_from_arg_list_with_walrus_named_target(self): + """Test target parameter via := syntax. + + Given: + A COVERAGE expression with := named target parameter + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCoverage node with target set + """ + # Act + ast = parse_one( + "SELECT COVERAGE(interval, 1000, target := 'score') FROM features", + dialect=GIQLDialect, + ) + + # Assert + coverage = list(ast.find_all(GIQLCoverage)) + assert len(coverage) == 1 + assert coverage[0].args["target"].this == "score" + + def test_from_arg_list_with_arrow_named_target(self): + """Test target parameter via => syntax. + + Given: + A COVERAGE expression with => named target parameter + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCoverage node with target set + """ + # Act + ast = parse_one( + "SELECT COVERAGE(interval, 1000, target => 'score') FROM features", + dialect=GIQLDialect, + ) + + # Assert + coverage = list(ast.find_all(GIQLCoverage)) + assert len(coverage) == 1 + assert coverage[0].args["target"].this == "score" + + def test_from_arg_list_with_all_named_params(self): + """Test all parameters provided as named arguments. + + Given: + A COVERAGE expression with stat, target, and resolution all named + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCoverage node with all three params set + """ + # Act + ast = parse_one( + "SELECT COVERAGE(interval, resolution := 500, " + "stat := 'mean', target := 'score') FROM features", + dialect=GIQLDialect, + ) + + # Assert + coverage = list(ast.find_all(GIQLCoverage)) + assert len(coverage) == 1 + assert coverage[0].args["resolution"].this == "500" + assert coverage[0].args["stat"].this == "mean" + assert coverage[0].args["target"].this == "score" + + # ------------------------------------------------------------------ + # Property-based parsing (PBT-001 to PBT-003) + # ------------------------------------------------------------------ + + @given( + resolution=st.integers(min_value=1, max_value=10_000_000), + stat=st.sampled_from(VALID_STATS), + syntax=st.sampled_from([":=", "=>"]), + ) + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_from_arg_list_with_varying_stat_and_resolution( + self, resolution, stat, syntax + ): + """Test stat and resolution parse correctly across input space. + + Given: + Any valid resolution (1-10M), stat (sampled from valid values), + and syntax (:= or =>) + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCoverage node with correct resolution and stat + """ + # Act + sql = ( + f"SELECT COVERAGE(interval, {resolution}, " + f"stat {syntax} '{stat}') FROM features" + ) + ast = parse_one(sql, dialect=GIQLDialect) + + # Assert + coverage = list(ast.find_all(GIQLCoverage)) + assert len(coverage) == 1 + assert coverage[0].args["resolution"].this == str(resolution) + assert coverage[0].args["stat"].this == stat + + @given(resolution=st.integers(min_value=1, max_value=10_000_000)) + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_from_arg_list_with_varying_positional_only(self, resolution): + """Test positional-only parsing across resolution range. + + Given: + Any valid resolution (1-10M) with no stat or target + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCoverage node with resolution set and + stat/target None + """ + # Act + ast = parse_one( + f"SELECT COVERAGE(interval, {resolution}) FROM features", + dialect=GIQLDialect, + ) + + # Assert + coverage = list(ast.find_all(GIQLCoverage)) + assert len(coverage) == 1 + assert coverage[0].args["resolution"].this == str(resolution) + assert coverage[0].args.get("stat") is None + assert coverage[0].args.get("target") is None + + @given(syntax=st.sampled_from([":=", "=>"])) + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_from_arg_list_with_varying_target_syntax(self, syntax): + """Test target parameter parsing across syntax variants. + + Given: + Either := or => syntax for target parameter + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCoverage node with target set + """ + # Act + ast = parse_one( + f"SELECT COVERAGE(interval, 1000, target {syntax} 'score') FROM features", + dialect=GIQLDialect, + ) + + # Assert + coverage = list(ast.find_all(GIQLCoverage)) + assert len(coverage) == 1 + assert coverage[0].args["target"].this == "score" + + +class TestCoverageTransformer: + """Tests for CoverageTransformer.transform via transpile().""" + + # ------------------------------------------------------------------ + # Instantiation (CT-001) + # ------------------------------------------------------------------ + + def test___init___with_tables(self): + """Test CoverageTransformer stores its tables reference. + + Given: + A Tables container with registered tables + When: + CoverageTransformer is instantiated + Then: + It should store the tables reference + """ + # Arrange + tables = Tables() + tables.register("features", Table("features")) + + # Act + transformer = CoverageTransformer(tables) + + # Assert + assert transformer.tables is tables + + # ------------------------------------------------------------------ + # Basic transpilation (CT-002, CT-003) + # ------------------------------------------------------------------ + + def test_transform_with_basic_count(self): + """Test basic COVERAGE produces correct SQL structure. + + Given: + A basic COVERAGE query with count (default stat) + When: + Transpiled + Then: + It should produce SQL with __giql_bins CTE, GENERATE_SERIES, + LEFT JOIN, GROUP BY, COUNT, and ORDER BY + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "__GIQL_BINS" in upper + assert "GENERATE_SERIES" in upper + assert "LEFT JOIN" in upper + assert "GROUP BY" in upper + assert "COUNT" in upper + assert "ORDER BY" in upper + + def test_transform_without_coverage_expression(self): + """Test non-COVERAGE query passes through unchanged. + + Given: + A query with no COVERAGE expression + When: + Transformed by CoverageTransformer + Then: + It should return the query unchanged + """ + # Arrange + tables = Tables() + tables.register("features", Table("features")) + transformer = CoverageTransformer(tables) + ast = parse_one("SELECT * FROM features", dialect=GIQLDialect) + + # Act + result = transformer.transform(ast) + + # Assert + assert result is ast + + # ------------------------------------------------------------------ + # Stat parameter (CT-004 to CT-007) + # ------------------------------------------------------------------ + + def test_transform_with_stat_mean(self): + """Test stat='mean' maps to AVG aggregate. + + Given: + A COVERAGE query with stat := 'mean' + When: + Transpiled + Then: + It should use AVG aggregate, not COUNT + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000, stat := 'mean') FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "AVG" in upper + assert "COUNT" not in upper + + def test_transform_with_stat_sum(self): + """Test stat='sum' maps to SUM aggregate. + + Given: + A COVERAGE query with stat := 'sum' + When: + Transpiled + Then: + It should use SUM aggregate + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000, stat := 'sum') FROM features", + tables=["features"], + ) + + # Assert + assert "SUM" in sql.upper() + + def test_transform_with_stat_min(self): + """Test stat='min' maps to MIN aggregate. + + Given: + A COVERAGE query with stat := 'min' + When: + Transpiled + Then: + It should use MIN aggregate + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000, stat := 'min') FROM features", + tables=["features"], + ) + + # Assert + assert "MIN(" in sql.upper() + + def test_transform_with_stat_max(self): + """Test stat='max' maps to MAX aggregate. + + Given: + A COVERAGE query with stat := 'max' + When: + Transpiled + Then: + It should use MAX aggregate + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000, stat := 'max') FROM features", + tables=["features"], + ) + + # Assert + assert "MAX(" in sql.upper() + + # ------------------------------------------------------------------ + # Target parameter (CT-008, CT-009) + # ------------------------------------------------------------------ + + def test_transform_with_target_and_mean(self): + """Test target column used with mean stat. + + Given: + A COVERAGE query with stat := 'mean' and target := 'score' + When: + Transpiled + Then: + It should use AVG on the score column + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000, stat := 'mean', " + "target := 'score') FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "AVG" in upper + assert "SCORE" in upper + + def test_transform_with_target_and_count(self): + """Test target column used with default count stat. + + Given: + A COVERAGE query with target := 'score' (default count) + When: + Transpiled + Then: + It should use COUNT on the score column, not COUNT(*) + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000, target := 'score') FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "COUNT" in upper + assert "SCORE" in upper + assert ".*)" not in sql + + # ------------------------------------------------------------------ + # Default alias (CT-010, CT-011) + # ------------------------------------------------------------------ + + def test_transform_with_default_alias(self): + """Test bare COVERAGE gets default 'value' alias. + + Given: + A COVERAGE query without an explicit AS alias + When: + Transpiled + Then: + It should alias the aggregate as "value" + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM features", + tables=["features"], + ) + + # Assert + assert "AS value" in sql + + def test_transform_with_explicit_alias(self): + """Test explicit AS alias overrides default. + + Given: + A COVERAGE query with explicit AS alias + When: + Transpiled + Then: + It should use the explicit alias, not "value" + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000) AS depth FROM features", + tables=["features"], + ) + + # Assert + assert "AS depth" in sql + assert "AS value" not in sql + + # ------------------------------------------------------------------ + # WHERE clause semantics (CT-012, CT-013, CT-014) + # ------------------------------------------------------------------ + + def test_transform_where_moves_to_join_on(self): + """Test WHERE migrates into LEFT JOIN ON clause. + + Given: + A COVERAGE query with a WHERE clause + When: + Transpiled + Then: + It should move the WHERE condition into the LEFT JOIN ON clause, + not the outer WHERE + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM features WHERE score > 10", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "ON" in upper + assert "SCORE > 10" in upper + # The condition should be in the ON clause (between LEFT JOIN and GROUP BY) + after_join = sql.split("LEFT JOIN")[1] + on_clause = after_join.split("GROUP BY")[0] + assert "score > 10" in on_clause + + def test_transform_where_qualifies_columns_in_on(self): + """Test WHERE column references are qualified with source table in ON. + + Given: + A COVERAGE query with a WHERE clause + When: + Transpiled + Then: + It should qualify unqualified column references in the JOIN ON + with the source table + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM features WHERE score > 10", + tables=["features"], + ) + + # Assert + after_join = sql.split("LEFT JOIN")[1] + on_clause = after_join.split("GROUP BY")[0] + assert "features.score" in on_clause + + def test_transform_where_applied_to_chroms_subquery(self): + """Test WHERE is also applied to the chroms subquery. + + Given: + A COVERAGE query with a WHERE clause + When: + Transpiled + Then: + It should also apply the WHERE to the chroms subquery with + table-qualified columns + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM features WHERE score > 10", + tables=["features"], + ) + + # Assert + # The chroms subquery is inside the CTE, before the outer SELECT + cte_part = sql.split(") SELECT")[0] + assert "features.score > 10" in cte_part + + # ------------------------------------------------------------------ + # Column mapping (CT-015) + # ------------------------------------------------------------------ + + def test_transform_with_custom_column_mapping(self): + """Test custom column names are used throughout. + + Given: + A COVERAGE query with custom column mappings + (chromosome, start_pos, end_pos) + When: + Transpiled + Then: + It should use the mapped column names throughout + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM peaks", + tables=[ + Table( + "peaks", + genomic_col="interval", + chrom_col="chromosome", + start_col="start_pos", + end_col="end_pos", + ) + ], + ) + + # Assert + assert "chromosome" in sql + assert "start_pos" in sql + assert "end_pos" in sql + + # ------------------------------------------------------------------ + # Additional SELECT columns (CT-016) + # ------------------------------------------------------------------ + + def test_transform_with_additional_select_columns(self): + """Test extra SELECT columns pass through alongside COVERAGE. + + Given: + A COVERAGE query with additional columns alongside COVERAGE + When: + Transpiled + Then: + It should include the extra columns in the output + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 500) AS cov, name FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "COV" in upper + assert "NAME" in upper + assert "COUNT" in upper + + # ------------------------------------------------------------------ + # Table alias (CT-017) + # ------------------------------------------------------------------ + + def test_transform_with_table_alias(self): + """Test table alias is used as source reference in JOIN. + + Given: + A COVERAGE query with a table alias (FROM features f) + When: + Transpiled + Then: + It should use the alias as the source reference in JOIN + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM features f", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "GENERATE_SERIES" in upper + assert "LEFT JOIN" in upper + + # ------------------------------------------------------------------ + # Resolution (CT-018) + # ------------------------------------------------------------------ + + def test_transform_with_resolution_propagation(self): + """Test resolution value propagates to generate_series and bin width. + + Given: + A COVERAGE query with resolution=500 + When: + Transpiled + Then: + It should use 500 as the step in generate_series and bin width + """ + # Act + sql = transpile( + "SELECT COVERAGE(interval, 500) FROM features", + tables=["features"], + ) + + # Assert + assert "500" in sql + + # ------------------------------------------------------------------ + # CTE nesting (CT-019) + # ------------------------------------------------------------------ + + def test_transform_with_coverage_in_cte(self): + """Test COVERAGE inside a WITH clause is transformed correctly. + + Given: + A COVERAGE expression inside a WITH clause + When: + Transpiled + Then: + It should correctly transform the CTE containing COVERAGE + """ + # Act + sql = transpile( + "WITH cov AS (SELECT COVERAGE(interval, 1000) FROM features) " + "SELECT * FROM cov", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "GENERATE_SERIES" in upper + assert "LEFT JOIN" in upper + assert "COUNT" in upper + + # ------------------------------------------------------------------ + # Error handling (CT-020, CT-021) + # ------------------------------------------------------------------ + + def test_transform_with_invalid_stat(self): + """Test invalid stat raises descriptive error. + + Given: + A COVERAGE query with an invalid stat value + When: + Transpiled + Then: + It should raise ValueError matching "Unknown COVERAGE stat" + """ + # Act & Assert + with pytest.raises(ValueError, match="Unknown COVERAGE stat"): + transpile( + "SELECT COVERAGE(interval, 1000, stat := 'median') FROM features", + tables=["features"], + ) + + def test_transform_with_multiple_coverage(self): + """Test multiple COVERAGE expressions raise error. + + Given: + A query with two COVERAGE expressions + When: + Transpiled + Then: + It should raise ValueError matching "Multiple COVERAGE" + """ + # Act & Assert + with pytest.raises(ValueError, match="Multiple COVERAGE"): + transpile( + "SELECT COVERAGE(interval, 1000), COVERAGE(interval, 500) FROM features", + tables=["features"], + ) + + # ------------------------------------------------------------------ + # Functional / DuckDB end-to-end (CT-022 to CT-026) + # ------------------------------------------------------------------ + + def test_transform_end_to_end_basic_count(self, to_df): + """Test count correctness with two intervals in one bin. + + Given: + A DuckDB table with two intervals in the same 1000bp bin + When: + COVERAGE count is transpiled and executed + Then: + It should return count=2 for that bin + """ + # Arrange + giql_sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM features", + tables=["features"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 100 AS start, 200 AS \"end\" " + "UNION ALL SELECT 'chr1', 300, 400" + ) + + # Act + df = to_df(conn.execute(giql_sql)) + conn.close() + + # Assert + row = df[df["start"] == 0].iloc[0] + assert row["value"] == 2 + + def test_transform_end_to_end_zero_coverage_bins(self, to_df): + """Test zero-coverage bins are present via LEFT JOIN. + + Given: + A DuckDB table with intervals covering only some bins + When: + COVERAGE count is transpiled and executed + Then: + Bins beyond intervals should appear with count=0 + """ + # Arrange + giql_sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM features", + tables=["features"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 100 AS start, 200 AS \"end\" " + "UNION ALL SELECT 'chr1', 1500, 2500" + ) + + # Act + df = to_df(conn.execute(giql_sql)) + conn.close() + + # Assert + assert len(df) >= 3 + assert df[df["start"] == 0].iloc[0]["value"] == 1 + + def test_transform_end_to_end_where_preserves_zero_bins(self, to_df): + """Test WHERE in ON preserves bins without matching intervals. + + Given: + A DuckDB table with high-scoring intervals in bin [0,1000) and + bin [2000,3000), plus a low-scoring interval in bin [1000,2000) + When: + COVERAGE count with WHERE score > 50 is transpiled and executed + Then: + All three bins should be present (the WHERE is in the ON clause + so bins are not dropped even when no source rows match) + """ + # Arrange + giql_sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM features WHERE score > 50", + tables=["features"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 100 AS start, 200 AS \"end\", 100 AS score " + "UNION ALL SELECT 'chr1', 1500, 1600, 10 " + "UNION ALL SELECT 'chr1', 2100, 2200, 80" + ) + + # Act + df = to_df(conn.execute(giql_sql)) + conn.close() + + # Assert — all three bins are present (not filtered by WHERE) + assert len(df) == 3 + assert set(df["start"].tolist()) == {0, 1000, 2000} + + def test_transform_end_to_end_mean_with_target(self, to_df): + """Test mean stat with target column produces correct average. + + Given: + A DuckDB table with a score column and two intervals in one bin + When: + COVERAGE with stat='mean' and target='score' is transpiled + and executed + Then: + It should return the average of the score values + """ + # Arrange + giql_sql = transpile( + "SELECT COVERAGE(interval, 1000, stat := 'mean', " + "target := 'score') FROM features", + tables=["features"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 100 AS start, 200 AS \"end\", " + "10.0 AS score " + "UNION ALL SELECT 'chr1', 300, 400, 20.0" + ) + + # Act + df = to_df(conn.execute(giql_sql)) + conn.close() + + # Assert + row = df[df["start"] == 0].iloc[0] + assert row["value"] == pytest.approx(15.0) + + def test_transform_end_to_end_min_stat(self, to_df): + """Test min stat returns minimum interval length. + + Given: + A DuckDB table with intervals of different lengths in one bin + When: + COVERAGE with stat='min' is transpiled and executed + Then: + It should return the minimum interval length + """ + # Arrange + giql_sql = transpile( + "SELECT COVERAGE(interval, 1000, stat := 'min') FROM features", + tables=["features"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 100 AS start, 200 AS \"end\" " + "UNION ALL SELECT 'chr1', 300, 600" + ) + + # Act + df = to_df(conn.execute(giql_sql)) + conn.close() + + # Assert + row = df[df["start"] == 0].iloc[0] + assert row["value"] == 100 + + # ------------------------------------------------------------------ + # Property-based transpilation (PBT-T001, PBT-T002) + # ------------------------------------------------------------------ + + @given( + resolution=st.integers(min_value=1, max_value=10_000_000), + stat=st.sampled_from(VALID_STATS), + ) + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_transform_with_varying_stat_and_resolution(self, resolution, stat): + """Test stat parameter maps to correct SQL aggregate across input space. + + Given: + Any valid stat (count/mean/sum/min/max) and resolution (1-10M) + When: + Transpiled via transpile() + Then: + The output SQL should contain the corresponding SQL aggregate + function name and the resolution value + """ + # Arrange + stat_to_sql = { + "count": "COUNT", + "mean": "AVG", + "sum": "SUM(", + "min": "MIN(", + "max": "MAX(", + } + expected_agg = stat_to_sql[stat] + + # Act + sql = transpile( + f"SELECT COVERAGE(interval, {resolution}, stat := '{stat}') FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert expected_agg in upper + assert str(resolution) in sql + + @given( + resolution=st.integers(min_value=1, max_value=10_000_000), + stat=st.sampled_from(VALID_STATS), + ) + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_transform_structural_invariants_with_varying_stat_and_resolution( + self, resolution, stat + ): + """Test transpiled SQL always contains required structural elements. + + Given: + Any valid stat (count/mean/sum/min/max) and resolution (1-10M) + When: + Transpiled via transpile() + Then: + The output SQL should always contain __GIQL_BINS, + GENERATE_SERIES, LEFT JOIN, GROUP BY, and ORDER BY + """ + # Act + sql = transpile( + f"SELECT COVERAGE(interval, {resolution}, stat := '{stat}') FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "__GIQL_BINS" in upper + assert "GENERATE_SERIES" in upper + assert "LEFT JOIN" in upper + assert "GROUP BY" in upper + assert "ORDER BY" in upper diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..bc36148 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests for bedtools integration test utilities.""" diff --git a/tests/unit/test_bedtools_wrapper.py b/tests/unit/test_bedtools_wrapper.py new file mode 100644 index 0000000..f950243 --- /dev/null +++ b/tests/unit/test_bedtools_wrapper.py @@ -0,0 +1,386 @@ +"""Unit tests for pybedtools wrapper functions.""" + +import shutil + +import pytest + +pybedtools = pytest.importorskip("pybedtools") + +if not shutil.which("bedtools"): + pytest.skip( + "bedtools binary not found in PATH", + allow_module_level=True, + ) + +from tests.integration.bedtools.utils.bedtools_wrapper import BedtoolsError # noqa: E402 +from tests.integration.bedtools.utils.bedtools_wrapper import ( # noqa: E402 + bedtool_to_tuples, +) +from tests.integration.bedtools.utils.bedtools_wrapper import closest # noqa: E402 +from tests.integration.bedtools.utils.bedtools_wrapper import ( # noqa: E402 + create_bedtool, +) +from tests.integration.bedtools.utils.bedtools_wrapper import intersect # noqa: E402 +from tests.integration.bedtools.utils.bedtools_wrapper import merge # noqa: E402 + + +class TestCreateBedtool: + def test_bed3_format(self): + """ + GIVEN a list of BED3 tuples + WHEN create_bedtool() is called + THEN returns a BedTool with correct intervals + """ + bt = create_bedtool([("chr1", 100, 200)]) + intervals = list(bt) + assert len(intervals) == 1 + assert intervals[0].chrom == "chr1" + assert intervals[0].start == 100 + assert intervals[0].end == 200 + + def test_bed6_format(self): + """ + GIVEN a list of BED6 tuples + WHEN create_bedtool() is called + THEN returns a BedTool with all 6 fields + """ + bt = create_bedtool([("chr1", 100, 200, "a1", 50, "+")]) + intervals = list(bt) + assert len(intervals) == 1 + assert intervals[0].fields == ["chr1", "100", "200", "a1", "50", "+"] + + def test_none_values_replaced(self): + """ + GIVEN BED6 tuples with None values + WHEN create_bedtool() is called + THEN None values replaced with defaults + """ + bt = create_bedtool([("chr1", 100, 200, None, None, None)]) + fields = list(bt)[0].fields + assert fields[3] == "." # name + assert fields[4] == "0" # score + assert fields[5] == "." # strand + + def test_invalid_tuple_length_raises(self): + """ + GIVEN a tuple with invalid length + WHEN create_bedtool() is called + THEN ValueError is raised + """ + with pytest.raises(ValueError, match="Invalid interval format"): + create_bedtool([("chr1", 100)]) + + def test_multiple_intervals(self): + """ + GIVEN multiple intervals across chromosomes + WHEN create_bedtool() is called + THEN BedTool contains all intervals + """ + bt = create_bedtool( + [ + ("chr1", 100, 200, "a", 0, "+"), + ("chr2", 300, 400, "b", 0, "-"), + ] + ) + intervals = list(bt) + assert len(intervals) == 2 + + +class TestIntersect: + def test_basic_overlap(self): + """ + GIVEN two sets of overlapping intervals + WHEN intersect() is called + THEN returns intervals from A that overlap B + """ + a = [("chr1", 100, 200, "a1", 100, "+")] + b = [("chr1", 150, 250, "b1", 100, "+")] + result = intersect(a, b) + assert len(result) == 1 + assert result[0][0] == "chr1" + + def test_no_overlap(self): + """ + GIVEN non-overlapping intervals + WHEN intersect() is called + THEN returns empty list + """ + a = [("chr1", 100, 200, "a1", 100, "+")] + b = [("chr1", 300, 400, "b1", 100, "+")] + result = intersect(a, b) + assert result == [] + + def test_same_strand_mode(self): + """ + GIVEN intervals on same and opposite strands + WHEN intersect() is called with strand_mode="same" + THEN only same-strand overlaps returned + """ + a = [ + ("chr1", 100, 200, "a1", 0, "+"), + ("chr1", 100, 200, "a2", 0, "-"), + ] + b = [("chr1", 150, 250, "b1", 0, "+")] + result = intersect(a, b, strand_mode="same") + names = [r[3] for r in result] + assert "a1" in names + assert "a2" not in names + + def test_opposite_strand_mode(self): + """ + GIVEN intervals on same and opposite strands + WHEN intersect() is called with strand_mode="opposite" + THEN only opposite-strand overlaps returned + """ + a = [ + ("chr1", 100, 200, "a1", 0, "+"), + ("chr1", 100, 200, "a2", 0, "-"), + ] + b = [("chr1", 150, 250, "b1", 0, "+")] + result = intersect(a, b, strand_mode="opposite") + names = [r[3] for r in result] + assert "a2" in names + assert "a1" not in names + + def test_no_strand_mode(self): + """ + GIVEN overlapping intervals on different strands + WHEN intersect() is called with strand_mode=None + THEN all overlaps returned regardless of strand + """ + a = [("chr1", 100, 200, "a1", 0, "+")] + b = [("chr1", 150, 250, "b1", 0, "-")] + result = intersect(a, b) + assert len(result) == 1 + + +class TestMerge: + def test_overlapping(self): + """ + GIVEN overlapping intervals + WHEN merge() is called + THEN returns merged BED3 intervals + """ + intervals = [ + ("chr1", 100, 200, "i1", 0, "+"), + ("chr1", 150, 250, "i2", 0, "+"), + ] + result = merge(intervals) + assert len(result) == 1 + assert result[0] == ("chr1", 100, 250) + + def test_separated(self): + """ + GIVEN separated intervals + WHEN merge() is called + THEN each interval returned separately (BED3) + """ + intervals = [ + ("chr1", 100, 200, "i1", 0, "+"), + ("chr1", 300, 400, "i2", 0, "+"), + ] + result = merge(intervals) + assert len(result) == 2 + + def test_strand_specific(self): + """ + GIVEN overlapping intervals on different strands + WHEN merge() is called with strand_mode="same" + THEN merges per-strand separately + """ + intervals = [ + ("chr1", 100, 200, "i1", 0, "+"), + ("chr1", 150, 250, "i2", 0, "+"), + ("chr1", 120, 220, "i3", 0, "-"), + ] + result = merge(intervals, strand_mode="same") + # Should have 2: one merged + strand, one - strand + assert len(result) == 2 + + def test_adjacent(self): + """ + GIVEN adjacent intervals (end == start of next) + WHEN merge() is called + THEN adjacent intervals are merged + """ + intervals = [ + ("chr1", 100, 200, "i1", 0, "+"), + ("chr1", 200, 300, "i2", 0, "+"), + ] + result = merge(intervals) + assert len(result) == 1 + assert result[0] == ("chr1", 100, 300) + + +class TestClosest: + def test_basic(self): + """ + GIVEN non-overlapping intervals + WHEN closest() is called + THEN returns each A paired with nearest B plus distance + """ + a = [("chr1", 100, 200, "a1", 100, "+")] + b = [("chr1", 300, 400, "b1", 100, "+")] + result = closest(a, b) + assert len(result) == 1 + # Last field is distance + # bedtools 2.31+ may report 101 (1-based gap) vs 100 (0-based) + assert result[0][-1] in (100, 101) + + def test_cross_chromosome(self): + """ + GIVEN intervals on different chromosomes + WHEN closest() is called + THEN finds nearest per-chromosome + """ + a = [ + ("chr1", 100, 200, "a1", 0, "+"), + ("chr2", 100, 200, "a2", 0, "+"), + ] + b = [ + ("chr1", 300, 400, "b1", 0, "+"), + ("chr2", 500, 600, "b2", 0, "+"), + ] + result = closest(a, b) + assert len(result) == 2 + # Each A should match B on same chromosome + for row in result: + assert row[0] == row[6] # a.chrom == b.chrom + + def test_same_strand_mode(self): + """ + GIVEN intervals with mixed strands + WHEN closest() is called with strand_mode="same" + THEN returns nearest same-strand interval + """ + a = [("chr1", 100, 200, "a1", 0, "+")] + b = [ + ("chr1", 220, 240, "b_opp", 0, "-"), # closer but opposite + ("chr1", 300, 400, "b_same", 0, "+"), # farther but same + ] + result = closest(a, b, strand_mode="same") + assert len(result) == 1 + assert result[0][9] == "b_same" + + def test_k_greater_than_one(self): + """ + GIVEN one query and three database intervals + WHEN closest() is called with k=3 + THEN returns up to 3 nearest + """ + a = [("chr1", 200, 300, "a1", 0, "+")] + b = [ + ("chr1", 100, 150, "b1", 0, "+"), + ("chr1", 350, 400, "b2", 0, "+"), + ("chr1", 500, 600, "b3", 0, "+"), + ] + result = closest(a, b, k=3) + # bedtools returns up to k nearest; exact count may vary by version + assert len(result) >= 2 + + +class TestBedtoolToTuples: + def test_bed3_conversion(self): + """ + GIVEN a BedTool with BED3 intervals + WHEN bedtool_to_tuples() is called with bed_format="bed3" + THEN returns list of (chrom, start, end) tuples with int positions + """ + bt = pybedtools.BedTool("chr1\t100\t200\n", from_string=True) + result = bedtool_to_tuples(bt, bed_format="bed3") + assert result == [("chr1", 100, 200)] + + def test_bed6_conversion(self): + """ + GIVEN a BedTool with BED6 intervals + WHEN bedtool_to_tuples() is called with bed_format="bed6" + THEN returns list of 6-tuples with correct types + """ + bt = pybedtools.BedTool("chr1\t100\t200\tgene1\t500\t+\n", from_string=True) + result = bedtool_to_tuples(bt, bed_format="bed6") + assert result == [("chr1", 100, 200, "gene1", 500, "+")] + + def test_bed6_dot_to_none(self): + """ + GIVEN a BedTool with "." for name and strand + WHEN bedtool_to_tuples() is called with bed_format="bed6" + THEN "." values converted to None + """ + bt = pybedtools.BedTool("chr1\t100\t200\t.\t0\t.\n", from_string=True) + result = bedtool_to_tuples(bt, bed_format="bed6") + assert result[0][3] is None # name + assert result[0][5] is None # strand + + def test_bed6_padding(self): + """ + GIVEN a BedTool with fewer than 6 fields + WHEN bedtool_to_tuples() is called with bed_format="bed6" + THEN missing fields padded with defaults + """ + bt = pybedtools.BedTool("chr1\t100\t200\n", from_string=True) + result = bedtool_to_tuples(bt, bed_format="bed6") + assert len(result) == 1 + assert len(result[0]) == 6 + + def test_closest_format(self): + """ + GIVEN a BedTool from closest operation (13 fields) + WHEN bedtool_to_tuples() is called with bed_format="closest" + THEN returns tuples with A fields, B fields, and distance + """ + line = "chr1\t100\t200\ta1\t50\t+\tchr1\t300\t400\tb1\t75\t+\t100\n" + bt = pybedtools.BedTool(line, from_string=True) + result = bedtool_to_tuples(bt, bed_format="closest") + assert len(result) == 1 + row = result[0] + assert row[0] == "chr1" # a.chrom + assert row[1] == 100 # a.start (int) + assert row[6] == "chr1" # b.chrom + assert row[7] == 300 # b.start (int) + assert row[12] == 100 # distance (int) + + def test_closest_dot_values(self): + """ + GIVEN a BedTool from closest with "." scores/names + WHEN bedtool_to_tuples() is called with bed_format="closest" + THEN "." values converted to None + """ + line = "chr1\t100\t200\t.\t.\t.\tchr1\t300\t400\t.\t.\t.\t50\n" + bt = pybedtools.BedTool(line, from_string=True) + result = bedtool_to_tuples(bt, bed_format="closest") + row = result[0] + assert row[3] is None # a.name + assert row[4] is None # a.score + assert row[5] is None # a.strand + assert row[9] is None # b.name + + def test_invalid_format_raises(self): + """ + GIVEN any BedTool + WHEN bedtool_to_tuples() is called with invalid format + THEN ValueError is raised + """ + bt = pybedtools.BedTool("chr1\t100\t200\n", from_string=True) + with pytest.raises(ValueError, match="Unsupported format"): + bedtool_to_tuples(bt, bed_format="invalid") + + def test_closest_insufficient_fields_raises(self): + """ + GIVEN a BedTool with fewer than 13 fields + WHEN bedtool_to_tuples() is called with bed_format="closest" + THEN ValueError is raised + """ + bt = pybedtools.BedTool("chr1\t100\t200\ta1\t0\t+\n", from_string=True) + with pytest.raises(ValueError, match="Unexpected number of fields"): + bedtool_to_tuples(bt, bed_format="closest") + + +class TestBedtoolsError: + def test_is_exception_subclass(self): + """ + GIVEN a message string + WHEN BedtoolsError is raised + THEN it is an instance of Exception with correct message + """ + with pytest.raises(BedtoolsError, match="test error"): + raise BedtoolsError("test error") diff --git a/tests/unit/test_comparison.py b/tests/unit/test_comparison.py new file mode 100644 index 0000000..831ccb7 --- /dev/null +++ b/tests/unit/test_comparison.py @@ -0,0 +1,212 @@ +"""Unit tests for result comparison logic.""" + +from hypothesis import given +from hypothesis import strategies as st + +from tests.integration.bedtools.utils.comparison import compare_results + + +class TestCompareResults: + def test_exact_match(self): + """ + GIVEN two identical lists of tuples + WHEN compare_results() is called + THEN returns match=True with no differences + """ + rows = [("chr1", 100, 200), ("chr1", 300, 400)] + result = compare_results(rows, rows) + assert result.match is True + assert result.differences == [] + + def test_order_independent(self): + """ + GIVEN same tuples in different order + WHEN compare_results() is called + THEN returns match=True + """ + a = [("chr1", 300, 400), ("chr1", 100, 200)] + b = [("chr1", 100, 200), ("chr1", 300, 400)] + result = compare_results(a, b) + assert result.match is True + + def test_row_count_mismatch(self): + """ + GIVEN lists with different row counts + WHEN compare_results() is called + THEN returns match=False with row count difference + """ + a = [("chr1", 100, 200)] + b = [("chr1", 100, 200), ("chr1", 300, 400)] + result = compare_results(a, b) + assert result.match is False + assert any("Row count" in d for d in result.differences) + + def test_integer_exact_match(self): + """ + GIVEN rows with identical integer values + WHEN compare_results() is called + THEN returns match=True + """ + a = [("chr1", 100, 200, 50)] + b = [("chr1", 100, 200, 50)] + result = compare_results(a, b) + assert result.match is True + + def test_float_within_epsilon(self): + """ + GIVEN rows with floats differing by less than epsilon + WHEN compare_results() is called + THEN returns match=True + """ + a = [(1.0000000001,)] + b = [(1.0,)] + result = compare_results(a, b) + assert result.match is True + + def test_float_beyond_epsilon(self): + """ + GIVEN rows with floats differing by more than epsilon + WHEN compare_results() is called + THEN returns match=False + """ + a = [(1.5,)] + b = [(1.0,)] + result = compare_results(a, b) + assert result.match is False + + def test_custom_epsilon(self): + """ + GIVEN rows with floats differing by 0.05 + WHEN compare_results() is called with epsilon=0.1 + THEN returns match=True + """ + a = [(1.05,)] + b = [(1.0,)] + result = compare_results(a, b, epsilon=0.1) + assert result.match is True + + def test_none_none_match(self): + """ + GIVEN rows with None in the same positions + WHEN compare_results() is called + THEN returns match=True + """ + a = [("chr1", None, 200)] + b = [("chr1", None, 200)] + result = compare_results(a, b) + assert result.match is True + + def test_none_vs_value_mismatch(self): + """ + GIVEN rows where one has None and other has a value + WHEN compare_results() is called + THEN returns match=False + """ + a = [("chr1", None, 200)] + b = [("chr1", 100, 200)] + result = compare_results(a, b) + assert result.match is False + + def test_column_count_mismatch(self): + """ + GIVEN rows with different column counts + WHEN compare_results() is called + THEN returns match=False with column count difference + """ + a = [("chr1", 100, 200)] + b = [("chr1", 100)] + result = compare_results(a, b) + assert result.match is False + assert any("Column count" in d for d in result.differences) + + def test_extra_giql_rows(self): + """ + GIVEN GIQL has extra rows not in bedtools + WHEN compare_results() is called + THEN differences list the extra rows + """ + a = [("chr1", 100, 200), ("chr1", 300, 400)] + b = [("chr1", 100, 200)] + result = compare_results(a, b) + assert result.match is False + assert any( + "missing in bedtools" in d.lower() or "Present in GIQL" in d + for d in result.differences + ) + + def test_extra_bedtools_rows(self): + """ + GIVEN bedtools has extra rows not in GIQL + WHEN compare_results() is called + THEN differences list the missing rows + """ + a = [("chr1", 100, 200)] + b = [("chr1", 100, 200), ("chr1", 300, 400)] + result = compare_results(a, b) + assert result.match is False + assert any("Missing in GIQL" in d for d in result.differences) + + def test_empty_comparison(self): + """ + GIVEN both lists empty + WHEN compare_results() is called + THEN returns match=True with zero row counts + """ + result = compare_results([], []) + assert result.match is True + assert result.giql_row_count == 0 + assert result.bedtools_row_count == 0 + + def test_metadata_populated(self): + """ + GIVEN any comparison + WHEN compare_results() is called + THEN comparison_metadata contains epsilon and sorted keys + """ + result = compare_results([], []) + assert "epsilon" in result.comparison_metadata + assert "sorted" in result.comparison_metadata + + def test_row_counts_set(self): + """ + GIVEN lists of different sizes + WHEN compare_results() is called + THEN giql_row_count and bedtools_row_count are set correctly + """ + result = compare_results( + [("a",), ("b",)], + [("a",), ("b",), ("c",)], + ) + assert result.giql_row_count == 2 + assert result.bedtools_row_count == 3 + + def test_sorting_with_none_values(self): + """ + GIVEN rows containing None values in different positions + WHEN compare_results() is called + THEN sorting handles None deterministically without errors + """ + a = [("chr1", None, 200), ("chr1", 100, 200)] + b = [("chr1", 100, 200), ("chr1", None, 200)] + result = compare_results(a, b) + assert result.match is True + + @given( + rows=st.lists( + st.tuples( + st.sampled_from(["chr1", "chr2"]), + st.integers(min_value=0, max_value=10000), + st.integers(min_value=0, max_value=10000), + ), + min_size=0, + max_size=20, + ) + ) + def test_self_comparison_always_matches(self, rows): + """ + GIVEN any list of tuples + WHEN compare_results(rows, rows) is called + THEN always returns match=True + """ + result = compare_results(rows, rows) + assert result.match is True diff --git a/tests/unit/test_data_models.py b/tests/unit/test_data_models.py new file mode 100644 index 0000000..8086165 --- /dev/null +++ b/tests/unit/test_data_models.py @@ -0,0 +1,258 @@ +"""Unit tests for bedtools integration test data models.""" + +import pytest +from hypothesis import given +from hypothesis import strategies as st + +from tests.integration.bedtools.utils.data_models import ComparisonResult +from tests.integration.bedtools.utils.data_models import GenomicInterval + + +class TestGenomicInterval: + def test_basic_instantiation(self): + """ + GIVEN valid chrom, start, end values + WHEN GenomicInterval is instantiated + THEN object is created with correct attributes + """ + gi = GenomicInterval("chr1", 100, 200) + assert gi.chrom == "chr1" + assert gi.start == 100 + assert gi.end == 200 + assert gi.name is None + assert gi.score is None + assert gi.strand is None + + def test_full_instantiation(self): + """ + GIVEN all fields provided + WHEN GenomicInterval is instantiated + THEN all attributes are set correctly + """ + gi = GenomicInterval("chrX", 500, 1000, "gene1", 800, "+") + assert gi.chrom == "chrX" + assert gi.start == 500 + assert gi.end == 1000 + assert gi.name == "gene1" + assert gi.score == 800 + assert gi.strand == "+" + + def test_start_equals_end_raises(self): + """ + GIVEN start equals end + WHEN GenomicInterval is instantiated + THEN ValueError is raised + """ + with pytest.raises(ValueError, match="start .* >= end"): + GenomicInterval("chr1", 200, 200) + + def test_start_greater_than_end_raises(self): + """ + GIVEN start > end + WHEN GenomicInterval is instantiated + THEN ValueError is raised + """ + with pytest.raises(ValueError, match="start .* >= end"): + GenomicInterval("chr1", 300, 200) + + def test_negative_start_raises(self): + """ + GIVEN start < 0 + WHEN GenomicInterval is instantiated + THEN ValueError is raised + """ + with pytest.raises(ValueError, match="start .* < 0"): + GenomicInterval("chr1", -1, 200) + + def test_invalid_strand_raises(self): + """ + GIVEN an invalid strand value + WHEN GenomicInterval is instantiated + THEN ValueError is raised + """ + with pytest.raises(ValueError, match="Invalid strand"): + GenomicInterval("chr1", 100, 200, strand="X") + + def test_score_below_range_raises(self): + """ + GIVEN score < 0 + WHEN GenomicInterval is instantiated + THEN ValueError is raised + """ + with pytest.raises(ValueError, match="Invalid score"): + GenomicInterval("chr1", 100, 200, score=-1) + + def test_score_above_range_raises(self): + """ + GIVEN score > 1000 + WHEN GenomicInterval is instantiated + THEN ValueError is raised + """ + with pytest.raises(ValueError, match="Invalid score"): + GenomicInterval("chr1", 100, 200, score=1001) + + @pytest.mark.parametrize("strand", ["+", "-", "."]) + def test_valid_strand_values(self, strand): + """ + GIVEN a valid strand value + WHEN GenomicInterval is instantiated + THEN object is created successfully + """ + gi = GenomicInterval("chr1", 100, 200, strand=strand) + assert gi.strand == strand + + def test_score_boundary_zero(self): + """ + GIVEN score = 0 + WHEN GenomicInterval is instantiated + THEN object is created successfully + """ + gi = GenomicInterval("chr1", 100, 200, score=0) + assert gi.score == 0 + + def test_score_boundary_thousand(self): + """ + GIVEN score = 1000 + WHEN GenomicInterval is instantiated + THEN object is created successfully + """ + gi = GenomicInterval("chr1", 100, 200, score=1000) + assert gi.score == 1000 + + def test_to_tuple(self): + """ + GIVEN a GenomicInterval with all fields + WHEN to_tuple() is called + THEN returns 6-element tuple with all field values + """ + gi = GenomicInterval("chr1", 100, 200, "a1", 500, "+") + assert gi.to_tuple() == ("chr1", 100, 200, "a1", 500, "+") + + def test_to_tuple_with_nones(self): + """ + GIVEN a GenomicInterval with optional fields as None + WHEN to_tuple() is called + THEN tuple contains None for optional fields + """ + gi = GenomicInterval("chr1", 100, 200) + assert gi.to_tuple() == ("chr1", 100, 200, None, None, None) + + @given( + chrom=st.sampled_from(["chr1", "chr2", "chrX", "chrM"]), + start=st.integers(min_value=0, max_value=999_999), + size=st.integers(min_value=1, max_value=10_000), + strand=st.sampled_from(["+", "-", "."]), + score=st.integers(min_value=0, max_value=1000), + ) + def test_to_tuple_roundtrip(self, chrom, start, size, strand, score): + """ + GIVEN any valid GenomicInterval + WHEN to_tuple() is called + THEN the tuple can be used to reconstruct the interval's key fields + """ + end = start + size + gi = GenomicInterval(chrom, start, end, "name", score, strand) + t = gi.to_tuple() + assert t == (chrom, start, end, "name", score, strand) + + +class TestComparisonResult: + def test_matching_result(self): + """ + GIVEN match=True with equal row counts + WHEN ComparisonResult is instantiated + THEN attributes are set correctly + """ + cr = ComparisonResult(match=True, giql_row_count=5, bedtools_row_count=5) + assert cr.match is True + assert cr.giql_row_count == 5 + assert cr.bedtools_row_count == 5 + assert cr.differences == [] + + def test_mismatching_result(self): + """ + GIVEN match=False with differences + WHEN ComparisonResult is instantiated + THEN attributes are set correctly + """ + diffs = ["Row 0: mismatch"] + cr = ComparisonResult( + match=False, + giql_row_count=3, + bedtools_row_count=4, + differences=diffs, + ) + assert cr.match is False + assert cr.differences == diffs + + def test_bool_true(self): + """ + GIVEN a matching ComparisonResult + WHEN used in boolean context + THEN evaluates to True + """ + cr = ComparisonResult(match=True, giql_row_count=1, bedtools_row_count=1) + assert cr + + def test_bool_false(self): + """ + GIVEN a non-matching ComparisonResult + WHEN used in boolean context + THEN evaluates to False + """ + cr = ComparisonResult(match=False, giql_row_count=1, bedtools_row_count=2) + assert not cr + + def test_failure_message_match(self): + """ + GIVEN a matching ComparisonResult + WHEN failure_message() is called + THEN returns success message + """ + cr = ComparisonResult(match=True, giql_row_count=1, bedtools_row_count=1) + assert "match" in cr.failure_message().lower() + + def test_failure_message_mismatch(self): + """ + GIVEN a non-matching ComparisonResult with differences + WHEN failure_message() is called + THEN returns formatted message with row counts and differences + """ + cr = ComparisonResult( + match=False, + giql_row_count=3, + bedtools_row_count=5, + differences=["Row 0: val mismatch", "Row 1: missing"], + ) + msg = cr.failure_message() + assert "3" in msg + assert "5" in msg + assert "Row 0: val mismatch" in msg + assert "Row 1: missing" in msg + + def test_failure_message_truncates_at_ten(self): + """ + GIVEN a ComparisonResult with more than 10 differences + WHEN failure_message() is called + THEN only first 10 are shown with a count of remaining + """ + diffs = [f"diff_{i}" for i in range(15)] + cr = ComparisonResult( + match=False, + giql_row_count=0, + bedtools_row_count=15, + differences=diffs, + ) + msg = cr.failure_message() + assert "diff_9" in msg + assert "diff_10" not in msg + assert "5 more" in msg + + def test_default_metadata(self): + """ + GIVEN no comparison_metadata provided + WHEN ComparisonResult is instantiated + THEN metadata defaults to empty dict + """ + cr = ComparisonResult(match=True, giql_row_count=0, bedtools_row_count=0) + assert cr.comparison_metadata == {} diff --git a/tests/unit/test_dialect.py b/tests/unit/test_dialect.py new file mode 100644 index 0000000..2307c4d --- /dev/null +++ b/tests/unit/test_dialect.py @@ -0,0 +1,250 @@ +"""Tests for giql.dialect module.""" + +from sqlglot import exp +from sqlglot import parse_one +from sqlglot.tokens import TokenType + +from giql.dialect import CONTAINS +from giql.dialect import INTERSECTS +from giql.dialect import WITHIN +from giql.dialect import GIQLDialect +from giql.expressions import Contains +from giql.expressions import GIQLCluster +from giql.expressions import GIQLCoverage +from giql.expressions import GIQLDistance +from giql.expressions import GIQLMerge +from giql.expressions import GIQLNearest +from giql.expressions import Intersects +from giql.expressions import SpatialPredicate +from giql.expressions import SpatialSetPredicate +from giql.expressions import Within + + +class TestDialectConstants: + """Tests for module-level constants and token registration.""" + + def test_dc_001_constant_values(self): + """GIVEN the module is imported + WHEN INTERSECTS, CONTAINS, WITHIN constants are accessed + THEN they equal "INTERSECTS", "CONTAINS", "WITHIN" respectively. + """ + assert INTERSECTS == "INTERSECTS" + assert CONTAINS == "CONTAINS" + assert WITHIN == "WITHIN" + + def test_dc_002_token_type_attributes(self): + """GIVEN the module is imported + WHEN TokenType attributes are checked + THEN TokenType has INTERSECTS, CONTAINS, WITHIN attributes. + """ + assert hasattr(TokenType, "INTERSECTS") + assert hasattr(TokenType, "CONTAINS") + assert hasattr(TokenType, "WITHIN") + + +class TestGIQLDialect: + """Tests for GIQLDialect parsing of spatial predicates and GIQL functions.""" + + def test_gd_001_intersects_predicate(self): + """GIVEN a query string with `column INTERSECTS 'chr1:1000-2000'` + WHEN the query is parsed with GIQLDialect + THEN the AST contains an Intersects node with correct left and right expressions. + """ + ast = parse_one( + "SELECT * FROM t WHERE column INTERSECTS 'chr1:1000-2000'", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(Intersects)) + assert len(nodes) == 1 + node = nodes[0] + assert node.this.name == "column" + assert node.expression.this == "chr1:1000-2000" + + def test_gd_002_contains_predicate(self): + """GIVEN a query string with `column CONTAINS 'chr1:1500'` + WHEN the query is parsed with GIQLDialect + THEN the AST contains a Contains node. + """ + ast = parse_one( + "SELECT * FROM t WHERE column CONTAINS 'chr1:1500'", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(Contains)) + assert len(nodes) == 1 + + def test_gd_003_within_predicate(self): + """GIVEN a query string with `column WITHIN 'chr1:1000-5000'` + WHEN the query is parsed with GIQLDialect + THEN the AST contains a Within node. + """ + ast = parse_one( + "SELECT * FROM t WHERE column WITHIN 'chr1:1000-5000'", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(Within)) + assert len(nodes) == 1 + + def test_gd_004_intersects_any(self): + """GIVEN a query with `column INTERSECTS ANY('chr1:1000-2000', 'chr1:5000-6000')` + WHEN the query is parsed + THEN the AST contains a SpatialSetPredicate with quantifier=ANY. + """ + ast = parse_one( + "SELECT * FROM t WHERE column INTERSECTS ANY('chr1:1000-2000', 'chr1:5000-6000')", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(SpatialSetPredicate)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args["quantifier"] == "ANY" + + def test_gd_005_intersects_all(self): + """GIVEN a query with `column INTERSECTS ALL('chr1:1000-2000', 'chr1:5000-6000')` + WHEN the query is parsed + THEN the AST contains a SpatialSetPredicate with quantifier=ALL. + """ + ast = parse_one( + "SELECT * FROM t WHERE column INTERSECTS ALL('chr1:1000-2000', 'chr1:5000-6000')", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(SpatialSetPredicate)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args["quantifier"] == "ALL" + + def test_gd_006_plain_sql_fallback(self): + """GIVEN a query with no spatial operators (plain SQL) + WHEN the query is parsed with GIQLDialect + THEN the AST is a standard SELECT without spatial nodes. + """ + ast = parse_one( + "SELECT id, name FROM t WHERE id = 1", + dialect=GIQLDialect, + ) + spatial_nodes = list(ast.find_all(SpatialPredicate, SpatialSetPredicate)) + assert len(spatial_nodes) == 0 + assert ast.find(exp.Select) is not None + + def test_gd_007_cluster_basic(self): + """GIVEN a query with `CLUSTER(interval)` + WHEN the query is parsed + THEN the AST contains a GIQLCluster node. + """ + ast = parse_one( + "SELECT CLUSTER(interval) FROM t", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(GIQLCluster)) + assert len(nodes) == 1 + + def test_gd_008_cluster_with_distance(self): + """GIVEN a query with `CLUSTER(interval, 1000)` + WHEN the query is parsed + THEN the GIQLCluster node has distance arg set. + """ + ast = parse_one( + "SELECT CLUSTER(interval, 1000) FROM t", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(GIQLCluster)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args.get("distance") is not None + + def test_gd_009_merge_basic(self): + """GIVEN a query with `MERGE(interval)` + WHEN the query is parsed + THEN the AST contains a GIQLMerge node. + """ + ast = parse_one( + "SELECT MERGE(interval) FROM t", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(GIQLMerge)) + assert len(nodes) == 1 + + def test_gd_010_coverage_with_resolution(self): + """GIVEN a query with `COVERAGE(interval, 1000)` + WHEN the query is parsed + THEN the AST contains a GIQLCoverage node with resolution set. + """ + ast = parse_one( + "SELECT COVERAGE(interval, 1000) FROM t", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(GIQLCoverage)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args.get("resolution") is not None + + def test_gd_011_coverage_with_stat(self): + """GIVEN a query with `COVERAGE(interval, 500, stat := 'mean')` + WHEN the query is parsed + THEN the GIQLCoverage node has stat arg set. + """ + ast = parse_one( + "SELECT COVERAGE(interval, 500, stat := 'mean') FROM t", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(GIQLCoverage)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args.get("stat") is not None + assert node.args["stat"].this == "mean" + + def test_gd_012_coverage_with_kwarg_resolution(self): + """GIVEN a query with `COVERAGE(interval, resolution => 1000)` + WHEN the query is parsed + THEN the GIQLCoverage node has resolution set via Kwarg. + """ + ast = parse_one( + "SELECT COVERAGE(interval, resolution => 1000) FROM t", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(GIQLCoverage)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args.get("resolution") is not None + + def test_gd_013_coverage_with_stat_and_target(self): + """GIVEN a query with `COVERAGE(interval, 1000, stat := 'mean', target := 'score')` + WHEN the query is parsed + THEN the GIQLCoverage node has stat and target args set. + """ + ast = parse_one( + "SELECT COVERAGE(interval, 1000, stat := 'mean', target := 'score') FROM t", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(GIQLCoverage)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args.get("stat") is not None + assert node.args["stat"].this == "mean" + assert node.args.get("target") is not None + assert node.args["target"].this == "score" + + def test_gd_014_distance_function(self): + """GIVEN a query with `DISTANCE(a.interval, b.interval)` + WHEN the query is parsed + THEN the AST contains a GIQLDistance node. + """ + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval) FROM t", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(GIQLDistance)) + assert len(nodes) == 1 + + def test_gd_015_nearest_with_k(self): + """GIVEN a query with `NEAREST(genes, k := 3)` + WHEN the query is parsed + THEN the AST contains a GIQLNearest node with k arg set. + """ + ast = parse_one( + "SELECT NEAREST(genes, k := 3) FROM t", + dialect=GIQLDialect, + ) + nodes = list(ast.find_all(GIQLNearest)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args.get("k") is not None diff --git a/tests/unit/test_duckdb_loader.py b/tests/unit/test_duckdb_loader.py new file mode 100644 index 0000000..b3b7a0c --- /dev/null +++ b/tests/unit/test_duckdb_loader.py @@ -0,0 +1,81 @@ +"""Unit tests for DuckDB interval loading utility.""" + +import duckdb +import pytest + +from tests.integration.bedtools.utils.duckdb_loader import load_intervals + + +@pytest.fixture() +def conn(): + c = duckdb.connect(":memory:") + yield c + c.close() + + +class TestLoadIntervals: + def test_creates_table_with_correct_schema(self, conn): + """ + GIVEN a DuckDB connection and interval tuples + WHEN load_intervals() is called + THEN table is created with columns: chrom, start, end, name, score, strand + """ + load_intervals(conn, "test_table", [("chr1", 100, 200, "a1", 50, "+")]) + cols = conn.execute( + "SELECT column_name FROM information_schema.columns " + "WHERE table_name = 'test_table' ORDER BY ordinal_position" + ).fetchall() + col_names = [c[0] for c in cols] + assert col_names == ["chrom", "start", "end", "name", "score", "strand"] + + def test_inserts_all_rows(self, conn): + """ + GIVEN multiple interval tuples + WHEN load_intervals() is called and table is queried + THEN all rows are present with correct values + """ + intervals = [ + ("chr1", 100, 200, "a1", 50, "+"), + ("chr2", 300, 400, "a2", 75, "-"), + ] + load_intervals(conn, "t", intervals) + rows = conn.execute("SELECT * FROM t ORDER BY chrom").fetchall() + assert len(rows) == 2 + assert rows[0] == ("chr1", 100, 200, "a1", 50, "+") + assert rows[1] == ("chr2", 300, 400, "a2", 75, "-") + + def test_null_handling(self, conn): + """ + GIVEN tuples with None values for optional fields + WHEN load_intervals() is called + THEN NULL values stored correctly in DuckDB + """ + load_intervals(conn, "t", [("chr1", 100, 200, None, None, None)]) + row = conn.execute("SELECT * FROM t").fetchone() + assert row == ("chr1", 100, 200, None, None, None) + + def test_multi_chromosome(self, conn): + """ + GIVEN intervals across multiple chromosomes + WHEN load_intervals() is called + THEN all intervals inserted regardless of chromosome + """ + intervals = [ + ("chr1", 100, 200, "a", 0, "+"), + ("chr2", 100, 200, "b", 0, "+"), + ("chrX", 100, 200, "c", 0, "+"), + ] + load_intervals(conn, "t", intervals) + count = conn.execute("SELECT COUNT(*) FROM t").fetchone()[0] + assert count == 3 + + def test_empty_dataset(self, conn): + """ + GIVEN an empty list of intervals + WHEN load_intervals() is called + THEN DuckDB raises an error (executemany requires non-empty list) + """ + import duckdb + + with pytest.raises(duckdb.InvalidInputException): + load_intervals(conn, "t", []) diff --git a/tests/unit/test_expressions.py b/tests/unit/test_expressions.py new file mode 100644 index 0000000..b4b8af0 --- /dev/null +++ b/tests/unit/test_expressions.py @@ -0,0 +1,655 @@ +"""Tests for custom AST expression nodes. + +Test specification: specs/test_expressions.md +""" + +from sqlglot import exp +from sqlglot import parse_one + +from giql.dialect import GIQLDialect +from giql.expressions import Contains +from giql.expressions import GenomicRange +from giql.expressions import GIQLCluster +from giql.expressions import GIQLCoverage +from giql.expressions import GIQLDistance +from giql.expressions import GIQLMerge +from giql.expressions import GIQLNearest +from giql.expressions import Intersects +from giql.expressions import SpatialPredicate +from giql.expressions import SpatialSetPredicate +from giql.expressions import Within + + +class TestGenomicRange: + """Tests for GenomicRange expression node.""" + + def test_instantiate_with_required_args(self): + """GR-001: Instantiate with required args. + + Given: + All required args (chromosome, start, end) + When: + GenomicRange is instantiated + Then: + Instance has correct chromosome, start, and end args + """ + chrom = exp.Literal.string("chr1") + start = exp.Literal.number(1000) + end = exp.Literal.number(2000) + + gr = GenomicRange(chromosome=chrom, start=start, end=end) + + assert gr.args["chromosome"] is chrom + assert gr.args["start"] is start + assert gr.args["end"] is end + + def test_instantiate_with_all_args(self): + """GR-002: Instantiate with all args including optional strand and coord_system. + + Given: + Required args plus optional strand and coord_system + When: + GenomicRange is instantiated + Then: + Instance has all five args accessible + """ + chrom = exp.Literal.string("chr1") + start = exp.Literal.number(1000) + end = exp.Literal.number(2000) + strand = exp.Literal.string("+") + coord_system = exp.Literal.string("0-based") + + gr = GenomicRange( + chromosome=chrom, + start=start, + end=end, + strand=strand, + coord_system=coord_system, + ) + + assert gr.args["chromosome"] is chrom + assert gr.args["start"] is start + assert gr.args["end"] is end + assert gr.args["strand"] is strand + assert gr.args["coord_system"] is coord_system + + def test_optional_args_default_to_none(self): + """GR-003: Optional args default to None. + + Given: + Only required args provided + When: + GenomicRange is instantiated + Then: + strand and coord_system args are None + """ + gr = GenomicRange( + chromosome=exp.Literal.string("chr1"), + start=exp.Literal.number(1000), + end=exp.Literal.number(2000), + ) + + assert gr.args.get("strand") is None + assert gr.args.get("coord_system") is None + + +class TestSpatialPredicate: + """Tests for SpatialPredicate subclasses.""" + + def test_intersects_is_spatial_predicate_and_binary(self): + """SP-001: Intersects inheritance. + + Given: + Two expression nodes (this, expression) + When: + Intersects is instantiated + Then: + Instance is a SpatialPredicate and exp.Binary + """ + left = exp.Column(this=exp.Identifier(this="a")) + right = exp.Column(this=exp.Identifier(this="b")) + + node = Intersects(this=left, expression=right) + + assert isinstance(node, SpatialPredicate) + assert isinstance(node, exp.Binary) + + def test_contains_is_spatial_predicate_and_binary(self): + """SP-002: Contains inheritance. + + Given: + Two expression nodes + When: + Contains is instantiated + Then: + Instance is a SpatialPredicate and exp.Binary + """ + left = exp.Column(this=exp.Identifier(this="a")) + right = exp.Column(this=exp.Identifier(this="b")) + + node = Contains(this=left, expression=right) + + assert isinstance(node, SpatialPredicate) + assert isinstance(node, exp.Binary) + + def test_within_is_spatial_predicate_and_binary(self): + """SP-003: Within inheritance. + + Given: + Two expression nodes + When: + Within is instantiated + Then: + Instance is a SpatialPredicate and exp.Binary + """ + left = exp.Column(this=exp.Identifier(this="a")) + right = exp.Column(this=exp.Identifier(this="b")) + + node = Within(this=left, expression=right) + + assert isinstance(node, SpatialPredicate) + assert isinstance(node, exp.Binary) + + +class TestSpatialSetPredicate: + """Tests for SpatialSetPredicate expression node.""" + + def test_instantiate_with_all_required_args(self): + """SSP-001: Instantiate with all required args. + + Given: + All required args (this, operator, quantifier, ranges) + When: + SpatialSetPredicate is instantiated + Then: + Instance has all four args accessible + """ + this = exp.Column(this=exp.Identifier(this="interval")) + operator = exp.Literal.string("INTERSECTS") + quantifier = exp.Literal.string("ANY") + ranges = exp.Array( + expressions=[ + exp.Literal.string("chr1:1000-2000"), + exp.Literal.string("chr1:5000-6000"), + ] + ) + + node = SpatialSetPredicate( + this=this, + operator=operator, + quantifier=quantifier, + ranges=ranges, + ) + + assert node.args["this"] is this + assert node.args["operator"] is operator + assert node.args["quantifier"] is quantifier + assert node.args["ranges"] is ranges + + +class TestGIQLCluster: + """Tests for GIQLCluster expression node parsing.""" + + def test_parse_cluster_with_one_arg(self): + """CL-001: Parse CLUSTER with one positional arg. + + Given: + A CLUSTER expression with one positional arg (column) + When: + Parsed with GIQLDialect + Then: + GIQLCluster instance has `this` set + """ + ast = parse_one( + "SELECT CLUSTER(interval) FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLCluster)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + + def test_parse_cluster_with_distance(self): + """CL-002: Parse CLUSTER with distance. + + Given: + A CLUSTER expression with two positional args (column, distance) + When: + Parsed with GIQLDialect + Then: + GIQLCluster instance has `this` and `distance` set + """ + ast = parse_one( + "SELECT CLUSTER(interval, 1000) FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLCluster)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["distance"].this == "1000" + + def test_parse_cluster_with_stranded(self): + """CL-003: Parse CLUSTER with stranded parameter. + + Given: + A CLUSTER expression with one positional and stranded := true + When: + Parsed with GIQLDialect + Then: + GIQLCluster instance has `this` and `stranded` set + """ + ast = parse_one( + "SELECT CLUSTER(interval, stranded := true) FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLCluster)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["stranded"] is not None + + def test_parse_cluster_with_distance_and_stranded(self): + """CL-004: Parse CLUSTER with distance and stranded. + + Given: + A CLUSTER expression with two positionals and stranded := true + When: + Parsed with GIQLDialect + Then: + GIQLCluster instance has `this`, `distance`, and `stranded` set + """ + ast = parse_one( + "SELECT CLUSTER(interval, 1000, stranded := true) FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLCluster)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["distance"].this == "1000" + assert nodes[0].args["stranded"] is not None + + def test_direct_instantiation_minimal(self): + """CL-005: Direct instantiation with just `this`. + + Given: + Required arg `this` only + When: + GIQLCluster is instantiated directly + Then: + Instance has `this` set; `distance` and `stranded` are absent + """ + col = exp.Column(this=exp.Identifier(this="interval")) + + node = GIQLCluster(this=col) + + assert node.args["this"] is col + assert node.args.get("distance") is None + assert node.args.get("stranded") is None + + +class TestGIQLMerge: + """Tests for GIQLMerge expression node parsing.""" + + def test_parse_merge_with_one_arg(self): + """MG-001: Parse MERGE with one positional arg. + + Given: + A MERGE expression with one positional arg (column) + When: + Parsed with GIQLDialect + Then: + GIQLMerge instance has `this` set + """ + ast = parse_one( + "SELECT MERGE(interval) FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLMerge)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + + def test_parse_merge_with_distance(self): + """MG-002: Parse MERGE with distance. + + Given: + A MERGE expression with two positional args (column, distance) + When: + Parsed with GIQLDialect + Then: + GIQLMerge instance has `this` and `distance` set + """ + ast = parse_one( + "SELECT MERGE(interval, 1000) FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLMerge)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["distance"].this == "1000" + + def test_parse_merge_with_stranded(self): + """MG-003: Parse MERGE with stranded parameter. + + Given: + A MERGE expression with one positional and stranded := true + When: + Parsed with GIQLDialect + Then: + GIQLMerge instance has `this` and `stranded` set + """ + ast = parse_one( + "SELECT MERGE(interval, stranded := true) FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLMerge)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["stranded"] is not None + + def test_parse_merge_with_distance_and_stranded(self): + """MG-004: Parse MERGE with distance and stranded. + + Given: + A MERGE expression with two positionals and stranded := true + When: + Parsed with GIQLDialect + Then: + GIQLMerge instance has `this`, `distance`, and `stranded` set + """ + ast = parse_one( + "SELECT MERGE(interval, 1000, stranded := true) FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLMerge)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["distance"].this == "1000" + assert nodes[0].args["stranded"] is not None + + +class TestGIQLCoverage: + """Tests for GIQLCoverage expression node parsing.""" + + def test_parse_coverage_with_positional_args(self): + """COV-001: Parse COVERAGE with positional args. + + Given: + A COVERAGE expression with two positional args (column, resolution) + When: + Parsed with GIQLDialect + Then: + GIQLCoverage instance has `this` and `resolution` set + """ + ast = parse_one( + "SELECT COVERAGE(interval, 1000) FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLCoverage)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["resolution"].this == "1000" + assert nodes[0].args.get("stat") is None + assert nodes[0].args.get("target") is None + + def test_parse_coverage_with_walrus_named_resolution(self): + """COV-002: Parse COVERAGE with := named resolution. + + Given: + A COVERAGE expression with one positional and resolution := 1000 + When: + Parsed with GIQLDialect + Then: + GIQLCoverage instance has `this` and `resolution` set + """ + ast = parse_one( + "SELECT COVERAGE(interval, resolution := 1000) FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLCoverage)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["resolution"].this == "1000" + + def test_parse_coverage_with_stat(self): + """COV-003: Parse COVERAGE with stat parameter. + + Given: + A COVERAGE expression with two positionals and stat := 'mean' + When: + Parsed with GIQLDialect + Then: + GIQLCoverage instance has `this`, `resolution`, and `stat` set + """ + ast = parse_one( + "SELECT COVERAGE(interval, 500, stat := 'mean') FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLCoverage)) + assert len(nodes) == 1 + assert nodes[0].args["resolution"].this == "500" + assert nodes[0].args["stat"].this == "mean" + + def test_parse_coverage_with_stat_and_target(self): + """COV-004: Parse COVERAGE with stat and target. + + Given: + A COVERAGE expression with two positionals, stat := 'mean', and target := 'score' + When: + Parsed with GIQLDialect + Then: + GIQLCoverage instance has `this`, `resolution`, `stat`, and `target` set + """ + ast = parse_one( + "SELECT COVERAGE(interval, 1000, stat := 'mean', target := 'score') FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLCoverage)) + assert len(nodes) == 1 + assert nodes[0].args["resolution"].this == "1000" + assert nodes[0].args["stat"].this == "mean" + assert nodes[0].args["target"].this == "score" + + def test_parse_coverage_with_arrow_named_resolution(self): + """COV-005: Parse COVERAGE with => named resolution. + + Given: + A COVERAGE expression with one positional and resolution => 1000 + When: + Parsed with GIQLDialect + Then: + GIQLCoverage instance has `this` and `resolution` set + """ + ast = parse_one( + "SELECT COVERAGE(interval, resolution => 1000) FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLCoverage)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["resolution"].this == "1000" + + def test_parse_coverage_with_target_no_stat(self): + """COV-006: Parse COVERAGE with target but no stat. + + Given: + A COVERAGE expression with two positionals and target := 'score' only + When: + Parsed with GIQLDialect + Then: + GIQLCoverage instance has `this`, `resolution`, and `target` set; `stat` is absent + """ + ast = parse_one( + "SELECT COVERAGE(interval, 1000, target := 'score') FROM features", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLCoverage)) + assert len(nodes) == 1 + assert nodes[0].args["resolution"].this == "1000" + assert nodes[0].args["target"].this == "score" + assert nodes[0].args.get("stat") is None + + def test_direct_instantiation_minimal(self): + """COV-007: Direct instantiation with required args only. + + Given: + Required args `this` and `resolution` only + When: + GIQLCoverage is instantiated directly + Then: + Instance has `this` and `resolution` set; `stat` and `target` are absent + """ + col = exp.Column(this=exp.Identifier(this="interval")) + resolution = exp.Literal.number(1000) + + node = GIQLCoverage(this=col, resolution=resolution) + + assert node.args["this"] is col + assert node.args["resolution"] is resolution + assert node.args.get("stat") is None + assert node.args.get("target") is None + + +class TestGIQLDistance: + """Tests for GIQLDistance expression node parsing.""" + + def test_parse_distance_with_two_positional_args(self): + """DI-001: Parse DISTANCE with two positional args. + + Given: + A DISTANCE expression with two positional args (interval_a, interval_b) + When: + Parsed with GIQLDialect + Then: + GIQLDistance instance has `this` and `expression` set + """ + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval) FROM a, b", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLDistance)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["expression"] is not None + + def test_parse_distance_with_stranded_and_signed(self): + """DI-002: Parse DISTANCE with stranded and signed. + + Given: + A DISTANCE expression with two positionals and stranded := true, signed := true + When: + Parsed with GIQLDialect + Then: + GIQLDistance instance has `this`, `expression`, `stranded`, and `signed` set + """ + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval, stranded := true, signed := true) FROM a, b", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLDistance)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["expression"] is not None + assert nodes[0].args["stranded"] is not None + assert nodes[0].args["signed"] is not None + + def test_parse_distance_with_stranded_only(self): + """DI-003: Parse DISTANCE with only stranded. + + Given: + A DISTANCE expression with two positionals and only stranded := true + When: + Parsed with GIQLDialect + Then: + GIQLDistance instance has `this`, `expression`, and `stranded` set; `signed` absent + """ + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval, stranded := true) FROM a, b", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLDistance)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["expression"] is not None + assert nodes[0].args["stranded"] is not None + assert nodes[0].args.get("signed") is None + + +class TestGIQLNearest: + """Tests for GIQLNearest expression node parsing.""" + + def test_parse_nearest_with_one_positional(self): + """NR-001: Parse NEAREST with one positional arg. + + Given: + A NEAREST expression with one positional arg (table) + When: + Parsed with GIQLDialect + Then: + GIQLNearest instance has `this` set + """ + ast = parse_one( + "SELECT NEAREST(genes) FROM peaks", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLNearest)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + + def test_parse_nearest_with_k(self): + """NR-002: Parse NEAREST with k parameter. + + Given: + A NEAREST expression with one positional and k := 3 + When: + Parsed with GIQLDialect + Then: + GIQLNearest instance has `this` and `k` set + """ + ast = parse_one( + "SELECT NEAREST(genes, k := 3) FROM peaks", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLNearest)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["k"].this == "3" + + def test_parse_nearest_with_multiple_named_params(self): + """NR-003: Parse NEAREST with multiple named params. + + Given: + A NEAREST expression with one positional and multiple named params + When: + Parsed with GIQLDialect + Then: + GIQLNearest instance has all provided args set + """ + ast = parse_one( + "SELECT NEAREST(genes, k := 5, max_distance := 100000, stranded := true, signed := true) FROM peaks", + dialect=GIQLDialect, + ) + + nodes = list(ast.find_all(GIQLNearest)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["k"].this == "5" + assert nodes[0].args["max_distance"].this == "100000" + assert nodes[0].args["stranded"] is not None + assert nodes[0].args["signed"] is not None diff --git a/tests/unit/test_generators_base.py b/tests/unit/test_generators_base.py new file mode 100644 index 0000000..e31f907 --- /dev/null +++ b/tests/unit/test_generators_base.py @@ -0,0 +1,460 @@ +"""Tests for BaseGIQLGenerator. + +Test specification: specs/test_generators_base.md +Test IDs: BG-001 through BG-020 +""" + +import pytest +from sqlglot import parse_one + +from giql.dialect import GIQLDialect +from giql.generators import BaseGIQLGenerator +from giql.table import Table +from giql.table import Tables + + +@pytest.fixture +def tables_two(): + """Tables with two tables for column-to-column tests.""" + tables = Tables() + tables.register("features_a", Table("features_a")) + tables.register("features_b", Table("features_b")) + return tables + + +@pytest.fixture +def tables_peaks_and_genes(): + """Tables with peaks and genes for NEAREST/DISTANCE tests.""" + tables = Tables() + tables.register("peaks", Table("peaks")) + tables.register("genes", Table("genes")) + return tables + + +def _normalize(sql: str) -> str: + """Collapse whitespace for easier assertion.""" + return " ".join(sql.split()) + + +class TestBaseGIQLGenerator: + """Tests for BaseGIQLGenerator class (BG-001 to BG-020).""" + + # ------------------------------------------------------------------ + # Instantiation + # ------------------------------------------------------------------ + + def test_bg_001_no_args_defaults(self): + """ + GIVEN no arguments + WHEN BaseGIQLGenerator is instantiated + THEN instance has empty Tables and SUPPORTS_LATERAL is True. + """ + generator = BaseGIQLGenerator() + + assert generator.tables is not None + assert generator.SUPPORTS_LATERAL is True + # Empty tables: looking up any name returns None + assert generator.tables.get("anything") is None + + def test_bg_002_with_tables(self): + """ + GIVEN a Tables instance with a registered table + WHEN BaseGIQLGenerator is instantiated with tables= + THEN the instance uses the provided tables for column resolution. + """ + tables = Tables() + tables.register("peaks", Table("peaks")) + generator = BaseGIQLGenerator(tables=tables) + + assert generator.tables is tables + assert "peaks" in generator.tables + + # ------------------------------------------------------------------ + # Spatial predicates + # ------------------------------------------------------------------ + + def test_bg_003_intersects_literal(self): + """ + GIVEN an Intersects AST node with a literal range 'chr1:1000-2000' + WHEN generate is called + THEN output contains chrom = 'chr1' AND start < 2000 AND end > 1000. + """ + tables = Tables() + tables.register("peaks", Table("peaks")) + generator = BaseGIQLGenerator(tables=tables) + + ast = parse_one( + "SELECT * FROM peaks WHERE interval INTERSECTS 'chr1:1000-2000'", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + assert "\"chrom\" = 'chr1'" in sql + assert '"start" < 2000' in sql + assert '"end" > 1000' in sql + + def test_bg_004_intersects_column_to_column(self, tables_two): + """ + GIVEN an Intersects AST node with column-to-column (a.interval INTERSECTS b.interval) + WHEN generate is called + THEN output contains chrom equality and overlap conditions using both table prefixes. + """ + generator = BaseGIQLGenerator(tables=tables_two) + + ast = parse_one( + "SELECT * FROM features_a AS a CROSS JOIN features_b AS b " + "WHERE a.interval INTERSECTS b.interval", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + assert 'a."chrom" = b."chrom"' in sql + assert 'a."start" < b."end"' in sql + assert 'a."end" > b."start"' in sql + + def test_bg_005_contains_point(self): + """ + GIVEN a Contains AST node with a point range 'chr1:1500' + WHEN generate is called + THEN output contains point containment predicate. + """ + generator = BaseGIQLGenerator() + + ast = parse_one( + "SELECT * FROM peaks WHERE interval CONTAINS 'chr1:1500'", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + assert "\"chrom\" = 'chr1'" in sql + assert '"start" <= 1500' in sql + assert '"end" > 1500' in sql + + def test_bg_006_contains_range(self): + """ + GIVEN a Contains AST node with a range 'chr1:1000-2000' + WHEN generate is called + THEN output contains range containment predicate. + """ + generator = BaseGIQLGenerator() + + ast = parse_one( + "SELECT * FROM peaks WHERE interval CONTAINS 'chr1:1000-2000'", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + assert "\"chrom\" = 'chr1'" in sql + assert '"start" <= 1000' in sql + assert '"end" >= 2000' in sql + + def test_bg_007_within_range(self): + """ + GIVEN a Within AST node with a range 'chr1:1000-5000' + WHEN generate is called + THEN output contains within predicate. + """ + generator = BaseGIQLGenerator() + + ast = parse_one( + "SELECT * FROM peaks WHERE interval WITHIN 'chr1:1000-5000'", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + assert "\"chrom\" = 'chr1'" in sql + assert '"start" >= 1000' in sql + assert '"end" <= 5000' in sql + + # ------------------------------------------------------------------ + # Spatial set predicates + # ------------------------------------------------------------------ + + def test_bg_008_intersects_any(self): + """ + GIVEN a SpatialSetPredicate with INTERSECTS ANY and two ranges + WHEN generate is called + THEN output contains two conditions joined by OR. + """ + generator = BaseGIQLGenerator() + + ast = parse_one( + "SELECT * FROM peaks " + "WHERE interval INTERSECTS ANY('chr1:1000-2000', 'chr1:5000-6000')", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + assert " OR " in sql + assert '"end" > 1000' in sql + assert '"end" > 5000' in sql + + def test_bg_009_intersects_all(self): + """ + GIVEN a SpatialSetPredicate with INTERSECTS ALL and two ranges + WHEN generate is called + THEN output contains two conditions joined by AND. + """ + generator = BaseGIQLGenerator() + + ast = parse_one( + "SELECT * FROM peaks " + "WHERE interval INTERSECTS ALL('chr1:1000-2000', 'chr1:1500-1800')", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + # The outer WHERE already has AND, but the set predicate wraps + # its conditions in parens joined by AND. + norm = _normalize(sql) + # Both range predicates should appear + assert '"start" < 2000' in sql + assert '"start" < 1800' in sql + # They are joined by AND (inside the set predicate parentheses) + # Check the pattern: one condition AND another condition + idx_first = norm.index('"start" < 2000') + idx_second = norm.index('"start" < 1800') + between = norm[idx_first:idx_second] + assert "AND" in between + + # ------------------------------------------------------------------ + # DISTANCE + # ------------------------------------------------------------------ + + def test_bg_010_distance_basic(self, tables_two): + """ + GIVEN a GIQLDistance node with two column references + WHEN generate is called + THEN output contains CASE WHEN with chromosome check, overlap check, and distance calculations. + """ + generator = BaseGIQLGenerator(tables=tables_two) + + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval) AS dist " + "FROM features_a a CROSS JOIN features_b b", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + assert 'a."chrom" != b."chrom" THEN NULL' in sql + assert "THEN 0" in sql + assert 'b."start" - a."end"' in sql + assert 'a."start" - b."end"' in sql + assert sql.startswith("SELECT CASE WHEN") + + def test_bg_011_distance_stranded(self, tables_two): + """ + GIVEN a GIQLDistance node with stranded := true + WHEN generate is called + THEN output contains strand NULL checks and strand flip logic. + """ + generator = BaseGIQLGenerator(tables=tables_two) + + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval, stranded := true) AS dist " + "FROM features_a a CROSS JOIN features_b b", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + assert 'a."strand" IS NULL' in sql + assert 'b."strand" IS NULL' in sql + assert "a.\"strand\" = '.'" in sql + assert "a.\"strand\" = '?'" in sql + assert "a.\"strand\" = '-'" in sql + + def test_bg_012_distance_signed(self, tables_two): + """ + GIVEN a GIQLDistance node with signed := true + WHEN generate is called + THEN output contains signed distance (negative for upstream). + """ + generator = BaseGIQLGenerator(tables=tables_two) + + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval, signed := true) AS dist " + "FROM features_a a CROSS JOIN features_b b", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + # Signed: ELSE branch has negative sign + assert "-(" in sql + # Unsigned ELSE would be (a."start" - b."end") without negation + # Signed ELSE is -(a."start" - b."end") + assert '-(a."start" - b."end")' in sql + + def test_bg_013_distance_stranded_and_signed(self, tables_two): + """ + GIVEN a GIQLDistance node with stranded := true and signed := true + WHEN generate is called + THEN output contains both strand flip and signed distance. + """ + generator = BaseGIQLGenerator(tables=tables_two) + + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval, stranded := true, signed := true) AS dist " + "FROM features_a a CROSS JOIN features_b b", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + # Should have strand NULL checks + assert 'a."strand" IS NULL' in sql + # Should have strand flip + assert "a.\"strand\" = '-'" in sql + # Stranded+signed: the ELSE for '-' strand flips sign differently + # from stranded-only + # In stranded+signed: ELSE WHEN strand='-' THEN (a.start - b.end) + # In stranded-only: ELSE WHEN strand='-' THEN -(a.start - b.end) + assert '(a."start" - b."end")' in sql + assert '-(a."start" - b."end")' in sql + + def test_bg_014_distance_closed_intervals(self): + """ + GIVEN tables with interval_type="closed" for one table + WHEN generate is called for a DISTANCE expression + THEN output contains '+ 1' gap adjustment. + """ + tables = Tables() + tables.register("bed_a", Table("bed_a", interval_type="closed")) + tables.register("bed_b", Table("bed_b", interval_type="closed")) + generator = BaseGIQLGenerator(tables=tables) + + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval) AS dist " + "FROM bed_a a CROSS JOIN bed_b b", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + assert "+ 1)" in sql + + # ------------------------------------------------------------------ + # NEAREST + # ------------------------------------------------------------------ + + def test_bg_015_nearest_standalone(self, tables_peaks_and_genes): + """ + GIVEN a GIQLNearest node with explicit reference (standalone mode) + WHEN generate is called + THEN output is a subquery with WHERE, ORDER BY ABS(distance), LIMIT. + """ + generator = BaseGIQLGenerator(tables=tables_peaks_and_genes) + + ast = parse_one( + "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000')", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + norm = _normalize(sql) + + assert "WHERE" in norm + assert "ORDER BY ABS(" in norm + assert "LIMIT 1" in norm + assert "'chr1' = genes.\"chrom\"" in sql + assert "AS distance" in sql + + def test_bg_016_nearest_k5(self, tables_peaks_and_genes): + """ + GIVEN a GIQLNearest node with k := 5 + WHEN generate is called + THEN output has LIMIT 5. + """ + generator = BaseGIQLGenerator(tables=tables_peaks_and_genes) + + ast = parse_one( + "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 5)", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + assert "LIMIT 5" in sql + + def test_bg_017_nearest_max_distance(self, tables_peaks_and_genes): + """ + GIVEN a GIQLNearest node with max_distance := 100000 + WHEN generate is called + THEN the distance threshold appears in the WHERE clause. + """ + generator = BaseGIQLGenerator(tables=tables_peaks_and_genes) + + ast = parse_one( + "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', max_distance := 100000)", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + norm = _normalize(sql) + + assert "100000" in norm + assert "<= 100000" in norm + + def test_bg_018_nearest_correlated_lateral(self, tables_peaks_and_genes): + """ + GIVEN a GIQLNearest node in correlated mode (no standalone reference, in LATERAL context) + WHEN generate is called + THEN output is a LATERAL-compatible subquery referencing the outer table columns. + """ + generator = BaseGIQLGenerator(tables=tables_peaks_and_genes) + + ast = parse_one( + "SELECT * FROM peaks " + "CROSS JOIN LATERAL NEAREST(genes, reference := peaks.interval, k := 3)", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + norm = _normalize(sql) + + assert "LATERAL" in norm + assert 'peaks."chrom"' in sql + assert 'genes."chrom"' in sql + assert "LIMIT 3" in sql + + def test_bg_019_nearest_stranded(self, tables_peaks_and_genes): + """ + GIVEN a GIQLNearest node with stranded := true + WHEN generate is called + THEN output includes strand matching in WHERE clause. + """ + generator = BaseGIQLGenerator(tables=tables_peaks_and_genes) + + ast = parse_one( + "SELECT * FROM peaks " + "CROSS JOIN LATERAL NEAREST(genes, reference := peaks.interval, k := 3, stranded := true)", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + assert 'peaks."strand"' in sql + assert 'genes."strand"' in sql + # Strand matching in WHERE + assert 'peaks."strand" = genes."strand"' in sql + + # ------------------------------------------------------------------ + # SELECT override + # ------------------------------------------------------------------ + + def test_bg_020_select_alias_mapping(self): + """ + GIVEN a SELECT with aliased FROM and JOIN tables + WHEN generate is called + THEN alias-to-table mapping is built correctly, verified through correct column resolution in a spatial op. + """ + tables = Tables() + tables.register("features_a", Table("features_a")) + tables.register("features_b", Table("features_b")) + generator = BaseGIQLGenerator(tables=tables) + + ast = parse_one( + "SELECT * FROM features_a AS a " + "JOIN features_b AS b ON a.id = b.id " + "WHERE a.interval INTERSECTS b.interval", + dialect=GIQLDialect, + ) + sql = generator.generate(ast) + + # The aliases 'a' and 'b' should resolve to the registered tables + # and produce correctly qualified column references + assert 'a."chrom" = b."chrom"' in sql + assert 'a."start" < b."end"' in sql + assert 'a."end" > b."start"' in sql diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py new file mode 100644 index 0000000..55bc30d --- /dev/null +++ b/tests/unit/test_table.py @@ -0,0 +1,225 @@ +"""Tests for giql.table module.""" + +import pytest +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +from giql.table import Table +from giql.table import Tables + + +class TestTable: + """Tests for the Table dataclass.""" + + def test_default_values(self): + """ + GIVEN only the required arg `name` + WHEN Table is instantiated + THEN all fields have their default values. + """ + table = Table(name="peaks") + + assert table.name == "peaks" + assert table.genomic_col == "interval" + assert table.chrom_col == "chrom" + assert table.start_col == "start" + assert table.end_col == "end" + assert table.strand_col == "strand" + assert table.coordinate_system == "0based" + assert table.interval_type == "half_open" + + def test_all_custom_values(self): + """ + GIVEN all fields provided with custom values + WHEN Table is instantiated + THEN all fields reflect the custom values. + """ + table = Table( + name="variants", + genomic_col="position", + chrom_col="chr", + start_col="pos_start", + end_col="pos_end", + strand_col="direction", + coordinate_system="1based", + interval_type="closed", + ) + + assert table.name == "variants" + assert table.genomic_col == "position" + assert table.chrom_col == "chr" + assert table.start_col == "pos_start" + assert table.end_col == "pos_end" + assert table.strand_col == "direction" + assert table.coordinate_system == "1based" + assert table.interval_type == "closed" + + def test_strand_col_none(self): + """ + GIVEN strand_col=None + WHEN Table is instantiated + THEN strand_col is None. + """ + table = Table(name="peaks", strand_col=None) + + assert table.strand_col is None + + def test_coordinate_system_1based(self): + """ + GIVEN coordinate_system="1based" + WHEN Table is instantiated + THEN coordinate_system is "1based". + """ + table = Table(name="peaks", coordinate_system="1based") + + assert table.coordinate_system == "1based" + + def test_interval_type_closed(self): + """ + GIVEN interval_type="closed" + WHEN Table is instantiated + THEN interval_type is "closed". + """ + table = Table(name="peaks", interval_type="closed") + + assert table.interval_type == "closed" + + def test_invalid_coordinate_system(self): + """ + GIVEN coordinate_system="invalid" + WHEN Table is instantiated + THEN raises ValueError with message about valid options. + """ + with pytest.raises(ValueError, match="coordinate_system"): + Table(name="peaks", coordinate_system="invalid") + + def test_invalid_interval_type(self): + """ + GIVEN interval_type="invalid" + WHEN Table is instantiated + THEN raises ValueError with message about valid options. + """ + with pytest.raises(ValueError, match="interval_type"): + Table(name="peaks", interval_type="invalid") + + @given( + coordinate_system=st.sampled_from(["0based", "1based"]), + interval_type=st.sampled_from(["half_open", "closed"]), + ) + @settings(max_examples=20) + def test_valid_params_never_raise(self, coordinate_system, interval_type): + """ + GIVEN any Table with valid coordinate_system and interval_type + WHEN Table is instantiated + THEN no exception is raised and all fields are accessible. + """ + table = Table( + name="test", + coordinate_system=coordinate_system, + interval_type=interval_type, + ) + + assert table.coordinate_system == coordinate_system + assert table.interval_type == interval_type + + +class TestTables: + """Tests for the Tables container class.""" + + def test_get_missing_key(self): + """ + GIVEN a fresh Tables instance + WHEN get is called with an unregistered name + THEN returns None. + """ + tables = Tables() + + assert tables.get("unknown") is None + + def test_get_existing_key(self): + """ + GIVEN a Tables instance with one registered table + WHEN get is called with the registered name + THEN returns the Table object. + """ + tables = Tables() + table = Table(name="peaks") + tables.register("peaks", table) + + assert tables.get("peaks") is table + + def test_register_multiple_tables(self): + """ + GIVEN a Tables instance with one registered table + WHEN register is called with a new name and Table + THEN both tables are retrievable via get. + """ + tables = Tables() + peaks = Table(name="peaks") + variants = Table(name="variants") + tables.register("peaks", peaks) + tables.register("variants", variants) + + assert tables.get("peaks") is peaks + assert tables.get("variants") is variants + + def test_register_overwrites(self): + """ + GIVEN a Tables instance with a registered table + WHEN register is called with the same name and a different Table + THEN get returns the new Table (overwrite). + """ + tables = Tables() + old_table = Table(name="peaks") + new_table = Table(name="peaks", chrom_col="chr") + tables.register("peaks", old_table) + tables.register("peaks", new_table) + + assert tables.get("peaks") is new_table + + def test_contains(self): + """ + GIVEN a Tables instance with registered tables + WHEN the in operator is used + THEN returns True for registered names, False for others. + """ + tables = Tables() + tables.register("peaks", Table(name="peaks")) + + assert "peaks" in tables + assert "unknown" not in tables + + def test_iter(self): + """ + GIVEN a Tables instance with registered tables + WHEN iterated with a for loop + THEN yields all registered Table objects. + """ + tables = Tables() + peaks = Table(name="peaks") + variants = Table(name="variants") + tables.register("peaks", peaks) + tables.register("variants", variants) + + result = [] + for table in tables: + result.append(table) + + assert len(result) == 2 + assert peaks in result + assert variants in result + + def test_iter_empty(self): + """ + GIVEN a fresh Tables instance with no tables + WHEN iterated with a for loop + THEN yields nothing (empty iteration). + """ + tables = Tables() + + result = [] + for table in tables: + result.append(table) + + assert result == [] diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py new file mode 100644 index 0000000..656b3d8 --- /dev/null +++ b/tests/unit/test_transformer.py @@ -0,0 +1,494 @@ +"""Tests for the transformer module. + +Test specification: specs/test_transformer.md +""" + +import pytest +from sqlglot import exp +from sqlglot import parse_one + +from giql import transpile +from giql.dialect import GIQLDialect +from giql.generators import BaseGIQLGenerator +from giql.table import Table +from giql.table import Tables +from giql.transformer import COVERAGE_STAT_MAP +from giql.transformer import ClusterTransformer +from giql.transformer import CoverageTransformer +from giql.transformer import MergeTransformer + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_tables(*names: str, **custom: Table) -> Tables: + tables = Tables() + for name in names: + tables.register(name, Table(name)) + for name, table in custom.items(): + tables.register(name, table) + return tables + + +def _transform_and_sql(query: str, transformer_cls, tables: Tables | None = None) -> str: + tables = tables or _make_tables("features") + ast = parse_one(query, dialect=GIQLDialect) + transformer = transformer_cls(tables) + result = transformer.transform(ast) + generator = BaseGIQLGenerator(tables=tables) + return generator.generate(result) + + +# =========================================================================== +# TestCoverageStatMap +# =========================================================================== + + +class TestCoverageStatMap: + """Tests for the COVERAGE_STAT_MAP module-level constant.""" + + def test_csm_001_coverage_stat_map_has_correct_mappings(self): + """GIVEN the module is imported WHEN COVERAGE_STAT_MAP is accessed THEN it maps count->COUNT, mean->AVG, sum->SUM, min->MIN, max->MAX.""" + assert COVERAGE_STAT_MAP == { + "count": "COUNT", + "mean": "AVG", + "sum": "SUM", + "min": "MIN", + "max": "MAX", + } + + +# =========================================================================== +# TestClusterTransformer +# =========================================================================== + + +class TestClusterTransformer: + """Tests for ClusterTransformer.transform.""" + + def test_ct_001_basic_cluster_has_lag_and_sum_windows(self): + """GIVEN a Tables instance and a parsed SELECT with CLUSTER(interval) WHEN transform is called THEN the result contains LAG and SUM window expressions.""" + sql = _transform_and_sql( + "SELECT *, CLUSTER(interval) FROM features", ClusterTransformer + ) + upper = sql.upper() + assert "LAG" in upper + assert "SUM" in upper + + def test_ct_002_cluster_alias_preserved(self): + """GIVEN a parsed SELECT with CLUSTER(interval) AS cluster_id WHEN transform is called THEN the alias is preserved on the SUM window expression.""" + sql = _transform_and_sql( + "SELECT *, CLUSTER(interval) AS cluster_id FROM features", + ClusterTransformer, + ) + assert "cluster_id" in sql + + def test_ct_003_cluster_with_distance(self): + """GIVEN a parsed SELECT with CLUSTER(interval, 1000) WHEN transform is called THEN the LAG result has distance 1000 added.""" + sql = _transform_and_sql( + "SELECT *, CLUSTER(interval, 1000) FROM features", + ClusterTransformer, + ) + upper = sql.upper() + assert "LAG" in upper + assert "1000" in sql + + def test_ct_004_cluster_stranded_partitions_by_strand(self): + """GIVEN a parsed SELECT with CLUSTER(interval, stranded := true) WHEN transform is called THEN the result partitions by chrom AND strand.""" + sql = _transform_and_sql( + "SELECT *, CLUSTER(interval, stranded := true) FROM features", + ClusterTransformer, + ) + upper = sql.upper() + assert "STRAND" in upper + # Both chrom and strand should appear in partition + assert "CHROM" in upper + + def test_ct_005_non_select_returns_unchanged(self): + """GIVEN a non-SELECT expression WHEN transform is called THEN the expression is returned unchanged.""" + tables = _make_tables("features") + transformer = ClusterTransformer(tables) + insert = exp.Insert(this=exp.to_table("features")) + result = transformer.transform(insert) + assert result is insert + + def test_ct_006_no_cluster_returns_unchanged(self): + """GIVEN a SELECT with no CLUSTER expressions WHEN transform is called THEN the query is returned unchanged.""" + tables = _make_tables("features") + transformer = ClusterTransformer(tables) + ast = parse_one("SELECT * FROM features", dialect=GIQLDialect) + result = transformer.transform(ast) + assert result is ast + + def test_ct_007_custom_column_names_via_tables(self): + """GIVEN a Tables instance with custom column names WHEN transform is called on a CLUSTER query THEN the generated query uses custom column names.""" + custom = Table( + "features", + chrom_col="chromosome", + start_col="start_pos", + end_col="end_pos", + ) + tables = _make_tables(features=custom) + sql = _transform_and_sql( + "SELECT *, CLUSTER(interval) FROM features", + ClusterTransformer, + tables=tables, + ) + assert "chromosome" in sql + assert "start_pos" in sql + assert "end_pos" in sql + + def test_ct_008_cluster_inside_cte_recursive_transformation(self): + """GIVEN a SELECT with CLUSTER inside a CTE subquery WHEN transform is called THEN the CTE subquery is recursively transformed.""" + sql = _transform_and_sql( + "WITH c AS (SELECT *, CLUSTER(interval) AS cid FROM features) " + "SELECT * FROM c", + ClusterTransformer, + ) + upper = sql.upper() + assert "LAG" in upper + assert "SUM" in upper + + def test_ct_009_cluster_with_where_preserved(self): + """GIVEN a SELECT with CLUSTER and a WHERE clause WHEN transform is called THEN the WHERE clause is preserved.""" + sql = _transform_and_sql( + "SELECT *, CLUSTER(interval) FROM features WHERE score > 10", + ClusterTransformer, + ) + assert "score > 10" in sql + + def test_ct_010_specific_columns_with_cluster_adds_required_cols(self): + """GIVEN a SELECT with specific columns (not *) and CLUSTER WHEN transform is called THEN missing required genomic columns are added to the CTE select list.""" + sql = _transform_and_sql( + "SELECT name, CLUSTER(interval) AS cid FROM features", + ClusterTransformer, + ) + upper = sql.upper() + # Required genomic cols should be in the output + assert "CHROM" in upper + assert "START" in upper + assert "END" in upper + + +# =========================================================================== +# TestMergeTransformer +# =========================================================================== + + +class TestMergeTransformer: + """Tests for MergeTransformer.transform.""" + + def test_mt_001_basic_merge_has_group_by_min_max(self): + """GIVEN a Tables instance and a parsed SELECT with MERGE(interval) WHEN transform is called THEN the result has GROUP BY, MIN(start), MAX(end).""" + sql = _transform_and_sql( + "SELECT MERGE(interval) FROM features", MergeTransformer + ) + upper = sql.upper() + assert "GROUP BY" in upper + assert "MIN(" in upper + assert "MAX(" in upper + + def test_mt_002_merge_alias_dropped_output_fixed(self): + """GIVEN a parsed SELECT with MERGE(interval) AS merged WHEN transform is called THEN the query still produces valid output with fixed columns.""" + sql = _transform_and_sql( + "SELECT MERGE(interval) AS merged FROM features", + MergeTransformer, + ) + upper = sql.upper() + assert "GROUP BY" in upper + assert "MIN(" in upper + assert "MAX(" in upper + + def test_mt_003_merge_with_distance(self): + """GIVEN a parsed SELECT with MERGE(interval, 1000) WHEN transform is called THEN the distance is passed through to CLUSTER.""" + sql = _transform_and_sql( + "SELECT MERGE(interval, 1000) FROM features", + MergeTransformer, + ) + assert "1000" in sql + + def test_mt_004_merge_stranded_adds_strand_to_group_by(self): + """GIVEN a parsed SELECT with MERGE(interval, stranded := true) WHEN transform is called THEN strand appears in GROUP BY and partition.""" + sql = _transform_and_sql( + "SELECT MERGE(interval, stranded := true) FROM features", + MergeTransformer, + ) + upper = sql.upper() + assert "STRAND" in upper + assert "GROUP BY" in upper + + def test_mt_005_non_select_returns_unchanged(self): + """GIVEN a non-SELECT expression WHEN transform is called THEN the expression is returned unchanged.""" + tables = _make_tables("features") + transformer = MergeTransformer(tables) + insert = exp.Insert(this=exp.to_table("features")) + result = transformer.transform(insert) + assert result is insert + + def test_mt_006_no_merge_returns_unchanged(self): + """GIVEN a SELECT with no MERGE expressions WHEN transform is called THEN the query is returned unchanged.""" + tables = _make_tables("features") + transformer = MergeTransformer(tables) + ast = parse_one("SELECT * FROM features", dialect=GIQLDialect) + result = transformer.transform(ast) + assert result is ast + + def test_mt_007_two_merge_expressions_raises_value_error(self): + """GIVEN a SELECT with two MERGE expressions WHEN transform is called THEN it raises ValueError.""" + tables = _make_tables("features") + transformer = MergeTransformer(tables) + ast = parse_one( + "SELECT MERGE(interval), MERGE(interval) FROM features", + dialect=GIQLDialect, + ) + with pytest.raises(ValueError, match="Multiple MERGE"): + transformer.transform(ast) + + def test_mt_008_merge_with_where_preserved(self): + """GIVEN a SELECT with MERGE and a WHERE clause WHEN transform is called THEN the WHERE clause is preserved in the clustered subquery.""" + sql = _transform_and_sql( + "SELECT MERGE(interval) FROM features WHERE score > 10", + MergeTransformer, + ) + assert "score > 10" in sql + + def test_mt_009_merge_inside_cte_recursive_transformation(self): + """GIVEN a SELECT with MERGE inside a CTE subquery WHEN transform is called THEN the CTE subquery is recursively transformed.""" + sql = _transform_and_sql( + "WITH m AS (SELECT MERGE(interval) FROM features) SELECT * FROM m", + MergeTransformer, + ) + upper = sql.upper() + assert "GROUP BY" in upper + assert "MIN(" in upper + assert "MAX(" in upper + + +# =========================================================================== +# TestCoverageTransformer +# =========================================================================== + + +class TestCoverageTransformer: + """Tests for CoverageTransformer.transform.""" + + def test_cvt_001_basic_coverage_structure(self): + """GIVEN a Tables instance and a parsed SELECT with COVERAGE(interval, 1000) WHEN transform is called THEN the result has __giql_bins CTE, LEFT JOIN, COUNT, and GROUP BY.""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 1000) FROM features", + CoverageTransformer, + ) + upper = sql.upper() + assert "__GIQL_BINS" in upper + assert "LEFT JOIN" in upper + assert "COUNT" in upper + assert "GROUP BY" in upper + + def test_cvt_002_stat_mean_uses_avg(self): + """GIVEN a parsed SELECT with COVERAGE(interval, 500, stat := 'mean') WHEN transform is called THEN the result uses AVG over (end - start).""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 500, stat := 'mean') FROM features", + CoverageTransformer, + ) + upper = sql.upper() + assert "AVG" in upper + assert "COUNT" not in upper + + def test_cvt_003_stat_sum(self): + """GIVEN a parsed SELECT with COVERAGE(interval, 500, stat := 'sum') WHEN transform is called THEN the result uses SUM.""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 500, stat := 'sum') FROM features", + CoverageTransformer, + ) + assert "SUM" in sql.upper() + + def test_cvt_004_stat_min(self): + """GIVEN a parsed SELECT with COVERAGE(interval, 500, stat := 'min') WHEN transform is called THEN the result uses MIN.""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 500, stat := 'min') FROM features", + CoverageTransformer, + ) + assert "MIN(" in sql.upper() + + def test_cvt_005_stat_max(self): + """GIVEN a parsed SELECT with COVERAGE(interval, 500, stat := 'max') WHEN transform is called THEN the result uses MAX.""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 500, stat := 'max') FROM features", + CoverageTransformer, + ) + assert "MAX(" in sql.upper() + + def test_cvt_006_stat_mean_with_target_score(self): + """GIVEN a parsed SELECT with COVERAGE(interval, 1000, stat := 'mean', target := 'score') WHEN transform is called THEN the result uses AVG over the score column.""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 1000, stat := 'mean', target := 'score') FROM features", + CoverageTransformer, + ) + upper = sql.upper() + assert "AVG" in upper + assert "SCORE" in upper + + def test_cvt_007_target_score_with_default_count(self): + """GIVEN a parsed SELECT with COVERAGE(interval, 1000, target := 'score') and default count stat WHEN transform is called THEN the result uses COUNT over the score column.""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 1000, target := 'score') FROM features", + CoverageTransformer, + ) + upper = sql.upper() + assert "COUNT" in upper + assert "SCORE" in upper + # Should NOT have COUNT(source.*) + assert ".*)" not in sql + + def test_cvt_008_coverage_alias_preserved(self): + """GIVEN a parsed SELECT with COVERAGE(interval, 1000) AS cov WHEN transform is called THEN the aggregate column uses the alias 'cov'.""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 1000) AS cov FROM features", + CoverageTransformer, + ) + assert "AS cov" in sql + assert "AS value" not in sql + + def test_cvt_009_bare_coverage_default_alias_value(self): + """GIVEN a parsed SELECT with bare COVERAGE(interval, 1000) (no alias) WHEN transform is called THEN the aggregate column is aliased as 'value'.""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 1000) FROM features", + CoverageTransformer, + ) + assert "AS value" in sql + + def test_cvt_010_non_select_returns_unchanged(self): + """GIVEN a non-SELECT expression WHEN transform is called THEN the expression is returned unchanged.""" + tables = _make_tables("features") + transformer = CoverageTransformer(tables) + insert = exp.Insert(this=exp.to_table("features")) + result = transformer.transform(insert) + assert result is insert + + def test_cvt_011_no_coverage_returns_unchanged(self): + """GIVEN a SELECT with no COVERAGE expressions WHEN transform is called THEN the query is returned unchanged.""" + tables = _make_tables("features") + transformer = CoverageTransformer(tables) + ast = parse_one("SELECT * FROM features", dialect=GIQLDialect) + result = transformer.transform(ast) + assert result is ast + + def test_cvt_012_two_coverage_raises_value_error(self): + """GIVEN a SELECT with two COVERAGE expressions WHEN transform is called THEN it raises ValueError.""" + tables = _make_tables("features") + transformer = CoverageTransformer(tables) + ast = parse_one( + "SELECT COVERAGE(interval, 1000), COVERAGE(interval, 500) FROM features", + dialect=GIQLDialect, + ) + with pytest.raises(ValueError, match="Multiple COVERAGE"): + transformer.transform(ast) + + def test_cvt_013_where_in_join_on_and_chroms_subquery(self): + """GIVEN a parsed SELECT with COVERAGE and a WHERE clause WHEN transform is called THEN the WHERE is merged into the LEFT JOIN ON condition AND applied to the chroms subquery.""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 1000) FROM features WHERE score > 10", + CoverageTransformer, + ) + upper = sql.upper() + # WHERE should be in the ON clause + after_join = sql.split("LEFT JOIN")[1] + on_clause = after_join.split("GROUP BY")[0] + assert "score > 10" in on_clause + # WHERE should also be in the chroms subquery (the CTE part) + cte_part = sql.split(") SELECT")[0] + assert "score > 10" in cte_part + + def test_cvt_014_custom_column_names(self): + """GIVEN a Tables instance with custom column names WHEN transform is called on a COVERAGE query THEN the generated query uses custom column names.""" + custom = Table( + "peaks", + chrom_col="chromosome", + start_col="start_pos", + end_col="end_pos", + ) + tables = _make_tables(peaks=custom) + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 1000) FROM peaks", + CoverageTransformer, + tables=tables, + ) + assert "chromosome" in sql + assert "start_pos" in sql + assert "end_pos" in sql + + def test_cvt_015_non_integer_resolution_raises_value_error(self): + """GIVEN a parsed SELECT with COVERAGE where resolution is not an integer literal WHEN transform is called THEN it raises ValueError about resolution.""" + tables = _make_tables("features") + transformer = CoverageTransformer(tables) + # Construct an AST manually with a non-integer resolution + from giql.expressions import GIQLCoverage + + coverage = GIQLCoverage( + this=exp.column("interval"), + resolution=exp.column("some_col"), + ) + ast = exp.Select().select(coverage).from_("features") + with pytest.raises(ValueError, match="resolution"): + transformer.transform(ast) + + def test_cvt_016_invalid_stat_raises_value_error(self): + """GIVEN a parsed SELECT with COVERAGE(interval, 1000, stat := 'invalid') WHEN transform is called THEN it raises ValueError about unknown stat.""" + tables = _make_tables("features") + transformer = CoverageTransformer(tables) + ast = parse_one( + "SELECT COVERAGE(interval, 1000, stat := 'invalid') FROM features", + dialect=GIQLDialect, + ) + with pytest.raises(ValueError, match="Unknown COVERAGE stat"): + transformer.transform(ast) + + def test_cvt_017_coverage_inside_cte_recursive_transformation(self): + """GIVEN a parsed SELECT with COVERAGE inside a CTE subquery WHEN transform is called THEN the CTE subquery is recursively transformed.""" + sql = _transform_and_sql( + "WITH cov AS (SELECT COVERAGE(interval, 1000) FROM features) " + "SELECT * FROM cov", + CoverageTransformer, + ) + upper = sql.upper() + assert "__GIQL_BINS" in upper + assert "LEFT JOIN" in upper + assert "COUNT" in upper + + def test_cvt_018_table_alias_used_as_source_ref(self): + """GIVEN a query FROM a table with an alias (FROM features AS f) WHEN transform is called THEN the source_ref in the generated SQL uses the alias.""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 1000) FROM features AS f", + CoverageTransformer, + ) + upper = sql.upper() + assert "LEFT JOIN" in upper + # The alias 'f' should appear as the source reference in the join + assert "f." in sql or "AS f" in sql + + def test_cvt_019_bins_cte_has_generate_series_with_cross_join_lateral(self): + """GIVEN the bins CTE in a basic COVERAGE transformation WHEN the SQL is inspected THEN it contains generate_series with CROSS JOIN LATERAL.""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 1000) FROM features", + CoverageTransformer, + ) + upper = sql.upper() + assert "GENERATE_SERIES" in upper + assert "CROSS JOIN" in upper + assert "LATERAL" in upper + + def test_cvt_020_output_ordered_by_bins_chrom_bins_start(self): + """GIVEN a COVERAGE transformation output WHEN the ORDER BY clause is inspected THEN the output is ordered by bins.chrom, bins.start.""" + sql = _transform_and_sql( + "SELECT COVERAGE(interval, 1000) FROM features", + CoverageTransformer, + ) + upper = sql.upper() + assert "ORDER BY" in upper + # Extract ORDER BY clause + order_by_part = sql.split("ORDER BY")[1] + order_upper = order_by_part.upper() + assert "BINS" in order_upper + assert "CHROM" in order_upper + assert "START" in order_upper diff --git a/tests/unit/test_transpile.py b/tests/unit/test_transpile.py new file mode 100644 index 0000000..30be66f --- /dev/null +++ b/tests/unit/test_transpile.py @@ -0,0 +1,339 @@ +"""Unit tests for the transpile() function. + +Tests TR-001 through TR-021 covering all public API behavior of +giql.transpile as a black box: GIQL string in, SQL string out. +""" + +import pytest + +from giql import Table +from giql import transpile + + +class TestTranspile: + """Tests for transpile() public API (TR-001 to TR-021).""" + + # ── Basic transpilation ────────────────────────────────────────── + + def test_plain_sql_passthrough(self): + """ + GIVEN a plain SQL query with no GIQL extensions + WHEN transpile is called + THEN it returns an equivalent SQL string unchanged. + """ + sql = transpile("SELECT id, name FROM features") + upper = sql.upper() + assert "SELECT" in upper + assert "FEATURES" in upper + assert "ID" in upper + + def test_intersects_predicate(self): + """ + GIVEN a query with an INTERSECTS predicate and a tables list + WHEN transpile is called + THEN the returned SQL contains expanded range comparison predicates. + """ + sql = transpile( + "SELECT * FROM features WHERE interval INTERSECTS 'chr1:1000-2000'", + tables=["features"], + ) + upper = sql.upper() + assert "CHR1" in upper + assert "1000" in sql + assert "2000" in sql + # Range overlap requires both start/end comparisons + assert "START" in upper or "END" in upper + + def test_contains_predicate(self): + """ + GIVEN a query with a CONTAINS predicate + WHEN transpile is called + THEN the returned SQL contains containment predicates. + """ + sql = transpile( + "SELECT * FROM features WHERE interval CONTAINS 'chr1:1500'", + tables=["features"], + ) + upper = sql.upper() + assert "SELECT" in upper + assert "1500" in sql + + def test_within_predicate(self): + """ + GIVEN a query with a WITHIN predicate + WHEN transpile is called + THEN the returned SQL contains within predicates. + """ + sql = transpile( + "SELECT * FROM features WHERE interval WITHIN 'chr1:1000-2000'", + tables=["features"], + ) + upper = sql.upper() + assert "SELECT" in upper + assert "1000" in sql + assert "2000" in sql + + # ── CLUSTER transpilation ──────────────────────────────────────── + + def test_cluster_basic(self): + """ + GIVEN a query with CLUSTER(interval) and tables=["features"] + WHEN transpile is called + THEN the returned SQL contains LAG and SUM window functions in a subquery. + """ + sql = transpile( + "SELECT *, CLUSTER(interval) AS cluster_id FROM features", + tables=["features"], + ) + upper = sql.upper() + assert "LAG" in upper + assert "SUM" in upper + + def test_cluster_with_distance(self): + """ + GIVEN a query with CLUSTER(interval, 1000) + WHEN transpile is called + THEN the returned SQL includes a distance offset in the LAG expression. + """ + sql = transpile( + "SELECT *, CLUSTER(interval, 1000) AS cluster_id FROM features", + tables=["features"], + ) + upper = sql.upper() + assert "LAG" in upper + assert "1000" in sql + + # ── MERGE transpilation ────────────────────────────────────────── + + def test_merge_basic(self): + """ + GIVEN a query with MERGE(interval) and tables=["features"] + WHEN transpile is called + THEN the returned SQL contains a CLUSTER CTE with GROUP BY and MIN/MAX aggregation. + """ + sql = transpile( + "SELECT MERGE(interval) FROM features", + tables=["features"], + ) + upper = sql.upper() + assert "MIN" in upper + assert "MAX" in upper + assert "GROUP BY" in upper + + # ── COVERAGE transpilation ─────────────────────────────────────── + + def test_coverage_basic(self): + """ + GIVEN a query with COVERAGE(interval, 1000) and tables=["features"] + WHEN transpile is called + THEN the returned SQL contains a bins CTE, LEFT JOIN, COUNT, GROUP BY, and ORDER BY. + """ + sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM features", + tables=["features"], + ) + upper = sql.upper() + assert "LEFT JOIN" in upper or "LEFT OUTER JOIN" in upper + assert "COUNT" in upper + assert "GROUP BY" in upper + assert "ORDER BY" in upper + assert "1000" in sql + + def test_coverage_mean_stat(self): + """ + GIVEN a query with COVERAGE(interval, 500, stat := 'mean') + WHEN transpile is called + THEN the returned SQL contains an AVG aggregate. + """ + sql = transpile( + "SELECT COVERAGE(interval, 500, stat := 'mean') FROM features", + tables=["features"], + ) + upper = sql.upper() + assert "AVG" in upper + + def test_coverage_mean_with_target(self): + """ + GIVEN a query with COVERAGE(interval, 1000, stat := 'mean', target := 'score') + WHEN transpile is called + THEN the returned SQL contains AVG applied to the score column. + """ + sql = transpile( + "SELECT COVERAGE(interval, 1000, stat := 'mean', target := 'score') FROM features", + tables=["features"], + ) + upper = sql.upper() + assert "AVG" in upper + assert "SCORE" in upper + + def test_coverage_custom_alias(self): + """ + GIVEN a query with COVERAGE(interval, 1000) AS cov + WHEN transpile is called + THEN the aggregate column in the returned SQL is aliased as "cov". + """ + sql = transpile( + "SELECT COVERAGE(interval, 1000) AS cov FROM features", + tables=["features"], + ) + assert "cov" in sql.lower() + + def test_coverage_default_alias(self): + """ + GIVEN a query with bare COVERAGE(interval, 1000) (no alias) + WHEN transpile is called + THEN the aggregate column in the returned SQL is aliased as "value". + """ + sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM features", + tables=["features"], + ) + assert "value" in sql.lower() + + def test_coverage_where_in_join_on(self): + """ + GIVEN a query with COVERAGE and a WHERE clause + WHEN transpile is called + THEN the WHERE condition appears in the JOIN ON condition rather than as a standalone WHERE. + """ + sql = transpile( + "SELECT COVERAGE(interval, 1000) FROM features WHERE chrom = 'chr1'", + tables=["features"], + ) + upper = sql.upper() + # The WHERE should be folded into the JOIN ON condition + assert "JOIN" in upper + assert "CHR1" in upper + + # ── DISTANCE transpilation ─────────────────────────────────────── + + def test_distance_case_expression(self): + """ + GIVEN a query with DISTANCE(a.interval, b.interval) and two tables + WHEN transpile is called + THEN the returned SQL contains a CASE expression for computing distance. + """ + sql = transpile( + "SELECT DISTANCE(a.interval, b.interval) FROM features a, genes b", + tables=["features", "genes"], + ) + upper = sql.upper() + assert "CASE" in upper + + # ── NEAREST transpilation ──────────────────────────────────────── + + def test_nearest_lateral_join(self): + """ + GIVEN a query with NEAREST in a LATERAL join and two tables + WHEN transpile is called + THEN the returned SQL contains a LATERAL subquery with a LIMIT clause. + """ + sql = transpile( + """ + SELECT * + FROM peaks + CROSS JOIN LATERAL NEAREST(genes, reference=peaks.interval, k=3) + """, + tables=["peaks", "genes"], + ) + upper = sql.upper() + assert "LATERAL" in upper + assert "LIMIT" in upper + + # ── Table configuration ────────────────────────────────────────── + + def test_tables_string_list(self): + """ + GIVEN tables parameter as a list of strings + WHEN transpile is called + THEN tables are registered with default column mappings (chrom, start, end). + """ + sql = transpile( + "SELECT * FROM features WHERE interval INTERSECTS 'chr1:100-200'", + tables=["features"], + ) + upper = sql.upper() + assert '"CHROM"' in upper or "CHROM" in upper + assert '"START"' in upper or "START" in upper + assert '"END"' in upper or "END" in upper + + def test_tables_custom_table_objects(self): + """ + GIVEN tables parameter as a list of Table objects with custom column names + WHEN transpile is called + THEN the generated SQL uses those custom column names. + """ + sql = transpile( + "SELECT * FROM features WHERE interval INTERSECTS 'chr1:100-200'", + tables=[ + Table( + "features", + genomic_col="interval", + chrom_col="chromosome", + start_col="start_pos", + end_col="end_pos", + ) + ], + ) + assert "chromosome" in sql or "CHROMOSOME" in sql.upper() + assert "start_pos" in sql or "START_POS" in sql.upper() + assert "end_pos" in sql or "END_POS" in sql.upper() + + def test_tables_none(self): + """ + GIVEN tables parameter is None + WHEN transpile is called + THEN default column names (chrom, start, end) are still used. + """ + sql = transpile( + "SELECT * FROM features WHERE interval INTERSECTS 'chr1:100-200'", + tables=None, + ) + upper = sql.upper() + assert "SELECT" in upper + assert "CHROM" in upper + + def test_tables_mixed_strings_and_objects(self): + """ + GIVEN tables parameter mixes strings and Table objects + WHEN transpile is called + THEN both are correctly registered and the SQL is valid. + """ + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.region + """, + tables=[ + "peaks", + Table("genes", genomic_col="region", chrom_col="seqname"), + ], + ) + upper = sql.upper() + assert "PEAKS" in upper + assert "GENES" in upper + assert "SEQNAME" in upper + + # ── Error handling ─────────────────────────────────────────────── + + def test_invalid_query_raises_parse_error(self): + """ + GIVEN an invalid/unparseable query string + WHEN transpile is called + THEN a ValueError is raised with a message containing "Parse error". + """ + with pytest.raises(ValueError, match="Parse error"): + transpile("SELECT * FORM features") + + def test_coverage_invalid_stat_raises(self): + """ + GIVEN a query with COVERAGE using an invalid stat name + WHEN transpile is called + THEN a ValueError is raised with a message containing "Unknown COVERAGE stat". + """ + with pytest.raises(ValueError, match="Unknown COVERAGE stat"): + transpile( + "SELECT COVERAGE(interval, 1000, stat := 'invalid_stat') FROM features", + tables=["features"], + )