Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public long rtHandle() {
public static native long nativeBuild(
String buildHashTableId,
long[] batchHandlers,
String joinKeys,
String[] joinKeys,
int joinType,
boolean hasMixedFiltCondition,
boolean isExistenceJoin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,4 @@ object GlutenRpcMessages {

case class GlutenCleanExecutionResource(executionId: String, broadcastHashIds: util.Set[String])
extends GlutenRpcMessage

// for mergetree cache
case class GlutenMergeTreeCacheLoad(
mergeTreeTable: String,
columns: util.Set[String],
onlyMetaCache: Boolean)
extends GlutenRpcMessage

case class GlutenCacheLoadStatus(jobId: String)

case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "")
extends GlutenRpcMessage

case class GlutenFilesCacheLoad(files: Array[Byte]) extends GlutenRpcMessage

case class GlutenFilesCacheLoadStatus(jobId: String)
}
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,18 @@ case class ColumnarBuildSideRelation(
)
}

val joinKey = keys.asScala
.map {
key =>
val attr = ConverterUtils.getAttrFromExpr(key)
ConverterUtils.genColumnNameWithExprId(attr)
}
.mkString(",")
val joinKeys = keys.asScala.map {
key =>
val attr = ConverterUtils.getAttrFromExpr(key)
ConverterUtils.genColumnNameWithExprId(attr)
}.toArray

// Build the hash table
hashTableData = HashJoinBuilder
.nativeBuild(
broadcastContext.buildHashTableId,
batchArray.toArray,
joinKey,
joinKeys,
broadcastContext.substraitJoinType.ordinal(),
broadcastContext.hasMixedFiltCondition,
broadcastContext.isExistenceJoin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,20 +167,18 @@ class UnsafeColumnarBuildSideRelation(
)
}

val joinKey = keys.asScala
.map {
key =>
val attr = ConverterUtils.getAttrFromExpr(key)
ConverterUtils.genColumnNameWithExprId(attr)
}
.mkString(",")
val joinKeys = keys.asScala.map {
key =>
val attr = ConverterUtils.getAttrFromExpr(key)
ConverterUtils.genColumnNameWithExprId(attr)
}.toArray

// Build the hash table
hashTableData = HashJoinBuilder
.nativeBuild(
broadcastContext.buildHashTableId,
batchArray.toArray,
joinKey,
joinKeys,
broadcastContext.substraitJoinType.ordinal(),
broadcastContext.hasMixedFiltCondition,
broadcastContext.isExistenceJoin,
Expand Down
1 change: 0 additions & 1 deletion cpp/velox/compute/VeloxBackend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ void VeloxBackend::tearDown() {
filesystem->close();
}
#endif
gluten::hashTableObjStore.reset();

// Destruct IOThreadPoolExecutor will join all threads.
// On threads exit, thread local variables can be constructed with referencing global variables.
Expand Down
48 changes: 21 additions & 27 deletions cpp/velox/jni/JniHashTable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,34 @@

namespace gluten {

static jclass jniVeloxBroadcastBuildSideCache = nullptr;
static jmethodID jniGet = nullptr;
void JniHashTableContext::initialize(JNIEnv* env, JavaVM* javaVm) {
vm_ = javaVm;
const char* classSig = "Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;";
jniVeloxBroadcastBuildSideCache_ = createGlobalClassReferenceOrError(env, classSig);
jniGet_ = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache_, "get", "(Ljava/lang/String;)J");
}

jlong callJavaGet(const std::string& id) {
void JniHashTableContext::finalize(JNIEnv* env) {
if (jniVeloxBroadcastBuildSideCache_ != nullptr) {
env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache_);
jniVeloxBroadcastBuildSideCache_ = nullptr;
}
}

jlong JniHashTableContext::callJavaGet(const std::string& id) const {
JNIEnv* env;
if (vm->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
throw gluten::GlutenException("JNIEnv was not attached to current thread");
}

const jstring s = env->NewStringUTF(id.c_str());

auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache, jniGet, s);
auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache_, jniGet_, s);
return result;
}

// Return the velox's hash table.
std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
const std::string& joinKeys,
const std::vector<std::string>& joinKeys,
std::vector<std::string> names,
std::vector<facebook::velox::TypePtr> veloxTypeList,
int joinType,
Expand Down Expand Up @@ -98,12 +108,9 @@ std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
VELOX_NYI("Unsupported Join type: {}", std::to_string(sJoin));
}

