Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 198 additions & 117 deletions native/src/jni/hook_bridge.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <alloca.h>
#include <parallel_hashmap/phmap.h>

#include <lsplant.hpp>
Expand Down Expand Up @@ -273,171 +274,251 @@ VECTOR_DEF_NATIVE_METHOD(jobject, HookBridge, allocateObject, jclass cls) {
}

/**
* @brief A high-performance, low-level implementation of Method.invoke for super.method() calls.
* Core JNI backend for non-virtual method invocation and special object initialization.
*
* This function manually unboxes arguments from a jobject array into a jvalue C-style array,
* calls the appropriate JNI `CallNonvirtual...MethodA` function,
* and then boxes the return value back into a jobject.
* This avoids the overhead of Java reflection.
*
* @warning This is a very sensitive function.
* The `shorty` descriptor must perfectly match the method's actual signature.
* Implementation details:
* 1. Dispatches using JNI CallNonvirtual<Type>MethodA.
* 2. Employs stack allocation (alloca) for JNI argument mapping.
* 3. Safely mirrors standard Java reflection (NPEs on null primitives/receivers).
* 4. Prevents JNI Type Confusion and memory leaks by caching primitive wrappers globally,
* while leveraging java.lang.Number for fast implicit widening/narrowing.
* 5. Accurately catches and wraps target method exceptions into InvocationTargetException.
*/
VECTOR_DEF_NATIVE_METHOD(jobject, HookBridge, invokeSpecialMethod, jobject method,
jcharArray shorty, jclass cls, jobject thiz, jobjectArray args) {
// --- Cache all necessary MethodIDs for boxing/unboxing primitive wrappers
// --- This is a major performance optimization, done only once.
static auto *const get_int =
env->GetMethodID(env->FindClass("java/lang/Integer"), "intValue", "()I");
static auto *const get_double =
env->GetMethodID(env->FindClass("java/lang/Double"), "doubleValue", "()D");
static auto *const get_long =
env->GetMethodID(env->FindClass("java/lang/Long"), "longValue", "()J");
static auto *const get_float =
env->GetMethodID(env->FindClass("java/lang/Float"), "floatValue", "()F");
static auto *const get_short =
env->GetMethodID(env->FindClass("java/lang/Short"), "shortValue", "()S");
static auto *const get_byte =
env->GetMethodID(env->FindClass("java/lang/Byte"), "byteValue", "()B");
static auto *const get_char =
env->GetMethodID(env->FindClass("java/lang/Character"), "charValue", "()C");
static auto *const get_boolean =
env->GetMethodID(env->FindClass("java/lang/Boolean"), "booleanValue", "()Z");
static auto *const set_int = env->GetStaticMethodID(env->FindClass("java/lang/Integer"),
"valueOf", "(I)Ljava/lang/Integer;");
static auto *const set_double = env->GetStaticMethodID(env->FindClass("java/lang/Double"),
"valueOf", "(D)Ljava/lang/Double;");
// --- JNI Global Reference Caching ---
// Cached once per process lifecycle to maintain extreme performance and prevent JNI aborts.
static jclass cls_Number = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Number"));
static jclass cls_Boolean = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Boolean"));
static jclass cls_Character = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Character"));

// Globally cache primitive wrapper classes for safe return value boxing
static jclass cls_Integer = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Integer"));
static jclass cls_Double = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Double"));
static jclass cls_Long = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Long"));
static jclass cls_Float = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Float"));
static jclass cls_Short = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Short"));
static jclass cls_Byte = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Byte"));

static jclass cls_ITE =
(jclass)env->NewGlobalRef(env->FindClass("java/lang/reflect/InvocationTargetException"));

static auto *const ctor_ite = env->GetMethodID(cls_ITE, "<init>", "(Ljava/lang/Throwable;)V");

static auto *const get_int = env->GetMethodID(cls_Number, "intValue", "()I");
static auto *const get_double = env->GetMethodID(cls_Number, "doubleValue", "()D");
static auto *const get_long = env->GetMethodID(cls_Number, "longValue", "()J");
static auto *const get_float = env->GetMethodID(cls_Number, "floatValue", "()F");
static auto *const get_short = env->GetMethodID(cls_Number, "shortValue", "()S");
static auto *const get_byte = env->GetMethodID(cls_Number, "byteValue", "()B");

static auto *const get_char = env->GetMethodID(cls_Character, "charValue", "()C");
static auto *const get_boolean = env->GetMethodID(cls_Boolean, "booleanValue", "()Z");

static auto *const set_int =
env->GetStaticMethodID(cls_Integer, "valueOf", "(I)Ljava/lang/Integer;");
static auto *const set_double =
env->GetStaticMethodID(cls_Double, "valueOf", "(D)Ljava/lang/Double;");
static auto *const set_long =
env->GetStaticMethodID(env->FindClass("java/lang/Long"), "valueOf", "(J)Ljava/lang/Long;");
static auto *const set_float = env->GetStaticMethodID(env->FindClass("java/lang/Float"),
"valueOf", "(F)Ljava/lang/Float;");
static auto *const set_short = env->GetStaticMethodID(env->FindClass("java/lang/Short"),
"valueOf", "(S)Ljava/lang/Short;");
env->GetStaticMethodID(cls_Long, "valueOf", "(J)Ljava/lang/Long;");
static auto *const set_float =
env->GetStaticMethodID(cls_Float, "valueOf", "(F)Ljava/lang/Float;");
static auto *const set_short =
env->GetStaticMethodID(cls_Short, "valueOf", "(S)Ljava/lang/Short;");
static auto *const set_byte =
env->GetStaticMethodID(env->FindClass("java/lang/Byte"), "valueOf", "(B)Ljava/lang/Byte;");
static auto *const set_char = env->GetStaticMethodID(env->FindClass("java/lang/Character"),
"valueOf", "(C)Ljava/lang/Character;");
static auto *const set_boolean = env->GetStaticMethodID(env->FindClass("java/lang/Boolean"),
"valueOf", "(Z)Ljava/lang/Boolean;");
env->GetStaticMethodID(cls_Byte, "valueOf", "(B)Ljava/lang/Byte;");
static auto *const set_char =
env->GetStaticMethodID(cls_Character, "valueOf", "(C)Ljava/lang/Character;");
static auto *const set_boolean =
env->GetStaticMethodID(cls_Boolean, "valueOf", "(Z)Ljava/lang/Boolean;");

auto target = env->FromReflectedMethod(method);
auto param_len = env->GetArrayLength(shorty) - 1; // First char is return type.
auto param_len = env->GetArrayLength(shorty) - 1;

// --- Argument Validation ---
if (env->GetArrayLength(args) != param_len) {
// --- Argument & Receiver Validation ---
auto args_len = args != nullptr ? env->GetArrayLength(args) : 0;
if (args_len != param_len) {
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
"args.length does not match parameter count");
return nullptr;
}

if (thiz == nullptr) {
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
"`this` cannot be null for a non-virtual call");
env->ThrowNew(env->FindClass("java/lang/NullPointerException"), "null receiver");
return nullptr;
}

