diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseJoinSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseJoinSuite.scala index 844b12def2b3..1404e83214f8 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseJoinSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseJoinSuite.scala @@ -170,4 +170,76 @@ class GlutenClickHouseJoinSuite extends GlutenClickHouseWholeStageTransformerSui } } + test("GLUTEN-10961 cross join with empty join clause") { + val crossSql1 = + """ + |select a, b from (select id as a from range(1) ) + |cross join ( + | select id as b from range(2) + |); + |""".stripMargin + compareResultsAgainstVanillaSpark(crossSql1, true, { _ => }) + + val crossSql2 = + """ + |select a, b from (select id as a from range(1) where id > 1 ) + |cross join ( + | select id as b from range(2) + |); + |""".stripMargin + compareResultsAgainstVanillaSpark(crossSql2, true, { _ => }) + + val fullSql1 = + """ + |select a, b from (select id as a from range(1) where id > 1) + |full join ( + | select id as b from range(2) + |) + |""".stripMargin + compareResultsAgainstVanillaSpark(fullSql1, true, { _ => }) + + val fullSql2 = + """ + |select a, b from (select id as a from range(1) ) + |full join ( + | select id as b from range(2) + |) + |""".stripMargin + compareResultsAgainstVanillaSpark(fullSql2, true, { _ => }) + + val innerSql1 = + """ + |select a, b from (select id as a from range(1) where id > 1) + |inner join ( + | select id as b from range(2) + |) + |""".stripMargin + compareResultsAgainstVanillaSpark(innerSql1, true, { _ => }) + val innerSql2 = + """ + |select a, b from (select id as a from range(1) ) + |inner join ( + | select id as b from range(2) + |) + |""".stripMargin + compareResultsAgainstVanillaSpark(innerSql2, true, { _ => }) + + val leftSql1 = + """ + |select a, b from (select id as a from range(1) where id > 1) + |left join ( + | select id as b from range(2) + |) + |""".stripMargin + compareResultsAgainstVanillaSpark(leftSql1, true, { _ => }) + val leftSql2 = + """ + |select a, b from (select id as a from range(1) ) + |left join ( + | select id as b from range(2) + |) + |""".stripMargin + compareResultsAgainstVanillaSpark(leftSql2, true, { _ => }) + } + } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala index 9947609c8e8c..87cbc8749068 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala @@ -464,11 +464,24 @@ class GlutenClickHouseTPCHSuite extends MergeTreeSuite { | insert into cross_join_t | select id as a, cast(id as string) as b, | concat('1231231232323232322', cast(id as string)) as c - | from range(0, 100000) + | from range(0, 10000) |""".stripMargin spark.sql(sql) sql = """ - | select * from cross_join_t as t1 full join cross_join_t as t2 limit 10 + | insert into cross_join_t + | select id as a, cast(id as string) as b, + | concat('1231231232323232322', cast(id as string)) as c + | from range(10000, 20000) + |""".stripMargin + spark.sql(sql) + sql = """ + |select * from ( + | select a as a1, b as b1, c as c1 from cross_join_t + |) as t1 full join ( + | select a as a2, b as b2, c as c2 from cross_join_t + |) as t2 + |order by a1, b1, c1, a2, b2, c2 + |limit 10 |""".stripMargin compareResultsAgainstVanillaSpark(sql, true, { _ => }) spark.sql("drop table cross_join_t") diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 30c872eda306..a7c0df4777f4 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -1048,16 +1048,25 @@ UInt64 MemoryUtil::getMemoryRSS() return rss * sysconf(_SC_PAGESIZE); } - -void JoinUtil::reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols) +void JoinUtil::adjustJoinOutput(DB::QueryPlan & plan, DB::Names cols) { - ActionsDAG project{plan.getCurrentHeader()->getNamesAndTypesList()}; - NamesWithAliases project_cols; + auto header = plan.getCurrentHeader(); + std::unordered_map name_to_node; + ActionsDAG project; + for (const auto & col : header->getColumnsWithTypeAndName()) + { + const auto * node = &(project.addInput(col)); + name_to_node[col.name] = node; + } for (const auto & col : cols) { - project_cols.emplace_back(NameWithAlias(col, col)); + const auto it = name_to_node.find(col); + if (it == name_to_node.end()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Column {} not found in header", col); + } + project.addOrReplaceInOutputs(*(it->second)); } - project.project(project_cols); QueryPlanStepPtr project_step = std::make_unique(plan.getCurrentHeader(), std::move(project)); project_step->setStepDescription("Reorder Join Output"); plan.addStep(std::move(project_step)); @@ -1097,9 +1106,11 @@ std::pair JoinUtil::getCrossJoinKindAndStrictn switch (join_type) { case substrait::CrossRel_JoinType_JOIN_TYPE_INNER: + return {DB::JoinKind::Cross, DB::JoinStrictness::All}; case substrait::CrossRel_JoinType_JOIN_TYPE_LEFT: + return {DB::JoinKind::Left, DB::JoinStrictness::All}; case substrait::CrossRel_JoinType_JOIN_TYPE_OUTER: - return {DB::JoinKind::Cross, DB::JoinStrictness::All}; + return {DB::JoinKind::Full, DB::JoinStrictness::All}; default: throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); } diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index b7cd75524b3b..709af367ea9c 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -52,7 +52,7 @@ class BlockUtil { public: static constexpr auto VIRTUAL_ROW_COUNT_COLUMN = "__VIRTUAL_ROW_COUNT_COLUMN__"; - static constexpr auto RIHGT_COLUMN_PREFIX = "broadcast_right_"; + static constexpr auto RIGHT_COLUMN_PREFIX = "broadcast_right_"; // Build a header block with a virtual column which will be // use to indicate the number of rows in a block. @@ -249,7 +249,11 @@ class MemoryUtil class JoinUtil { public: - static void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols); + static constexpr auto CROSS_REL_LEFT_CONST_KEY_COLUMN = "__CROSS_REL_LEFT_CONST_KEY_COLUMN__"; + static constexpr auto CROSS_REL_RIGHT_CONST_KEY_COLUMN = "__CROSS_REL_RIGHT_CONST_KEY_COLUMN__"; + + // Keep necessarily columns and reorder them according to cols + static void adjustJoinOutput(DB::QueryPlan & plan, DB::Names cols); static std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type, bool is_existence_join); static std::pair getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type); diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp index 9a83f08b5438..8e46556e3d68 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp @@ -29,6 +29,7 @@ #include #include #include +#include namespace DB { @@ -67,12 +68,12 @@ DB::Block resetBuildTableBlockName(Block & block, bool only_one = false) // add a sequence to avoid duplicate name in some rare cases if (names.find(col.name) == names.end()) { - new_name << BlockUtil::RIHGT_COLUMN_PREFIX << col.name; + new_name << BlockUtil::RIGHT_COLUMN_PREFIX << col.name; names.insert(col.name); } else { - new_name << BlockUtil::RIHGT_COLUMN_PREFIX << (seq++) << "_" << col.name; + new_name << BlockUtil::RIGHT_COLUMN_PREFIX << (seq++) << "_" << col.name; } new_cols.emplace_back(col.column, col.type, new_name.str()); @@ -108,6 +109,51 @@ std::shared_ptr getJoin(const std::string & key) return wrapper; } +// A join in cross rel. +static bool isCrossRelJoin(const std::string & key) +{ + return key.starts_with("BuiltBNLJBroadcastTable-"); +} + +static void collectBlocksForCountingRows(NativeReader & block_stream, Block & header, Blocks & result) +{ + ProfileInfo profile; + Block block = block_stream.read(); + while (!block.empty()) + { + const auto & col = block.getByPosition(0); + auto counting_col = BlockUtil::buildRowCountBlock(col.column->size()).getColumnsWithTypeAndName()[0]; + DB::ColumnsWithTypeAndName columns; + columns.emplace_back(counting_col.column->convertToFullColumnIfConst(), counting_col.type, counting_col.name); + DB::Block new_block(columns); + profile.update(new_block); + result.emplace_back(std::move(new_block)); + block = block_stream.read(); + } + header = BlockUtil::buildRowCountHeader(); +} + +static void collectBlocksForJoinRel(NativeReader & reader, Block & header, Blocks & result) +{ + ProfileInfo profile; + Block block = reader.read(); + while (!block.empty()) + { + DB::ColumnsWithTypeAndName columns; + for (size_t i = 0; i < block.columns(); ++i) + { + const auto & column = block.getByPosition(i); + columns.emplace_back(BlockUtil::convertColumnAsNecessary(column, header.getByPosition(i))); + } + + DB::Block final_block(columns); + profile.update(final_block); + result.emplace_back(std::move(final_block)); + + block = reader.read(); + } +} + std::shared_ptr buildJoin( const std::string & key, DB::ReadBuffer & input, @@ -123,12 +169,14 @@ std::shared_ptr buildJoin( auto join_key_list = Poco::StringTokenizer(join_keys, ","); Names key_names; for (const auto & key_name : join_key_list) - key_names.emplace_back(BlockUtil::RIHGT_COLUMN_PREFIX + key_name); + key_names.emplace_back(BlockUtil::RIGHT_COLUMN_PREFIX + key_name); DB::JoinKind kind; DB::JoinStrictness strictness; + bool is_cross_rel_join = isCrossRelJoin(key); + assert(is_cross_rel_join && key_names.empty()); // cross rel join should not have join keys - if (key.starts_with("BuiltBNLJBroadcastTable-")) + if (is_cross_rel_join) std::tie(kind, strictness) = JoinUtil::getCrossJoinKindAndStrictness(static_cast(join_type)); else std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast(join_type), is_existence_join); @@ -139,40 +187,41 @@ std::shared_ptr buildJoin( Block header = TypeParser::buildBlockFromNamedStruct(substrait_struct); header = resetBuildTableBlockName(header); + bool only_one_column = header.getNamesAndTypesList().empty(); + if (only_one_column) + header = BlockUtil::buildRowCountBlock(0).getColumnsWithTypeAndName(); + Blocks data; - auto collect_data = [&] + auto collect_data = [&]() { - bool only_one_column = header.getNamesAndTypesList().empty(); + NativeReader block_stream(input); if (only_one_column) - header = BlockUtil::buildRowCountBlock(0).getColumnsWithTypeAndName(); + collectBlocksForCountingRows(block_stream, header, data); + else + collectBlocksForJoinRel(block_stream, header, data); - NativeReader block_stream(input); - ProfileInfo info; - Block block = block_stream.read(); - while (!block.empty()) + // For not cross join, we need to add a constant join key column + // to make it behavior like a normal join. + if (is_cross_rel_join && kind != JoinKind::Cross) { - DB::ColumnsWithTypeAndName columns; - for (size_t i = 0; i < block.columns(); ++i) + auto data_type_u8 = std::make_shared(); + UInt8 const_key_val = 0; + String const_key_name = JoinUtil::CROSS_REL_RIGHT_CONST_KEY_COLUMN; + Blocks new_data; + for (const auto & block : data) { - const auto & column = block.getByPosition(i); - if (only_one_column) - { - auto virtual_block = BlockUtil::buildRowCountBlock(column.column->size()).getColumnsWithTypeAndName(); - header = virtual_block; - columns.emplace_back(virtual_block.back()); - break; - } - - columns.emplace_back(BlockUtil::convertColumnAsNecessary(column, header.getByPosition(i))); + auto cols = block.getColumnsWithTypeAndName(); + cols.emplace_back(data_type_u8->createColumnConst(block.rows(), const_key_val), data_type_u8, const_key_name); + new_data.emplace_back(Block(cols)); } - - DB::Block final_block(columns); - info.update(final_block); - data.emplace_back(std::move(final_block)); - - block = block_stream.read(); + data.swap(new_data); + key_names.emplace_back(const_key_name); + auto cols = header.getColumnsWithTypeAndName(); + cols.emplace_back(data_type_u8->createColumnConst(0, const_key_val), data_type_u8, const_key_name); + header = Block(cols); } }; + /// Record memory usage in Total Memory Tracker ThreadFromGlobalPoolNoTracingContextPropagation thread(collect_data); thread.join(); diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp index eb7e5a9a0ad1..889e3a6e87fc 100644 --- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp +++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp @@ -154,7 +154,6 @@ void StorageJoinFromReadBuffer::buildJoinLazily(const DB::SharedHeader & header, thread.join(); } - /// The column names of 'right_header' could be different from the ones in `input_blocks`, and we must /// use 'right_header' to build the HashJoin. Otherwise, it will cause exceptions with name mismatches. /// diff --git a/cpp-ch/local-engine/Parser/RelParsers/CrossRelParser.cpp b/cpp-ch/local-engine/Parser/RelParsers/CrossRelParser.cpp index 3f05526cbfa5..ce38f7c43a9f 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/CrossRelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParsers/CrossRelParser.cpp @@ -35,6 +35,7 @@ #include #include #include +#include namespace DB { @@ -93,11 +94,39 @@ std::optional CrossRelParser::getSingleInput(const subst throw Exception(ErrorCodes::LOGICAL_ERROR, "join node has 2 inputs, can't call getSingleInput()."); } +// For non-cross join, CH uses constant join keys. We keep the same implementation here. +void CrossRelParser::addConstJoinKeys(DB::QueryPlan & left, DB::QueryPlan & right) +{ + auto data_type_u8 = std::make_shared(); + auto const_key_col = data_type_u8->createColumnConst(1, UInt8(0)); + + String left_key = JoinUtil::CROSS_REL_LEFT_CONST_KEY_COLUMN; + auto left_columns = left.getCurrentHeader()->getColumnsWithTypeAndName(); + DB::ActionsDAG left_project_actions(left_columns); + const auto & left_key_node = left_project_actions.addColumn({const_key_col, data_type_u8, left_key}); + left_project_actions.addOrReplaceInOutputs(left_key_node); + auto left_project_step = std::make_unique(left.getCurrentHeader(), std::move(left_project_actions)); + left_project_step->setStepDescription("Add const join key for cross rel left"); + left.addStep(std::move(left_project_step)); + + String right_key = JoinUtil::CROSS_REL_RIGHT_CONST_KEY_COLUMN; + auto right_columns = right.getCurrentHeader()->getColumnsWithTypeAndName(); + DB::ActionsDAG right_project_actions(right_columns); + const auto & right_key_node = right_project_actions.addColumn({const_key_col, data_type_u8, right_key}); + right_project_actions.addOrReplaceInOutputs(right_key_node); + auto right_project_step = std::make_unique(right.getCurrentHeader(), std::move(right_project_actions)); + right_project_step->setStepDescription("Add const join key for cross rel right"); + right.addStep(std::move(right_project_step)); +} + DB::QueryPlanPtr CrossRelParser::parse(std::vector & input_plans_, const substrait::Rel & rel, std::list &) { assert(input_plans_.size() == 2); const auto & join = rel.cross(); + std::pair kind_and_strictness = JoinUtil::getCrossJoinKindAndStrictness(join.type()); + if (kind_and_strictness.first != JoinKind::Cross) + addConstJoinKeys(*input_plans_[0], *input_plans_[1]); return parseJoin(join, std::move(input_plans_[0]), std::move(input_plans_[1])); } @@ -160,14 +189,16 @@ DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB: right->getCurrentHeader()->dumpNames()); } - Names after_join_names; - auto left_names = left->getCurrentHeader()->getNames(); - after_join_names.insert(after_join_names.end(), left_names.begin(), left_names.end()); - auto right_name = table_join->columnsFromJoinedTable().getNames(); - after_join_names.insert(after_join_names.end(), right_name.begin(), right_name.end()); + Names after_join_names = collectOutputColumnsName(*left, *right); - auto left_header = left->getCurrentHeader(); - auto right_header = right->getCurrentHeader(); + if (table_join->kind() != JoinKind::Cross) + { + table_join->addDisjunct(); + auto & join_clause = table_join->getClauses().back(); + String left_key = JoinUtil::CROSS_REL_LEFT_CONST_KEY_COLUMN; + String right_key = JoinUtil::CROSS_REL_RIGHT_CONST_KEY_COLUMN; + join_clause.addKey(left_key, right_key, false); + } QueryPlanPtr query_plan; if (storage_join) @@ -184,15 +215,7 @@ DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB: extra_plan_holder.emplace_back(std::move(right)); addPostFilter(*query_plan, join); - Names cols; - for (auto after_join_name : after_join_names) - { - if (BlockUtil::VIRTUAL_ROW_COUNT_COLUMN == after_join_name) - continue; - - cols.emplace_back(after_join_name); - } - JoinUtil::reorderJoinOutput(*query_plan, cols); + JoinUtil::adjustJoinOutput(*query_plan, after_join_names); } else { @@ -216,7 +239,7 @@ DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB: query_plan = std::make_unique(); query_plan->unitePlans(std::move(join_step), {std::move(plans)}); - JoinUtil::reorderJoinOutput(*query_plan, after_join_names); + JoinUtil::adjustJoinOutput(*query_plan, after_join_names); } return query_plan; @@ -318,6 +341,26 @@ void CrossRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left } } +DB::Names CrossRelParser::collectOutputColumnsName(const DB::QueryPlan & left, const DB::QueryPlan & right) +{ + Names join_result_names; + auto is_unused_column = [](const String & name) + { + return name == JoinUtil::CROSS_REL_LEFT_CONST_KEY_COLUMN || name == JoinUtil::CROSS_REL_RIGHT_CONST_KEY_COLUMN + || name == BlockUtil::VIRTUAL_ROW_COUNT_COLUMN; + }; + for (auto & col : left.getCurrentHeader()->getColumnsWithTypeAndName()) + { + if (!is_unused_column(col.name)) + join_result_names.emplace_back(col.name); + } + for (auto & col : right.getCurrentHeader()->getColumnsWithTypeAndName()) + { + if (!is_unused_column(col.name)) + join_result_names.emplace_back(col.name); + } + return join_result_names; +} void registerCrossRelParser(RelParserFactory & factory) { diff --git a/cpp-ch/local-engine/Parser/RelParsers/CrossRelParser.h b/cpp-ch/local-engine/Parser/RelParsers/CrossRelParser.h index d7ccc487cce9..0d5990569989 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/CrossRelParser.h +++ b/cpp-ch/local-engine/Parser/RelParsers/CrossRelParser.h @@ -20,6 +20,7 @@ #include #include #include +#include namespace DB { @@ -32,6 +33,8 @@ namespace local_engine class StorageJoinFromReadBuffer; +/// Cross rel is for joins without joining keys. For example, +/// SELECT * FROM t1 LEFT JOIN t2 class CrossRelParser : public RelParser { public: @@ -62,6 +65,9 @@ class CrossRelParser : public RelParser DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition); + + void addConstJoinKeys(DB::QueryPlan & left, DB::QueryPlan & right); + DB::Names collectOutputColumnsName(const DB::QueryPlan & left, const DB::QueryPlan & right); }; } diff --git a/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp index 6dbd28b95f74..db96600632c6 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp @@ -358,7 +358,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q } } - JoinUtil::reorderJoinOutput(*query_plan, after_join_names); + JoinUtil::adjustJoinOutput(*query_plan, after_join_names); /// Need to project the right table column into boolean type if (join_opt_info.is_existence_join) existenceJoinPostProject(*query_plan, left_names);