Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 75 additions & 31 deletions utils/local-engine/Operator/PartitionColumnFillingTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <Common/StringUtils.h>
#include "Processors/Chunk.h"
#include <Columns/IColumn.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Interpreters/DatabaseAndTableWithAlias.h>
#include <base/DayNum.h>

#include <Poco/Logger.h>
#include <base/logger_useful.h>


using namespace DB;

Expand All @@ -20,42 +30,52 @@ namespace ErrorCodes
namespace local_engine
{
template <typename Type>
requires(
std::is_same_v<Type, Int8> || std::is_same_v<Type, UInt16> || std::is_same_v<Type, Int16> || std::is_same_v<Type, Int32> || std::is_same_v<Type, Int64>)
ColumnPtr createIntPartitionColumn(DataTypePtr column_type, std::string partition_value)
requires(std::is_same_v<Type, Int8> || std::is_same_v<Type, UInt16> || std::is_same_v<Type, Int16> || std::is_same_v<Type, Int32> || std::is_same_v<Type, Int64>)
ColumnPtr createIntPartitionColumn(DataTypePtr column_type, std::string partition_value, size_t rows)
{
Type value;
auto value_buffer = ReadBufferFromString(partition_value);
readIntText(value, value_buffer);
return column_type->createColumnConst(1, value);
return column_type->createColumnConst(rows, value);
}

template <typename Type>
requires(std::is_same_v<Type, Float32> || std::is_same_v<Type, Float64>) ColumnPtr
createFloatPartitionColumn(DataTypePtr column_type, std::string partition_value)
requires(std::is_same_v<Type, Float32> || std::is_same_v<Type, Float64>)
ColumnPtr createFloatPartitionColumn(DataTypePtr column_type, std::string partition_value, size_t rows)
{
Type value;
auto value_buffer = ReadBufferFromString(partition_value);
readFloatText(value, value_buffer);
return column_type->createColumnConst(1, value);
return column_type->createColumnConst(rows, value);
}

//template <>
//ColumnPtr createFloatPartitionColumn<Float32>(DataTypePtr column_type, std::string partition_value);
//template <>
//ColumnPtr createFloatPartitionColumn<Float64>(DataTypePtr column_type, std::string partition_value);

PartitionColumnFillingTransform::PartitionColumnFillingTransform(
const DB::Block & input_, const DB::Block & output_, const String & partition_col_name_, const String & partition_col_value_)
: ISimpleTransform(input_, output_, true), partition_col_name(partition_col_name_), partition_col_value(partition_col_value_)
const DB::Block & input_, const DB::Block & output_, const PartitionValues & partition_columns_)
: ISimpleTransform(input_, output_, true), partition_column_values(partition_columns_)
{
partition_col_type = output_.getByName(partition_col_name_).type;
partition_column = createPartitionColumn();
for (const auto & value : partition_column_values)
{
partition_columns[value.first] = value.second;
}
}

ColumnPtr PartitionColumnFillingTransform::createPartitionColumn()
/// In the case that a partition column is wrapper by nullable or LowCardinality, we need to keep the data type same
/// as input.
ColumnPtr PartitionColumnFillingTransform::tryWrapPartitionColumn(const ColumnPtr & nested_col, DataTypePtr original_data_type)
{
auto result = nested_col;
if (original_data_type->getTypeId() == TypeIndex::Nullable)
{
result = ColumnNullable::create(nested_col, ColumnUInt8::create());
}
return result;
}

