diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 773868b0c450..426f806fb982 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -113,8 +113,7 @@ object VeloxRuleApi { injector.injectPostTransform(_ => AppendBatchResizeForShuffleInputAndOutput()) injector.injectPostTransform(_ => GpuBufferBatchResizeForShuffleInputOutput()) injector.injectPostTransform(_ => UnionTransformerRule()) - injector.injectPostTransform(c => PartialProjectRule.apply(c.session)) - injector.injectPostTransform(_ => PartialGenerateRule()) + injector.injectPostTransform(_ => PartialFallbackRules()) injector.injectPostTransform(_ => RemoveNativeWriteFilesSortAndProject()) injector.injectPostTransform(_ => PushDownFilterToScan) injector.injectPostTransform(_ => PushDownInputFileExpression.PostOffload) @@ -218,8 +217,7 @@ object VeloxRuleApi { injector.injectPostTransform(_ => GpuBufferBatchResizeForShuffleInputOutput()) injector.injectPostTransform(_ => RemoveTransitions) injector.injectPostTransform(_ => UnionTransformerRule()) - injector.injectPostTransform(c => PartialProjectRule.apply(c.session)) - injector.injectPostTransform(_ => PartialGenerateRule()) + injector.injectPostTransform(_ => PartialFallbackRules()) injector.injectPostTransform(_ => RemoveNativeWriteFilesSortAndProject()) injector.injectPostTransform(_ => PushDownFilterToScan) injector.injectPostTransform(_ => PushDownInputFileExpression.PostOffload) diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/PartialFallback.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/PartialFallback.scala new file mode 100644 index 000000000000..8a7da72b3e92 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/PartialFallback.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.execution.{GenerateExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.internal.SQLConf + +case class PartialFallbackRules() extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + new PartialFallbackRuleExecutor().execute(plan) + } + + private class PartialFallbackRuleExecutor extends RuleExecutor[SparkPlan] { + private def fixedPoint = + FixedPoint( + SQLConf.get.optimizerMaxIterations, + maxIterationsSetting = SQLConf.OPTIMIZER_MAX_ITERATIONS.key) + + override protected def batches: Seq[Batch] = Seq( + Batch("PartialFallback", fixedPoint, PartialProjectRule(), PartialGenerateRule())) + } +} + +object PartialFallback { + def supportPartialFallback(plan: SparkPlan): Boolean = { + plan.isInstanceOf[ProjectExec] || + plan.isInstanceOf[GenerateExec] + } +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/PartialGenerateRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/PartialGenerateRule.scala index 5e641862e4cd..b47bc23fb71c 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/PartialGenerateRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/PartialGenerateRule.scala @@ -17,7 +17,7 @@ package org.apache.gluten.extension import org.apache.gluten.config.GlutenConfig -import org.apache.gluten.execution.{ColumnarPartialGenerateExec, GenerateExecTransformer} +import org.apache.gluten.execution.{ColumnarPartialGenerateExec, GenerateExecTransformer, WholeStageTransformer} import org.apache.gluten.utils.PlanUtil import org.apache.spark.sql.catalyst.expressions.UserDefinedExpression @@ -29,23 +29,23 @@ case class PartialGenerateRule() extends Rule[SparkPlan] { if (!GlutenConfig.get.enableColumnarPartialGenerate) { return plan } - val newPlan = plan match { - // If the root node of the plan is a GenerateExec and its child is a gluten columnar op, - // we try to add a ColumnarPartialGenerateExec - case plan: GenerateExec if PlanUtil.isGlutenColumnarOp(plan.child) => - tryAddColumnarPartialGenerateExec(plan) - case _ => plan - } - newPlan.transformUp { - case parent: SparkPlan - if parent.children.exists(_.isInstanceOf[GenerateExec]) && - PlanUtil.isGlutenColumnarOp(parent) => - parent.mapChildren { - case plan: GenerateExec if PlanUtil.isGlutenColumnarOp(plan.child) => - tryAddColumnarPartialGenerateExec(plan) - case other => other - } - } + // Wrap a WholeStageTransformer to check if the top node supports partial fallback. + // It will be removed afterward. + val wrapped = WholeStageTransformer(plan)(-1) + wrapped + .transformUp { + case parent: SparkPlan + if parent.children.exists(_.isInstanceOf[GenerateExec]) && + (PlanUtil.isGlutenColumnarOp(parent) || PartialFallback.supportPartialFallback( + parent)) => + parent.mapChildren { + case plan: GenerateExec if PlanUtil.isGlutenColumnarOp(plan.child) => + tryAddColumnarPartialGenerateExec(plan) + case other => other + } + } + .children + .head } private def tryAddColumnarPartialGenerateExec(plan: GenerateExec): SparkPlan = { diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala index f60bf1174677..7a07ec0bc971 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala @@ -17,37 +17,36 @@ package org.apache.gluten.extension import org.apache.gluten.config.GlutenConfig -import org.apache.gluten.execution.ColumnarPartialProjectExec +import org.apache.gluten.execution.{ColumnarPartialProjectExec, WholeStageTransformer} import org.apache.gluten.utils.PlanUtil -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} -case class PartialProjectRule(spark: SparkSession) extends Rule[SparkPlan] { +case class PartialProjectRule() extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { if (!GlutenConfig.get.enableColumnarPartialProject) { return plan } - val newPlan = plan match { - // If the root node of the plan is a ProjectExec and its child is a gluten columnar op, - // we try to add a ColumnarPartialProjectExec - case p: ProjectExec if PlanUtil.isGlutenColumnarOp(p.child) => - tryAddColumnarPartialProjectExec(p) - case _ => plan - } + // Wrap a WholeStageTransformer to check if the top node supports partial fallback. + // It will be removed afterward. + val wrapped = WholeStageTransformer(plan)(-1) - newPlan.transformUp { - case parent: SparkPlan - if parent.children.exists(_.isInstanceOf[ProjectExec]) && - PlanUtil.isGlutenColumnarOp(parent) => - parent.mapChildren { - case p: ProjectExec if PlanUtil.isGlutenColumnarOp(p.child) => - tryAddColumnarPartialProjectExec(p) - case other => other - } - } + wrapped + .transformUp { + case parent: SparkPlan + if parent.children.exists(_.isInstanceOf[ProjectExec]) && + (PlanUtil.isGlutenColumnarOp(parent) || PartialFallback.supportPartialFallback( + parent)) => + parent.mapChildren { + case p: ProjectExec if PlanUtil.isGlutenColumnarOp(p.child) => + tryAddColumnarPartialProjectExec(p) + case other => other + } + } + .children + .head } private def tryAddColumnarPartialProjectExec(plan: ProjectExec): SparkPlan = { diff --git a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala index f5a1bd454b20..404477deba0c 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala @@ -118,6 +118,21 @@ abstract class UDFPartialProjectSuite extends WholeStageTransformerSuite { } } + testWithMinSparkVersion("test plus_one in nested project lists", "3.4") { + val sql = """ + |select plus_one(col1) as col2, l_partkey from ( + | select plus_one(l_orderkey) as col1, l_partkey from lineitem + |)""".stripMargin + runQueryAndCompare(sql) { + checkGlutenPlan[ColumnarPartialProjectExec] + } + + val df = spark.sql(sql) + assert(df.queryExecution.executedPlan.collect { + case p: ColumnarPartialProjectExec => p + }.size == 2) + } + test("test plus_one with many columns in project") { runQueryAndCompare("SELECT plus_one(cast(l_orderkey as long)), hash(l_partkey) from lineitem") { checkGlutenPlan[ColumnarPartialProjectExec] diff --git a/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala b/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala index 6919a303903a..296de381cd66 100644 --- a/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala +++ b/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala @@ -157,6 +157,24 @@ class GlutenHiveUDFSuite extends GlutenQueryComparisonTest with SQLTestUtils { } } + test("nested partial fallback") { + withTempFunction("noInputUDTF") { + val plusOne = udf((x: Long) => x + 1) + spark.udf.register("plus_one", plusOne) + sql(s"CREATE TEMPORARY FUNCTION noInputUDTF AS '${classOf[NoInputUDTF].getName}'") + runQueryAndCompare(""" + |select plus_one(col1) as col2, l_partkey from ( + | select col1, l_partkey from lineitem lateral view noInputUDTF() as col1 + |)""".stripMargin) { + df => + { + checkOperatorMatch[ColumnarPartialProjectExec](df) + checkOperatorMatch[ColumnarPartialGenerateExec](df) + } + } + } + } + test("lateral view outer udtf") { withTempFunction("conditionalOutputUDTF") { sql(