Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions velox/experimental/cudf/exec/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ add_library(
CudfOrderBy.cpp
DebugUtil.cpp
ExpressionEvaluator.cpp
PrestoAggregates.cpp
ToCudf.cpp
Utilities.cpp
VeloxCudfInterop.cpp)
Expand Down
182 changes: 152 additions & 30 deletions velox/experimental/cudf/exec/CudfHashAggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <cudf/stream_compaction.hpp>
#include <cudf/unary.hpp>

#include <mutex>

namespace {

using namespace facebook::velox;
Expand Down Expand Up @@ -386,32 +388,6 @@ struct MeanAggregator : cudf_velox::CudfHashAggregation::Aggregator {
uint32_t countIdx_;
};

std::unique_ptr<cudf_velox::CudfHashAggregation::Aggregator> createAggregator(
core::AggregationNode::Step step,
std::string const& kind,
uint32_t inputIndex,
VectorPtr constant,
bool isGlobal) {
if (kind.rfind("sum", 0) == 0) {
return std::make_unique<SumAggregator>(
step, inputIndex, constant, isGlobal);
} else if (kind.rfind("count", 0) == 0) {
return std::make_unique<CountAggregator>(
step, inputIndex, constant, isGlobal);
} else if (kind.rfind("min", 0) == 0) {
return std::make_unique<MinAggregator>(
step, inputIndex, constant, isGlobal);
} else if (kind.rfind("max", 0) == 0) {
return std::make_unique<MaxAggregator>(
step, inputIndex, constant, isGlobal);
} else if (kind.rfind("avg", 0) == 0) {
return std::make_unique<MeanAggregator>(
step, inputIndex, constant, isGlobal);
} else {
VELOX_NYI("Aggregation not yet supported");
}
}

static const std::unordered_map<std::string, core::AggregationNode::Step>
companionStep = {
{"_partial", core::AggregationNode::Step::kPartial},
Expand Down Expand Up @@ -485,8 +461,8 @@ auto toAggregators(
auto const inputIndex = aggInputs[0];
auto const constant = aggConstants.empty() ? nullptr : aggConstants[0];
auto const companionStep = getCompanionStep(kind, step);
aggregators.push_back(
createAggregator(companionStep, kind, inputIndex, constant, isGlobal));
aggregators.push_back(facebook::velox::cudf_velox::createAggregator(
kind, companionStep, inputIndex, constant, isGlobal));
}
return aggregators;
}
Expand All @@ -507,8 +483,8 @@ auto toIntermediateAggregators(
auto const inputIndex = aggregationNode.groupingKeys().size() + i;
auto const kind = aggregate.call->name();
auto const constant = nullptr;
aggregators.push_back(
createAggregator(step, kind, inputIndex, constant, isGlobal));
aggregators.push_back(facebook::velox::cudf_velox::createAggregator(
kind, step, inputIndex, constant, isGlobal));
}
return aggregators;
}
Expand Down Expand Up @@ -889,4 +865,150 @@ bool CudfHashAggregation::isFinished() {
return finished_;
}

std::unique_ptr<cudf_velox::CudfHashAggregation::Aggregator> createAggregator(
const std::string& kind,
core::AggregationNode::Step step,
uint32_t inputIndex,
VectorPtr constant,
bool isGlobal) {
// Ensure basic cudf aggregators are registered
static std::once_flag registrationFlag;
std::call_once(registrationFlag, []() {
facebook::velox::cudf_velox::registerCudfAggregators(
false /* withCompanionFunctions */, false /* overwrite */);
});

if (auto entry = facebook::velox::cudf_velox::getAggregatorEntry(kind)) {
return entry->factory(step, inputIndex, constant, isGlobal);
}

VELOX_NYI("Aggregation not yet supported: {}", kind);
}

AggregatorMap& aggregators() {
static AggregatorMap aggregators;
return aggregators;
}

const AggregatorEntry* FOLLY_NULLABLE
getAggregatorEntry(const std::string& name) {
return aggregators().withRLock(
[&](const auto& aggregatorsMap) -> const AggregatorEntry* {
auto it = aggregatorsMap.find(name);
if (it != aggregatorsMap.end()) {
return &it->second;
}
return nullptr;
});
}