// --- Unbox Arguments ---
std::vector<jvalue> a(param_len);
// Allocate jvalue array on the stack
jvalue *a = param_len > 0 ? static_cast<jvalue *>(alloca(param_len * sizeof(jvalue))) : nullptr;

auto *const shorty_char = env->GetCharArrayElements(shorty, nullptr);
if (shorty_char == nullptr) {
return nullptr; // JVM already threw OutOfMemoryError
}

// RAII/Helper for clean JNI array exits
auto abort_and_return = [&]() {
env->ReleaseCharArrayElements(shorty, shorty_char, JNI_ABORT);
return nullptr;
};

// --- Safe Unboxing ---
for (jint i = 0; i != param_len; ++i) {
jobject element = env->GetObjectArrayElement(args, i);
if (env->ExceptionCheck()) {
env->ReleaseCharArrayElements(shorty, shorty_char, JNI_ABORT);
return nullptr;
}
if (env->ExceptionCheck()) return abort_and_return();

// The shorty string at index i+1 describes the type of the i-th parameter.
switch (shorty_char[i + 1]) {
case 'I':
a[i].i = env->CallIntMethod(element, get_int);
break;
case 'D':
a[i].d = env->CallDoubleMethod(element, get_double);
break;
case 'J':
a[i].j = env->CallLongMethod(element, get_long);
break;
case 'F':
a[i].f = env->CallFloatMethod(element, get_float);
break;
case 'S':
a[i].s = env->CallShortMethod(element, get_short);
break;
case 'B':
a[i].b = env->CallByteMethod(element, get_byte);
break;
case 'C':
a[i].c = env->CallCharMethod(element, get_char);
break;
case 'Z':
a[i].z = env->CallBooleanMethod(element, get_boolean);
break;
default: // Assumes 'L' or '[' for object types
a[i].l = element;
// Set element to null so we don't delete the local ref twice.
// The reference is stored in the jvalue array and is still valid.
element = nullptr;
break;
char type = shorty_char[i + 1];

if (element == nullptr) {
if (type != 'L' && type != '[') {
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
"null primitive argument");
return abort_and_return();
}
a[i].l = nullptr;
} else {
if (type == 'Z') {
if (!env->IsInstanceOf(element, cls_Boolean)) {
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
"Expected Boolean");
return abort_and_return();
}
a[i].z = env->CallBooleanMethod(element, get_boolean);
} else if (type == 'C') {
if (!env->IsInstanceOf(element, cls_Character)) {
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
"Expected Character");
return abort_and_return();
}
a[i].c = env->CallCharMethod(element, get_char);
} else if (type != 'L' && type != '[') {
bool is_number = env->IsInstanceOf(element, cls_Number) == JNI_TRUE;
bool is_character =
!is_number && (env->IsInstanceOf(element, cls_Character) == JNI_TRUE);

if (!is_number && !is_character) {
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
"Expected Number or Character");
return abort_and_return();
}

