diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/HashJoinMetricsUpdater.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/HashJoinMetricsUpdater.scala index ca891bac27c6..d654125a32cd 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/HashJoinMetricsUpdater.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/HashJoinMetricsUpdater.scala @@ -31,24 +31,6 @@ class HashJoinMetricsUpdater(val metrics: Map[String, SQLMetric]) var currentIdx = operatorMetrics.metricsList.size() - 1 var totalTime = 0L - // build side pre projection - if (joinParams.buildPreProjectionNeeded) { - metrics("buildPreProjectionTime") += - (operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong - metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors - totalTime += operatorMetrics.metricsList.get(currentIdx).time - currentIdx -= 1 - } - - // stream side pre projection - if (joinParams.streamPreProjectionNeeded) { - metrics("streamPreProjectionTime") += - (operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong - metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors - totalTime += operatorMetrics.metricsList.get(currentIdx).time - currentIdx -= 1 - } - // update fillingRightJoinSideTime MetricsUtil .getAllProcessorList(operatorMetrics.metricsList.get(currentIdx)) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/SortMergeJoinMetricsUpdater.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/SortMergeJoinMetricsUpdater.scala index e5833a39bc58..d1a6e651d4d7 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/SortMergeJoinMetricsUpdater.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/SortMergeJoinMetricsUpdater.scala @@ -32,24 +32,6 @@ class SortMergeJoinMetricsUpdater(val metrics: Map[String, SQLMetric]) var currentIdx = operatorMetrics.metricsList.size() - 1 var totalTime = 0L - // build side pre projection - if (joinParams.buildPreProjectionNeeded) { - metrics("buildPreProjectionTime") += - (operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong - metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors - totalTime += operatorMetrics.metricsList.get(currentIdx).time - currentIdx -= 1 - } - - // stream side pre projection - if (joinParams.streamPreProjectionNeeded) { - metrics("streamPreProjectionTime") += - (operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong - metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors - totalTime += operatorMetrics.metricsList.get(currentIdx).time - currentIdx -= 1 - } - // update fillingRightJoinSideTime MetricsUtil .getAllProcessorList(operatorMetrics.metricsList.get(currentIdx)) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala index d7839ef774ef..ff5b7879f29d 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala @@ -70,12 +70,12 @@ class GlutenClickHouseTPCHBucketSuite plans(3) .asInstanceOf[HashJoinLikeExecTransformer] .left - .isInstanceOf[InputIteratorTransformer]) + .isInstanceOf[ProjectExecTransformer]) assert( plans(3) .asInstanceOf[HashJoinLikeExecTransformer] .right - .isInstanceOf[InputIteratorTransformer]) + .isInstanceOf[ProjectExecTransformer]) // Check the bucket join assert( diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCDSMetricsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCDSMetricsSuite.scala index d5b98e1cca5c..59dd838a860b 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCDSMetricsSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCDSMetricsSuite.scala @@ -90,7 +90,7 @@ class GlutenClickHouseTPCDSMetricsSuite extends GlutenClickHouseTPCDSAbstractSui case g: GlutenPlan if !g.isInstanceOf[InputIteratorTransformer] => g } - assert(allGlutenPlans.size == 30) + assert(allGlutenPlans.size == 34) val windowPlan0 = allGlutenPlans(3) assert(windowPlan0.metrics("totalTime").value == 2) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala index f5a70731d560..15b581c3b468 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala @@ -286,7 +286,7 @@ class GlutenClickHouseTPCHMetricsSuite extends ParquetTPCHSuite { case g: GlutenPlan if !g.isInstanceOf[InputIteratorTransformer] => g } - assert(allGlutenPlans.size == 58) + assert(allGlutenPlans.size == 60) val shjPlan = allGlutenPlans(8) assert(shjPlan.metrics("totalTime").value == 6) diff --git a/backends-velox/src/main/scala/org/apache/gluten/metrics/JoinMetricsUpdater.scala b/backends-velox/src/main/scala/org/apache/gluten/metrics/JoinMetricsUpdater.scala index cf894b9da466..589cda77eee4 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/metrics/JoinMetricsUpdater.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/metrics/JoinMetricsUpdater.scala @@ -101,12 +101,6 @@ class HashJoinMetricsUpdater(override val metrics: Map[String, SQLMetric]) val bloomFilterBlocksByteSize: SQLMetric = metrics("bloomFilterBlocksByteSize") - val streamPreProjectionCpuCount: SQLMetric = metrics("streamPreProjectionCpuCount") - val streamPreProjectionWallNanos: SQLMetric = metrics("streamPreProjectionWallNanos") - - val buildPreProjectionCpuCount: SQLMetric = metrics("buildPreProjectionCpuCount") - val buildPreProjectionWallNanos: SQLMetric = metrics("buildPreProjectionWallNanos") - val loadLazyVectorTime: SQLMetric = metrics("loadLazyVectorTime") override protected def updateJoinMetricsInternal( @@ -148,17 +142,6 @@ class HashJoinMetricsUpdater(override val metrics: Map[String, SQLMetric]) hashBuildSpilledFiles += hashBuildMetrics.spilledFiles idx += 1 - if (joinParams.buildPreProjectionNeeded) { - buildPreProjectionCpuCount += joinMetrics.get(idx).cpuCount - buildPreProjectionWallNanos += joinMetrics.get(idx).wallNanos - idx += 1 - } - - if (joinParams.streamPreProjectionNeeded) { - streamPreProjectionCpuCount += joinMetrics.get(idx).cpuCount - streamPreProjectionWallNanos += joinMetrics.get(idx).wallNanos - idx += 1 - } if (TaskResources.inSparkTask()) { SparkMetricsUtil.incMemoryBytesSpilled( TaskResources.getLocalTaskContext().taskMetrics(), @@ -185,11 +168,6 @@ class SortMergeJoinMetricsUpdater(override val metrics: Map[String, SQLMetric]) val peakMemoryBytes: SQLMetric = metrics("peakMemoryBytes") val numMemoryAllocations: SQLMetric = metrics("numMemoryAllocations") - val streamPreProjectionCpuCount: SQLMetric = metrics("streamPreProjectionCpuCount") - val streamPreProjectionWallNanos: SQLMetric = metrics("streamPreProjectionWallNanos") - val bufferPreProjectionCpuCount: SQLMetric = metrics("bufferPreProjectionCpuCount") - val bufferPreProjectionWallNanos: SQLMetric = metrics("bufferPreProjectionWallNanos") - override protected def updateJoinMetricsInternal( joinMetrics: util.ArrayList[OperatorMetrics], joinParams: JoinParams): Unit = { @@ -200,17 +178,5 @@ class SortMergeJoinMetricsUpdater(override val metrics: Map[String, SQLMetric]) peakMemoryBytes += smjMetrics.peakMemoryBytes numMemoryAllocations += smjMetrics.numMemoryAllocations idx += 1 - - if (joinParams.buildPreProjectionNeeded) { - bufferPreProjectionCpuCount += joinMetrics.get(idx).cpuCount - bufferPreProjectionWallNanos += joinMetrics.get(idx).wallNanos - idx += 1 - } - - if (joinParams.streamPreProjectionNeeded) { - streamPreProjectionCpuCount += joinMetrics.get(idx).cpuCount - streamPreProjectionWallNanos += joinMetrics.get(idx).wallNanos - idx += 1 - } } } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala index db26ce298530..be9daf1e4d78 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala @@ -94,8 +94,6 @@ class VeloxMetricsSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa val metrics = smj.get.metrics assert(metrics("numOutputRows").value == 100) assert(metrics("numOutputVectors").value > 0) - assert(metrics("streamPreProjectionCpuCount").value > 0) - assert(metrics("bufferPreProjectionCpuCount").value > 0) } } } @@ -133,8 +131,6 @@ class VeloxMetricsSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa val metrics = smj.get.metrics assert(metrics("numOutputRows").value == 100) assert(metrics("numOutputVectors").value > 0) - assert(metrics("streamPreProjectionCpuCount").value > 0) - assert(metrics("buildPreProjectionCpuCount").value > 0) } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/RewriteSparkPlanRulesManager.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/RewriteSparkPlanRulesManager.scala index 24fa7a6fc974..97ca8807055d 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/RewriteSparkPlanRulesManager.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/RewriteSparkPlanRulesManager.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} case class RewrittenNodeWall(originalChild: SparkPlan) extends LeafExecNode { override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() @@ -50,7 +51,7 @@ class RewriteSparkPlanRulesManager private ( FallbackTags.maybeOffloadable(plan) && rewriteRules.exists(_.isRewritable(plan)) } - private def getFallbackTagBack(rewrittenPlan: SparkPlan): Option[FallbackTag] = { + private def getRewriteNodeBack(rewrittenPlan: SparkPlan): SparkPlan = { // The rewritten plan may contain more nodes than origin, for now it should only be // `ProjectExec`. // TODO: Find a better approach than checking `p.isInstanceOf[ProjectExec]` which is not @@ -59,7 +60,7 @@ class RewriteSparkPlanRulesManager private ( case p if !p.isInstanceOf[ProjectExec] && !p.isInstanceOf[RewrittenNodeWall] => p } assert(target.size == 1) - FallbackTags.getOption(target.head) + target.head } private def applyRewriteRules(origin: SparkPlan): (SparkPlan, Option[String]) = { @@ -99,10 +100,22 @@ class RewriteSparkPlanRulesManager private ( origin } else { validateRule.apply(rewrittenPlan) - val tag = getFallbackTagBack(rewrittenPlan) - if (tag.isDefined) { + val rewriteNode = getRewriteNodeBack(rewrittenPlan) + val allFallbackTags = rewrittenPlan.collect { + case p if !p.isInstanceOf[RewrittenNodeWall] => FallbackTags.getOption(p) + } + if (FallbackTags.getOption(rewriteNode).isDefined) { // If the rewritten plan is still not transformable, return the original plan. - FallbackTags.add(origin, tag.get) + FallbackTags.add(origin, FallbackTags.getOption(rewriteNode).get) + origin + } else if ( + (rewriteNode.isInstanceOf[BroadcastHashJoinExec] || + rewriteNode.isInstanceOf[BroadcastNestedLoopJoinExec]) && + allFallbackTags.exists(_.isDefined) + ) { + // If the inserted projects for join is not transformable, return the original plan. + val reason = allFallbackTags.collect { case Some(s) => s.reason() }.mkString(", ") + FallbackTags.add(origin, FallbackTag.Converter.FromString.from(reason).get) origin } else { rewrittenPlan.transformUp { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala index 57e7373fc3cf..4b1af2481b85 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala @@ -25,6 +25,7 @@ import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.SparkContextUtils +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -189,6 +190,10 @@ abstract class ProjectExecTransformerBase(val list: Seq[NamedExpression], val in }() } + override def doExecuteBroadcast[T](): Broadcast[T] = { + child.executeBroadcast[T]() + } + override def metricsUpdater(): MetricsUpdater = BackendsApiManager.getMetricsApiInstance.genProjectTransformerMetricsUpdater(metrics) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala index e5db3385154d..a6f220fd4879 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{ExpandOutputPartitioningShim, ExplainUtils, SparkPlan} -import org.apache.spark.sql.execution.joins.{BaseJoinExec, HashedRelationBroadcastMode, HashJoin} +import org.apache.spark.sql.execution.joins.{BaseJoinExec, HashedRelationBroadcastMode} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -135,18 +135,11 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport { .forall(types => sameType(types._1, types._2)), "Join keys from two sides should have same length and types" ) - // Spark has an improvement which would patch integer joins keys to a Long value. - // But this improvement would cause add extra project before hash join in velox, - // disabling this improvement as below would help reduce the project. - val (lkeys, rkeys) = if (BackendsApiManager.getSettings.enableJoinKeysRewrite()) { - (HashJoin.rewriteKeyExpr(leftKeys), HashJoin.rewriteKeyExpr(rightKeys)) - } else { - (leftKeys, rightKeys) - } + if (needSwitchChildren) { - (lkeys, rkeys) + (leftKeys, rightKeys) } else { - (rkeys, lkeys) + (rightKeys, leftKeys) } } @@ -234,13 +227,6 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport { val operatorId = context.nextOperatorId(this.nodeName) val joinParams = new JoinParams - if (JoinUtils.preProjectionNeeded(streamedKeyExprs)) { - joinParams.streamPreProjectionNeeded = true - } - if (JoinUtils.preProjectionNeeded(buildKeyExprs)) { - joinParams.buildPreProjectionNeeded = true - } - if (condition.isDefined) { joinParams.isWithCondition = true } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala index a7a31cf471c5..1a8266797519 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.execution -import org.apache.gluten.expression.{AttributeReferenceTransformer, ExpressionConverter} +import org.apache.gluten.expression.ExpressionConverter import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode} @@ -24,9 +24,8 @@ import org.apache.gluten.substrait.extensions.{AdvancedExtensionNode, ExtensionB import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.gluten.utils.SubstraitUtil -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.types.DataType import com.google.protobuf.Any import io.substrait.proto.{CrossRel, JoinRel} @@ -44,83 +43,6 @@ object JoinUtils { } } - def preProjectionNeeded(keyExprs: Seq[Expression]): Boolean = { - !keyExprs.forall(_.isInstanceOf[AttributeReference]) - } - - private def createPreProjectionIfNeeded( - keyExprs: Seq[Expression], - inputNode: RelNode, - inputNodeOutput: Seq[Attribute], - partialConstructedJoinOutput: Seq[Attribute], - substraitContext: SubstraitContext, - operatorId: java.lang.Long, - validation: Boolean): (Seq[(ExpressionNode, DataType)], RelNode, Seq[Attribute]) = { - if (!preProjectionNeeded(keyExprs)) { - // Skip pre-projection if all keys are [AttributeReference]s, - // which can be directly converted into SelectionNode. - val keys = keyExprs.map { - expr => - ( - ExpressionConverter - .replaceWithExpressionTransformer(expr, partialConstructedJoinOutput) - .asInstanceOf[AttributeReferenceTransformer] - .doTransform(substraitContext), - expr.dataType) - } - (keys, inputNode, inputNodeOutput) - } else { - // Pre-projection is constructed from original columns followed by join-key expressions. - val selectOrigins = inputNodeOutput.indices.map(ExpressionBuilder.makeSelection(_)) - val appendedKeys = keyExprs.flatMap { - case _: AttributeReference => None - case expr => - Some( - ( - ExpressionConverter - .replaceWithExpressionTransformer(expr, inputNodeOutput) - .doTransform(substraitContext), - expr.dataType)) - } - val preProjectNode = RelBuilder.makeProjectRel( - inputNode, - new java.util.ArrayList[ExpressionNode]((selectOrigins ++ appendedKeys.map(_._1)).asJava), - createExtensionNode(inputNodeOutput, validation), - substraitContext, - operatorId, - inputNodeOutput.size - ) - - // Compute index for join keys in join outputs. - val offset = partialConstructedJoinOutput.size - val appendedKeysAndIndices = appendedKeys.zipWithIndex.iterator - val keys = keyExprs.map { - case a: AttributeReference => - // The selection index for original AttributeReference is unchanged. - ( - ExpressionConverter - .replaceWithExpressionTransformer(a, partialConstructedJoinOutput) - .asInstanceOf[AttributeReferenceTransformer] - .doTransform(substraitContext), - a.dataType) - case _ => - val (key, idx) = appendedKeysAndIndices.next() - (ExpressionBuilder.makeSelection(idx + offset), key._2) - } - ( - keys, - preProjectNode, - inputNodeOutput ++ - appendedKeys.zipWithIndex.map { - case (key, idx) => - // Create output attributes for appended keys. - // This is used as place holder for finding the right column indexes in post-join - // filters. - AttributeReference(s"col_${idx + offset}", key._2)() - }) - } - } - private def createJoinExtensionNode( joinParameters: Any, output: Seq[Attribute]): AdvancedExtensionNode = { @@ -178,32 +100,32 @@ object JoinUtils { exchangeTable: Boolean, joinType: JoinType, joinParameters: Any, - inputStreamedRelNode: RelNode, - inputBuildRelNode: RelNode, - inputStreamedOutput: Seq[Attribute], - inputBuildOutput: Seq[Attribute], + streamedRelNode: RelNode, + buildRelNode: RelNode, + streamedOutput: Seq[Attribute], + buildOutput: Seq[Attribute], substraitContext: SubstraitContext, operatorId: java.lang.Long, validation: Boolean = false): RelNode = { // scalastyle:on argcount // Create pre-projection for build/streamed plan. Append projected keys to each side. - val (streamedKeys, streamedRelNode, streamedOutput) = createPreProjectionIfNeeded( - streamedKeyExprs, - inputStreamedRelNode, - inputStreamedOutput, - inputStreamedOutput, - substraitContext, - operatorId, - validation) + val streamedKeys = streamedKeyExprs.map { + expr => + ( + ExpressionConverter + .replaceWithExpressionTransformer(expr, streamedOutput) + .doTransform(substraitContext), + expr.dataType) + } - val (buildKeys, buildRelNode, buildOutput) = createPreProjectionIfNeeded( - buildKeyExprs, - inputBuildRelNode, - inputBuildOutput, - streamedOutput ++ inputBuildOutput, - substraitContext, - operatorId, - validation) + val buildKeys = buildKeyExprs.map { + expr => + ( + ExpressionConverter + .replaceWithExpressionTransformer(expr, streamedOutput ++ buildOutput) + .doTransform(substraitContext), + expr.dataType) + } // Combine join keys to make a single expression. val joinExpressionNode = streamedKeys @@ -240,8 +162,8 @@ object JoinUtils { createProjectRelPostJoinRel( exchangeTable, joinType, - inputStreamedOutput, - inputBuildOutput, + streamedOutput, + buildOutput, substraitContext, operatorId, joinRel, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala index bc6a1d3e7b5e..21e000344eed 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala @@ -198,13 +198,6 @@ abstract class SortMergeJoinExecTransformerBase( val operatorId = context.nextOperatorId(this.nodeName) val joinParams = new JoinParams - if (JoinUtils.preProjectionNeeded(leftKeys)) { - joinParams.streamPreProjectionNeeded = true - } - if (JoinUtils.preProjectionNeeded(rightKeys)) { - joinParams.buildPreProjectionNeeded = true - } - val joinRel = JoinUtils.createJoinRel( streamedKeys, bufferedKeys, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala index af1fe35f7c18..39a156794032 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala @@ -18,7 +18,7 @@ package org.apache.gluten.extension.columnar.enumerated import org.apache.gluten.execution.{GlutenPlan, ValidatablePlan} import org.apache.gluten.extension.columnar.FallbackTags -import org.apache.gluten.extension.columnar.offload.OffloadSingleNode +import org.apache.gluten.extension.columnar.offload.{OffloadOthers, OffloadSingleNode} import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode import org.apache.gluten.extension.columnar.validator.Validator import org.apache.gluten.ras.path.Pattern @@ -27,7 +27,8 @@ import org.apache.gluten.ras.rule.{RasRule, Shape} import org.apache.gluten.ras.rule.Shapes.pattern import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} import scala.reflect.{classTag, ClassTag} @@ -129,7 +130,21 @@ object RasOffload { case t: ValidatablePlan => t } val outComes = offloadedNodes.map(_.doValidate()).filter(!_.ok()) - if (outComes.nonEmpty) { + // 4.1 Validate pre project of broadcast join + val notOffload = from match { + case _: BroadcastHashJoinExec | _: BroadcastNestedLoopJoinExec => + val projectOffload = RasOffload.from[ProjectExec](OffloadOthers()) + from + .collect { + case preProject: ProjectExec => projectOffload.offload(preProject) + } + .exists { + case t: ValidatablePlan => !t.doValidate().ok() + case plan if !plan.isInstanceOf[GlutenPlan] => true + } + case _ => false + } + if (outComes.nonEmpty || notOffload) { // 5. If native validation fails on at least one of the offloaded nodes, return // the original one. // diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala index 13c14bffc002..106ea58b480e 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala @@ -17,6 +17,7 @@ package org.apache.gluten.extension.columnar.rewrite import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.extension.columnar.heuristic.RewrittenNodeWall import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.utils.PullOutProjectHelper @@ -24,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Partial} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, TypedAggregateExpression} +import org.apache.spark.sql.execution.joins.{BaseJoinExec, HashJoin} import org.apache.spark.sql.execution.python.ArrowEvalPythonExec import org.apache.spark.sql.execution.window.WindowExec @@ -47,6 +49,7 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper { case _: ExpandExec => true case _: GenerateExec => true case _: ArrowEvalPythonExec => true + case _: BaseJoinExec => true case _ => false } } @@ -96,6 +99,14 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper { windowGroupLimitExecShim.orderSpec.exists(o => isNotAttribute(o.child)) || windowGroupLimitExecShim.partitionSpec.exists(isNotAttribute) case expand: ExpandExec => expand.projections.flatten.exists(isNotAttributeAndLiteral) + case join: BaseJoinExec => + join match { + case _: HashJoin if BackendsApiManager.getSettings.enableJoinKeysRewrite() => + HashJoin.rewriteKeyExpr(join.leftKeys).exists(isNotAttribute) || + HashJoin.rewriteKeyExpr(join.rightKeys).exists(isNotAttribute) + case _ => + join.leftKeys.exists(isNotAttribute) || join.rightKeys.exists(isNotAttribute) + } case _ => false } } @@ -282,6 +293,55 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper { BackendsApiManager.getSparkPlanExecApiInstance.genPreProjectForArrowEvalPythonExec( arrowEvalPythonExec) + case join: BaseJoinExec if needsPreProject(join) => + // Spark has an improvement which would patch integer joins keys to a Long value. + // But this improvement would cause adding extra project before hash join in velox, + // disabling this improvement as below would help reduce the project. + val (leftKeys, rightKeys) = join match { + case _: HashJoin if BackendsApiManager.getSettings.enableJoinKeysRewrite() => + (HashJoin.rewriteKeyExpr(join.leftKeys), HashJoin.rewriteKeyExpr(join.rightKeys)) + case _ => + (join.leftKeys, join.rightKeys) + } + + def pullOutPreProjectForJoin(joinChild: SparkPlan, joinKeys: Seq[Expression]) + : (SparkPlan, Seq[Expression], mutable.HashMap[Expression, NamedExpression]) = { + val expressionMap = new mutable.HashMap[Expression, NamedExpression]() + if (joinKeys.exists(isNotAttribute)) { + val newJoinKeys = + joinKeys.toIndexedSeq.map(replaceExpressionWithAttribute(_, expressionMap)) + val preProject = ProjectExec( + eliminateProjectList(joinChild.outputSet, expressionMap.values.toSeq), + joinChild) + joinChild match { + case r: RewrittenNodeWall => + r.originalChild.logicalLink.foreach(preProject.setLogicalLink) + case _ => + joinChild.logicalLink.foreach(preProject.setLogicalLink) + } + (preProject, newJoinKeys, expressionMap) + } else { + (joinChild, joinKeys, expressionMap) + } + } + + val (newLeft, newLeftKeys, leftMap) = pullOutPreProjectForJoin(join.left, leftKeys) + val (newRight, newRightKeys, rightMap) = pullOutPreProjectForJoin(join.right, rightKeys) + val newCondition = if (leftMap.nonEmpty || rightMap.nonEmpty) { + join.condition.map(_.transform { + case p @ Equality(l, r) => + p.makeCopy( + Array(leftMap.getOrElse(l.canonicalized, l), rightMap.getOrElse(r.canonicalized, r))) + }) + } else { + join.condition + } + val newJoin = + copyBaseJoinExec(join)(newLeft, newRight, newLeftKeys, newRightKeys, newCondition) + val newProject = ProjectExec(join.output, newJoin) + newJoin.logicalLink.foreach(newProject.setLogicalLink) + newProject + case _ => plan } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/substrait/SubstraitContext.scala b/gluten-substrait/src/main/scala/org/apache/gluten/substrait/SubstraitContext.scala index 1ceb2d4155ab..56582825dbb9 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/substrait/SubstraitContext.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/substrait/SubstraitContext.scala @@ -21,12 +21,6 @@ import java.security.InvalidParameterException import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList, Map => JMap} case class JoinParams() { - // Whether preProjection is needed in streamed side. - var streamPreProjectionNeeded = false - - // Whether preProjection is needed in build side. - var buildPreProjectionNeeded = false - // Whether postProjection is needed after Join. var postProjectionNeeded = true diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala index a4f8d2cbf4bb..b4e4a95c0770 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala @@ -21,7 +21,9 @@ import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete, Partial} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.aggregate._ +import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType} @@ -140,6 +142,59 @@ trait PullOutProjectHelper { throw new GlutenNotSupportException(s"Unsupported agg $agg") } + protected def copyBaseJoinExec(join: BaseJoinExec)( + newLeft: SparkPlan = join.left, + newRight: SparkPlan = join.right, + newLeftKeys: Seq[Expression] = join.leftKeys, + newRightKeys: Seq[Expression] = join.rightKeys, + newCondition: Option[Expression] = join.condition): BaseJoinExec = join match { + case bhj: BroadcastHashJoinExec => + val newBhj = bhj.copy( + left = newLeft, + right = newRight, + leftKeys = newLeftKeys, + rightKeys = newRightKeys, + condition = newCondition) + newBhj.copyTagsFrom(bhj) + newBhj + case shj: ShuffledHashJoinExec => + val newShj = shj.copy( + left = newLeft, + right = newRight, + leftKeys = newLeftKeys, + rightKeys = newRightKeys, + condition = newCondition) + newShj.copyTagsFrom(shj) + newShj + case smj: SortMergeJoinExec => + val newSmj = smj.copy( + left = newLeft, + right = newRight, + leftKeys = newLeftKeys, + rightKeys = newRightKeys, + condition = newCondition) + newSmj.copyTagsFrom(smj) + newSmj + case nestedLoopJoin: BroadcastNestedLoopJoinExec => + val newNestedLoopJoin = nestedLoopJoin.copy( + left = newLeft, + right = newRight, + condition = newCondition + ) + newNestedLoopJoin.copyTagsFrom(nestedLoopJoin) + newNestedLoopJoin + case cartesianProduct: CartesianProductExec => + val newCartesianProduct = cartesianProduct.copy( + left = newLeft, + right = newRight, + condition = newCondition + ) + newCartesianProduct.copyTagsFrom(cartesianProduct) + newCartesianProduct + case _ => + throw new UnsupportedOperationException(s"Unsupported join $join") + } + protected def rewriteAggregateExpression( ae: AggregateExpression, expressionMap: mutable.HashMap[Expression, NamedExpression]): AggregateExpression = {