bool registerAggregator(
const std::string& name,
const AggregatorFactory& factory,
bool overwrite) {
if (overwrite) {
aggregators().withWLock(
[&](auto& aggregatorsMap) { aggregatorsMap[name] = {factory}; });
return true;
} else {
return aggregators().withWLock([&](auto& aggregatorsMap) {
auto [_, inserted] = aggregatorsMap.insert({name, {factory}});
return inserted;
});
}
}

// Registration functions for CUDF aggregators
template <typename AggregatorType>
void registerAggregatorImpl(
const std::string& name,
bool withCompanionFunctions,
bool overwrite) {
registerAggregator(
name,
[](core::AggregationNode::Step step,
uint32_t inputIndex,
VectorPtr constant,
bool isGlobal) -> std::unique_ptr<CudfHashAggregation::Aggregator> {
return std::make_unique<AggregatorType>(
step, inputIndex, constant, isGlobal);
},
overwrite);
if (withCompanionFunctions) {
registerAggregator(
name + "_partial",
[](core::AggregationNode::Step,
uint32_t inputIndex,
VectorPtr constant,
bool isGlobal) -> std::unique_ptr<CudfHashAggregation::Aggregator> {
return std::make_unique<AggregatorType>(
core::AggregationNode::Step::kPartial,
inputIndex,
constant,
isGlobal);
},
overwrite);
registerAggregator(
name + "_merge",
[](core::AggregationNode::Step,
uint32_t inputIndex,
VectorPtr constant,
bool isGlobal) -> std::unique_ptr<CudfHashAggregation::Aggregator> {
return std::make_unique<AggregatorType>(
core::AggregationNode::Step::kIntermediate,
inputIndex,
constant,
isGlobal);
},
overwrite);
registerAggregator(
name + "_merge_extract",
[](core::AggregationNode::Step,
uint32_t inputIndex,
VectorPtr constant,
bool isGlobal) -> std::unique_ptr<CudfHashAggregation::Aggregator> {
return std::make_unique<AggregatorType>(
core::AggregationNode::Step::kFinal,
inputIndex,
constant,
isGlobal);
},
overwrite);
}
}

void registerSumAggregator(bool withCompanionFunctions, bool overwrite) {
registerAggregatorImpl<SumAggregator>(
"sum", withCompanionFunctions, overwrite);
}

void registerCountAggregator(bool withCompanionFunctions, bool overwrite) {
registerAggregatorImpl<CountAggregator>(
"count", withCompanionFunctions, overwrite);
}

void registerMinAggregator(bool withCompanionFunctions, bool overwrite) {
registerAggregatorImpl<MinAggregator>(
"min", withCompanionFunctions, overwrite);
}

void registerMaxAggregator(bool withCompanionFunctions, bool overwrite) {
registerAggregatorImpl<MaxAggregator>(
"max", withCompanionFunctions, overwrite);
}

void registerAvgAggregator(bool withCompanionFunctions, bool overwrite) {
registerAggregatorImpl<MeanAggregator>(
"avg", withCompanionFunctions, overwrite);
}

// Register all CUDF aggregators
void registerCudfAggregators(bool withCompanionFunctions, bool overwrite) {
registerSumAggregator(withCompanionFunctions, overwrite);
registerCountAggregator(withCompanionFunctions, overwrite);
registerMinAggregator(withCompanionFunctions, overwrite);
registerMaxAggregator(withCompanionFunctions, overwrite);
registerAvgAggregator(withCompanionFunctions, overwrite);
}

} // namespace facebook::velox::cudf_velox
50 changes: 50 additions & 0 deletions velox/experimental/cudf/exec/CudfHashAggregation.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,54 @@ class CudfHashAggregation : public exec::Operator, public NvtxHelper {
CudfVectorPtr partialOutput_;
};

