Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,76 @@ class GlutenClickHouseJoinSuite extends GlutenClickHouseWholeStageTransformerSui
}
}

test("GLUTEN-10961 cross join with empty join clause") {
val crossSql1 =
"""
|select a, b from (select id as a from range(1) )
|cross join (
| select id as b from range(2)
|);
|""".stripMargin
compareResultsAgainstVanillaSpark(crossSql1, true, { _ => })

val crossSql2 =
"""
|select a, b from (select id as a from range(1) where id > 1 )
|cross join (
| select id as b from range(2)
|);
|""".stripMargin
compareResultsAgainstVanillaSpark(crossSql2, true, { _ => })

val fullSql1 =
"""
|select a, b from (select id as a from range(1) where id > 1)
|full join (
| select id as b from range(2)
|)
|""".stripMargin
compareResultsAgainstVanillaSpark(fullSql1, true, { _ => })

val fullSql2 =
"""
|select a, b from (select id as a from range(1) )
|full join (
| select id as b from range(2)
|)
|""".stripMargin
compareResultsAgainstVanillaSpark(fullSql2, true, { _ => })

val innerSql1 =
"""
|select a, b from (select id as a from range(1) where id > 1)
|inner join (
| select id as b from range(2)
|)
|""".stripMargin
compareResultsAgainstVanillaSpark(innerSql1, true, { _ => })
val innerSql2 =
"""
|select a, b from (select id as a from range(1) )
|inner join (
| select id as b from range(2)
|)
|""".stripMargin
compareResultsAgainstVanillaSpark(innerSql2, true, { _ => })

val leftSql1 =
"""
|select a, b from (select id as a from range(1) where id > 1)
|left join (
| select id as b from range(2)
|)
|""".stripMargin
compareResultsAgainstVanillaSpark(leftSql1, true, { _ => })
val leftSql2 =
"""
|select a, b from (select id as a from range(1) )
|left join (
| select id as b from range(2)
|)
|""".stripMargin
compareResultsAgainstVanillaSpark(leftSql2, true, { _ => })
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,24 @@ class GlutenClickHouseTPCHSuite extends MergeTreeSuite {
| insert into cross_join_t
| select id as a, cast(id as string) as b,
| concat('1231231232323232322', cast(id as string)) as c
| from range(0, 100000)
| from range(0, 10000)
|""".stripMargin
spark.sql(sql)
sql = """
| select * from cross_join_t as t1 full join cross_join_t as t2 limit 10
| insert into cross_join_t
| select id as a, cast(id as string) as b,
| concat('1231231232323232322', cast(id as string)) as c
| from range(10000, 20000)
|""".stripMargin
spark.sql(sql)
sql = """
|select * from (
| select a as a1, b as b1, c as c1 from cross_join_t
|) as t1 full join (
| select a as a2, b as b2, c as c2 from cross_join_t
|) as t2
|order by a1, b1, c1, a2, b2, c2
|limit 10
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
spark.sql("drop table cross_join_t")
Expand Down
25 changes: 18 additions & 7 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1048,16 +1048,25 @@ UInt64 MemoryUtil::getMemoryRSS()
return rss * sysconf(_SC_PAGESIZE);
}


void JoinUtil::reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols)
void JoinUtil::adjustJoinOutput(DB::QueryPlan & plan, DB::Names cols)
{
ActionsDAG project{plan.getCurrentHeader()->getNamesAndTypesList()};
NamesWithAliases project_cols;
auto header = plan.getCurrentHeader();
std::unordered_map<String, const DB::ActionsDAG::Node *> name_to_node;
ActionsDAG project;
for (const auto & col : header->getColumnsWithTypeAndName())
{
const auto * node = &(project.addInput(col));
name_to_node[col.name] = node;
}
for (const auto & col : cols)
{
project_cols.emplace_back(NameWithAlias(col, col));
const auto it = name_to_node.find(col);
if (it == name_to_node.end())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Column {} not found in header", col);
}
project.addOrReplaceInOutputs(*(it->second));
}
project.project(project_cols);
QueryPlanStepPtr project_step = std::make_unique<ExpressionStep>(plan.getCurrentHeader(), std::move(project));
project_step->setStepDescription("Reorder Join Output");
plan.addStep(std::move(project_step));
Expand Down Expand Up @@ -1097,9 +1106,11 @@ std::pair<DB::JoinKind, DB::JoinStrictness> JoinUtil::getCrossJoinKindAndStrictn
switch (join_type)
{
case substrait::CrossRel_JoinType_JOIN_TYPE_INNER:
return {DB::JoinKind::Cross, DB::JoinStrictness::All};
case substrait::CrossRel_JoinType_JOIN_TYPE_LEFT:
return {DB::JoinKind::Left, DB::JoinStrictness::All};
case substrait::CrossRel_JoinType_JOIN_TYPE_OUTER:
return {DB::JoinKind::Cross, DB::JoinStrictness::All};
return {DB::JoinKind::Full, DB::JoinStrictness::All};
default:
throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type));
}
Expand Down
8 changes: 6 additions & 2 deletions cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class BlockUtil
{
public:
static constexpr auto VIRTUAL_ROW_COUNT_COLUMN = "__VIRTUAL_ROW_COUNT_COLUMN__";
static constexpr auto RIHGT_COLUMN_PREFIX = "broadcast_right_";
static constexpr auto RIGHT_COLUMN_PREFIX = "broadcast_right_";

// Build a header block with a virtual column which will be
// use to indicate the number of rows in a block.
Expand Down Expand Up @@ -249,7 +249,11 @@ class MemoryUtil
class JoinUtil
{
public:
static void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols);
static constexpr auto CROSS_REL_LEFT_CONST_KEY_COLUMN = "__CROSS_REL_LEFT_CONST_KEY_COLUMN__";
static constexpr auto CROSS_REL_RIGHT_CONST_KEY_COLUMN = "__CROSS_REL_RIGHT_CONST_KEY_COLUMN__";

// Keep necessarily columns and reorder them according to cols
static void adjustJoinOutput(DB::QueryPlan & plan, DB::Names cols);
static std::pair<DB::JoinKind, DB::JoinStrictness>
getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type, bool is_existence_join);
static std::pair<DB::JoinKind, DB::JoinStrictness> getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type);
Expand Down
107 changes: 78 additions & 29 deletions cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <Common/CHUtil.h>
#include <Common/JNIUtils.h>
#include <Common/logger_useful.h>
#include <DataTypes/DataTypesNumber.h>

