diff --git a/backends-velox/pom.xml b/backends-velox/pom.xml
index cd7d795861d2..ddf49166339a 100644
--- a/backends-velox/pom.xml
+++ b/backends-velox/pom.xml
@@ -86,6 +86,10 @@
${project.version}
compile
+
+ com.github.ben-manes.caffeine
+ caffeine
+
org.scalacheck
scalacheck_${scala.binary.version}
diff --git a/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java b/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
new file mode 100644
index 000000000000..e54909054cea
--- /dev/null
+++ b/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.vectorized;
+
+import org.apache.gluten.runtime.Runtime;
+import org.apache.gluten.runtime.RuntimeAware;
+
+public class HashJoinBuilder implements RuntimeAware {
+ private final Runtime runtime;
+
+ private HashJoinBuilder(Runtime runtime) {
+ this.runtime = runtime;
+ }
+
+ public static HashJoinBuilder create(Runtime runtime) {
+ return new HashJoinBuilder(runtime);
+ }
+
+ @Override
+ public long rtHandle() {
+ return runtime.getHandle();
+ }
+
+ public static native void clearHashTable(long hashTableData);
+
+ public static native long cloneHashTable(long hashTableData);
+
+ public static native long nativeBuild(
+ String buildHashTableId,
+ long[] batchHandlers,
+ String joinKeys,
+ int joinType,
+ boolean hasMixedFiltCondition,
+ boolean isExistenceJoin,
+ byte[] namedStruct,
+ boolean isNullAwareAntiJoin,
+ long bloomFilterPushdownSize,
+ int broadcastHashTableBuildThreads);
+}
diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index 24d08a57920d..2ab3af7ceaad 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -97,6 +97,11 @@ object VeloxBackendSettings extends BackendSettingsApi {
val GLUTEN_VELOX_INTERNAL_UDF_LIB_PATHS = VeloxBackend.CONF_PREFIX + ".internal.udfLibraryPaths"
val GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION = VeloxBackend.CONF_PREFIX + ".udfAllowTypeConversion"
+ val GLUTEN_VELOX_BROADCAST_CACHE_EXPIRED_TIME: String =
+ VeloxBackend.CONF_PREFIX + ("broadcast.cache.expired.time")
+ // unit: SECONDS, default 1 day
+ val GLUTEN_VELOX_BROADCAST_CACHE_EXPIRED_TIME_DEFAULT: Int = 86400
+
override def primaryBatchType: Convention.BatchType = VeloxBatchType
override def validateScanExec(
@@ -495,13 +500,17 @@ object VeloxBackendSettings extends BackendSettingsApi {
allSupported
}
+ override def enableJoinKeysRewrite(): Boolean = false
+
override def supportColumnarShuffleExec(): Boolean = {
val conf = GlutenConfig.get
conf.enableColumnarShuffle &&
(conf.isUseGlutenShuffleManager || conf.shuffleManagerSupportsColumnarShuffle)
}
- override def enableJoinKeysRewrite(): Boolean = false
+ override def enableHashTableBuildOncePerExecutor(): Boolean = {
+ VeloxConfig.get.enableBroadcastBuildOncePerExecutor
+ }
override def supportHashBuildJoinTypeOnLeft: JoinType => Boolean = {
t =>
diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
index 585f6d736db0..8722ae8616b8 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
@@ -20,6 +20,7 @@ import org.apache.gluten.backendsapi.ListenerApi
import org.apache.gluten.backendsapi.arrow.ArrowBatchTypes.{ArrowJavaBatchType, ArrowNativeBatchType}
import org.apache.gluten.config.{GlutenConfig, GlutenCoreConfig, VeloxConfig}
import org.apache.gluten.config.VeloxConfig._
+import org.apache.gluten.execution.VeloxBroadcastBuildSideCache
import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.gluten.expression.UDFMappings
import org.apache.gluten.extension.columnar.transition.Convention
@@ -35,8 +36,10 @@ import org.apache.gluten.utils._
import org.apache.spark.{HdfsConfGenerator, ShuffleDependency, SparkConf, SparkContext}
import org.apache.spark.api.plugin.PluginContext
import org.apache.spark.internal.Logging
+import org.apache.spark.listener.VeloxGlutenSQLAppStatusListener
import org.apache.spark.memory.GlobalOffHeapMemory
import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.rpc.{GlutenDriverEndpoint, GlutenExecutorEndpoint}
import org.apache.spark.shuffle.{ColumnarShuffleDependency, LookupKey, ShuffleManagerRegistry}
import org.apache.spark.shuffle.sort.ColumnarShuffleManager
import org.apache.spark.sql.execution.ColumnarCachedBatchSerializer
@@ -56,6 +59,8 @@ class VeloxListenerApi extends ListenerApi with Logging {
import VeloxListenerApi._
override def onDriverStart(sc: SparkContext, pc: PluginContext): Unit = {
+ GlutenDriverEndpoint.glutenDriverEndpointRef = (new GlutenDriverEndpoint).self
+ VeloxGlutenSQLAppStatusListener.registerListener(sc)
val conf = pc.conf()
// When the Velox cache is enabled, the Velox file handle cache should also be enabled.
@@ -138,6 +143,8 @@ class VeloxListenerApi extends ListenerApi with Logging {
override def onDriverShutdown(): Unit = shutdown()
override def onExecutorStart(pc: PluginContext): Unit = {
+ GlutenExecutorEndpoint.executorEndpoint = new GlutenExecutorEndpoint(pc.executorID, pc.conf)
+
val conf = pc.conf()
// Static initializers for executor.
@@ -250,6 +257,11 @@ class VeloxListenerApi extends ListenerApi with Logging {
private def shutdown(): Unit = {
// TODO shutdown implementation in velox to release resources
+ VeloxBroadcastBuildSideCache.cleanAll()
+ val executorEndpoint = GlutenExecutorEndpoint.executorEndpoint
+ if (executorEndpoint != null) {
+ executorEndpoint.stop()
+ }
}
}
diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index 773868b0c450..1d805362903a 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -66,6 +66,11 @@ object VeloxRuleApi {
injector.injectOptimizerRule(CollapseGetJsonObjectExpressionRule.apply)
injector.injectOptimizerRule(RewriteCastFromArray.apply)
injector.injectOptimizerRule(RewriteUnboundedWindow.apply)
+
+ if (!BackendsApiManager.getSettings.enableJoinKeysRewrite()) {
+ injector.injectPlannerStrategy(_ => org.apache.gluten.extension.GlutenJoinKeysCapture())
+ }
+
if (BackendsApiManager.getSettings.supportAppendDataExec()) {
injector.injectPlannerStrategy(SparkShimLoader.getSparkShims.getRewriteCreateTableAsSelect(_))
}
diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 69419deb1a2a..338bef20dfe5 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -22,6 +22,7 @@ import org.apache.gluten.exception.{GlutenExceptionUtil, GlutenNotSupportExcepti
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet}
+import org.apache.gluten.extension.JoinKeysTag
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.shuffle.NeedCustomColumnarBatchSerializer
import org.apache.gluten.sql.shims.SparkShimLoader
@@ -29,6 +30,7 @@ import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSeria
import org.apache.spark.{ShuffleDependency, SparkEnv, SparkException}
import org.apache.spark.api.python.{ColumnarArrowEvalPythonExec, PullOutArrowEvalPythonPreProjectHelper}
+import org.apache.spark.internal.Logging
import org.apache.spark.memory.SparkMemoryUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
@@ -43,9 +45,10 @@ import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
-import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode}
+import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode, SparkHashJoinUtils}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
@@ -64,8 +67,9 @@ import javax.ws.rs.core.UriBuilder
import java.util.Locale
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
-class VeloxSparkPlanExecApi extends SparkPlanExecApi {
+class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging {
/** Transform GetArrayItem to Substrait. */
override def genGetArrayItemTransformer(
@@ -678,9 +682,136 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
child: SparkPlan,
numOutputRows: SQLMetric,
dataSize: SQLMetric): BuildSideRelation = {
+
+ val buildKeys = mode match {
+ case mode1: HashedRelationBroadcastMode =>
+ mode1.key
+ case _ =>
+ // IdentityBroadcastMode
+ Seq.empty
+ }
+ var offload = true
+ val (newChild, newOutput, newBuildKeys) =
+ if (VeloxConfig.get.enableBroadcastBuildOncePerExecutor) {
+
+ // Try to lookup from TreeNodeTag using child's logical plan
+ // Need to recursively find logicalLink in case of AQE or other wrappers
+ @scala.annotation.tailrec
+ def findLogicalLink(
+ plan: SparkPlan): Option[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan] = {
+ plan.logicalLink match {
+ case some @ Some(_) => some
+ case None =>
+ plan.children match {
+ case Seq(child) => findLogicalLink(child)
+ case _ => None
+ }
+ }
+ }
+
+ val newBuildKeys = findLogicalLink(child)
+ .flatMap(_.getTagValue(JoinKeysTag.ORIGINAL_JOIN_KEYS))
+ .getOrElse {
+ if (SparkHashJoinUtils.canRewriteAsLongType(buildKeys) && buildKeys.nonEmpty) {
+ SparkHashJoinUtils.getOriginalKeysFromPacked(buildKeys.head)
+ } else {
+ buildKeys
+ }
+ }
+
+ val noNeedPreOp = newBuildKeys.forall {
+ case _: AttributeReference | _: BoundReference => true
+ case _ => false
+ }
+
+ if (noNeedPreOp) {
+ (child, child.output, Seq.empty[Expression])
+ } else {
+ // pre projection in case of expression join keys
+ val appendedProjections = new ArrayBuffer[NamedExpression]()
+ val preProjectionBuildKeys = newBuildKeys.zipWithIndex.map {
+ case (e, idx) =>
+ e match {
+ case b: BoundReference => child.output(b.ordinal)
+ case a: AttributeReference => a
+ case o: Expression =>
+ val newExpr = Alias(o, "col_" + idx)()
+ appendedProjections += newExpr
+ newExpr
+ }
+ }
+
+ def wrapChild(child: SparkPlan): SparkPlan = {
+ val childWithAdapter =
+ ColumnarCollapseTransformStages.wrapInputIteratorTransformer(child)
+ val projectExecTransformer =
+ ProjectExecTransformer(child.output ++ appendedProjections, childWithAdapter)
+ val validationResult = projectExecTransformer.doValidate()
+ if (validationResult.ok()) {
+ WholeStageTransformer(
+ ProjectExecTransformer(child.output ++ appendedProjections, childWithAdapter))(
+ ColumnarCollapseTransformStages
+ .getTransformStageCounter(childWithAdapter)
+ .incrementAndGet()
+ )
+ } else {
+ offload = false
+ child
+ }
+ }
+
+ val newChild = child match {
+ case wt: WholeStageTransformer =>
+ val projectTransformer =
+ ProjectExecTransformer(child.output ++ appendedProjections, wt.child)
+ if (projectTransformer.doValidate().ok()) {
+ wt.withNewChildren(
+ Seq(ProjectExecTransformer(child.output ++ appendedProjections, wt.child)))
+
+ } else {
+ offload = false
+ child
+ }
+ case w: WholeStageCodegenExec =>
+ w.withNewChildren(Seq(ProjectExec(child.output ++ appendedProjections, w.child)))
+ case r: AQEShuffleReadExec if r.supportsColumnar =>
+ // when aqe is open
+ // TODO: remove this after pushdowning preprojection
+ wrapChild(r)
+ case r2c: RowToVeloxColumnarExec =>
+ wrapChild(r2c)
+ case union: ColumnarUnionExec =>
+ wrapChild(union)
+ case ordered: TakeOrderedAndProjectExecTransformer =>
+ wrapChild(ordered)
+ case a2v: ArrowColumnarToVeloxColumnarExec =>
+ wrapChild(a2v)
+ case other =>
+ offload = false
+ logWarning(
+ "Not supported operator " + other.nodeName +
+ " for BroadcastRelation and fallback to shuffle hash join")
+ child
+ }
+
+ if (offload) {
+ (
+ newChild,
+ (child.output ++ appendedProjections).map(_.toAttribute),
+ preProjectionBuildKeys)
+ } else {
+ (child, child.output, Seq.empty[Expression])
+ }
+ }
+ } else {
+ offload = false
+ (child, child.output, buildKeys)
+ }
+
val useOffheapBroadcastBuildRelation =
VeloxConfig.get.enableBroadcastBuildRelationInOffheap
- val serialized: Seq[ColumnarBatchSerializeResult] = child
+
+ val serialized: Seq[ColumnarBatchSerializeResult] = newChild
.executeColumnar()
.mapPartitions(itr => Iterator(BroadcastUtils.serializeStream(itr)))
.filter(_.numRows != 0)
@@ -694,18 +825,23 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
}
numOutputRows += serialized.map(_.numRows).sum
dataSize += rawSize
+
if (useOffheapBroadcastBuildRelation) {
TaskResources.runUnsafe {
UnsafeColumnarBuildSideRelation(
- child.output,
+ newOutput,
serialized.flatMap(_.offHeapData().asScala),
- mode)
+ mode,
+ newBuildKeys,
+ offload)
}
} else {
ColumnarBuildSideRelation(
- child.output,
+ newOutput,
serialized.flatMap(_.onHeapData().asScala).toArray,
- mode)
+ mode,
+ newBuildKeys,
+ offload)
}
}
diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
index a40e9ca6e4ea..3a1d53154fe2 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
@@ -30,6 +30,7 @@ import org.apache.gluten.vectorized.PlanEvaluatorJniWrapper
import org.apache.spark.Partition
import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -120,6 +121,10 @@ class VeloxTransformerApi extends TransformerApi with Logging {
override def packPBMessage(message: Message): Any = Any.pack(message, "")
+ override def invalidateSQLExecutionResource(executionId: String): Unit = {
+ GlutenDriverEndpoint.invalidateResourceRelation(executionId)
+ }
+
override def genWriteParameters(write: WriteFilesExecTransformer): Any = {
write.fileFormat match {
case _ @(_: ParquetFileFormat | _: HiveFileFormat) =>
diff --git a/backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala b/backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala
index ee0866391ce0..071d75d6cf00 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala
@@ -61,6 +61,12 @@ class VeloxConfig(conf: SQLConf) extends GlutenConfig(conf) {
def enableBroadcastBuildRelationInOffheap: Boolean =
getConf(VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP)
+ def enableBroadcastBuildOncePerExecutor: Boolean =
+ getConf(VELOX_BROADCAST_BUILD_HASHTABLE_ONCE_PER_EXECUTOR)
+
+ def veloxBroadcastHashTableBuildThreads: Int =
+ getConf(COLUMNAR_VELOX_BROADCAST_HASH_TABLE_BUILD_THREADS)
+
def veloxOrcScanEnabled: Boolean =
getConf(VELOX_ORC_SCAN_ENABLED)
@@ -195,6 +201,13 @@ object VeloxConfig extends ConfigRegistry {
.intConf
.createOptional
+ val COLUMNAR_VELOX_BROADCAST_HASH_TABLE_BUILD_THREADS =
+ buildStaticConf("spark.gluten.sql.columnar.backend.velox.broadcastHashTableBuildThreads")
+ .doc("The number of threads used to build the broadcast hash table. " +
+ "If not set or set to 0, it will use the default number of threads (available processors).")
+ .intConf
+ .createWithDefault(1)
+
val COLUMNAR_VELOX_ASYNC_TIMEOUT =
buildStaticConf("spark.gluten.sql.columnar.backend.velox.asyncTimeoutOnTaskStopping")
.doc(
@@ -586,6 +599,16 @@ object VeloxConfig extends ConfigRegistry {
.intConf
.createWithDefault(0)
+ val VELOX_BROADCAST_BUILD_HASHTABLE_ONCE_PER_EXECUTOR =
+ buildConf("spark.gluten.velox.buildHashTableOncePerExecutor.enabled")
+ .internal()
+ .doc(
+ "When enabled, the hash table is " +
+ "constructed once per executor. If not enabled, " +
+ "the hash table is rebuilt for each task.")
+ .booleanConf
+ .createWithDefault(true)
+
val QUERY_TRACE_ENABLED = buildConf("spark.gluten.sql.columnar.backend.velox.queryTraceEnabled")
.doc("Enable query tracing flag.")
.booleanConf
diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/HashJoinExecTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/HashJoinExecTransformer.scala
index e3c93848dc2b..d79a3cae042d 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/execution/HashJoinExecTransformer.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/execution/HashJoinExecTransformer.scala
@@ -16,11 +16,14 @@
*/
package org.apache.gluten.execution
+import org.apache.gluten.config.VeloxConfig
+
import org.apache.spark.rdd.RDD
+import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.optimizer.BuildSide
+import org.apache.spark.sql.catalyst.optimizer.{BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -99,6 +102,9 @@ case class BroadcastHashJoinExecTransformer(
right,
isNullAwareAntiJoin) {
+ // Unique ID for built table
+ lazy val buildBroadcastTableId: String = buildPlan.id.toString
+
override protected lazy val substraitJoinType: JoinRel.JoinType = joinType match {
case _: InnerLike =>
JoinRel.JoinType.JOIN_TYPE_INNER
@@ -125,9 +131,52 @@ case class BroadcastHashJoinExecTransformer(
override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = {
val streamedRDD = getColumnarInputRDDs(streamedPlan)
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ if (executionId != null) {
+ logWarning(
+ s"Trace broadcast table data $buildBroadcastTableId" + " " +
+ "and the execution id is " + executionId)
+ GlutenDriverEndpoint.collectResources(executionId, buildBroadcastTableId)
+ } else {
+ logWarning(
+ s"Can not trace broadcast table data $buildBroadcastTableId" +
+ s" because execution id is null." +
+ s" Will clean up until expire time.")
+ }
+
val broadcast = buildPlan.executeBroadcast[BuildSideRelation]()
- val broadcastRDD = VeloxBroadcastBuildSideRDD(sparkContext, broadcast)
+ val bloomFilterPushdownSize = if (VeloxConfig.get.hashProbeDynamicFilterPushdownEnabled) {
+ VeloxConfig.get.hashProbeBloomFilterPushdownMaxSize
+ } else {
+ -1
+ }
+ val context =
+ BroadcastHashJoinContext(
+ buildKeyExprs,
+ substraitJoinType,
+ buildSide == BuildRight,
+ condition.isDefined,
+ joinType.isInstanceOf[ExistenceJoin],
+ buildPlan.output,
+ buildBroadcastTableId,
+ isNullAwareAntiJoin,
+ bloomFilterPushdownSize,
+ VeloxConfig.get.veloxBroadcastHashTableBuildThreads
+ )
+ val broadcastRDD = VeloxBroadcastBuildSideRDD(sparkContext, broadcast, context)
// FIXME: Do we have to make build side a RDD?
streamedRDD :+ broadcastRDD
}
}
+
+case class BroadcastHashJoinContext(
+ buildSideJoinKeys: Seq[Expression],
+ substraitJoinType: JoinRel.JoinType,
+ buildRight: Boolean,
+ hasMixedFiltCondition: Boolean,
+ isExistenceJoin: Boolean,
+ buildSideStructure: Seq[Attribute],
+ buildHashTableId: String,
+ isNullAwareAntiJoin: Boolean = false,
+ bloomFilterPushdownSize: Long,
+ broadcastHashTableBuildThreads: Int)
diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideCache.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideCache.scala
new file mode 100644
index 000000000000..2705f3b34cbf
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideCache.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
+import org.apache.gluten.vectorized.HashJoinBuilder
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.ColumnarBuildSideRelation
+import org.apache.spark.sql.execution.joins.BuildSideRelation
+import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
+
+import com.github.benmanes.caffeine.cache.{Cache, Caffeine, RemovalCause, RemovalListener}
+
+import java.util.concurrent.TimeUnit
+
+case class BroadcastHashTable(pointer: Long, relation: BuildSideRelation)
+
+/**
+ * `VeloxBroadcastBuildSideCache` is used for controlling to build bhj hash table once.
+ *
+ * The complicated part is due to reuse exchange, where multiple BHJ IDs correspond to a
+ * `BuildSideRelation`.
+ */
+object VeloxBroadcastBuildSideCache
+ extends Logging
+ with RemovalListener[String, BroadcastHashTable] {
+
+ private lazy val expiredTime = SparkEnv.get.conf.getLong(
+ VeloxBackendSettings.GLUTEN_VELOX_BROADCAST_CACHE_EXPIRED_TIME,
+ VeloxBackendSettings.GLUTEN_VELOX_BROADCAST_CACHE_EXPIRED_TIME_DEFAULT
+ )
+
+ // Use for controlling to build bhj hash table once.
+ // key: hashtable id, value is hashtable backend pointer(long to string).
+ private val buildSideRelationCache: Cache[String, BroadcastHashTable] =
+ Caffeine.newBuilder
+ .expireAfterAccess(expiredTime, TimeUnit.SECONDS)
+ .removalListener(this)
+ .build[String, BroadcastHashTable]()
+
+ def getOrBuildBroadcastHashTable(
+ broadcast: Broadcast[BuildSideRelation],
+ broadcastContext: BroadcastHashJoinContext): BroadcastHashTable = synchronized {
+
+ buildSideRelationCache
+ .get(
+ broadcastContext.buildHashTableId,
+ (broadcast_id: String) => {
+ val (pointer, relation) = broadcast.value match {
+ case columnar: ColumnarBuildSideRelation =>
+ columnar.buildHashTable(broadcastContext)
+ case unsafe: UnsafeColumnarBuildSideRelation =>
+ unsafe.buildHashTable(broadcastContext)
+ }
+
+ logWarning(s"Create bhj $broadcast_id = $pointer")
+ BroadcastHashTable(pointer, relation)
+ }
+ )
+ }
+
+ /** This is callback from c++ backend. */
+ def get(broadcastHashtableId: String): Long =
+ synchronized {
+ Option(buildSideRelationCache.getIfPresent(broadcastHashtableId))
+ .map(_.pointer)
+ .getOrElse(0)
+ }
+
+ def invalidateBroadcastHashtable(broadcastHashtableId: String): Unit = synchronized {
+ // Cleanup operations on the backend are idempotent.
+ buildSideRelationCache.invalidate(broadcastHashtableId)
+ }
+
+ /** Only used in UT. */
+ def size(): Long = buildSideRelationCache.estimatedSize()
+
+ def cleanAll(): Unit = buildSideRelationCache.invalidateAll()
+
+ override def onRemoval(key: String, value: BroadcastHashTable, cause: RemovalCause): Unit = {
+ synchronized {
+ logWarning(s"Remove bhj $key = ${value.pointer}")
+ if (value.relation != null) {
+ value.relation match {
+ case columnar: ColumnarBuildSideRelation =>
+ columnar.reset()
+ case unsafe: UnsafeColumnarBuildSideRelation =>
+ unsafe.reset()
+ }
+ }
+
+ HashJoinBuilder.clearHashTable(value.pointer)
+ }
+ }
+}
diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala
index 0163178e59f4..2d4b15705654 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala
@@ -19,19 +19,36 @@ package org.apache.gluten.execution
import org.apache.gluten.iterator.Iterators
import org.apache.spark.{broadcast, SparkContext}
+import org.apache.spark.sql.execution.ColumnarBuildSideRelation
import org.apache.spark.sql.execution.joins.BuildSideRelation
+import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch
case class VeloxBroadcastBuildSideRDD(
@transient private val sc: SparkContext,
- broadcasted: broadcast.Broadcast[BuildSideRelation])
+ broadcasted: broadcast.Broadcast[BuildSideRelation],
+ broadcastContext: BroadcastHashJoinContext,
+ isBNL: Boolean = false)
extends BroadcastBuildSideRDD(sc, broadcasted) {
override def genBroadcastBuildSideIterator(): Iterator[ColumnarBatch] = {
- val relation = broadcasted.value.asReadOnlyCopy()
- Iterators
- .wrap(relation.deserialized)
- .recyclePayload(batch => batch.close())
- .create()
+ val offload = broadcasted.value.asReadOnlyCopy() match {
+ case columnar: ColumnarBuildSideRelation =>
+ columnar.offload
+ case unsafe: UnsafeColumnarBuildSideRelation =>
+ unsafe.isOffload
+ }
+ val output = if (isBNL || !offload) {
+ val relation = broadcasted.value.asReadOnlyCopy()
+ Iterators
+ .wrap(relation.deserialized)
+ .recyclePayload(batch => batch.close())
+ .create()
+ } else {
+ VeloxBroadcastBuildSideCache.getOrBuildBroadcastHashTable(broadcasted, broadcastContext)
+ Iterator.empty
+ }
+
+ output
}
}
diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastNestedLoopJoinExecTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastNestedLoopJoinExecTransformer.scala
index 2a920c3ab931..6e0aaa27c6d2 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastNestedLoopJoinExecTransformer.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastNestedLoopJoinExecTransformer.scala
@@ -45,7 +45,7 @@ case class VeloxBroadcastNestedLoopJoinExecTransformer(
override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = {
val streamedRDD = getColumnarInputRDDs(streamedPlan)
val broadcast = buildPlan.executeBroadcast[BuildSideRelation]()
- val broadcastRDD = VeloxBroadcastBuildSideRDD(sparkContext, broadcast)
+ val broadcastRDD = VeloxBroadcastBuildSideRDD(sparkContext, broadcast, null, true)
// FIXME: Do we have to make build side a RDD?
streamedRDD :+ broadcastRDD
}
diff --git a/backends-velox/src/main/scala/org/apache/spark/listener/VeloxGlutenSQLAppStatusListener.scala b/backends-velox/src/main/scala/org/apache/spark/listener/VeloxGlutenSQLAppStatusListener.scala
new file mode 100644
index 000000000000..7e4ecc9a842c
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/spark/listener/VeloxGlutenSQLAppStatusListener.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.listener
+
+import org.apache.spark.SparkContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{GlutenDriverEndpoint, RpcEndpointRef}
+import org.apache.spark.rpc.GlutenRpcMessages._
+import org.apache.spark.scheduler._
+import org.apache.spark.sql.execution.ui._
+
+/** Gluten SQL listener. Used for monitor sql on whole life cycle.Create and release resource. */
+class VeloxGlutenSQLAppStatusListener(val driverEndpointRef: RpcEndpointRef)
+ extends SparkListener
+ with Logging {
+
+ /**
+ * If executor was removed, driver endpoint need to remove executor endpoint ref.\n When execution
+ * was end, Can't call executor ref again.
+ * @param executorRemoved
+ * execution eemoved event
+ */
+ override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = {
+ driverEndpointRef.send(GlutenExecutorRemoved(executorRemoved.executorId))
+ logTrace(s"Execution ${executorRemoved.executorId} Removed.")
+ }
+
+ override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
+ case e: SparkListenerSQLExecutionStart => onExecutionStart(e)
+ case e: SparkListenerSQLExecutionEnd => onExecutionEnd(e)
+ case _ => // Ignore
+ }
+
+ /**
+ * If execution is start, notice gluten executor with some prepare. execution.
+ *
+ * @param event
+ * execution start event
+ */
+ private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = {
+ val executionId = event.executionId.toString
+ driverEndpointRef.send(GlutenOnExecutionStart(executionId))
+ logTrace(s"Execution $executionId start.")
+ }
+
+ /**
+ * If execution was end, some backend like CH need to clean resource which is relation to this
+ * execution.
+ * @param event
+ * execution end event
+ */
+ private def onExecutionEnd(event: SparkListenerSQLExecutionEnd): Unit = {
+ // val stackTraceElements = Thread.currentThread().getStackTrace()
+
+ // for (element <- stackTraceElements) {
+ // logWarning(element.toString);
+ // }
+ val executionId = event.executionId.toString
+ driverEndpointRef.send(GlutenOnExecutionEnd(executionId))
+ logTrace(s"Execution $executionId end.")
+ }
+}
+object VeloxGlutenSQLAppStatusListener {
+ def registerListener(sc: SparkContext): Unit = {
+ sc.listenerBus.addToStatusQueue(
+ new VeloxGlutenSQLAppStatusListener(GlutenDriverEndpoint.glutenDriverEndpointRef))
+ }
+}
diff --git a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenDriverEndpoint.scala b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenDriverEndpoint.scala
new file mode 100644
index 000000000000..af635addf3b3
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenDriverEndpoint.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc
+
+import org.apache.gluten.config.GlutenConfig
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.GlutenRpcMessages._
+
+import com.github.benmanes.caffeine.cache.{Cache, Caffeine, RemovalCause, RemovalListener}
+
+import java.util
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
+
+/**
+ * The gluten driver endpoint is responsible for communicating with the executor. Executor will
+ * register with the driver when it starts.
+ */
+class GlutenDriverEndpoint extends IsolatedRpcEndpoint with Logging {
+ override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv
+
+ protected val totalRegisteredExecutors = new AtomicInteger(0)
+
+ private val driverEndpoint: RpcEndpointRef =
+ rpcEnv.setupEndpoint(GlutenRpcConstants.GLUTEN_DRIVER_ENDPOINT_NAME, this)
+
+ // TODO(yuan): get thread cnt from spark context
+ override def threadCount(): Int = 1
+ override def receive: PartialFunction[Any, Unit] = {
+ case GlutenOnExecutionStart(executionId) =>
+ if (executionId == null) {
+ logWarning(s"Execution Id is null. Resources maybe not clean after execution end.")
+ }
+
+ case GlutenOnExecutionEnd(executionId) =>
+ logWarning(s"Execution Id is $executionId end.")
+
+ GlutenDriverEndpoint.executionResourceRelation.invalidate(executionId)
+
+ case GlutenExecutorRemoved(executorId) =>
+ GlutenDriverEndpoint.executorDataMap.remove(executorId)
+ totalRegisteredExecutors.addAndGet(-1)
+ logTrace(s"Executor endpoint ref $executorId is removed.")
+
+ case e =>
+ logError(s"Received unexpected message. $e")
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+
+ case GlutenRegisterExecutor(executorId, executorRef) =>
+ if (GlutenDriverEndpoint.executorDataMap.contains(executorId)) {
+ context.sendFailure(new IllegalStateException(s"Duplicate executor ID: $executorId"))
+ } else {
+ // If the executor's rpc env is not listening for incoming connections, `hostPort`
+ // will be null, and the client connection should be used to contact the executor.
+ val executorAddress = if (executorRef.address != null) {
+ executorRef.address
+ } else {
+ context.senderAddress
+ }
+ logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId")
+
+ totalRegisteredExecutors.addAndGet(1)
+ val data = new ExecutorData(executorRef)
+ // This must be synchronized because variables mutated
+ // in this block are read when requesting executors
+ GlutenDriverEndpoint.this.synchronized {
+ GlutenDriverEndpoint.executorDataMap.put(executorId, data)
+ }
+ logTrace(s"Executor size ${GlutenDriverEndpoint.executorDataMap.size()}")
+ // Note: some tests expect the reply to come after we put the executor in the map
+ context.reply(true)
+ }
+
+ }
+
+ override def onStart(): Unit = {
+ logInfo(s"Initialized GlutenDriverEndpoint, address: ${driverEndpoint.address.toString()}.")
+ }
+}
+
+object GlutenDriverEndpoint extends Logging with RemovalListener[String, util.Set[String]] {
+ private lazy val executionResourceExpiredTime = SparkEnv.get.conf.getLong(
+ GlutenConfig.GLUTEN_RESOURCE_RELATION_EXPIRED_TIME.key,
+ GlutenConfig.GLUTEN_RESOURCE_RELATION_EXPIRED_TIME.defaultValue.get
+ )
+
+ var glutenDriverEndpointRef: RpcEndpointRef = _
+
+ // keep executorRef on memory
+ val executorDataMap = new ConcurrentHashMap[String, ExecutorData]
+
+ // If spark.scheduler.listenerbus.eventqueue.capacity is set too small,
+ // the listener may lose messages.
+ // We set a maximum expiration time of 1 day by default
+ // key: executionId, value: resourceIds
+ private val executionResourceRelation: Cache[String, util.Set[String]] =
+ Caffeine.newBuilder
+ .expireAfterAccess(executionResourceExpiredTime, TimeUnit.SECONDS)
+ .removalListener(this)
+ .build[String, util.Set[String]]()
+
+ def collectResources(executionId: String, resourceId: String): Unit = {
+ val resources = executionResourceRelation
+ .get(executionId, (_: String) => new util.HashSet[String]())
+ resources.add(resourceId)
+ }
+
+ def invalidateResourceRelation(executionId: String): Unit = {
+ executionResourceRelation.invalidate(executionId)
+ }
+
+ override def onRemoval(key: String, value: util.Set[String], cause: RemovalCause): Unit = {
+ executorDataMap.forEach(
+ (_, executor) => executor.executorEndpointRef.send(GlutenCleanExecutionResource(key, value)))
+ }
+}
+
+class ExecutorData(val executorEndpointRef: RpcEndpointRef) {}
diff --git a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala
new file mode 100644
index 000000000000..49ecef20b3b3
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc
+
+import org.apache.gluten.execution.VeloxBroadcastBuildSideCache
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.internal.{config, Logging}
+import org.apache.spark.rpc.GlutenRpcMessages._
+import org.apache.spark.util.ThreadUtils
+
+import scala.util.{Failure, Success}
+
+/** Gluten executor endpoint. */
+class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf)
+ extends IsolatedRpcEndpoint
+ with Logging {
+ override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv
+
+ private val driverHost = conf.get(config.DRIVER_HOST_ADDRESS.key, "localhost")
+ private val driverPort = conf.getInt(config.DRIVER_PORT.key, 7077)
+ private val rpcAddress = RpcAddress(driverHost, driverPort)
+ private val driverUrl =
+ RpcEndpointAddress(rpcAddress, GlutenRpcConstants.GLUTEN_DRIVER_ENDPOINT_NAME).toString
+
+ @volatile var driverEndpointRef: RpcEndpointRef = null
+
+ rpcEnv.setupEndpoint(GlutenRpcConstants.GLUTEN_EXECUTOR_ENDPOINT_NAME, this)
+ // TODO(yuan): get thread cnt from spark context
+ override def threadCount(): Int = 1
+ override def onStart(): Unit = {
+ rpcEnv
+ .asyncSetupEndpointRefByURI(driverUrl)
+ .flatMap {
+ ref =>
+ // This is a very fast action so we can use "ThreadUtils.sameThread"
+ driverEndpointRef = ref
+ ref.ask[Boolean](GlutenRegisterExecutor(executorId, self))
+ }(ThreadUtils.sameThread)
+ .onComplete {
+ case Success(_) => logTrace("Register GlutenExecutor listener success.")
+ case Failure(e) => logError("Register GlutenExecutor listener error.", e)
+ }(ThreadUtils.sameThread)
+ logInfo("Initialized GlutenExecutorEndpoint.")
+ }
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case GlutenCleanExecutionResource(executionId, hashIds) =>
+ if (executionId != null) {
+ hashIds.forEach(
+ resource_id => VeloxBroadcastBuildSideCache.invalidateBroadcastHashtable(resource_id))
+ }
+
+ case e =>
+ logError(s"Received unexpected message. $e")
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case e =>
+ logInfo(s"Received message. $e")
+ }
+}
+object GlutenExecutorEndpoint {
+ var executorEndpoint: GlutenExecutorEndpoint = _
+}
diff --git a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcConstants.scala b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcConstants.scala
new file mode 100644
index 000000000000..4fbb0722a26a
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcConstants.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc
+
+object GlutenRpcConstants {
+
+ val GLUTEN_DRIVER_ENDPOINT_NAME = "GlutenDriverEndpoint"
+
+ val GLUTEN_EXECUTOR_ENDPOINT_NAME = "GlutenExecutorEndpoint"
+}
diff --git a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
new file mode 100644
index 000000000000..8127c324b79c
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc
+
+import java.util
+
+trait GlutenRpcMessage extends Serializable
+
+object GlutenRpcMessages {
+ case class GlutenRegisterExecutor(
+ executorId: String,
+ executorRef: RpcEndpointRef
+ ) extends GlutenRpcMessage
+
+ case class GlutenOnExecutionStart(executionId: String) extends GlutenRpcMessage
+
+ case class GlutenOnExecutionEnd(executionId: String) extends GlutenRpcMessage
+
+ case class GlutenExecutorRemoved(executorId: String) extends GlutenRpcMessage
+
+ case class GlutenCleanExecutionResource(executionId: String, broadcastHashIds: util.Set[String])
+ extends GlutenRpcMessage
+
+ // for mergetree cache
+ case class GlutenMergeTreeCacheLoad(
+ mergeTreeTable: String,
+ columns: util.Set[String],
+ onlyMetaCache: Boolean)
+ extends GlutenRpcMessage
+
+ case class GlutenCacheLoadStatus(jobId: String)
+
+ case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "")
+ extends GlutenRpcMessage
+
+ case class GlutenFilesCacheLoad(files: Array[Byte]) extends GlutenRpcMessage
+
+ case class GlutenFilesCacheLoadStatus(jobId: String)
+}
diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
index ad066d47f9e0..cf3f9ccca460 100644
--- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
+++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
@@ -108,7 +108,9 @@ object BroadcastUtils {
UnsafeColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
result.offHeapData().asScala.toSeq,
- mode)
+ mode,
+ Seq.empty,
+ result.isOffHeap)
} else {
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
index d542fd92b92c..6429f8bb3fc5 100644
--- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
+++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
@@ -18,13 +18,16 @@ package org.apache.spark.sql.execution
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.columnarbatch.ColumnarBatches
+import org.apache.gluten.execution.BroadcastHashJoinContext
+import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.iterator.Iterators
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.gluten.utils.ArrowAbiUtil
-import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper}
+import org.apache.gluten.utils.{ArrowAbiUtil, SubstraitUtil}
+import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, HashJoinBuilder, NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper}
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, BindReferences, BoundReference, Expression, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
@@ -37,14 +40,18 @@ import org.apache.spark.util.KnownSizeEstimation
import org.apache.arrow.c.ArrowSchema
+import scala.collection.JavaConverters._
import scala.collection.JavaConverters.asScalaIteratorConverter
+import scala.collection.mutable.ArrayBuffer
object ColumnarBuildSideRelation {
// Keep constructor with BroadcastMode for compatibility
def apply(
output: Seq[Attribute],
batches: Array[Array[Byte]],
- mode: BroadcastMode): ColumnarBuildSideRelation = {
+ mode: BroadcastMode,
+ newBuildKeys: Seq[Expression] = Seq.empty,
+ offload: Boolean = false): ColumnarBuildSideRelation = {
val boundMode = mode match {
case HashedRelationBroadcastMode(keys, isNullAware) =>
// Bind each key to the build-side output so simple cols become BoundReference
@@ -54,15 +61,23 @@ object ColumnarBuildSideRelation {
case m =>
m // IdentityBroadcastMode, etc.
}
- new ColumnarBuildSideRelation(output, batches, BroadcastModeUtils.toSafe(boundMode))
+ new ColumnarBuildSideRelation(
+ output,
+ batches,
+ BroadcastModeUtils.toSafe(boundMode),
+ newBuildKeys,
+ offload)
}
}
case class ColumnarBuildSideRelation(
output: Seq[Attribute],
batches: Array[Array[Byte]],
- safeBroadcastMode: SafeBroadcastMode)
+ safeBroadcastMode: SafeBroadcastMode,
+ newBuildKeys: Seq[Expression],
+ offload: Boolean)
extends BuildSideRelation
+ with Logging
with KnownSizeEstimation {
// Rebuild the real BroadcastMode on demand; never serialize it.
@@ -135,6 +150,87 @@ case class ColumnarBuildSideRelation(
override def asReadOnlyCopy(): ColumnarBuildSideRelation = this
+ private var hashTableData: Long = 0L
+
+ def buildHashTable(
+ broadcastContext: BroadcastHashJoinContext): (Long, ColumnarBuildSideRelation) =
+ synchronized {
+ if (hashTableData == 0) {
+ val runtime = Runtimes.contextInstance(
+ BackendsApiManager.getBackendName,
+ "ColumnarBuildSideRelation#buildHashTable")
+ val jniWrapper = ColumnarBatchSerializerJniWrapper.create(runtime)
+ val serializeHandle: Long = {
+ val allocator = ArrowBufferAllocators.contextInstance()
+ val cSchema = ArrowSchema.allocateNew(allocator)
+ val arrowSchema = SparkArrowUtil.toArrowSchema(
+ SparkShimLoader.getSparkShims.structFromAttributes(output),
+ SQLConf.get.sessionLocalTimeZone)
+ ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema)
+ val handle = jniWrapper
+ .init(cSchema.memoryAddress())
+ cSchema.close()
+ handle
+ }
+
+ val batchArray = new ArrayBuffer[Long]
+
+ var batchId = 0
+ while (batchId < batches.size) {
+ batchArray.append(jniWrapper.deserialize(serializeHandle, batches(batchId)))
+ batchId += 1
+ }
+
+ logDebug(
+ s"BHJ value size: " +
+ s"${broadcastContext.buildHashTableId} = ${batches.length}")
+
+ val (keys, newOutput) = if (newBuildKeys.isEmpty) {
+ (
+ broadcastContext.buildSideJoinKeys.asJava,
+ broadcastContext.buildSideStructure.asJava
+ )
+ } else {
+ (
+ newBuildKeys.asJava,
+ output.asJava
+ )
+ }
+
+ val joinKey = keys.asScala
+ .map {
+ key =>
+ val attr = ConverterUtils.getAttrFromExpr(key)
+ ConverterUtils.genColumnNameWithExprId(attr)
+ }
+ .mkString(",")
+
+ // Build the hash table
+ hashTableData = HashJoinBuilder
+ .nativeBuild(
+ broadcastContext.buildHashTableId,
+ batchArray.toArray,
+ joinKey,
+ broadcastContext.substraitJoinType.ordinal(),
+ broadcastContext.hasMixedFiltCondition,
+ broadcastContext.isExistenceJoin,
+ SubstraitUtil.toNameStruct(newOutput).toByteArray,
+ broadcastContext.isNullAwareAntiJoin,
+ broadcastContext.bloomFilterPushdownSize,
+ broadcastContext.broadcastHashTableBuildThreads
+ )
+
+ jniWrapper.close(serializeHandle)
+ (hashTableData, this)
+ } else {
+ (HashJoinBuilder.cloneHashTable(hashTableData), null)
+ }
+ }
+
+ def reset(): Unit = synchronized {
+ hashTableData = 0
+ }
+
/**
* Transform columnar broadcast value to Array[InternalRow] by key.
*
diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/joins/SparkHashJoinUtils.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/joins/SparkHashJoinUtils.scala
new file mode 100644
index 000000000000..1e6b677253f5
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/joins/SparkHashJoinUtils.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, BitwiseAnd, BitwiseOr, Cast, Expression, ShiftLeft}
+import org.apache.spark.sql.types.IntegralType
+
+object SparkHashJoinUtils {
+
+ // Copy from org.apache.spark.sql.execution.joins.HashJoin#canRewriteAsLongType
+ // we should keep consistent with it to identify the LongHashRelation.
+ def canRewriteAsLongType(keys: Seq[Expression]): Boolean = {
+ // TODO: support BooleanType, DateType and TimestampType
+ keys.forall(_.dataType.isInstanceOf[IntegralType]) &&
+ keys.map(_.dataType.defaultSize).sum <= 8
+ }
+
+ def getOriginalKeysFromPacked(expr: Expression): Seq[Expression] = {
+
+ def unwrap(e: Expression): Expression = e match {
+ case Cast(child, _, _, _) => unwrap(child)
+ case Alias(child, _) => unwrap(child)
+ case BitwiseAnd(child, _) => unwrap(child)
+ case other => other
+ }
+
+ expr match {
+ case BitwiseOr(ShiftLeft(left, _), rightPart) =>
+ getOriginalKeysFromPacked(left) :+ unwrap(rightPart)
+ case BitwiseOr(left, rightPart) =>
+ getOriginalKeysFromPacked(left) :+ unwrap(rightPart)
+ case other =>
+ Seq(unwrap(other))
+ }
+ }
+
+}
diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
index ba307415c501..fc7516c4b325 100644
--- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
+++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
@@ -18,12 +18,14 @@ package org.apache.spark.sql.execution.unsafe
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.columnarbatch.ColumnarBatches
+import org.apache.gluten.execution.BroadcastHashJoinContext
+import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.iterator.Iterators
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.gluten.utils.ArrowAbiUtil
-import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper}
+import org.apache.gluten.utils.{ArrowAbiUtil, SubstraitUtil}
+import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, HashJoinBuilder, NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper}
import org.apache.spark.annotation.Experimental
import org.apache.spark.internal.Logging
@@ -44,13 +46,17 @@ import org.apache.arrow.c.ArrowSchema
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import scala.collection.JavaConverters._
import scala.collection.JavaConverters.asScalaIteratorConverter
+import scala.collection.mutable.ArrayBuffer
object UnsafeColumnarBuildSideRelation {
def apply(
output: Seq[Attribute],
batches: Seq[UnsafeByteArray],
- mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
+ mode: BroadcastMode,
+ newBuildKeys: Seq[Expression] = Seq.empty,
+ offload: Boolean = false): UnsafeColumnarBuildSideRelation = {
val boundMode = mode match {
case HashedRelationBroadcastMode(keys, isNullAware) =>
// Bind each key to the build-side output so simple cols become BoundReference
@@ -60,7 +66,12 @@ object UnsafeColumnarBuildSideRelation {
case m =>
m // IdentityBroadcastMode, etc.
}
- new UnsafeColumnarBuildSideRelation(output, batches, BroadcastModeUtils.toSafe(boundMode))
+ new UnsafeColumnarBuildSideRelation(
+ output,
+ batches,
+ BroadcastModeUtils.toSafe(boundMode),
+ newBuildKeys,
+ offload)
}
}
@@ -78,7 +89,9 @@ object UnsafeColumnarBuildSideRelation {
class UnsafeColumnarBuildSideRelation(
private var output: Seq[Attribute],
private var batches: Seq[UnsafeByteArray],
- private var safeBroadcastMode: SafeBroadcastMode)
+ private var safeBroadcastMode: SafeBroadcastMode,
+ private var newBuildKeys: Seq[Expression],
+ private var offload: Boolean)
extends BuildSideRelation
with Externalizable
with Logging
@@ -96,37 +109,128 @@ class UnsafeColumnarBuildSideRelation(
case _ => None
}
+ def isOffload: Boolean = offload
+
/** needed for serialization. */
def this() = {
- this(null, null, null)
+ this(null, null, null, Seq.empty, false)
}
private[unsafe] def getBatches(): Seq[UnsafeByteArray] = {
batches
}
+ private var hashTableData: Long = 0L
+
+ def buildHashTable(broadcastContext: BroadcastHashJoinContext): (Long, BuildSideRelation) =
+ synchronized {
+ if (hashTableData == 0) {
+ val runtime = Runtimes.contextInstance(
+ BackendsApiManager.getBackendName,
+ "UnsafeColumnarBuildSideRelation#buildHashTable")
+ val jniWrapper = ColumnarBatchSerializerJniWrapper.create(runtime)
+ val serializeHandle: Long = {
+ val allocator = ArrowBufferAllocators.contextInstance()
+ val cSchema = ArrowSchema.allocateNew(allocator)
+ val arrowSchema = SparkArrowUtil.toArrowSchema(
+ SparkShimLoader.getSparkShims.structFromAttributes(output),
+ SQLConf.get.sessionLocalTimeZone)
+ ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema)
+ val handle = jniWrapper
+ .init(cSchema.memoryAddress())
+ cSchema.close()
+ handle
+ }
+
+ val batchArray = new ArrayBuffer[Long]
+
+ var batchId = 0
+ while (batchId < batches.size) {
+ val (offset, length) = (batches(batchId).address(), batches(batchId).size())
+ batchArray.append(jniWrapper.deserializeDirect(serializeHandle, offset, length.toInt))
+ batchId += 1
+ }
+
+ logDebug(
+ s"BHJ value size: " +
+ s"${broadcastContext.buildHashTableId} = ${batches.size}")
+
+ val (keys, newOutput) = if (newBuildKeys.isEmpty) {
+ (
+ broadcastContext.buildSideJoinKeys.asJava,
+ broadcastContext.buildSideStructure.asJava
+ )
+ } else {
+ (
+ newBuildKeys.asJava,
+ output.asJava
+ )
+ }
+
+ val joinKey = keys.asScala
+ .map {
+ key =>
+ val attr = ConverterUtils.getAttrFromExpr(key)
+ ConverterUtils.genColumnNameWithExprId(attr)
+ }
+ .mkString(",")
+
+ // Build the hash table
+ hashTableData = HashJoinBuilder
+ .nativeBuild(
+ broadcastContext.buildHashTableId,
+ batchArray.toArray,
+ joinKey,
+ broadcastContext.substraitJoinType.ordinal(),
+ broadcastContext.hasMixedFiltCondition,
+ broadcastContext.isExistenceJoin,
+ SubstraitUtil.toNameStruct(newOutput).toByteArray,
+ broadcastContext.isNullAwareAntiJoin,
+ broadcastContext.bloomFilterPushdownSize,
+ broadcastContext.broadcastHashTableBuildThreads
+ )
+
+ jniWrapper.close(serializeHandle)
+ (hashTableData, this)
+ } else {
+ (HashJoinBuilder.cloneHashTable(hashTableData), null)
+ }
+ }
+
+ def reset(): Unit = synchronized {
+ hashTableData = 0
+ }
+
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeObject(output)
out.writeObject(safeBroadcastMode)
out.writeObject(batches.toArray)
+ out.writeObject(newBuildKeys)
+ out.writeBoolean(offload)
}
override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException {
kryo.writeObject(out, output.toList)
kryo.writeClassAndObject(out, safeBroadcastMode)
kryo.writeClassAndObject(out, batches.toArray)
+ kryo.writeClassAndObject(out, newBuildKeys)
+ out.writeBoolean(offload)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
output = in.readObject().asInstanceOf[Seq[Attribute]]
safeBroadcastMode = in.readObject().asInstanceOf[SafeBroadcastMode]
batches = in.readObject().asInstanceOf[Array[UnsafeByteArray]].toSeq
+ newBuildKeys = in.readObject().asInstanceOf[Seq[Expression]]
+ offload = in.readBoolean()
}
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
output = kryo.readObject(in, classOf[List[_]]).asInstanceOf[Seq[Attribute]]
safeBroadcastMode = kryo.readClassAndObject(in).asInstanceOf[SafeBroadcastMode]
batches = kryo.readClassAndObject(in).asInstanceOf[Array[UnsafeByteArray]].toSeq
+ newBuildKeys = kryo.readClassAndObject(in).asInstanceOf[Seq[Expression]]
+ offload = in.readBoolean()
}
private def transformProjection: UnsafeProjection = safeBroadcastMode match {
diff --git a/backends-velox/src/test/java/org/apache/gluten/test/MockVeloxBackend.java b/backends-velox/src/test/java/org/apache/gluten/test/MockVeloxBackend.java
index 06fe3d28caff..2c4b813f30cb 100644
--- a/backends-velox/src/test/java/org/apache/gluten/test/MockVeloxBackend.java
+++ b/backends-velox/src/test/java/org/apache/gluten/test/MockVeloxBackend.java
@@ -43,7 +43,7 @@ public SparkConf conf() {
@Override
public String executorID() {
- throw new UnsupportedOperationException();
+ return "MockVeloxBackend ID";
}
@Override
diff --git a/backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java b/backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
index 27596137931c..c66c67fe9ebb 100644
--- a/backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
+++ b/backends-velox/src/test/java/org/apache/gluten/test/VeloxBackendTestBase.java
@@ -19,19 +19,36 @@
import org.apache.gluten.backendsapi.ListenerApi;
import org.apache.gluten.backendsapi.velox.VeloxListenerApi;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.test.TestSparkSession;
import org.junit.AfterClass;
import org.junit.BeforeClass;
public abstract class VeloxBackendTestBase {
private static final ListenerApi API = new VeloxListenerApi();
+ private static SparkSession sparkSession = null;
@BeforeClass
public static void setup() {
+ if (sparkSession == null) {
+ sparkSession =
+ TestSparkSession.builder()
+ .appName("VeloxBackendTest")
+ .master("local[1]")
+ .config(MockVeloxBackend.mockPluginContext().conf())
+ .getOrCreate();
+ }
+
API.onExecutorStart(MockVeloxBackend.mockPluginContext());
}
@AfterClass
public static void tearDown() {
API.onExecutorShutdown();
+
+ if (sparkSession != null) {
+ sparkSession.stop();
+ sparkSession = null;
+ }
}
}
diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/DynamicOffHeapSizingSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/DynamicOffHeapSizingSuite.scala
index 0afbc2fa19c5..ddd76f917db9 100644
--- a/backends-velox/src/test/scala/org/apache/gluten/execution/DynamicOffHeapSizingSuite.scala
+++ b/backends-velox/src/test/scala/org/apache/gluten/execution/DynamicOffHeapSizingSuite.scala
@@ -35,6 +35,10 @@ class DynamicOffHeapSizingSuite extends VeloxWholeStageTransformerSuite {
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.executor.memory", "2GB")
.set("spark.memory.offHeap.enabled", "false")
+ .set(
+ "spark.gluten.velox.buildHashTableOncePerExecutor.enabled",
+ "false"
+ ) // build native hash table need use off heap memory.
.set(GlutenCoreConfig.DYNAMIC_OFFHEAP_SIZING_MEMORY_FRACTION.key, "0.95")
.set(GlutenCoreConfig.DYNAMIC_OFFHEAP_SIZING_ENABLED.key, "true")
}
diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala
index 4fca03fa857b..86565aa42b9e 100644
--- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala
+++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala
@@ -22,8 +22,7 @@ import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.SparkConf
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarSubqueryBroadcastExec, InputIteratorTransformer}
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec}
+import org.apache.spark.sql.execution.{ColumnarSubqueryBroadcastExec, InputIteratorTransformer}
class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite {
override protected val resourcePath: String = "/tpch-data-parquet"
@@ -114,85 +113,12 @@ class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite {
}
}
- test("Reuse broadcast exchange for different build keys with same table") {
- Seq("true", "false").foreach(
- enabledOffheapBroadcast =>
- withSQLConf(
- VeloxConfig.VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP.key -> enabledOffheapBroadcast) {
- withTable("t1", "t2") {
- spark.sql("""
- |CREATE TABLE t1 USING PARQUET
- |AS SELECT id as c1, id as c2 FROM range(10)
- |""".stripMargin)
-
- spark.sql("""
- |CREATE TABLE t2 USING PARQUET
- |AS SELECT id as c1, id as c2 FROM range(3)
- |""".stripMargin)
-
- val df = spark.sql("""
- |SELECT * FROM t1
- |JOIN t2 as tmp1 ON t1.c1 = tmp1.c1 and tmp1.c1 = tmp1.c2
- |JOIN t2 as tmp2 on t1.c2 = tmp2.c2 and tmp2.c1 = tmp2.c2
- |""".stripMargin)
-
- assert(collect(df.queryExecution.executedPlan) {
- case b: BroadcastExchangeExec => b
- }.size == 2)
-
- checkAnswer(
- df,
- Row(2, 2, 2, 2, 2, 2) :: Row(1, 1, 1, 1, 1, 1) :: Row(0, 0, 0, 0, 0, 0) :: Nil)
-
- assert(collect(df.queryExecution.executedPlan) {
- case b: ColumnarBroadcastExchangeExec => b
- }.size == 1)
- assert(collect(df.queryExecution.executedPlan) {
- case r @ ReusedExchangeExec(_, _: ColumnarBroadcastExchangeExec) => r
- }.size == 1)
- }
- })
- }
-
- test("ColumnarBuildSideRelation with small columnar to row memory") {
- Seq("true", "false").foreach(
- enabledOffheapBroadcast =>
- withSQLConf(
- GlutenConfig.GLUTEN_COLUMNAR_TO_ROW_MEM_THRESHOLD.key -> "16",
- VeloxConfig.VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP.key -> enabledOffheapBroadcast) {
- withTable("t1", "t2") {
- spark.sql("""
- |CREATE TABLE t1 USING PARQUET
- |AS SELECT id as c1, id as c2 FROM range(10)
- |""".stripMargin)
-
- spark.sql("""
- |CREATE TABLE t2 USING PARQUET PARTITIONED BY (c1)
- |AS SELECT id as c1, id as c2 FROM range(30)
- |""".stripMargin)
-
- val df = spark.sql("""
- |SELECT t1.c2
- |FROM t1, t2
- |WHERE t1.c1 = t2.c1
- |AND t1.c2 < 4
- |""".stripMargin)
-
- checkAnswer(df, Row(0) :: Row(1) :: Row(2) :: Row(3) :: Nil)
-
- val subqueryBroadcastExecs = collectWithSubqueries(df.queryExecution.executedPlan) {
- case subqueryBroadcast: ColumnarSubqueryBroadcastExec => subqueryBroadcast
- }
- assert(subqueryBroadcastExecs.size == 1)
- }
- })
- }
-
test("ColumnarBuildSideRelation transform support multiple key columns") {
Seq("true", "false").foreach(
enabledOffheapBroadcast =>
withSQLConf(
- VeloxConfig.VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP.key -> enabledOffheapBroadcast) {
+ VeloxConfig.VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP.key ->
+ enabledOffheapBroadcast) {
withTable("t1", "t2") {
val df1 =
(0 until 50)
@@ -317,4 +243,83 @@ class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite {
}
}
}
+
+ test("Broadcast join preserves original cast expression in join keys") {
+ withSQLConf(
+ ("spark.sql.autoBroadcastJoinThreshold", "10MB"),
+ ("spark.sql.adaptive.enabled", "false")
+ ) {
+ withTable("t1_int", "t2_long") {
+ // Create table with INT column
+ spark
+ .range(100)
+ .selectExpr("cast(id as int) as key", "id as value")
+ .write
+ .saveAsTable("t1_int")
+
+ // Create table with LONG column
+ spark.range(50).selectExpr("id as key", "id * 2 as value").write.saveAsTable("t2_long")
+
+ // Join INT with LONG - Spark will insert cast(int to long) in join keys
+ val query = """
+ SELECT t1.key, t1.value, t2.value as value2
+ FROM t1_int t1
+ JOIN t2_long t2 ON t1.key = t2.key
+ ORDER BY t1.key
+ """
+
+ runQueryAndCompare(query) {
+ df =>
+ // Check that broadcast join is used in Gluten execution
+ val plan = df.queryExecution.executedPlan
+ val broadcastJoins = plan.collect { case bhj: BroadcastHashJoinExecTransformer => bhj }
+ assert(broadcastJoins.nonEmpty, "Should use broadcast hash join")
+ }
+ }
+ }
+ }
+
+ test("Broadcast join with multiple cast expressions in join keys") {
+ withSQLConf(
+ ("spark.sql.autoBroadcastJoinThreshold", "10MB"),
+ ("spark.sql.adaptive.enabled", "false")
+ ) {
+ withTable("t1_mixed", "t2_mixed") {
+ // Create table with mixed types
+ spark
+ .range(100)
+ .selectExpr("cast(id as int) as key1", "cast(id as short) as key2", "id as value")
+ .write
+ .saveAsTable("t1_mixed")
+
+ // Create table with different types requiring casts
+ spark
+ .range(50)
+ .selectExpr("id as key1", "cast(id as int) as key2", "id * 2 as value")
+ .write
+ .saveAsTable("t2_mixed")
+
+ // Join with multiple keys requiring casts
+ // key1: cast(int to long), key2: cast(short to int)
+ val query = """
+ SELECT t1.key1, t1.key2, t1.value, t2.value as value2
+ FROM t1_mixed t1
+ JOIN t2_mixed t2 ON t1.key1 = t2.key1 AND t1.key2 = t2.key2
+ ORDER BY t1.key1, t1.key2
+ """
+
+ runQueryAndCompare(query) {
+ df =>
+ // Check that broadcast join is used in Gluten execution
+ val plan = df.queryExecution.executedPlan
+ val broadcastJoins = plan.collect { case bhj: BroadcastHashJoinExecTransformer => bhj }
+ assert(broadcastJoins.nonEmpty, "Should use broadcast hash join")
+
+ // Verify multiple join keys are handled correctly
+ assert(broadcastJoins.head.leftKeys.length == 2)
+ assert(broadcastJoins.head.rightKeys.length == 2)
+ }
+ }
+ }
+ }
}
diff --git a/backends-velox/src/test/scala/org/apache/spark/sql/execution/benchmark/VeloxBroadcastBuildOnceBenchmark.scala b/backends-velox/src/test/scala/org/apache/spark/sql/execution/benchmark/VeloxBroadcastBuildOnceBenchmark.scala
new file mode 100644
index 000000000000..6e06cc35a74b
--- /dev/null
+++ b/backends-velox/src/test/scala/org/apache/spark/sql/execution/benchmark/VeloxBroadcastBuildOnceBenchmark.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.benchmark
+
+import org.apache.gluten.config.VeloxConfig
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.internal.SQLConf
+
+/** Benchmark to measure performance for BHJ build once per executor. */
+object VeloxBroadcastBuildOnceBenchmark extends SqlBasedBenchmark {
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ val numRows = 5 * 1000 * 1000
+ val broadcastRows = 1000 * 1000
+
+ withTempPath {
+ f =>
+ val path = f.getCanonicalPath
+ val probePath = s"$path/probe"
+ val buildPath = s"$path/build"
+
+ // Generate probe table with many partitions to simulate many tasks
+ spark
+ .range(numRows)
+ .repartition(100)
+ .selectExpr("id as k1", "id as v1")
+ .write
+ .parquet(probePath)
+
+ // Generate build table
+ spark
+ .range(broadcastRows)
+ .selectExpr("id as k2", "id as v2")
+ .write
+ .parquet(buildPath)
+
+ spark.read.parquet(probePath).createOrReplaceTempView("probe")
+ spark.read.parquet(buildPath).createOrReplaceTempView("build")
+
+ val query = "SELECT /*+ BROADCAST(build) */ count(*) FROM probe JOIN build ON k1 = k2"
+
+ val benchmark = new Benchmark("BHJ Build Once Benchmark", numRows, output = output)
+
+ // Warm up
+ spark.sql(query).collect()
+
+ benchmark.addCase("Build once per executor enabled=false", 3) {
+ _ =>
+ withSQLConf(
+ VeloxConfig.VELOX_BROADCAST_BUILD_HASHTABLE_ONCE_PER_EXECUTOR.key -> "false",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200MB"
+ ) {
+ spark.sql(query).collect()
+ }
+ }
+
+ benchmark.addCase("Build once per executor enabled=true", 3) {
+ _ =>
+ withSQLConf(
+ VeloxConfig.VELOX_BROADCAST_BUILD_HASHTABLE_ONCE_PER_EXECUTOR.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200MB"
+ ) {
+ spark.sql(query).collect()
+ }
+ }
+
+ benchmark.run()
+ }
+ }
+}
diff --git a/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
index 41400f613f59..c881d77ed105 100644
--- a/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
+++ b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
@@ -188,4 +188,30 @@ class UnsafeColumnarBuildSideRelationTest extends SharedSparkSession {
newUnsafeRelationWithHashMode(ByteUnit.MiB.toKiB(50).toInt)
}
}
+
+ test("Verify offload field serialization") {
+ val relation = UnsafeColumnarBuildSideRelation(
+ output,
+ Seq(sampleUnsafeByteArrayInKb(1)),
+ IdentityBroadcastMode,
+ Seq.empty,
+ offload = true
+ )
+
+ // Java Serialization
+ val javaSerializer = new JavaSerializer(SparkEnv.get.conf).newInstance()
+ val javaBuffer = javaSerializer.serialize(relation)
+ val javaObj = javaSerializer.deserialize[UnsafeColumnarBuildSideRelation](javaBuffer)
+ assert(javaObj.isOffload, "Java deserialization failed to restore offload=true")
+
+ // Kryo Serialization
+ val kryoSerializer = new KryoSerializer(SparkEnv.get.conf).newInstance()
+ val kryoBuffer = kryoSerializer.serialize(relation)
+ val kryoObj = kryoSerializer.deserialize[UnsafeColumnarBuildSideRelation](kryoBuffer)
+ assert(kryoObj.isOffload, "Kryo deserialization failed to restore offload=true")
+
+ // Create another relation with offload=false to compare byte size if possible,
+ // but boolean only takes 1 byte, might be hard to distinguish from metadata noise.
+ // Instead, trust the assertion above.
+ }
}
diff --git a/cpp/velox/CMakeLists.txt b/cpp/velox/CMakeLists.txt
index be31f18206b3..fc6391b7f3c0 100644
--- a/cpp/velox/CMakeLists.txt
+++ b/cpp/velox/CMakeLists.txt
@@ -157,6 +157,7 @@ set(VELOX_SRCS
jni/JniFileSystem.cc
jni/JniUdf.cc
jni/VeloxJniWrapper.cc
+ jni/JniHashTable.cc
memory/BufferOutputStream.cc
memory/VeloxColumnarBatch.cc
memory/VeloxMemoryManager.cc
@@ -164,6 +165,7 @@ set(VELOX_SRCS
operators/functions/RowConstructorWithNull.cc
operators/functions/SparkExprToSubfieldFilterParser.cc
operators/plannodes/RowVectorStream.cc
+ operators/hashjoin/HashTableBuilder.cc
operators/reader/FileReaderIterator.cc
operators/reader/ParquetReaderIterator.cc
operators/serializer/VeloxColumnarBatchSerializer.cc
diff --git a/cpp/velox/compute/VeloxBackend.cc b/cpp/velox/compute/VeloxBackend.cc
index de9e9385f8f0..0232da48da14 100644
--- a/cpp/velox/compute/VeloxBackend.cc
+++ b/cpp/velox/compute/VeloxBackend.cc
@@ -362,6 +362,7 @@ void VeloxBackend::tearDown() {
filesystem->close();
}
#endif
+ gluten::hashTableObjStore.reset();
// Destruct IOThreadPoolExecutor will join all threads.
// On threads exit, thread local variables can be constructed with referencing global variables.
diff --git a/cpp/velox/compute/VeloxBackend.h b/cpp/velox/compute/VeloxBackend.h
index 94e7ec93fba0..d73787063f54 100644
--- a/cpp/velox/compute/VeloxBackend.h
+++ b/cpp/velox/compute/VeloxBackend.h
@@ -28,6 +28,7 @@
#include "velox/common/config/Config.h"
#include "velox/common/memory/MmapAllocator.h"
+#include "jni/JniHashTable.h"
#include "memory/VeloxMemoryManager.h"
namespace gluten {
@@ -56,6 +57,10 @@ class VeloxBackend {
return globalMemoryManager_.get();
}
+ folly::Executor* executor() const {
+ return ioExecutor_.get();
+ }
+
void tearDown();
private:
diff --git a/cpp/velox/jni/JniHashTable.cc b/cpp/velox/jni/JniHashTable.cc
new file mode 100644
index 000000000000..77cd78ff6a4c
--- /dev/null
+++ b/cpp/velox/jni/JniHashTable.cc
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+
+#include
+#include "JniHashTable.h"
+#include "folly/String.h"
+#include "memory/ColumnarBatch.h"
+#include "memory/VeloxColumnarBatch.h"
+#include "substrait/algebra.pb.h"
+#include "substrait/type.pb.h"
+#include "velox/core/PlanNode.h"
+#include "velox/type/Type.h"
+
+namespace gluten {
+
+static jclass jniVeloxBroadcastBuildSideCache = nullptr;
+static jmethodID jniGet = nullptr;
+
+jlong callJavaGet(const std::string& id) {
+ JNIEnv* env;
+ if (vm->GetEnv(reinterpret_cast(&env), jniVersion) != JNI_OK) {
+ throw gluten::GlutenException("JNIEnv was not attached to current thread");
+ }
+
+ const jstring s = env->NewStringUTF(id.c_str());
+
+ auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache, jniGet, s);
+ return result;
+}
+
+// Return the velox's hash table.
+std::shared_ptr nativeHashTableBuild(
+ const std::string& joinKeys,
+ std::vector names,
+ std::vector veloxTypeList,
+ int joinType,
+ bool hasMixedJoinCondition,
+ bool isExistenceJoin,
+ bool isNullAwareAntiJoin,
+ int64_t bloomFilterPushdownSize,
+ std::vector>& batches,
+ std::shared_ptr memoryPool) {
+ auto rowType = std::make_shared(std::move(names), std::move(veloxTypeList));
+
+ auto sJoin = static_cast(joinType);
+ facebook::velox::core::JoinType vJoin;
+ switch (sJoin) {
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER:
+ vJoin = facebook::velox::core::JoinType::kInner;
+ break;
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_OUTER:
+ vJoin = facebook::velox::core::JoinType::kFull;
+ break;
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT:
+ vJoin = facebook::velox::core::JoinType::kLeft;
+ break;
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT:
+ vJoin = facebook::velox::core::JoinType::kRight;
+ break;
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
+ // Determine the semi join type based on extracted information.
+ if (isExistenceJoin) {
+ vJoin = facebook::velox::core::JoinType::kLeftSemiProject;
+ } else {
+ vJoin = facebook::velox::core::JoinType::kLeftSemiFilter;
+ }
+ break;
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
+ // Determine the semi join type based on extracted information.
+ if (isExistenceJoin) {
+ vJoin = facebook::velox::core::JoinType::kRightSemiProject;
+ } else {
+ vJoin = facebook::velox::core::JoinType::kRightSemiFilter;
+ }
+ break;
+ case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_ANTI: {
+ // Determine the anti join type based on extracted information.
+ vJoin = facebook::velox::core::JoinType::kAnti;
+ break;
+ }
+ default:
+ VELOX_NYI("Unsupported Join type: {}", std::to_string(sJoin));
+ }
+
+ std::vector joinKeyNames;
+ folly::split(',', joinKeys, joinKeyNames);
+
+ std::vector> joinKeyTypes;
+ joinKeyTypes.reserve(joinKeyNames.size());
+ for (const auto& name : joinKeyNames) {
+ joinKeyTypes.emplace_back(
+ std::make_shared(rowType->findChild(name), name));
+ }
+
+ auto hashTableBuilder = std::make_shared(
+ vJoin,
+ isNullAwareAntiJoin,
+ hasMixedJoinCondition,
+ bloomFilterPushdownSize,
+ joinKeyTypes,
+ rowType,
+ memoryPool.get());
+
+ for (auto i = 0; i < batches.size(); i++) {
+ auto rowVector = VeloxColumnarBatch::from(memoryPool.get(), batches[i])->getRowVector();
+ hashTableBuilder->addInput(rowVector);
+ }
+
+ return hashTableBuilder;
+}
+
+long getJoin(std::string hashTableId) {
+ return callJavaGet(hashTableId);
+}
+
+void initVeloxJniHashTable(JNIEnv* env) {
+ if (env->GetJavaVM(&vm) != JNI_OK) {
+ throw gluten::GlutenException("Unable to get JavaVM instance");
+ }
+ const char* classSig = "Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;";
+ jniVeloxBroadcastBuildSideCache = createGlobalClassReferenceOrError(env, classSig);
+ jniGet = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache, "get", "(Ljava/lang/String;)J");
+}
+
+void finalizeVeloxJniHashTable(JNIEnv* env) {
+ env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache);
+}
+
+} // namespace gluten
diff --git a/cpp/velox/jni/JniHashTable.h b/cpp/velox/jni/JniHashTable.h
new file mode 100644
index 000000000000..c0d9227840d9
--- /dev/null
+++ b/cpp/velox/jni/JniHashTable.h
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include "memory/ColumnarBatch.h"
+#include "memory/VeloxMemoryManager.h"
+#include "operators/hashjoin/HashTableBuilder.h"
+#include "utils/ObjectStore.h"
+#include "velox/exec/HashTable.h"
+
+namespace gluten {
+
+inline static JavaVM* vm = nullptr;
+
+inline static std::unique_ptr hashTableObjStore = ObjectStore::create();
+
+// Return the hash table builder address.
+std::shared_ptr nativeHashTableBuild(
+ const std::string& joinKeys,
+ std::vector names,
+ std::vector veloxTypeList,
+ int joinType,
+ bool hasMixedJoinCondition,
+ bool isExistenceJoin,
+ bool isNullAwareAntiJoin,
+ int64_t bloomFilterPushdownSize,
+ std::vector>& batches,
+ std::shared_ptr memoryPool);
+
+long getJoin(std::string hashTableId);
+
+void initVeloxJniHashTable(JNIEnv* env);
+
+void finalizeVeloxJniHashTable(JNIEnv* env);
+
+jlong callJavaGet(const std::string& id);
+
+} // namespace gluten
diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc
index ad6f8947eb28..e488274e9718 100644
--- a/cpp/velox/jni/VeloxJniWrapper.cc
+++ b/cpp/velox/jni/VeloxJniWrapper.cc
@@ -30,14 +30,18 @@
#include "config/GlutenConfig.h"
#include "jni/JniError.h"
#include "jni/JniFileSystem.h"
+#include "jni/JniHashTable.h"
+#include "memory/AllocationListener.h"
#include "memory/VeloxColumnarBatch.h"
#include "memory/VeloxMemoryManager.h"
+#include "operators/hashjoin/HashTableBuilder.h"
#include "shuffle/rss/RssPartitionWriter.h"
#include "substrait/SubstraitToVeloxPlanValidator.h"
#include "utils/ObjectStore.h"
#include "utils/VeloxBatchResizer.h"
#include "velox/common/base/BloomFilter.h"
#include "velox/common/file/FileSystems.h"
+#include "velox/exec/HashTable.h"
#ifdef GLUTEN_ENABLE_GPU
#include "cudf/CudfPlanValidator.h"
@@ -76,6 +80,7 @@ jint JNI_OnLoad(JavaVM* vm, void*) {
getJniErrorState()->ensureInitialized(env);
initVeloxJniFileSystem(env);
initVeloxJniUDF(env);
+ initVeloxJniHashTable(env);
infoCls = createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/validate/NativePlanValidationInfo;");
infoClsInitMethod = getMethodIdOrError(env, infoCls, "", "(ILjava/lang/String;)V");
@@ -84,12 +89,13 @@ jint JNI_OnLoad(JavaVM* vm, void*) {
createGlobalClassReferenceOrError(env, "Lorg/apache/spark/sql/execution/datasources/BlockStripes;");
blockStripesConstructor = getMethodIdOrError(env, blockStripesClass, "", "(J[J[II[[B)V");
- batchWriteMetricsClass =
- createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/metrics/BatchWriteMetrics;");
+ batchWriteMetricsClass = createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/metrics/BatchWriteMetrics;");
batchWriteMetricsConstructor = getMethodIdOrError(env, batchWriteMetricsClass, "", "(JIJJ)V");
DLOG(INFO) << "Loaded Velox backend.";
+ gluten::vm = vm;
+
return jniVersion;
}
@@ -183,8 +189,7 @@ Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeValidateWithFail
JNI_METHOD_END(nullptr)
}
-JNIEXPORT jboolean JNICALL
-Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeValidateExpression( // NOLINT
+JNIEXPORT jboolean JNICALL Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeValidateExpression( // NOLINT
JNIEnv* env,
jobject wrapper,
jbyteArray exprArray,
@@ -439,8 +444,8 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_utils_VeloxBatchResizerJniWrapper
auto ctx = getRuntime(env, wrapper);
auto pool = dynamic_cast(ctx->memoryManager())->getLeafMemoryPool();
auto iter = makeJniColumnarBatchIterator(env, jIter, ctx);
- auto appender = std::make_shared(
- std::make_unique(pool.get(), minOutputBatchSize, maxOutputBatchSize, preferredBatchBytes, std::move(iter)));
+ auto appender = std::make_shared(std::make_unique(
+ pool.get(), minOutputBatchSize, maxOutputBatchSize, preferredBatchBytes, std::move(iter)));
return ctx->saveObject(appender);
JNI_METHOD_END(kInvalidObjectHandle)
}
@@ -583,12 +588,15 @@ Java_org_apache_gluten_datasource_VeloxDataSourceJniWrapper_splitBlockByPartitio
const auto numRows = inputRowVector->size();
connector::hive::PartitionIdGenerator idGen(
- asRowType(inputRowVector->type()), partitionColIndicesVec, 65536, pool.get()
+ asRowType(inputRowVector->type()),
+ partitionColIndicesVec,
+ 65536,
+ pool.get()
#ifdef GLUTEN_ENABLE_ENHANCED_FEATURES
- ,
+ ,
true
-#endif
- );
+#endif
+ );
raw_vector partitionIds{};
idGen.run(inputRowVector, partitionIds);
GLUTEN_CHECK(partitionIds.size() == numRows, "Mismatched number of partition ids");
@@ -914,18 +922,181 @@ JNIEXPORT jobject JNICALL Java_org_apache_gluten_execution_IcebergWriteJniWrappe
auto writer = ObjectStore::retrieve(writerHandle);
auto writeStats = writer->writeStats();
jobject writeMetrics = env->NewObject(
- batchWriteMetricsClass,
- batchWriteMetricsConstructor,
- writeStats.numWrittenBytes,
- writeStats.numWrittenFiles,
- writeStats.writeIOTimeNs,
- writeStats.writeWallNs);
+ batchWriteMetricsClass,
+ batchWriteMetricsConstructor,
+ writeStats.numWrittenBytes,
+ writeStats.numWrittenFiles,
+ writeStats.writeIOTimeNs,
+ writeStats.writeWallNs);
return writeMetrics;
JNI_METHOD_END(nullptr)
}
#endif
+JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_nativeBuild( // NOLINT
+ JNIEnv* env,
+ jclass,
+ jstring tableId,
+ jlongArray batchHandles,
+ jstring joinKey,
+ jint joinType,
+ jboolean hasMixedJoinCondition,
+ jboolean isExistenceJoin,
+ jbyteArray namedStruct,
+ jboolean isNullAwareAntiJoin,
+ jlong bloomFilterPushdownSize,
+ jint broadcastHashTableBuildThreads) {
+ JNI_METHOD_START
+ const auto hashTableId = jStringToCString(env, tableId);
+ const auto hashJoinKey = jStringToCString(env, joinKey);
+ const auto inputType = gluten::getByteArrayElementsSafe(env, namedStruct);
+ std::string structString{
+ reinterpret_cast(inputType.elems()), static_cast(inputType.length())};
+
+ substrait::NamedStruct substraitStruct;
+ substraitStruct.ParseFromString(structString);
+
+ std::vector veloxTypeList;
+ veloxTypeList = SubstraitParser::parseNamedStruct(substraitStruct);
+
+ const auto& substraitNames = substraitStruct.names();
+
+ std::vector names;
+ names.reserve(substraitNames.size());
+ for (const auto& name : substraitNames) {
+ names.emplace_back(name);
+ }
+
+ std::vector> cb;
+ int handleCount = env->GetArrayLength(batchHandles);
+ auto safeArray = getLongArrayElementsSafe(env, batchHandles);
+ for (int i = 0; i < handleCount; ++i) {
+ int64_t handle = safeArray.elems()[i];
+ cb.push_back(ObjectStore::retrieve(handle));
+ }
+
+ size_t maxThreads = broadcastHashTableBuildThreads > 0
+ ? std::min((size_t)broadcastHashTableBuildThreads, (size_t)32)
+ : std::min((size_t)std::thread::hardware_concurrency(), (size_t)32);
+
+ // Heuristic: Each thread should process at least a certain number of batches to justify parallelism overhead.
+ // 32 batches is roughly 128k rows, which is a reasonable granularity for a single thread.
+ constexpr size_t kMinBatchesPerThread = 32;
+ size_t numThreads = std::min(maxThreads, (handleCount + kMinBatchesPerThread - 1) / kMinBatchesPerThread);
+ numThreads = std::max((size_t)1, numThreads);
+
+ if (numThreads <= 1) {
+ auto builder = nativeHashTableBuild(
+ hashJoinKey,
+ names,
+ veloxTypeList,
+ joinType,
+ hasMixedJoinCondition,
+ isExistenceJoin,
+ isNullAwareAntiJoin,
+ bloomFilterPushdownSize,
+ cb,
+ defaultLeafVeloxMemoryPool());
+
+ auto mainTable = builder->uniqueTable();
+ mainTable->prepareJoinTable(
+ {},
+ facebook::velox::exec::BaseHashTable::kNoSpillInputStartPartitionBit,
+ 1'000'000,
+ builder->dropDuplicates(),
+ nullptr);
+ builder->setHashTable(std::move(mainTable));
+
+ return gluten::hashTableObjStore->save(builder);
+ }
+
+ std::vector threads;
+
+ std::vector> hashTableBuilders(numThreads);
+ std::vector> otherTables(numThreads);
+
+ for (size_t t = 0; t < numThreads; ++t) {
+ size_t start = (handleCount * t) / numThreads;
+ size_t end = (handleCount * (t + 1)) / numThreads;
+
+ threads.emplace_back([&, t, start, end]() {
+ std::vector> threadBatches;
+ for (size_t i = start; i < end; ++i) {
+ threadBatches.push_back(cb[i]);
+ }
+
+ auto builder = nativeHashTableBuild(
+ hashJoinKey,
+ names,
+ veloxTypeList,
+ joinType,
+ hasMixedJoinCondition,
+ isExistenceJoin,
+ isNullAwareAntiJoin,
+ bloomFilterPushdownSize,
+ threadBatches,
+ defaultLeafVeloxMemoryPool());
+
+ hashTableBuilders[t] = std::move(builder);
+ otherTables[t] = std::move(hashTableBuilders[t]->uniqueTable());
+ });
+ }
+
+ for (auto& thread : threads) {
+ thread.join();
+ }
+
+ auto mainTable = std::move(otherTables[0]);
+ std::vector> tables;
+ for (int i = 1; i < numThreads; ++i) {
+ tables.push_back(std::move(otherTables[i]));
+ }
+
+ // TODO: Get accurate signal if parallel join build is going to be applied
+ // from hash table. Currently there is still a chance inside hash table that
+ // it might decide it is not going to trigger parallel join build.
+ const bool allowParallelJoinBuild = !tables.empty();
+
+ mainTable->prepareJoinTable(
+ std::move(tables),
+ facebook::velox::exec::BaseHashTable::kNoSpillInputStartPartitionBit,
+ 1'000'000,
+ hashTableBuilders[0]->dropDuplicates(),
+ allowParallelJoinBuild ? VeloxBackend::get()->executor() : nullptr);
+
+ for (int i = 1; i < numThreads; ++i) {
+ if (hashTableBuilders[i]->joinHasNullKeys()) {
+ hashTableBuilders[0]->setJoinHasNullKeys(true);
+ break;
+ }
+ }
+
+ hashTableBuilders[0]->setHashTable(std::move(mainTable));
+ return gluten::hashTableObjStore->save(hashTableBuilders[0]);
+ JNI_METHOD_END(kInvalidObjectHandle)
+}
+
+JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_cloneHashTable( // NOLINT
+ JNIEnv* env,
+ jclass,
+ jlong tableHandler) {
+ JNI_METHOD_START
+ auto hashTableHandler = ObjectStore::retrieve(tableHandler);
+ return gluten::hashTableObjStore->save(hashTableHandler);
+ JNI_METHOD_END(kInvalidObjectHandle)
+}
+
+JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_clearHashTable( // NOLINT
+ JNIEnv* env,
+ jclass,
+ jlong tableHandler) {
+ JNI_METHOD_START
+ auto hashTableHandler = ObjectStore::retrieve(tableHandler);
+ hashTableHandler->hashTable()->clear(true);
+ ObjectStore::release(tableHandler);
+ JNI_METHOD_END()
+}
#ifdef __cplusplus
}
#endif
diff --git a/cpp/velox/operators/hashjoin/HashTableBuilder.cc b/cpp/velox/operators/hashjoin/HashTableBuilder.cc
new file mode 100644
index 000000000000..7c42cf5b499a
--- /dev/null
+++ b/cpp/velox/operators/hashjoin/HashTableBuilder.cc
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "operators/hashjoin/HashTableBuilder.h"
+#include "velox/exec/OperatorUtils.h"
+
+namespace gluten {
+namespace {
+facebook::velox::RowTypePtr hashJoinTableType(
+ const std::vector& joinKeys,
+ const facebook::velox::RowTypePtr& inputType) {
+ const auto numKeys = joinKeys.size();
+
+ std::vector names;
+ names.reserve(inputType->size());
+ std::vector types;
+ types.reserve(inputType->size());
+ std::unordered_set keyChannelSet;
+ keyChannelSet.reserve(inputType->size());
+
+ for (int i = 0; i < numKeys; ++i) {
+ auto& key = joinKeys[i];
+ auto channel = facebook::velox::exec::exprToChannel(key.get(), inputType);
+ keyChannelSet.insert(channel);
+ names.emplace_back(inputType->nameOf(channel));
+ types.emplace_back(inputType->childAt(channel));
+ }
+
+ for (auto i = 0; i < inputType->size(); ++i) {
+ if (keyChannelSet.find(i) == keyChannelSet.end()) {
+ names.emplace_back(inputType->nameOf(i));
+ types.emplace_back(inputType->childAt(i));
+ }
+ }
+
+ return ROW(std::move(names), std::move(types));
+}
+
+bool isLeftNullAwareJoinWithFilter(facebook::velox::core::JoinType joinType, bool nullAware, bool withFilter) {
+ return (isAntiJoin(joinType) || isLeftSemiProjectJoin(joinType) || isLeftSemiFilterJoin(joinType)) && nullAware &&
+ withFilter;
+}
+} // namespace
+
+HashTableBuilder::HashTableBuilder(
+ facebook::velox::core::JoinType joinType,
+ bool nullAware,
+ bool withFilter,
+ int64_t bloomFilterPushdownSize,
+ const std::vector& joinKeys,
+ const facebook::velox::RowTypePtr& inputType,
+ facebook::velox::memory::MemoryPool* pool)
+ : joinType_{joinType},
+ nullAware_{nullAware},
+ withFilter_(withFilter),
+ keyChannelMap_(joinKeys.size()),
+ inputType_(inputType),
+ bloomFilterPushdownSize_(bloomFilterPushdownSize),
+ pool_(pool) {
+ const auto numKeys = joinKeys.size();
+ keyChannels_.reserve(numKeys);
+
+ for (int i = 0; i < numKeys; ++i) {
+ auto& key = joinKeys[i];
+ auto channel = facebook::velox::exec::exprToChannel(key.get(), inputType_);
+ keyChannelMap_[channel] = i;
+ keyChannels_.emplace_back(channel);
+ }
+
+ // Identify the non-key build side columns and make a decoder for each.
+ const int32_t numDependents = inputType_->size() - numKeys;
+ if (numDependents > 0) {
+ // Number of join keys (numKeys) may be less then number of input columns
+ // (inputType->size()). In this case numDependents is negative and cannot be
+ // used to call 'reserve'. This happens when we join different probe side
+ // keys with the same build side key: SELECT * FROM t LEFT JOIN u ON t.k1 =
+ // u.k AND t.k2 = u.k.
+ dependentChannels_.reserve(numDependents);
+ decoders_.reserve(numDependents);
+ }
+ for (auto i = 0; i < inputType->size(); ++i) {
+ if (keyChannelMap_.find(i) == keyChannelMap_.end()) {
+ dependentChannels_.emplace_back(i);
+ decoders_.emplace_back(std::make_unique());
+ }
+ }
+
+ tableType_ = hashJoinTableType(joinKeys, inputType);
+ setupTable();
+}
+
+// Invoked to set up hash table to build.
+void HashTableBuilder::setupTable() {
+ VELOX_CHECK_NULL(uniqueTable_);
+
+ const auto numKeys = keyChannels_.size();
+ std::vector> keyHashers;
+ keyHashers.reserve(numKeys);
+ for (vector_size_t i = 0; i < numKeys; ++i) {
+ keyHashers.emplace_back(facebook::velox::exec::VectorHasher::create(tableType_->childAt(i), keyChannels_[i]));
+ }
+
+ const auto numDependents = tableType_->size() - numKeys;
+ std::vector dependentTypes;
+ dependentTypes.reserve(numDependents);
+ for (int i = numKeys; i < tableType_->size(); ++i) {
+ dependentTypes.emplace_back(tableType_->childAt(i));
+ }
+ if (isRightJoin(joinType_) || isFullJoin(joinType_) || isRightSemiProjectJoin(joinType_)) {
+ // Do not ignore null keys.
+ uniqueTable_ = facebook::velox::exec::HashTable::createForJoin(
+ std::move(keyHashers),
+ dependentTypes,
+ true, // allowDuplicates
+ true, // hasProbedFlag
+ 1'000, // operatorCtx_->driverCtx()->queryConfig().minTableRowsForParallelJoinBuild()
+ pool_,
+ true);
+ } else {
+ // (Left) semi and anti join with no extra filter only needs to know whether
+ // there is a match. Hence, no need to store entries with duplicate keys.
+ dropDuplicates_ =
+ !withFilter_ && (isLeftSemiFilterJoin(joinType_) || isLeftSemiProjectJoin(joinType_) || isAntiJoin(joinType_));
+ // Right semi join needs to tag build rows that were probed.
+ const bool needProbedFlag = isRightSemiFilterJoin(joinType_);
+ if (isLeftNullAwareJoinWithFilter(joinType_, nullAware_, withFilter_)) {
+ // We need to check null key rows in build side in case of null-aware anti
+ // or left semi project join with filter set.
+ uniqueTable_ = facebook::velox::exec::HashTable::createForJoin(
+ std::move(keyHashers),
+ dependentTypes,
+ !dropDuplicates_, // allowDuplicates
+ needProbedFlag, // hasProbedFlag
+ 1'000, // operatorCtx_->driverCtx()->queryConfig().minTableRowsForParallelJoinBuild()
+ pool_,
+ true);
+ } else {
+ // Ignore null keys
+ uniqueTable_ = facebook::velox::exec::HashTable::createForJoin(
+ std::move(keyHashers),
+ dependentTypes,
+ !dropDuplicates_, // allowDuplicates
+ needProbedFlag, // hasProbedFlag
+ 1'000, // operatorCtx_->driverCtx()->queryConfig().minTableRowsForParallelJoinBuild()
+ pool_,
+ bloomFilterPushdownSize_);
+ }
+ }
+ analyzeKeys_ = uniqueTable_->hashMode() != facebook::velox::exec::BaseHashTable::HashMode::kHash;
+}
+
+void HashTableBuilder::addInput(facebook::velox::RowVectorPtr input) {
+ activeRows_.resize(input->size());
+ activeRows_.setAll();
+
+ auto& hashers = uniqueTable_->hashers();
+
+ for (auto i = 0; i < hashers.size(); ++i) {
+ auto key = input->childAt(hashers[i]->channel())->loadedVector();
+ hashers[i]->decode(*key, activeRows_);
+ }
+
+ deselectRowsWithNulls(hashers, activeRows_);
+ activeRows_.setAll();
+
+ if (!isRightJoin(joinType_) && !isFullJoin(joinType_) && !isRightSemiProjectJoin(joinType_) &&
+ !isLeftNullAwareJoinWithFilter(joinType_, nullAware_, withFilter_)) {
+ deselectRowsWithNulls(hashers, activeRows_);
+ if (nullAware_ && !joinHasNullKeys_ && activeRows_.countSelected() < input->size()) {
+ joinHasNullKeys_ = true;
+ }
+ } else if (nullAware_ && !joinHasNullKeys_) {
+ for (auto& hasher : hashers) {
+ auto& decoded = hasher->decodedVector();
+ if (decoded.mayHaveNulls()) {
+ auto* nulls = decoded.nulls(&activeRows_);
+ if (nulls && facebook::velox::bits::countNulls(nulls, 0, activeRows_.end()) > 0) {
+ joinHasNullKeys_ = true;
+ break;
+ }
+ }
+ }
+ }
+
+ for (auto i = 0; i < dependentChannels_.size(); ++i) {
+ decoders_[i]->decode(*input->childAt(dependentChannels_[i])->loadedVector(), activeRows_);
+ }
+
+ if (!activeRows_.hasSelections()) {
+ return;
+ }
+
+ if (analyzeKeys_ && hashes_.size() < activeRows_.end()) {
+ hashes_.resize(activeRows_.end());
+ }
+
+ // As long as analyzeKeys is true, we keep running the keys through
+ // the Vectorhashers so that we get a possible mapping of the keys
+ // to small ints for array or normalized key. When mayUseValueIds is
+ // false for the first time we stop. We do not retain the value ids
+ // since the final ones will only be known after all data is
+ // received.
+ for (auto& hasher : hashers) {
+ // TODO: Load only for active rows, except if right/full outer join.
+ if (analyzeKeys_) {
+ hasher->computeValueIds(activeRows_, hashes_);
+ analyzeKeys_ = hasher->mayUseValueIds();
+ }
+ }
+ auto rows = uniqueTable_->rows();
+ auto nextOffset = rows->nextOffset();
+
+ activeRows_.applyToSelected([&](auto rowIndex) {
+ char* newRow = rows->newRow();
+ if (nextOffset) {
+ *reinterpret_cast(newRow + nextOffset) = nullptr;
+ }
+ // Store the columns for each row in sequence. At probe time
+ // strings of the row will probably be in consecutive places, so
+ // reading one will prime the cache for the next.
+ for (auto i = 0; i < hashers.size(); ++i) {
+ rows->store(hashers[i]->decodedVector(), rowIndex, newRow, i);
+ }
+ for (auto i = 0; i < dependentChannels_.size(); ++i) {
+ rows->store(*decoders_[i], rowIndex, newRow, i + hashers.size());
+ }
+ });
+}
+
+} // namespace gluten
diff --git a/cpp/velox/operators/hashjoin/HashTableBuilder.h b/cpp/velox/operators/hashjoin/HashTableBuilder.h
new file mode 100644
index 000000000000..83c90b411009
--- /dev/null
+++ b/cpp/velox/operators/hashjoin/HashTableBuilder.h
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include "velox/exec/HashJoinBridge.h"
+#include "velox/exec/HashTable.h"
+#include "velox/exec/RowContainer.h"
+#include "velox/exec/VectorHasher.h"
+
+namespace gluten {
+using column_index_t = uint32_t;
+using vector_size_t = int32_t;
+
+class HashTableBuilder {
+ public:
+ HashTableBuilder(
+ facebook::velox::core::JoinType joinType,
+ bool nullAware,
+ bool withFilter,
+ int64_t bloomFilterPushdownSize,
+ const std::vector& joinKeys,
+ const facebook::velox::RowTypePtr& inputType,
+ facebook::velox::memory::MemoryPool* pool);
+
+ void addInput(facebook::velox::RowVectorPtr input);
+
+ void setHashTable(std::unique_ptr uniqueHashTable) {
+ table_ = std::move(uniqueHashTable);
+ }
+
+ std::unique_ptr uniqueTable() {
+ return std::move(uniqueTable_);
+ }
+
+ std::shared_ptr hashTable() {
+ return table_;
+ }
+ void setJoinHasNullKeys(bool joinHasNullKeys) {
+ joinHasNullKeys_ = joinHasNullKeys;
+ }
+
+ bool joinHasNullKeys() {
+ return joinHasNullKeys_;
+ }
+
+ bool dropDuplicates() {
+ return dropDuplicates_;
+ }
+
+ private:
+ // Invoked to set up hash table to build.
+ void setupTable();
+
+ const facebook::velox::core::JoinType joinType_;
+
+ const bool nullAware_;
+ const bool withFilter_;
+
+ // The row type used for hash table build and disk spilling.
+ facebook::velox::RowTypePtr tableType_;
+
+ // Container for the rows being accumulated.
+ std::shared_ptr table_;
+
+ std::unique_ptr uniqueTable_;
+
+ // Key channels in 'input_'
+ std::vector keyChannels_;
+
+ // Non-key channels in 'input_'.
+ std::vector dependentChannels_;
+
+ // Corresponds 1:1 to 'dependentChannels_'.
+ std::vector> decoders_;
+
+ // True if we are considering use of normalized keys or array hash tables.
+ // Set to false when the dataset is no longer suitable.
+ bool analyzeKeys_;
+
+ // Temporary space for hash numbers.
+ facebook::velox::raw_vector hashes_;
+
+ // Set of active rows during addInput().
+ facebook::velox::SelectivityVector activeRows_;
+
+ // True if this is a build side of an anti or left semi project join and has
+ // at least one entry with null join keys.
+ bool joinHasNullKeys_{false};
+
+ // Indices of key columns used by the filter in build side table.
+ std::vector keyFilterChannels_;
+ // Indices of dependent columns used by the filter in 'decoders_'.
+ std::vector dependentFilterChannels_;
+
+ // Maps key channel in 'input_' to channel in key.
+ folly::F14FastMap keyChannelMap_;
+
+ const facebook::velox::RowTypePtr& inputType_;
+
+ int64_t bloomFilterPushdownSize_;
+
+ facebook::velox::memory::MemoryPool* pool_;
+
+ bool dropDuplicates_{false};
+};
+
+} // namespace gluten
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index d71ab12528dd..834127e20cc1 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -19,12 +19,15 @@
#include "TypeUtils.h"
#include "VariantToVectorConverter.h"
+#include "jni/JniHashTable.h"
+#include "operators/hashjoin/HashTableBuilder.h"
#include "operators/plannodes/RowVectorStream.h"
#include "velox/connectors/hive/HiveDataSink.h"
#include "velox/exec/TableWriter.h"
#include "velox/type/Type.h"
#include "utils/ConfigExtractor.h"
+#include "utils/ObjectStore.h"
#include "utils/VeloxWriterUtils.h"
#include "config.pb.h"
@@ -393,6 +396,43 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
rightNode,
getJoinOutputType(leftNode, rightNode, joinType));
+ } else if (
+ sJoin.has_advanced_extension() &&
+ SubstraitParser::configSetInOptimization(sJoin.advanced_extension(), "isBHJ=")) {
+ std::string hashTableId = sJoin.hashtableid();
+
+ std::shared_ptr opaqueSharedHashTable = nullptr;
+ bool joinHasNullKeys = false;
+
+ try {
+ auto hashTableBuilder = ObjectStore::retrieve(getJoin(hashTableId));
+ joinHasNullKeys = hashTableBuilder->joinHasNullKeys();
+ auto originalShared = hashTableBuilder->hashTable();
+ opaqueSharedHashTable = std::shared_ptr(
+ originalShared, reinterpret_cast(originalShared.get()));
+
+ LOG(INFO) << "Successfully retrieved and aliased HashTable for reuse. ID: " << hashTableId;
+ } catch (const std::exception& e) {
+ LOG(WARNING)
+ << "Error retrieving HashTable from ObjectStore: " << e.what()
+ << ". Falling back to building new table. To ensure correct results, please verify that spark.gluten.velox.buildHashTableOncePerExecutor.enabled is set to false.";
+ opaqueSharedHashTable = nullptr;
+ }
+
+ // Create HashJoinNode node
+ return std::make_shared(
+ nextPlanNodeId(),
+ joinType,
+ isNullAwareAntiJoin,
+ leftKeys,
+ rightKeys,
+ filter,
+ leftNode,
+ rightNode,
+ getJoinOutputType(leftNode, rightNode, joinType),
+ false,
+ joinHasNullKeys,
+ opaqueSharedHashTable);
} else {
// Create HashJoinNode node
return std::make_shared(
diff --git a/docs/Configuration.md b/docs/Configuration.md
index 1372d982430c..066d66443602 100644
--- a/docs/Configuration.md
+++ b/docs/Configuration.md
@@ -79,6 +79,7 @@ nav_order: 15
| spark.gluten.sql.columnar.partial.generate | true | Evaluates the non-offload-able HiveUDTF using vanilla Spark generator |
| spark.gluten.sql.columnar.partial.project | true | Break up one project node into 2 phases when some of the expressions are non offload-able. Phase one is a regular offloaded project transformer that evaluates the offload-able expressions in native, phase two preserves the output from phase one and evaluates the remaining non-offload-able expressions using vanilla Spark projections |
| spark.gluten.sql.columnar.physicalJoinOptimizationLevel | 12 | Fallback to row operators if there are several continuous joins. |
+| spark.gluten.sql.columnar.physicalJoinOptimizationOutputSize | 52 | Fallback to row operators if there are several continuous joins and matched output size. |
| spark.gluten.sql.columnar.physicalJoinOptimizeEnable | false | Enable or disable columnar physicalJoinOptimize. |
| spark.gluten.sql.columnar.preferStreamingAggregate | true | Velox backend supports `StreamingAggregate`. `StreamingAggregate` uses the less memory as it does not need to hold all groups in memory, so it could avoid spill. When true and the child output ordering satisfies the grouping key then Gluten will choose `StreamingAggregate` as the native operator. |
| spark.gluten.sql.columnar.project | true | Enable or disable columnar project. |
diff --git a/docs/velox-configuration.md b/docs/velox-configuration.md
index f4a79c465211..1a4a1fb7e65a 100644
--- a/docs/velox-configuration.md
+++ b/docs/velox-configuration.md
@@ -19,6 +19,7 @@ nav_order: 16
| spark.gluten.sql.columnar.backend.velox.bloomFilter.expectedNumItems | 1000000 | The default number of expected items for the velox bloomfilter: 'spark.bloom_filter.expected_num_items' |
| spark.gluten.sql.columnar.backend.velox.bloomFilter.maxNumBits | 4194304 | The max number of bits to use for the velox bloom filter: 'spark.bloom_filter.max_num_bits' |
| spark.gluten.sql.columnar.backend.velox.bloomFilter.numBits | 8388608 | The default number of bits to use for the velox bloom filter: 'spark.bloom_filter.num_bits' |
+| spark.gluten.sql.columnar.backend.velox.broadcastHashTableBuildThreads | 1 | The number of threads used to build the broadcast hash table. If not set or set to 0, it will use the default number of threads (available processors). |
| spark.gluten.sql.columnar.backend.velox.cacheEnabled | false | Enable Velox cache, default off. It's recommended to enablesoft-affinity as well when enable velox cache. |
| spark.gluten.sql.columnar.backend.velox.cachePrefetchMinPct | 0 | Set prefetch cache min pct for velox file scan |
| spark.gluten.sql.columnar.backend.velox.checkUsageLeak | true | Enable check memory usage leak. |
diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenJoinKeysCapture.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenJoinKeysCapture.scala
new file mode 100644
index 000000000000..5d1cb8d90a8c
--- /dev/null
+++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenJoinKeysCapture.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, ExtractSingleColumnNullAwareAntiJoin}
+import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
+import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
+
+/**
+ * Strategy to capture join keys from logical plan before Spark's JoinSelection transforms them.
+ * This strategy runs early in the planning phase to preserve the original join keys before any
+ * transformations like rewriteKeyExpr.
+ */
+case class GlutenJoinKeysCapture() extends SparkStrategy {
+
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = {
+
+ if (!plan.isInstanceOf[Join]) {
+ return Nil
+ }
+
+ plan match {
+
+ case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, left, right, _) =>
+ if (leftKeys.nonEmpty) {
+ left.setTagValue(JoinKeysTag.ORIGINAL_JOIN_KEYS, leftKeys)
+ }
+ if (rightKeys.nonEmpty) {
+ right.setTagValue(JoinKeysTag.ORIGINAL_JOIN_KEYS, rightKeys)
+ }
+
+ Nil
+
+ case j @ ExtractSingleColumnNullAwareAntiJoin(leftKeys, rightKeys) =>
+ if (leftKeys.nonEmpty) {
+ j.left.setTagValue(JoinKeysTag.ORIGINAL_JOIN_KEYS, leftKeys)
+ }
+ if (rightKeys.nonEmpty) {
+ j.right.setTagValue(JoinKeysTag.ORIGINAL_JOIN_KEYS, rightKeys)
+ }
+
+ Nil
+
+ // For non-equi-join or other plan nodes, return Nil.
+ case _ => Nil
+ }
+ }
+}
diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/JoinKeysTag.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/JoinKeysTag.scala
new file mode 100644
index 000000000000..646b0df7d0ed
--- /dev/null
+++ b/gluten-core/src/main/scala/org/apache/gluten/extension/JoinKeysTag.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
+
+/** TreeNodeTag for storing original join keys before Spark's transformations. */
+object JoinKeysTag {
+
+ /** Tag to store original join keys on logical plan nodes. */
+ val ORIGINAL_JOIN_KEYS: TreeNodeTag[Seq[Expression]] =
+ TreeNodeTag[Seq[Expression]]("gluten.originalJoinKeys")
+}
diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/JoinRelNode.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/JoinRelNode.java
index 714340cdf670..2bd98500feef 100644
--- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/JoinRelNode.java
+++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/JoinRelNode.java
@@ -32,6 +32,7 @@ public class JoinRelNode implements RelNode, Serializable {
private final ExpressionNode expression;
private final ExpressionNode postJoinFilter;
private final AdvancedExtensionNode extensionNode;
+ private final String hashTableId;
JoinRelNode(
RelNode left,
@@ -39,12 +40,14 @@ public class JoinRelNode implements RelNode, Serializable {
JoinRel.JoinType joinType,
ExpressionNode expression,
ExpressionNode postJoinFilter,
+ String hashTableId,
AdvancedExtensionNode extensionNode) {
this.left = left;
this.right = right;
this.joinType = joinType;
this.expression = expression;
this.postJoinFilter = postJoinFilter;
+ this.hashTableId = hashTableId;
this.extensionNode = extensionNode;
}
@@ -72,6 +75,8 @@ public Rel toProtobuf() {
joinBuilder.setAdvancedExtension(extensionNode.toProtobuf());
}
+ joinBuilder.setHashTableId(hashTableId);
+
return Rel.newBuilder().setJoin(joinBuilder.build()).build();
}
}
diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
index 20ca9d36f1e0..40723946241d 100644
--- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
+++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
@@ -184,11 +184,12 @@ public static RelNode makeJoinRel(
JoinRel.JoinType joinType,
ExpressionNode expression,
ExpressionNode postJoinFilter,
+ String hashTableId,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return makeJoinRel(
- left, right, joinType, expression, postJoinFilter, null, context, operatorId);
+ left, right, joinType, expression, postJoinFilter, null, hashTableId, context, operatorId);
}
public static RelNode makeJoinRel(
@@ -198,10 +199,12 @@ public static RelNode makeJoinRel(
ExpressionNode expression,
ExpressionNode postJoinFilter,
AdvancedExtensionNode extensionNode,
+ String hashTableId,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
- return new JoinRelNode(left, right, joinType, expression, postJoinFilter, extensionNode);
+ return new JoinRelNode(
+ left, right, joinType, expression, postJoinFilter, hashTableId, extensionNode);
}
public static RelNode makeCrossRel(
diff --git a/gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto b/gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto
index 7d72332baa88..2bfb68e09790 100644
--- a/gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto
+++ b/gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto
@@ -258,6 +258,8 @@ message JoinRel {
JoinType type = 6;
+ string hashTableId = 7;
+
enum JoinType {
JOIN_TYPE_UNSPECIFIED = 0;
JOIN_TYPE_INNER = 1;
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
index 671a29709e9a..8dd3156099e9 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
@@ -85,6 +85,8 @@ trait BackendSettingsApi {
def enableJoinKeysRewrite(): Boolean = true
+ def enableHashTableBuildOncePerExecutor(): Boolean = true
+
def supportHashBuildJoinTypeOnLeft: JoinType => Boolean = {
case _: InnerLike | RightOuter | FullOuter => true
case _ => false
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala b/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
index ed2d54936655..8d71e15964ea 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
@@ -202,6 +202,9 @@ class GlutenConfig(conf: SQLConf) extends GlutenCoreConfig(conf) {
def physicalJoinOptimizationThrottle: Integer =
getConf(COLUMNAR_PHYSICAL_JOIN_OPTIMIZATION_THROTTLE)
+ def physicalJoinOptimizationOutputSize: Integer =
+ getConf(COLUMNAR_PHYSICAL_JOIN_OPTIMIZATION_OUTPUT_SIZE)
+
def enablePhysicalJoinOptimize: Boolean =
getConf(COLUMNAR_PHYSICAL_JOIN_OPTIMIZATION_ENABLED)
@@ -998,6 +1001,13 @@ object GlutenConfig extends ConfigRegistry {
.intConf
.createWithDefault(12)
+ val COLUMNAR_PHYSICAL_JOIN_OPTIMIZATION_OUTPUT_SIZE =
+ buildConf("spark.gluten.sql.columnar.physicalJoinOptimizationOutputSize")
+ .doc(
+ "Fallback to row operators if there are several continuous joins and matched output size.")
+ .intConf
+ .createWithDefault(52)
+
val COLUMNAR_PHYSICAL_JOIN_OPTIMIZATION_ENABLED =
buildConf("spark.gluten.sql.columnar.physicalJoinOptimizeEnable")
.doc("Enable or disable columnar physicalJoinOptimize.")
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
index e5db3385154d..b4fa188f44e6 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
@@ -186,9 +186,14 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport {
// https://issues.apache.org/jira/browse/SPARK-31869
private def expandPartitioning(partitioning: Partitioning): Partitioning = {
val expandLimit = conf.broadcastHashJoinOutputPartitioningExpandLimit
+ val (buildKeys, streamedKeys) = if (needSwitchChildren) {
+ (leftKeys, rightKeys)
+ } else {
+ (rightKeys, leftKeys)
+ }
joinType match {
case _: InnerLike if expandLimit > 0 =>
- new ExpandOutputPartitioningShim(streamedKeyExprs, buildKeyExprs, expandLimit)
+ new ExpandOutputPartitioningShim(streamedKeys, buildKeys, expandLimit)
.expandPartitioning(partitioning)
case _ => partitioning
}
@@ -262,7 +267,8 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport {
inputStreamedOutput,
inputBuildOutput,
context,
- operatorId
+ operatorId,
+ buildPlan.id.toString
)
context.registerJoinParam(operatorId, joinParams)
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
index a7a31cf471c5..eeb60698902b 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
@@ -184,6 +184,7 @@ object JoinUtils {
inputBuildOutput: Seq[Attribute],
substraitContext: SubstraitContext,
operatorId: java.lang.Long,
+ hashTableId: String = "",
validation: Boolean = false): RelNode = {
// scalastyle:on argcount
// Create pre-projection for build/streamed plan. Append projected keys to each side.
@@ -233,6 +234,7 @@ object JoinUtils {
joinExpressionNode,
postJoinFilter.orNull,
createJoinExtensionNode(joinParameters, streamedOutput ++ buildOutput),
+ hashTableId,
substraitContext,
operatorId
)
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
index 926708ee334a..5e6c77792289 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
@@ -38,17 +38,32 @@ case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan]
lazy val glutenConf: GlutenConfig = GlutenConfig.get
lazy val physicalJoinOptimize = glutenConf.enablePhysicalJoinOptimize
lazy val optimizeLevel: Integer = glutenConf.physicalJoinOptimizationThrottle
+ lazy val outputSize: Integer = glutenConf.physicalJoinOptimizationOutputSize
def existsMultiCodegens(plan: SparkPlan, count: Int = 0): Boolean =
plan match {
case plan: CodegenSupport if plan.supportCodegen =>
- if ((count + 1) >= optimizeLevel) return true
+ if (
+ (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize
+ ) {
+ return true
+ }
plan.children.exists(existsMultiCodegens(_, count + 1))
case plan: ShuffledHashJoinExec =>
- if ((count + 1) >= optimizeLevel) return true
+ if (
+ (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize
+ ) {
+ return true
+ }
+
plan.children.exists(existsMultiCodegens(_, count + 1))
case plan: SortMergeJoinExec if GlutenConfig.get.forceShuffledHashJoin =>
- if ((count + 1) >= optimizeLevel) return true
+ if (
+ (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize
+ ) {
+ return true
+ }
+
plan.children.exists(existsMultiCodegens(_, count + 1))
case _ => false
}
diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
index 1de490ad6165..371f9948b730 100644
--- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
+++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
@@ -131,9 +131,7 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
override def rowType0(): Convention.RowType = Convention.RowType.None
override def doCanonicalize(): SparkPlan = {
- val canonicalized =
- BackendsApiManager.getSparkPlanExecApiInstance.doCanonicalizeForBroadcastMode(mode)
- ColumnarBroadcastExchangeExec(canonicalized, child.canonicalized)
+ ColumnarBroadcastExchangeExec(mode.canonicalized, child.canonicalized)
}
override def doPrepare(): Unit = {