Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions utils/local-engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_headers_and_sources(common Common)
add_headers_and_sources(external External)
add_headers_and_sources(shuffle Shuffle)
add_headers_and_sources(operator Operator)
add_headers_and_sources(jni jni)

include_directories(
${JNI_INCLUDE_DIRS}
Expand All @@ -38,6 +39,7 @@ add_library(${LOCALENGINE_SHARED_LIB} SHARED
${external_sources}
${shuffle_sources}
${operator_sources}
${jni_sources}
local_engine_jni.cpp)


Expand Down
42 changes: 34 additions & 8 deletions utils/local-engine/Operator/PartitionColumnFillingTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <Common/StringUtils.h>
#include <Columns/IColumn.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Interpreters/DatabaseAndTableWithAlias.h>
#include <base/DayNum.h>


using namespace DB;

Expand Down Expand Up @@ -53,6 +59,18 @@ PartitionColumnFillingTransform::PartitionColumnFillingTransform(
partition_column = createPartitionColumn();
}

/// In the case that a partition column is wrapper by nullable and LowCardinality, we need to keep the data type same
/// as input.
ColumnPtr PartitionColumnFillingTransform::tryWrapPartitionColumn(const ColumnPtr & nested_col, DataTypePtr original_data_type)
{
auto result = nested_col;
if (original_data_type->getTypeId() == TypeIndex::Nullable)
{
result = ColumnNullable::create(nested_col, ColumnUInt8::create());
}
return result;
}

ColumnPtr PartitionColumnFillingTransform::createPartitionColumn()
{
ColumnPtr result;
Expand All @@ -68,43 +86,51 @@ ColumnPtr PartitionColumnFillingTransform::createPartitionColumn()
WhichDataType which(nested_type);
if (which.isInt8())
{
result = createIntPartitionColumn<Int8>(partition_col_type, partition_col_value);
result = createIntPartitionColumn<Int8>(nested_type, partition_col_value);
}
else if (which.isInt16())
{
result = createIntPartitionColumn<Int16>(partition_col_type, partition_col_value);
result = createIntPartitionColumn<Int16>(nested_type, partition_col_value);
}
else if (which.isInt32())
{
result = createIntPartitionColumn<Int32>(partition_col_type, partition_col_value);
result = createIntPartitionColumn<Int32>(nested_type, partition_col_value);
}
else if (which.isInt64())
{
result = createIntPartitionColumn<Int64>(partition_col_type, partition_col_value);
result = createIntPartitionColumn<Int64>(nested_type, partition_col_value);
}
else if (which.isFloat32())
{
result = createFloatPartitionColumn<Float32>(partition_col_type, partition_col_value);
result = createFloatPartitionColumn<Float32>(nested_type, partition_col_value);
}
else if (which.isFloat64())
{
result = createFloatPartitionColumn<Float64>(partition_col_type, partition_col_value);
result = createFloatPartitionColumn<Float64>(nested_type, partition_col_value);
}
else if (which.isDate())
{
DayNum value;
auto value_buffer = ReadBufferFromString(partition_col_value);
readDateText(value, value_buffer);
result = partition_col_type->createColumnConst(1, value);
result = nested_type->createColumnConst(1, value);
}
else if (which.isDate32())
{
ExtendedDayNum value;
auto value_buffer = ReadBufferFromString(partition_col_value);
readDateText(value, value_buffer);
result = nested_type->createColumnConst(1, value.toUnderType());
}
else if (which.isString())
{
result = partition_col_type->createColumnConst(1, partition_col_value);
result = nested_type->createColumnConst(1, partition_col_value);
}
else
{
throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported datatype {}", partition_col_type->getFamilyName());
}
result = tryWrapPartitionColumn(result, partition_col_type);
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class PartitionColumnFillingTransform : public DB::ISimpleTransform

private:
DB::ColumnPtr createPartitionColumn();
static DB::ColumnPtr tryWrapPartitionColumn(const DB::ColumnPtr & nested_col, DB::DataTypePtr original_data_type);

DB::DataTypePtr partition_col_type;
String partition_col_name;
Expand Down
9 changes: 8 additions & 1 deletion utils/local-engine/Parser/CHColumnToSparkRow.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
#include "CHColumnToSparkRow.h"
#include <cstdint>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/IColumn.h>
#include <DataTypes/DataTypeNullable.h>
#include <Core/Types.h>
#include <DataTypes/DataTypesDecimal.h>
#include "DataTypes/Serializations/ISerialization.h"
#include "base/types.h"
#include <Functions/FunctionHelpers.h>


namespace DB
{
Expand Down Expand Up @@ -106,12 +111,14 @@ void writeValue(
std::vector<int64_t> & buffer_cursor)
{
ColumnPtr nested_col = col.column;

const auto * nullable_column = checkAndGetColumn<ColumnNullable>(*col.column);
if (nullable_column)
{
nested_col = nullable_column->getNestedColumnPtr();
}
nested_col = nested_col->convertToFullColumnIfConst();

WhichDataType which(nested_col->getDataType());
if (which.isUInt8())
{
Expand Down Expand Up @@ -181,7 +188,7 @@ void writeValue(
}
else
{
throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support type {} convert from ch to spark" ,magic_enum::enum_name(nested_col->getDataType()));
throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support type {} convert from ch to spark", col.type->getName());
}
}

Expand Down
14 changes: 13 additions & 1 deletion utils/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <Common/MergeTreeTool.h>
#include <Common/StringUtils.h>

#include <google/protobuf/util/json_util.h>
#include "SerializedPlanParser.h"

namespace DB
Expand Down Expand Up @@ -1281,7 +1282,18 @@ QueryPlanPtr SerializedPlanParser::parse(std::string & plan)
{
auto plan_ptr = std::make_unique<substrait::Plan>();
plan_ptr->ParseFromString(plan);
LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "parse plan \n{}", plan_ptr->DebugString());

auto printPlan = [](const std::string & plan_raw){
substrait::Plan plan;
plan.ParseFromString(plan_raw);
std::string json_ret;
google::protobuf::util::JsonPrintOptions json_opt;
json_opt.add_whitespace = true;
google::protobuf::util::MessageToJsonString(plan, &json_ret, json_opt);
return json_ret;
};

LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "parse plan \n{}", printPlan(plan));
return parse(std::move(plan_ptr));
}
void SerializedPlanParser::initFunctionEnv()
Expand Down
69 changes: 69 additions & 0 deletions utils/local-engine/jni/jni_common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include <exception>
#include <jni/jni_common.h>
#include <stdexcept>
#include <string>
#include <exception>
#include <jni/jni_error.h>

namespace local_engine
{
jclass CreateGlobalExceptionClassReference(JNIEnv* env, const char* class_name)
{
jclass local_class = env->FindClass(class_name);
jclass global_class = static_cast<jclass>(env->NewGlobalRef(local_class));
env->DeleteLocalRef(local_class);
if (global_class == nullptr) {
std::string error_msg = "Unable to createGlobalClassReference for" + std::string(class_name);
throw std::runtime_error(error_msg);
}
return global_class;
}

jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name)
{
jclass local_class = env->FindClass(class_name);
jclass global_class = static_cast<jclass>(env->NewGlobalRef(local_class));
env->DeleteLocalRef(local_class);
if (global_class == nullptr) {
std::string error_message =
"Unable to createGlobalClassReference for" + std::string(class_name);
env->ThrowNew(JniErrorsGlobalState::instance().getIllegalAccessExceptionClass(), error_message.c_str());
}
return global_class;
}

jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig)
{
jmethodID ret = env->GetMethodID(this_class, name, sig);
if (ret == nullptr) {
std::string error_message = "Unable to find method " + std::string(name) +
" within signature" + std::string(sig);
env->ThrowNew(JniErrorsGlobalState::instance().getIllegalAccessExceptionClass(), error_message.c_str());
}

return ret;
}

jmethodID GetStaticMethodID(JNIEnv * env, jclass this_class, const char * name, const char * sig)
{
jmethodID ret = env->GetStaticMethodID(this_class, name, sig);
if (ret == nullptr) {
std::string error_message = "Unable to find static method " + std::string(name) +
" within signature" + std::string(sig);
env->ThrowNew(JniErrorsGlobalState::instance().getIllegalAccessExceptionClass(), error_message.c_str());
}
return ret;
}

jstring charTojstring(JNIEnv* env, const char* pat) {
jclass str_class = (env)->FindClass("Ljava/lang/String;");
jmethodID ctor_id = (env)->GetMethodID(str_class, "<init>", "([BLjava/lang/String;)V");
jbyteArray bytes = (env)->NewByteArray(strlen(pat));
(env)->SetByteArrayRegion(bytes, 0, strlen(pat), reinterpret_cast<jbyte*>(const_cast<char*>(pat)));
jstring encoding = (env)->NewStringUTF("UTF-8");
jstring result = static_cast<jstring>((env)->NewObject(str_class, ctor_id, bytes, encoding));
env->DeleteLocalRef(bytes);
env->DeleteLocalRef(encoding);
return result;
}
}
17 changes: 17 additions & 0 deletions utils/local-engine/jni/jni_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once
#include <jni.h>