namespace DB
{
Expand Down Expand Up @@ -67,12 +68,12 @@ DB::Block resetBuildTableBlockName(Block & block, bool only_one = false)
// add a sequence to avoid duplicate name in some rare cases
if (names.find(col.name) == names.end())
{
new_name << BlockUtil::RIHGT_COLUMN_PREFIX << col.name;
new_name << BlockUtil::RIGHT_COLUMN_PREFIX << col.name;
names.insert(col.name);
}
else
{
new_name << BlockUtil::RIHGT_COLUMN_PREFIX << (seq++) << "_" << col.name;
new_name << BlockUtil::RIGHT_COLUMN_PREFIX << (seq++) << "_" << col.name;
}
new_cols.emplace_back(col.column, col.type, new_name.str());

Expand Down Expand Up @@ -108,6 +109,51 @@ std::shared_ptr<StorageJoinFromReadBuffer> getJoin(const std::string & key)
return wrapper;
}

// A join in cross rel.
static bool isCrossRelJoin(const std::string & key)
{
return key.starts_with("BuiltBNLJBroadcastTable-");
}

static void collectBlocksForCountingRows(NativeReader & block_stream, Block & header, Blocks & result)
{
ProfileInfo profile;
Block block = block_stream.read();
while (!block.empty())
{
const auto & col = block.getByPosition(0);
auto counting_col = BlockUtil::buildRowCountBlock(col.column->size()).getColumnsWithTypeAndName()[0];
DB::ColumnsWithTypeAndName columns;
columns.emplace_back(counting_col.column->convertToFullColumnIfConst(), counting_col.type, counting_col.name);
DB::Block new_block(columns);
profile.update(new_block);
result.emplace_back(std::move(new_block));
block = block_stream.read();
}
header = BlockUtil::buildRowCountHeader();
}