// If a Character is passed to a numeric parameter, extract its value for widening
jchar c_val = 0;
if (is_character) {
c_val = env->CallCharMethod(element, get_char);
if (env->ExceptionCheck()) return abort_and_return();
}

switch (type) {
case 'I':
a[i].i = env->CallIntMethod(element, get_int);
break;
case 'D':
a[i].d = env->CallDoubleMethod(element, get_double);
break;
case 'J':
a[i].j = env->CallLongMethod(element, get_long);
break;
case 'F':
a[i].f = env->CallFloatMethod(element, get_float);
break;
case 'S':
a[i].s = env->CallShortMethod(element, get_short);
break;
case 'B':
a[i].b = env->CallByteMethod(element, get_byte);
break;
}
} else {
a[i].l = element;
element =
nullptr; // Transferred ownership to jvalue array; will be freed on return
}
}

// Clean up the local reference for the wrapper object if it was created.
if (element) env->DeleteLocalRef(element);
if (env->ExceptionCheck()) return abort_and_return();
}

// Check for exceptions during the unboxing call (e.g.,
// NullPointerException).
if (env->ExceptionCheck()) {
env->ReleaseCharArrayElements(shorty, shorty_char, JNI_ABORT);
return nullptr;
// --- Non-virtual Invocation ---
jvalue ret_val;
switch (shorty_char[0]) {
case 'I':
ret_val.i = env->CallNonvirtualIntMethodA(thiz, cls, target, a);
break;
case 'D':
ret_val.d = env->CallNonvirtualDoubleMethodA(thiz, cls, target, a);
break;
case 'J':
ret_val.j = env->CallNonvirtualLongMethodA(thiz, cls, target, a);
break;
case 'F':
ret_val.f = env->CallNonvirtualFloatMethodA(thiz, cls, target, a);
break;
case 'S':
ret_val.s = env->CallNonvirtualShortMethodA(thiz, cls, target, a);
break;
case 'B':
ret_val.b = env->CallNonvirtualByteMethodA(thiz, cls, target, a);
break;
case 'C':
ret_val.c = env->CallNonvirtualCharMethodA(thiz, cls, target, a);
break;
case 'Z':
ret_val.z = env->CallNonvirtualBooleanMethodA(thiz, cls, target, a);
break;
case 'L':
ret_val.l = env->CallNonvirtualObjectMethodA(thiz, cls, target, a);
break;
default:
env->CallNonvirtualVoidMethodA(thiz, cls, target, a);
break;
}

// --- Exception Wrapping ---
jthrowable target_exception = env->ExceptionOccurred();
if (target_exception) {
env->ExceptionClear();
jobject ite = env->NewObject(cls_ITE, ctor_ite, target_exception);
// Ensure NewObject didn't fail due to OOM before throwing
if (ite) {
env->Throw(static_cast<jthrowable>(ite));
}
return abort_and_return();
}

// --- Call Non-virtual Method and Box Return Value ---
// --- Box Return Value ---
jobject value = nullptr;
// The shorty string at index 0 describes the return type.
switch (shorty_char[0]) {
case 'I':
value =
env->CallStaticObjectMethod(jclass{nullptr},
set_int, // Use Integer.valueOf() to box
env->CallNonvirtualIntMethodA(thiz, cls, target, a.data()));
value = env->CallStaticObjectMethod(cls_Integer, set_int, ret_val.i);
break;
case 'D':
value = env->CallStaticObjectMethod(
jclass{nullptr}, set_double,
env->CallNonvirtualDoubleMethodA(thiz, cls, target, a.data()));
value = env->CallStaticObjectMethod(cls_Double, set_double, ret_val.d);
break;
case 'J':
value = env->CallStaticObjectMethod(
jclass{nullptr}, set_long, env->CallNonvirtualLongMethodA(thiz, cls, target, a.data()));
value = env->CallStaticObjectMethod(cls_Long, set_long, ret_val.j);
break;
case 'F':
value = env->CallStaticObjectMethod(
jclass{nullptr}, set_float,
env->CallNonvirtualFloatMethodA(thiz, cls, target, a.data()));
value = env->CallStaticObjectMethod(cls_Float, set_float, ret_val.f);
break;
case 'S':
value = env->CallStaticObjectMethod(
jclass{nullptr}, set_short,
env->CallNonvirtualShortMethodA(thiz, cls, target, a.data()));
value = env->CallStaticObjectMethod(cls_Short, set_short, ret_val.s);
break;
case 'B':
value = env->CallStaticObjectMethod(
jclass{nullptr}, set_byte, env->CallNonvirtualByteMethodA(thiz, cls, target, a.data()));
value = env->CallStaticObjectMethod(cls_Byte, set_byte, ret_val.b);
break;
case 'C':
value = env->CallStaticObjectMethod(
jclass{nullptr}, set_char, env->CallNonvirtualCharMethodA(thiz, cls, target, a.data()));
value = env->CallStaticObjectMethod(cls_Character, set_char, ret_val.c);
break;
case 'Z':
value = env->CallStaticObjectMethod(
jclass{nullptr}, set_boolean,
env->CallNonvirtualBooleanMethodA(thiz, cls, target, a.data()));
value = env->CallStaticObjectMethod(cls_Boolean, set_boolean, ret_val.z);
break;
case 'L': // Return type is an object, no boxing needed.
value = env->CallNonvirtualObjectMethodA(thiz, cls, target, a.data());
case 'L':
value = ret_val.l;
break;
default: // Assumes 'V' for void return type.
case 'V':
env->CallNonvirtualVoidMethodA(thiz, cls, target, a.data());
value = nullptr;
break;
}

Expand Down
Loading