using AggregatorFactory =
std::function<std::unique_ptr<CudfHashAggregation::Aggregator>(
core::AggregationNode::Step step,
uint32_t inputIndex,
VectorPtr constant,
bool isGlobal)>;

struct AggregatorEntry {
AggregatorFactory factory;
};

using AggregatorMap =
folly::Synchronized<std::unordered_map<std::string, AggregatorEntry>>;

AggregatorMap& aggregators();

const AggregatorEntry* FOLLY_NULLABLE
getAggregatorEntry(const std::string& name);

/// Register an aggregator function with the specified name and factory.
/// When function with `name` already exists, if overwrite is true, existing
/// registration will be replaced. Otherwise, return false without overwriting.
bool registerAggregator(
const std::string& name,
const AggregatorFactory& factory,
bool overwrite = false);

/// Creates an aggregator instance using the registered factory.
/// Returns nullptr if no factory is registered for the given name.
std::unique_ptr<CudfHashAggregation::Aggregator> createAggregator(
const std::string& kind,
core::AggregationNode::Step step,
uint32_t inputIndex,
VectorPtr constant,
bool isGlobal);

/// Registration functions for CUDF aggregators
void registerSumAggregator(bool withCompanionFunctions, bool overwrite = false);
void registerCountAggregator(
bool withCompanionFunctions,
bool overwrite = false);
void registerMinAggregator(bool withCompanionFunctions, bool overwrite = false);
void registerMaxAggregator(bool withCompanionFunctions, bool overwrite = false);
void registerAvgAggregator(bool withCompanionFunctions, bool overwrite = false);

/// Register all CUDF aggregators
void registerCudfAggregators(
bool withCompanionFunctions,
bool overwrite = false);

} // namespace facebook::velox::cudf_velox
79 changes: 79 additions & 0 deletions velox/experimental/cudf/exec/PrestoAggregates.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/experimental/cudf/exec/CudfHashAggregation.h"
#include "velox/experimental/cudf/exec/PrestoAggregates.h"

#include "velox/functions/prestosql/aggregates/AggregateNames.h"

namespace facebook::velox::cudf_velox::presto {

namespace {
// Use constants from AggregateNames.h
using facebook::velox::aggregate::kAvg;
using facebook::velox::aggregate::kCount;
using facebook::velox::aggregate::kMax;
using facebook::velox::aggregate::kMin;
using facebook::velox::aggregate::kSum;
} // namespace

void registerPrestoAggregate(
const std::string& prefix,
const std::string& aggregateName,
bool overwrite) {
auto name = prefix + aggregateName;
registerAggregator(
name,
[aggregateName](
core::AggregationNode::Step step,
uint32_t inputIndex,
VectorPtr constant,
bool isGlobal) -> std::unique_ptr<CudfHashAggregation::Aggregator> {
return facebook::velox::cudf_velox::createAggregator(
aggregateName, step, inputIndex, constant, isGlobal);
},
overwrite);
}

void registerPrestoSumAggregate(const std::string& prefix, bool overwrite) {
registerPrestoAggregate(prefix, kSum, overwrite);
}

void registerPrestoCountAggregate(const std::string& prefix, bool overwrite) {
registerPrestoAggregate(prefix, kCount, overwrite);
}

void registerPrestoMinAggregate(const std::string& prefix, bool overwrite) {
registerPrestoAggregate(prefix, kMin, overwrite);
}

void registerPrestoMaxAggregate(const std::string& prefix, bool overwrite) {
registerPrestoAggregate(prefix, kMax, overwrite);
}

void registerPrestoAvgAggregate(const std::string& prefix, bool overwrite) {
registerPrestoAggregate(prefix, kAvg, overwrite);
}

void registerAllPrestoAggregates(const std::string& prefix, bool overwrite) {
registerPrestoSumAggregate(prefix, overwrite);
registerPrestoCountAggregate(prefix, overwrite);
registerPrestoMinAggregate(prefix, overwrite);
registerPrestoMaxAggregate(prefix, overwrite);
registerPrestoAvgAggregate(prefix, overwrite);
}

} // namespace facebook::velox::cudf_velox::presto
Loading
Loading