Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3012,6 +3012,17 @@ class GlutenClickHouseTPCHSaltNullParquetSuite
compareResult = true,
checkWindowGroupLimit
)

compareResultsAgainstVanillaSpark(
"""
|select * from(
|select a, b, c, row_number() over (partition by a order by b, c, a) as r
|from test_win_top)
|where r <= 1
|""".stripMargin,
compareResult = true,
checkWindowGroupLimit
)
spark.sql("drop table if exists test_win_top")
}

Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Common/GlutenConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ struct WindowConfig
public:
inline static const String WINDOW_AGGREGATE_TOPK_SAMPLE_ROWS = "window.aggregate_topk_sample_rows";
inline static const String WINDOW_AGGREGATE_TOPK_HIGH_CARDINALITY_THRESHOLD = "window.aggregate_topk_high_cardinality_threshold";
size_t aggregate_topk_sample_rows = 5000;
double aggregate_topk_high_cardinality_threshold = 0.6;
size_t aggregate_topk_sample_rows = 50000;
double aggregate_topk_high_cardinality_threshold = 0.4;
static WindowConfig loadFromContext(const DB::ContextPtr & context);
};

Expand Down
49 changes: 44 additions & 5 deletions cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include "GroupLimitRelParser.h"
#include <algorithm>
#include <memory>
#include <unordered_set>
#include <utility>
Expand Down Expand Up @@ -46,12 +47,14 @@
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <google/protobuf/repeated_field.h>
#include <google/protobuf/wrappers.pb.h>
#include "Common/Logger.h"
#include <Common/AggregateUtil.h>
#include <Common/ArrayJoinHelper.h>
#include <Common/GlutenConfig.h>
#include <Common/PlanUtil.h>
#include <Common/QueryContext.h>
#include <Common/logger_useful.h>
#include "cctz/civil_time_detail.h"