static void collectBlocksForJoinRel(NativeReader & reader, Block & header, Blocks & result)
{
ProfileInfo profile;
Block block = reader.read();
while (!block.empty())
{
DB::ColumnsWithTypeAndName columns;
for (size_t i = 0; i < block.columns(); ++i)
{
const auto & column = block.getByPosition(i);
columns.emplace_back(BlockUtil::convertColumnAsNecessary(column, header.getByPosition(i)));
}

DB::Block final_block(columns);
profile.update(final_block);
result.emplace_back(std::move(final_block));

block = reader.read();
}
}

std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
const std::string & key,
DB::ReadBuffer & input,
Expand All @@ -123,12 +169,14 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
auto join_key_list = Poco::StringTokenizer(join_keys, ",");
Names key_names;
for (const auto & key_name : join_key_list)
key_names.emplace_back(BlockUtil::RIHGT_COLUMN_PREFIX + key_name);
key_names.emplace_back(BlockUtil::RIGHT_COLUMN_PREFIX + key_name);

DB::JoinKind kind;
DB::JoinStrictness strictness;
bool is_cross_rel_join = isCrossRelJoin(key);
assert(is_cross_rel_join && key_names.empty()); // cross rel join should not have join keys

if (key.starts_with("BuiltBNLJBroadcastTable-"))
if (is_cross_rel_join)
std::tie(kind, strictness) = JoinUtil::getCrossJoinKindAndStrictness(static_cast<substrait::CrossRel_JoinType>(join_type));
else
std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast<substrait::JoinRel_JoinType>(join_type), is_existence_join);
Expand All @@ -139,40 +187,41 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
Block header = TypeParser::buildBlockFromNamedStruct(substrait_struct);
header = resetBuildTableBlockName(header);

bool only_one_column = header.getNamesAndTypesList().empty();
if (only_one_column)
header = BlockUtil::buildRowCountBlock(0).getColumnsWithTypeAndName();

Blocks data;
auto collect_data = [&]
auto collect_data = [&]()
{
bool only_one_column = header.getNamesAndTypesList().empty();
NativeReader block_stream(input);
if (only_one_column)
header = BlockUtil::buildRowCountBlock(0).getColumnsWithTypeAndName();
collectBlocksForCountingRows(block_stream, header, data);
else
collectBlocksForJoinRel(block_stream, header, data);

NativeReader block_stream(input);
ProfileInfo info;
Block block = block_stream.read();
while (!block.empty())
// For not cross join, we need to add a constant join key column
// to make it behavior like a normal join.
if (is_cross_rel_join && kind != JoinKind::Cross)
{
DB::ColumnsWithTypeAndName columns;
for (size_t i = 0; i < block.columns(); ++i)
auto data_type_u8 = std::make_shared<DataTypeUInt8>();
UInt8 const_key_val = 0;
String const_key_name = JoinUtil::CROSS_REL_RIGHT_CONST_KEY_COLUMN;
Blocks new_data;
for (const auto & block : data)
{
const auto & column = block.getByPosition(i);
if (only_one_column)
{
auto virtual_block = BlockUtil::buildRowCountBlock(column.column->size()).getColumnsWithTypeAndName();
header = virtual_block;
columns.emplace_back(virtual_block.back());
break;
}

columns.emplace_back(BlockUtil::convertColumnAsNecessary(column, header.getByPosition(i)));
auto cols = block.getColumnsWithTypeAndName();
cols.emplace_back(data_type_u8->createColumnConst(block.rows(), const_key_val), data_type_u8, const_key_name);
new_data.emplace_back(Block(cols));
}

DB::Block final_block(columns);
info.update(final_block);
data.emplace_back(std::move(final_block));

block = block_stream.read();
data.swap(new_data);
key_names.emplace_back(const_key_name);
auto cols = header.getColumnsWithTypeAndName();
cols.emplace_back(data_type_u8->createColumnConst(0, const_key_val), data_type_u8, const_key_name);
header = Block(cols);
}
};

/// Record memory usage in Total Memory Tracker
ThreadFromGlobalPoolNoTracingContextPropagation thread(collect_data);
thread.join();
Expand Down
1 change: 0 additions & 1 deletion cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ void StorageJoinFromReadBuffer::buildJoinLazily(const DB::SharedHeader & header,
thread.join();
}


/// The column names of 'right_header' could be different from the ones in `input_blocks`, and we must
/// use 'right_header' to build the HashJoin. Otherwise, it will cause exceptions with name mismatches.
///
Expand Down
Loading