diff --git a/Cargo.lock b/Cargo.lock index 6b87bee..4f40173 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1246,6 +1246,7 @@ dependencies = [ "rpassword", "serde", "serde_json", + "sqlparser", "syntect", "thiserror 1.0.69", "tokio", @@ -1744,6 +1745,15 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "sqlparser" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8" +dependencies = [ + "log", +] + [[package]] name = "static_assertions" version = "1.1.0" diff --git a/Cargo.toml b/Cargo.toml index 3654ab6..6196121 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,9 @@ rpassword = "5" # Home directory dirs = "5" +# SQL parsing +sqlparser = "0.53" + # Logging (optional for debugging) tracing = "0.1" tracing-subscriber = "0.3" diff --git a/src/ast/adapter.rs b/src/ast/adapter.rs new file mode 100644 index 0000000..426d7ac --- /dev/null +++ b/src/ast/adapter.rs @@ -0,0 +1,186 @@ +/// Language adapter traits for the multi-language execution engine. +/// +/// Each language adapter converts its input format into our unified AST. +/// This enables pgrsql to accept queries from SQL, Python DSLs, Rust DSLs, +/// and future language integrations. +use anyhow::Result; + +use super::types::Query; + +/// Adapter for parsing raw query languages (like SQL dialects). +/// +/// Implementations should handle a specific language's text format +/// and produce our unified AST. +/// +/// # Example +/// +/// ```ignore +/// struct PostgresAdapter; +/// +/// impl QueryLanguageAdapter for PostgresAdapter { +/// fn name(&self) -> &str { "PostgreSQL" } +/// fn parse(&self, input: &str) -> Result> { +/// crate::ast::parser::parse_sql(input) +/// } +/// } +/// ``` +pub trait QueryLanguageAdapter: Send + Sync { + /// Human-readable name of the language this adapter handles. + fn name(&self) -> &str; + + /// Parse input text into one or more AST queries. + fn parse(&self, input: &str) -> Result>; + + /// Check if this adapter can handle the given input. + /// Used for auto-detection of input language. + fn can_handle(&self, input: &str) -> bool; +} + +/// Adapter for Domain-Specific Languages that compile to our AST. +/// +/// Unlike `QueryLanguageAdapter`, DSL adapters may maintain state +/// (e.g., variable bindings, session context) across invocations. +/// +/// # Example +/// +/// ```ignore +/// struct PythonDSLAdapter { runtime: PyRuntime } +/// +/// impl DSLAdapter for PythonDSLAdapter { +/// fn name(&self) -> &str { "Python DataFrame DSL" } +/// fn compile_to_ast(&self, code: &str) -> Result { +/// // Evaluate Python code, extract query builder chain, produce AST +/// } +/// } +/// ``` +pub trait DSLAdapter: Send + Sync { + /// Human-readable name of the DSL. + fn name(&self) -> &str; + + /// Compile DSL code into a single query AST. + fn compile_to_ast(&self, code: &str) -> Result; + + /// Return supported file extensions for this DSL (e.g., `["py", "python"]`). + fn file_extensions(&self) -> Vec<&str> { + vec![] + } +} + +/// Built-in PostgreSQL adapter using our parser. +pub struct PostgresAdapter; + +impl QueryLanguageAdapter for PostgresAdapter { + fn name(&self) -> &str { + "PostgreSQL" + } + + fn parse(&self, input: &str) -> Result> { + super::parser::parse_sql(input) + } + + fn can_handle(&self, input: &str) -> bool { + let trimmed = input.trim().to_uppercase(); + // Basic heuristic: starts with a SQL keyword + trimmed.starts_with("SELECT") + || trimmed.starts_with("INSERT") + || trimmed.starts_with("UPDATE") + || trimmed.starts_with("DELETE") + || trimmed.starts_with("WITH") + || trimmed.starts_with("CREATE") + || trimmed.starts_with("ALTER") + || trimmed.starts_with("DROP") + || trimmed.starts_with("EXPLAIN") + || trimmed.starts_with("SHOW") + } +} + +/// Registry for managing multiple language adapters. +pub struct AdapterRegistry { + query_adapters: Vec>, + dsl_adapters: Vec>, +} + +impl Default for AdapterRegistry { + fn default() -> Self { + let mut registry = Self { + query_adapters: Vec::new(), + dsl_adapters: Vec::new(), + }; + // Register the built-in PostgreSQL adapter + registry.register_query_adapter(Box::new(PostgresAdapter)); + registry + } +} + +impl AdapterRegistry { + pub fn new() -> Self { + Self::default() + } + + pub fn register_query_adapter(&mut self, adapter: Box) { + self.query_adapters.push(adapter); + } + + pub fn register_dsl_adapter(&mut self, adapter: Box) { + self.dsl_adapters.push(adapter); + } + + /// Parse input using the first adapter that can handle it. + pub fn parse(&self, input: &str) -> Result> { + for adapter in &self.query_adapters { + if adapter.can_handle(input) { + return adapter.parse(input); + } + } + anyhow::bail!("No adapter found that can handle this input") + } + + /// List all registered adapter names. + pub fn adapter_names(&self) -> Vec<&str> { + let mut names: Vec<&str> = self.query_adapters.iter().map(|a| a.name()).collect(); + names.extend(self.dsl_adapters.iter().map(|a| a.name())); + names + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_postgres_adapter_can_handle() { + let adapter = PostgresAdapter; + assert!(adapter.can_handle("SELECT * FROM users")); + assert!(adapter.can_handle(" select * from users ")); + assert!(adapter.can_handle("INSERT INTO users VALUES (1)")); + assert!(adapter.can_handle("WITH cte AS (SELECT 1) SELECT * FROM cte")); + assert!(!adapter.can_handle("df.filter(col('x') > 1)")); + } + + #[test] + fn test_postgres_adapter_parse() { + let adapter = PostgresAdapter; + let queries = adapter.parse("SELECT * FROM users").unwrap(); + assert_eq!(queries.len(), 1); + } + + #[test] + fn test_registry_default_has_postgres() { + let registry = AdapterRegistry::new(); + assert!(registry.adapter_names().contains(&"PostgreSQL")); + } + + #[test] + fn test_registry_parse_sql() { + let registry = AdapterRegistry::new(); + let queries = registry.parse("SELECT 1").unwrap(); + assert_eq!(queries.len(), 1); + } + + #[test] + fn test_registry_no_adapter_for_unknown() { + let registry = AdapterRegistry::new(); + let result = registry.parse("df.show()"); + assert!(result.is_err()); + } +} diff --git a/src/ast/compiler.rs b/src/ast/compiler.rs new file mode 100644 index 0000000..9d101c6 --- /dev/null +++ b/src/ast/compiler.rs @@ -0,0 +1,693 @@ +/// Unified AST → SQL compiler. +/// +/// Converts our internal AST back into a SQL string targeting PostgreSQL. +/// This enables round-trip parsing: SQL → AST → SQL, and allows DSLs +/// to generate SQL by building AST nodes. +use super::types::*; + +/// Compile a query AST into a PostgreSQL SQL string. +pub fn compile(query: &Query) -> String { + match query { + Query::Select(s) => compile_select(s), + Query::Insert(i) => compile_insert(i), + Query::Update(u) => compile_update(u), + Query::Delete(d) => compile_delete(d), + Query::With(cte) => compile_cte(cte), + Query::Raw(sql) => sql.clone(), + } +} + +fn compile_select(select: &SelectQuery) -> String { + let mut parts = Vec::new(); + + // SELECT [DISTINCT] + let mut select_clause = String::from("SELECT "); + if select.distinct { + select_clause.push_str("DISTINCT "); + } + + if select.projections.is_empty() { + select_clause.push('*'); + } else { + let items: Vec = select.projections.iter().map(compile_select_item).collect(); + select_clause.push_str(&items.join(", ")); + } + parts.push(select_clause); + + // FROM + if !select.from.is_empty() { + let tables: Vec = select.from.iter().map(compile_table_ref).collect(); + parts.push(format!("FROM {}", tables.join(", "))); + } + + // JOINs + for join in &select.joins { + parts.push(compile_join(join)); + } + + // WHERE + if let Some(ref filter) = select.filter { + parts.push(format!("WHERE {}", compile_expr(filter))); + } + + // GROUP BY + if !select.group_by.is_empty() { + let groups: Vec = select.group_by.iter().map(compile_expr).collect(); + parts.push(format!("GROUP BY {}", groups.join(", "))); + } + + // HAVING + if let Some(ref having) = select.having { + parts.push(format!("HAVING {}", compile_expr(having))); + } + + // WINDOW + for window in &select.windows { + parts.push(format!( + "WINDOW {} AS ({})", + window.name, + compile_window_spec(&window.spec) + )); + } + + // Set operations (UNION, INTERSECT, EXCEPT) + if let Some(ref set_op) = select.set_op { + let op_str = match set_op.op { + SetOperator::Union => "UNION", + SetOperator::Intersect => "INTERSECT", + SetOperator::Except => "EXCEPT", + }; + let all_str = if set_op.all { " ALL" } else { "" }; + parts.push(format!("{}{} {}", op_str, all_str, compile(&set_op.right))); + } + + // ORDER BY + if !select.order_by.is_empty() { + let orders: Vec = select.order_by.iter().map(compile_order_by).collect(); + parts.push(format!("ORDER BY {}", orders.join(", "))); + } + + // LIMIT + if let Some(ref limit) = select.limit { + parts.push(format!("LIMIT {}", compile_expr(limit))); + } + + // OFFSET + if let Some(ref offset) = select.offset { + parts.push(format!("OFFSET {}", compile_expr(offset))); + } + + parts.join(" ") +} + +fn compile_select_item(item: &SelectItem) -> String { + match item { + SelectItem::Wildcard => "*".to_string(), + SelectItem::QualifiedWildcard(table) => format!("{}.*", table), + SelectItem::Expression { expr, alias } => { + let expr_str = compile_expr(expr); + match alias { + Some(a) => format!("{} AS {}", expr_str, a), + None => expr_str, + } + } + } +} + +fn compile_table_ref(table: &TableRef) -> String { + match table { + TableRef::Table { + schema, + name, + alias, + } => { + let mut s = match schema { + Some(sc) => format!("{}.{}", sc, name), + None => name.clone(), + }; + if let Some(a) = alias { + s.push_str(&format!(" AS {}", a)); + } + s + } + TableRef::Subquery { query, alias } => { + format!("({}) AS {}", compile(query), alias) + } + TableRef::Function { name, args, alias } => { + let args_str: Vec = args.iter().map(compile_expr).collect(); + let mut s = format!("{}({})", name, args_str.join(", ")); + if let Some(a) = alias { + s.push_str(&format!(" AS {}", a)); + } + s + } + } +} + +fn compile_join(join: &Join) -> String { + let type_str = match join.join_type { + JoinType::Inner => "JOIN", + JoinType::Left => "LEFT JOIN", + JoinType::Right => "RIGHT JOIN", + JoinType::Full => "FULL JOIN", + JoinType::Cross => "CROSS JOIN", + JoinType::Lateral => "LATERAL JOIN", + }; + + let table_str = compile_table_ref(&join.table); + + let condition_str = match &join.condition { + Some(JoinCondition::On(expr)) => format!(" ON {}", compile_expr(expr)), + Some(JoinCondition::Using(cols)) => format!(" USING ({})", cols.join(", ")), + Some(JoinCondition::Natural) => " NATURAL".to_string(), + None => String::new(), + }; + + format!("{} {}{}", type_str, table_str, condition_str) +} + +fn compile_expr(expr: &Expression) -> String { + match expr { + Expression::Column { table, name } => match table { + Some(t) => format!("{}.{}", t, name), + None => name.clone(), + }, + Expression::Literal(lit) => compile_literal(lit), + Expression::BinaryOp { left, op, right } => { + let op_str = match op { + BinaryOperator::Eq => "=", + BinaryOperator::NotEq => "<>", + BinaryOperator::Lt => "<", + BinaryOperator::LtEq => "<=", + BinaryOperator::Gt => ">", + BinaryOperator::GtEq => ">=", + BinaryOperator::And => "AND", + BinaryOperator::Or => "OR", + BinaryOperator::Plus => "+", + BinaryOperator::Minus => "-", + BinaryOperator::Multiply => "*", + BinaryOperator::Divide => "/", + BinaryOperator::Modulo => "%", + BinaryOperator::Like => "LIKE", + BinaryOperator::ILike => "ILIKE", + BinaryOperator::NotLike => "NOT LIKE", + BinaryOperator::NotILike => "NOT ILIKE", + BinaryOperator::Concat => "||", + }; + format!("{} {} {}", compile_expr(left), op_str, compile_expr(right)) + } + Expression::UnaryOp { op, expr } => { + let op_str = match op { + UnaryOperator::Not => "NOT", + UnaryOperator::Minus => "-", + UnaryOperator::Plus => "+", + }; + format!("{} {}", op_str, compile_expr(expr)) + } + Expression::Function { + name, + args, + distinct, + } => { + let distinct_str = if *distinct { "DISTINCT " } else { "" }; + let args_str: Vec = args.iter().map(compile_expr).collect(); + format!("{}({}{})", name, distinct_str, args_str.join(", ")) + } + Expression::Aggregate { + name, + args, + distinct, + filter, + } => { + let distinct_str = if *distinct { "DISTINCT " } else { "" }; + let args_str: Vec = args.iter().map(compile_expr).collect(); + let mut s = format!("{}({}{})", name, distinct_str, args_str.join(", ")); + if let Some(f) = filter { + s.push_str(&format!(" FILTER (WHERE {})", compile_expr(f))); + } + s + } + Expression::WindowFunction { function, window } => { + format!( + "{} OVER ({})", + compile_expr(function), + compile_window_spec(window) + ) + } + Expression::Case { + operand, + when_clauses, + else_clause, + } => { + let mut s = String::from("CASE"); + if let Some(op) = operand { + s.push_str(&format!(" {}", compile_expr(op))); + } + for (when, then) in when_clauses { + s.push_str(&format!( + " WHEN {} THEN {}", + compile_expr(when), + compile_expr(then) + )); + } + if let Some(else_expr) = else_clause { + s.push_str(&format!(" ELSE {}", compile_expr(else_expr))); + } + s.push_str(" END"); + s + } + Expression::Subquery(q) => format!("({})", compile(q)), + Expression::Exists(q) => format!("EXISTS ({})", compile(q)), + Expression::InList { + expr, + list, + negated, + } => { + let not_str = if *negated { "NOT " } else { "" }; + let items: Vec = list.iter().map(compile_expr).collect(); + format!( + "{} {}IN ({})", + compile_expr(expr), + not_str, + items.join(", ") + ) + } + Expression::InSubquery { + expr, + subquery, + negated, + } => { + let not_str = if *negated { "NOT " } else { "" }; + format!( + "{} {}IN ({})", + compile_expr(expr), + not_str, + compile(subquery) + ) + } + Expression::Between { + expr, + low, + high, + negated, + } => { + let not_str = if *negated { "NOT " } else { "" }; + format!( + "{} {}BETWEEN {} AND {}", + compile_expr(expr), + not_str, + compile_expr(low), + compile_expr(high) + ) + } + Expression::IsNull { expr, negated } => { + if *negated { + format!("{} IS NOT NULL", compile_expr(expr)) + } else { + format!("{} IS NULL", compile_expr(expr)) + } + } + Expression::Cast { expr, data_type } => { + format!("CAST({} AS {})", compile_expr(expr), data_type) + } + Expression::Wildcard => "*".to_string(), + Expression::Parameter(idx) => format!("${}", idx), + Expression::Array(elems) => { + let items: Vec = elems.iter().map(compile_expr).collect(); + format!("ARRAY[{}]", items.join(", ")) + } + Expression::JsonAccess { + expr, + path, + as_text, + } => { + let op = if *as_text { "->>" } else { "->" }; + format!("{}{}{}", compile_expr(expr), op, compile_expr(path)) + } + Expression::TypeCast { expr, data_type } => { + format!("{}::{}", compile_expr(expr), data_type) + } + Expression::Nested(expr) => format!("({})", compile_expr(expr)), + } +} + +fn compile_literal(lit: &Literal) -> String { + match lit { + Literal::Null => "NULL".to_string(), + Literal::Boolean(b) => { + if *b { + "TRUE".to_string() + } else { + "FALSE".to_string() + } + } + Literal::Integer(i) => i.to_string(), + Literal::Float(f) => format!("{}", f), + Literal::String(s) => format!("'{}'", s.replace('\'', "''")), + } +} + +fn compile_window_spec(spec: &WindowSpec) -> String { + let mut parts = Vec::new(); + + if !spec.partition_by.is_empty() { + let cols: Vec = spec.partition_by.iter().map(compile_expr).collect(); + parts.push(format!("PARTITION BY {}", cols.join(", "))); + } + + if !spec.order_by.is_empty() { + let orders: Vec = spec.order_by.iter().map(compile_order_by).collect(); + parts.push(format!("ORDER BY {}", orders.join(", "))); + } + + if let Some(ref frame) = spec.frame { + parts.push(compile_window_frame(frame)); + } + + parts.join(" ") +} + +fn compile_window_frame(frame: &WindowFrame) -> String { + let mode = match frame.mode { + WindowFrameMode::Rows => "ROWS", + WindowFrameMode::Range => "RANGE", + WindowFrameMode::Groups => "GROUPS", + }; + + let start = compile_window_frame_bound(&frame.start); + + match &frame.end { + Some(end) => format!( + "{} BETWEEN {} AND {}", + mode, + start, + compile_window_frame_bound(end) + ), + None => format!("{} {}", mode, start), + } +} + +fn compile_window_frame_bound(bound: &WindowFrameBound) -> String { + match bound { + WindowFrameBound::CurrentRow => "CURRENT ROW".to_string(), + WindowFrameBound::Preceding(None) => "UNBOUNDED PRECEDING".to_string(), + WindowFrameBound::Preceding(Some(n)) => format!("{} PRECEDING", n), + WindowFrameBound::Following(None) => "UNBOUNDED FOLLOWING".to_string(), + WindowFrameBound::Following(Some(n)) => format!("{} FOLLOWING", n), + } +} + +fn compile_order_by(order: &OrderByExpr) -> String { + let mut s = compile_expr(&order.expr); + match order.asc { + Some(true) => s.push_str(" ASC"), + Some(false) => s.push_str(" DESC"), + None => {} + } + match order.nulls_first { + Some(true) => s.push_str(" NULLS FIRST"), + Some(false) => s.push_str(" NULLS LAST"), + None => {} + } + s +} + +fn compile_cte(cte: &CTEQuery) -> String { + let recursive = if cte.recursive { "RECURSIVE " } else { "" }; + + let ctes: Vec = cte + .ctes + .iter() + .map(|c| { + let cols = if c.columns.is_empty() { + String::new() + } else { + format!("({})", c.columns.join(", ")) + }; + format!("{}{} AS ({})", c.name, cols, compile(&c.query)) + }) + .collect(); + + format!( + "WITH {}{} {}", + recursive, + ctes.join(", "), + compile(&cte.body) + ) +} + +fn compile_insert(insert: &InsertQuery) -> String { + let table = compile_table_ref(&insert.table); + let columns = if insert.columns.is_empty() { + String::new() + } else { + format!(" ({})", insert.columns.join(", ")) + }; + + let source = match &insert.source { + InsertSource::Values(rows) => { + let row_strs: Vec = rows + .iter() + .map(|row| { + let vals: Vec = row.iter().map(compile_expr).collect(); + format!("({})", vals.join(", ")) + }) + .collect(); + format!("VALUES {}", row_strs.join(", ")) + } + InsertSource::Query(q) => compile(q), + }; + + let returning = if insert.returning.is_empty() { + String::new() + } else { + let items: Vec = insert.returning.iter().map(compile_select_item).collect(); + format!(" RETURNING {}", items.join(", ")) + }; + + format!("INSERT INTO {}{} {}{}", table, columns, source, returning) +} + +fn compile_update(update: &UpdateQuery) -> String { + let table = compile_table_ref(&update.table); + + let sets: Vec = update + .assignments + .iter() + .map(|a| format!("{} = {}", a.column, compile_expr(&a.value))) + .collect(); + + let filter = match &update.filter { + Some(f) => format!(" WHERE {}", compile_expr(f)), + None => String::new(), + }; + + let returning = if update.returning.is_empty() { + String::new() + } else { + let items: Vec = update.returning.iter().map(compile_select_item).collect(); + format!(" RETURNING {}", items.join(", ")) + }; + + format!( + "UPDATE {} SET {}{}{}", + table, + sets.join(", "), + filter, + returning + ) +} + +fn compile_delete(delete: &DeleteQuery) -> String { + let table = compile_table_ref(&delete.table); + + let filter = match &delete.filter { + Some(f) => format!(" WHERE {}", compile_expr(f)), + None => String::new(), + }; + + let returning = if delete.returning.is_empty() { + String::new() + } else { + let items: Vec = delete.returning.iter().map(compile_select_item).collect(); + format!(" RETURNING {}", items.join(", ")) + }; + + format!("DELETE FROM {}{}{}", table, filter, returning) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::parser::parse_single; + + /// Helper: parse SQL, compile back, and verify the result parses again. + fn round_trip(sql: &str) -> String { + let query = parse_single(sql).expect("Failed to parse"); + compile(&query) + } + + #[test] + fn test_compile_simple_select() { + let compiled = round_trip("SELECT * FROM users"); + assert!(compiled.contains("SELECT")); + assert!(compiled.contains("FROM users")); + } + + #[test] + fn test_compile_select_with_where() { + let compiled = round_trip("SELECT id, name FROM users WHERE age > 18"); + assert!(compiled.contains("WHERE")); + assert!(compiled.contains("age > 18")); + } + + #[test] + fn test_compile_select_with_alias() { + let compiled = round_trip("SELECT u.name AS user_name FROM users AS u"); + assert!(compiled.contains("AS user_name")); + assert!(compiled.contains("AS u")); + } + + #[test] + fn test_compile_join() { + let compiled = round_trip("SELECT * FROM users JOIN orders ON users.id = orders.user_id"); + assert!(compiled.contains("JOIN")); + assert!(compiled.contains("ON")); + } + + #[test] + fn test_compile_left_join() { + let compiled = round_trip("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id"); + assert!(compiled.contains("LEFT JOIN")); + } + + #[test] + fn test_compile_group_by_having() { + let compiled = + round_trip("SELECT dept, COUNT(*) FROM emp GROUP BY dept HAVING COUNT(*) > 5"); + assert!(compiled.contains("GROUP BY")); + assert!(compiled.contains("HAVING")); + } + + #[test] + fn test_compile_order_by_limit() { + let compiled = round_trip("SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5"); + assert!(compiled.contains("ORDER BY")); + assert!(compiled.contains("LIMIT 10")); + assert!(compiled.contains("OFFSET 5")); + } + + #[test] + fn test_compile_cte() { + let compiled = round_trip( + "WITH active AS (SELECT * FROM users WHERE active = TRUE) SELECT * FROM active", + ); + assert!(compiled.contains("WITH ")); + assert!(compiled.contains("active AS")); + } + + #[test] + fn test_compile_insert() { + let compiled = + round_trip("INSERT INTO users (name, email) VALUES ('John', 'john@example.com')"); + assert!(compiled.contains("INSERT INTO")); + assert!(compiled.contains("VALUES")); + } + + #[test] + fn test_compile_update() { + let compiled = round_trip("UPDATE users SET name = 'Jane' WHERE id = 1"); + assert!(compiled.contains("UPDATE")); + assert!(compiled.contains("SET")); + assert!(compiled.contains("WHERE")); + } + + #[test] + fn test_compile_delete() { + let compiled = round_trip("DELETE FROM users WHERE id = 1"); + assert!(compiled.contains("DELETE FROM")); + assert!(compiled.contains("WHERE")); + } + + #[test] + fn test_compile_union() { + let compiled = round_trip("SELECT id FROM users UNION ALL SELECT id FROM admins"); + assert!(compiled.contains("UNION ALL")); + } + + #[test] + fn test_compile_between() { + let compiled = round_trip("SELECT * FROM products WHERE price BETWEEN 10 AND 100"); + assert!(compiled.contains("BETWEEN")); + } + + #[test] + fn test_compile_is_null() { + let compiled = round_trip("SELECT * FROM users WHERE email IS NOT NULL"); + assert!(compiled.contains("IS NOT NULL")); + } + + #[test] + fn test_compile_case() { + let compiled = + round_trip("SELECT CASE WHEN status = 'active' THEN 1 ELSE 0 END FROM users"); + assert!(compiled.contains("CASE")); + assert!(compiled.contains("WHEN")); + assert!(compiled.contains("THEN")); + assert!(compiled.contains("ELSE")); + assert!(compiled.contains("END")); + } + + #[test] + fn test_compile_window_function() { + let compiled = round_trip( + "SELECT ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) FROM employees", + ); + assert!(compiled.contains("OVER")); + assert!(compiled.contains("PARTITION BY")); + } + + #[test] + fn test_compile_subquery_in() { + let compiled = + round_trip("SELECT * FROM users WHERE id IN (SELECT user_id FROM active_users)"); + assert!(compiled.contains("IN (")); + } + + #[test] + fn test_compile_distinct() { + let compiled = round_trip("SELECT DISTINCT name FROM users"); + assert!(compiled.contains("DISTINCT")); + } + + #[test] + fn test_compile_aggregate_distinct() { + let compiled = round_trip("SELECT COUNT(DISTINCT status) FROM orders"); + assert!(compiled.contains("COUNT(DISTINCT")); + } + + #[test] + fn test_round_trip_reparseable() { + let test_cases = vec![ + "SELECT * FROM users", + "SELECT id, name FROM users WHERE age > 18", + "SELECT * FROM users ORDER BY name LIMIT 10", + "INSERT INTO users (name) VALUES ('John')", + "UPDATE users SET name = 'Jane' WHERE id = 1", + "DELETE FROM users WHERE id = 1", + ]; + + for sql in test_cases { + let compiled = round_trip(sql); + // The compiled SQL should be parseable again + let reparsed = parse_single(&compiled); + assert!( + reparsed.is_ok(), + "Round-trip failed for: {} -> {} -> {:?}", + sql, + compiled, + reparsed.err() + ); + } + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs new file mode 100644 index 0000000..b7aebc4 --- /dev/null +++ b/src/ast/mod.rs @@ -0,0 +1,35 @@ +/// Unified Query AST and Multi-Language Execution Engine. +/// +/// This module provides the foundational architecture for pgrsql's +/// query processing pipeline: +/// +/// ```text +/// Input (SQL / DSL) +/// ↓ +/// Language Adapter Layer (adapter.rs) +/// ↓ +/// Unified Query AST (types.rs) +/// ↓ +/// Analysis / Optimization (optimizer.rs) +/// ↓ +/// SQL Compiler (compiler.rs) +/// ↓ +/// Execution Engine (existing db/ module) +/// ``` +/// +/// The plugin system (plugin.rs) allows external extensions to add +/// new adapters, optimization passes, and more. +pub mod adapter; +pub mod compiler; +pub mod optimizer; +pub mod parser; +pub mod plugin; +pub mod types; + +// Re-export key types for convenience +pub use adapter::{AdapterRegistry, DSLAdapter, QueryLanguageAdapter}; +pub use compiler::compile; +pub use optimizer::{analyze_query, OptimizationPass, Optimizer, QueryAnalysis}; +pub use parser::{parse_single, parse_sql}; +pub use plugin::{PluginRegistry, QueryPlugin}; +pub use types::*; diff --git a/src/ast/optimizer.rs b/src/ast/optimizer.rs new file mode 100644 index 0000000..dec29af --- /dev/null +++ b/src/ast/optimizer.rs @@ -0,0 +1,384 @@ +/// AST optimization and transformation infrastructure. +/// +/// Provides a pass-based system for analyzing and transforming query ASTs. +/// Each optimization pass takes an AST, returns a potentially modified AST, +/// and preserves query semantics. Passes can be composed and ordered. +use anyhow::Result; + +use super::types::*; + +/// A single optimization or transformation pass over a query AST. +/// +/// Passes should be pure functions: given the same input, they produce +/// the same output. This makes them composable and testable. +/// +/// # Example +/// +/// ```ignore +/// struct ConstantFolding; +/// +/// impl OptimizationPass for ConstantFolding { +/// fn name(&self) -> &str { "constant_folding" } +/// fn transform(&self, query: Query) -> Result { +/// // Evaluate constant expressions at compile time +/// // e.g., WHERE 1 = 1 → (removed), WHERE 2 + 3 > 4 → WHERE TRUE +/// } +/// } +/// ``` +pub trait OptimizationPass: Send + Sync { + /// Unique name identifying this pass. + fn name(&self) -> &str; + + /// Optional description of what this pass does. + fn description(&self) -> &str { + "" + } + + /// Transform a query, returning the optimized version. + /// Returns the query unchanged if no optimization applies. + fn transform(&self, query: Query) -> Result; +} + +/// Manages and executes a pipeline of optimization passes. +#[derive(Default)] +pub struct Optimizer { + passes: Vec>, +} + +impl Optimizer { + pub fn new() -> Self { + Self::default() + } + + /// Create an optimizer with the default set of passes. + pub fn with_defaults() -> Self { + let mut opt = Self::new(); + opt.add_pass(Box::new(RemoveRedundantNesting)); + opt + } + + /// Add an optimization pass to the pipeline. + pub fn add_pass(&mut self, pass: Box) { + self.passes.push(pass); + } + + /// Run all optimization passes on a query in order. + pub fn optimize(&self, query: Query) -> Result { + let mut current = query; + for pass in &self.passes { + current = pass.transform(current)?; + } + Ok(current) + } + + /// List registered pass names. + pub fn pass_names(&self) -> Vec<&str> { + self.passes.iter().map(|p| p.name()).collect() + } +} + +/// Built-in pass: removes unnecessary nested/parenthesized expressions. +/// +/// Transforms `((x))` → `x` where the nesting doesn't affect semantics. +struct RemoveRedundantNesting; + +impl OptimizationPass for RemoveRedundantNesting { + fn name(&self) -> &str { + "remove_redundant_nesting" + } + + fn description(&self) -> &str { + "Removes unnecessary parenthesized expressions" + } + + fn transform(&self, query: Query) -> Result { + match query { + Query::Select(s) => Ok(Query::Select(Box::new(simplify_select(*s)))), + other => Ok(other), + } + } +} + +fn simplify_select(mut select: SelectQuery) -> SelectQuery { + select.filter = select.filter.map(simplify_expr); + select.having = select.having.map(simplify_expr); + select.projections = select + .projections + .into_iter() + .map(|item| match item { + SelectItem::Expression { expr, alias } => SelectItem::Expression { + expr: simplify_expr(expr), + alias, + }, + other => other, + }) + .collect(); + select.group_by = select.group_by.into_iter().map(simplify_expr).collect(); + select.order_by = select + .order_by + .into_iter() + .map(|o| OrderByExpr { + expr: simplify_expr(o.expr), + ..o + }) + .collect(); + select +} + +fn simplify_expr(expr: Expression) -> Expression { + match expr { + Expression::Nested(inner) => match *inner { + // Remove double nesting: ((x)) → x + Expression::Nested(_) => simplify_expr(*inner), + // Remove nesting around simple expressions + Expression::Column { .. } + | Expression::Literal(_) + | Expression::Wildcard + | Expression::Parameter(_) => simplify_expr(*inner), + // Keep nesting for complex expressions (may be needed for precedence) + other => Expression::Nested(Box::new(simplify_expr(other))), + }, + Expression::BinaryOp { left, op, right } => Expression::BinaryOp { + left: Box::new(simplify_expr(*left)), + op, + right: Box::new(simplify_expr(*right)), + }, + Expression::UnaryOp { op, expr } => Expression::UnaryOp { + op, + expr: Box::new(simplify_expr(*expr)), + }, + Expression::Function { + name, + args, + distinct, + } => Expression::Function { + name, + args: args.into_iter().map(simplify_expr).collect(), + distinct, + }, + other => other, + } +} + +/// Analyze a query and return metadata about its structure. +pub fn analyze_query(query: &Query) -> QueryAnalysis { + let mut analysis = QueryAnalysis::default(); + analyze_query_inner(query, &mut analysis); + analysis +} + +fn analyze_query_inner(query: &Query, analysis: &mut QueryAnalysis) { + match query { + Query::Select(s) => { + analysis.has_select = true; + if s.distinct { + analysis.has_distinct = true; + } + if !s.joins.is_empty() { + analysis.has_joins = true; + analysis.join_count += s.joins.len(); + } + if !s.group_by.is_empty() { + analysis.has_aggregation = true; + } + if s.set_op.is_some() { + analysis.has_set_operations = true; + } + if !s.windows.is_empty() { + analysis.has_window_functions = true; + } + // Check projections for window functions + for item in &s.projections { + if let SelectItem::Expression { expr, .. } = item { + check_expr_features(expr, analysis); + } + } + if let Some(ref filter) = s.filter { + check_expr_features(filter, analysis); + } + } + Query::With(cte) => { + analysis.has_cte = true; + if cte.recursive { + analysis.has_recursive_cte = true; + } + for c in &cte.ctes { + analyze_query_inner(&c.query, analysis); + } + analyze_query_inner(&cte.body, analysis); + } + Query::Insert(_) => analysis.has_insert = true, + Query::Update(_) => analysis.has_update = true, + Query::Delete(_) => analysis.has_delete = true, + Query::Raw(_) => {} + } +} + +fn check_expr_features(expr: &Expression, analysis: &mut QueryAnalysis) { + match expr { + Expression::WindowFunction { .. } => analysis.has_window_functions = true, + Expression::Subquery(q) | Expression::Exists(q) => { + analysis.has_subqueries = true; + analyze_query_inner(q, analysis); + } + Expression::InSubquery { subquery, .. } => { + analysis.has_subqueries = true; + analyze_query_inner(subquery, analysis); + } + Expression::Aggregate { .. } => analysis.has_aggregation = true, + Expression::JsonAccess { .. } => analysis.has_json_operations = true, + Expression::BinaryOp { left, right, .. } => { + check_expr_features(left, analysis); + check_expr_features(right, analysis); + } + Expression::UnaryOp { expr, .. } => check_expr_features(expr, analysis), + Expression::Function { args, .. } => { + for arg in args { + check_expr_features(arg, analysis); + } + } + Expression::Case { + when_clauses, + else_clause, + .. + } => { + for (w, t) in when_clauses { + check_expr_features(w, analysis); + check_expr_features(t, analysis); + } + if let Some(e) = else_clause { + check_expr_features(e, analysis); + } + } + _ => {} + } +} + +/// Structural metadata about a query. +#[derive(Debug, Default, Clone)] +pub struct QueryAnalysis { + pub has_select: bool, + pub has_insert: bool, + pub has_update: bool, + pub has_delete: bool, + pub has_distinct: bool, + pub has_joins: bool, + pub join_count: usize, + pub has_aggregation: bool, + pub has_window_functions: bool, + pub has_subqueries: bool, + pub has_cte: bool, + pub has_recursive_cte: bool, + pub has_set_operations: bool, + pub has_json_operations: bool, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::parser::parse_single; + + #[test] + fn test_optimizer_empty() { + let opt = Optimizer::new(); + let query = Query::Select(Box::new(SelectQuery::default())); + let result = opt.optimize(query.clone()).unwrap(); + assert_eq!(result, query); + } + + #[test] + fn test_optimizer_with_defaults() { + let opt = Optimizer::with_defaults(); + assert!(!opt.pass_names().is_empty()); + assert!(opt.pass_names().contains(&"remove_redundant_nesting")); + } + + #[test] + fn test_remove_redundant_nesting() { + let pass = RemoveRedundantNesting; + let query = Query::Select(Box::new(SelectQuery { + filter: Some(Expression::Nested(Box::new(Expression::Column { + table: None, + name: "x".into(), + }))), + ..Default::default() + })); + + let optimized = pass.transform(query).unwrap(); + match optimized { + Query::Select(s) => { + // The nesting around a simple column should be removed + assert!(matches!(s.filter, Some(Expression::Column { .. }))); + } + _ => panic!("Expected Select"), + } + } + + #[test] + fn test_analyze_simple_select() { + let q = parse_single("SELECT * FROM users").unwrap(); + let analysis = analyze_query(&q); + assert!(analysis.has_select); + assert!(!analysis.has_joins); + assert!(!analysis.has_aggregation); + } + + #[test] + fn test_analyze_join_query() { + let q = + parse_single("SELECT * FROM a JOIN b ON a.id = b.a_id LEFT JOIN c ON b.id = c.b_id") + .unwrap(); + let analysis = analyze_query(&q); + assert!(analysis.has_joins); + assert_eq!(analysis.join_count, 2); + } + + #[test] + fn test_analyze_aggregation() { + let q = parse_single("SELECT dept, COUNT(*) FROM emp GROUP BY dept").unwrap(); + let analysis = analyze_query(&q); + assert!(analysis.has_aggregation); + } + + #[test] + fn test_analyze_window_function() { + let q = parse_single("SELECT ROW_NUMBER() OVER (PARTITION BY dept ORDER BY id) FROM emp") + .unwrap(); + let analysis = analyze_query(&q); + assert!(analysis.has_window_functions); + } + + #[test] + fn test_analyze_cte() { + let q = parse_single("WITH cte AS (SELECT 1) SELECT * FROM cte").unwrap(); + let analysis = analyze_query(&q); + assert!(analysis.has_cte); + assert!(!analysis.has_recursive_cte); + } + + #[test] + fn test_analyze_recursive_cte() { + let q = parse_single( + "WITH RECURSIVE nums AS (SELECT 1 AS n UNION ALL SELECT n + 1 FROM nums WHERE n < 10) SELECT * FROM nums", + ) + .unwrap(); + let analysis = analyze_query(&q); + assert!(analysis.has_cte); + assert!(analysis.has_recursive_cte); + } + + #[test] + fn test_analyze_subquery() { + let q = + parse_single("SELECT * FROM users WHERE id IN (SELECT user_id FROM active)").unwrap(); + let analysis = analyze_query(&q); + assert!(analysis.has_subqueries); + } + + #[test] + fn test_analyze_set_operation() { + let q = parse_single("SELECT id FROM a UNION SELECT id FROM b").unwrap(); + let analysis = analyze_query(&q); + assert!(analysis.has_set_operations); + } +} diff --git a/src/ast/parser.rs b/src/ast/parser.rs new file mode 100644 index 0000000..32c0a3c --- /dev/null +++ b/src/ast/parser.rs @@ -0,0 +1,1159 @@ +/// SQL → Unified AST parser. +/// +/// Translates SQL text into our internal AST representation using `sqlparser` +/// as the parsing frontend. This decouples our AST from the sqlparser crate, +/// allowing us to evolve our representation independently. +use anyhow::{anyhow, Result}; +use sqlparser::ast as sp; +use sqlparser::dialect::PostgreSqlDialect; +use sqlparser::parser::Parser as SqlParser; + +use super::types::*; + +/// Parse a SQL string into our unified AST. +pub fn parse_sql(sql: &str) -> Result> { + let dialect = PostgreSqlDialect {}; + let statements = + SqlParser::parse_sql(&dialect, sql).map_err(|e| anyhow!("SQL parse error: {}", e))?; + + statements.into_iter().map(convert_statement).collect() +} + +/// Parse a single SQL statement. Returns an error if the input contains +/// more than one statement. +pub fn parse_single(sql: &str) -> Result { + let mut queries = parse_sql(sql)?; + if queries.len() != 1 { + return Err(anyhow!("Expected 1 statement, found {}", queries.len())); + } + Ok(queries.remove(0)) +} + +fn convert_statement(stmt: sp::Statement) -> Result { + match stmt { + sp::Statement::Query(q) => convert_query(*q), + sp::Statement::Insert(insert) => convert_insert(insert), + sp::Statement::Update { + table, + assignments, + selection, + returning, + .. + } => convert_update(table, assignments, selection, returning), + sp::Statement::Delete(delete) => convert_delete(delete), + _ => Ok(Query::Raw(stmt.to_string())), + } +} + +fn convert_query(query: sp::Query) -> Result { + // Extract order_by exprs from Option + let order_by_exprs: Vec = + query.order_by.map(|ob| ob.exprs).unwrap_or_default(); + + // Handle CTEs + if let Some(with) = query.with { + let recursive = with.recursive; + let ctes = with + .cte_tables + .into_iter() + .map(convert_cte) + .collect::>>()?; + + let body = convert_set_expr(*query.body)?; + + // Apply ORDER BY, LIMIT, OFFSET to the body + let body = apply_query_modifiers(body, &order_by_exprs, &query.limit, &query.offset)?; + + return Ok(Query::With(CTEQuery { + recursive, + ctes, + body: Box::new(body), + })); + } + + let body = convert_set_expr(*query.body)?; + apply_query_modifiers(body, &order_by_exprs, &query.limit, &query.offset) +} + +fn apply_query_modifiers( + query: Query, + order_by: &[sp::OrderByExpr], + limit: &Option, + offset: &Option, +) -> Result { + // Only apply modifiers to Select queries + if let Query::Select(mut select) = query { + if !order_by.is_empty() { + select.order_by = order_by + .iter() + .map(|o| convert_order_by(o.clone())) + .collect::>>()?; + } + if let Some(l) = limit { + select.limit = Some(convert_expr(l.clone())?); + } + if let Some(o) = offset { + select.offset = Some(convert_expr(o.value.clone())?); + } + Ok(Query::Select(select)) + } else { + Ok(query) + } +} + +fn convert_cte(cte: sp::Cte) -> Result { + let columns = match cte.alias.columns.is_empty() { + true => vec![], + false => cte + .alias + .columns + .iter() + .map(|c| c.name.value.clone()) + .collect(), + }; + Ok(CTE { + name: cte.alias.name.value.clone(), + columns, + query: convert_query(*cte.query)?, + }) +} + +fn convert_set_expr(expr: sp::SetExpr) -> Result { + match expr { + sp::SetExpr::Select(select) => convert_select(*select), + sp::SetExpr::Query(query) => convert_query(*query), + sp::SetExpr::SetOperation { + op, + set_quantifier, + left, + right, + .. + } => { + let left_query = convert_set_expr(*left)?; + let right_query = convert_set_expr(*right)?; + + let all = matches!( + set_quantifier, + sp::SetQuantifier::All | sp::SetQuantifier::AllByName + ); + + let set_op = SetOperation { + op: match op { + sp::SetOperator::Union => SetOperator::Union, + sp::SetOperator::Intersect => SetOperator::Intersect, + sp::SetOperator::Except => SetOperator::Except, + }, + all, + right: right_query, + }; + + match left_query { + Query::Select(mut s) => { + s.set_op = Some(Box::new(set_op)); + Ok(Query::Select(s)) + } + other => { + // Wrap in a basic select + let s = SelectQuery { + set_op: Some(Box::new(set_op)), + from: vec![TableRef::Subquery { + query: Box::new(other), + alias: "_left".into(), + }], + ..Default::default() + }; + Ok(Query::Select(Box::new(s))) + } + } + } + sp::SetExpr::Values(values) => { + // VALUES as a standalone query - wrap in raw + Ok(Query::Raw(format!("VALUES {}", values))) + } + _ => Ok(Query::Raw(expr.to_string())), + } +} + +fn convert_select(select: sp::Select) -> Result { + let distinct = select.distinct.is_some(); + + let projections = select + .projection + .into_iter() + .map(convert_select_item) + .collect::>>()?; + + let from = select + .from + .into_iter() + .map(convert_table_with_joins) + .collect::>>()?; + + // Flatten: first element is the table, rest are joins + let (tables, join_lists): (Vec<_>, Vec<_>) = from.into_iter().unzip(); + + let joins: Vec = join_lists.into_iter().flatten().collect(); + + let filter = select.selection.map(convert_expr).transpose()?; + + let group_by = match select.group_by { + sp::GroupByExpr::Expressions(exprs, _modifiers) => exprs + .into_iter() + .map(convert_expr) + .collect::>>()?, + sp::GroupByExpr::All(_) => vec![], + }; + + let having = select.having.map(convert_expr).transpose()?; + + let windows = select + .named_window + .into_iter() + .map(|nw| { + let spec = convert_window_spec_from_named(&nw.1); + Ok(NamedWindowSpec { + name: nw.0.value.clone(), + spec: spec?, + }) + }) + .collect::>>()?; + + Ok(Query::Select(Box::new(SelectQuery { + distinct, + projections, + from: tables, + joins, + filter, + group_by, + having, + windows, + order_by: vec![], + limit: None, + offset: None, + set_op: None, + }))) +} + +fn convert_table_with_joins(twj: sp::TableWithJoins) -> Result<(TableRef, Vec)> { + let table = convert_table_factor(twj.relation)?; + let joins = twj + .joins + .into_iter() + .map(convert_join) + .collect::>>()?; + Ok((table, joins)) +} + +fn convert_table_factor(tf: sp::TableFactor) -> Result { + match tf { + sp::TableFactor::Table { name, alias, .. } => { + let parts: Vec<&str> = name.0.iter().map(|p| p.value.as_str()).collect(); + let (schema, table_name) = match parts.len() { + 1 => (None, parts[0].to_string()), + 2 => (Some(parts[0].to_string()), parts[1].to_string()), + _ => (None, name.to_string()), + }; + Ok(TableRef::Table { + schema, + name: table_name, + alias: alias.map(|a| a.name.value), + }) + } + sp::TableFactor::Derived { + subquery, alias, .. + } => { + let alias_name = alias + .map(|a| a.name.value) + .unwrap_or_else(|| "_subquery".into()); + Ok(TableRef::Subquery { + query: Box::new(convert_query(*subquery)?), + alias: alias_name, + }) + } + sp::TableFactor::TableFunction { expr, alias } => Ok(TableRef::Function { + name: expr.to_string(), + args: vec![], + alias: alias.map(|a| a.name.value), + }), + _ => Ok(TableRef::Table { + schema: None, + name: tf.to_string(), + alias: None, + }), + } +} + +fn convert_join(join: sp::Join) -> Result { + let join_type = match &join.join_operator { + sp::JoinOperator::Inner(_) => JoinType::Inner, + sp::JoinOperator::LeftOuter(_) => JoinType::Left, + sp::JoinOperator::RightOuter(_) => JoinType::Right, + sp::JoinOperator::FullOuter(_) => JoinType::Full, + sp::JoinOperator::CrossJoin => JoinType::Cross, + _ => JoinType::Inner, + }; + + let condition = match &join.join_operator { + sp::JoinOperator::Inner(c) + | sp::JoinOperator::LeftOuter(c) + | sp::JoinOperator::RightOuter(c) + | sp::JoinOperator::FullOuter(c) => convert_join_constraint(c)?, + _ => None, + }; + + Ok(Join { + join_type, + table: convert_table_factor(join.relation)?, + condition, + }) +} + +fn convert_join_constraint(constraint: &sp::JoinConstraint) -> Result> { + match constraint { + sp::JoinConstraint::On(expr) => Ok(Some(JoinCondition::On(convert_expr(expr.clone())?))), + sp::JoinConstraint::Using(cols) => Ok(Some(JoinCondition::Using( + cols.iter().map(|c| c.value.clone()).collect(), + ))), + sp::JoinConstraint::Natural => Ok(Some(JoinCondition::Natural)), + sp::JoinConstraint::None => Ok(None), + } +} + +fn convert_select_item(item: sp::SelectItem) -> Result { + match item { + sp::SelectItem::UnnamedExpr(expr) => Ok(SelectItem::Expression { + expr: convert_expr(expr)?, + alias: None, + }), + sp::SelectItem::ExprWithAlias { expr, alias } => Ok(SelectItem::Expression { + expr: convert_expr(expr)?, + alias: Some(alias.value), + }), + sp::SelectItem::Wildcard(_) => Ok(SelectItem::Wildcard), + sp::SelectItem::QualifiedWildcard(name, _) => { + Ok(SelectItem::QualifiedWildcard(name.to_string())) + } + } +} + +fn convert_expr(expr: sp::Expr) -> Result { + match expr { + sp::Expr::Identifier(ident) => Ok(Expression::Column { + table: None, + name: ident.value, + }), + sp::Expr::CompoundIdentifier(parts) => { + let names: Vec = parts.into_iter().map(|p| p.value).collect(); + match names.len() { + 1 => Ok(Expression::Column { + table: None, + name: names.into_iter().next().unwrap(), + }), + 2 => { + let mut iter = names.into_iter(); + Ok(Expression::Column { + table: Some(iter.next().unwrap()), + name: iter.next().unwrap(), + }) + } + _ => Ok(Expression::Column { + table: None, + name: names.join("."), + }), + } + } + sp::Expr::Value(val) => convert_value(val), + sp::Expr::BinaryOp { left, op, right } => Ok(Expression::BinaryOp { + left: Box::new(convert_expr(*left)?), + op: convert_binary_op(op)?, + right: Box::new(convert_expr(*right)?), + }), + sp::Expr::UnaryOp { op, expr } => Ok(Expression::UnaryOp { + op: convert_unary_op(op)?, + expr: Box::new(convert_expr(*expr)?), + }), + sp::Expr::Function(func) => convert_function(func), + sp::Expr::Case { + operand, + conditions, + results, + else_result, + } => { + let when_clauses = conditions + .into_iter() + .zip(results) + .map(|(c, r)| Ok((convert_expr(c)?, convert_expr(r)?))) + .collect::>>()?; + Ok(Expression::Case { + operand: operand.map(|o| convert_expr(*o)).transpose()?.map(Box::new), + when_clauses, + else_clause: else_result + .map(|e| convert_expr(*e)) + .transpose()? + .map(Box::new), + }) + } + sp::Expr::Subquery(q) => Ok(Expression::Subquery(Box::new(convert_query(*q)?))), + sp::Expr::Exists { subquery, negated } => { + let exists = Expression::Exists(Box::new(convert_query(*subquery)?)); + if negated { + Ok(Expression::UnaryOp { + op: UnaryOperator::Not, + expr: Box::new(exists), + }) + } else { + Ok(exists) + } + } + sp::Expr::InList { + expr, + list, + negated, + } => Ok(Expression::InList { + expr: Box::new(convert_expr(*expr)?), + list: list + .into_iter() + .map(convert_expr) + .collect::>>()?, + negated, + }), + sp::Expr::InSubquery { + expr, + subquery, + negated, + } => Ok(Expression::InSubquery { + expr: Box::new(convert_expr(*expr)?), + subquery: Box::new(convert_query(*subquery)?), + negated, + }), + sp::Expr::Between { + expr, + negated, + low, + high, + } => Ok(Expression::Between { + expr: Box::new(convert_expr(*expr)?), + low: Box::new(convert_expr(*low)?), + high: Box::new(convert_expr(*high)?), + negated, + }), + sp::Expr::IsNull(expr) => Ok(Expression::IsNull { + expr: Box::new(convert_expr(*expr)?), + negated: false, + }), + sp::Expr::IsNotNull(expr) => Ok(Expression::IsNull { + expr: Box::new(convert_expr(*expr)?), + negated: true, + }), + sp::Expr::Cast { + expr, data_type, .. + } => Ok(Expression::Cast { + expr: Box::new(convert_expr(*expr)?), + data_type: data_type.to_string(), + }), + sp::Expr::Nested(expr) => Ok(Expression::Nested(Box::new(convert_expr(*expr)?))), + sp::Expr::Like { + negated, + expr, + pattern, + .. + } => { + let op = if negated { + BinaryOperator::NotLike + } else { + BinaryOperator::Like + }; + Ok(Expression::BinaryOp { + left: Box::new(convert_expr(*expr)?), + op, + right: Box::new(convert_expr(*pattern)?), + }) + } + sp::Expr::ILike { + negated, + expr, + pattern, + .. + } => { + let op = if negated { + BinaryOperator::NotILike + } else { + BinaryOperator::ILike + }; + Ok(Expression::BinaryOp { + left: Box::new(convert_expr(*expr)?), + op, + right: Box::new(convert_expr(*pattern)?), + }) + } + sp::Expr::Array(arr) => { + let elems = arr + .elem + .into_iter() + .map(convert_expr) + .collect::>>()?; + Ok(Expression::Array(elems)) + } + sp::Expr::JsonAccess { value, path } => convert_json_access(*value, path), + _ => { + // Fallback: store as a literal string representation + Ok(Expression::Literal(Literal::String(expr.to_string()))) + } + } +} + +fn convert_json_access(value: sp::Expr, path: sp::JsonPath) -> Result { + let base = convert_expr(value)?; + let mut current = base; + + for element in path.path { + match element { + sp::JsonPathElem::Dot { key, .. } => { + current = Expression::JsonAccess { + expr: Box::new(current), + path: Box::new(Expression::Literal(Literal::String(key))), + as_text: false, + }; + } + sp::JsonPathElem::Bracket { key } => { + current = Expression::JsonAccess { + expr: Box::new(current), + path: Box::new(convert_expr(key)?), + as_text: false, + }; + } + } + } + + Ok(current) +} + +fn convert_value(val: sp::Value) -> Result { + match val { + sp::Value::Null => Ok(Expression::Literal(Literal::Null)), + sp::Value::Boolean(b) => Ok(Expression::Literal(Literal::Boolean(b))), + sp::Value::Number(n, _) => { + if let Ok(i) = n.parse::() { + Ok(Expression::Literal(Literal::Integer(i))) + } else if let Ok(f) = n.parse::() { + Ok(Expression::Literal(Literal::Float(f))) + } else { + Ok(Expression::Literal(Literal::String(n))) + } + } + sp::Value::SingleQuotedString(s) => Ok(Expression::Literal(Literal::String(s))), + sp::Value::DoubleQuotedString(s) => Ok(Expression::Literal(Literal::String(s))), + sp::Value::Placeholder(p) => { + // Parse $1, $2, etc. + if let Some(n) = p.strip_prefix('$') { + if let Ok(idx) = n.parse::() { + return Ok(Expression::Parameter(idx)); + } + } + Ok(Expression::Literal(Literal::String(p))) + } + _ => Ok(Expression::Literal(Literal::String(val.to_string()))), + } +} + +fn convert_binary_op(op: sp::BinaryOperator) -> Result { + match op { + sp::BinaryOperator::Eq => Ok(BinaryOperator::Eq), + sp::BinaryOperator::NotEq => Ok(BinaryOperator::NotEq), + sp::BinaryOperator::Lt => Ok(BinaryOperator::Lt), + sp::BinaryOperator::LtEq => Ok(BinaryOperator::LtEq), + sp::BinaryOperator::Gt => Ok(BinaryOperator::Gt), + sp::BinaryOperator::GtEq => Ok(BinaryOperator::GtEq), + sp::BinaryOperator::And => Ok(BinaryOperator::And), + sp::BinaryOperator::Or => Ok(BinaryOperator::Or), + sp::BinaryOperator::Plus => Ok(BinaryOperator::Plus), + sp::BinaryOperator::Minus => Ok(BinaryOperator::Minus), + sp::BinaryOperator::Multiply => Ok(BinaryOperator::Multiply), + sp::BinaryOperator::Divide => Ok(BinaryOperator::Divide), + sp::BinaryOperator::Modulo => Ok(BinaryOperator::Modulo), + sp::BinaryOperator::StringConcat => Ok(BinaryOperator::Concat), + _ => Err(anyhow!("Unsupported binary operator: {:?}", op)), + } +} + +fn convert_unary_op(op: sp::UnaryOperator) -> Result { + match op { + sp::UnaryOperator::Not => Ok(UnaryOperator::Not), + sp::UnaryOperator::Minus => Ok(UnaryOperator::Minus), + sp::UnaryOperator::Plus => Ok(UnaryOperator::Plus), + _ => Err(anyhow!("Unsupported unary operator: {:?}", op)), + } +} + +fn convert_function(func: sp::Function) -> Result { + let name = func.name.to_string().to_uppercase(); + + let (args, distinct) = match func.args { + sp::FunctionArguments::List(arg_list) => { + let distinct = matches!( + arg_list.duplicate_treatment, + Some(sp::DuplicateTreatment::Distinct) + ); + let args = arg_list + .args + .into_iter() + .filter_map(|a| match a { + sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => Some(convert_expr(e)), + sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => { + Some(Ok(Expression::Wildcard)) + } + sp::FunctionArg::Named { + arg: sp::FunctionArgExpr::Expr(e), + .. + } => Some(convert_expr(e)), + _ => None, + }) + .collect::>>()?; + (args, distinct) + } + sp::FunctionArguments::None => (vec![], false), + sp::FunctionArguments::Subquery(q) => ( + vec![Expression::Subquery(Box::new(convert_query(*q)?))], + false, + ), + }; + + // Check if this is a window function + if let Some(over) = func.over { + let window = match over { + sp::WindowType::WindowSpec(spec) => convert_window_spec(spec)?, + sp::WindowType::NamedWindow(_name) => { + // Reference to a named window - use empty spec as placeholder + WindowSpec { + partition_by: vec![], + order_by: vec![], + frame: None, + } + } + }; + + let function = Expression::Function { + name, + args, + distinct, + }; + + return Ok(Expression::WindowFunction { + function: Box::new(function), + window, + }); + } + + // Check if it's an aggregate function + let is_aggregate = matches!( + name.as_str(), + "COUNT" + | "SUM" + | "AVG" + | "MIN" + | "MAX" + | "ARRAY_AGG" + | "STRING_AGG" + | "BOOL_AND" + | "BOOL_OR" + ); + + if is_aggregate { + Ok(Expression::Aggregate { + name, + args, + distinct, + filter: None, + }) + } else { + Ok(Expression::Function { + name, + args, + distinct, + }) + } +} + +fn convert_window_spec(spec: sp::WindowSpec) -> Result { + let partition_by = spec + .partition_by + .into_iter() + .map(convert_expr) + .collect::>>()?; + + let order_by = spec + .order_by + .into_iter() + .map(convert_order_by) + .collect::>>()?; + + let frame = spec.window_frame.map(convert_window_frame).transpose()?; + + Ok(WindowSpec { + partition_by, + order_by, + frame, + }) +} + +fn convert_window_spec_from_named(spec: &sp::NamedWindowExpr) -> Result { + match spec { + sp::NamedWindowExpr::NamedWindow(_ident) => Ok(WindowSpec { + partition_by: vec![], + order_by: vec![], + frame: None, + }), + sp::NamedWindowExpr::WindowSpec(spec) => convert_window_spec(spec.clone()), + } +} + +fn convert_window_frame(frame: sp::WindowFrame) -> Result { + let mode = match frame.units { + sp::WindowFrameUnits::Rows => WindowFrameMode::Rows, + sp::WindowFrameUnits::Range => WindowFrameMode::Range, + sp::WindowFrameUnits::Groups => WindowFrameMode::Groups, + }; + + let start = convert_window_frame_bound(frame.start_bound)?; + let end = frame + .end_bound + .map(convert_window_frame_bound) + .transpose()?; + + Ok(WindowFrame { mode, start, end }) +} + +fn convert_window_frame_bound(bound: sp::WindowFrameBound) -> Result { + match bound { + sp::WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow), + sp::WindowFrameBound::Preceding(None) => Ok(WindowFrameBound::Preceding(None)), + sp::WindowFrameBound::Preceding(Some(expr)) => { + if let sp::Expr::Value(sp::Value::Number(n, _)) = *expr { + Ok(WindowFrameBound::Preceding(Some(n.parse().unwrap_or(0)))) + } else { + Ok(WindowFrameBound::Preceding(None)) + } + } + sp::WindowFrameBound::Following(None) => Ok(WindowFrameBound::Following(None)), + sp::WindowFrameBound::Following(Some(expr)) => { + if let sp::Expr::Value(sp::Value::Number(n, _)) = *expr { + Ok(WindowFrameBound::Following(Some(n.parse().unwrap_or(0)))) + } else { + Ok(WindowFrameBound::Following(None)) + } + } + } +} + +fn convert_order_by(order: sp::OrderByExpr) -> Result { + Ok(OrderByExpr { + expr: convert_expr(order.expr)?, + asc: order.asc, + nulls_first: order.nulls_first, + }) +} + +fn convert_insert(insert: sp::Insert) -> Result { + let table_name = insert.table_name.to_string(); + let table = TableRef::Table { + schema: None, + name: table_name, + alias: None, + }; + + let columns: Vec = insert.columns.iter().map(|c| c.value.clone()).collect(); + + let source = if let Some(src) = insert.source { + match *src.body { + sp::SetExpr::Values(values) => { + let rows = values + .rows + .into_iter() + .map(|row| { + row.into_iter() + .map(convert_expr) + .collect::>>() + }) + .collect::>>()?; + InsertSource::Values(rows) + } + other => { + let query = convert_set_expr(other)?; + InsertSource::Query(Box::new(query)) + } + } + } else { + InsertSource::Values(vec![]) + }; + + let returning = insert + .returning + .unwrap_or_default() + .into_iter() + .map(convert_select_item) + .collect::>>()?; + + Ok(Query::Insert(InsertQuery { + table, + columns, + source, + returning, + })) +} + +fn convert_update( + table: sp::TableWithJoins, + assignments: Vec, + selection: Option, + returning: Option>, +) -> Result { + let table_ref = convert_table_factor(table.relation)?; + + let assigns = assignments + .into_iter() + .map(|a| { + let column = a.target.to_string(); + Ok(Assignment { + column, + value: convert_expr(a.value)?, + }) + }) + .collect::>>()?; + + let filter = selection.map(convert_expr).transpose()?; + + let ret = returning + .unwrap_or_default() + .into_iter() + .map(convert_select_item) + .collect::>>()?; + + Ok(Query::Update(UpdateQuery { + table: table_ref, + assignments: assigns, + filter, + returning: ret, + })) +} + +fn convert_delete(delete: sp::Delete) -> Result { + // Extract tables from FromTable enum + let from_tables = match delete.from { + sp::FromTable::WithFromKeyword(tables) => tables, + sp::FromTable::WithoutKeyword(tables) => tables, + }; + + let table_ref = if let Some(twj) = from_tables.into_iter().next() { + convert_table_factor(twj.relation)? + } else { + return Err(anyhow!("DELETE without table reference")); + }; + + let filter = delete.selection.map(convert_expr).transpose()?; + + let returning = delete + .returning + .unwrap_or_default() + .into_iter() + .map(convert_select_item) + .collect::>>()?; + + Ok(Query::Delete(DeleteQuery { + table: table_ref, + filter, + returning, + })) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_simple_select() { + let q = parse_single("SELECT * FROM users").unwrap(); + match q { + Query::Select(s) => { + assert_eq!(s.projections.len(), 1); + assert!(matches!(s.projections[0], SelectItem::Wildcard)); + assert_eq!(s.from.len(), 1); + } + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_select_with_where() { + let q = parse_single("SELECT id, name FROM users WHERE age > 18").unwrap(); + match q { + Query::Select(s) => { + assert_eq!(s.projections.len(), 2); + assert!(s.filter.is_some()); + } + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_select_with_join() { + let q = + parse_single("SELECT u.name, o.total FROM users u JOIN orders o ON u.id = o.user_id") + .unwrap(); + match q { + Query::Select(s) => { + assert_eq!(s.joins.len(), 1); + assert!(matches!(s.joins[0].join_type, JoinType::Inner)); + } + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_select_with_group_by() { + let q = parse_single( + "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5", + ) + .unwrap(); + match q { + Query::Select(s) => { + assert_eq!(s.group_by.len(), 1); + assert!(s.having.is_some()); + } + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_cte() { + let q = parse_single( + "WITH active AS (SELECT * FROM users WHERE active = true) SELECT * FROM active", + ) + .unwrap(); + match q { + Query::With(cte) => { + assert!(!cte.recursive); + assert_eq!(cte.ctes.len(), 1); + assert_eq!(cte.ctes[0].name, "active"); + } + _ => panic!("Expected CTE query"), + } + } + + #[test] + fn test_parse_recursive_cte() { + let q = parse_single( + "WITH RECURSIVE nums AS (SELECT 1 AS n UNION ALL SELECT n + 1 FROM nums WHERE n < 10) SELECT * FROM nums", + ) + .unwrap(); + match q { + Query::With(cte) => { + assert!(cte.recursive); + } + _ => panic!("Expected CTE query"), + } + } + + #[test] + fn test_parse_window_function() { + let q = parse_single( + "SELECT name, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) FROM employees", + ) + .unwrap(); + match q { + Query::Select(s) => { + assert_eq!(s.projections.len(), 2); + match &s.projections[1] { + SelectItem::Expression { expr, .. } => { + assert!(matches!(expr, Expression::WindowFunction { .. })); + } + _ => panic!("Expected window function expression"), + } + } + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_insert() { + let q = parse_single("INSERT INTO users (name, email) VALUES ('John', 'john@example.com')") + .unwrap(); + match q { + Query::Insert(i) => { + assert_eq!(i.columns.len(), 2); + match &i.source { + InsertSource::Values(rows) => assert_eq!(rows.len(), 1), + _ => panic!("Expected values source"), + } + } + _ => panic!("Expected Insert query"), + } + } + + #[test] + fn test_parse_update() { + let q = parse_single("UPDATE users SET name = 'Jane' WHERE id = 1").unwrap(); + match q { + Query::Update(u) => { + assert_eq!(u.assignments.len(), 1); + assert_eq!(u.assignments[0].column, "name"); + assert!(u.filter.is_some()); + } + _ => panic!("Expected Update query"), + } + } + + #[test] + fn test_parse_delete() { + let q = parse_single("DELETE FROM users WHERE id = 1").unwrap(); + match q { + Query::Delete(d) => { + assert!(d.filter.is_some()); + } + _ => panic!("Expected Delete query"), + } + } + + #[test] + fn test_parse_subquery() { + let q = parse_single("SELECT * FROM users WHERE id IN (SELECT user_id FROM active_users)") + .unwrap(); + match q { + Query::Select(s) => { + assert!(s.filter.is_some()); + match s.filter.unwrap() { + Expression::InSubquery { negated, .. } => { + assert!(!negated); + } + _ => panic!("Expected InSubquery expression"), + } + } + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_multiple_statements() { + let queries = parse_sql("SELECT 1; SELECT 2").unwrap(); + assert_eq!(queries.len(), 2); + } + + #[test] + fn test_parse_invalid_sql() { + assert!(parse_single("SELCT * FORM users").is_err()); + } + + #[test] + fn test_parse_union() { + let q = parse_single("SELECT id FROM users UNION ALL SELECT id FROM admins").unwrap(); + match q { + Query::Select(s) => { + assert!(s.set_op.is_some()); + let set_op = s.set_op.unwrap(); + assert!(matches!(set_op.op, SetOperator::Union)); + assert!(set_op.all); + } + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_order_by_limit() { + let q = parse_single("SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5").unwrap(); + match q { + Query::Select(s) => { + assert_eq!(s.order_by.len(), 1); + assert_eq!(s.order_by[0].asc, Some(true)); + assert!(s.limit.is_some()); + assert!(s.offset.is_some()); + } + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_between() { + let q = parse_single("SELECT * FROM products WHERE price BETWEEN 10 AND 100").unwrap(); + match q { + Query::Select(s) => { + assert!(matches!( + s.filter, + Some(Expression::Between { negated: false, .. }) + )); + } + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_case_expression() { + let q = parse_single("SELECT CASE WHEN status = 'active' THEN 1 ELSE 0 END FROM users") + .unwrap(); + match q { + Query::Select(s) => match &s.projections[0] { + SelectItem::Expression { expr, .. } => { + assert!(matches!(expr, Expression::Case { .. })); + } + _ => panic!("Expected expression"), + }, + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_is_null() { + let q = parse_single("SELECT * FROM users WHERE email IS NOT NULL").unwrap(); + match q { + Query::Select(s) => { + assert!(matches!( + s.filter, + Some(Expression::IsNull { negated: true, .. }) + )); + } + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_cast() { + let q = parse_single("SELECT CAST(price AS INTEGER) FROM products").unwrap(); + match q { + Query::Select(s) => match &s.projections[0] { + SelectItem::Expression { expr, .. } => { + assert!(matches!(expr, Expression::Cast { .. })); + } + _ => panic!("Expected expression"), + }, + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_aggregate_distinct() { + let q = parse_single("SELECT COUNT(DISTINCT status) FROM orders").unwrap(); + match q { + Query::Select(s) => match &s.projections[0] { + SelectItem::Expression { expr, .. } => match expr { + Expression::Aggregate { distinct, name, .. } => { + assert!(distinct); + assert_eq!(name, "COUNT"); + } + _ => panic!("Expected aggregate"), + }, + _ => panic!("Expected expression"), + }, + _ => panic!("Expected Select query"), + } + } + + #[test] + fn test_parse_left_join() { + let q = parse_single("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap(); + match q { + Query::Select(s) => { + assert_eq!(s.joins.len(), 1); + assert!(matches!(s.joins[0].join_type, JoinType::Left)); + } + _ => panic!("Expected Select query"), + } + } +} diff --git a/src/ast/plugin.rs b/src/ast/plugin.rs new file mode 100644 index 0000000..8a0552e --- /dev/null +++ b/src/ast/plugin.rs @@ -0,0 +1,164 @@ +/// Plugin architecture for extending pgrsql. +/// +/// Plugins can register new query languages, optimization passes, +/// custom SQL functions, and execution strategies. This provides +/// a clean extension point without modifying core code. +use anyhow::Result; + +use super::adapter::{DSLAdapter, QueryLanguageAdapter}; +use super::optimizer::OptimizationPass; + +/// Trait that all pgrsql plugins must implement. +/// +/// A plugin registers its capabilities with the `PluginRegistry` +/// during initialization. Plugins are loaded and initialized once +/// at startup. +/// +/// # Example +/// +/// ```ignore +/// struct MyPlugin; +/// +/// impl QueryPlugin for MyPlugin { +/// fn name(&self) -> &str { "my-plugin" } +/// fn version(&self) -> &str { "0.1.0" } +/// fn register(&self, registry: &mut PluginRegistry) -> Result<()> { +/// registry.add_optimization_pass(Box::new(MyOptPass)); +/// Ok(()) +/// } +/// } +/// ``` +pub trait QueryPlugin: Send + Sync { + /// Unique plugin identifier. + fn name(&self) -> &str; + + /// Plugin version string. + fn version(&self) -> &str; + + /// Optional description. + fn description(&self) -> &str { + "" + } + + /// Register plugin capabilities with the registry. + fn register(&self, registry: &mut PluginRegistry) -> Result<()>; +} + +/// Central registry for all plugin-provided capabilities. +#[derive(Default)] +pub struct PluginRegistry { + query_adapters: Vec>, + dsl_adapters: Vec>, + optimization_passes: Vec>, + loaded_plugins: Vec, +} + +#[derive(Debug, Clone)] +pub struct PluginInfo { + pub name: String, + pub version: String, + pub description: String, +} + +impl PluginRegistry { + pub fn new() -> Self { + Self::default() + } + + /// Register a query language adapter. + pub fn add_query_adapter(&mut self, adapter: Box) { + self.query_adapters.push(adapter); + } + + /// Register a DSL adapter. + pub fn add_dsl_adapter(&mut self, adapter: Box) { + self.dsl_adapters.push(adapter); + } + + /// Register an optimization pass. + pub fn add_optimization_pass(&mut self, pass: Box) { + self.optimization_passes.push(pass); + } + + /// Load and initialize a plugin. + pub fn load_plugin(&mut self, plugin: Box) -> Result<()> { + let info = PluginInfo { + name: plugin.name().to_string(), + version: plugin.version().to_string(), + description: plugin.description().to_string(), + }; + + plugin.register(self)?; + self.loaded_plugins.push(info); + Ok(()) + } + + /// Get all registered query adapters. + pub fn query_adapters(&self) -> &[Box] { + &self.query_adapters + } + + /// Get all registered DSL adapters. + pub fn dsl_adapters(&self) -> &[Box] { + &self.dsl_adapters + } + + /// Take ownership of all optimization passes (for building an optimizer). + pub fn take_optimization_passes(&mut self) -> Vec> { + std::mem::take(&mut self.optimization_passes) + } + + /// List loaded plugins. + pub fn loaded_plugins(&self) -> &[PluginInfo] { + &self.loaded_plugins + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct TestPlugin; + + impl QueryPlugin for TestPlugin { + fn name(&self) -> &str { + "test-plugin" + } + + fn version(&self) -> &str { + "0.1.0" + } + + fn description(&self) -> &str { + "A test plugin" + } + + fn register(&self, _registry: &mut PluginRegistry) -> Result<()> { + Ok(()) + } + } + + #[test] + fn test_registry_empty() { + let registry = PluginRegistry::new(); + assert!(registry.loaded_plugins().is_empty()); + assert!(registry.query_adapters().is_empty()); + } + + #[test] + fn test_load_plugin() { + let mut registry = PluginRegistry::new(); + registry.load_plugin(Box::new(TestPlugin)).unwrap(); + assert_eq!(registry.loaded_plugins().len(), 1); + assert_eq!(registry.loaded_plugins()[0].name, "test-plugin"); + assert_eq!(registry.loaded_plugins()[0].version, "0.1.0"); + } + + #[test] + fn test_take_optimization_passes() { + let mut registry = PluginRegistry::new(); + // Initially empty + let passes = registry.take_optimization_passes(); + assert!(passes.is_empty()); + } +} diff --git a/src/ast/types.rs b/src/ast/types.rs new file mode 100644 index 0000000..7901efe --- /dev/null +++ b/src/ast/types.rs @@ -0,0 +1,428 @@ +//! Unified Query AST types for pgrsql. +//! +//! This module defines the internal representation used by all language adapters, +//! optimization passes, and the SQL compiler. The AST is designed to be: +//! - Language-agnostic (any DSL can compile to it) +//! - Immutable-friendly (clone-based transformations) +//! - Extensible (new node types can be added without breaking existing passes) + +/// Top-level query representation. +#[derive(Debug, Clone, PartialEq)] +pub enum Query { + Select(Box), + Insert(InsertQuery), + Update(UpdateQuery), + Delete(DeleteQuery), + /// Common Table Expressions wrapping an inner query. + With(CTEQuery), + /// Raw SQL passthrough for unsupported or complex statements. + Raw(String), +} + +/// A SELECT query with all standard SQL clauses. +#[derive(Debug, Clone, PartialEq, Default)] +pub struct SelectQuery { + pub distinct: bool, + pub projections: Vec, + pub from: Vec, + pub joins: Vec, + pub filter: Option, + pub group_by: Vec, + pub having: Option, + pub windows: Vec, + pub order_by: Vec, + pub limit: Option, + pub offset: Option, + /// Set operations (UNION, INTERSECT, EXCEPT). + pub set_op: Option>, +} + +/// A single item in the SELECT projection list. +#[derive(Debug, Clone, PartialEq)] +pub enum SelectItem { + /// `*` + Wildcard, + /// `table.*` + QualifiedWildcard(String), + /// An expression, optionally aliased: `expr AS alias`. + Expression { + expr: Expression, + alias: Option, + }, +} + +/// Table reference in FROM clause. +#[derive(Debug, Clone, PartialEq)] +pub enum TableRef { + /// Simple table: `schema.table AS alias` + Table { + schema: Option, + name: String, + alias: Option, + }, + /// Subquery: `(SELECT ...) AS alias` + Subquery { query: Box, alias: String }, + /// Table-valued function: `generate_series(1, 10) AS alias` + Function { + name: String, + args: Vec, + alias: Option, + }, +} + +/// JOIN clause representation. +#[derive(Debug, Clone, PartialEq)] +pub struct Join { + pub join_type: JoinType, + pub table: TableRef, + pub condition: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum JoinType { + Inner, + Left, + Right, + Full, + Cross, + Lateral, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum JoinCondition { + On(Expression), + Using(Vec), + Natural, +} + +/// Core expression type. Recursive to support arbitrary nesting. +#[derive(Debug, Clone, PartialEq)] +pub enum Expression { + /// Column reference: `table.column` or just `column`. + Column { table: Option, name: String }, + /// Literal value. + Literal(Literal), + /// Binary operation: `left op right`. + BinaryOp { + left: Box, + op: BinaryOperator, + right: Box, + }, + /// Unary operation: `op expr` (e.g., NOT, -). + UnaryOp { + op: UnaryOperator, + expr: Box, + }, + /// Function call: `name(args)`. + Function { + name: String, + args: Vec, + distinct: bool, + }, + /// Aggregate function with optional filter. + Aggregate { + name: String, + args: Vec, + distinct: bool, + filter: Option>, + }, + /// Window function: `expr OVER (...)`. + WindowFunction { + function: Box, + window: WindowSpec, + }, + /// CASE expression. + Case { + operand: Option>, + when_clauses: Vec<(Expression, Expression)>, + else_clause: Option>, + }, + /// Subquery expression: `(SELECT ...)`. + Subquery(Box), + /// EXISTS (SELECT ...). + Exists(Box), + /// expr IN (values or subquery). + InList { + expr: Box, + list: Vec, + negated: bool, + }, + InSubquery { + expr: Box, + subquery: Box, + negated: bool, + }, + /// expr BETWEEN low AND high. + Between { + expr: Box, + low: Box, + high: Box, + negated: bool, + }, + /// expr IS NULL / IS NOT NULL. + IsNull { + expr: Box, + negated: bool, + }, + /// CAST(expr AS type). + Cast { + expr: Box, + data_type: String, + }, + /// Wildcard `*` (used in COUNT(*)). + Wildcard, + /// Parameter placeholder: `$1`, `$2`, etc. + Parameter(usize), + /// Array expression: `ARRAY[...]`. + Array(Vec), + /// JSON access: `expr->key`, `expr->>key`. + JsonAccess { + expr: Box, + path: Box, + as_text: bool, + }, + /// Type-cast using `::` operator (PostgreSQL specific). + TypeCast { + expr: Box, + data_type: String, + }, + /// Nested expression (parenthesized). + Nested(Box), +} + +/// Literal values in SQL. +#[derive(Debug, Clone, PartialEq)] +pub enum Literal { + Null, + Boolean(bool), + Integer(i64), + Float(f64), + String(String), +} + +/// Binary operators. +#[derive(Debug, Clone, PartialEq)] +pub enum BinaryOperator { + // Comparison + Eq, + NotEq, + Lt, + LtEq, + Gt, + GtEq, + // Logical + And, + Or, + // Arithmetic + Plus, + Minus, + Multiply, + Divide, + Modulo, + // String + Like, + ILike, + NotLike, + NotILike, + // Other + Concat, +} + +/// Unary operators. +#[derive(Debug, Clone, PartialEq)] +pub enum UnaryOperator { + Not, + Minus, + Plus, +} + +/// Window specification for window functions. +#[derive(Debug, Clone, PartialEq)] +pub struct WindowSpec { + pub partition_by: Vec, + pub order_by: Vec, + pub frame: Option, +} + +/// Named window definition for WINDOW clause. +#[derive(Debug, Clone, PartialEq)] +pub struct NamedWindowSpec { + pub name: String, + pub spec: WindowSpec, +} + +/// Window frame specification. +#[derive(Debug, Clone, PartialEq)] +pub struct WindowFrame { + pub mode: WindowFrameMode, + pub start: WindowFrameBound, + pub end: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum WindowFrameMode { + Rows, + Range, + Groups, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum WindowFrameBound { + CurrentRow, + Preceding(Option), + Following(Option), +} + +/// ORDER BY expression. +#[derive(Debug, Clone, PartialEq)] +pub struct OrderByExpr { + pub expr: Expression, + pub asc: Option, + pub nulls_first: Option, +} + +/// Set operations (UNION, INTERSECT, EXCEPT). +#[derive(Debug, Clone, PartialEq)] +pub struct SetOperation { + pub op: SetOperator, + pub all: bool, + pub right: Query, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum SetOperator { + Union, + Intersect, + Except, +} + +/// Common Table Expression (WITH clause). +#[derive(Debug, Clone, PartialEq)] +pub struct CTEQuery { + pub recursive: bool, + pub ctes: Vec, + pub body: Box, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct CTE { + pub name: String, + pub columns: Vec, + pub query: Query, +} + +/// INSERT statement. +#[derive(Debug, Clone, PartialEq)] +pub struct InsertQuery { + pub table: TableRef, + pub columns: Vec, + pub source: InsertSource, + pub returning: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum InsertSource { + Values(Vec>), + Query(Box), +} + +/// UPDATE statement. +#[derive(Debug, Clone, PartialEq)] +pub struct UpdateQuery { + pub table: TableRef, + pub assignments: Vec, + pub filter: Option, + pub returning: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Assignment { + pub column: String, + pub value: Expression, +} + +/// DELETE statement. +#[derive(Debug, Clone, PartialEq)] +pub struct DeleteQuery { + pub table: TableRef, + pub filter: Option, + pub returning: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_select_query() { + let q = SelectQuery::default(); + assert!(!q.distinct); + assert!(q.projections.is_empty()); + assert!(q.from.is_empty()); + assert!(q.filter.is_none()); + assert!(q.limit.is_none()); + } + + #[test] + fn test_query_clone() { + let q = Query::Select(Box::new(SelectQuery { + distinct: true, + projections: vec![SelectItem::Wildcard], + from: vec![TableRef::Table { + schema: None, + name: "users".into(), + alias: None, + }], + ..Default::default() + })); + let q2 = q.clone(); + assert_eq!(q, q2); + } + + #[test] + fn test_expression_nesting() { + let expr = Expression::BinaryOp { + left: Box::new(Expression::Column { + table: None, + name: "age".into(), + }), + op: BinaryOperator::Gt, + right: Box::new(Expression::Literal(Literal::Integer(18))), + }; + // Verify we can clone deeply nested expressions + let _ = expr.clone(); + } + + #[test] + fn test_literal_equality() { + assert_eq!(Literal::Null, Literal::Null); + assert_eq!(Literal::Boolean(true), Literal::Boolean(true)); + assert_ne!(Literal::Integer(1), Literal::Integer(2)); + assert_eq!( + Literal::String("hello".into()), + Literal::String("hello".into()) + ); + } + + #[test] + fn test_cte_query_structure() { + let cte = CTEQuery { + recursive: true, + ctes: vec![CTE { + name: "recursive_cte".into(), + columns: vec!["n".into()], + query: Query::Select(Box::new(SelectQuery { + projections: vec![SelectItem::Expression { + expr: Expression::Literal(Literal::Integer(1)), + alias: Some("n".into()), + }], + ..Default::default() + })), + }], + body: Box::new(Query::Select(Box::new(SelectQuery::default()))), + }; + assert!(cte.recursive); + assert_eq!(cte.ctes.len(), 1); + assert_eq!(cte.ctes[0].name, "recursive_cte"); + } +} diff --git a/src/main.rs b/src/main.rs index 16bc4db..690ab37 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +pub mod ast; mod db; mod editor; mod ui;