From 6d2cb69de802f53ad6527e84e710db7bbcbcccef Mon Sep 17 00:00:00 2001 From: Allen Cheng Date: Tue, 13 Jan 2026 20:27:29 -0800 Subject: [PATCH] feat(graph): add WITH clause with query chaining support Adds WITH clause for intermediate projections, aggregations, and query chaining in Cypher queries. Supported syntax: - WITH projection: WITH p.name AS name - WITH aggregation: WITH city, count(*) AS total - WITH ORDER BY/LIMIT: WITH p ORDER BY p.age DESC LIMIT 10 - Post-WITH WHERE: WITH ... WHERE total > 1 - Post-WITH MATCH: WITH ... MATCH (p2:Person) ... Changes: - Add WithClause to AST with items, order_by, limit fields - Add with_clause, post_with_match_clauses, post_with_where_clause to CypherQuery - Parse WITH clause and post-WITH MATCH/WHERE in parser - Add semantic analysis for WITH scope - Add plan_with_clause in logical planner - Chain post-WITH MATCH using plan_match_clause_with_base - Add 5 comprehensive tests for WITH functionality Note: WITH p (passing whole node) then MATCH (p)-[]->(f) requires explicit property projection. Use WITH p.id AS id ... instead. --- rust/lance-graph/src/ast.rs | 22 +- rust/lance-graph/src/logical_plan.rs | 67 +++++- rust/lance-graph/src/parser.rs | 39 +++- rust/lance-graph/src/query.rs | 3 + rust/lance-graph/src/semantic.rs | 73 ++++++- .../tests/test_datafusion_pipeline.rs | 195 ++++++++++++++++++ 6 files changed, 390 insertions(+), 9 deletions(-) diff --git a/rust/lance-graph/src/ast.rs b/rust/lance-graph/src/ast.rs index 94c1438..2e770ab 100644 --- a/rust/lance-graph/src/ast.rs +++ b/rust/lance-graph/src/ast.rs @@ -15,8 +15,14 @@ use std::collections::HashMap; pub struct CypherQuery { /// MATCH clauses pub match_clauses: Vec, - /// WHERE clause (optional) + /// WHERE clause (optional, before WITH if present) pub where_clause: Option, + /// WITH clause (optional) - intermediate projection/aggregation + pub with_clause: Option, + /// MATCH clauses after WITH (optional) - query chaining + pub post_with_match_clauses: Vec, + /// WHERE clause after WITH (optional) - filters the WITH results + pub post_with_where_clause: Option, /// RETURN clause pub return_clause: ReturnClause, /// LIMIT clause (optional) @@ -323,6 +329,20 @@ pub enum ArithmeticOperator { Modulo, } +/// WITH clause for intermediate projections/aggregations +/// +/// WITH acts as a query stage boundary, projecting results that become +/// the input for subsequent clauses. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct WithClause { + /// Items to project (similar to RETURN) + pub items: Vec, + /// Optional ORDER BY within WITH + pub order_by: Option, + /// Optional LIMIT within WITH + pub limit: Option, +} + /// RETURN clause specifying what to return #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ReturnClause { diff --git a/rust/lance-graph/src/logical_plan.rs b/rust/lance-graph/src/logical_plan.rs index 2f45eed..392feb6 100644 --- a/rust/lance-graph/src/logical_plan.rs +++ b/rust/lance-graph/src/logical_plan.rs @@ -156,7 +156,7 @@ impl LogicalPlanner { // Start with the MATCH clause(s) let mut plan = self.plan_match_clauses(&query.match_clauses)?; - // Apply WHERE clause if present + // Apply WHERE clause if present (before WITH) if let Some(where_clause) = &query.where_clause { plan = LogicalOperator::Filter { input: Box::new(plan), @@ -164,6 +164,24 @@ impl LogicalPlanner { }; } + // Apply WITH clause if present (intermediate projection/aggregation) + if let Some(with_clause) = &query.with_clause { + plan = self.plan_with_clause(with_clause, plan)?; + } + + // Apply post-WITH MATCH clauses if present (query chaining) + for match_clause in &query.post_with_match_clauses { + plan = self.plan_match_clause_with_base(Some(plan), match_clause)?; + } + + // Apply post-WITH WHERE clause if present + if let Some(post_where) = &query.post_with_where_clause { + plan = LogicalOperator::Filter { + input: Box::new(plan), + predicate: post_where.expression.clone(), + }; + } + // Apply RETURN clause plan = self.plan_return_clause(&query.return_clause, plan)?; @@ -429,6 +447,53 @@ impl LogicalPlanner { Ok(plan) } + + /// Plan WITH clause - intermediate projection/aggregation with optional ORDER BY and LIMIT + fn plan_with_clause( + &self, + with_clause: &WithClause, + input: LogicalOperator, + ) -> Result { + // WITH creates a projection (like RETURN) + let projections = with_clause + .items + .iter() + .map(|item| ProjectionItem { + expression: item.expression.clone(), + alias: item.alias.clone(), + }) + .collect(); + + let mut plan = LogicalOperator::Project { + input: Box::new(input), + projections, + }; + + // Apply ORDER BY within WITH if present + if let Some(order_by) = &with_clause.order_by { + plan = LogicalOperator::Sort { + input: Box::new(plan), + sort_items: order_by + .items + .iter() + .map(|item| SortItem { + expression: item.expression.clone(), + direction: item.direction.clone(), + }) + .collect(), + }; + } + + // Apply LIMIT within WITH if present + if let Some(limit) = with_clause.limit { + plan = LogicalOperator::Limit { + input: Box::new(plan), + count: limit, + }; + } + + Ok(plan) + } } impl Default for LogicalPlanner { diff --git a/rust/lance-graph/src/parser.rs b/rust/lance-graph/src/parser.rs index 7427437..b458dd1 100644 --- a/rust/lance-graph/src/parser.rs +++ b/rust/lance-graph/src/parser.rs @@ -42,7 +42,20 @@ pub fn parse_cypher_query(input: &str) -> Result { fn cypher_query(input: &str) -> IResult<&str, CypherQuery> { let (input, _) = multispace0(input)?; let (input, match_clauses) = many0(match_clause)(input)?; - let (input, where_clause) = opt(where_clause)(input)?; + let (input, pre_with_where) = opt(where_clause)(input)?; + + // Optional WITH clause with optional post-WITH MATCH and WHERE + let (input, with_result) = opt(with_clause)(input)?; + // Only try to parse post-WITH clauses if we have a WITH clause + let (input, post_with_matches, post_with_where) = match with_result { + Some(_) => { + let (input, matches) = many0(match_clause)(input)?; + let (input, where_cl) = opt(where_clause)(input)?; + (input, matches, where_cl) + } + None => (input, vec![], None), + }; + let (input, return_clause) = return_clause(input)?; let (input, order_by) = opt(order_by_clause)(input)?; let (input, (skip, limit)) = pagination_clauses(input)?; @@ -52,7 +65,10 @@ fn cypher_query(input: &str) -> IResult<&str, CypherQuery> { input, CypherQuery { match_clauses, - where_clause, + where_clause: pre_with_where, + with_clause: with_result, + post_with_match_clauses: post_with_matches, + post_with_where_clause: post_with_where, return_clause, limit, order_by, @@ -657,6 +673,25 @@ fn property_reference(input: &str) -> IResult<&str, PropertyRef> { )) } +// Parse a WITH clause (intermediate projection/aggregation) +fn with_clause(input: &str) -> IResult<&str, WithClause> { + let (input, _) = multispace0(input)?; + let (input, _) = tag_no_case("WITH")(input)?; + let (input, _) = multispace1(input)?; + let (input, items) = separated_list0(comma_ws, return_item)(input)?; + let (input, order_by) = opt(order_by_clause)(input)?; + let (input, limit) = opt(limit_clause)(input)?; + + Ok(( + input, + WithClause { + items, + order_by, + limit, + }, + )) +} + // Parse a RETURN clause fn return_clause(input: &str) -> IResult<&str, ReturnClause> { let (input, _) = multispace0(input)?; diff --git a/rust/lance-graph/src/query.rs b/rust/lance-graph/src/query.rs index c453547..8028eb0 100644 --- a/rust/lance-graph/src/query.rs +++ b/rust/lance-graph/src/query.rs @@ -1178,6 +1178,9 @@ impl CypherQueryBuilder { where_clause: self .where_expression .map(|expr| crate::ast::WhereClause { expression: expr }), + with_clause: None, // WITH not supported via builder yet + post_with_match_clauses: vec![], + post_with_where_clause: None, return_clause: crate::ast::ReturnClause { distinct: self.distinct, items: self.return_items, diff --git a/rust/lance-graph/src/semantic.rs b/rust/lance-graph/src/semantic.rs index bc5fdce..dfe4453 100644 --- a/rust/lance-graph/src/semantic.rs +++ b/rust/lance-graph/src/semantic.rs @@ -44,6 +44,8 @@ pub enum VariableType { pub enum ScopeType { Match, Where, + With, + PostWithWhere, Return, OrderBy, } @@ -79,7 +81,7 @@ impl SemanticAnalyzer { } } - // Phase 2: Validate WHERE clause + // Phase 2: Validate WHERE clause (before WITH) if let Some(where_clause) = &query.where_clause { self.current_scope = ScopeType::Where; if let Err(e) = self.analyze_where_clause(where_clause) { @@ -87,13 +89,29 @@ impl SemanticAnalyzer { } } - // Phase 3: Validate RETURN clause + // Phase 3: Validate WITH clause if present + if let Some(with_clause) = &query.with_clause { + self.current_scope = ScopeType::With; + if let Err(e) = self.analyze_with_clause(with_clause) { + errors.push(format!("WITH clause error: {}", e)); + } + } + + // Phase 4: Validate post-WITH WHERE clause if present + if let Some(post_where) = &query.post_with_where_clause { + self.current_scope = ScopeType::PostWithWhere; + if let Err(e) = self.analyze_where_clause(post_where) { + errors.push(format!("Post-WITH WHERE clause error: {}", e)); + } + } + + // Phase 5: Validate RETURN clause self.current_scope = ScopeType::Return; if let Err(e) = self.analyze_return_clause(&query.return_clause) { errors.push(format!("RETURN clause error: {}", e)); } - // Phase 4: Validate ORDER BY clause + // Phase 6: Validate ORDER BY clause if let Some(order_by) = &query.order_by { self.current_scope = ScopeType::OrderBy; if let Err(e) = self.analyze_order_by_clause(order_by) { @@ -101,10 +119,10 @@ impl SemanticAnalyzer { } } - // Phase 5: Schema validation + // Phase 7: Schema validation self.validate_schema(&mut warnings); - // Phase 6: Type checking + // Phase 8: Type checking self.validate_types(&mut errors); Ok(SemanticResult { @@ -416,6 +434,21 @@ impl SemanticAnalyzer { Ok(()) } + /// Analyze WITH clause + fn analyze_with_clause(&mut self, with_clause: &WithClause) -> Result<()> { + // Validate WITH item expressions (similar to RETURN) + for item in &with_clause.items { + self.analyze_value_expression(&item.expression)?; + } + // Validate ORDER BY within WITH if present + if let Some(order_by) = &with_clause.order_by { + for item in &order_by.items { + self.analyze_value_expression(&item.expression)?; + } + } + Ok(()) + } + /// Analyze ORDER BY clause fn analyze_order_by_clause(&mut self, order_by: &OrderByClause) -> Result<()> { for item in &order_by.items { @@ -558,6 +591,9 @@ mod tests { let query = CypherQuery { match_clauses: vec![], where_clause: None, + with_clause: None, + post_with_match_clauses: vec![], + post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![ReturnItem { @@ -585,6 +621,9 @@ mod tests { patterns: vec![GraphPattern::Node(node)], }], where_clause: None, + with_clause: None, + post_with_match_clauses: vec![], + post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![ReturnItem { @@ -615,6 +654,9 @@ mod tests { patterns: vec![GraphPattern::Node(node1), GraphPattern::Node(node2)], }], where_clause: None, + with_clause: None, + post_with_match_clauses: vec![], + post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![], @@ -661,6 +703,9 @@ mod tests { patterns: vec![GraphPattern::Path(path)], }], where_clause: None, + with_clause: None, + post_with_match_clauses: vec![], + post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![], @@ -690,6 +735,9 @@ mod tests { patterns: vec![GraphPattern::Node(node)], }], where_clause: Some(where_clause), + with_clause: None, + post_with_match_clauses: vec![], + post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![], @@ -729,6 +777,9 @@ mod tests { patterns: vec![GraphPattern::Path(path)], }], where_clause: None, + with_clause: None, + post_with_match_clauses: vec![], + post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![], @@ -755,6 +806,9 @@ mod tests { patterns: vec![GraphPattern::Node(node)], }], where_clause: None, + with_clause: None, + post_with_match_clauses: vec![], + post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![], @@ -792,6 +846,9 @@ mod tests { patterns: vec![GraphPattern::Node(node)], }], where_clause: None, + with_clause: None, + post_with_match_clauses: vec![], + post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![], @@ -834,6 +891,9 @@ mod tests { patterns: vec![GraphPattern::Path(path)], }], where_clause: None, + with_clause: None, + post_with_match_clauses: vec![], + post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![], @@ -898,6 +958,9 @@ mod tests { patterns: vec![GraphPattern::Path(path)], }], where_clause: None, + with_clause: None, + post_with_match_clauses: vec![], + post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![], diff --git a/rust/lance-graph/tests/test_datafusion_pipeline.rs b/rust/lance-graph/tests/test_datafusion_pipeline.rs index 9e3837b..7f7d4ec 100644 --- a/rust/lance-graph/tests/test_datafusion_pipeline.rs +++ b/rust/lance-graph/tests/test_datafusion_pipeline.rs @@ -4311,3 +4311,198 @@ async fn test_collect_with_grouping() { assert_eq!(cities.value(2), "San Francisco"); assert_eq!(cities.value(3), "Seattle"); } + +#[tokio::test] +async fn test_collect_with_null_values() { + // Test COLLECT handles NULL values correctly + // David has NULL city, so collecting cities should include the null + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + // Collect all cities (including NULL from David) + let query = CypherQuery::new("MATCH (p:Person) RETURN collect(p.city) AS all_cities") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + // COLLECT returns a single row with an array + assert_eq!(result.num_rows(), 1); + + // Verify the column exists and has 5 elements (including the NULL) + let all_cities_col = result.column_by_name("all_cities").unwrap(); + // The array should have been created successfully + assert!(!all_cities_col.is_empty()); +} + +// ============================================================================ +// WITH Clause Tests +// ============================================================================ + +#[tokio::test] +async fn test_with_simple_projection() { + // Test WITH as a simple projection pass-through + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new( + "MATCH (p:Person) WITH p.name AS name, p.age AS age RETURN name, age ORDER BY age", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + // Should have all 5 people + assert_eq!(result.num_rows(), 5); + + // Verify columns exist + assert!(result.column_by_name("name").is_some()); + assert!(result.column_by_name("age").is_some()); + + // Check ordering by age (Alice=25, Eve=28, Charlie=30, Bob=35, David=40) + let ages = result + .column_by_name("age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(ages.value(0), 25); + assert_eq!(ages.value(4), 40); +} + +#[tokio::test] +async fn test_with_aggregation() { + // Test WITH for aggregation: count people by city + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new( + "MATCH (p:Person) WHERE p.city IS NOT NULL WITH p.city AS city, count(*) AS total RETURN city, total ORDER BY city", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + // Should have 4 cities (David has NULL city) + assert_eq!(result.num_rows(), 4); + + // Verify columns exist + assert!(result.column_by_name("city").is_some()); + assert!(result.column_by_name("total").is_some()); +} + +#[tokio::test] +async fn test_with_order_by_limit_and_where() { + // Test WITH with ORDER BY, LIMIT, and post-WITH WHERE filter + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + // Get top 4 oldest people, then filter to those older than 30 + // Data: Alice=25, Eve=28, Charlie=30, Bob=35, David=40 + // After ORDER BY DESC LIMIT 4: David=40, Bob=35, Charlie=30, Eve=28 + // After WHERE age > 30: David=40, Bob=35 + let query = CypherQuery::new( + "MATCH (p:Person) WITH p.name AS name, p.age AS age ORDER BY age DESC LIMIT 4 WHERE age > 30 RETURN name, age ORDER BY age", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + // Should have 2 people (Bob=35, David=40) after LIMIT 4 then filter age > 30 + assert_eq!(result.num_rows(), 2); + + let names = result + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(names.value(0), "Bob"); + assert_eq!(names.value(1), "David"); + + let ages = result + .column_by_name("age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(ages.value(0), 35); + assert_eq!(ages.value(1), 40); +} + +#[tokio::test] +async fn test_with_post_match_chaining() { + // Test WITH with post-WITH MATCH (query chaining) + // First get people, then find additional patterns + let person_batch = create_person_dataset(); + let knows_batch = create_knows_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .with_relationship("KNOWS", "src_person_id", "dst_person_id") + .build() + .unwrap(); + + // Simpler chaining: WITH aggregation, then MATCH for additional data + // Get count of people per city, then find people in those cities + let query = CypherQuery::new( + "MATCH (p:Person) WHERE p.city IS NOT NULL \ + WITH p.city AS city, count(*) AS cnt \ + MATCH (p2:Person) WHERE p2.city = city \ + RETURN city, cnt, p2.name ORDER BY city", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + datasets.insert("KNOWS".to_string(), knows_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + // Should have results (city + count + person names) + assert!(result.num_rows() > 0); + assert!(result.column_by_name("city").is_some()); + assert!(result.column_by_name("cnt").is_some()); +}