ColumnPtr PartitionColumnFillingTransform::createPartitionColumn(const String & parition_col, const String & partition_col_value, size_t rows)
{
ColumnPtr result;
auto partition_col_type = output.getHeader().getByName(parition_col).type;
DataTypePtr nested_type = partition_col_type;
if (const DataTypeNullable * nullable_type = checkAndGetDataType<DataTypeNullable>(partition_col_type.get()))
{
Expand All @@ -68,56 +88,80 @@ ColumnPtr PartitionColumnFillingTransform::createPartitionColumn()
WhichDataType which(nested_type);
if (which.isInt8())
{
result = createIntPartitionColumn<Int8>(partition_col_type, partition_col_value);
result = createIntPartitionColumn<Int8>(nested_type, partition_col_value, rows);
}
else if (which.isInt16())
{
result = createIntPartitionColumn<Int16>(partition_col_type, partition_col_value);
result = createIntPartitionColumn<Int16>(nested_type, partition_col_value, rows);
}
else if (which.isInt32())
{
result = createIntPartitionColumn<Int32>(partition_col_type, partition_col_value);
result = createIntPartitionColumn<Int32>(nested_type, partition_col_value, rows);
}
else if (which.isInt64())
{
result = createIntPartitionColumn<Int64>(partition_col_type, partition_col_value);
result = createIntPartitionColumn<Int64>(nested_type, partition_col_value, rows);
}
else if (which.isFloat32())
{
result = createFloatPartitionColumn<Float32>(partition_col_type, partition_col_value);
result = createFloatPartitionColumn<Float32>(nested_type, partition_col_value, rows);
}
else if (which.isFloat64())
{
result = createFloatPartitionColumn<Float64>(partition_col_type, partition_col_value);
result = createFloatPartitionColumn<Float64>(nested_type, partition_col_value, rows);
}
else if (which.isDate())
{
DayNum value;
auto value_buffer = ReadBufferFromString(partition_col_value);
readDateText(value, value_buffer);
result = partition_col_type->createColumnConst(1, value);
result = nested_type->createColumnConst(rows, value);
}
else if (which.isDate32())
{
ExtendedDayNum value;
auto value_buffer = ReadBufferFromString(partition_col_value);
readDateText(value, value_buffer);
result = nested_type->createColumnConst(rows, value.toUnderType());
}
else if (which.isString())
{
result = partition_col_type->createColumnConst(1, partition_col_value);
result = nested_type->createColumnConst(rows, partition_col_value);
}
else
{
throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported datatype {}", partition_col_type->getFamilyName());
}
result = tryWrapPartitionColumn(result, partition_col_type);
return result;
}

void PartitionColumnFillingTransform::transform(DB::Chunk & chunk)
{
size_t partition_column_position = output.getHeader().getPositionByName(partition_col_name);
if (partition_column_position == input.getHeader().columns())
{
chunk.addColumn(partition_column->cloneResized(chunk.getNumRows()));
}
else
auto rows = chunk.getNumRows();
auto input_cols = chunk.detachColumns();
Columns result_cols;
auto input_header = input.getHeader();
for (const auto & output_col : output.getHeader())
{
chunk.addColumn(partition_column_position, partition_column->cloneResized(chunk.getNumRows()));
if (input_header.has(output_col.name))
{
size_t pos = input_header.getPositionByName(output_col.name);
result_cols.push_back(input_cols[pos]);
}
else
{
// it's a partition column
auto it = partition_columns.find(output_col.name);
if (it == partition_columns.end())
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "Not found column({}) in parition columns", output_col.name);
}
result_cols.emplace_back(createPartitionColumn(it->first, it->second, rows));

}

}
chunk = DB::Chunk(std::move(result_cols), rows);
}
}
16 changes: 9 additions & 7 deletions utils/local-engine/Operator/PartitionColumnFillingTransform.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#pragma once

#include <Processors/ISimpleTransform.h>
#include "Common/StringUtils.h"
#include "Columns/IColumn.h"
#include "Core/Block.h"
#include "DataTypes/Serializations/ISerialization.h"

namespace local_engine
{
Expand All @@ -10,21 +14,19 @@ class PartitionColumnFillingTransform : public DB::ISimpleTransform
PartitionColumnFillingTransform(
const DB::Block & input_,
const DB::Block & output_,
const String & partition_col_name_,
const String & partition_col_value_);
const PartitionValues & partition_columns_);
void transform(DB::Chunk & chunk) override;
String getName() const override
{
return "PartitionColumnFillingTransform";
}

private:
DB::ColumnPtr createPartitionColumn();
DB::ColumnPtr createPartitionColumn(const String & parition_col, const String & partition_col_value, size_t row);
static DB::ColumnPtr tryWrapPartitionColumn(const DB::ColumnPtr & nested_col, DB::DataTypePtr original_data_type);

