diff --git a/.cursor/rules.md b/.cursor/rules.md new file mode 100644 index 0000000000000..00385e6f5f759 --- /dev/null +++ b/.cursor/rules.md @@ -0,0 +1,5 @@ +The AI agents should never link to any issue or a pull request +in any GitHub repository in the code reviews! + +The AI agents should not review AI agents' config files like CLAUDE.md or AGENTS.md! + diff --git a/AGENTS.md b/AGENTS.md index eeedbd8bc45ec..00385e6f5f759 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,34 +1,5 @@ -# Agent Guidelines for Apache DataFusion +The AI agents should never link to any issue or a pull request +in any GitHub repository in the code reviews! -## Developer Documentation +The AI agents should not review AI agents' config files like CLAUDE.md or AGENTS.md! -- [Contributor Guide](docs/source/contributor-guide/index.md) -- [Architecture Guide](docs/source/contributor-guide/architecture.md) - -## Before Committing - -Before committing any changes, you **must** run the following checks and fix any issues: - -```bash -cargo fmt --all -cargo clippy --all-targets --all-features -- -D warnings -``` - -- `cargo fmt` ensures consistent code formatting across the project. -- `cargo clippy` catches common mistakes and enforces idiomatic Rust patterns. All warnings must be resolved (treated as errors via `-D warnings`). - -Do not commit code that fails either of these checks. - -## Testing - -Run relevant tests before submitting changes: - -```bash -cargo test --all-features -``` - -For SQL logic tests: - -```bash -cargo test -p datafusion-sqllogictest -``` diff --git a/datafusion-examples/examples/dataframe/dataframe.rs b/datafusion-examples/examples/dataframe/dataframe.rs index dde19cb476f14..07253e98c8cbf 100644 --- a/datafusion-examples/examples/dataframe/dataframe.rs +++ b/datafusion-examples/examples/dataframe/dataframe.rs @@ -29,10 +29,12 @@ use datafusion::common::config::CsvOptions; use datafusion::common::parsers::CompressionTypeVariant; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::error::Result; -use datafusion::functions_aggregate::average::avg; -use datafusion::functions_aggregate::min_max::max; +use datafusion::functions_aggregate::average::{self, avg}; +use datafusion::functions_aggregate::min_max::{self, max}; use datafusion::prelude::*; use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; +use datafusion_expr::expr::WindowFunction; +use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; use tempfile::{TempDir, tempdir}; use tokio::fs::create_dir_all; @@ -53,8 +55,10 @@ use tokio::fs::create_dir_all; /// /// * [write_out]: write out a DataFrame to a table, parquet file, csv file, or json file /// -/// # Executing subqueries +/// # Querying data /// +/// * [aggregate_global_and_grouped]: global vs grouped aggregation (`select` vs `aggregate`) +/// * [window_vs_grouped_aggregation]: GROUP BY vs window functions (`aggregate`, `window`, `select`) /// * [where_scalar_subquery]: execute a scalar subquery /// * [where_in_subquery]: execute a subquery with an IN clause /// * [where_exist_subquery]: execute a subquery with an EXISTS clause @@ -69,6 +73,8 @@ pub async fn dataframe_example() -> Result<()> { write_out(&ctx).await?; register_cars_test_data("t1", &ctx).await?; register_cars_test_data("t2", &ctx).await?; + aggregate_global_and_grouped(&ctx).await?; + window_vs_grouped_aggregation(&ctx).await?; where_scalar_subquery(&ctx).await?; where_in_subquery(&ctx).await?; where_exist_subquery(&ctx).await?; @@ -269,6 +275,94 @@ async fn write_out(ctx: &SessionContext) -> Result<()> { Ok(()) } +/// Global vs grouped aggregation using `select` and `aggregate` +async fn aggregate_global_and_grouped(ctx: &SessionContext) -> Result<()> { + let df = ctx.table("t1").await?; + + // SELECT AVG(speed) FROM t1 + df.clone() + .aggregate(vec![], vec![avg(col("speed"))])? + .show() + .await?; + + // SELECT AVG(speed) FROM t1 (same result via `select`) + df.clone().select(vec![avg(col("speed"))])?.show().await?; + + // SELECT car, AVG(speed) FROM t1 GROUP BY car + df.aggregate(vec![col("car")], vec![avg(col("speed"))])? + .show() + .await?; + + Ok(()) +} + +/// GROUP BY vs window functions using `aggregate`, `window`, and `select` +async fn window_vs_grouped_aggregation(ctx: &SessionContext) -> Result<()> { + let df = ctx.table("t1").await?; + + // SELECT car, + // AVG(speed), + // MAX(speed) + // FROM t1 + // GROUP BY car + df.clone() + .aggregate( + vec![col("car")], + vec![ + avg(col("speed")).alias("avg_speed"), + max(col("speed")).alias("max_speed"), + ], + )? + .show() + .await?; + + // SELECT car, speed, + // AVG(speed) OVER (PARTITION BY car), + // MAX(speed) OVER (PARTITION BY car) + // FROM t1 + + // Window expressions: + let avg_win = Expr::WindowFunction(Box::new(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(average::avg_udaf()), + vec![col("speed")], + ))) + .partition_by(vec![col("car")]) + .order_by(vec![]) + .window_frame(WindowFrame::new(None)) + .build()?; + + let max_win = Expr::WindowFunction(Box::new(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(min_max::max_udaf()), + vec![col("speed")], + ))) + .partition_by(vec![col("car")]) + .order_by(vec![]) + .window_frame(WindowFrame::new(None)) + .build()?; + + // Two equivalent ways to compute window expressions: + // Using `window` then selecting columns + let res = df + .clone() + .window(vec![ + avg_win.clone().alias("avg_speed"), + max_win.clone().alias("max_speed"), + ])? + .select_columns(&["car", "speed", "avg_speed", "max_speed"])?; + res.show().await?; + + // Using window expressions directly in `select` + let res = df.select(vec![ + col("car"), + col("speed"), + avg_win.alias("avg_speed"), + max_win.alias("max_speed"), + ])?; + res.show().await?; + + Ok(()) +} + /// Use the DataFrame API to execute the following subquery: /// select car, speed from t1 where (select avg(t2.speed) from t2 where t1.car = t2.car) > 0 limit 3; async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 2292f5855bfde..a22b9322c6bbb 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -51,12 +51,14 @@ use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow_schema::FieldRef; use datafusion_common::config::{CsvOptions, JsonOptions}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ Column, DFSchema, DataFusionError, ParamValues, ScalarValue, SchemaError, TableReference, UnnestOptions, exec_err, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, }; use datafusion_expr::select_expr::SelectExpr; +use datafusion_expr::utils::find_aggregate_exprs; use datafusion_expr::{ ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, case, dml::InsertOp, @@ -387,13 +389,16 @@ impl DataFrame { /// # use datafusion::prelude::*; /// # use datafusion::error::Result; /// # use datafusion_common::assert_batches_sorted_eq; + /// # use datafusion_functions_aggregate::expr_fn::{count, sum}; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) /// .await?; - /// let df = df.select(vec![col("a"), col("b") * col("c")])?; + /// + /// // Expressions are evaluated per row + /// let res = df.clone().select(vec![col("a"), col("b") * col("c")])?; /// let expected = vec![ /// "+---+-----------------------+", /// "| a | ?table?.b * ?table?.c |", @@ -401,7 +406,18 @@ impl DataFrame { /// "| 1 | 6 |", /// "+---+-----------------------+", /// ]; - /// # assert_batches_sorted_eq!(expected, &df.collect().await?); + /// # assert_batches_sorted_eq!(expected, &res.collect().await?); + /// + /// // Aggregate expressions are also supported + /// let res = df.select(vec![count(col("a")), sum(col("b"))])?; + /// let expected = vec![ + /// "+----------+--------+", + /// "| count(a) | sum(b) |", + /// "+----------+--------+", + /// "| 1 | 2 |", + /// "+----------+--------+", + /// ]; + /// # assert_batches_sorted_eq!(expected, &res.collect().await?); /// # Ok(()) /// # } /// ``` @@ -410,21 +426,123 @@ impl DataFrame { expr_list: impl IntoIterator>, ) -> Result { let expr_list: Vec = - expr_list.into_iter().map(|e| e.into()).collect::>(); + expr_list.into_iter().map(|e| e.into()).collect(); + // Extract expressions let expressions = expr_list.iter().filter_map(|e| match e { SelectExpr::Expression(expr) => Some(expr), _ => None, }); - let window_func_exprs = find_window_exprs(expressions); - let plan = if window_func_exprs.is_empty() { + // Apply window functions first + let window_func_exprs = find_window_exprs(expressions.clone()); + + let mut plan = if window_func_exprs.is_empty() { self.plan } else { LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)? }; - let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; + // Collect aggregate expressions + let aggr_exprs = find_aggregate_exprs(expressions.clone()); + + // Check for non-aggregate expressions + let has_non_aggregate_expr = expr_list.iter().any(|e| match e { + SelectExpr::Expression(expr) => { + find_aggregate_exprs(std::iter::once(expr)).is_empty() + } + SelectExpr::Wildcard(_) | SelectExpr::QualifiedWildcard(_, _) => true, + }); + + if has_non_aggregate_expr && !aggr_exprs.is_empty() { + return plan_err!( + "Column in SELECT must be in GROUP BY or an aggregate function" + ); + } + + // Fallback to projection + if matches!(plan, LogicalPlan::Aggregate(_)) + || has_non_aggregate_expr + || aggr_exprs.is_empty() + { + let project_plan = + LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; + + return Ok(DataFrame { + session_state: self.session_state, + plan: project_plan, + projection_requires_validation: false, + }); + } + + // Unique name generator + let make_unique_name = + |base: String, used: &mut HashSet, start: usize| { + let mut name = base.clone(); + let mut counter = start; + while used.contains(&name) { + name = format!("{base}_{counter}"); + counter += 1; + } + used.insert(name.clone()); + + name + }; + + // Aggregate stage + let mut aggr_map: HashMap = HashMap::new(); + let mut aggr_used_names = HashSet::new(); + let aggr_exprs_with_alias: Vec = aggr_exprs + .into_iter() + .map(|expr| { + let base_name = expr.name_for_alias()?; + let name = make_unique_name(base_name, &mut aggr_used_names, 1); + let aliased = expr.clone().alias(name.clone()); + let col = Expr::Column(Column::from_name(name)); + aggr_map.insert(expr, col); + + Ok(aliased) + }) + .collect::>>()?; + + // Build aggregate plan + plan = LogicalPlanBuilder::from(plan) + .aggregate(Vec::::new(), aggr_exprs_with_alias)? + .build()?; + + // Rewrite expressions + let rewrite_expr = |expr: Expr, aggr_map: &HashMap| -> Result { + expr.transform(|e| { + Ok(match aggr_map.get(&e) { + Some(replacement) => Transformed::yes(replacement.clone()), + None => Transformed::no(e), + }) + }) + .map(|t| t.data) + }; + + // Projection stage + let mut rewritten_exprs = Vec::with_capacity(expr_list.len()); + let mut projection_used_names = HashSet::new(); + for select_expr in expr_list.into_iter() { + match select_expr { + SelectExpr::Expression(expr) => { + let base_alias = expr.name_for_alias()?; + let rewritten = rewrite_expr(expr, &aggr_map)?; + let name = + make_unique_name(base_alias, &mut projection_used_names, 1); + let final_expr = rewritten.alias(name); + + rewritten_exprs.push(SelectExpr::Expression(final_expr)); + } + other => rewritten_exprs.push(other), + } + } + + // Final projection + let project_plan = LogicalPlanBuilder::from(plan) + .project(rewritten_exprs)? + .build()?; Ok(DataFrame { session_state: self.session_state, diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 80bbde1f6ba14..02e40cdc29025 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -32,8 +32,9 @@ use arrow::datatypes::{ use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SortOptions, TimeUnit}; -use datafusion::{assert_batches_eq, dataframe}; +use datafusion::{assert_batches_eq, dataframe, functions_window}; use datafusion_common::metadata::FieldMetadata; +use datafusion_expr::select_expr::SelectExpr; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, @@ -72,7 +73,9 @@ use datafusion_common_runtime::SpawnedTask; use datafusion_datasource::file_format::format_as_file_type; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_expr::expr::{GroupingSet, NullTreatment, Sort, WindowFunction}; +use datafusion_expr::expr::{ + GroupingSet, NullTreatment, Sort, WildcardOptions, WindowFunction, +}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, LogicalPlanBuilder, @@ -1046,26 +1049,28 @@ async fn test_aggregate_with_union() -> Result<()> { let df1 = df .clone() - // GROUP BY `c1` - .aggregate(vec![col("c1")], vec![min(col("c2"))])? - // SELECT `c1` , min(c2) as `result` - .select(vec![col("c1"), min(col("c2")).alias("result")])?; + // GROUP BY c1, compute min(c2) as result + .aggregate(vec![col("c1")], vec![min(col("c2")).alias("result")])? + // SELECT c1, result + .select(vec![col("c1"), col("result")])?; + let df2 = df .clone() - // GROUP BY `c1` - .aggregate(vec![col("c1")], vec![max(col("c3"))])? - // SELECT `c1` , max(c3) as `result` - .select(vec![col("c1"), max(col("c3")).alias("result")])?; + // GROUP BY c1, compute max(c3) as result + .aggregate(vec![col("c1")], vec![max(col("c3")).alias("result")])? + // SELECT c1, result + .select(vec![col("c1"), col("result")])?; let df_union = df1.union(df2)?; + let df = df_union - // GROUP BY `c1` + // GROUP BY c1, sum(result) as sum_result .aggregate( vec![col("c1")], vec![sum(col("result")).alias("sum_result")], )? - // SELECT `c1`, sum(result) as `sum_result` - .select(vec![(col("c1")), col("sum_result")])?; + // SELECT c1, sum_result + .select(vec![col("c1"), col("sum_result")])?; let df_results = df.collect().await?; @@ -1095,28 +1100,28 @@ async fn test_aggregate_subexpr() -> Result<()> { let df = df // GROUP BY `c2 + 1` - .aggregate(vec![group_expr.clone()], vec![aggr_expr.clone()])? + .aggregate( + vec![group_expr.clone().alias("g")], + vec![aggr_expr.clone().alias("a")], + )? // SELECT `c2 + 1` as c2 + 10, sum(c3 + 2) + 20 // SELECT expressions contain aggr_expr and group_expr as subexpressions - .select(vec![ - group_expr.alias("c2") + lit(10), - (aggr_expr + lit(20)).alias("sum"), - ])?; + .select(vec![col("g") + lit(10), (col("a") + lit(20)).alias("sum")])?; let df_results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), @r" - +----------------+------+ - | c2 + Int32(10) | sum | - +----------------+------+ - | 12 | 431 | - | 13 | 248 | - | 14 | 453 | - | 15 | 95 | - | 16 | -146 | - +----------------+------+ + +---------------+------+ + | g + Int32(10) | sum | + +---------------+------+ + | 12 | 431 | + | 13 | 248 | + | 14 | 453 | + | 15 | 95 | + | 16 | -146 | + +---------------+------+ " ); @@ -6854,3 +6859,445 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_dataframe_api_select_semantics() -> Result<()> { + let df = test_table().await?; + + // ---------------------------------------------------------------------- + // Aggregate functions in SELECT (no GROUP BY) + // ---------------------------------------------------------------------- + let res = df.clone().select(vec![ + count(col("c9")).alias("count_c9"), + count(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"), + sum(col("c9")).alias("sum_c9"), + count(col("c8")).alias("count_c8"), + (sum(col("c9")) + count(col("c8"))).alias("total1"), + ((count(col("c9")) + lit(1)) * lit(2)).alias("total2"), + (count(col("c9")) + lit(1)).alias("count_c9_add_1"), + ])?; + + assert_batches_eq!( + &[ + "+----------+--------------+--------------+----------+--------------+--------+----------------+", + "| count_c9 | count_c9_str | sum_c9 | count_c8 | total1 | total2 | count_c9_add_1 |", + "+----------+--------------+--------------+----------+--------------+--------+----------------+", + "| 100 | 100 | 222089770060 | 100 | 222089770160 | 202 | 101 |", + "+----------+--------------+--------------+----------+--------------+--------+----------------+", + ], + &res.collect().await? + ); + + // ---------------------------------------------------------------------- + // Alias deduplication + // ---------------------------------------------------------------------- + let res = df.clone().select(vec![ + count(col("c9")).alias("count_c9"), + count(col("c9")).alias("count_c9"), + ])?; + + assert_batches_eq!( + &[ + "+----------+------------+", + "| count_c9 | count_c9_1 |", + "+----------+------------+", + "| 100 | 100 |", + "+----------+------------+", + ], + &res.collect().await? + ); + + // ---------------------------------------------------------------------- + // Wildcard handling + // ---------------------------------------------------------------------- + // SQL: + // SELECT *, 42 FROM t + let res = df + .clone() + .select(vec![ + SelectExpr::Wildcard(WildcardOptions::default()), + lit(42).into(), + ])? + .limit(0, None)?; + + let batches = res.collect().await?; + assert!(!batches.is_empty()); + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 100); + assert!(batches.iter().all(|b| b.num_columns() == 14)); + + let res = df.clone().select(vec![ + SelectExpr::QualifiedWildcard( + "aggregate_test_100".into(), + WildcardOptions::default(), + ), + lit(42).into(), + ])?; + + let batches = res.collect().await?; + assert!(!batches.is_empty()); + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 100); + assert!(batches.iter().all(|b| b.num_columns() == 14)); + + // ---------------------------------------------------------------------- + // Window functions + // ---------------------------------------------------------------------- + // SQL: + // SELECT + // c1, + // COUNT(c9) OVER (PARTITION BY c1) AS cnt, + // SUM(c9) OVER (PARTITION BY c1) AS sum_c9, + // AVG(c9) OVER (PARTITION BY c1) AS avg_c9 + // FROM t + // ORDER BY c1 + let count_window_function = Expr::WindowFunction(Box::new(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF( + datafusion_functions_aggregate::count::count_udaf(), + ), + vec![col("c9")], + ))) + .partition_by(vec![col("c1")]) + .order_by(vec![]) + .window_frame(WindowFrame::new(None)) + .build()?; + + let sum_window_function = Expr::WindowFunction(Box::new(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF( + datafusion_functions_aggregate::sum::sum_udaf(), + ), + vec![col("c9")], + ))) + .partition_by(vec![col("c1")]) + .order_by(vec![]) + .window_frame(WindowFrame::new(None)) + .build()?; + + let avg_window_function = Expr::WindowFunction(Box::new(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF( + datafusion_functions_aggregate::average::avg_udaf(), + ), + vec![col("c9")], + ))) + .partition_by(vec![col("c1")]) + .order_by(vec![]) + .window_frame(WindowFrame::new(None)) + .build()?; + + let res = df + .clone() + .select(vec![ + col("c1"), + count_window_function.alias("cnt"), + sum_window_function.alias("sum_c9"), + avg_window_function.alias("avg_c9"), + ])? + .sort(vec![col("c1").sort(true, true)])?; + + let batches = res.collect().await?; + assert!(!batches.is_empty()); + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 100); + assert!(batches.iter().all(|b| b.num_columns() == 4)); + let batch = &batches[0]; + assert_batches_eq!( + &[ + "+----+-----+-------------+--------------------+", + "| c1 | cnt | sum_c9 | avg_c9 |", + "+----+-----+-------------+--------------------+", + "| a | 21 | 42619217323 | 2029486539.1904762 |", + "| b | 19 | 42365566310 | 2229766647.894737 |", + "| c | 21 | 46381998762 | 2208666607.714286 |", + "| d | 18 | 39910269981 | 2217237221.1666665 |", + "| e | 21 | 50812717684 | 2419653223.047619 |", + "+----+-----+-------------+--------------------+", + ], + &[ + batch.slice(0, 1), // a + batch.slice(21, 1), // b + batch.slice(40, 1), // c + batch.slice(61, 1), // d + batch.slice(79, 1) // e + ] + ); + + // Window with ORDER BY + // SQL: + // SELECT + // c1, + // ROW_NUMBER() OVER (PARTITION BY c1 ORDER BY c9) AS rn + // FROM t + // ORDER BY c1 + let row_number_window_function = Expr::WindowFunction(Box::new(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + functions_window::row_number::row_number_udwf(), + ), + vec![], + ))) + .partition_by(vec![col("c1")]) + .order_by(vec![col("c9").sort(true, true)]) + .window_frame(WindowFrame::new(None)) + .build()?; + + let res = df + .clone() + .select(vec![col("c1"), row_number_window_function.alias("rn")])? + .sort(vec![col("c1").sort(true, true)])?; + + let batches = res.collect().await?; + assert!(!batches.is_empty()); + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 100); + assert!(batches.iter().all(|b| b.num_columns() == 2)); + assert_batches_eq!( + &[ + "+----+----+", + "| c1 | rn |", + "+----+----+", + "| a | 1 |", + "| a | 2 |", + "| a | 3 |", + "| a | 4 |", + "| a | 5 |", + "| a | 6 |", + "| a | 7 |", + "| a | 8 |", + "| a | 9 |", + "| a | 10 |", + "| a | 11 |", + "| a | 12 |", + "| a | 13 |", + "| a | 14 |", + "| a | 15 |", + "| a | 16 |", + "| a | 17 |", + "| a | 18 |", + "| a | 19 |", + "| a | 20 |", + "| a | 21 |", + "| b | 1 |", + "| b | 2 |", + "| b | 3 |", + "| b | 4 |", + "| b | 5 |", + "| b | 6 |", + "| b | 7 |", + "| b | 8 |", + "| b | 9 |", + "| b | 10 |", + "| b | 11 |", + "| b | 12 |", + "| b | 13 |", + "| b | 14 |", + "| b | 15 |", + "| b | 16 |", + "| b | 17 |", + "| b | 18 |", + "| b | 19 |", + "| c | 1 |", + "| c | 2 |", + "| c | 3 |", + "| c | 4 |", + "| c | 5 |", + "| c | 6 |", + "| c | 7 |", + "| c | 8 |", + "| c | 9 |", + "| c | 10 |", + "| c | 11 |", + "| c | 12 |", + "| c | 13 |", + "| c | 14 |", + "| c | 15 |", + "| c | 16 |", + "| c | 17 |", + "| c | 18 |", + "| c | 19 |", + "| c | 20 |", + "| c | 21 |", + "| d | 1 |", + "| d | 2 |", + "| d | 3 |", + "| d | 4 |", + "| d | 5 |", + "| d | 6 |", + "| d | 7 |", + "| d | 8 |", + "| d | 9 |", + "| d | 10 |", + "| d | 11 |", + "| d | 12 |", + "| d | 13 |", + "| d | 14 |", + "| d | 15 |", + "| d | 16 |", + "| d | 17 |", + "| d | 18 |", + "| e | 1 |", + "| e | 2 |", + "| e | 3 |", + "| e | 4 |", + "| e | 5 |", + "| e | 6 |", + "| e | 7 |", + "| e | 8 |", + "| e | 9 |", + "| e | 10 |", + "| e | 11 |", + "| e | 12 |", + "| e | 13 |", + "| e | 14 |", + "| e | 15 |", + "| e | 16 |", + "| e | 17 |", + "| e | 18 |", + "| e | 19 |", + "| e | 20 |", + "| e | 21 |", + "+----+----+", + ], + &batches + ); + + // Window inside expression + // SQL: + // SELECT + // c1, + // COUNT(c9) OVER (PARTITION BY c1) + 1 AS cnt_plus + // FROM t + // ORDER BY c1 + let count_window_function = Expr::WindowFunction(Box::new(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF( + datafusion_functions_aggregate::count::count_udaf(), + ), + vec![col("c9")], + ))) + .partition_by(vec![col("c1")]) + .order_by(vec![]) + .window_frame(WindowFrame::new(None)) + .build()?; + + let cnt_plus_expr = count_window_function + lit(1); + + let res = df + .clone() + .select(vec![col("c1"), cnt_plus_expr.alias("cnt_plus")])? + .sort(vec![col("c1").sort(true, true)])?; + + let batches = res.collect().await?; + assert!(!batches.is_empty()); + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 100); + assert!(batches.iter().all(|b| b.num_columns() == 2)); + let batch = &batches[0]; + assert_batches_eq!( + &[ + "+----+----------+", + "| c1 | cnt_plus |", + "+----+----------+", + "| a | 22 |", + "| b | 20 |", + "| c | 22 |", + "| d | 19 |", + "| e | 22 |", + "+----+----------+", + ], + &[ + batch.slice(0, 1), // a + batch.slice(21, 1), // b + batch.slice(40, 1), // c + batch.slice(61, 1), // d + batch.slice(79, 1), // e + ] + ); + + // Mixed aggregate + window + // SQL: + // SELECT + // c1, + // SUM(c9) OVER () AS total_sum, + // COUNT(c9) OVER (PARTITION BY c1) AS cnt + // FROM t + let sum_window_function = Expr::WindowFunction(Box::new(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF( + datafusion_functions_aggregate::sum::sum_udaf(), + ), + vec![col("c9")], + ))) + .partition_by(vec![]) + .order_by(vec![]) + .window_frame(WindowFrame::new(None)) + .build()?; + + let count_window_function = Expr::WindowFunction(Box::new(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF( + datafusion_functions_aggregate::count::count_udaf(), + ), + vec![col("c9")], + ))) + .partition_by(vec![col("c1")]) + .order_by(vec![]) + .window_frame(WindowFrame::new(None)) + .build()?; + + let res = df + .clone() + .select(vec![ + col("c1"), + sum_window_function.alias("total_sum"), + count_window_function.alias("cnt"), + ])? + .sort(vec![col("c1").sort(true, true)])?; + + let batches = res.collect().await?; + assert!(!batches.is_empty()); + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 100); + assert!(batches.iter().all(|b| b.num_columns() == 3)); + let batch = &batches[0]; + + assert_batches_eq!( + &[ + "+----+--------------+-----+", + "| c1 | total_sum | cnt |", + "+----+--------------+-----+", + "| a | 222089770060 | 21 |", + "+----+--------------+-----+" + ], + &[batch.slice(0, 1)] + ); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_api_select_semantics_err() -> Result<()> { + let df = test_table().await?; + + // Wildcard + aggregate without GROUP BY + // Equivalent SQL: SELECT *, COUNT(c9) FROM t + // Should fail: non-aggregate columns require GROUP BY + let res = df.clone().select(vec![ + SelectExpr::Wildcard(WildcardOptions::default()), + count(col("c9")).alias("cnt").into(), + ]); + let err = res.expect_err("query should fail"); + let msg = err.to_string(); + assert!(msg.contains("must be in GROUP BY"), "actual error: {msg}"); + + // Aggregate with subexpressions in SELECT without proper GROUP BY + // Equivalent SQL: SELECT c2 + 10, SUM(c3 + 2) + 20 FROM t GROUP BY c2 + 1 + // Should fail: subexpressions not allowed outside GROUP BY + // Hint: see `test_aggregate_subexpr` for a working example + let group_expr = col("c2") + lit(1); + let aggr_expr = sum(col("c3") + lit(2)); + let res = df + .clone() + .aggregate(vec![group_expr.clone()], vec![aggr_expr.clone()])? + .select(vec![ + group_expr.alias("c2") + lit(10), + (aggr_expr + lit(20)).alias("sum"), + ]); + let err = res.expect_err("query should fail for subexpr outside GROUP BY"); + let msg = err.to_string(); + assert!( + msg.contains("must be in GROUP BY"), + "unexpected error for subexpr + aggregate: {msg}" + ); + + Ok(()) +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 63ad00c92e6a9..f6e7e7c1f7e79 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -33,6 +33,7 @@ use datafusion::optimizer::Optimizer; use datafusion::optimizer::optimize_unions::OptimizeUnions; use datafusion_common::parquet_config::DFParquetWriterVersion; use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_expr::utils::find_aggregate_exprs; use datafusion_functions_aggregate::sum::sum_distinct; use prost::Message; use std::any::Any; @@ -53,8 +54,8 @@ use datafusion::execution::session_state::SessionStateBuilder; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, - count_distinct, covar_pop, covar_samp, first_value, grouping, max, median, min, - stddev, stddev_pop, sum, var_pop, var_sample, + count_distinct, covar_pop, covar_samp, first_value, max, median, min, stddev, + stddev_pop, sum, var_pop, var_sample, }; use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_nested::map::map; @@ -1193,7 +1194,7 @@ async fn roundtrip_expr_api() -> Result<()> { lit(0.5), Some(lit(50)), ), - grouping(lit(1)), + // grouping(lit(1)), // #TODO Error: Context("resolve_grouping_function", Plan("Argument Int32(1) to grouping function is not in grouping columns ") bit_and(lit(2)), bit_or(lit(2)), bit_xor(lit(2)), @@ -1232,11 +1233,32 @@ async fn roundtrip_expr_api() -> Result<()> { ), ]; - // ensure expressions created with the expr api can be round tripped - let plan = table.select(expr_list)?.into_optimized_plan()?; - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; - assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); + let agg_exprs: Vec = expr_list + .iter() + .filter(|e| !find_aggregate_exprs(vec![*e]).is_empty()) + .cloned() + .collect(); + + let non_agg_exprs: Vec = expr_list + .iter() + .filter(|e| find_aggregate_exprs(vec![*e]).is_empty()) + .cloned() + .collect(); + + let plan_non_agg = table.clone().select(non_agg_exprs)?.into_optimized_plan()?; + let bytes_non_agg = logical_plan_to_bytes(&plan_non_agg)?; + let logical_round_trip_non_agg = + logical_plan_from_bytes(&bytes_non_agg, &ctx.task_ctx())?; + assert_eq!( + format!("{plan_non_agg}"), + format!("{logical_round_trip_non_agg}") + ); + + let plan_agg = table.aggregate(vec![], agg_exprs)?.into_optimized_plan()?; + let bytes_agg = logical_plan_to_bytes(&plan_agg)?; + let logical_round_trip_agg = logical_plan_from_bytes(&bytes_agg, &ctx.task_ctx())?; + assert_eq!(format!("{plan_agg}"), format!("{logical_round_trip_agg}")); + Ok(()) }