diff --git a/.gitignore b/.gitignore index cc1be48..58fbb4e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,10 @@ # macOS .DS_store +# Rust +target/ +Cargo.lock + # Python __pycache__/ *.py[cod] diff --git a/crates/giql-datafusion/.gitignore b/crates/giql-datafusion/.gitignore new file mode 100644 index 0000000..ca98cd9 --- /dev/null +++ b/crates/giql-datafusion/.gitignore @@ -0,0 +1,2 @@ +/target/ +Cargo.lock diff --git a/crates/giql-datafusion/Cargo.toml b/crates/giql-datafusion/Cargo.toml new file mode 100644 index 0000000..484d043 --- /dev/null +++ b/crates/giql-datafusion/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "giql-datafusion" +version = "0.1.0" +edition = "2021" +description = "DataFusion optimizer for genomic interval (INTERSECTS) joins" +license = "MIT" + +[dependencies] +arrow = { version = "58", default-features = false, features = ["prettyprint"] } +async-trait = "0.1.89" +coitrees = "0.4.0" +datafusion = "53" +futures = "0.3.32" +log = "0.4" +parquet = "58" + +[dev-dependencies] +tempfile = "3" +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } diff --git a/crates/giql-datafusion/src/coitree.rs b/crates/giql-datafusion/src/coitree.rs new file mode 100644 index 0000000..0bddfa5 --- /dev/null +++ b/crates/giql-datafusion/src/coitree.rs @@ -0,0 +1,507 @@ +//! COI tree interval join — build/probe execution using cache-oblivious +//! interval trees from the `coitrees` crate. +//! +//! Used for non-uniform width distributions where binning would cause +//! excessive replication. Each interval is stored exactly once in the +//! tree; queries are O(log N + k) per probe interval. + +use std::any::Any; +use std::collections::HashMap; +use std::fmt; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use arrow::array::{ + Array, AsArray, Int64Array, RecordBatch, UInt32Array, +}; +use arrow::compute; +use arrow::datatypes::SchemaRef; +use coitrees::{COITree, Interval, IntervalTree}; +use datafusion::common::{Column, DFSchemaRef, Result}; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNode}; +use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion::physical_plan::execution_plan::{ + Boundedness, EmissionType, +}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; + +// ── Logical node ──────────────────────────────────────────────── + +/// Logical plan node representing a COI tree interval join. +/// +/// The logical optimizer rule emits this when the sampled width +/// distribution is non-uniform (cost_optimal_bin > 2 * median). +/// The extension planner converts it to a [`COITreeExec`]. +#[derive(Debug, Clone)] +pub struct COITreeJoinNode { + pub left: Arc, + pub right: Arc, + /// Equi-keys from the original join (e.g., chrom = chrom). + pub on: Vec<(Column, Column)>, + /// Interval column names from giql_intersects() args. + pub start_a: Column, + pub end_a: Column, + pub start_b: Column, + pub end_b: Column, + pub schema: DFSchemaRef, +} + +impl Hash for COITreeJoinNode { + fn hash(&self, state: &mut H) { + self.on.hash(state); + self.start_a.hash(state); + self.end_a.hash(state); + self.start_b.hash(state); + self.end_b.hash(state); + } +} + +impl PartialEq for COITreeJoinNode { + fn eq(&self, other: &Self) -> bool { + self.on == other.on + && self.start_a == other.start_a + && self.end_a == other.end_a + && self.start_b == other.start_b + && self.end_b == other.end_b + } +} + +impl Eq for COITreeJoinNode {} + +impl PartialOrd for COITreeJoinNode { + fn partial_cmp( + &self, + _other: &Self, + ) -> Option { + None + } +} + +impl UserDefinedLogicalNode for COITreeJoinNode { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "COITreeJoin" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.left, &self.right] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn check_invariants( + &self, + _check: datafusion::logical_expr::InvariantLevel, + ) -> Result<()> { + Ok(()) + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "COITreeJoin: on=[{}]", + self.on + .iter() + .map(|(l, r)| format!("{l} = {r}")) + .collect::>() + .join(", ") + ) + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result> { + Ok(Arc::new(COITreeJoinNode { + left: Arc::new(inputs[0].clone()), + right: Arc::new(inputs[1].clone()), + on: self.on.clone(), + start_a: self.start_a.clone(), + end_a: self.end_a.clone(), + start_b: self.start_b.clone(), + end_b: self.end_b.clone(), + schema: self.schema.clone(), + })) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } + + fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { + other + .as_any() + .downcast_ref::() + .map_or(false, |o| self == o) + } + + fn dyn_ord( + &self, + _other: &dyn UserDefinedLogicalNode, + ) -> Option { + None + } +} + +// ── Physical execution plan ───────────────────────────────────── + +/// Physical execution plan that uses COI trees for interval joins. +/// +/// Build phase: collect the left (build) side, group by chromosome, +/// construct a `COITree` per chromosome with row indices as metadata. +/// +/// Probe phase: stream the right (probe) side batch by batch, query +/// the per-chromosome tree for each interval, emit joined output. +#[derive(Debug)] +pub struct COITreeExec { + left: Arc, + right: Arc, + /// Equi-key column names (e.g., chrom). + chrom_l: String, + chrom_r: String, + start_l: String, + end_l: String, + start_r: String, + end_r: String, + schema: SchemaRef, + properties: Arc, +} + +impl COITreeExec { + pub fn new( + left: Arc, + right: Arc, + on: &[(Column, Column)], + start_a: &Column, + end_a: &Column, + start_b: &Column, + end_b: &Column, + schema: SchemaRef, + ) -> Self { + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + Partitioning::UnknownPartitioning(1), + EmissionType::Final, + Boundedness::Bounded, + )); + + // Use the first equi-key as the chromosome column. + let (chrom_l, chrom_r) = if let Some((l, r)) = on.first() { + (l.name.clone(), r.name.clone()) + } else { + ("chrom".to_string(), "chrom".to_string()) + }; + + Self { + left, + right, + chrom_l, + chrom_r, + start_l: start_a.name.clone(), + end_l: end_a.name.clone(), + start_r: start_b.name.clone(), + end_r: end_b.name.clone(), + schema, + properties, + } + } +} + +impl DisplayAs for COITreeExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + write!( + f, + "COITreeExec: on=[{} = {}]", + self.chrom_l, self.chrom_r + ) + } +} + +impl ExecutionPlan for COITreeExec { + fn name(&self) -> &str { + "COITreeExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(COITreeExec { + left: children[0].clone(), + right: children[1].clone(), + chrom_l: self.chrom_l.clone(), + chrom_r: self.chrom_r.clone(), + start_l: self.start_l.clone(), + end_l: self.end_l.clone(), + start_r: self.start_r.clone(), + end_r: self.end_r.clone(), + schema: self.schema.clone(), + properties: self.properties.clone(), + })) + } + + fn execute( + &self, + _partition: usize, + context: Arc, + ) -> Result { + let left_plan = self.left.clone(); + let right_plan = self.right.clone(); + let schema = self.schema.clone(); + let chrom_l = self.chrom_l.clone(); + let chrom_r = self.chrom_r.clone(); + let start_l = self.start_l.clone(); + let end_l = self.end_l.clone(); + let start_r = self.start_r.clone(); + let end_r = self.end_r.clone(); + + let stream = futures::stream::once(async move { + // ── Build phase: collect left side, build COI trees ── + let left_batches = + datafusion::physical_plan::collect( + left_plan, + context.clone(), + ) + .await?; + + if left_batches.is_empty() { + return Ok(RecordBatch::new_empty(schema)); + } + + let left_concat = compute::concat_batches( + &left_batches[0].schema(), + &left_batches, + )?; + + let left_schema = left_concat.schema(); + let l_chrom_idx = + left_schema.index_of(&chrom_l)?; + let l_start_idx = + left_schema.index_of(&start_l)?; + let l_end_idx = left_schema.index_of(&end_l)?; + + let l_chrom_col = left_concat.column(l_chrom_idx); + let l_starts = left_concat + .column(l_start_idx) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal( + "start column is not Int64".into(), + ) + })?; + let l_ends = left_concat + .column(l_end_idx) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal( + "end column is not Int64".into(), + ) + })?; + + // Group intervals by chromosome and build COI trees. + // Metadata = row index in left_concat. + let mut chrom_intervals: HashMap< + String, + Vec>, + > = HashMap::new(); + for i in 0..left_concat.num_rows() { + let chrom: &str = l_chrom_col.as_string_view().value(i); + let start = l_starts.value(i) as i32; + // Half-open [start, end) → end-inclusive [start, end-1] + let end = (l_ends.value(i) - 1) as i32; + chrom_intervals + .entry(chrom.to_string()) + .or_default() + .push(Interval::new(start, end, i as u32)); + } + + let trees: HashMap> = + chrom_intervals + .iter() + .map(|(chrom, intervals)| { + (chrom.clone(), COITree::new(intervals)) + }) + .collect(); + + // ── Probe phase: stream right side, query trees ───── + let right_batches = + datafusion::physical_plan::collect( + right_plan, context, + ) + .await?; + + if right_batches.is_empty() { + return Ok(RecordBatch::new_empty(schema)); + } + + let right_concat = compute::concat_batches( + &right_batches[0].schema(), + &right_batches, + )?; + + let right_schema = right_concat.schema(); + let r_chrom_idx = + right_schema.index_of(&chrom_r)?; + let r_start_idx = + right_schema.index_of(&start_r)?; + let r_end_idx = + right_schema.index_of(&end_r)?; + + let r_chrom_col = right_concat.column(r_chrom_idx); + let r_starts = right_concat + .column(r_start_idx) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal( + "start column is not Int64".into(), + ) + })?; + let r_ends = right_concat + .column(r_end_idx) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal( + "end column is not Int64".into(), + ) + })?; + + // Collect join pairs as (left_idx, right_idx). + let mut left_indices: Vec = Vec::new(); + let mut right_indices: Vec = Vec::new(); + + for i in 0..right_concat.num_rows() { + let chrom: &str = r_chrom_col.as_string_view().value(i); + if let Some(tree) = trees.get(chrom) { + let start = r_starts.value(i) as i32; + let end = (r_ends.value(i) - 1) as i32; + tree.query(start, end, |hit| { + left_indices + .push(*hit.metadata); + right_indices.push(i as u32); + }); + } + } + + if left_indices.is_empty() { + return Ok(RecordBatch::new_empty(schema)); + } + + // Build output batch using take() on both sides. + let left_idx_arr = UInt32Array::from(left_indices); + let right_idx_arr = UInt32Array::from(right_indices); + + let mut output_columns: Vec> = + Vec::with_capacity(schema.fields().len()); + + // Left side columns + for col in left_concat.columns() { + output_columns.push(compute::take( + col.as_ref(), + &left_idx_arr, + None, + )?); + } + // Right side columns + for col in right_concat.columns() { + output_columns.push(compute::take( + col.as_ref(), + &right_idx_arr, + None, + )?); + } + + RecordBatch::try_new(schema, output_columns) + .map_err(datafusion::error::DataFusionError::from) + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema.clone(), + stream, + ))) + } +} + +// ── Extension planner ─────────────────────────────────────────── + +/// Converts [`COITreeJoinNode`] logical nodes into [`COITreeExec`] +/// physical plans. +#[derive(Debug)] +pub struct COITreePlanner; + +#[async_trait::async_trait] +impl datafusion::physical_planner::ExtensionPlanner + for COITreePlanner +{ + async fn plan_extension( + &self, + _planner: &dyn datafusion::physical_planner::PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + _session_state: &datafusion::execution::SessionState, + ) -> Result>> { + let Some(join_node) = + node.as_any().downcast_ref::() + else { + return Ok(None); + }; + + // Build the output Arrow schema from the logical schema. + let arrow_schema: SchemaRef = + Arc::new(join_node.schema.as_arrow().clone()); + + Ok(Some(Arc::new(COITreeExec::new( + physical_inputs[0].clone(), + physical_inputs[1].clone(), + &join_node + .on + .iter() + .map(|(l, r)| (l.clone(), r.clone())) + .collect::>(), + &join_node.start_a, + &join_node.end_a, + &join_node.start_b, + &join_node.end_b, + arrow_schema, + )))) + } +} diff --git a/crates/giql-datafusion/src/lib.rs b/crates/giql-datafusion/src/lib.rs new file mode 100644 index 0000000..233f2e8 --- /dev/null +++ b/crates/giql-datafusion/src/lib.rs @@ -0,0 +1,192 @@ +//! DataFusion optimizer for genomic interval (INTERSECTS) joins. +//! +//! This crate provides a logical [`OptimizerRule`] that detects +//! `giql_intersects()` function calls in join filters and rewrites +//! them into binned equi-joins using UNNEST. Bin size is chosen +//! adaptively from table statistics when available. +//! +//! The `giql_intersects` function is a placeholder UDF emitted by the +//! GIQL transpiler's `"datafusion"` dialect. It preserves INTERSECTS +//! semantics through the SQL layer so the optimizer can match on it +//! directly, without heuristic pattern detection. +//! +//! # Usage +//! +//! ```rust,no_run +//! use datafusion::execution::SessionStateBuilder; +//! use datafusion::prelude::*; +//! use giql_datafusion::register_optimizer; +//! +//! let state = SessionStateBuilder::new() +//! .with_default_features() +//! .build(); +//! let state = register_optimizer(state); +//! let ctx = SessionContext::from(state); +//! ``` + +pub mod coitree; +pub mod logical_rule; + +pub use logical_rule::{IntersectsConfig, IntersectsLogicalRule}; + +use std::fmt::Debug; +use std::sync::Arc; + +use datafusion::common::Result; +use datafusion::execution::SessionState; +use datafusion::logical_expr::{ + ColumnarValue, LogicalPlan, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use datafusion::optimizer::OptimizerRule; +use datafusion::physical_plan::ExecutionPlan; + +// ── Placeholder UDF ───────────────────────────────────────────── + +/// Placeholder `giql_intersects(start_a, end_a, start_b, end_b)` UDF. +/// +/// Exists only so DataFusion's SQL parser accepts the function call. +/// The logical optimizer rule rewrites it away before execution. +#[derive(Debug, Hash, PartialEq, Eq)] +struct GiqlIntersectsUdf { + signature: Signature, +} + +impl GiqlIntersectsUdf { + fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::Any(4), + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for GiqlIntersectsUdf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "giql_intersects" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[arrow::datatypes::DataType], + ) -> Result { + Ok(arrow::datatypes::DataType::Boolean) + } + + fn invoke_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { + Err(datafusion::error::DataFusionError::Internal( + "giql_intersects should be rewritten by the logical \ + optimizer rule — was the IntersectsLogicalRule registered?" + .into(), + )) + } +} + +/// Create the placeholder `giql_intersects` scalar UDF. +pub fn giql_intersects_udf() -> ScalarUDF { + ScalarUDF::from(GiqlIntersectsUdf::new()) +} + +// ── Registration ──────────────────────────────────────────────── + +// ── Custom query planner ───────────────────────────────────────── + +/// Query planner that includes the COI tree extension planner. +#[derive(Debug)] +struct GiqlQueryPlanner; + +#[async_trait::async_trait] +impl datafusion::execution::context::QueryPlanner + for GiqlQueryPlanner +{ + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> Result> { + use datafusion::physical_planner::{ + DefaultPhysicalPlanner, PhysicalPlanner, + }; + + let planner = + DefaultPhysicalPlanner::with_extension_planners(vec![ + Arc::new(coitree::COITreePlanner), + ]); + planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +// ── Registration ──────────────────────────────────────────────── + +/// Build a [`SessionState`] with the INTERSECTS logical optimizer +/// rule, the `giql_intersects` placeholder UDF, and the COI tree +/// extension planner. +/// +/// The logical rule detects `giql_intersects()` calls in join +/// filters and rewrites them into either: +/// - A binned equi-join (uniform width distributions) +/// - A COI tree join (non-uniform distributions) +pub fn register_optimizer(state: SessionState) -> SessionState { + use datafusion::execution::SessionStateBuilder; + + let logical_rule: Arc = + Arc::new(IntersectsLogicalRule::new()); + + let mut logical_rules: Vec> = + state.optimizers().to_vec(); + logical_rules.push(logical_rule); + + let udf = Arc::new(giql_intersects_udf()); + + let mut scalar_fns: Vec> = + state.scalar_functions().values().cloned().collect(); + scalar_fns.push(udf); + + SessionStateBuilder::new_from_existing(state) + .with_optimizer_rules(logical_rules) + .with_scalar_functions(scalar_fns) + .with_query_planner(Arc::new(GiqlQueryPlanner)) + .build() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_register_optimizer_adds_rule_and_udf() { + use datafusion::execution::SessionStateBuilder; + + let state = SessionStateBuilder::new() + .with_default_features() + .build(); + let n_before = state.optimizers().len(); + + let state = register_optimizer(state); + + // Logical rule was added + assert_eq!(state.optimizers().len(), n_before + 1); + let last_rule = state.optimizers().last().unwrap(); + assert_eq!(last_rule.name(), "intersects_logical_binned"); + + // UDF was registered + assert!( + state.scalar_functions().contains_key("giql_intersects") + ); + } +} diff --git a/crates/giql-datafusion/src/logical_rule.rs b/crates/giql-datafusion/src/logical_rule.rs new file mode 100644 index 0000000..9ed8683 --- /dev/null +++ b/crates/giql-datafusion/src/logical_rule.rs @@ -0,0 +1,865 @@ +use std::sync::Arc; + +use datafusion::common::tree_node::Transformed; +use datafusion::common::{Column, Result, ScalarValue}; +use datafusion::datasource::listing::ListingTable; +use datafusion::datasource::source_as_provider; +use datafusion::logical_expr::expr::ScalarFunction; +use datafusion::logical_expr::{ + BinaryExpr, Expr, Join, JoinType, LogicalPlan, LogicalPlanBuilder, + Operator, +}; +use datafusion::optimizer::{OptimizerConfig, OptimizerRule}; +use datafusion::prelude::*; + +/// Logical optimizer rule that rewrites interval overlap joins into +/// binned equi-joins using UNNEST. +/// +/// Detects `giql_intersects(start_a, end_a, start_b, end_b)` function +/// calls in join filters (emitted by the GIQL transpiler's +/// `"datafusion"` dialect) and rewrites them to: +/// +/// `SELECT ... FROM Unnest(a + bins) JOIN Unnest(b + bins) +/// ON chrom = chrom AND bin = bin WHERE start < end AND end > start` +/// +/// DataFusion handles UNNEST, hash join, and dedup natively with +/// full parallelism. +/// Configuration for the INTERSECTS logical optimizer rule. +#[derive(Debug, Clone, Default)] +pub struct IntersectsConfig { + /// Force the binned equi-join strategy instead of the default + /// COI tree join. The bin size is chosen adaptively from + /// Parquet sampling or column-level statistics. This is an + /// escape hatch for benchmarking; COI tree is faster in all + /// tested distributions. + pub force_binned: bool, +} + +#[derive(Debug)] +pub struct IntersectsLogicalRule { + config: IntersectsConfig, +} + +impl IntersectsLogicalRule { + pub fn new() -> Self { + Self { + config: IntersectsConfig::default(), + } + } + + pub fn with_config(config: IntersectsConfig) -> Self { + Self { config } + } +} + +impl Default for IntersectsLogicalRule { + fn default() -> Self { + Self::new() + } +} + +impl OptimizerRule for IntersectsLogicalRule { + fn name(&self) -> &str { + "intersects_logical_binned" + } + + fn apply_order(&self) -> Option { + Some(datafusion::optimizer::ApplyOrder::BottomUp) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + let LogicalPlan::Join(ref join) = plan else { + return Ok(Transformed::no(plan)); + }; + + if join.join_type != JoinType::Inner { + return Ok(Transformed::no(plan)); + } + + // Skip if already rewritten (has __giql_bins equi-keys) + let already_binned = join.on.iter().any(|(l, _)| { + if let Expr::Column(c) = l { + c.name.starts_with("__giql_bins") + } else { + false + } + }); + if already_binned { + return Ok(Transformed::no(plan)); + } + + // Detect giql_intersects() function call in the filter + let overlap = match &join.filter { + Some(filter) => detect_giql_intersects(filter), + None => None, + }; + + let Some((start_a, end_a, start_b, end_b)) = overlap else { + return Ok(Transformed::no(plan)); + }; + + if self.config.force_binned { + // Binned equi-join path (escape hatch for benchmarking). + let left_stats = get_table_stats( + &join.left, &start_a.name, &end_a.name, + ); + let right_stats = get_table_stats( + &join.right, &start_b.name, &end_b.name, + ); + let bin_size = + choose_bin_size(&left_stats, &right_stats); + + log::debug!( + "INTERSECTS logical rule: rewriting to \ + binned join (force_binned), bin_size={bin_size}" + ); + + let rewritten_filter = + join.filter.as_ref().map(|f| { + replace_giql_intersects( + f, &start_a, &end_a, &start_b, &end_b, + ) + }); + + let rewritten = rewrite_to_binned( + join, + bin_size, + &start_a, + &end_a, + &start_b, + &end_b, + rewritten_filter.as_ref(), + )?; + return Ok(Transformed::yes(rewritten)); + } + + // Default: COI tree join — faster across all tested + // distributions, no bin replication overhead. + use crate::coitree::COITreeJoinNode; + use datafusion::logical_expr::Extension; + + log::debug!("INTERSECTS logical rule: using COI tree join"); + + let on: Vec<(Column, Column)> = join + .on + .iter() + .map(|(l, r)| { + ( + extract_column(l).unwrap_or_else(|| { + Column::new(None::<&str>, "chrom") + }), + extract_column(r).unwrap_or_else(|| { + Column::new(None::<&str>, "chrom") + }), + ) + }) + .collect(); + + let node = COITreeJoinNode { + left: Arc::new((*join.left).clone()), + right: Arc::new((*join.right).clone()), + on, + start_a, + end_a, + start_b, + end_b, + schema: join.schema.clone(), + }; + + Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(node), + }))) + } +} + +// ── Pattern detection ─────────────────────────────────────────── + +/// Detect `giql_intersects(start_a, end_a, start_b, end_b)` in a +/// filter expression. Searches through AND-combined predicates. +/// +/// Returns `(start_a, end_a, start_b, end_b)` column references. +fn detect_giql_intersects( + expr: &Expr, +) -> Option<(Column, Column, Column, Column)> { + match expr { + Expr::ScalarFunction(func) + if func.name() == "giql_intersects" + && func.args.len() == 4 => + { + let start_a = extract_column(&func.args[0])?; + let end_a = extract_column(&func.args[1])?; + let start_b = extract_column(&func.args[2])?; + let end_b = extract_column(&func.args[3])?; + Some((start_a, end_a, start_b, end_b)) + } + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) => detect_giql_intersects(left) + .or_else(|| detect_giql_intersects(right)), + _ => None, + } +} + +/// Extract a Column from an Expr, handling TryCast/Cast wrappers +/// that DataFusion may insert during type coercion. +fn extract_column(expr: &Expr) -> Option { + match expr { + Expr::Column(c) => Some(c.clone()), + Expr::Cast(cast) => extract_column(&cast.expr), + Expr::TryCast(tc) => extract_column(&tc.expr), + _ => None, + } +} + +/// Replace `giql_intersects(start_a, end_a, start_b, end_b)` in an +/// expression tree with `start_a < end_b AND end_a > start_b`. +fn replace_giql_intersects( + expr: &Expr, + start_a: &Column, + end_a: &Column, + start_b: &Column, + end_b: &Column, +) -> Expr { + match expr { + Expr::ScalarFunction(func) + if func.name() == "giql_intersects" => + { + // start_a < end_b AND end_a > start_b + Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(start_a.clone())), + op: Operator::Lt, + right: Box::new(Expr::Column(end_b.clone())), + })), + op: Operator::And, + right: Box::new(Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(end_a.clone())), + op: Operator::Gt, + right: Box::new(Expr::Column(start_b.clone())), + })), + }) + } + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { + left: Box::new(replace_giql_intersects( + left, start_a, end_a, start_b, end_b, + )), + op: *op, + right: Box::new(replace_giql_intersects( + right, start_a, end_a, start_b, end_b, + )), + }) + } + other => other.clone(), + } +} + +// ── Stats collection ──────────────────────────────────────────── + +struct LogicalStats { + /// Reserved for future use (e.g., skipping binning on tiny + /// tables where a nested-loop join would be cheaper). + #[allow(dead_code)] + row_count: Option, + start_min: Option, + start_max: Option, + end_min: Option, + end_max: Option, + /// Sampled width statistics from Parquet row groups. + sampled: Option, +} + +/// Width statistics computed from sampled Parquet row groups. +struct SampledWidthStats { + /// Smallest bin size with mean replication <= 2.0. + optimal_bin: i64, + /// Median interval width. + median: i64, +} + +fn get_table_stats( + plan: &LogicalPlan, + start_col_name: &str, + end_col_name: &str, +) -> Option { + match plan { + LogicalPlan::TableScan(ts) => { + let provider = source_as_provider(&ts.source).ok()?; + + let mut stats = provider + .statistics() + .and_then(|s| { + stats_to_logical( + &s, + &ts.source.schema(), + start_col_name, + end_col_name, + ) + }) + .unwrap_or(LogicalStats { + row_count: None, + start_min: None, + start_max: None, + end_min: None, + end_max: None, + sampled: None, + }); + + // Try lightweight Parquet sampling for accurate width + stats.sampled = try_sample_from_listing( + provider.as_ref(), + start_col_name, + end_col_name, + ); + + Some(stats) + } + _ => plan.inputs().first().and_then(|child| { + get_table_stats(child, start_col_name, end_col_name) + }), + } +} + +fn stats_to_logical( + stats: &datafusion::common::Statistics, + schema: &arrow::datatypes::SchemaRef, + start_col_name: &str, + end_col_name: &str, +) -> Option { + let row_count = match stats.num_rows { + datafusion::common::stats::Precision::Exact(n) => Some(n), + datafusion::common::stats::Precision::Inexact(n) => Some(n), + _ => None, + }; + let start_idx = schema + .fields() + .iter() + .position(|f| f.name() == start_col_name)?; + let end_idx = schema + .fields() + .iter() + .position(|f| f.name() == end_col_name)?; + let col_stats = &stats.column_statistics; + let start_stats = col_stats.get(start_idx)?; + let end_stats = col_stats.get(end_idx)?; + Some(LogicalStats { + row_count, + start_min: scalar_to_i64(&start_stats.min_value), + start_max: scalar_to_i64(&start_stats.max_value), + end_min: scalar_to_i64(&end_stats.min_value), + end_max: scalar_to_i64(&end_stats.max_value), + sampled: None, // filled by try_sample_from_listing + }) +} + +fn scalar_to_i64( + precision: &datafusion::common::stats::Precision, +) -> Option { + match precision { + datafusion::common::stats::Precision::Exact(v) + | datafusion::common::stats::Precision::Inexact(v) => match v { + ScalarValue::Int32(Some(n)) => Some(*n as i64), + ScalarValue::Int64(Some(n)) => Some(*n), + _ => None, + }, + _ => None, + } +} + +// ── Lightweight Parquet sampling ───────────────────────────────── + +/// Try to sample interval widths from a Parquet-backed ListingTable. +/// +/// Returns `None` silently if the provider is not a ListingTable, +/// the file is not Parquet, or any I/O error occurs. +fn try_sample_from_listing( + provider: &dyn datafusion::catalog::TableProvider, + start_col: &str, + end_col: &str, +) -> Option { + let listing = provider.as_any().downcast_ref::()?; + let path_str = listing.table_paths().first()?.as_str(); + + // ListingTableUrl stores file:// URLs. For remote sources + // (s3://, gs://, etc.) the else branch produces a path that + // won't exist on disk — File::open fails and we fall back to + // column-level stats gracefully. + let fs_path = if let Some(p) = path_str.strip_prefix("file://") { + std::path::PathBuf::from(p) + } else { + std::path::PathBuf::from(format!("/{path_str}")) + }; + + match sample_width_stats(&fs_path, start_col, end_col) { + Some(stats) => { + log::debug!( + "INTERSECTS logical: sampled optimal_bin={}, \ + median={} from {path_str}", + stats.optimal_bin, + stats.median + ); + Some(stats) + } + None => { + log::debug!( + "INTERSECTS logical: Parquet sampling failed \ + for {path_str}" + ); + None + } + } +} + +/// Read start/end columns from 1–3 representative Parquet row groups +/// and choose the optimal bin size by minimizing replication cost. +/// +/// Binary searches for the smallest bin size B such that the mean +/// replication factor `mean(ceil(w_i / B))` is at most +/// `TARGET_MEAN_REPLICATION`. This naturally handles all width +/// distributions: the few wide outliers pull B upward in proportion +/// to their actual cost, without over-sizing bins for the majority. +fn sample_width_stats( + path: &std::path::Path, + start_col: &str, + end_col: &str, +) -> Option { + use arrow::array::{Array, Int64Array}; + use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; + use parquet::arrow::ProjectionMask; + + let file = std::fs::File::open(path).ok()?; + let builder = + ParquetRecordBatchReaderBuilder::try_new(file).ok()?; + + let parquet_schema = builder.parquet_schema().clone(); + let num_row_groups = builder.metadata().num_row_groups(); + if num_row_groups == 0 { + return None; + } + + // Find column indices in the Parquet schema + let start_idx = parquet_schema + .columns() + .iter() + .position(|c| c.name() == start_col)?; + let end_idx = parquet_schema + .columns() + .iter() + .position(|c| c.name() == end_col)?; + + // Select representative row groups: first, middle, last + let mut rg_indices = vec![0]; + if num_row_groups > 2 { + rg_indices.push(num_row_groups / 2); + } + if num_row_groups > 1 { + rg_indices.push(num_row_groups - 1); + } + + let mask = ProjectionMask::leaves( + &parquet_schema, + [start_idx, end_idx], + ); + + // Cap batch size to bound memory for very large row groups. + let reader = builder + .with_projection(mask) + .with_row_groups(rg_indices) + .with_batch_size(100_000) + .build() + .ok()?; + + let mut widths: Vec = Vec::new(); + const MAX_SAMPLES: usize = 300_000; // ~3 row groups × 100K + + for batch in reader { + let batch = batch.ok()?; + // ProjectionMask preserves original column order, so + // column 0 is start and column 1 is end (assuming + // start_idx < end_idx, which holds for all standard + // genomic schemas). + let starts = batch + .column(0) + .as_any() + .downcast_ref::()?; + let ends = batch + .column(1) + .as_any() + .downcast_ref::()?; + + for i in 0..batch.num_rows() { + if !starts.is_null(i) && !ends.is_null(i) { + let w = ends.value(i) - starts.value(i); + // Skip malformed intervals where end < start + if w > 0 { + widths.push(w); + } + } + } + + if widths.len() >= MAX_SAMPLES { + break; + } + } + + if widths.is_empty() { + return None; + } + + widths.sort_unstable(); + let median = widths[widths.len() / 2]; + let optimal_bin = find_optimal_bin_size(&widths); + Some(SampledWidthStats { + optimal_bin, + median, + }) +} + +/// Find the smallest bin size B such that the mean replication +/// factor across all sampled widths is at most TARGET. +/// +/// Binary searches over [1, max_width]. For each candidate B, +/// `mean_replication(B) = sum(ceil(w_i / B)) / N`. Since widths +/// are sorted, all w_i <= B contribute ceil=1 — a binary search +/// finds the cutoff index, making each evaluation O(N_above_B) +/// rather than O(N). +fn find_optimal_bin_size(sorted_widths: &[i64]) -> i64 { + /// Target mean replication factor. Each interval is copied + /// into ceil(width / bin_size) bins on average. 2.0 means + /// the average interval spans at most 2 bins — good + /// selectivity with bounded replication. + const TARGET_MEAN_REPL: f64 = 2.0; + + let max_width = *sorted_widths.last().unwrap(); + + // Binary search: lo always has mean_repl > target, + // hi always has mean_repl <= target. + let mut lo: i64 = 1; + let mut hi: i64 = max_width; + + // At hi = max_width, every interval fits in 1-2 bins, so + // mean_repl <= 2.0. At lo = 1, mean_repl = mean(widths). + // If even max_width doesn't meet the target (shouldn't happen + // since ceil(w/w) <= 2 for all w), return max_width. + if mean_replication(sorted_widths, hi) > TARGET_MEAN_REPL { + return hi; + } + + while lo < hi { + let mid = lo + (hi - lo) / 2; + if mid == lo { + break; + } + if mean_replication(sorted_widths, mid) <= TARGET_MEAN_REPL { + hi = mid; + } else { + lo = mid; + } + } + + hi +} + +/// Compute mean(ceil(w_i / B)) for a sorted widths array. +/// +/// All widths <= B have ceil = 1. Binary search finds the cutoff, +/// then only iterate the tail above B. +fn mean_replication(sorted_widths: &[i64], bin_size: i64) -> f64 { + let n = sorted_widths.len(); + // Find first index where width > bin_size + let cutoff = + sorted_widths.partition_point(|&w| w <= bin_size); + // All [0..cutoff) contribute 1 each + let mut total: i64 = cutoff as i64; + // [cutoff..n) contribute ceil(w / bin_size) each + for &w in &sorted_widths[cutoff..] { + total += (w + bin_size - 1) / bin_size; // ceil division + } + total as f64 / n as f64 +} + +// ── Strategy decision ─────────────────────────────────────────── + +/// Default bin size when stats are unavailable. +const DEFAULT_BIN_SIZE: usize = 10_000; + +/// Choose a bin size for the forced-binning path. +/// +/// Uses sampled p95 width from Parquet when available, falling +/// back to a column-level min/max heuristic. +fn choose_bin_size( + left: &Option, + right: &Option, +) -> usize { + // Tier 1: cost-optimal bin size from Parquet sampling. + let sampled: Option<&SampledWidthStats> = [left, right] + .iter() + .filter_map(|s| s.as_ref()?.sampled.as_ref()) + .next(); + + if let Some(stats) = sampled { + let bin_size = (stats.optimal_bin.max(1) as usize) + .clamp(1_000, 1_000_000); + log::debug!( + "INTERSECTS logical: bin_size={bin_size} \ + (from sampled optimal_bin={})", + stats.optimal_bin + ); + return bin_size; + } + + // Tier 2: column-level min/max heuristic. + let width_from_stats = |s: &LogicalStats| -> Option { + let min_start = s.start_min?; + let max_start = s.start_max?; + let min_end = s.end_min?; + let max_end = s.end_max?; + let w1 = min_end - min_start; + let w2 = max_end - max_start; + Some(w1.max(w2).max(1)) + }; + + let l_width = left.as_ref().and_then(width_from_stats); + let r_width = right.as_ref().and_then(width_from_stats); + + let bin_size = match (l_width, r_width) { + (Some(l), Some(r)) => { + (l.max(r).max(1) as usize).clamp(1_000, 1_000_000) + } + (Some(w), None) | (None, Some(w)) => { + (w.max(1) as usize).clamp(1_000, 1_000_000) + } + (None, None) => DEFAULT_BIN_SIZE, + }; + + log::debug!( + "INTERSECTS logical: bin_size={bin_size} \ + (column-level heuristic)" + ); + bin_size +} + +// ── Plan rewrite ──────────────────────────────────────────────── + +/// Extract the table qualifier from a plan's schema. +/// +/// Uses the qualifier of the first column in the plan's output +/// schema, which reflects SQL aliases (e.g., `intervals2` in +/// `FROM intervals JOIN intervals AS intervals2`). This is more +/// robust than walking to the TableScan, which would return the +/// physical table name and miss SQL aliases. +fn get_plan_qualifier(plan: &LogicalPlan) -> Option { + plan.schema() + .columns() + .first() + .and_then(|c| c.relation.as_ref()) + .map(|r| r.table().to_string()) +} + +fn rewrite_to_binned( + join: &Join, + bin_size: usize, + start_a: &Column, + end_a: &Column, + start_b: &Column, + end_b: &Column, + rewritten_filter: Option<&Expr>, +) -> Result { + let bs = bin_size as i64; + + // Get table qualifiers for aliasing after UNNEST. Uses the + // schema qualifier (which reflects SQL aliases) so column + // references in the filter resolve correctly after the rewrite. + let left_alias = get_plan_qualifier(&join.left) + .unwrap_or_else(|| "l".to_string()); + let right_alias = get_plan_qualifier(&join.right) + .unwrap_or_else(|| "r".to_string()); + + let left_expanded = expand_with_bins( + (*join.left).clone(), + "__giql_bins_l", + bs, + &left_alias, + start_a, + end_a, + )?; + let right_expanded = expand_with_bins( + (*join.right).clone(), + "__giql_bins_r", + bs, + &right_alias, + start_b, + end_b, + )?; + + // Equi-keys: original keys re-qualified with the aliases + + // bin columns + let mut left_keys: Vec = join + .on + .iter() + .map(|(l, _)| { + if let Expr::Column(c) = l { + Expr::Column(Column::new( + Some(left_alias.clone()), + &c.name, + )) + } else { + l.clone() + } + }) + .collect(); + let mut right_keys: Vec = join + .on + .iter() + .map(|(_, r)| { + if let Expr::Column(c) = r { + Expr::Column(Column::new( + Some(right_alias.clone()), + &c.name, + )) + } else { + r.clone() + } + }) + .collect(); + left_keys + .push(col(format!("{left_alias}.__giql_bins_l"))); + right_keys + .push(col(format!("{right_alias}.__giql_bins_r"))); + + // Build the join with the rewritten filter (giql_intersects + // replaced with real overlap predicates) and extra bin + // equi-keys. + let joined = LogicalPlanBuilder::from(left_expanded) + .join_with_expr_keys( + right_expanded, + JoinType::Inner, + (left_keys, right_keys), + rewritten_filter.cloned(), + )? + .build()?; + + // Add canonical-bin filter to eliminate duplicates from + // multi-bin matches. For each pair, only emit from the bin + // that equals the GREATER of the two intervals' first bins. + // This makes DISTINCT unnecessary. + let left_first_bin = cast( + Expr::Column(Column::new( + Some(left_alias.clone()), + &start_a.name, + )), + arrow::datatypes::DataType::Int64, + ) / lit(bs); + let right_first_bin = cast( + Expr::Column(Column::new( + Some(right_alias.clone()), + &start_b.name, + )), + arrow::datatypes::DataType::Int64, + ) / lit(bs); + + // canonical_bin = CASE WHEN l_fb >= r_fb THEN l_fb ELSE r_fb + let canonical_bin = + Expr::Case(datafusion::logical_expr::expr::Case { + expr: None, + when_then_expr: vec![( + Box::new( + left_first_bin.clone().gt_eq(right_first_bin.clone()), + ), + Box::new(left_first_bin), + )], + else_expr: Some(Box::new(right_first_bin)), + }); + + let bins_col = + col(format!("{left_alias}.__giql_bins_l")); + let dedup_filter = bins_col.eq(canonical_bin); + + let filtered = LogicalPlanBuilder::from(joined) + .filter(dedup_filter)? + .build()?; + + // Project to strip bin columns. The canonical-bin filter above + // already ensures each pair appears exactly once, so DISTINCT + // is unnecessary. + let output_exprs: Vec = filtered + .schema() + .columns() + .into_iter() + .filter(|c| { + c.name != "__giql_bins_l" && c.name != "__giql_bins_r" + }) + .map(|c| Expr::Column(c)) + .collect(); + + LogicalPlanBuilder::from(filtered) + .project(output_exprs)? + .build() +} + +/// Add a `range(start/B, (end-1)/B + 1)` column, unnest it, and +/// wrap in a SubqueryAlias to preserve the table qualifier for +/// the join filter. +fn expand_with_bins( + input: LogicalPlan, + bin_col_name: &str, + bin_size: i64, + table_alias: &str, + start_col: &Column, + end_col: &Column, +) -> Result { + let schema = input.schema().clone(); + + // Cast start/end to Int64 first, then compute bin boundaries: + // range(CAST(start AS BIGINT) / B, (CAST(end AS BIGINT) - 1) + // / B + 1) + let start_i64 = cast( + Expr::Column(start_col.clone()), + arrow::datatypes::DataType::Int64, + ); + let end_i64 = cast( + Expr::Column(end_col.clone()), + arrow::datatypes::DataType::Int64, + ); + let start_bin = start_i64 / lit(bin_size); + let end_bin = + (end_i64 - lit(1i64)) / lit(bin_size) + lit(1i64); + + // Build: SELECT *, range(start_bin, end_bin) AS __giql_bins + // FROM input. Then UNNEST(__giql_bins) + let range_expr = + Expr::ScalarFunction(ScalarFunction::new_udf( + datafusion::functions_nested::range::range_udf(), + vec![start_bin, end_bin], + )) + .alias(bin_col_name); + + let mut proj_exprs: Vec = schema + .columns() + .into_iter() + .map(|c| Expr::Column(c)) + .collect(); + proj_exprs.push(range_expr); + + let with_bins = LogicalPlanBuilder::from(input) + .project(proj_exprs)? + .build()?; + + // Unnest the bin list column, then re-apply the table alias + // so that qualified column references (e.g., a.start) in the + // join filter resolve correctly against this side. + LogicalPlanBuilder::from(with_bins) + .unnest_column(bin_col_name)? + .alias(table_alias)? + .build() +} diff --git a/crates/giql-datafusion/tests/logical_rule_test.rs b/crates/giql-datafusion/tests/logical_rule_test.rs new file mode 100644 index 0000000..c1cb96a --- /dev/null +++ b/crates/giql-datafusion/tests/logical_rule_test.rs @@ -0,0 +1,1370 @@ +//! Tests for the logical optimizer rule (`logical_rule.rs`). +//! +//! Covers: +//! - OptimizerRule trait implementation (name, apply_order, supports_rewrite) +//! - Join type filtering (inner only, skip left/right/full outer) +//! - Already-binned join detection (skip re-rewrite) +//! - giql_intersects() function detection +//! - Adaptive bin sizing from table statistics +//! - Canonical-bin dedup filter correctness +//! - Full pipeline integration through DataFusion with the logical rule + +use std::path::Path; +use std::sync::Arc; + +use arrow::array::{Int64Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion::execution::SessionStateBuilder; +use datafusion::logical_expr::LogicalPlan; +use datafusion::optimizer::OptimizerRule; +use datafusion::prelude::*; +use parquet::arrow::ArrowWriter; +use tempfile::TempDir; + +use giql_datafusion::logical_rule::IntersectsLogicalRule; +use giql_datafusion::register_optimizer; + +// ── Helpers ───────────────────────────────────────────────────── + +fn make_rule() -> IntersectsLogicalRule { + IntersectsLogicalRule::new() +} + +fn write_intervals_parquet( + dir: &Path, + filename: &str, + chroms: &[&str], + starts: &[i64], + ends: &[i64], +) -> std::path::PathBuf { + let schema = Arc::new(Schema::new(vec![ + Field::new("chrom", DataType::Utf8, false), + Field::new("start", DataType::Int64, false), + Field::new("end", DataType::Int64, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(chroms.to_vec())), + Arc::new(Int64Array::from(starts.to_vec())), + Arc::new(Int64Array::from(ends.to_vec())), + ], + ) + .unwrap(); + + let path = dir.join(filename); + let file = std::fs::File::create(&path).unwrap(); + let mut writer = + ArrowWriter::try_new(file, schema, None).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + path +} + +fn write_intervals_parquet_custom_schema( + dir: &Path, + filename: &str, + schema: Arc, + chroms: &[&str], + starts: &[i64], + ends: &[i64], +) -> std::path::PathBuf { + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(chroms.to_vec())), + Arc::new(Int64Array::from(starts.to_vec())), + Arc::new(Int64Array::from(ends.to_vec())), + ], + ) + .unwrap(); + + let path = dir.join(filename); + let file = std::fs::File::create(&path).unwrap(); + let mut writer = + ArrowWriter::try_new(file, schema, None).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + path +} + +/// Create a SessionContext with the logical rule and UDF registered. +fn make_ctx() -> SessionContext { + let state = SessionStateBuilder::new() + .with_default_features() + .build(); + let state = register_optimizer(state); + SessionContext::from(state) +} + +const INTERSECTS_SQL: &str = "\ + SELECT a.chrom, a.start, a.\"end\", \ + b.chrom AS chrom_b, b.start AS start_b, b.\"end\" AS end_b \ + FROM a JOIN b \ + ON a.chrom = b.chrom \ + AND giql_intersects(a.start, a.\"end\", b.start, b.\"end\")"; + +// ── OptimizerRule trait tests ─────────────────────────────────── + +#[test] +fn test_rule_name() { + let rule = make_rule(); + assert_eq!(rule.name(), "intersects_logical_binned"); +} + +#[test] +fn test_rule_apply_order_is_bottom_up() { + let rule = make_rule(); + let order = rule.apply_order(); + assert!(order.is_some()); + assert!(matches!( + order.unwrap(), + datafusion::optimizer::ApplyOrder::BottomUp + )); +} + +#[test] +fn test_rule_supports_rewrite() { + let rule = make_rule(); + #[allow(deprecated)] + let supports = rule.supports_rewrite(); + assert!(supports); +} + +// ── Rewrite skipping tests ────────────────────────────────────── + +#[test] +fn test_rewrite_skips_non_join_plan() { + // Given a non-join logical plan (EmptyRelation), + // When the rule is applied, + // Then the plan is returned unchanged. + let rule = make_rule(); + let config = datafusion::optimizer::OptimizerContext::new(); + + let plan = LogicalPlan::EmptyRelation( + datafusion::logical_expr::EmptyRelation { + produce_one_row: false, + schema: Arc::new( + datafusion::common::DFSchema::empty(), + ), + }, + ); + + let result = rule.rewrite(plan.clone(), &config).unwrap(); + assert!(!result.transformed); +} + +#[tokio::test] +async fn test_rewrite_skips_left_join() { + // Given a LEFT JOIN with overlap predicates, + // When the logical rule is applied, + // Then the plan is not rewritten (only INNER joins are supported). + let ctx = make_ctx(); + let schema = Arc::new(Schema::new(vec![ + Field::new("chrom", DataType::Utf8, false), + Field::new("start", DataType::Int64, false), + Field::new("end", DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["chr1"])), + Arc::new(Int64Array::from(vec![100])), + Arc::new(Int64Array::from(vec![200])), + ], + ) + .unwrap(); + + let table = datafusion::datasource::MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()]], + ) + .unwrap(); + let table2 = datafusion::datasource::MemTable::try_new( + schema, + vec![vec![batch]], + ) + .unwrap(); + ctx.register_table("a", Arc::new(table)).unwrap(); + ctx.register_table("b", Arc::new(table2)).unwrap(); + + let left_join_sql = "\ + SELECT a.chrom, a.start, a.\"end\", \ + b.chrom, b.start, b.\"end\" \ + FROM a LEFT JOIN b \ + ON a.chrom = b.chrom \ + AND giql_intersects(a.start, a.\"end\", b.start, b.\"end\")"; + + let df = ctx.sql(left_join_sql).await.unwrap(); + let plan = df.logical_plan().clone(); + + let rule = make_rule(); + let config = datafusion::optimizer::OptimizerContext::new(); + + let result = rule.rewrite(plan, &config).unwrap(); + // DataFusion may restructure non-INNER joins before our rule + // sees them (e.g., converting to a Filter + CrossJoin), so the + // rule may not receive a Join node at all. We verify the rule + // does not panic; the important guarantee is that non-INNER + // joins are never rewritten to binned joins. + assert!(!result.transformed); +} + +#[tokio::test] +async fn test_rewrite_skips_right_join() { + // Given a RIGHT JOIN with overlap predicates, + // When the logical rule is applied, + // Then the plan is not rewritten (only INNER joins are supported). + let ctx = make_ctx(); + let schema = Arc::new(Schema::new(vec![ + Field::new("chrom", DataType::Utf8, false), + Field::new("start", DataType::Int64, false), + Field::new("end", DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["chr1"])), + Arc::new(Int64Array::from(vec![100])), + Arc::new(Int64Array::from(vec![200])), + ], + ) + .unwrap(); + + let table = datafusion::datasource::MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()]], + ) + .unwrap(); + let table2 = datafusion::datasource::MemTable::try_new( + schema, + vec![vec![batch]], + ) + .unwrap(); + ctx.register_table("a", Arc::new(table)).unwrap(); + ctx.register_table("b", Arc::new(table2)).unwrap(); + + let right_join_sql = "\ + SELECT a.chrom, a.start, a.\"end\", \ + b.chrom, b.start, b.\"end\" \ + FROM a RIGHT JOIN b \ + ON a.chrom = b.chrom \ + AND giql_intersects(a.start, a.\"end\", b.start, b.\"end\")"; + + let df = ctx.sql(right_join_sql).await.unwrap(); + let plan = df.logical_plan().clone(); + + let rule = make_rule(); + let config = datafusion::optimizer::OptimizerContext::new(); + + let result = rule.rewrite(plan, &config).unwrap(); + assert!(!result.transformed); +} + +#[tokio::test] +async fn test_rewrite_skips_full_outer_join() { + // Given a FULL OUTER JOIN with overlap predicates, + // When the logical rule is applied, + // Then the plan is not rewritten. + let ctx = make_ctx(); + let schema = Arc::new(Schema::new(vec![ + Field::new("chrom", DataType::Utf8, false), + Field::new("start", DataType::Int64, false), + Field::new("end", DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["chr1"])), + Arc::new(Int64Array::from(vec![100])), + Arc::new(Int64Array::from(vec![200])), + ], + ) + .unwrap(); + + let table = datafusion::datasource::MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()]], + ) + .unwrap(); + let table2 = datafusion::datasource::MemTable::try_new( + schema, + vec![vec![batch]], + ) + .unwrap(); + ctx.register_table("a", Arc::new(table)).unwrap(); + ctx.register_table("b", Arc::new(table2)).unwrap(); + + let full_join_sql = "\ + SELECT a.chrom, a.start, a.\"end\", \ + b.chrom, b.start, b.\"end\" \ + FROM a FULL OUTER JOIN b \ + ON a.chrom = b.chrom \ + AND giql_intersects(a.start, a.\"end\", b.start, b.\"end\")"; + + let df = ctx.sql(full_join_sql).await.unwrap(); + let plan = df.logical_plan().clone(); + + let rule = make_rule(); + let config = datafusion::optimizer::OptimizerContext::new(); + + let result = rule.rewrite(plan, &config).unwrap(); + assert!(!result.transformed); +} + +// ── Raw overlap predicates are NOT rewritten ──────────────────── + +#[tokio::test] +async fn test_rewrite_skips_raw_overlap_predicates() { + // Given a standard inner join with raw overlap predicates + // (no giql_intersects function call), + // When the logical rule is applied, + // Then the plan is not rewritten. + let ctx = make_ctx(); + let schema = Arc::new(Schema::new(vec![ + Field::new("chrom", DataType::Utf8, false), + Field::new("start", DataType::Int64, false), + Field::new("end", DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["chr1"])), + Arc::new(Int64Array::from(vec![100])), + Arc::new(Int64Array::from(vec![200])), + ], + ) + .unwrap(); + + let table = datafusion::datasource::MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()]], + ) + .unwrap(); + let table2 = datafusion::datasource::MemTable::try_new( + schema, + vec![vec![batch]], + ) + .unwrap(); + ctx.register_table("a", Arc::new(table)).unwrap(); + ctx.register_table("b", Arc::new(table2)).unwrap(); + + let raw_sql = "\ + SELECT a.chrom, a.start, a.\"end\", \ + b.chrom, b.start, b.\"end\" \ + FROM a JOIN b \ + ON a.chrom = b.chrom \ + AND a.start < b.\"end\" \ + AND a.\"end\" > b.start"; + + let df = ctx.sql(raw_sql).await.unwrap(); + let plan = df.logical_plan().clone(); + + let rule = make_rule(); + let config = datafusion::optimizer::OptimizerContext::new(); + + let result = rule.rewrite(plan, &config).unwrap(); + // The plan should NOT be rewritten since there's no + // giql_intersects() function call. + assert!( + !result.transformed, + "Raw overlap predicates should not trigger the rule" + ); +} + +// ── Correctness integration tests ─────────────────────────────── + +#[tokio::test] +async fn test_logical_rule_produces_correct_results_simple() { + // Given two tables with known overlapping intervals and the + // logical rule enabled, + // When an INTERSECTS join is executed, + // Then the correct number of overlap pairs is returned. + let dir = TempDir::new().unwrap(); + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1", "chr1", "chr1"], + &[100, 300, 600], + &[250, 500, 800], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1", "chr1"], + &[200, 700], + &[400, 900], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + // Expected overlaps: + // a[100,250) x b[200,400) -> yes + // a[300,500) x b[200,400) -> yes + // a[600,800) x b[700,900) -> yes + assert_eq!(total_rows, 3); +} + +#[tokio::test] +async fn test_logical_rule_no_false_positives_adjacent() { + // Given adjacent intervals [100,200) and [200,300) with the + // logical rule enabled, + // When an INTERSECTS join is executed, + // Then no overlap pairs are returned (half-open semantics). + let dir = TempDir::new().unwrap(); + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1"], + &[100], + &[200], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1"], + &[200], + &[300], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 0); +} + +#[tokio::test] +async fn test_logical_rule_containment() { + // Given an interval [100,500) that fully contains [200,300) + // with the logical rule enabled, + // When an INTERSECTS join is executed, + // Then exactly one overlap pair is returned. + let dir = TempDir::new().unwrap(); + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1"], + &[100], + &[500], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1"], + &[200], + &[300], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 1); +} + +#[tokio::test] +async fn test_logical_rule_different_chroms_no_overlap() { + // Given intervals on different chromosomes with the logical + // rule enabled, + // When an INTERSECTS join is executed, + // Then no pairs are returned even though the coordinates overlap. + let dir = TempDir::new().unwrap(); + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1", "chr1"], + &[100, 300], + &[500, 600], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr2", "chr2"], + &[100, 300], + &[500, 600], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 0); +} + +// ── Canonical-bin dedup correctness ───────────────────────────── + +#[tokio::test] +async fn test_no_duplicate_pairs_wide_intervals() { + // Given wide intervals that span multiple bins, + // When the logical rule rewrites to a binned join, + // Then each overlapping pair appears exactly once. + let dir = TempDir::new().unwrap(); + + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1", "chr1"], + &[0, 50000], + &[40000, 90000], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1", "chr1"], + &[10000, 60000], + &[30000, 80000], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + // a[0,40000) overlaps b[10000,30000) -> yes + // a[0,40000) overlaps b[60000,80000) -> no + // a[50000,90000) overlaps b[10000,30000) -> no + // a[50000,90000) overlaps b[60000,80000) -> yes + assert_eq!(total_rows, 2); +} + +#[tokio::test] +async fn test_no_duplicate_pairs_many_bins() { + // Given an interval that spans many bins and overlaps with + // multiple other intervals, + // When the logical rule rewrites to a binned join, + // Then each pair appears exactly once. + let dir = TempDir::new().unwrap(); + + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1"], + &[0], + &[100000], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1", "chr1", "chr1"], + &[5000, 50000, 200000], + &[15000, 70000, 300000], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + // a[0,100000) overlaps b[5000,15000) -> yes + // a[0,100000) overlaps b[50000,70000) -> yes + // a[0,100000) overlaps b[200000,300000) -> no + assert_eq!(total_rows, 2); +} + +// ── Adaptive bin sizing ───────────────────────────────────────── + +#[tokio::test] +async fn test_narrow_intervals_produce_small_bin_size() { + // Given tables with narrow intervals (width ~100bp), + // When the logical rule processes them, + // Then the result should still be correct. + let dir = TempDir::new().unwrap(); + + let chroms: Vec<&str> = vec!["chr1"; 100]; + let starts: Vec = (0..100).map(|i| i * 200).collect(); + let ends: Vec = starts.iter().map(|s| s + 100).collect(); + + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &chroms, + &starts, + &ends, + ); + let starts_b: Vec = + (0..100).map(|i| i * 200 + 50).collect(); + let ends_b: Vec = + starts_b.iter().map(|s| s + 100).collect(); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &chroms, + &starts_b, + &ends_b, + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 100); +} + +#[tokio::test] +async fn test_wide_intervals_produce_large_bin_size() { + // Given tables with wide intervals (width ~50000bp), + // When the logical rule processes them, + // Then the result should still be correct. + let dir = TempDir::new().unwrap(); + + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1", "chr1"], + &[0, 100000], + &[50000, 150000], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1", "chr1"], + &[25000, 125000], + &[75000, 175000], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + // a[0,50000) overlaps b[25000,75000) -> yes + // a[0,50000) overlaps b[125000,175000) -> no + // a[100000,150000) overlaps b[25000,75000) -> no + // a[100000,150000) overlaps b[125000,175000) -> yes + assert_eq!(total_rows, 2); +} + +// ── Multi-chromosome tests ────────────────────────────────────── + +#[tokio::test] +async fn test_multi_chromosome_intersects() { + // Given intervals on multiple chromosomes, + // When the logical rule processes the join, + // Then overlaps are correctly identified per-chromosome only. + let dir = TempDir::new().unwrap(); + + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1", "chr2", "chr3"], + &[100, 100, 100], + &[500, 500, 500], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1", "chr2", "chr4"], + &[200, 200, 200], + &[400, 400, 400], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + // chr1: a[100,500) x b[200,400) -> yes + // chr2: a[100,500) x b[200,400) -> yes + // chr3 vs chr4: no match + assert_eq!(total_rows, 2); +} + +// ── Many-to-many overlap (correctness stress test) ────────────── + +#[tokio::test] +async fn test_many_to_many_overlap() { + // Given overlapping intervals where each interval overlaps + // multiple intervals on the other side, + // When the logical rule processes the join, + // Then all valid pairs are returned exactly once. + let dir = TempDir::new().unwrap(); + + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1", "chr1", "chr1"], + &[0, 100, 200], + &[300, 400, 500], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1", "chr1", "chr1"], + &[150, 250, 350], + &[350, 450, 550], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 8); +} + +// ── Empty tables ──────────────────────────────────────────────── + +#[tokio::test] +async fn test_logical_rule_empty_right_table() { + // Given a non-empty left table and an empty right table, + // When the logical rule processes an INTERSECTS join, + // Then zero rows are returned. + let dir = TempDir::new().unwrap(); + + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1", "chr1"], + &[100, 300], + &[200, 400], + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("chrom", DataType::Utf8, false), + Field::new("start", DataType::Int64, false), + Field::new("end", DataType::Int64, false), + ])); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let empty_batch = RecordBatch::new_empty(schema.clone()); + let empty_table = datafusion::datasource::MemTable::try_new( + schema, + vec![vec![empty_batch]], + ) + .unwrap(); + ctx.register_table("b", Arc::new(empty_table)).unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 0); +} + +// ── Single-row tables ─────────────────────────────────────────── + +#[tokio::test] +async fn test_logical_rule_single_row_overlap() { + // Given two single-row tables with overlapping intervals, + // When the logical rule processes the join, + // Then exactly one pair is returned. + let dir = TempDir::new().unwrap(); + + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1"], + &[100], + &[300], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1"], + &[200], + &[400], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 1); +} + +#[tokio::test] +async fn test_logical_rule_single_row_no_overlap() { + // Given two single-row tables with non-overlapping intervals, + // When the logical rule processes the join, + // Then zero pairs are returned. + let dir = TempDir::new().unwrap(); + + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1"], + &[100], + &[200], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1"], + &[300], + &[400], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 0); +} + +// ── Identical intervals ───────────────────────────────────────── + +#[tokio::test] +async fn test_logical_rule_identical_intervals() { + // Given two tables with identical intervals, + // When the logical rule processes the join, + // Then all N*M pairs are returned. + let dir = TempDir::new().unwrap(); + + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1", "chr1"], + &[100, 100], + &[200, 200], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1", "chr1"], + &[100, 100], + &[200, 200], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + // 2 left x 2 right = 4 pairs, all overlap + assert_eq!(total_rows, 4); +} + +// ── Boundary conditions ───────────────────────────────────────── + +#[tokio::test] +async fn test_logical_rule_one_bp_overlap() { + // Given intervals that overlap by exactly 1bp, + // When the logical rule processes the join, + // Then the pair is returned. + let dir = TempDir::new().unwrap(); + + // a[100,201) and b[200,300) overlap at position 200 + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1"], + &[100], + &[201], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1"], + &[200], + &[300], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 1); +} + +// ── Custom column names (chromStart/chromEnd) ─────────────────── + +#[tokio::test] +async fn test_logical_rule_chromstart_chromend_columns() { + // Given tables with BED-style column names (chromStart, chromEnd), + // When an INTERSECTS join is executed with giql_intersects() + // using those column names explicitly, + // Then the overlaps are found correctly. + let dir = TempDir::new().unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("chrom", DataType::Utf8, false), + Field::new("chromStart", DataType::Int64, false), + Field::new("chromEnd", DataType::Int64, false), + ])); + + let path_a = write_intervals_parquet_custom_schema( + dir.path(), + "a.parquet", + schema.clone(), + &["chr1", "chr1"], + &[100, 500], + &[300, 700], + ); + let path_b = write_intervals_parquet_custom_schema( + dir.path(), + "b.parquet", + schema, + &["chr1"], + &[200], + &[600], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let sql = "\ + SELECT a.chrom, a.\"chromStart\", a.\"chromEnd\", \ + b.chrom AS chrom_b, b.\"chromStart\" AS start_b, \ + b.\"chromEnd\" AS end_b \ + FROM a JOIN b \ + ON a.chrom = b.chrom \ + AND giql_intersects(\ + a.\"chromStart\", a.\"chromEnd\", \ + b.\"chromStart\", b.\"chromEnd\")"; + + let result = ctx.sql(sql).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + // a[100,300) x b[200,600) -> yes + // a[500,700) x b[200,600) -> yes + assert_eq!(total_rows, 2); +} + +// ── Verify values in output ───────────────────────────────────── + +#[tokio::test] +async fn test_logical_rule_output_values_correct() { + // Given known overlapping intervals, + // When the logical rule processes the join, + // Then the output columns contain the correct start/end values. + let dir = TempDir::new().unwrap(); + + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1"], + &[100], + &[300], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1"], + &[200], + &[400], + ); + + let ctx = make_ctx(); + ctx.register_parquet("a", path_a.to_str().unwrap(), Default::default()) + .await + .unwrap(); + ctx.register_parquet("b", path_b.to_str().unwrap(), Default::default()) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + assert_eq!(batches.len(), 1); + + let batch = &batches[0]; + assert_eq!(batch.num_rows(), 1); + + let a_start = batch + .column_by_name("start") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a_start.value(0), 100); + + let a_end = batch + .column_by_name("end") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a_end.value(0), 300); + + let b_start = batch + .column_by_name("start_b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(b_start.value(0), 200); + + let b_end = batch + .column_by_name("end_b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(b_end.value(0), 400); +} + +/// Tables aliased as "peaks" and "genes" — not starting with +/// 'a' or 'l'. Verifies giql_intersects() works with any table +/// names. +#[tokio::test] +async fn test_logical_rule_non_al_table_aliases() { + let dir = TempDir::new().unwrap(); + let left_path = write_intervals_parquet( + dir.path(), + "peaks.parquet", + &["chr1", "chr1"], + &[100, 300], + &[250, 500], + ); + let right_path = write_intervals_parquet( + dir.path(), + "genes.parquet", + &["chr1", "chr1"], + &[200, 400], + &[350, 600], + ); + + let ctx = make_ctx(); + ctx.register_parquet( + "peaks", + left_path.to_str().unwrap(), + Default::default(), + ) + .await + .unwrap(); + ctx.register_parquet( + "genes", + right_path.to_str().unwrap(), + Default::default(), + ) + .await + .unwrap(); + + let sql = r#" + SELECT peaks.chrom, peaks.start, peaks."end", + genes.chrom AS chrom_b, genes.start AS start_b, + genes."end" AS end_b + FROM peaks JOIN genes + ON peaks.chrom = genes.chrom + AND giql_intersects( + peaks.start, peaks."end", + genes.start, genes."end") + "#; + + let result = ctx.sql(sql).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + // [100,250) overlaps [200,350): yes + // [300,500) overlaps [200,350): yes + // [300,500) overlaps [400,600): yes + assert_eq!(total_rows, 3); +} + +// ── Self-join ─────────────────────────────────────────────────── + +#[tokio::test] +async fn test_logical_rule_self_join() { + // Given a single table joined against itself, + // When the logical rule processes the self-join, + // Then overlaps are found correctly without alias collisions. + let dir = TempDir::new().unwrap(); + let path = write_intervals_parquet( + dir.path(), + "intervals.parquet", + &["chr1", "chr1", "chr1"], + &[100, 200, 500], + &[300, 400, 700], + ); + + let ctx = make_ctx(); + ctx.register_parquet( + "intervals", + path.to_str().unwrap(), + Default::default(), + ) + .await + .unwrap(); + + // Use the physical table name (not aliases) so that the + // rewritten plan's SubqueryAlias names resolve correctly. + let sql = r#" + SELECT intervals.chrom, intervals.start, intervals."end" + FROM intervals JOIN intervals AS intervals2 + ON intervals.chrom = intervals2.chrom + AND giql_intersects( + intervals.start, intervals."end", + intervals2.start, intervals2."end") + "#; + + let result = ctx.sql(sql).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + // All pairs where intervals overlap (including self-pairs): + // [100,300) x [100,300) -> yes + // [100,300) x [200,400) -> yes + // [100,300) x [500,700) -> no + // [200,400) x [100,300) -> yes + // [200,400) x [200,400) -> yes + // [200,400) x [500,700) -> no + // [500,700) x [100,300) -> no + // [500,700) x [200,400) -> no + // [500,700) x [500,700) -> yes + assert_eq!(total_rows, 5); +} + +// ── Compound predicate alongside giql_intersects ──────────────── + +#[tokio::test] +async fn test_logical_rule_compound_predicate() { + // Given an additional filter alongside giql_intersects, + // When the logical rule processes the join, + // Then both the overlap and the extra predicate are applied. + let dir = TempDir::new().unwrap(); + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1", "chr1"], + &[100, 300], + &[250, 500], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1", "chr1"], + &[200, 400], + &[350, 600], + ); + + let ctx = make_ctx(); + ctx.register_parquet( + "a", + path_a.to_str().unwrap(), + Default::default(), + ) + .await + .unwrap(); + ctx.register_parquet( + "b", + path_b.to_str().unwrap(), + Default::default(), + ) + .await + .unwrap(); + + // Extra predicate: only keep pairs where b.start > 300 + let sql = r#" + SELECT a.chrom, a.start, a."end", + b.chrom AS chrom_b, b.start AS start_b, + b."end" AS end_b + FROM a JOIN b + ON a.chrom = b.chrom + AND giql_intersects(a.start, a."end", b.start, b."end") + AND b.start > 300 + "#; + + let result = ctx.sql(sql).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + // Without the extra predicate, overlaps would be: + // [100,250) x [200,350) -> yes + // [300,500) x [200,350) -> yes + // [300,500) x [400,600) -> yes + // With b.start > 300, only b[400,600) qualifies: + // [300,500) x [400,600) -> yes + assert_eq!(total_rows, 1); +} + +// ── Pathological width distribution (sampling) ────────────────── + +#[tokio::test] +async fn test_middle_wide_interval_not_at_extremes() { + // Given a "middle-wide" distribution where the widest interval + // has neither the smallest start nor the largest end, + // When the logical rule processes the join, + // Then Parquet sampling detects the actual width and the + // result is correct (no duplicates from replication blowup). + let dir = TempDir::new().unwrap(); + + // Pathological case for column-level heuristics: + // - [0, 50): narrow, has min(start) and min(end) + // - [1000, 900_000): WIDE (width 899K), middle of coordinate space + // - [999_000, 1_000_000): narrow, has max(start) and max(end) + // + // Column-level: w1=min(end)-min(start)=50, w2=max(end)-max(start)=1000 + // Both miss the 899K-wide interval. Sampling reads actual widths. + let path_a = write_intervals_parquet( + dir.path(), + "a.parquet", + &["chr1", "chr1", "chr1"], + &[0, 1000, 999_000], + &[50, 900_000, 1_000_000], + ); + let path_b = write_intervals_parquet( + dir.path(), + "b.parquet", + &["chr1", "chr1"], + &[500_000, 950_000], + &[600_000, 999_500], + ); + + let ctx = make_ctx(); + ctx.register_parquet( + "a", + path_a.to_str().unwrap(), + Default::default(), + ) + .await + .unwrap(); + ctx.register_parquet( + "b", + path_b.to_str().unwrap(), + Default::default(), + ) + .await + .unwrap(); + + let result = ctx.sql(INTERSECTS_SQL).await.unwrap(); + let batches = result.collect().await.unwrap(); + let total_rows: usize = + batches.iter().map(|b| b.num_rows()).sum(); + + // a[0,50) x b[500000,600000) -> no + // a[0,50) x b[950000,999500) -> no + // a[1000,900000) x b[500000,600000) -> yes + // a[1000,900000) x b[950000,999500) -> no (900000 <= 950000) + // a[999000,1000000) x b[500000,600000) -> no (999000 >= 600000) + // a[999000,1000000) x b[950000,999500) -> yes + assert_eq!(total_rows, 2); +} diff --git a/src/giql/generators/__init__.py b/src/giql/generators/__init__.py index ca8cb16..226851c 100644 --- a/src/giql/generators/__init__.py +++ b/src/giql/generators/__init__.py @@ -1,5 +1,6 @@ """SQL generators for GIQL transpilation.""" from giql.generators.base import BaseGIQLGenerator +from giql.generators.datafusion import DataFusionGIQLGenerator -__all__ = ["BaseGIQLGenerator"] +__all__ = ["BaseGIQLGenerator", "DataFusionGIQLGenerator"] diff --git a/src/giql/generators/datafusion.py b/src/giql/generators/datafusion.py new file mode 100644 index 0000000..3af6a1f --- /dev/null +++ b/src/giql/generators/datafusion.py @@ -0,0 +1,35 @@ +"""DataFusion SQL generator for GIQL transpilation. + +Emits ``giql_intersects()`` function calls for column-to-column +INTERSECTS joins instead of expanding to raw overlap predicates. +A DataFusion logical optimizer rule matches on that function call +and rewrites it to a binned equi-join with adaptive bin sizing. +""" + +from __future__ import annotations + +from giql.generators.base import BaseGIQLGenerator + + +class DataFusionGIQLGenerator(BaseGIQLGenerator): + """Generator that preserves INTERSECTS semantics for DataFusion. + + For column-to-column INTERSECTS joins, emits:: + + (l.chrom = r.chrom AND giql_intersects(l.start, l.end, r.start, r.end)) + + instead of the standard overlap predicates. The chrom equi-key is + preserved as plain SQL so DataFusion can use it for hash + partitioning. All other operations (literal range queries, + CONTAINS, WITHIN) fall through to the base generator. + """ + + def _generate_column_join(self, left_col: str, right_col: str, op_type: str) -> str: + if op_type == "intersects": + l_chrom, l_start, l_end = self._get_column_refs(left_col, None) + r_chrom, r_start, r_end = self._get_column_refs(right_col, None) + return ( + f"({l_chrom} = {r_chrom} " + f"AND giql_intersects({l_start}, {l_end}, {r_start}, {r_end}))" + ) + return super()._generate_column_join(left_col, right_col, op_type) diff --git a/src/giql/transpile.py b/src/giql/transpile.py index 2b29c3d..d140aa1 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -4,6 +4,10 @@ to standard SQL. """ +from __future__ import annotations + +from typing import Literal + from sqlglot import parse_one from giql.dialect import GIQLDialect @@ -45,6 +49,7 @@ def _build_tables(tables: list[str | Table] | None) -> Tables: def transpile( giql: str, tables: list[str | Table] | None = None, + dialect: Literal["default", "datafusion"] = "default", ) -> str: """Transpile a GIQL query to SQL. @@ -60,6 +65,12 @@ def transpile( Table configurations. Strings use default column mappings (chrom, start, end, strand). Table objects provide custom column name mappings. + dialect : {"default", "datafusion"} + Target SQL dialect. ``"datafusion"`` emits + ``giql_intersects()`` function calls for column-to-column + INTERSECTS joins, allowing a DataFusion logical optimizer + rule to rewrite them into binned equi-joins with adaptive + bin sizing. Default emits standard SQL-92 overlap predicates. Returns ------- @@ -69,7 +80,8 @@ def transpile( Raises ------ ValueError - If the query cannot be parsed or transpiled. + If the query cannot be parsed or transpiled, or if an + unsupported dialect is specified. Examples -------- @@ -80,19 +92,12 @@ def transpile( tables=["peaks"], ) - Custom table configuration:: + DataFusion dialect for optimized interval joins:: sql = transpile( - "SELECT * FROM peaks WHERE interval INTERSECTS 'chr1:1000-2000'", - tables=[ - Table( - "peaks", - genomic_col="interval", - chrom_col="chrom", - start_col="start", - end_col="end", - ) - ], + "SELECT * FROM a JOIN b ON a.interval INTERSECTS b.interval", + tables=["a", "b"], + dialect="datafusion", ) """ # Build tables container @@ -102,8 +107,17 @@ def transpile( merge_transformer = MergeTransformer(tables_container) cluster_transformer = ClusterTransformer(tables_container) - # Initialize generator with table configurations - generator = BaseGIQLGenerator(tables=tables_container) + # Initialize generator for the target dialect + if dialect == "datafusion": + from giql.generators.datafusion import DataFusionGIQLGenerator + + generator = DataFusionGIQLGenerator(tables=tables_container) + elif dialect == "default": + generator = BaseGIQLGenerator(tables=tables_container) + else: + raise ValueError( + f"Unknown dialect: {dialect!r}. Supported: 'default', 'datafusion'" + ) # Parse GIQL query try: diff --git a/tests/test_transpile.py b/tests/test_transpile.py index e0f54fe..3450659 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -377,6 +377,101 @@ def test_nearest_lateral(self): assert "LIMIT 3" in sql +class TestTranspileDataFusionDialect: + """Tests for the datafusion dialect.""" + + def test_datafusion_dialect_intersects_join(self): + """ + GIVEN a GIQL query joining two tables with INTERSECTS + WHEN transpiling with dialect="datafusion" + THEN should emit giql_intersects() function call with chrom equi-key preserved + """ + sql = transpile( + "SELECT a.*, b.* FROM peaks a JOIN genes b ON a.interval INTERSECTS b.region", + tables=[ + Table("peaks", genomic_col="interval"), + Table("genes", genomic_col="region"), + ], + dialect="datafusion", + ) + + assert "giql_intersects(" in sql + assert "JOIN" in sql.upper() + # Chrom equi-key should still be standard SQL + assert '"chrom"' in sql + + def test_datafusion_dialect_literal_range_unchanged(self): + """ + GIVEN a GIQL query with INTERSECTS and a literal range + WHEN transpiling with dialect="datafusion" + THEN should emit standard SQL predicates, not giql_intersects() + """ + sql = transpile( + "SELECT * FROM peaks WHERE interval INTERSECTS 'chr1:1000-2000'", + tables=["peaks"], + dialect="datafusion", + ) + + assert "giql_intersects" not in sql + assert "chr1" in sql + + def test_datafusion_dialect_contains_unchanged(self): + """ + GIVEN a GIQL CONTAINS join + WHEN transpiling with dialect="datafusion" + THEN should emit standard SQL predicates (only INTERSECTS uses the function call) + """ + sql = transpile( + "SELECT a.*, b.* FROM peaks a JOIN genes b ON a.interval CONTAINS b.region", + tables=[ + Table("peaks", genomic_col="interval"), + Table("genes", genomic_col="region"), + ], + dialect="datafusion", + ) + + assert "giql_intersects" not in sql + + def test_default_dialect_unchanged(self): + """ + GIVEN a GIQL INTERSECTS join + WHEN transpiling with dialect="default" (or omitted) + THEN should emit standard SQL overlap predicates + """ + sql_default = transpile( + "SELECT a.*, b.* FROM peaks a JOIN genes b ON a.interval INTERSECTS b.region", + tables=[ + Table("peaks", genomic_col="interval"), + Table("genes", genomic_col="region"), + ], + ) + sql_explicit = transpile( + "SELECT a.*, b.* FROM peaks a JOIN genes b ON a.interval INTERSECTS b.region", + tables=[ + Table("peaks", genomic_col="interval"), + Table("genes", genomic_col="region"), + ], + dialect="default", + ) + + assert "giql_intersects" not in sql_default + assert "giql_intersects" not in sql_explicit + assert sql_default == sql_explicit + + def test_invalid_dialect_raises(self): + """ + GIVEN an unsupported dialect string + WHEN transpiling + THEN should raise ValueError + """ + with pytest.raises(ValueError, match="Unknown dialect"): + transpile( + "SELECT * FROM peaks WHERE interval INTERSECTS 'chr1:1000-2000'", + tables=["peaks"], + dialect="postgres", + ) + + class TestTranspileErrors: """Tests for error handling."""