Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ object VeloxBackendSettings extends BackendSettingsApi {

override def supportWindowGroupLimitExec(rankLikeFunction: Expression): Boolean = {
rankLikeFunction match {
case _: RowNumber => true
case _: RowNumber | _: Rank | _: DenseRank => true
case _ => false
}
}
Expand Down
16 changes: 16 additions & 0 deletions cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,22 @@ bool SubstraitParser::configSetInOptimization(
return false;
}

bool SubstraitParser::checkWindowFunction(
const ::substrait::extensions::AdvancedExtension& extension,
const std::string& targetFunction) {
const std::string config = "window_function=";
if (extension.has_optimization()) {
google::protobuf::StringValue msg;
extension.optimization().UnpackTo(&msg);
std::size_t pos = msg.value().find(config);
if ((pos != std::string::npos) && (msg.value().size() >= targetFunction.size()) &&
(msg.value().substr(pos + config.size(), targetFunction.size()) == targetFunction)) {
return true;
}
}
return false;
}

std::vector<TypePtr> SubstraitParser::sigToTypes(const std::string& signature) {
std::vector<std::string> typeStrs = SubstraitParser::getSubFunctionTypes(signature);
std::vector<TypePtr> types;
Expand Down
7 changes: 7 additions & 0 deletions cpp/velox/substrait/SubstraitParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ class SubstraitParser {
/// @return Whether the config is set as true.
static bool configSetInOptimization(const ::substrait::extensions::AdvancedExtension&, const std::string& config);

/// @brief Return whether a config is set as true in AdvancedExtension
/// optimization.
/// @param extension Substrait advanced extension.
/// @param target function
/// @return Whether the target function is match.
static bool checkWindowFunction(const ::substrait::extensions::AdvancedExtension&, const std::string& targetFunction);

/// Extract input types from Substrait function signature.
static std::vector<facebook::velox::TypePtr> sigToTypes(const std::string& functionSig);

Expand Down
11 changes: 10 additions & 1 deletion cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1169,9 +1169,18 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(
childNode);
}

auto windowFunc = core::TopNRowNumberNode::RankFunction::kRowNumber;
if (windowGroupLimitRel.has_advanced_extension()) {
if (SubstraitParser::checkWindowFunction(windowGroupLimitRel.advanced_extension(), "rank")){
windowFunc = core::TopNRowNumberNode::RankFunction::kRank;
} else if (SubstraitParser::checkWindowFunction(windowGroupLimitRel.advanced_extension(), "dense_rank")) {
windowFunc = core::TopNRowNumberNode::RankFunction::kDenseRank;
}
}

return std::make_shared<core::TopNRowNumberNode>(
nextPlanNodeId(),
core::TopNRowNumberNode::RankFunction::kRowNumber,
windowFunc,
partitionKeys,
sortingKeys,
sortingOrders,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.ExpressionConverter
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, DenseRank, Expression, Rank, RowNumber, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.window.{GlutenFinal, GlutenPartial, GlutenWindowGroupLimitMode}

import com.google.protobuf.StringValue
import io.substrait.proto.SortField

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -111,11 +114,27 @@ case class WindowGroupLimitExecTransformer(
builder.build()
}.asJava
if (!validation) {
val windowFunction = rankLikeFunction match {
case _: RowNumber => "row_number"
case _: Rank => "rank"
case _: DenseRank => "dense_rank"
case _ => throw new GlutenNotSupportException(s"Unknow window function $rankLikeFunction")
}
val parametersStr = new StringBuffer("WindowGroupLimitParameters:")
parametersStr
.append("window_function=")
.append(windowFunction)
.append("\n")
val message = StringValue.newBuilder().setValue(parametersStr.toString).build()
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(message),
null)
RelBuilder.makeWindowGroupLimitRel(
input,
partitionsExpressions,
sortFieldList,
limit,
extensionNode,
context,
operatorId)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2006,6 +2006,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeCH("SPLIT")
enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite]
.excludeCH("remove redundant WindowGroupLimits")
.excludeCH("Gluten - remove redundant WindowGroupLimits")
enableSuite[GlutenReplaceHashWithSortAggSuite]
.exclude("replace partial hash aggregate with sort aggregate")
.exclude("replace partial and final hash aggregate together with sort aggregate")
Expand Down Expand Up @@ -2060,6 +2061,8 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeCH(
"window function: multiple window expressions specified by range in a single expression")
.excludeCH("Gluten - Filter on row number")
.excludeCH("Gluten - Filter on rank")
.excludeCH("Gluten - Filter on dense_rank")
enableSuite[GlutenSameResultSuite]
enableSuite[GlutenSaveLoadSuite]
enableSuite[GlutenScalaReflectionRelationSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,8 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenParquetFileMetadataStructRowIndexSuite]
enableSuite[GlutenTableLocationSuite]
enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite]
// rewrite with Gluten test
.exclude("remove redundant WindowGroupLimits")
enableSuite[GlutenSQLCollectLimitExecSuite]
enableSuite[GlutenBatchEvalPythonExecSuite]
// Replaced with other tests that check for native operations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,57 @@
*/
package org.apache.spark.sql.execution

import org.apache.gluten.execution.WindowGroupLimitExecTransformer

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.GlutenSQLTestsBaseTrait
import org.apache.spark.sql.functions.lit

class GlutenRemoveRedundantWindowGroupLimitsSuite
extends RemoveRedundantWindowGroupLimitsSuite
with GlutenSQLTestsBaseTrait {}
with GlutenSQLTestsBaseTrait {
private def checkNumWindowGroupLimits(df: DataFrame, count: Int): Unit = {
val plan = df.queryExecution.executedPlan
assert(collectWithSubqueries(plan) {
case exec: WindowGroupLimitExecTransformer => exec
}.length == count)
}

private def checkWindowGroupLimits(query: String, count: Int): Unit = {
val df = sql(query)
checkNumWindowGroupLimits(df, count)
val result = df.collect()
checkAnswer(df, result)
}

testGluten("remove redundant WindowGroupLimits") {
withTempView("t") {
spark.range(0, 100).withColumn("value", lit(1)).createOrReplaceTempView("t")
val query1 =
"""
|SELECT *
|FROM (
| SELECT id, rank() OVER w AS rn
| FROM t
| GROUP BY id
| WINDOW w AS (PARTITION BY id ORDER BY max(value))
|)
|WHERE rn < 3
|""".stripMargin
checkWindowGroupLimits(query1, 1)

val query2 =
"""
|SELECT *
|FROM (
| SELECT id, rank() OVER w AS rn
| FROM t
| GROUP BY id
| WINDOW w AS (ORDER BY max(value))
|)
|WHERE rn < 3
|""".stripMargin
checkWindowGroupLimits(query2, 2)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,51 @@ class GlutenSQLWindowFunctionSuite extends SQLWindowFunctionSuite with GlutenSQL
)
)
assert(
!getExecutedPlan(df).exists {
getExecutedPlan(df).exists {
case _: WindowGroupLimitExecTransformer => true
case _ => false
}
)
}
}

testGluten("Filter on dense_rank") {
withTable("customer") {
val rdd = spark.sparkContext.parallelize(customerData)
val customerDF = spark.createDataFrame(rdd, customerSchema)
customerDF.createOrReplaceTempView("customer")
val query =
"""
|SELECT * from (SELECT
| c_custkey,
| c_acctbal,
| dense_rank() OVER (
| PARTITION BY c_nationkey,
| "a"
| ORDER BY
| c_custkey,
| "a"
| ) AS rank
|FROM
| customer ORDER BY 1, 2) where rank <=2
|""".stripMargin
val df = sql(query)
checkAnswer(
df,
Seq(
Row(4553, BigDecimal(638841L, 2), 1),
Row(4953, BigDecimal(603728L, 2), 1),
Row(9954, BigDecimal(758725L, 2), 1),
Row(35403, BigDecimal(603470L, 2), 2),
Row(35803, BigDecimal(528487L, 2), 1),
Row(61065, BigDecimal(728477L, 2), 1),
Row(95337, BigDecimal(91561L, 2), 2),
Row(127412, BigDecimal(462141L, 2), 2),
Row(148303, BigDecimal(430230L, 2), 2)
)
)
assert(
getExecutedPlan(df).exists {
case _: WindowGroupLimitExecTransformer => true
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeCH("SPLIT")
enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite]
.excludeCH("remove redundant WindowGroupLimits")
.excludeCH("Gluten - remove redundant WindowGroupLimits")
enableSuite[GlutenReplaceHashWithSortAggSuite]
.exclude("replace partial hash aggregate with sort aggregate")
.exclude("replace partial and final hash aggregate together with sort aggregate")
Expand Down Expand Up @@ -2036,6 +2037,8 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeCH(
"window function: multiple window expressions specified by range in a single expression")
.excludeCH("Gluten - Filter on row number")
.excludeCH("Gluten - Filter on rank")
.excludeCH("Gluten - Filter on dense_rank")
enableSuite[GlutenSameResultSuite]
enableSuite[GlutenSaveLoadSuite]
enableSuite[GlutenScalaReflectionRelationSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,8 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenParquetFileMetadataStructRowIndexSuite]
enableSuite[GlutenTableLocationSuite]
enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite]
// rewrite with Gluten test
.exclude("remove redundant WindowGroupLimits")
enableSuite[GlutenSQLCollectLimitExecSuite]
// Generated suites for org.apache.spark.sql.execution.python
// TODO: 4.x enableSuite[GlutenPythonDataSourceSuite] // 1 failure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,57 @@
*/
package org.apache.spark.sql.execution

import org.apache.gluten.execution.WindowGroupLimitExecTransformer

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.GlutenSQLTestsBaseTrait
import org.apache.spark.sql.functions.lit

class GlutenRemoveRedundantWindowGroupLimitsSuite
extends RemoveRedundantWindowGroupLimitsSuite
with GlutenSQLTestsBaseTrait {}
with GlutenSQLTestsBaseTrait {
private def checkNumWindowGroupLimits(df: DataFrame, count: Int): Unit = {
val plan = df.queryExecution.executedPlan
assert(collectWithSubqueries(plan) {
case exec: WindowGroupLimitExecTransformer => exec
}.length == count)
}

private def checkWindowGroupLimits(query: String, count: Int): Unit = {
val df = sql(query)
checkNumWindowGroupLimits(df, count)
val result = df.collect()
checkAnswer(df, result)
}

testGluten("remove redundant WindowGroupLimits") {
withTempView("t") {
spark.range(0, 100).withColumn("value", lit(1)).createOrReplaceTempView("t")
val query1 =
"""
|SELECT *
|FROM (
| SELECT id, rank() OVER w AS rn
| FROM t
| GROUP BY id
| WINDOW w AS (PARTITION BY id ORDER BY max(value))
|)
|WHERE rn < 3
|""".stripMargin
checkWindowGroupLimits(query1, 1)

val query2 =
"""
|SELECT *
|FROM (
| SELECT id, rank() OVER w AS rn
| FROM t
| GROUP BY id
| WINDOW w AS (ORDER BY max(value))
|)
|WHERE rn < 3
|""".stripMargin
checkWindowGroupLimits(query2, 2)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,51 @@ class GlutenSQLWindowFunctionSuite extends SQLWindowFunctionSuite with GlutenSQL
)
)
assert(
!getExecutedPlan(df).exists {
getExecutedPlan(df).exists {
case _: WindowGroupLimitExecTransformer => true
case _ => false
}
)
}
}

testGluten("Filter on dense_rank") {
withTable("customer") {
val rdd = spark.sparkContext.parallelize(customerData)
val customerDF = spark.createDataFrame(rdd, customerSchema)
customerDF.createOrReplaceTempView("customer")
val query =
"""
|SELECT * from (SELECT
| c_custkey,
| c_acctbal,
| dense_rank() OVER (
| PARTITION BY c_nationkey,
| "a"
| ORDER BY
| c_custkey,
| "a"
| ) AS rank
|FROM
| customer ORDER BY 1, 2) where rank <=2
|""".stripMargin
val df = sql(query)
checkAnswer(
df,
Seq(
Row(4553, BigDecimal(638841L, 2), 1),
Row(4953, BigDecimal(603728L, 2), 1),
Row(9954, BigDecimal(758725L, 2), 1),
Row(35403, BigDecimal(603470L, 2), 2),
Row(35803, BigDecimal(528487L, 2), 1),
Row(61065, BigDecimal(728477L, 2), 1),
Row(95337, BigDecimal(91561L, 2), 2),
Row(127412, BigDecimal(462141L, 2), 2),
Row(148303, BigDecimal(430230L, 2), 2)
)
)
assert(
getExecutedPlan(df).exists {
case _: WindowGroupLimitExecTransformer => true
case _ => false
}
Expand Down
Loading