std::vector<std::string> joinKeyNames;
folly::split(',', joinKeys, joinKeyNames);

std::vector<std::shared_ptr<const facebook::velox::core::FieldAccessTypedExpr>> joinKeyTypes;
joinKeyTypes.reserve(joinKeyNames.size());
for (const auto& name : joinKeyNames) {
joinKeyTypes.reserve(joinKeys.size());
for (const auto& name : joinKeys) {
joinKeyTypes.emplace_back(
std::make_shared<facebook::velox::core::FieldAccessTypedExpr>(rowType->findChild(name), name));
}
Expand All @@ -125,21 +132,8 @@ std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
return hashTableBuilder;
}

long getJoin(std::string hashTableId) {
return callJavaGet(hashTableId);
}

void initVeloxJniHashTable(JNIEnv* env) {
if (env->GetJavaVM(&vm) != JNI_OK) {
throw gluten::GlutenException("Unable to get JavaVM instance");
}
const char* classSig = "Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;";
jniVeloxBroadcastBuildSideCache = createGlobalClassReferenceOrError(env, classSig);
jniGet = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache, "get", "(Ljava/lang/String;)J");
}

void finalizeVeloxJniHashTable(JNIEnv* env) {
env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache);
long getJoin(const std::string& hashTableId) {
return JniHashTableContext::getInstance().callJavaGet(hashTableId);
}

} // namespace gluten
63 changes: 56 additions & 7 deletions cpp/velox/jni/JniHashTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,53 @@

namespace gluten {

inline static JavaVM* vm = nullptr;
// Wrapper class to encapsulate JNI-related static objects for hash table operations.
// This avoids exposing global variables in the gluten namespace.
class JniHashTableContext {
public:
static JniHashTableContext& getInstance() {
static JniHashTableContext instance;
return instance;
}

inline static std::unique_ptr<ObjectStore> hashTableObjStore = ObjectStore::create();
// Delete copy and move constructors/operators
JniHashTableContext(const JniHashTableContext&) = delete;
JniHashTableContext& operator=(const JniHashTableContext&) = delete;
JniHashTableContext(JniHashTableContext&&) = delete;
JniHashTableContext& operator=(JniHashTableContext&&) = delete;

void initialize(JNIEnv* env, JavaVM* javaVm);
void finalize(JNIEnv* env);

JavaVM* getJavaVM() const {
return vm_;
}

ObjectStore* getHashTableObjStore() const {
return hashTableObjStore_.get();
}

jlong callJavaGet(const std::string& id) const;

private:
JniHashTableContext() : hashTableObjStore_(ObjectStore::create()) {}

~JniHashTableContext() {
// Note: The destructor is called at program exit (after main() returns).
// By this time, JNI_OnUnload should have already been called, which invokes
// finalize() to clean up JNI global references while the JVM is still valid.
// The singleton itself (including hashTableObjStore_) will be destroyed here.
}

JavaVM* vm_{nullptr};
std::unique_ptr<ObjectStore> hashTableObjStore_;
jclass jniVeloxBroadcastBuildSideCache_{nullptr};
jmethodID jniGet_{nullptr};
};

// Return the hash table builder address.
std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
const std::string& joinKeys,
const std::vector<std::string>& joinKeys,
std::vector<std::string> names,
std::vector<facebook::velox::TypePtr> veloxTypeList,
int joinType,
Expand All @@ -43,12 +83,21 @@ std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
std::vector<std::shared_ptr<ColumnarBatch>>& batches,
std::shared_ptr<facebook::velox::memory::MemoryPool> memoryPool);

long getJoin(std::string hashTableId);
long getJoin(const std::string& hashTableId);

void initVeloxJniHashTable(JNIEnv* env);
// Initialize the JNI hash table context
inline void initVeloxJniHashTable(JNIEnv* env, JavaVM* javaVm) {
JniHashTableContext::getInstance().initialize(env, javaVm);
}

void finalizeVeloxJniHashTable(JNIEnv* env);
// Finalize the JNI hash table context
inline void finalizeVeloxJniHashTable(JNIEnv* env) {
JniHashTableContext::getInstance().finalize(env);
}