DB::DataTypePtr partition_col_type;
String partition_col_name;
String partition_col_value;
DB::ColumnPtr partition_column;
PartitionValues partition_column_values;
std::map<String, String> partition_columns;
};

}
Expand Down
9 changes: 8 additions & 1 deletion utils/local-engine/Parser/CHColumnToSparkRow.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
#include "CHColumnToSparkRow.h"
#include <cstdint>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/IColumn.h>
#include <DataTypes/DataTypeNullable.h>
#include <Core/Types.h>
#include <DataTypes/DataTypesDecimal.h>
#include "DataTypes/Serializations/ISerialization.h"
#include "base/types.h"
#include <Functions/FunctionHelpers.h>


namespace DB
{
Expand Down Expand Up @@ -106,12 +111,14 @@ void writeValue(
std::vector<int64_t> & buffer_cursor)
{
ColumnPtr nested_col = col.column;

const auto * nullable_column = checkAndGetColumn<ColumnNullable>(*col.column);
if (nullable_column)
{
nested_col = nullable_column->getNestedColumnPtr();
}
nested_col = nested_col->convertToFullColumnIfConst();

WhichDataType which(nested_col->getDataType());
if (which.isUInt8())
{
Expand Down Expand Up @@ -181,7 +188,7 @@ void writeValue(
}
else
{
throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support type {} convert from ch to spark" ,magic_enum::enum_name(nested_col->getDataType()));
throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support type {} convert from ch to spark", col.type->getName());
}
}

Expand Down
33 changes: 20 additions & 13 deletions utils/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <base/logger_useful.h>
#include "SerializedPlanParser.h"
#include <memory>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/registerAggregateFunctions.h>
#include <Builder/BroadCastJoinBuilder.h>
Expand Down Expand Up @@ -41,7 +42,7 @@
#include <Common/MergeTreeTool.h>
#include <Common/StringUtils.h>

#include "SerializedPlanParser.h"
#include <google/protobuf/util/json_util.h>

namespace DB
{
Expand Down Expand Up @@ -197,19 +198,14 @@ QueryPlanPtr SerializedPlanParser::parseReadRealWithLocalFile(const substrait::R
}
auto header = parseNameStruct(rel.base_schema());
PartitionValues partition_values = StringUtils::parsePartitionTablePath(files_info->files[0]);
if (partition_values.size() > 1)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "doesn't support multiple level partition.");
}
ProcessorPtr partition_transform;
if (!partition_values.empty())

auto origin_header = header.cloneEmpty();
for (const auto & partition_value : partition_values)
{
auto origin_header = header.cloneEmpty();
PartitionValue partition_value = partition_values[0];
header.erase(partition_value.first);
partition_transform
= std::make_shared<PartitionColumnFillingTransform>(header, origin_header, partition_value.first, partition_value.second);
}
ProcessorPtr partition_transform = std::make_shared<PartitionColumnFillingTransform>(header, origin_header, partition_values);

auto query_plan = std::make_unique<QueryPlan>();
std::shared_ptr<IProcessor> source = std::make_shared<BatchParquetFileSource>(files_info, header, context);
auto source_pipe = Pipe(source);
Expand Down Expand Up @@ -1281,7 +1277,18 @@ QueryPlanPtr SerializedPlanParser::parse(std::string & plan)
{
auto plan_ptr = std::make_unique<substrait::Plan>();
plan_ptr->ParseFromString(plan);
LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "parse plan \n{}", plan_ptr->DebugString());

auto printPlan = [](const std::string & plan_raw){
substrait::Plan plan;
plan.ParseFromString(plan_raw);
std::string json_ret;
google::protobuf::util::JsonPrintOptions json_opt;
json_opt.add_whitespace = true;
google::protobuf::util::MessageToJsonString(plan, &json_ret, json_opt);
return json_ret;
};

LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "parse plan \n{}", printPlan(plan));
return parse(std::move(plan_ptr));
}
void SerializedPlanParser::initFunctionEnv()
Expand Down
Loading