diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index 8a1a343087f1..24d08a57920d 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -419,7 +419,7 @@ object VeloxBackendSettings extends BackendSettingsApi { override def supportWindowGroupLimitExec(rankLikeFunction: Expression): Boolean = { rankLikeFunction match { - case _: RowNumber => true + case _: RowNumber | _: Rank | _: DenseRank => true case _ => false } } diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index 2bc1dd71c301..c67ad56f0932 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -289,6 +289,22 @@ bool SubstraitParser::configSetInOptimization( return false; } +bool SubstraitParser::checkWindowFunction( + const ::substrait::extensions::AdvancedExtension& extension, + const std::string& targetFunction) { + const std::string config = "window_function="; + if (extension.has_optimization()) { + google::protobuf::StringValue msg; + extension.optimization().UnpackTo(&msg); + std::size_t pos = msg.value().find(config); + if ((pos != std::string::npos) && (msg.value().size() >= targetFunction.size()) && + (msg.value().substr(pos + config.size(), targetFunction.size()) == targetFunction)) { + return true; + } + } + return false; +} + std::vector SubstraitParser::sigToTypes(const std::string& signature) { std::vector typeStrs = SubstraitParser::getSubFunctionTypes(signature); std::vector types; diff --git a/cpp/velox/substrait/SubstraitParser.h b/cpp/velox/substrait/SubstraitParser.h index f42d05b4a21c..8131851ed094 100644 --- a/cpp/velox/substrait/SubstraitParser.h +++ b/cpp/velox/substrait/SubstraitParser.h @@ -93,6 +93,13 @@ class SubstraitParser { /// @return Whether the config is set as true. static bool configSetInOptimization(const ::substrait::extensions::AdvancedExtension&, const std::string& config); + /// @brief Return whether a config is set as true in AdvancedExtension + /// optimization. + /// @param extension Substrait advanced extension. + /// @param target function + /// @return Whether the target function is match. + static bool checkWindowFunction(const ::substrait::extensions::AdvancedExtension&, const std::string& targetFunction); + /// Extract input types from Substrait function signature. static std::vector sigToTypes(const std::string& functionSig); diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index b543dfa8ba89..83099efddeea 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -1169,9 +1169,18 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan( childNode); } + auto windowFunc = core::TopNRowNumberNode::RankFunction::kRowNumber; + if (windowGroupLimitRel.has_advanced_extension()) { + if (SubstraitParser::checkWindowFunction(windowGroupLimitRel.advanced_extension(), "rank")){ + windowFunc = core::TopNRowNumberNode::RankFunction::kRank; + } else if (SubstraitParser::checkWindowFunction(windowGroupLimitRel.advanced_extension(), "dense_rank")) { + windowFunc = core::TopNRowNumberNode::RankFunction::kDenseRank; + } + } + return std::make_shared( nextPlanNodeId(), - core::TopNRowNumberNode::RankFunction::kRowNumber, + windowFunc, partitionKeys, sortingKeys, sortingOrders, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala index 282e1b8e712f..27bc765047e4 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala @@ -17,16 +17,19 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.expression.ExpressionConverter import org.apache.gluten.metrics.MetricsUpdater import org.apache.gluten.substrait.SubstraitContext +import org.apache.gluten.substrait.extensions.ExtensionBuilder import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, DenseRank, Expression, Rank, RowNumber, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.window.{GlutenFinal, GlutenPartial, GlutenWindowGroupLimitMode} +import com.google.protobuf.StringValue import io.substrait.proto.SortField import scala.collection.JavaConverters._ @@ -111,11 +114,27 @@ case class WindowGroupLimitExecTransformer( builder.build() }.asJava if (!validation) { + val windowFunction = rankLikeFunction match { + case _: RowNumber => "row_number" + case _: Rank => "rank" + case _: DenseRank => "dense_rank" + case _ => throw new GlutenNotSupportException(s"Unknow window function $rankLikeFunction") + } + val parametersStr = new StringBuffer("WindowGroupLimitParameters:") + parametersStr + .append("window_function=") + .append(windowFunction) + .append("\n") + val message = StringValue.newBuilder().setValue(parametersStr.toString).build() + val extensionNode = ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance.packPBMessage(message), + null) RelBuilder.makeWindowGroupLimitRel( input, partitionsExpressions, sortFieldList, limit, + extensionNode, context, operatorId) } else { diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index d325d8a6b9c0..29d7534e8fae 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -2006,6 +2006,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .excludeCH("SPLIT") enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite] .excludeCH("remove redundant WindowGroupLimits") + .excludeCH("Gluten - remove redundant WindowGroupLimits") enableSuite[GlutenReplaceHashWithSortAggSuite] .exclude("replace partial hash aggregate with sort aggregate") .exclude("replace partial and final hash aggregate together with sort aggregate") @@ -2060,6 +2061,8 @@ class ClickHouseTestSettings extends BackendTestSettings { .excludeCH( "window function: multiple window expressions specified by range in a single expression") .excludeCH("Gluten - Filter on row number") + .excludeCH("Gluten - Filter on rank") + .excludeCH("Gluten - Filter on dense_rank") enableSuite[GlutenSameResultSuite] enableSuite[GlutenSaveLoadSuite] enableSuite[GlutenScalaReflectionRelationSuite] diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index f4427d7d43fb..b76a717e426e 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -960,6 +960,8 @@ class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenParquetFileMetadataStructRowIndexSuite] enableSuite[GlutenTableLocationSuite] enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite] + // rewrite with Gluten test + .exclude("remove redundant WindowGroupLimits") enableSuite[GlutenSQLCollectLimitExecSuite] enableSuite[GlutenBatchEvalPythonExecSuite] // Replaced with other tests that check for native operations diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala index 9d819d2bd90f..455fa283b1cf 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala @@ -16,8 +16,57 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.execution.WindowGroupLimitExecTransformer + +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.GlutenSQLTestsBaseTrait +import org.apache.spark.sql.functions.lit class GlutenRemoveRedundantWindowGroupLimitsSuite extends RemoveRedundantWindowGroupLimitsSuite - with GlutenSQLTestsBaseTrait {} + with GlutenSQLTestsBaseTrait { + private def checkNumWindowGroupLimits(df: DataFrame, count: Int): Unit = { + val plan = df.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { + case exec: WindowGroupLimitExecTransformer => exec + }.length == count) + } + + private def checkWindowGroupLimits(query: String, count: Int): Unit = { + val df = sql(query) + checkNumWindowGroupLimits(df, count) + val result = df.collect() + checkAnswer(df, result) + } + + testGluten("remove redundant WindowGroupLimits") { + withTempView("t") { + spark.range(0, 100).withColumn("value", lit(1)).createOrReplaceTempView("t") + val query1 = + """ + |SELECT * + |FROM ( + | SELECT id, rank() OVER w AS rn + | FROM t + | GROUP BY id + | WINDOW w AS (PARTITION BY id ORDER BY max(value)) + |) + |WHERE rn < 3 + |""".stripMargin + checkWindowGroupLimits(query1, 1) + + val query2 = + """ + |SELECT * + |FROM ( + | SELECT id, rank() OVER w AS rn + | FROM t + | GROUP BY id + | WINDOW w AS (ORDER BY max(value)) + |) + |WHERE rn < 3 + |""".stripMargin + checkWindowGroupLimits(query2, 2) + } + } +} diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala index 4a87bac690e8..32e7e2c717c6 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala @@ -176,7 +176,51 @@ class GlutenSQLWindowFunctionSuite extends SQLWindowFunctionSuite with GlutenSQL ) ) assert( - !getExecutedPlan(df).exists { + getExecutedPlan(df).exists { + case _: WindowGroupLimitExecTransformer => true + case _ => false + } + ) + } + } + + testGluten("Filter on dense_rank") { + withTable("customer") { + val rdd = spark.sparkContext.parallelize(customerData) + val customerDF = spark.createDataFrame(rdd, customerSchema) + customerDF.createOrReplaceTempView("customer") + val query = + """ + |SELECT * from (SELECT + | c_custkey, + | c_acctbal, + | dense_rank() OVER ( + | PARTITION BY c_nationkey, + | "a" + | ORDER BY + | c_custkey, + | "a" + | ) AS rank + |FROM + | customer ORDER BY 1, 2) where rank <=2 + |""".stripMargin + val df = sql(query) + checkAnswer( + df, + Seq( + Row(4553, BigDecimal(638841L, 2), 1), + Row(4953, BigDecimal(603728L, 2), 1), + Row(9954, BigDecimal(758725L, 2), 1), + Row(35403, BigDecimal(603470L, 2), 2), + Row(35803, BigDecimal(528487L, 2), 1), + Row(61065, BigDecimal(728477L, 2), 1), + Row(95337, BigDecimal(91561L, 2), 2), + Row(127412, BigDecimal(462141L, 2), 2), + Row(148303, BigDecimal(430230L, 2), 2) + ) + ) + assert( + getExecutedPlan(df).exists { case _: WindowGroupLimitExecTransformer => true case _ => false } diff --git a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index abccf9fe912d..ec99089c324e 100644 --- a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -1982,6 +1982,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .excludeCH("SPLIT") enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite] .excludeCH("remove redundant WindowGroupLimits") + .excludeCH("Gluten - remove redundant WindowGroupLimits") enableSuite[GlutenReplaceHashWithSortAggSuite] .exclude("replace partial hash aggregate with sort aggregate") .exclude("replace partial and final hash aggregate together with sort aggregate") @@ -2036,6 +2037,8 @@ class ClickHouseTestSettings extends BackendTestSettings { .excludeCH( "window function: multiple window expressions specified by range in a single expression") .excludeCH("Gluten - Filter on row number") + .excludeCH("Gluten - Filter on rank") + .excludeCH("Gluten - Filter on dense_rank") enableSuite[GlutenSameResultSuite] enableSuite[GlutenSaveLoadSuite] enableSuite[GlutenScalaReflectionRelationSuite] diff --git a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 202705b6d1b5..f5c9d22db6ac 100644 --- a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -1122,6 +1122,8 @@ class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenParquetFileMetadataStructRowIndexSuite] enableSuite[GlutenTableLocationSuite] enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite] + // rewrite with Gluten test + .exclude("remove redundant WindowGroupLimits") enableSuite[GlutenSQLCollectLimitExecSuite] // Generated suites for org.apache.spark.sql.execution.python // TODO: 4.x enableSuite[GlutenPythonDataSourceSuite] // 1 failure diff --git a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala index 9d819d2bd90f..455fa283b1cf 100644 --- a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala +++ b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala @@ -16,8 +16,57 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.execution.WindowGroupLimitExecTransformer + +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.GlutenSQLTestsBaseTrait +import org.apache.spark.sql.functions.lit class GlutenRemoveRedundantWindowGroupLimitsSuite extends RemoveRedundantWindowGroupLimitsSuite - with GlutenSQLTestsBaseTrait {} + with GlutenSQLTestsBaseTrait { + private def checkNumWindowGroupLimits(df: DataFrame, count: Int): Unit = { + val plan = df.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { + case exec: WindowGroupLimitExecTransformer => exec + }.length == count) + } + + private def checkWindowGroupLimits(query: String, count: Int): Unit = { + val df = sql(query) + checkNumWindowGroupLimits(df, count) + val result = df.collect() + checkAnswer(df, result) + } + + testGluten("remove redundant WindowGroupLimits") { + withTempView("t") { + spark.range(0, 100).withColumn("value", lit(1)).createOrReplaceTempView("t") + val query1 = + """ + |SELECT * + |FROM ( + | SELECT id, rank() OVER w AS rn + | FROM t + | GROUP BY id + | WINDOW w AS (PARTITION BY id ORDER BY max(value)) + |) + |WHERE rn < 3 + |""".stripMargin + checkWindowGroupLimits(query1, 1) + + val query2 = + """ + |SELECT * + |FROM ( + | SELECT id, rank() OVER w AS rn + | FROM t + | GROUP BY id + | WINDOW w AS (ORDER BY max(value)) + |) + |WHERE rn < 3 + |""".stripMargin + checkWindowGroupLimits(query2, 2) + } + } +} diff --git a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala index 7c803dd78d20..7515d45fca57 100644 --- a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala +++ b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala @@ -178,7 +178,51 @@ class GlutenSQLWindowFunctionSuite extends SQLWindowFunctionSuite with GlutenSQL ) ) assert( - !getExecutedPlan(df).exists { + getExecutedPlan(df).exists { + case _: WindowGroupLimitExecTransformer => true + case _ => false + } + ) + } + } + + testGluten("Filter on dense_rank") { + withTable("customer") { + val rdd = spark.sparkContext.parallelize(customerData) + val customerDF = spark.createDataFrame(rdd, customerSchema) + customerDF.createOrReplaceTempView("customer") + val query = + """ + |SELECT * from (SELECT + | c_custkey, + | c_acctbal, + | dense_rank() OVER ( + | PARTITION BY c_nationkey, + | "a" + | ORDER BY + | c_custkey, + | "a" + | ) AS rank + |FROM + | customer ORDER BY 1, 2) where rank <=2 + |""".stripMargin + val df = sql(query) + checkAnswer( + df, + Seq( + Row(4553, BigDecimal(638841L, 2), 1), + Row(4953, BigDecimal(603728L, 2), 1), + Row(9954, BigDecimal(758725L, 2), 1), + Row(35403, BigDecimal(603470L, 2), 2), + Row(35803, BigDecimal(528487L, 2), 1), + Row(61065, BigDecimal(728477L, 2), 1), + Row(95337, BigDecimal(91561L, 2), 2), + Row(127412, BigDecimal(462141L, 2), 2), + Row(148303, BigDecimal(430230L, 2), 2) + ) + ) + assert( + getExecutedPlan(df).exists { case _: WindowGroupLimitExecTransformer => true case _ => false } diff --git a/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index abccf9fe912d..ec99089c324e 100644 --- a/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -1982,6 +1982,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .excludeCH("SPLIT") enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite] .excludeCH("remove redundant WindowGroupLimits") + .excludeCH("Gluten - remove redundant WindowGroupLimits") enableSuite[GlutenReplaceHashWithSortAggSuite] .exclude("replace partial hash aggregate with sort aggregate") .exclude("replace partial and final hash aggregate together with sort aggregate") @@ -2036,6 +2037,8 @@ class ClickHouseTestSettings extends BackendTestSettings { .excludeCH( "window function: multiple window expressions specified by range in a single expression") .excludeCH("Gluten - Filter on row number") + .excludeCH("Gluten - Filter on rank") + .excludeCH("Gluten - Filter on dense_rank") enableSuite[GlutenSameResultSuite] enableSuite[GlutenSaveLoadSuite] enableSuite[GlutenScalaReflectionRelationSuite] diff --git a/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index a74142c95d96..e8f8dfa76253 100644 --- a/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -1108,6 +1108,8 @@ class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenParquetFileMetadataStructRowIndexSuite] enableSuite[GlutenTableLocationSuite] enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite] + // rewrite with Gluten test + .exclude("remove redundant WindowGroupLimits") enableSuite[GlutenSQLCollectLimitExecSuite] // Generated suites for org.apache.spark.sql.execution.python // TODO: 4.x enableSuite[GlutenPythonDataSourceSuite] diff --git a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala index 9d819d2bd90f..455fa283b1cf 100644 --- a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala +++ b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala @@ -16,8 +16,57 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.execution.WindowGroupLimitExecTransformer + +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.GlutenSQLTestsBaseTrait +import org.apache.spark.sql.functions.lit class GlutenRemoveRedundantWindowGroupLimitsSuite extends RemoveRedundantWindowGroupLimitsSuite - with GlutenSQLTestsBaseTrait {} + with GlutenSQLTestsBaseTrait { + private def checkNumWindowGroupLimits(df: DataFrame, count: Int): Unit = { + val plan = df.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { + case exec: WindowGroupLimitExecTransformer => exec + }.length == count) + } + + private def checkWindowGroupLimits(query: String, count: Int): Unit = { + val df = sql(query) + checkNumWindowGroupLimits(df, count) + val result = df.collect() + checkAnswer(df, result) + } + + testGluten("remove redundant WindowGroupLimits") { + withTempView("t") { + spark.range(0, 100).withColumn("value", lit(1)).createOrReplaceTempView("t") + val query1 = + """ + |SELECT * + |FROM ( + | SELECT id, rank() OVER w AS rn + | FROM t + | GROUP BY id + | WINDOW w AS (PARTITION BY id ORDER BY max(value)) + |) + |WHERE rn < 3 + |""".stripMargin + checkWindowGroupLimits(query1, 1) + + val query2 = + """ + |SELECT * + |FROM ( + | SELECT id, rank() OVER w AS rn + | FROM t + | GROUP BY id + | WINDOW w AS (ORDER BY max(value)) + |) + |WHERE rn < 3 + |""".stripMargin + checkWindowGroupLimits(query2, 2) + } + } +} diff --git a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala index 7c803dd78d20..7515d45fca57 100644 --- a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala +++ b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala @@ -178,7 +178,51 @@ class GlutenSQLWindowFunctionSuite extends SQLWindowFunctionSuite with GlutenSQL ) ) assert( - !getExecutedPlan(df).exists { + getExecutedPlan(df).exists { + case _: WindowGroupLimitExecTransformer => true + case _ => false + } + ) + } + } + + testGluten("Filter on dense_rank") { + withTable("customer") { + val rdd = spark.sparkContext.parallelize(customerData) + val customerDF = spark.createDataFrame(rdd, customerSchema) + customerDF.createOrReplaceTempView("customer") + val query = + """ + |SELECT * from (SELECT + | c_custkey, + | c_acctbal, + | dense_rank() OVER ( + | PARTITION BY c_nationkey, + | "a" + | ORDER BY + | c_custkey, + | "a" + | ) AS rank + |FROM + | customer ORDER BY 1, 2) where rank <=2 + |""".stripMargin + val df = sql(query) + checkAnswer( + df, + Seq( + Row(4553, BigDecimal(638841L, 2), 1), + Row(4953, BigDecimal(603728L, 2), 1), + Row(9954, BigDecimal(758725L, 2), 1), + Row(35403, BigDecimal(603470L, 2), 2), + Row(35803, BigDecimal(528487L, 2), 1), + Row(61065, BigDecimal(728477L, 2), 1), + Row(95337, BigDecimal(91561L, 2), 2), + Row(127412, BigDecimal(462141L, 2), 2), + Row(148303, BigDecimal(430230L, 2), 2) + ) + ) + assert( + getExecutedPlan(df).exists { case _: WindowGroupLimitExecTransformer => true case _ => false }