jlong callJavaGet(const std::string& id);
// Get hash table object store
inline ObjectStore* getHashTableObjStore() {
return JniHashTableContext::getInstance().getHashTableObjStore();
}

} // namespace gluten
28 changes: 18 additions & 10 deletions cpp/velox/jni/VeloxJniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ jint JNI_OnLoad(JavaVM* vm, void*) {
getJniErrorState()->ensureInitialized(env);
initVeloxJniFileSystem(env);
initVeloxJniUDF(env);
initVeloxJniHashTable(env);
initVeloxJniHashTable(env, vm);

infoCls = createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/validate/NativePlanValidationInfo;");
infoClsInitMethod = getMethodIdOrError(env, infoCls, "<init>", "(ILjava/lang/String;)V");
Expand All @@ -94,8 +94,6 @@ jint JNI_OnLoad(JavaVM* vm, void*) {

DLOG(INFO) << "Loaded Velox backend.";

gluten::vm = vm;

return jniVersion;
}

Expand All @@ -108,6 +106,7 @@ void JNI_OnUnload(JavaVM* vm, void*) {

finalizeVeloxJniUDF(env);
finalizeVeloxJniFileSystem(env);
finalizeVeloxJniHashTable(env);
getJniErrorState()->close();
getJniCommonState()->close();
google::ShutdownGoogleLogging();
Expand Down Expand Up @@ -939,7 +938,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
jclass,
jstring tableId,
jlongArray batchHandles,
jstring joinKey,
jobjectArray joinKeys,
jint joinType,
jboolean hasMixedJoinCondition,
jboolean isExistenceJoin,
Expand All @@ -949,7 +948,16 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
jint broadcastHashTableBuildThreads) {
JNI_METHOD_START
const auto hashTableId = jStringToCString(env, tableId);
const auto hashJoinKey = jStringToCString(env, joinKey);

// Convert Java String array to C++ vector<string>
std::vector<std::string> hashJoinKeys;
jsize joinKeysCount = env->GetArrayLength(joinKeys);
hashJoinKeys.reserve(joinKeysCount);
for (jsize i = 0; i < joinKeysCount; ++i) {
jstring jkey = (jstring)env->GetObjectArrayElement(joinKeys, i);
hashJoinKeys.emplace_back(jStringToCString(env, jkey));
}

const auto inputType = gluten::getByteArrayElementsSafe(env, namedStruct);
std::string structString{
reinterpret_cast<const char*>(inputType.elems()), static_cast<std::string::size_type>(inputType.length())};
Expand Down Expand Up @@ -988,7 +996,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native

if (numThreads <= 1) {
auto builder = nativeHashTableBuild(
hashJoinKey,
hashJoinKeys,
names,
veloxTypeList,
joinType,
Expand All @@ -1008,7 +1016,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
nullptr);
builder->setHashTable(std::move(mainTable));

return gluten::hashTableObjStore->save(builder);
return gluten::getHashTableObjStore()->save(builder);
}

std::vector<std::thread> threads;
Expand All @@ -1027,7 +1035,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
}

auto builder = nativeHashTableBuild(
hashJoinKey,
hashJoinKeys,
names,
veloxTypeList,
joinType,
Expand Down Expand Up @@ -1073,7 +1081,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
}

hashTableBuilders[0]->setHashTable(std::move(mainTable));
return gluten::hashTableObjStore->save(hashTableBuilders[0]);
return gluten::getHashTableObjStore()->save(hashTableBuilders[0]);
JNI_METHOD_END(kInvalidObjectHandle)
}

Expand All @@ -1083,7 +1091,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_cloneH
jlong tableHandler) {
JNI_METHOD_START
auto hashTableHandler = ObjectStore::retrieve<gluten::HashTableBuilder>(tableHandler);
return gluten::hashTableObjStore->save(hashTableHandler);
return gluten::getHashTableObjStore()->save(hashTableHandler);
JNI_METHOD_END(kInvalidObjectHandle)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,22 @@ case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan]
plan match {
case plan: CodegenSupport if plan.supportCodegen =>
if (
(count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize
(count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum >= outputSize
) {
return true
}
plan.children.exists(existsMultiCodegens(_, count + 1))
case plan: ShuffledHashJoinExec =>
if (
(count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize
(count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum >= outputSize
) {
return true
}

plan.children.exists(existsMultiCodegens(_, count + 1))
case plan: SortMergeJoinExec if GlutenConfig.get.forceShuffledHashJoin =>
if (
(count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize
(count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum >= outputSize
) {
return true
}
Expand Down
Loading