diff --git a/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java b/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java index e54909054cea..ebfd47669ce8 100644 --- a/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java +++ b/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java @@ -42,7 +42,7 @@ public long rtHandle() { public static native long nativeBuild( String buildHashTableId, long[] batchHandlers, - String joinKeys, + String[] joinKeys, int joinType, boolean hasMixedFiltCondition, boolean isExistenceJoin, diff --git a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala index 8127c324b79c..dec67eed7878 100644 --- a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala +++ b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala @@ -34,20 +34,4 @@ object GlutenRpcMessages { case class GlutenCleanExecutionResource(executionId: String, broadcastHashIds: util.Set[String]) extends GlutenRpcMessage - - // for mergetree cache - case class GlutenMergeTreeCacheLoad( - mergeTreeTable: String, - columns: util.Set[String], - onlyMetaCache: Boolean) - extends GlutenRpcMessage - - case class GlutenCacheLoadStatus(jobId: String) - - case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "") - extends GlutenRpcMessage - - case class GlutenFilesCacheLoad(files: Array[Byte]) extends GlutenRpcMessage - - case class GlutenFilesCacheLoadStatus(jobId: String) } diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala index 6429f8bb3fc5..b106319e81b8 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala @@ -197,20 +197,18 @@ case class ColumnarBuildSideRelation( ) } - val joinKey = keys.asScala - .map { - key => - val attr = ConverterUtils.getAttrFromExpr(key) - ConverterUtils.genColumnNameWithExprId(attr) - } - .mkString(",") + val joinKeys = keys.asScala.map { + key => + val attr = ConverterUtils.getAttrFromExpr(key) + ConverterUtils.genColumnNameWithExprId(attr) + }.toArray // Build the hash table hashTableData = HashJoinBuilder .nativeBuild( broadcastContext.buildHashTableId, batchArray.toArray, - joinKey, + joinKeys, broadcastContext.substraitJoinType.ordinal(), broadcastContext.hasMixedFiltCondition, broadcastContext.isExistenceJoin, diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala index fc7516c4b325..01fbb86bee68 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala @@ -167,20 +167,18 @@ class UnsafeColumnarBuildSideRelation( ) } - val joinKey = keys.asScala - .map { - key => - val attr = ConverterUtils.getAttrFromExpr(key) - ConverterUtils.genColumnNameWithExprId(attr) - } - .mkString(",") + val joinKeys = keys.asScala.map { + key => + val attr = ConverterUtils.getAttrFromExpr(key) + ConverterUtils.genColumnNameWithExprId(attr) + }.toArray // Build the hash table hashTableData = HashJoinBuilder .nativeBuild( broadcastContext.buildHashTableId, batchArray.toArray, - joinKey, + joinKeys, broadcastContext.substraitJoinType.ordinal(), broadcastContext.hasMixedFiltCondition, broadcastContext.isExistenceJoin, diff --git a/cpp/velox/compute/VeloxBackend.cc b/cpp/velox/compute/VeloxBackend.cc index 0232da48da14..de9e9385f8f0 100644 --- a/cpp/velox/compute/VeloxBackend.cc +++ b/cpp/velox/compute/VeloxBackend.cc @@ -362,7 +362,6 @@ void VeloxBackend::tearDown() { filesystem->close(); } #endif - gluten::hashTableObjStore.reset(); // Destruct IOThreadPoolExecutor will join all threads. // On threads exit, thread local variables can be constructed with referencing global variables. diff --git a/cpp/velox/jni/JniHashTable.cc b/cpp/velox/jni/JniHashTable.cc index 77cd78ff6a4c..8af60a5534e7 100644 --- a/cpp/velox/jni/JniHashTable.cc +++ b/cpp/velox/jni/JniHashTable.cc @@ -29,24 +29,34 @@ namespace gluten { -static jclass jniVeloxBroadcastBuildSideCache = nullptr; -static jmethodID jniGet = nullptr; +void JniHashTableContext::initialize(JNIEnv* env, JavaVM* javaVm) { + vm_ = javaVm; + const char* classSig = "Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;"; + jniVeloxBroadcastBuildSideCache_ = createGlobalClassReferenceOrError(env, classSig); + jniGet_ = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache_, "get", "(Ljava/lang/String;)J"); +} -jlong callJavaGet(const std::string& id) { +void JniHashTableContext::finalize(JNIEnv* env) { + if (jniVeloxBroadcastBuildSideCache_ != nullptr) { + env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache_); + jniVeloxBroadcastBuildSideCache_ = nullptr; + } +} + +jlong JniHashTableContext::callJavaGet(const std::string& id) const { JNIEnv* env; - if (vm->GetEnv(reinterpret_cast(&env), jniVersion) != JNI_OK) { + if (vm_->GetEnv(reinterpret_cast(&env), jniVersion) != JNI_OK) { throw gluten::GlutenException("JNIEnv was not attached to current thread"); } const jstring s = env->NewStringUTF(id.c_str()); - - auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache, jniGet, s); + auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache_, jniGet_, s); return result; } // Return the velox's hash table. std::shared_ptr nativeHashTableBuild( - const std::string& joinKeys, + const std::vector& joinKeys, std::vector names, std::vector veloxTypeList, int joinType, @@ -98,12 +108,9 @@ std::shared_ptr nativeHashTableBuild( VELOX_NYI("Unsupported Join type: {}", std::to_string(sJoin)); } - std::vector joinKeyNames; - folly::split(',', joinKeys, joinKeyNames); - std::vector> joinKeyTypes; - joinKeyTypes.reserve(joinKeyNames.size()); - for (const auto& name : joinKeyNames) { + joinKeyTypes.reserve(joinKeys.size()); + for (const auto& name : joinKeys) { joinKeyTypes.emplace_back( std::make_shared(rowType->findChild(name), name)); } @@ -125,21 +132,8 @@ std::shared_ptr nativeHashTableBuild( return hashTableBuilder; } -long getJoin(std::string hashTableId) { - return callJavaGet(hashTableId); -} - -void initVeloxJniHashTable(JNIEnv* env) { - if (env->GetJavaVM(&vm) != JNI_OK) { - throw gluten::GlutenException("Unable to get JavaVM instance"); - } - const char* classSig = "Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;"; - jniVeloxBroadcastBuildSideCache = createGlobalClassReferenceOrError(env, classSig); - jniGet = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache, "get", "(Ljava/lang/String;)J"); -} - -void finalizeVeloxJniHashTable(JNIEnv* env) { - env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache); +long getJoin(const std::string& hashTableId) { + return JniHashTableContext::getInstance().callJavaGet(hashTableId); } } // namespace gluten diff --git a/cpp/velox/jni/JniHashTable.h b/cpp/velox/jni/JniHashTable.h index c0d9227840d9..27061e1778ab 100644 --- a/cpp/velox/jni/JniHashTable.h +++ b/cpp/velox/jni/JniHashTable.h @@ -26,13 +26,53 @@ namespace gluten { -inline static JavaVM* vm = nullptr; +// Wrapper class to encapsulate JNI-related static objects for hash table operations. +// This avoids exposing global variables in the gluten namespace. +class JniHashTableContext { + public: + static JniHashTableContext& getInstance() { + static JniHashTableContext instance; + return instance; + } -inline static std::unique_ptr hashTableObjStore = ObjectStore::create(); + // Delete copy and move constructors/operators + JniHashTableContext(const JniHashTableContext&) = delete; + JniHashTableContext& operator=(const JniHashTableContext&) = delete; + JniHashTableContext(JniHashTableContext&&) = delete; + JniHashTableContext& operator=(JniHashTableContext&&) = delete; + + void initialize(JNIEnv* env, JavaVM* javaVm); + void finalize(JNIEnv* env); + + JavaVM* getJavaVM() const { + return vm_; + } + + ObjectStore* getHashTableObjStore() const { + return hashTableObjStore_.get(); + } + + jlong callJavaGet(const std::string& id) const; + + private: + JniHashTableContext() : hashTableObjStore_(ObjectStore::create()) {} + + ~JniHashTableContext() { + // Note: The destructor is called at program exit (after main() returns). + // By this time, JNI_OnUnload should have already been called, which invokes + // finalize() to clean up JNI global references while the JVM is still valid. + // The singleton itself (including hashTableObjStore_) will be destroyed here. + } + + JavaVM* vm_{nullptr}; + std::unique_ptr hashTableObjStore_; + jclass jniVeloxBroadcastBuildSideCache_{nullptr}; + jmethodID jniGet_{nullptr}; +}; // Return the hash table builder address. std::shared_ptr nativeHashTableBuild( - const std::string& joinKeys, + const std::vector& joinKeys, std::vector names, std::vector veloxTypeList, int joinType, @@ -43,12 +83,21 @@ std::shared_ptr nativeHashTableBuild( std::vector>& batches, std::shared_ptr memoryPool); -long getJoin(std::string hashTableId); +long getJoin(const std::string& hashTableId); -void initVeloxJniHashTable(JNIEnv* env); +// Initialize the JNI hash table context +inline void initVeloxJniHashTable(JNIEnv* env, JavaVM* javaVm) { + JniHashTableContext::getInstance().initialize(env, javaVm); +} -void finalizeVeloxJniHashTable(JNIEnv* env); +// Finalize the JNI hash table context +inline void finalizeVeloxJniHashTable(JNIEnv* env) { + JniHashTableContext::getInstance().finalize(env); +} -jlong callJavaGet(const std::string& id); +// Get hash table object store +inline ObjectStore* getHashTableObjStore() { + return JniHashTableContext::getInstance().getHashTableObjStore(); +} } // namespace gluten diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc index e488274e9718..ed1cd5e85dc1 100644 --- a/cpp/velox/jni/VeloxJniWrapper.cc +++ b/cpp/velox/jni/VeloxJniWrapper.cc @@ -80,7 +80,7 @@ jint JNI_OnLoad(JavaVM* vm, void*) { getJniErrorState()->ensureInitialized(env); initVeloxJniFileSystem(env); initVeloxJniUDF(env); - initVeloxJniHashTable(env); + initVeloxJniHashTable(env, vm); infoCls = createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/validate/NativePlanValidationInfo;"); infoClsInitMethod = getMethodIdOrError(env, infoCls, "", "(ILjava/lang/String;)V"); @@ -94,8 +94,6 @@ jint JNI_OnLoad(JavaVM* vm, void*) { DLOG(INFO) << "Loaded Velox backend."; - gluten::vm = vm; - return jniVersion; } @@ -108,6 +106,7 @@ void JNI_OnUnload(JavaVM* vm, void*) { finalizeVeloxJniUDF(env); finalizeVeloxJniFileSystem(env); + finalizeVeloxJniHashTable(env); getJniErrorState()->close(); getJniCommonState()->close(); google::ShutdownGoogleLogging(); @@ -939,7 +938,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native jclass, jstring tableId, jlongArray batchHandles, - jstring joinKey, + jobjectArray joinKeys, jint joinType, jboolean hasMixedJoinCondition, jboolean isExistenceJoin, @@ -949,7 +948,16 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native jint broadcastHashTableBuildThreads) { JNI_METHOD_START const auto hashTableId = jStringToCString(env, tableId); - const auto hashJoinKey = jStringToCString(env, joinKey); + + // Convert Java String array to C++ vector + std::vector hashJoinKeys; + jsize joinKeysCount = env->GetArrayLength(joinKeys); + hashJoinKeys.reserve(joinKeysCount); + for (jsize i = 0; i < joinKeysCount; ++i) { + jstring jkey = (jstring)env->GetObjectArrayElement(joinKeys, i); + hashJoinKeys.emplace_back(jStringToCString(env, jkey)); + } + const auto inputType = gluten::getByteArrayElementsSafe(env, namedStruct); std::string structString{ reinterpret_cast(inputType.elems()), static_cast(inputType.length())}; @@ -988,7 +996,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native if (numThreads <= 1) { auto builder = nativeHashTableBuild( - hashJoinKey, + hashJoinKeys, names, veloxTypeList, joinType, @@ -1008,7 +1016,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native nullptr); builder->setHashTable(std::move(mainTable)); - return gluten::hashTableObjStore->save(builder); + return gluten::getHashTableObjStore()->save(builder); } std::vector threads; @@ -1027,7 +1035,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native } auto builder = nativeHashTableBuild( - hashJoinKey, + hashJoinKeys, names, veloxTypeList, joinType, @@ -1073,7 +1081,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native } hashTableBuilders[0]->setHashTable(std::move(mainTable)); - return gluten::hashTableObjStore->save(hashTableBuilders[0]); + return gluten::getHashTableObjStore()->save(hashTableBuilders[0]); JNI_METHOD_END(kInvalidObjectHandle) } @@ -1083,7 +1091,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_cloneH jlong tableHandler) { JNI_METHOD_START auto hashTableHandler = ObjectStore::retrieve(tableHandler); - return gluten::hashTableObjStore->save(hashTableHandler); + return gluten::getHashTableObjStore()->save(hashTableHandler); JNI_METHOD_END(kInvalidObjectHandle) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala index 5e6c77792289..76d8a50ccd76 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala @@ -44,14 +44,14 @@ case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan] plan match { case plan: CodegenSupport if plan.supportCodegen => if ( - (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize + (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum >= outputSize ) { return true } plan.children.exists(existsMultiCodegens(_, count + 1)) case plan: ShuffledHashJoinExec => if ( - (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize + (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum >= outputSize ) { return true } @@ -59,7 +59,7 @@ case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan] plan.children.exists(existsMultiCodegens(_, count + 1)) case plan: SortMergeJoinExec if GlutenConfig.get.forceShuffledHashJoin => if ( - (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize + (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum >= outputSize ) { return true }