Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.PullOutNondeterministic
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, AttributeSet, ExpressionSet}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
Expand All @@ -32,22 +31,8 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
case upper @ Aggregate(_, _, lower: Aggregate) if isLowerRedundant(upper, lower) =>
val aliasMap = getAliasMap(lower)

val newAggregate = upper.copy(
child = lower.child,
groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)),
aggregateExpressions = upper.aggregateExpressions.map(
replaceAliasButKeepName(_, aliasMap))
)

// We might have introduces non-deterministic grouping expression
if (newAggregate.groupingExpressions.exists(!_.deterministic)) {
PullOutNondeterministic.applyLocally.applyOrElse(newAggregate, identity[LogicalPlan])
} else {
newAggregate
}

val projectList = lower.aggregateExpressions.filter(upper.references.contains(_))
upper.copy(child = Project(projectList, lower.child))
case agg @ Aggregate(groupingExps, _, child)
if agg.groupOnly && child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
Project(agg.aggregateExpressions, child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class RemoveRedundantAggregatesSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("RemoveRedundantAggregates", FixedPoint(10),
RemoveRedundantAggregates) :: Nil
RemoveRedundantAggregates,
RemoveNoopOperators) :: Nil
}

private val relation = LocalRelation('a.int, 'b.int)
Expand All @@ -53,6 +54,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy('a)('a)
.analyze
val expected = relation
.select('a)
.groupBy('a)('a)
.analyze
val optimized = Optimize.execute(query)
Expand All @@ -68,6 +70,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy('a)('a)
.analyze
val expected = relation
.select('a)
.groupBy('a)('a)
.analyze
val optimized = Optimize.execute(query)
Expand All @@ -81,6 +84,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy('a)('a)
.analyze
val expected = relation
.select('a)
.groupBy('a)('a)
.analyze
val optimized = Optimize.execute(query)
Expand All @@ -94,7 +98,8 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy('c)('c)
.analyze
val expected = relation
.groupBy('a + 'b)(('a + 'b) as 'c)
.select(('a + 'b) as 'c)
.groupBy('c)('c')
.analyze
val optimized = Optimize.execute(query)
comparePlans(optimized, expected)
Expand All @@ -107,6 +112,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy('a)('a, rand(0) as 'c)
.analyze
val expected = relation
.select('a)
.groupBy('a)('a, rand(0) as 'c)
.analyze
val optimized = Optimize.execute(query)
Expand All @@ -119,8 +125,9 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy('a, 'c)('a, 'c)
.analyze
val expected = relation
.groupBy('a, 'c)('a, rand(0) as 'c)
.analyze
.select('a, 'b, rand(0) as '_nondeterministic)
.select('a, '_nondeterministic as 'c)
.groupBy('a, 'c)('a, 'c)
val optimized = Optimize.execute(query)
comparePlans(optimized, expected)
}
Expand Down Expand Up @@ -152,7 +159,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {

test("Remove redundant aggregate - upper has contains foldable expressions") {
val originalQuery = x.groupBy('a, 'b)('a, 'b).groupBy('a)('a, TrueLiteral).analyze
val correctAnswer = x.groupBy('a)('a, TrueLiteral).analyze
val correctAnswer = x.select('a).groupBy('a)('a, TrueLiteral).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
Expand Down Expand Up @@ -188,7 +195,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy("x.a".attr, "x.b".attr)("x.a".attr, "x.b".attr)
val correctAnswer = x.groupBy('a, 'b)('a, 'b)
.join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
.select("x.a".attr, "x.b".attr)

val optimized = Optimize.execute(originalQuery.analyze)
comparePlans(optimized, correctAnswer.analyze)
Expand All @@ -202,7 +208,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy("x.a".attr, "d".attr)("x.a".attr, "d".attr)
val correctAnswer = x.groupBy('a, 'b)('a, 'b.as("d"))
.join(y, joinType, Some("x.a".attr === "y.a".attr && "d".attr === "y.b".attr))
.select("x.a".attr, "d".attr)

val optimized = Optimize.execute(originalQuery.analyze)
comparePlans(optimized, correctAnswer.analyze)
Expand Down Expand Up @@ -232,7 +237,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy("x.a".attr, "x.b".attr)("x.a".attr)
val correctAnswer = x.groupBy('a, 'b)('a, 'b)
.join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
.select("x.a".attr, "x.b".attr)
.join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
.select("x.a".attr)

Expand Down
13 changes: 13 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/group-by.sql
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,16 @@ SELECT
FROM aggr
GROUP BY k
ORDER BY k;


-- SPARK-44846: PushFoldableIntoBranches in complex grouping expressions cause bindReference error
SELECT c * 2 AS d
FROM (
SELECT if(b > 1, 1, b) AS c
FROM (
SELECT if(a < 0, 0, a) AS b
FROM VALUES (-1), (1), (2) AS t1(a)
) t2
GROUP BY b
) t3
GROUP BY c;
18 changes: 18 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/group-by.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -1028,3 +1028,21 @@ struct<k:int,percentile_disc(0.25) WITHIN GROUP (ORDER BY v):double,percentile_d
2 10.0 30.0
3 60.0 60.0
4 NULL NULL


-- !query
SELECT c * 2 AS d
FROM (
SELECT if(b > 1, 1, b) AS c
FROM (
SELECT if(a < 0, 0, a) AS b
FROM VALUES (-1), (1), (2) AS t1(a)
) t2
GROUP BY b
) t3
GROUP BY c
-- !query schema
struct<d:int>
-- !query output
0
2
Loading