namespace local_engine
{
jclass CreateGlobalExceptionClassReference(JNIEnv *env, const char *class_name);

jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name);

jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig);

jmethodID GetStaticMethodID(JNIEnv * env, jclass this_class, const char * name, const char * sig);

jstring charTojstring(JNIEnv* env, const char* pat);

}

85 changes: 85 additions & 0 deletions utils/local-engine/jni/jni_error.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@

#include <stdexcept>
#include <jni/jni_error.h>
#include <jni/jni_common.h>
#include <jni.h>
#include <Poco/Logger.h>
#include "Common/Exception.h"
#include <base/logger_useful.h>

namespace local_engine
{
JniErrorsGlobalState & JniErrorsGlobalState::instance()
{
static JniErrorsGlobalState instance;
return instance;
}

void JniErrorsGlobalState::destroy(JNIEnv * env)
{
if (env)
{
if (io_exception_class)
{
env->DeleteGlobalRef(io_exception_class);
}
if (runtime_exception_class)
{
env->DeleteGlobalRef(runtime_exception_class);
}
if (unsupportedoperation_exception_class)
{
env->DeleteGlobalRef(unsupportedoperation_exception_class);
}
if (illegal_access_exception_class)
{
env->DeleteGlobalRef(illegal_access_exception_class);
}
if (illegal_argument_exception_class)
{
env->DeleteGlobalRef(illegal_argument_exception_class);
}
}
}

void JniErrorsGlobalState::initialize(JNIEnv * env_)
{
io_exception_class = CreateGlobalExceptionClassReference(env_, "Ljava/io/IOException;");
runtime_exception_class = CreateGlobalExceptionClassReference(env_, "Ljava/lang/RuntimeException;");
unsupportedoperation_exception_class = CreateGlobalExceptionClassReference(env_, "Ljava/lang/UnsupportedOperationException;");
illegal_access_exception_class = CreateGlobalExceptionClassReference(env_, "Ljava/lang/IllegalAccessException;");
illegal_argument_exception_class = CreateGlobalExceptionClassReference(env_, "Ljava/lang/IllegalArgumentException;");
}

void JniErrorsGlobalState::throwException(JNIEnv * env, const DB::Exception & e)
{
throwRuntimeException(env, e.message(), e.getStackTraceString());
}

void JniErrorsGlobalState::throwException(JNIEnv * env, const std::exception & e)
{
throwRuntimeException(env, e.what(), DB::getExceptionStackTraceString(e));
}

void JniErrorsGlobalState::throwException(JNIEnv * env,jclass exception_class, const std::string & message, const std::string & stack_trace)
{
if (exception_class)
{
std::string error_msg = message + "\n" + stack_trace;
env->ThrowNew(exception_class, error_msg.c_str());
}
else
{
// This will cause a coredump
throw std::runtime_error("Not found java runtime exception class");
}

}

void JniErrorsGlobalState::throwRuntimeException(JNIEnv * env,const std::string & message, const std::string & stack_trace)
{
throwException(env, runtime_exception_class, message, stack_trace);
}


}
Loading