diff --git a/velox/experimental/cudf/exec/CMakeLists.txt b/velox/experimental/cudf/exec/CMakeLists.txt index 0ff0a127cdc2..7d1ca8b7557b 100644 --- a/velox/experimental/cudf/exec/CMakeLists.txt +++ b/velox/experimental/cudf/exec/CMakeLists.txt @@ -23,6 +23,7 @@ add_library( CudfOrderBy.cpp DebugUtil.cpp ExpressionEvaluator.cpp + PrestoAggregates.cpp ToCudf.cpp Utilities.cpp VeloxCudfInterop.cpp) diff --git a/velox/experimental/cudf/exec/CudfHashAggregation.cpp b/velox/experimental/cudf/exec/CudfHashAggregation.cpp index b0feb7820b88..db3aadd88b1f 100644 --- a/velox/experimental/cudf/exec/CudfHashAggregation.cpp +++ b/velox/experimental/cudf/exec/CudfHashAggregation.cpp @@ -30,6 +30,8 @@ #include #include +#include + namespace { using namespace facebook::velox; @@ -386,32 +388,6 @@ struct MeanAggregator : cudf_velox::CudfHashAggregation::Aggregator { uint32_t countIdx_; }; -std::unique_ptr createAggregator( - core::AggregationNode::Step step, - std::string const& kind, - uint32_t inputIndex, - VectorPtr constant, - bool isGlobal) { - if (kind.rfind("sum", 0) == 0) { - return std::make_unique( - step, inputIndex, constant, isGlobal); - } else if (kind.rfind("count", 0) == 0) { - return std::make_unique( - step, inputIndex, constant, isGlobal); - } else if (kind.rfind("min", 0) == 0) { - return std::make_unique( - step, inputIndex, constant, isGlobal); - } else if (kind.rfind("max", 0) == 0) { - return std::make_unique( - step, inputIndex, constant, isGlobal); - } else if (kind.rfind("avg", 0) == 0) { - return std::make_unique( - step, inputIndex, constant, isGlobal); - } else { - VELOX_NYI("Aggregation not yet supported"); - } -} - static const std::unordered_map companionStep = { {"_partial", core::AggregationNode::Step::kPartial}, @@ -485,8 +461,8 @@ auto toAggregators( auto const inputIndex = aggInputs[0]; auto const constant = aggConstants.empty() ? nullptr : aggConstants[0]; auto const companionStep = getCompanionStep(kind, step); - aggregators.push_back( - createAggregator(companionStep, kind, inputIndex, constant, isGlobal)); + aggregators.push_back(facebook::velox::cudf_velox::createAggregator( + kind, companionStep, inputIndex, constant, isGlobal)); } return aggregators; } @@ -507,8 +483,8 @@ auto toIntermediateAggregators( auto const inputIndex = aggregationNode.groupingKeys().size() + i; auto const kind = aggregate.call->name(); auto const constant = nullptr; - aggregators.push_back( - createAggregator(step, kind, inputIndex, constant, isGlobal)); + aggregators.push_back(facebook::velox::cudf_velox::createAggregator( + kind, step, inputIndex, constant, isGlobal)); } return aggregators; } @@ -889,4 +865,150 @@ bool CudfHashAggregation::isFinished() { return finished_; } +std::unique_ptr createAggregator( + const std::string& kind, + core::AggregationNode::Step step, + uint32_t inputIndex, + VectorPtr constant, + bool isGlobal) { + // Ensure basic cudf aggregators are registered + static std::once_flag registrationFlag; + std::call_once(registrationFlag, []() { + facebook::velox::cudf_velox::registerCudfAggregators( + false /* withCompanionFunctions */, false /* overwrite */); + }); + + if (auto entry = facebook::velox::cudf_velox::getAggregatorEntry(kind)) { + return entry->factory(step, inputIndex, constant, isGlobal); + } + + VELOX_NYI("Aggregation not yet supported: {}", kind); +} + +AggregatorMap& aggregators() { + static AggregatorMap aggregators; + return aggregators; +} + +const AggregatorEntry* FOLLY_NULLABLE +getAggregatorEntry(const std::string& name) { + return aggregators().withRLock( + [&](const auto& aggregatorsMap) -> const AggregatorEntry* { + auto it = aggregatorsMap.find(name); + if (it != aggregatorsMap.end()) { + return &it->second; + } + return nullptr; + }); +} + +bool registerAggregator( + const std::string& name, + const AggregatorFactory& factory, + bool overwrite) { + if (overwrite) { + aggregators().withWLock( + [&](auto& aggregatorsMap) { aggregatorsMap[name] = {factory}; }); + return true; + } else { + return aggregators().withWLock([&](auto& aggregatorsMap) { + auto [_, inserted] = aggregatorsMap.insert({name, {factory}}); + return inserted; + }); + } +} + +// Registration functions for CUDF aggregators +template +void registerAggregatorImpl( + const std::string& name, + bool withCompanionFunctions, + bool overwrite) { + registerAggregator( + name, + [](core::AggregationNode::Step step, + uint32_t inputIndex, + VectorPtr constant, + bool isGlobal) -> std::unique_ptr { + return std::make_unique( + step, inputIndex, constant, isGlobal); + }, + overwrite); + if (withCompanionFunctions) { + registerAggregator( + name + "_partial", + [](core::AggregationNode::Step, + uint32_t inputIndex, + VectorPtr constant, + bool isGlobal) -> std::unique_ptr { + return std::make_unique( + core::AggregationNode::Step::kPartial, + inputIndex, + constant, + isGlobal); + }, + overwrite); + registerAggregator( + name + "_merge", + [](core::AggregationNode::Step, + uint32_t inputIndex, + VectorPtr constant, + bool isGlobal) -> std::unique_ptr { + return std::make_unique( + core::AggregationNode::Step::kIntermediate, + inputIndex, + constant, + isGlobal); + }, + overwrite); + registerAggregator( + name + "_merge_extract", + [](core::AggregationNode::Step, + uint32_t inputIndex, + VectorPtr constant, + bool isGlobal) -> std::unique_ptr { + return std::make_unique( + core::AggregationNode::Step::kFinal, + inputIndex, + constant, + isGlobal); + }, + overwrite); + } +} + +void registerSumAggregator(bool withCompanionFunctions, bool overwrite) { + registerAggregatorImpl( + "sum", withCompanionFunctions, overwrite); +} + +void registerCountAggregator(bool withCompanionFunctions, bool overwrite) { + registerAggregatorImpl( + "count", withCompanionFunctions, overwrite); +} + +void registerMinAggregator(bool withCompanionFunctions, bool overwrite) { + registerAggregatorImpl( + "min", withCompanionFunctions, overwrite); +} + +void registerMaxAggregator(bool withCompanionFunctions, bool overwrite) { + registerAggregatorImpl( + "max", withCompanionFunctions, overwrite); +} + +void registerAvgAggregator(bool withCompanionFunctions, bool overwrite) { + registerAggregatorImpl( + "avg", withCompanionFunctions, overwrite); +} + +// Register all CUDF aggregators +void registerCudfAggregators(bool withCompanionFunctions, bool overwrite) { + registerSumAggregator(withCompanionFunctions, overwrite); + registerCountAggregator(withCompanionFunctions, overwrite); + registerMinAggregator(withCompanionFunctions, overwrite); + registerMaxAggregator(withCompanionFunctions, overwrite); + registerAvgAggregator(withCompanionFunctions, overwrite); +} + } // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/exec/CudfHashAggregation.h b/velox/experimental/cudf/exec/CudfHashAggregation.h index 3b828e18e3cb..f1129f38c54b 100644 --- a/velox/experimental/cudf/exec/CudfHashAggregation.h +++ b/velox/experimental/cudf/exec/CudfHashAggregation.h @@ -147,4 +147,54 @@ class CudfHashAggregation : public exec::Operator, public NvtxHelper { CudfVectorPtr partialOutput_; }; +using AggregatorFactory = + std::function( + core::AggregationNode::Step step, + uint32_t inputIndex, + VectorPtr constant, + bool isGlobal)>; + +struct AggregatorEntry { + AggregatorFactory factory; +}; + +using AggregatorMap = + folly::Synchronized>; + +AggregatorMap& aggregators(); + +const AggregatorEntry* FOLLY_NULLABLE +getAggregatorEntry(const std::string& name); + +/// Register an aggregator function with the specified name and factory. +/// When function with `name` already exists, if overwrite is true, existing +/// registration will be replaced. Otherwise, return false without overwriting. +bool registerAggregator( + const std::string& name, + const AggregatorFactory& factory, + bool overwrite = false); + +/// Creates an aggregator instance using the registered factory. +/// Returns nullptr if no factory is registered for the given name. +std::unique_ptr createAggregator( + const std::string& kind, + core::AggregationNode::Step step, + uint32_t inputIndex, + VectorPtr constant, + bool isGlobal); + +/// Registration functions for CUDF aggregators +void registerSumAggregator(bool withCompanionFunctions, bool overwrite = false); +void registerCountAggregator( + bool withCompanionFunctions, + bool overwrite = false); +void registerMinAggregator(bool withCompanionFunctions, bool overwrite = false); +void registerMaxAggregator(bool withCompanionFunctions, bool overwrite = false); +void registerAvgAggregator(bool withCompanionFunctions, bool overwrite = false); + +/// Register all CUDF aggregators +void registerCudfAggregators( + bool withCompanionFunctions, + bool overwrite = false); + } // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/exec/PrestoAggregates.cpp b/velox/experimental/cudf/exec/PrestoAggregates.cpp new file mode 100644 index 000000000000..81092a29b3f7 --- /dev/null +++ b/velox/experimental/cudf/exec/PrestoAggregates.cpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/experimental/cudf/exec/CudfHashAggregation.h" +#include "velox/experimental/cudf/exec/PrestoAggregates.h" + +#include "velox/functions/prestosql/aggregates/AggregateNames.h" + +namespace facebook::velox::cudf_velox::presto { + +namespace { +// Use constants from AggregateNames.h +using facebook::velox::aggregate::kAvg; +using facebook::velox::aggregate::kCount; +using facebook::velox::aggregate::kMax; +using facebook::velox::aggregate::kMin; +using facebook::velox::aggregate::kSum; +} // namespace + +void registerPrestoAggregate( + const std::string& prefix, + const std::string& aggregateName, + bool overwrite) { + auto name = prefix + aggregateName; + registerAggregator( + name, + [aggregateName]( + core::AggregationNode::Step step, + uint32_t inputIndex, + VectorPtr constant, + bool isGlobal) -> std::unique_ptr { + return facebook::velox::cudf_velox::createAggregator( + aggregateName, step, inputIndex, constant, isGlobal); + }, + overwrite); +} + +void registerPrestoSumAggregate(const std::string& prefix, bool overwrite) { + registerPrestoAggregate(prefix, kSum, overwrite); +} + +void registerPrestoCountAggregate(const std::string& prefix, bool overwrite) { + registerPrestoAggregate(prefix, kCount, overwrite); +} + +void registerPrestoMinAggregate(const std::string& prefix, bool overwrite) { + registerPrestoAggregate(prefix, kMin, overwrite); +} + +void registerPrestoMaxAggregate(const std::string& prefix, bool overwrite) { + registerPrestoAggregate(prefix, kMax, overwrite); +} + +void registerPrestoAvgAggregate(const std::string& prefix, bool overwrite) { + registerPrestoAggregate(prefix, kAvg, overwrite); +} + +void registerAllPrestoAggregates(const std::string& prefix, bool overwrite) { + registerPrestoSumAggregate(prefix, overwrite); + registerPrestoCountAggregate(prefix, overwrite); + registerPrestoMinAggregate(prefix, overwrite); + registerPrestoMaxAggregate(prefix, overwrite); + registerPrestoAvgAggregate(prefix, overwrite); +} + +} // namespace facebook::velox::cudf_velox::presto diff --git a/velox/experimental/cudf/exec/PrestoAggregates.h b/velox/experimental/cudf/exec/PrestoAggregates.h new file mode 100644 index 000000000000..dbbd6fde0488 --- /dev/null +++ b/velox/experimental/cudf/exec/PrestoAggregates.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace facebook::velox::cudf_velox::presto { + +/// Register presto-style aggregate functions that use CUDF aggregators +/// with the specified prefix. +void registerPrestoSumAggregate( + const std::string& prefix, + bool overwrite = false); + +void registerPrestoCountAggregate( + const std::string& prefix, + bool overwrite = false); + +void registerPrestoMinAggregate( + const std::string& prefix, + bool overwrite = false); + +void registerPrestoMaxAggregate( + const std::string& prefix, + bool overwrite = false); + +void registerPrestoAvgAggregate( + const std::string& prefix, + bool overwrite = false); + +/// Register all presto-style CUDF aggregators with the specified prefix +void registerAllPrestoAggregates( + const std::string& prefix, + bool overwrite = false); + +} // namespace facebook::velox::cudf_velox::presto diff --git a/velox/experimental/cudf/tests/AggregationTest.cpp b/velox/experimental/cudf/tests/AggregationTest.cpp index a55b7256aa2b..5d4489ce969e 100644 --- a/velox/experimental/cudf/tests/AggregationTest.cpp +++ b/velox/experimental/cudf/tests/AggregationTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "velox/experimental/cudf/exec/CudfHashAggregation.h" #include "velox/experimental/cudf/exec/ToCudf.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" @@ -493,6 +494,9 @@ TEST_F(AggregationTest, CompanionAggs) { createDuckDbTable({rowVector}); + facebook::velox::cudf_velox::registerCudfAggregators( + true /* withCompanionFunctions */, false /* overwrite */); + auto op = PlanBuilder() .values({rowVector})