diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 341a10cb94b7..c8d6da2b666a 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -3012,6 +3012,17 @@ class GlutenClickHouseTPCHSaltNullParquetSuite compareResult = true, checkWindowGroupLimit ) + + compareResultsAgainstVanillaSpark( + """ + |select * from( + |select a, b, c, row_number() over (partition by a order by b, c, a) as r + |from test_win_top) + |where r <= 1 + |""".stripMargin, + compareResult = true, + checkWindowGroupLimit + ) spark.sql("drop table if exists test_win_top") } diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h b/cpp-ch/local-engine/Common/GlutenConfig.h index 7b2a33a9b6da..0f68151c03c7 100644 --- a/cpp-ch/local-engine/Common/GlutenConfig.h +++ b/cpp-ch/local-engine/Common/GlutenConfig.h @@ -164,8 +164,8 @@ struct WindowConfig public: inline static const String WINDOW_AGGREGATE_TOPK_SAMPLE_ROWS = "window.aggregate_topk_sample_rows"; inline static const String WINDOW_AGGREGATE_TOPK_HIGH_CARDINALITY_THRESHOLD = "window.aggregate_topk_high_cardinality_threshold"; - size_t aggregate_topk_sample_rows = 5000; - double aggregate_topk_high_cardinality_threshold = 0.6; + size_t aggregate_topk_sample_rows = 50000; + double aggregate_topk_high_cardinality_threshold = 0.4; static WindowConfig loadFromContext(const DB::ContextPtr & context); }; diff --git a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp index 01c778446051..7dcdbe6430e8 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp @@ -16,6 +16,7 @@ */ #include "GroupLimitRelParser.h" +#include #include #include #include @@ -46,12 +47,14 @@ #include #include #include +#include "Common/Logger.h" #include #include #include #include #include #include +#include "cctz/civil_time_detail.h" namespace DB::ErrorCodes { @@ -226,6 +229,7 @@ DB::QueryPlanPtr AggregateGroupLimitRelParser::parse( // If all partition keys are low cardinality keys, use aggregattion to get topk of each partition auto aggregation_plan = BranchStepHelper::createSubPlan(branch_in_header, 1); + collectPartitionAndSortFields(); prePrejectionForAggregateArguments(*aggregation_plan); addGroupLmitAggregationStep(*aggregation_plan); postProjectionForExplodingArrays(*aggregation_plan); @@ -262,15 +266,40 @@ String AggregateGroupLimitRelParser::getAggregateFunctionName(const String & win throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unsupported window function: {}", window_function_name); } +void AggregateGroupLimitRelParser::collectPartitionAndSortFields() +{ + partition_fields = parsePartitionFields(win_rel_def->partition_expressions()); + auto full_sort_fields = parseSortFields(win_rel_def->sorts()); + + std::set partition_fields_set(partition_fields.begin(), partition_fields.end()); + std::set full_sort_fields_set(full_sort_fields.begin(), full_sort_fields.end()); + std::set selected_sort_fields_set; + // Remove partition keys from sort keys + std::set_difference( + full_sort_fields_set.begin(), + full_sort_fields_set.end(), + partition_fields_set.begin(), + partition_fields_set.end(), + std::inserter(selected_sort_fields_set, selected_sort_fields_set.begin())); + if (selected_sort_fields_set.empty()) + { + // FIXME: support empty sort keys. + sort_fields.push_back(*partition_fields_set.begin()); + } + else + { + sort_fields = std::vector(selected_sort_fields_set.begin(), selected_sort_fields_set.end()); + } +} + // Build one tuple column as the aggregate function's arguments void AggregateGroupLimitRelParser::prePrejectionForAggregateArguments(DB::QueryPlan & plan) { auto projection_actions = std::make_shared(input_header->getColumnsWithTypeAndName()); - auto partition_fields = parsePartitionFields(win_rel_def->partition_expressions()); - auto sort_fields = parseSortFields(win_rel_def->sorts()); std::set unique_partition_fields(partition_fields.begin(), partition_fields.end()); std::set unique_sort_fields(sort_fields.begin(), sort_fields.end()); + DB::NameSet required_column_names; auto build_tuple = [&](const DB::DataTypes & data_types, const Strings & names, @@ -296,12 +325,13 @@ void AggregateGroupLimitRelParser::prePrejectionForAggregateArguments(DB::QueryP for (size_t i = 0; i < input_header->columns(); ++i) { const auto & col = input_header->getByPosition(i); - if (unique_partition_fields.count(i) && !unique_sort_fields.count(i)) + if (unique_partition_fields.count(i)) { required_column_names.insert(col.name); aggregate_grouping_keys.push_back(col.name); } - else + + if (!unique_partition_fields.count(i) || unique_sort_fields.count(i)) { aggregate_data_tuple_types.push_back(col.type); aggregate_data_tuple_names.push_back(col.name); @@ -333,7 +363,15 @@ DB::AggregateDescription AggregateGroupLimitRelParser::buildAggregateDescription agg_desc.argument_names = {aggregate_tuple_column_name}; auto & parameters = agg_desc.parameters; parameters.push_back(static_cast(limit)); - auto sort_directions = buildSQLLikeSortDescription(*input_header, win_rel_def->sorts()); + std::set sort_field_names; + for (auto i : sort_fields) + sort_field_names.insert(input_header->getByPosition(i).name); + auto full_sort_desc = parseSortFields(*input_header, win_rel_def->sorts()); + DB::SortDescription sort_desc; + for (const auto & sort_column : full_sort_desc) + if (sort_field_names.count(sort_column.column_name)) + sort_desc.push_back(sort_column); + auto sort_directions = buildSQLLikeSortDescription(sort_desc); parameters.push_back(sort_directions); const auto & header = *plan.getCurrentHeader(); @@ -348,6 +386,7 @@ DB::AggregateDescription AggregateGroupLimitRelParser::buildAggregateDescription void AggregateGroupLimitRelParser::addGroupLmitAggregationStep(DB::QueryPlan & plan) { const auto & settings = getContext()->getSettingsRef(); + DB::AggregateDescriptions agg_descs = {buildAggregateDescription(plan)}; auto params = AggregatorParamsHelper::buildParams( getContext(), aggregate_grouping_keys, agg_descs, AggregatorParamsHelper::Mode::INIT_TO_COMPLETED); diff --git a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h index 44159b01903d..f0643421cf9c 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h +++ b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h @@ -66,14 +66,17 @@ class AggregateGroupLimitRelParser : public RelParser String aggregate_function_name; size_t limit = 0; DB::SharedHeader input_header; - // DB::Block output_header; + // Field indexes at the input header which are used as partition keys + std::vector partition_fields; + // Field indexes at the input header which are used as sort keys + std::vector sort_fields; DB::Names aggregate_grouping_keys; String aggregate_tuple_column_name; String getAggregateFunctionName(const String & window_function_name); + void collectPartitionAndSortFields(); void prePrejectionForAggregateArguments(DB::QueryPlan & plan); - void addGroupLmitAggregationStep(DB::QueryPlan & plan); String parseSortDirections(const google::protobuf::RepeatedPtrField & sort_fields); DB::AggregateDescription buildAggregateDescription(DB::QueryPlan & plan); diff --git a/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp b/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp index c45849d97221..39c722fb0927 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp +++ b/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp @@ -76,32 +76,18 @@ DB::SortDescription parseSortFields(const DB::Block & header, const google::prot return sort_descr; } -std::string -buildSQLLikeSortDescription(const DB::Block & header, const google::protobuf::RepeatedPtrField & sort_fields) +std::string buildSQLLikeSortDescription(const DB::SortDescription & sort_description) { - static const std::unordered_map order_directions - = {{1, " asc nulls first"}, {2, " asc nulls last"}, {3, " desc nulls first"}, {4, " desc nulls last"}}; - size_t n = 0; DB::WriteBufferFromOwnString ostr; - for (const auto & sort_field : sort_fields) + size_t n = 0; + for (const auto & sort_column : sort_description) { - auto it = order_directions.find(sort_field.direction()); - if (it == order_directions.end()) - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow sort direction: {}", sort_field.direction()); - auto field_index = SubstraitParserUtils::getStructFieldIndex(sort_field.expr()); - if (!field_index) - { - throw DB::Exception( - DB::ErrorCodes::BAD_ARGUMENTS, "Sort field must be a column reference. but got {}", sort_field.DebugString()); - } - const auto & col_name = header.getByPosition(*field_index).name; if (n) - ostr << String(","); - // the col_name may contain '#' which can may ch fail to parse. - ostr << "`" << col_name << "`" << it->second; + ostr << String(", "); + const auto & col_name = sort_column.column_name; + ostr << "`" << col_name << "` " << (sort_column.direction == 1 ? "ASC" : "DESC") << " NULLS " << (sort_column.nulls_direction != sort_column.direction ? "FIRST" : "LAST"); n += 1; } - LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Order by clasue: {}", ostr.str()); return ostr.str(); } } diff --git a/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h b/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h index c460fa758b6d..4f20675e5ae6 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h +++ b/cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h @@ -28,6 +28,6 @@ DB::SortDescription parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField & expressions); DB::SortDescription parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField & sort_fields); -std::string -buildSQLLikeSortDescription(const DB::Block & header, const google::protobuf::RepeatedPtrField & sort_fields); + +std::string buildSQLLikeSortDescription(const DB::SortDescription & sort_description); }