diff --git a/rust/lance-graph/src/datafusion_planner/expression.rs b/rust/lance-graph/src/datafusion_planner/expression.rs index abd8ab3..954b546 100644 --- a/rust/lance-graph/src/datafusion_planner/expression.rs +++ b/rust/lance-graph/src/datafusion_planner/expression.rs @@ -10,6 +10,7 @@ use crate::datafusion_planner::udf; use datafusion::functions::string::lower; use datafusion::functions::string::upper; use datafusion::logical_expr::{col, lit, BinaryExpr, Expr, Operator}; +use datafusion_functions_aggregate::array_agg::array_agg; use datafusion_functions_aggregate::average::avg; use datafusion_functions_aggregate::count::count; use datafusion_functions_aggregate::min_max::max; @@ -188,6 +189,15 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { lit(0) } } + // COLLECT aggregation - collects values into an array + "collect" => { + if args.len() == 1 { + let arg_expr = to_df_value_expr(&args[0]); + array_agg(arg_expr) + } else { + lit(0) + } + } // String functions "tolower" | "lower" => { if args.len() == 1 { @@ -307,7 +317,7 @@ pub(crate) fn contains_aggregate(expr: &ValueExpression) -> bool { // Check if this is an aggregate function let is_aggregate = matches!( name.to_lowercase().as_str(), - "count" | "sum" | "avg" | "min" | "max" + "count" | "sum" | "avg" | "min" | "max" | "collect" ); // Also check arguments recursively is_aggregate || args.iter().any(contains_aggregate) diff --git a/rust/lance-graph/src/semantic.rs b/rust/lance-graph/src/semantic.rs index a34dace..bc5fdce 100644 --- a/rust/lance-graph/src/semantic.rs +++ b/rust/lance-graph/src/semantic.rs @@ -289,7 +289,7 @@ impl SemanticAnalyzer { ValueExpression::Function { name, args } => { // Validate function-specific arity and signature rules match name.to_lowercase().as_str() { - "count" | "sum" | "avg" | "min" | "max" => { + "count" | "sum" | "avg" | "min" | "max" | "collect" => { if args.len() != 1 { return Err(GraphError::PlanError { message: format!( @@ -302,7 +302,7 @@ impl SemanticAnalyzer { } // Additional validation for SUM, AVG, MIN, MAX: they require properties, not bare variables - // Only COUNT allows bare variables (COUNT(*) or COUNT(p)) + // Only COUNT and COLLECT allow bare variables (COUNT(*), COUNT(p), COLLECT(p)) if matches!(name.to_lowercase().as_str(), "sum" | "avg" | "min" | "max") { if let Some(ValueExpression::Variable(v)) = args.first() { return Err(GraphError::PlanError { diff --git a/rust/lance-graph/tests/test_datafusion_pipeline.rs b/rust/lance-graph/tests/test_datafusion_pipeline.rs index 9c1998a..9e3837b 100644 --- a/rust/lance-graph/tests/test_datafusion_pipeline.rs +++ b/rust/lance-graph/tests/test_datafusion_pipeline.rs @@ -4244,3 +4244,70 @@ async fn test_tolower_with_integer_column_in_return() { assert_eq!(names.value(0), "Charlie"); assert_eq!(ages.value(0), 30); } + +#[tokio::test] +async fn test_collect_property() { + // Test COLLECT aggregation - collects values into an array + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN collect(p.name) AS all_names") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + // COLLECT returns a single row with an array of all values + assert_eq!(result.num_rows(), 1); + // Verify the column exists + assert!(result.column_by_name("all_names").is_some()); +} + +#[tokio::test] +async fn test_collect_with_grouping() { + // Test COLLECT with GROUP BY - collect names grouped by city + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new( + "MATCH (p:Person) WHERE p.city IS NOT NULL RETURN p.city, collect(p.name) AS names ORDER BY p.city", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + // Should have one row per city (4 cities with non-null values) + assert_eq!(result.num_rows(), 4); + + let cities = result + .column_by_name("p.city") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // Cities should be ordered: Chicago, New York, San Francisco, Seattle + assert_eq!(cities.value(0), "Chicago"); + assert_eq!(cities.value(1), "New York"); + assert_eq!(cities.value(2), "San Francisco"); + assert_eq!(cities.value(3), "Seattle"); +}