-
Notifications
You must be signed in to change notification settings - Fork 0
21021: DataFrame API: allow aggregate functions in select() (#17874) #299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ff0ada7
1659fa7
f9f351e
0262b63
e527e12
072ba48
34127b1
2de353a
9564aba
d9ec8c8
a922960
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| The AI agents should never link to any issue or a pull request | ||
| in any GitHub repository in the code reviews! | ||
|
|
||
| The AI agents should not review AI agents' config files like CLAUDE.md or AGENTS.md! | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,34 +1,5 @@ | ||
| # Agent Guidelines for Apache DataFusion | ||
| The AI agents should never link to any issue or a pull request | ||
| in any GitHub repository in the code reviews! | ||
|
|
||
| ## Developer Documentation | ||
| The AI agents should not review AI agents' config files like CLAUDE.md or AGENTS.md! | ||
|
|
||
| - [Contributor Guide](docs/source/contributor-guide/index.md) | ||
| - [Architecture Guide](docs/source/contributor-guide/architecture.md) | ||
|
|
||
| ## Before Committing | ||
|
|
||
| Before committing any changes, you **must** run the following checks and fix any issues: | ||
|
|
||
| ```bash | ||
| cargo fmt --all | ||
| cargo clippy --all-targets --all-features -- -D warnings | ||
| ``` | ||
|
|
||
| - `cargo fmt` ensures consistent code formatting across the project. | ||
| - `cargo clippy` catches common mistakes and enforces idiomatic Rust patterns. All warnings must be resolved (treated as errors via `-D warnings`). | ||
|
|
||
| Do not commit code that fails either of these checks. | ||
|
|
||
| ## Testing | ||
|
|
||
| Run relevant tests before submitting changes: | ||
|
|
||
| ```bash | ||
| cargo test --all-features | ||
| ``` | ||
|
|
||
| For SQL logic tests: | ||
|
|
||
| ```bash | ||
| cargo test -p datafusion-sqllogictest | ||
| ``` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,12 +51,14 @@ use arrow::compute::{cast, concat}; | |
| use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; | ||
| use arrow_schema::FieldRef; | ||
| use datafusion_common::config::{CsvOptions, JsonOptions}; | ||
| use datafusion_common::tree_node::{Transformed, TreeNode}; | ||
| use datafusion_common::{ | ||
| Column, DFSchema, DataFusionError, ParamValues, ScalarValue, SchemaError, | ||
| TableReference, UnnestOptions, exec_err, internal_datafusion_err, not_impl_err, | ||
| plan_datafusion_err, plan_err, unqualified_field_not_found, | ||
| }; | ||
| use datafusion_expr::select_expr::SelectExpr; | ||
| use datafusion_expr::utils::find_aggregate_exprs; | ||
| use datafusion_expr::{ | ||
| ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, case, | ||
| dml::InsertOp, | ||
|
|
@@ -387,21 +389,35 @@ impl DataFrame { | |
| /// # use datafusion::prelude::*; | ||
| /// # use datafusion::error::Result; | ||
| /// # use datafusion_common::assert_batches_sorted_eq; | ||
| /// # use datafusion_functions_aggregate::expr_fn::{count, sum}; | ||
| /// # #[tokio::main] | ||
| /// # async fn main() -> Result<()> { | ||
| /// let ctx = SessionContext::new(); | ||
| /// let df = ctx | ||
| /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) | ||
| /// .await?; | ||
| /// let df = df.select(vec![col("a"), col("b") * col("c")])?; | ||
| /// | ||
| /// // Expressions are evaluated per row | ||
| /// let res = df.clone().select(vec![col("a"), col("b") * col("c")])?; | ||
| /// let expected = vec![ | ||
| /// "+---+-----------------------+", | ||
| /// "| a | ?table?.b * ?table?.c |", | ||
| /// "+---+-----------------------+", | ||
| /// "| 1 | 6 |", | ||
| /// "+---+-----------------------+", | ||
| /// ]; | ||
| /// # assert_batches_sorted_eq!(expected, &df.collect().await?); | ||
| /// # assert_batches_sorted_eq!(expected, &res.collect().await?); | ||
| /// | ||
| /// // Aggregate expressions are also supported | ||
| /// let res = df.select(vec![count(col("a")), sum(col("b"))])?; | ||
| /// let expected = vec![ | ||
| /// "+----------+--------+", | ||
| /// "| count(a) | sum(b) |", | ||
| /// "+----------+--------+", | ||
| /// "| 1 | 2 |", | ||
| /// "+----------+--------+", | ||
| /// ]; | ||
| /// # assert_batches_sorted_eq!(expected, &res.collect().await?); | ||
| /// # Ok(()) | ||
| /// # } | ||
| /// ``` | ||
|
|
@@ -410,21 +426,123 @@ impl DataFrame { | |
| expr_list: impl IntoIterator<Item = impl Into<SelectExpr>>, | ||
| ) -> Result<DataFrame> { | ||
| let expr_list: Vec<SelectExpr> = | ||
| expr_list.into_iter().map(|e| e.into()).collect::<Vec<_>>(); | ||
| expr_list.into_iter().map(|e| e.into()).collect(); | ||
|
|
||
| // Extract expressions | ||
| let expressions = expr_list.iter().filter_map(|e| match e { | ||
| SelectExpr::Expression(expr) => Some(expr), | ||
| _ => None, | ||
| }); | ||
|
|
||
| let window_func_exprs = find_window_exprs(expressions); | ||
| let plan = if window_func_exprs.is_empty() { | ||
| // Apply window functions first | ||
| let window_func_exprs = find_window_exprs(expressions.clone()); | ||
|
|
||
| let mut plan = if window_func_exprs.is_empty() { | ||
| self.plan | ||
| } else { | ||
| LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)? | ||
| }; | ||
|
|
||
| let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; | ||
| // Collect aggregate expressions | ||
| let aggr_exprs = find_aggregate_exprs(expressions.clone()); | ||
|
|
||
| // Check for non-aggregate expressions | ||
| let has_non_aggregate_expr = expr_list.iter().any(|e| match e { | ||
| SelectExpr::Expression(expr) => { | ||
| find_aggregate_exprs(std::iter::once(expr)).is_empty() | ||
| } | ||
| SelectExpr::Wildcard(_) | SelectExpr::QualifiedWildcard(_, _) => true, | ||
| }); | ||
|
Comment on lines
+449
to
+455
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check for non-aggregate expressions is incomplete. It only verifies if an expression contains an aggregate function, but it doesn't ensure that all column references within that expression are properly aggregated. For example, Additionally, calling |
||
|
|
||
| if has_non_aggregate_expr && !aggr_exprs.is_empty() { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Severity: medium 🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage. |
||
| return plan_err!( | ||
| "Column in SELECT must be in GROUP BY or an aggregate function" | ||
| ); | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Premature validation breaks aggregate-then-select chaining patternMedium Severity The mixed aggregate/non-aggregate validation at Additional Locations (1) |
||
|
|
||
| // Fallback to projection | ||
| if matches!(plan, LogicalPlan::Aggregate(_)) | ||
| || has_non_aggregate_expr | ||
| || aggr_exprs.is_empty() | ||
| { | ||
| let project_plan = | ||
| LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; | ||
|
|
||
| return Ok(DataFrame { | ||
| session_state: self.session_state, | ||
| plan: project_plan, | ||
| projection_requires_validation: false, | ||
| }); | ||
| } | ||
|
|
||
| // Unique name generator | ||
| let make_unique_name = | ||
| |base: String, used: &mut HashSet<String>, start: usize| { | ||
| let mut name = base.clone(); | ||
| let mut counter = start; | ||
| while used.contains(&name) { | ||
| name = format!("{base}_{counter}"); | ||
| counter += 1; | ||
| } | ||
| used.insert(name.clone()); | ||
|
|
||
| name | ||
| }; | ||
|
|
||
| // Aggregate stage | ||
| let mut aggr_map: HashMap<Expr, Expr> = HashMap::new(); | ||
| let mut aggr_used_names = HashSet::new(); | ||
| let aggr_exprs_with_alias: Vec<Expr> = aggr_exprs | ||
| .into_iter() | ||
| .map(|expr| { | ||
| let base_name = expr.name_for_alias()?; | ||
| let name = make_unique_name(base_name, &mut aggr_used_names, 1); | ||
| let aliased = expr.clone().alias(name.clone()); | ||
| let col = Expr::Column(Column::from_name(name)); | ||
| aggr_map.insert(expr, col); | ||
|
|
||
| Ok(aliased) | ||
| }) | ||
| .collect::<Result<Vec<_>>>()?; | ||
|
|
||
| // Build aggregate plan | ||
| plan = LogicalPlanBuilder::from(plan) | ||
| .aggregate(Vec::<Expr>::new(), aggr_exprs_with_alias)? | ||
| .build()?; | ||
|
|
||
| // Rewrite expressions | ||
| let rewrite_expr = |expr: Expr, aggr_map: &HashMap<Expr, Expr>| -> Result<Expr> { | ||
| expr.transform(|e| { | ||
| Ok(match aggr_map.get(&e) { | ||
| Some(replacement) => Transformed::yes(replacement.clone()), | ||
| None => Transformed::no(e), | ||
| }) | ||
| }) | ||
| .map(|t| t.data) | ||
| }; | ||
|
|
||
| // Projection stage | ||
| let mut rewritten_exprs = Vec::with_capacity(expr_list.len()); | ||
| let mut projection_used_names = HashSet::new(); | ||
| for select_expr in expr_list.into_iter() { | ||
| match select_expr { | ||
| SelectExpr::Expression(expr) => { | ||
| let base_alias = expr.name_for_alias()?; | ||
| let rewritten = rewrite_expr(expr, &aggr_map)?; | ||
| let name = | ||
| make_unique_name(base_alias, &mut projection_used_names, 1); | ||
| let final_expr = rewritten.alias(name); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line can create redundant let final_expr = if rewritten.name_for_alias().as_ref() == Ok(&name) {
rewritten
} else {
rewritten.alias(name)
};
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value:useful; category:bug; feedback: The Gemini AI reviewer is correct! If the expression already has the same alias then there is no need to set it second time. Another option would be to delete any previous alias and set only the new one. |
||
|
|
||
| rewritten_exprs.push(SelectExpr::Expression(final_expr)); | ||
| } | ||
| other => rewritten_exprs.push(other), | ||
| } | ||
| } | ||
|
|
||
| // Final projection | ||
| let project_plan = LogicalPlanBuilder::from(plan) | ||
| .project(rewritten_exprs)? | ||
| .build()?; | ||
|
|
||
| Ok(DataFrame { | ||
| session_state: self.session_state, | ||
|
|
||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
datafusion/core/src/dataframe/mod.rs:451: Thehas_non_aggregate_exprcheck treats any expression that contains an aggregate anywhere as “aggregate”, which can let expressions likesum(x) + col("y")slip into the aggregate-only path and then fail later with a less clear error. This seems like it could violate the intended “non-aggregate columns must be grouped” semantics for subexpressions.Severity: medium
🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.