namespace DB::ErrorCodes
{
Expand Down Expand Up @@ -226,6 +229,7 @@ DB::QueryPlanPtr AggregateGroupLimitRelParser::parse(

// If all partition keys are low cardinality keys, use aggregattion to get topk of each partition
auto aggregation_plan = BranchStepHelper::createSubPlan(branch_in_header, 1);
collectPartitionAndSortFields();
prePrejectionForAggregateArguments(*aggregation_plan);
addGroupLmitAggregationStep(*aggregation_plan);
postProjectionForExplodingArrays(*aggregation_plan);
Expand Down Expand Up @@ -262,15 +266,40 @@ String AggregateGroupLimitRelParser::getAggregateFunctionName(const String & win
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unsupported window function: {}", window_function_name);
}

void AggregateGroupLimitRelParser::collectPartitionAndSortFields()
{
partition_fields = parsePartitionFields(win_rel_def->partition_expressions());
auto full_sort_fields = parseSortFields(win_rel_def->sorts());

std::set<size_t> partition_fields_set(partition_fields.begin(), partition_fields.end());
std::set<size_t> full_sort_fields_set(full_sort_fields.begin(), full_sort_fields.end());
std::set<size_t> selected_sort_fields_set;
// Remove partition keys from sort keys
std::set_difference(
full_sort_fields_set.begin(),
full_sort_fields_set.end(),
partition_fields_set.begin(),
partition_fields_set.end(),
std::inserter(selected_sort_fields_set, selected_sort_fields_set.begin()));
if (selected_sort_fields_set.empty())
{
// FIXME: support empty sort keys.
sort_fields.push_back(*partition_fields_set.begin());
}
else
{
sort_fields = std::vector<size_t>(selected_sort_fields_set.begin(), selected_sort_fields_set.end());
}
}

// Build one tuple column as the aggregate function's arguments
void AggregateGroupLimitRelParser::prePrejectionForAggregateArguments(DB::QueryPlan & plan)
{
auto projection_actions = std::make_shared<DB::ActionsDAG>(input_header->getColumnsWithTypeAndName());

auto partition_fields = parsePartitionFields(win_rel_def->partition_expressions());
auto sort_fields = parseSortFields(win_rel_def->sorts());
std::set<size_t> unique_partition_fields(partition_fields.begin(), partition_fields.end());
std::set<size_t> unique_sort_fields(sort_fields.begin(), sort_fields.end());

DB::NameSet required_column_names;
auto build_tuple = [&](const DB::DataTypes & data_types,
const Strings & names,
Expand All @@ -296,12 +325,13 @@ void AggregateGroupLimitRelParser::prePrejectionForAggregateArguments(DB::QueryP
for (size_t i = 0; i < input_header->columns(); ++i)
{
const auto & col = input_header->getByPosition(i);
if (unique_partition_fields.count(i) && !unique_sort_fields.count(i))
if (unique_partition_fields.count(i))
{
required_column_names.insert(col.name);
aggregate_grouping_keys.push_back(col.name);
}
else

if (!unique_partition_fields.count(i) || unique_sort_fields.count(i))
{
aggregate_data_tuple_types.push_back(col.type);
aggregate_data_tuple_names.push_back(col.name);
Expand Down Expand Up @@ -333,7 +363,15 @@ DB::AggregateDescription AggregateGroupLimitRelParser::buildAggregateDescription
agg_desc.argument_names = {aggregate_tuple_column_name};
auto & parameters = agg_desc.parameters;
parameters.push_back(static_cast<UInt32>(limit));
auto sort_directions = buildSQLLikeSortDescription(*input_header, win_rel_def->sorts());
std::set<String> sort_field_names;
for (auto i : sort_fields)
sort_field_names.insert(input_header->getByPosition(i).name);
auto full_sort_desc = parseSortFields(*input_header, win_rel_def->sorts());
DB::SortDescription sort_desc;
for (const auto & sort_column : full_sort_desc)
if (sort_field_names.count(sort_column.column_name))
sort_desc.push_back(sort_column);
auto sort_directions = buildSQLLikeSortDescription(sort_desc);
parameters.push_back(sort_directions);

const auto & header = *plan.getCurrentHeader();
Expand All @@ -348,6 +386,7 @@ DB::AggregateDescription AggregateGroupLimitRelParser::buildAggregateDescription
void AggregateGroupLimitRelParser::addGroupLmitAggregationStep(DB::QueryPlan & plan)
{
const auto & settings = getContext()->getSettingsRef();

DB::AggregateDescriptions agg_descs = {buildAggregateDescription(plan)};
auto params = AggregatorParamsHelper::buildParams(
getContext(), aggregate_grouping_keys, agg_descs, AggregatorParamsHelper::Mode::INIT_TO_COMPLETED);
Expand Down
7 changes: 5 additions & 2 deletions cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,17 @@ class AggregateGroupLimitRelParser : public RelParser
String aggregate_function_name;
size_t limit = 0;
DB::SharedHeader input_header;
// DB::Block output_header;
// Field indexes at the input header which are used as partition keys
std::vector<size_t> partition_fields;
// Field indexes at the input header which are used as sort keys
std::vector<size_t> sort_fields;
DB::Names aggregate_grouping_keys;
String aggregate_tuple_column_name;

String getAggregateFunctionName(const String & window_function_name);

void collectPartitionAndSortFields();
void prePrejectionForAggregateArguments(DB::QueryPlan & plan);

void addGroupLmitAggregationStep(DB::QueryPlan & plan);
String parseSortDirections(const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);
DB::AggregateDescription buildAggregateDescription(DB::QueryPlan & plan);
Expand Down
26 changes: 6 additions & 20 deletions cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,32 +76,18 @@ DB::SortDescription parseSortFields(const DB::Block & header, const google::prot
return sort_descr;
}

std::string
buildSQLLikeSortDescription(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields)
std::string buildSQLLikeSortDescription(const DB::SortDescription & sort_description)
{
static const std::unordered_map<int, std::string> order_directions
= {{1, " asc nulls first"}, {2, " asc nulls last"}, {3, " desc nulls first"}, {4, " desc nulls last"}};
size_t n = 0;
DB::WriteBufferFromOwnString ostr;
for (const auto & sort_field : sort_fields)
size_t n = 0;
for (const auto & sort_column : sort_description)
{
auto it = order_directions.find(sort_field.direction());
if (it == order_directions.end())
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow sort direction: {}", sort_field.direction());
auto field_index = SubstraitParserUtils::getStructFieldIndex(sort_field.expr());
if (!field_index)
{
throw DB::Exception(
DB::ErrorCodes::BAD_ARGUMENTS, "Sort field must be a column reference. but got {}", sort_field.DebugString());
}
const auto & col_name = header.getByPosition(*field_index).name;
if (n)
ostr << String(",");
// the col_name may contain '#' which can may ch fail to parse.
ostr << "`" << col_name << "`" << it->second;
ostr << String(", ");
const auto & col_name = sort_column.column_name;
ostr << "`" << col_name << "` " << (sort_column.direction == 1 ? "ASC" : "DESC") << " NULLS " << (sort_column.nulls_direction != sort_column.direction ? "FIRST" : "LAST");
n += 1;
}
LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Order by clasue: {}", ostr.str());
return ostr.str();
}
}
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ DB::SortDescription
parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::Expression> & expressions);
DB::SortDescription parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);

std::string
buildSQLLikeSortDescription(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);

std::string buildSQLLikeSortDescription(const DB::SortDescription & sort_description);
}