From 4138a6e93487d5752115c72333d1d6792b682628 Mon Sep 17 00:00:00 2001 From: Ke Jia Date: Mon, 9 Mar 2026 05:00:17 -0700 Subject: [PATCH 1/2] Resolve comments --- .../gluten/vectorized/HashJoinBuilder.java | 2 +- .../execution/ColumnarBuildSideRelation.scala | 14 +++++------ .../UnsafeColumnarBuildSideRelation.scala | 14 +++++------ cpp/velox/jni/JniHashTable.cc | 19 +++++++-------- cpp/velox/jni/JniHashTable.h | 8 +++---- cpp/velox/jni/VeloxJniWrapper.cc | 24 +++++++++++++------ .../extension/columnar/FallbackRules.scala | 6 ++--- 7 files changed, 45 insertions(+), 42 deletions(-) diff --git a/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java b/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java index e54909054cea..ebfd47669ce8 100644 --- a/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java +++ b/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java @@ -42,7 +42,7 @@ public long rtHandle() { public static native long nativeBuild( String buildHashTableId, long[] batchHandlers, - String joinKeys, + String[] joinKeys, int joinType, boolean hasMixedFiltCondition, boolean isExistenceJoin, diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala index 6429f8bb3fc5..b106319e81b8 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala @@ -197,20 +197,18 @@ case class ColumnarBuildSideRelation( ) } - val joinKey = keys.asScala - .map { - key => - val attr = ConverterUtils.getAttrFromExpr(key) - ConverterUtils.genColumnNameWithExprId(attr) - } - .mkString(",") + val joinKeys = keys.asScala.map { + key => + val attr = ConverterUtils.getAttrFromExpr(key) + ConverterUtils.genColumnNameWithExprId(attr) + }.toArray // Build the hash table hashTableData = HashJoinBuilder .nativeBuild( broadcastContext.buildHashTableId, batchArray.toArray, - joinKey, + joinKeys, broadcastContext.substraitJoinType.ordinal(), broadcastContext.hasMixedFiltCondition, broadcastContext.isExistenceJoin, diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala index fc7516c4b325..01fbb86bee68 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala @@ -167,20 +167,18 @@ class UnsafeColumnarBuildSideRelation( ) } - val joinKey = keys.asScala - .map { - key => - val attr = ConverterUtils.getAttrFromExpr(key) - ConverterUtils.genColumnNameWithExprId(attr) - } - .mkString(",") + val joinKeys = keys.asScala.map { + key => + val attr = ConverterUtils.getAttrFromExpr(key) + ConverterUtils.genColumnNameWithExprId(attr) + }.toArray // Build the hash table hashTableData = HashJoinBuilder .nativeBuild( broadcastContext.buildHashTableId, batchArray.toArray, - joinKey, + joinKeys, broadcastContext.substraitJoinType.ordinal(), broadcastContext.hasMixedFiltCondition, broadcastContext.isExistenceJoin, diff --git a/cpp/velox/jni/JniHashTable.cc b/cpp/velox/jni/JniHashTable.cc index 77cd78ff6a4c..f845f418a7b5 100644 --- a/cpp/velox/jni/JniHashTable.cc +++ b/cpp/velox/jni/JniHashTable.cc @@ -29,6 +29,8 @@ namespace gluten { +JavaVM* vm = nullptr; + static jclass jniVeloxBroadcastBuildSideCache = nullptr; static jmethodID jniGet = nullptr; @@ -46,7 +48,7 @@ jlong callJavaGet(const std::string& id) { // Return the velox's hash table. std::shared_ptr nativeHashTableBuild( - const std::string& joinKeys, + const std::vector& joinKeys, std::vector names, std::vector veloxTypeList, int joinType, @@ -98,12 +100,9 @@ std::shared_ptr nativeHashTableBuild( VELOX_NYI("Unsupported Join type: {}", std::to_string(sJoin)); } - std::vector joinKeyNames; - folly::split(',', joinKeys, joinKeyNames); - std::vector> joinKeyTypes; - joinKeyTypes.reserve(joinKeyNames.size()); - for (const auto& name : joinKeyNames) { + joinKeyTypes.reserve(joinKeys.size()); + for (const auto& name : joinKeys) { joinKeyTypes.emplace_back( std::make_shared(rowType->findChild(name), name)); } @@ -125,14 +124,12 @@ std::shared_ptr nativeHashTableBuild( return hashTableBuilder; } -long getJoin(std::string hashTableId) { +long getJoin(const std::string& hashTableId) { return callJavaGet(hashTableId); } -void initVeloxJniHashTable(JNIEnv* env) { - if (env->GetJavaVM(&vm) != JNI_OK) { - throw gluten::GlutenException("Unable to get JavaVM instance"); - } +void initVeloxJniHashTable(JNIEnv* env, JavaVM* javaVm) { + vm = javaVm; const char* classSig = "Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;"; jniVeloxBroadcastBuildSideCache = createGlobalClassReferenceOrError(env, classSig); jniGet = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache, "get", "(Ljava/lang/String;)J"); diff --git a/cpp/velox/jni/JniHashTable.h b/cpp/velox/jni/JniHashTable.h index c0d9227840d9..6144522bf217 100644 --- a/cpp/velox/jni/JniHashTable.h +++ b/cpp/velox/jni/JniHashTable.h @@ -26,13 +26,13 @@ namespace gluten { -inline static JavaVM* vm = nullptr; +extern JavaVM* vm; inline static std::unique_ptr hashTableObjStore = ObjectStore::create(); // Return the hash table builder address. std::shared_ptr nativeHashTableBuild( - const std::string& joinKeys, + const std::vector& joinKeys, std::vector names, std::vector veloxTypeList, int joinType, @@ -43,9 +43,9 @@ std::shared_ptr nativeHashTableBuild( std::vector>& batches, std::shared_ptr memoryPool); -long getJoin(std::string hashTableId); +long getJoin(const std::string& hashTableId); -void initVeloxJniHashTable(JNIEnv* env); +void initVeloxJniHashTable(JNIEnv* env, JavaVM* javaVm); void finalizeVeloxJniHashTable(JNIEnv* env); diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc index e488274e9718..3add62dfc709 100644 --- a/cpp/velox/jni/VeloxJniWrapper.cc +++ b/cpp/velox/jni/VeloxJniWrapper.cc @@ -80,7 +80,7 @@ jint JNI_OnLoad(JavaVM* vm, void*) { getJniErrorState()->ensureInitialized(env); initVeloxJniFileSystem(env); initVeloxJniUDF(env); - initVeloxJniHashTable(env); + initVeloxJniHashTable(env, vm); infoCls = createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/validate/NativePlanValidationInfo;"); infoClsInitMethod = getMethodIdOrError(env, infoCls, "", "(ILjava/lang/String;)V"); @@ -94,8 +94,6 @@ jint JNI_OnLoad(JavaVM* vm, void*) { DLOG(INFO) << "Loaded Velox backend."; - gluten::vm = vm; - return jniVersion; } @@ -939,7 +937,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native jclass, jstring tableId, jlongArray batchHandles, - jstring joinKey, + jobjectArray joinKeys, jint joinType, jboolean hasMixedJoinCondition, jboolean isExistenceJoin, @@ -949,7 +947,19 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native jint broadcastHashTableBuildThreads) { JNI_METHOD_START const auto hashTableId = jStringToCString(env, tableId); - const auto hashJoinKey = jStringToCString(env, joinKey); + + // Convert Java String array to C++ vector + std::vector hashJoinKeys; + jsize joinKeysCount = env->GetArrayLength(joinKeys); + hashJoinKeys.reserve(joinKeysCount); + for (jsize i = 0; i < joinKeysCount; ++i) { + jstring jkey = (jstring)env->GetObjectArrayElement(joinKeys, i); + const char* keyChars = env->GetStringUTFChars(jkey, nullptr); + hashJoinKeys.emplace_back(keyChars); + env->ReleaseStringUTFChars(jkey, keyChars); + env->DeleteLocalRef(jkey); + } + const auto inputType = gluten::getByteArrayElementsSafe(env, namedStruct); std::string structString{ reinterpret_cast(inputType.elems()), static_cast(inputType.length())}; @@ -988,7 +998,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native if (numThreads <= 1) { auto builder = nativeHashTableBuild( - hashJoinKey, + hashJoinKeys, names, veloxTypeList, joinType, @@ -1027,7 +1037,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native } auto builder = nativeHashTableBuild( - hashJoinKey, + hashJoinKeys, names, veloxTypeList, joinType, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala index 5e6c77792289..76d8a50ccd76 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala @@ -44,14 +44,14 @@ case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan] plan match { case plan: CodegenSupport if plan.supportCodegen => if ( - (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize + (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum >= outputSize ) { return true } plan.children.exists(existsMultiCodegens(_, count + 1)) case plan: ShuffledHashJoinExec => if ( - (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize + (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum >= outputSize ) { return true } @@ -59,7 +59,7 @@ case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan] plan.children.exists(existsMultiCodegens(_, count + 1)) case plan: SortMergeJoinExec if GlutenConfig.get.forceShuffledHashJoin => if ( - (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize + (count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum >= outputSize ) { return true } From 6fd8a2fe607973932a720370e022f08787eb5270 Mon Sep 17 00:00:00 2001 From: Ke Jia Date: Tue, 10 Mar 2026 03:49:05 -0700 Subject: [PATCH 2/2] Resolve comments --- .../apache/spark/rpc/GlutenRpcMessages.scala | 16 ----- cpp/velox/compute/VeloxBackend.cc | 1 - cpp/velox/jni/JniHashTable.cc | 35 +++++------ cpp/velox/jni/JniHashTable.h | 59 +++++++++++++++++-- cpp/velox/jni/VeloxJniWrapper.cc | 12 ++-- 5 files changed, 75 insertions(+), 48 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala index 8127c324b79c..dec67eed7878 100644 --- a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala +++ b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala @@ -34,20 +34,4 @@ object GlutenRpcMessages { case class GlutenCleanExecutionResource(executionId: String, broadcastHashIds: util.Set[String]) extends GlutenRpcMessage - - // for mergetree cache - case class GlutenMergeTreeCacheLoad( - mergeTreeTable: String, - columns: util.Set[String], - onlyMetaCache: Boolean) - extends GlutenRpcMessage - - case class GlutenCacheLoadStatus(jobId: String) - - case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "") - extends GlutenRpcMessage - - case class GlutenFilesCacheLoad(files: Array[Byte]) extends GlutenRpcMessage - - case class GlutenFilesCacheLoadStatus(jobId: String) } diff --git a/cpp/velox/compute/VeloxBackend.cc b/cpp/velox/compute/VeloxBackend.cc index 0232da48da14..de9e9385f8f0 100644 --- a/cpp/velox/compute/VeloxBackend.cc +++ b/cpp/velox/compute/VeloxBackend.cc @@ -362,7 +362,6 @@ void VeloxBackend::tearDown() { filesystem->close(); } #endif - gluten::hashTableObjStore.reset(); // Destruct IOThreadPoolExecutor will join all threads. // On threads exit, thread local variables can be constructed with referencing global variables. diff --git a/cpp/velox/jni/JniHashTable.cc b/cpp/velox/jni/JniHashTable.cc index f845f418a7b5..8af60a5534e7 100644 --- a/cpp/velox/jni/JniHashTable.cc +++ b/cpp/velox/jni/JniHashTable.cc @@ -29,20 +29,28 @@ namespace gluten { -JavaVM* vm = nullptr; +void JniHashTableContext::initialize(JNIEnv* env, JavaVM* javaVm) { + vm_ = javaVm; + const char* classSig = "Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;"; + jniVeloxBroadcastBuildSideCache_ = createGlobalClassReferenceOrError(env, classSig); + jniGet_ = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache_, "get", "(Ljava/lang/String;)J"); +} -static jclass jniVeloxBroadcastBuildSideCache = nullptr; -static jmethodID jniGet = nullptr; +void JniHashTableContext::finalize(JNIEnv* env) { + if (jniVeloxBroadcastBuildSideCache_ != nullptr) { + env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache_); + jniVeloxBroadcastBuildSideCache_ = nullptr; + } +} -jlong callJavaGet(const std::string& id) { +jlong JniHashTableContext::callJavaGet(const std::string& id) const { JNIEnv* env; - if (vm->GetEnv(reinterpret_cast(&env), jniVersion) != JNI_OK) { + if (vm_->GetEnv(reinterpret_cast(&env), jniVersion) != JNI_OK) { throw gluten::GlutenException("JNIEnv was not attached to current thread"); } const jstring s = env->NewStringUTF(id.c_str()); - - auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache, jniGet, s); + auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache_, jniGet_, s); return result; } @@ -125,18 +133,7 @@ std::shared_ptr nativeHashTableBuild( } long getJoin(const std::string& hashTableId) { - return callJavaGet(hashTableId); -} - -void initVeloxJniHashTable(JNIEnv* env, JavaVM* javaVm) { - vm = javaVm; - const char* classSig = "Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;"; - jniVeloxBroadcastBuildSideCache = createGlobalClassReferenceOrError(env, classSig); - jniGet = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache, "get", "(Ljava/lang/String;)J"); -} - -void finalizeVeloxJniHashTable(JNIEnv* env) { - env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache); + return JniHashTableContext::getInstance().callJavaGet(hashTableId); } } // namespace gluten diff --git a/cpp/velox/jni/JniHashTable.h b/cpp/velox/jni/JniHashTable.h index 6144522bf217..27061e1778ab 100644 --- a/cpp/velox/jni/JniHashTable.h +++ b/cpp/velox/jni/JniHashTable.h @@ -26,9 +26,49 @@ namespace gluten { -extern JavaVM* vm; +// Wrapper class to encapsulate JNI-related static objects for hash table operations. +// This avoids exposing global variables in the gluten namespace. +class JniHashTableContext { + public: + static JniHashTableContext& getInstance() { + static JniHashTableContext instance; + return instance; + } -inline static std::unique_ptr hashTableObjStore = ObjectStore::create(); + // Delete copy and move constructors/operators + JniHashTableContext(const JniHashTableContext&) = delete; + JniHashTableContext& operator=(const JniHashTableContext&) = delete; + JniHashTableContext(JniHashTableContext&&) = delete; + JniHashTableContext& operator=(JniHashTableContext&&) = delete; + + void initialize(JNIEnv* env, JavaVM* javaVm); + void finalize(JNIEnv* env); + + JavaVM* getJavaVM() const { + return vm_; + } + + ObjectStore* getHashTableObjStore() const { + return hashTableObjStore_.get(); + } + + jlong callJavaGet(const std::string& id) const; + + private: + JniHashTableContext() : hashTableObjStore_(ObjectStore::create()) {} + + ~JniHashTableContext() { + // Note: The destructor is called at program exit (after main() returns). + // By this time, JNI_OnUnload should have already been called, which invokes + // finalize() to clean up JNI global references while the JVM is still valid. + // The singleton itself (including hashTableObjStore_) will be destroyed here. + } + + JavaVM* vm_{nullptr}; + std::unique_ptr hashTableObjStore_; + jclass jniVeloxBroadcastBuildSideCache_{nullptr}; + jmethodID jniGet_{nullptr}; +}; // Return the hash table builder address. std::shared_ptr nativeHashTableBuild( @@ -45,10 +85,19 @@ std::shared_ptr nativeHashTableBuild( long getJoin(const std::string& hashTableId); -void initVeloxJniHashTable(JNIEnv* env, JavaVM* javaVm); +// Initialize the JNI hash table context +inline void initVeloxJniHashTable(JNIEnv* env, JavaVM* javaVm) { + JniHashTableContext::getInstance().initialize(env, javaVm); +} -void finalizeVeloxJniHashTable(JNIEnv* env); +// Finalize the JNI hash table context +inline void finalizeVeloxJniHashTable(JNIEnv* env) { + JniHashTableContext::getInstance().finalize(env); +} -jlong callJavaGet(const std::string& id); +// Get hash table object store +inline ObjectStore* getHashTableObjStore() { + return JniHashTableContext::getInstance().getHashTableObjStore(); +} } // namespace gluten diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc index 3add62dfc709..ed1cd5e85dc1 100644 --- a/cpp/velox/jni/VeloxJniWrapper.cc +++ b/cpp/velox/jni/VeloxJniWrapper.cc @@ -106,6 +106,7 @@ void JNI_OnUnload(JavaVM* vm, void*) { finalizeVeloxJniUDF(env); finalizeVeloxJniFileSystem(env); + finalizeVeloxJniHashTable(env); getJniErrorState()->close(); getJniCommonState()->close(); google::ShutdownGoogleLogging(); @@ -954,10 +955,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native hashJoinKeys.reserve(joinKeysCount); for (jsize i = 0; i < joinKeysCount; ++i) { jstring jkey = (jstring)env->GetObjectArrayElement(joinKeys, i); - const char* keyChars = env->GetStringUTFChars(jkey, nullptr); - hashJoinKeys.emplace_back(keyChars); - env->ReleaseStringUTFChars(jkey, keyChars); - env->DeleteLocalRef(jkey); + hashJoinKeys.emplace_back(jStringToCString(env, jkey)); } const auto inputType = gluten::getByteArrayElementsSafe(env, namedStruct); @@ -1018,7 +1016,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native nullptr); builder->setHashTable(std::move(mainTable)); - return gluten::hashTableObjStore->save(builder); + return gluten::getHashTableObjStore()->save(builder); } std::vector threads; @@ -1083,7 +1081,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native } hashTableBuilders[0]->setHashTable(std::move(mainTable)); - return gluten::hashTableObjStore->save(hashTableBuilders[0]); + return gluten::getHashTableObjStore()->save(hashTableBuilders[0]); JNI_METHOD_END(kInvalidObjectHandle) } @@ -1093,7 +1091,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_cloneH jlong tableHandler) { JNI_METHOD_START auto hashTableHandler = ObjectStore::retrieve(tableHandler); - return gluten::hashTableObjStore->save(hashTableHandler); + return gluten::getHashTableObjStore()->save(hashTableHandler); JNI_METHOD_END(kInvalidObjectHandle) }