diff --git a/velox/experimental/stateful/StatefulPlanner.cpp b/velox/experimental/stateful/StatefulPlanner.cpp index 8409a1b3b505..ae83c83f64ca 100644 --- a/velox/experimental/stateful/StatefulPlanner.cpp +++ b/velox/experimental/stateful/StatefulPlanner.cpp @@ -13,26 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/experimental/stateful/StatefulPlanner.h" +#include #include "velox/core/PlanFragment.h" #include "velox/exec/AssignUniqueId.h" -#include "velox/exec/CallbackSink.h" #include "velox/exec/EnforceSingleRow.h" -#include "velox/exec/Exchange.h" #include "velox/exec/Expand.h" #include "velox/exec/FilterProject.h" #include "velox/exec/GroupId.h" #include "velox/exec/HashAggregation.h" -#include "velox/exec/HashBuild.h" #include "velox/exec/HashProbe.h" #include "velox/exec/IndexLookupJoin.h" #include "velox/exec/Limit.h" #include "velox/exec/MarkDistinct.h" -#include "velox/exec/Merge.h" #include "velox/exec/MergeJoin.h" -#include "velox/exec/NestedLoopJoinBuild.h" #include "velox/exec/NestedLoopJoinProbe.h" #include "velox/exec/OrderBy.h" -#include "velox/exec/RoundRobinPartitionFunction.h" #include "velox/exec/RowNumber.h" #include "velox/exec/StreamingAggregation.h" #include "velox/exec/TableScan.h" @@ -45,415 +41,527 @@ #include "velox/exec/Values.h" #include "velox/exec/Window.h" #include "velox/experimental/stateful/EmptyOperator.h" +#include "velox/experimental/stateful/GroupWindowAggregator.h" #include "velox/experimental/stateful/KeySelector.h" #include "velox/experimental/stateful/LocalWindowAggregator.h" -#include "velox/experimental/stateful/StatefulPlanner.h" #include "velox/experimental/stateful/StatefulPlanNode.h" -#include "velox/experimental/stateful/StreamPartition.h" #include "velox/experimental/stateful/StreamJoin.h" +#include "velox/experimental/stateful/StreamKeyedOperator.h" +#include "velox/experimental/stateful/StreamPartition.h" #include "velox/experimental/stateful/WatermarkAssigner.h" #include "velox/experimental/stateful/WindowAggregator.h" #include "velox/experimental/stateful/WindowJoin.h" -#include "velox/experimental/stateful/GroupWindowAggregator.h" -#include "velox/experimental/stateful/window/GroupWindowAggsHandler.h" -#include "velox/experimental/stateful/StreamKeyedOperator.h" -#include "velox/experimental/stateful/rank/RowTimeDeduplicateRanker.h" -#include "velox/experimental/stateful/rank/AppendOnlyTopNRanker.h" #include "velox/experimental/stateful/agg/AggsHandleFunction.h" #include "velox/experimental/stateful/agg/GroupAggregator.h" +#include "velox/experimental/stateful/rank/AppendOnlyTopNRanker.h" +#include "velox/experimental/stateful/rank/RowTimeDeduplicateRanker.h" +#include "velox/experimental/stateful/window/GroupWindowAggsHandler.h" namespace facebook::velox::stateful { -static std::atomic opId = 0; +static int nextOperatorId() { + static std::atomic opId = 0; + return opId.fetch_add(1); +} // static StatefulOperatorPtr StatefulPlanner::plan( const core::PlanFragment& planFragment, exec::DriverCtx* ctx, StateBackend* stateBackend) { - return nodeToStatefulOperator(planFragment.planNode, ctx, stateBackend); + StatefulPlanner planner(ctx, stateBackend); + return planner.transformStatefulOperators(planFragment.planNode); } -//static -StatefulOperatorPtr StatefulPlanner::nodeToStatefulOperator( - const core::PlanNodePtr& planNode, - exec::DriverCtx* ctx, - StateBackend* stateBackend) { +StatefulOperatorPtr StatefulPlanner::transformStatefulOperators( + const core::PlanNodePtr& planNode) { auto statefulNode = std::dynamic_pointer_cast(planNode); VELOX_CHECK(statefulNode, "Not stateful node: {}", planNode->toString()); - std::vector targets; - std::unique_ptr op = std::move(nodeToOperator(statefulNode->node(), ctx)); - for (auto target : statefulNode->targets()) { - targets.push_back(std::move(nodeToStatefulOperator(target, ctx, stateBackend))); + StatefulOperatorPtr result; + if (std::dynamic_pointer_cast( + statefulNode->node()) != nullptr) { + result = transformWatermarkAssignerOperator(*statefulNode); + } else if ( + std::dynamic_pointer_cast( + statefulNode->node()) != nullptr) { + result = transformStreamPartitionOperator(*statefulNode); + } else if ( + std::dynamic_pointer_cast(statefulNode->node()) != + nullptr) { + result = transformStreamJoinOperator(*statefulNode); + } else if ( + std::dynamic_pointer_cast( + statefulNode->node()) != nullptr) { + result = transformStreamWindowJoinOperator(*statefulNode); + } else if ( + std::dynamic_pointer_cast( + statefulNode->node()) != nullptr) { + result = transformStreamWindowAggregationOperator(*statefulNode); + } else if ( + std::dynamic_pointer_cast( + statefulNode->node()) != nullptr) { + result = transformGroupWindowAggregationOperator(*statefulNode); + } else if ( + std::dynamic_pointer_cast(statefulNode->node()) != + nullptr) { + result = transformStreamRankOperator(*statefulNode); + } else if ( + std::dynamic_pointer_cast( + statefulNode->node()) != nullptr) { + result = transformGroupAggregationOperator(*statefulNode); + } else { + result = transformGenericOperator(*statefulNode); } - if (auto watermarkAssignerNode = - std::dynamic_pointer_cast(statefulNode->node())) { - return std::make_unique( - std::move(op), - std::move(targets), - watermarkAssignerNode->idleTimeout(), - watermarkAssignerNode->rowtimeFieldIndex(), - watermarkAssignerNode->watermarkInterval()); - } else if (auto partitionNode = - std::dynamic_pointer_cast(statefulNode->node())) { - VELOX_CHECK(targets.size() == 0, "StreamPartitionNode should have no targets"); - int numPartitions = partitionNode->numPartitions(); - return std::make_unique( - std::move(op), - partitionNode->partition()->partitionFunctionSpec(), - numPartitions); - } else if ( - auto joinNode = - std::dynamic_pointer_cast(statefulNode->node())) { - VELOX_CHECK(joinNode->sources().size() == 2, "StreamJoinNode should have 2 sources"); - std::unique_ptr left = std::move(nodeToOperator(joinNode->sources()[0], ctx)); - std::unique_ptr right = std::move(nodeToOperator(joinNode->sources()[1], ctx)); - std::unique_ptr leftKeySelector = - std::make_unique( - std::move(joinNode->leftPartFuncSpec()->create(INT_MAX, false)), - op->pool(), - joinNode->numPartitions()); - std::unique_ptr rightKeySelector = - std::make_unique( - std::move(joinNode->rightPartFuncSpec()->create(INT_MAX, false)), - op->pool(), - joinNode->numPartitions()); - return std::make_unique( - std::move(left), - std::move(right), - std::move(leftKeySelector), - std::move(rightKeySelector), - std::move(op), - std::move(targets)); - } else if ( - auto joinNode = - std::dynamic_pointer_cast(statefulNode->node())) { - VELOX_CHECK(joinNode->sources().size() == 2, "StreamWindowJoinNode should have 2 sources"); - std::unique_ptr left = std::move(nodeToOperator(joinNode->sources()[0], ctx)); - std::unique_ptr right = std::move(nodeToOperator(joinNode->sources()[1], ctx)); - std::unique_ptr leftKeySelector = - std::make_unique( - std::move(joinNode->leftPartFuncSpec()->create(INT_MAX, false)), - op->pool(), - joinNode->numPartitions()); - std::unique_ptr rightKeySelector = - std::make_unique( - std::move(joinNode->rightPartFuncSpec()->create(INT_MAX, false)), - op->pool(), - joinNode->numPartitions()); - return std::make_unique( - std::move(left), - std::move(right), - std::move(leftKeySelector), - std::move(rightKeySelector), + VELOX_CHECK( + result, "Failed to build operator for node: {}", planNode->toString()); + return result; +} + +std::vector StatefulPlanner::transformStatefulOperators( + const std::vector& targets) { + std::vector operators; + operators.resize(targets.size()); + std::transform( + targets.begin(), + targets.end(), + operators.begin(), + [this](const core::PlanNodePtr& target) { + return transformStatefulOperators(target); + }); + return operators; +} + +StatefulOperatorPtr StatefulPlanner::transformWatermarkAssignerOperator( + const StatefulPlanNode& planNode) { + std::vector targets = + transformStatefulOperators(planNode.targets()); + + auto watermarkAssignerNode = + std::dynamic_pointer_cast(planNode.node()); + + auto op = std::make_unique( + nextOperatorId(), ctx_, nullptr, watermarkAssignerNode->project()); + + return std::make_unique( + std::move(op), + std::move(targets), + watermarkAssignerNode->idleTimeout(), + watermarkAssignerNode->rowtimeFieldIndex(), + watermarkAssignerNode->watermarkInterval()); +} + +StatefulOperatorPtr StatefulPlanner::transformStreamPartitionOperator( + const StatefulPlanNode& planNode) { + std::vector targets = + transformStatefulOperators(planNode.targets()); + VELOX_CHECK(targets.empty(), "StreamPartitionNode should have no targets"); + + auto partitionNode = + std::dynamic_pointer_cast(planNode.node()); + VELOX_CHECK(partitionNode, "Failed to cast to StreamPartitionNode"); + + auto op = std::make_unique( + nextOperatorId(), ctx_, partitionNode->partition()); + + return std::make_unique( + std::move(op), + partitionNode->partition()->partitionFunctionSpec(), + partitionNode->numPartitions()); +} + +StatefulOperatorPtr StatefulPlanner::transformStreamJoinOperator( + const StatefulPlanNode& planNode) { + std::vector targets = + transformStatefulOperators(planNode.targets()); + + auto joinNode = + std::dynamic_pointer_cast(planNode.node()); + VELOX_CHECK(joinNode, "Failed to cast to StreamJoinNode"); + VELOX_CHECK( + joinNode->sources().size() == 2, "StreamJoinNode should have 2 sources"); + + std::unique_ptr left = + transformOperator(joinNode->sources()[0]); + std::unique_ptr right = + transformOperator(joinNode->sources()[1]); + std::unique_ptr probe = transformOperator(joinNode->probe()); + + std::unique_ptr leftKeySelector = std::make_unique( + joinNode->leftPartFuncSpec()->create(INT_MAX, false), + probe->pool(), + joinNode->numPartitions()); + std::unique_ptr rightKeySelector = std::make_unique( + joinNode->rightPartFuncSpec()->create(INT_MAX, false), + probe->pool(), + joinNode->numPartitions()); + + return std::make_unique( + std::move(left), + std::move(right), + std::move(leftKeySelector), + std::move(rightKeySelector), + std::move(probe), + std::move(targets)); +} + +StatefulOperatorPtr StatefulPlanner::transformStreamWindowJoinOperator( + const StatefulPlanNode& planNode) { + std::vector targets = + transformStatefulOperators(planNode.targets()); + + auto joinNode = + std::dynamic_pointer_cast(planNode.node()); + VELOX_CHECK(joinNode, "Failed to cast to StreamWindowJoinNode"); + VELOX_CHECK( + joinNode->sources().size() == 2, + "StreamWindowJoinNode should have 2 sources"); + + std::unique_ptr left = + transformOperator(joinNode->sources()[0]); + std::unique_ptr right = + transformOperator(joinNode->sources()[1]); + std::unique_ptr probe = transformOperator(joinNode->probe()); + + std::unique_ptr leftKeySelector = std::make_unique( + joinNode->leftPartFuncSpec()->create(INT_MAX, false), + probe->pool(), + joinNode->numPartitions()); + std::unique_ptr rightKeySelector = std::make_unique( + joinNode->rightPartFuncSpec()->create(INT_MAX, false), + probe->pool(), + joinNode->numPartitions()); + + return std::make_unique( + std::move(left), + std::move(right), + std::move(leftKeySelector), + std::move(rightKeySelector), + std::move(probe), + std::move(targets), + joinNode->leftWindowEndIndex(), + joinNode->rightWindowEndIndex()); +} + +StatefulOperatorPtr StatefulPlanner::transformStreamWindowAggregationOperator( + const StatefulPlanNode& planNode) { + std::vector targets = + transformStatefulOperators(planNode.targets()); + + auto windowAggNode = + std::dynamic_pointer_cast( + planNode.node()); + VELOX_CHECK(windowAggNode, "Failed to cast to StreamWindowAggregationNode"); + + auto op = transformOperator(windowAggNode->aggregation()); + + std::unique_ptr keySelector = std::make_unique( + windowAggNode->keySelectorSpec()->create(INT_MAX, true), op->pool()); + std::unique_ptr sliceAssigner = std::make_unique( + windowAggNode->sliceAssignerSpec()->create(INT_MAX, true), op->pool()); + + if (windowAggNode->isLocalAgg()) { + return std::make_unique( std::move(op), std::move(targets), - joinNode->leftWindowEndIndex(), - joinNode->rightWindowEndIndex()); - } else if ( - auto windowAggNode = - std::dynamic_pointer_cast(statefulNode->node())) { - std::unique_ptr keySelector = - std::make_unique( - std::move(windowAggNode->keySelectorSpec()->create(INT_MAX, true)), - op->pool()); - std::unique_ptr sliceAssigner = - std::make_unique( - std::move(windowAggNode->sliceAssignerSpec()->create(INT_MAX, true)), - op->pool()); - if (windowAggNode->isLocalAgg()) { - return std::make_unique( - std::move(op), - std::move(targets), - std::move(keySelector), - std::move(sliceAssigner), - windowAggNode->windowInterval(), - windowAggNode->useDayLightSaving(), - windowAggNode->outputType()); - } else { - auto localAggregator = nodeToOperator(windowAggNode->localAgg(), ctx); - std::unique_ptr globalSliceAssigner = - std::make_unique( - std::move(sliceAssigner), - windowAggNode->size(), - windowAggNode->step(), - windowAggNode->offset(), - windowAggNode->windowType(), - windowAggNode->rowtimeIndex()); - return std::make_unique( - std::move(localAggregator), - std::move(op), - std::move(targets), - std::move(keySelector), - std::move(globalSliceAssigner), - windowAggNode->windowInterval(), - windowAggNode->useDayLightSaving()); - } - } else if ( - auto windowAggNode = - std::dynamic_pointer_cast(statefulNode->node())) { - std::unique_ptr keySelector = - std::make_unique( - std::move(windowAggNode->keySelectorSpec()->create(INT_MAX, true)), - op->pool()); - std::unique_ptr sliceAssigner = - std::make_unique( - std::move(windowAggNode->sliceAssignerSpec()->create(INT_MAX, true)), - op->pool()); - std::unique_ptr windowAssigner = + std::move(keySelector), + std::move(sliceAssigner), + windowAggNode->windowInterval(), + windowAggNode->useDayLightSaving(), + windowAggNode->outputType()); + } else { + auto localAggregator = transformOperator(windowAggNode->localAgg()); + std::unique_ptr globalSliceAssigner = std::make_unique( std::move(sliceAssigner), - 0, - 0, - 0, + windowAggNode->size(), + windowAggNode->step(), + windowAggNode->offset(), windowAggNode->windowType(), windowAggNode->rowtimeIndex()); - return std::make_unique( - std::unique_ptr(dynamic_cast(op.release())), - // TODO: support window parameters - std::make_unique(10, windowAggNode->isEventTime()), - std::move(targets), - std::move(keySelector), - std::move(windowAssigner), - windowAggNode->allowedLateness(), - windowAggNode->produceUpdates(), - windowAggNode->rowtimeIndex(), - windowAggNode->isEventTime()); - } else if ( - auto rankNode = - std::dynamic_pointer_cast(statefulNode->node())) { - std::unique_ptr keySelector = - std::make_unique( - std::move(rankNode->keySelectorSpec()->create(INT_MAX, true)), - op->pool()); - return std::make_unique( - std::move(op), - std::move(keySelector), - std::move(targets)); - } else if ( - auto aggNode = - std::dynamic_pointer_cast(statefulNode->node())) { - std::unique_ptr keySelector = - std::make_unique( - std::move(aggNode->keySelectorSpec()->create(INT_MAX, true)), - op->pool()); - return std::make_unique( + return std::make_unique( + std::move(localAggregator), std::move(op), + std::move(targets), std::move(keySelector), - std::move(targets)); + std::move(globalSliceAssigner), + windowAggNode->windowInterval(), + windowAggNode->useDayLightSaving()); } +} + +StatefulOperatorPtr StatefulPlanner::transformGroupWindowAggregationOperator( + const StatefulPlanNode& planNode) { + std::vector targets = + transformStatefulOperators(planNode.targets()); + + auto windowAggNode = + std::dynamic_pointer_cast( + planNode.node()); + VELOX_CHECK(windowAggNode, "Failed to cast to GroupWindowAggregationNode"); + + auto op = transformOperator(windowAggNode->aggregation()); + + std::unique_ptr keySelector = std::make_unique( + windowAggNode->keySelectorSpec()->create(INT_MAX, true), op->pool()); + std::unique_ptr sliceAssigner = std::make_unique( + windowAggNode->sliceAssignerSpec()->create(INT_MAX, true), op->pool()); + std::unique_ptr windowAssigner = + std::make_unique( + std::move(sliceAssigner), + 0, + 0, + 0, + windowAggNode->windowType(), + windowAggNode->rowtimeIndex()); + + return std::make_unique( + std::unique_ptr( + dynamic_cast(op.release())), + // TODO: support window parameters + std::make_unique(10, windowAggNode->isEventTime()), + std::move(targets), + std::move(keySelector), + std::move(windowAssigner), + windowAggNode->allowedLateness(), + windowAggNode->produceUpdates(), + windowAggNode->rowtimeIndex(), + windowAggNode->isEventTime()); +} + +StatefulOperatorPtr StatefulPlanner::transformStreamRankOperator( + const StatefulPlanNode& planNode) { + std::vector targets = + transformStatefulOperators(planNode.targets()); + + auto rankNode = + std::dynamic_pointer_cast(planNode.node()); + VELOX_CHECK(rankNode, "Failed to cast to StreamRankNode"); + + auto op = transformOperator(rankNode->ranker()); + + std::unique_ptr keySelector = std::make_unique( + rankNode->keySelectorSpec()->create(INT_MAX, true), op->pool()); + + return std::make_unique( + std::move(op), std::move(keySelector), std::move(targets)); +} + +StatefulOperatorPtr StatefulPlanner::transformGroupAggregationOperator( + const StatefulPlanNode& planNode) { + std::vector targets = + transformStatefulOperators(planNode.targets()); + + auto aggNode = + std::dynamic_pointer_cast(planNode.node()); + VELOX_CHECK(aggNode, "Failed to cast to GroupAggregationNode"); + + auto op = transformOperator(aggNode->aggregation()); + + std::unique_ptr keySelector = std::make_unique( + aggNode->keySelectorSpec()->create(INT_MAX, true), op->pool()); + return std::make_unique( + std::move(op), std::move(keySelector), std::move(targets)); +} + +StatefulOperatorPtr StatefulPlanner::transformGenericOperator( + const StatefulPlanNode& planNode) { + std::vector targets = + transformStatefulOperators(planNode.targets()); + std::unique_ptr op = transformOperator(planNode.node()); return std::make_unique(std::move(op), std::move(targets)); } -//static -std::unique_ptr StatefulPlanner::nodeToOperator( - const core::PlanNodePtr& planNode, - exec::DriverCtx* ctx) { +std::unique_ptr StatefulPlanner::transformOperator( + const core::PlanNodePtr& planNode) { if (auto filterNode = - std::dynamic_pointer_cast(planNode)) { + std::dynamic_pointer_cast(planNode)) { if (planNode->sources().size() == 1) { auto next = planNode->sources()[0]; if (auto projectNode = - std::dynamic_pointer_cast(next)) { + std::dynamic_pointer_cast(next)) { return std::make_unique( - opId.fetch_add(1), - ctx, - filterNode, - projectNode); + nextOperatorId(), ctx_, filterNode, projectNode); } } - return std::make_unique(opId.fetch_add(1), ctx, filterNode, nullptr); + return std::make_unique( + nextOperatorId(), ctx_, filterNode, nullptr); } else if ( auto projectNode = std::dynamic_pointer_cast(planNode)) { std::shared_ptr filterNode = nullptr; const std::vector& sources = projectNode->sources(); if (sources.size() == 1) { - filterNode = std::dynamic_pointer_cast(sources[0]); + filterNode = + std::dynamic_pointer_cast(sources[0]); } - return std::make_unique(opId.fetch_add(1), ctx, filterNode, projectNode); - } else if ( - auto joinNode = - std::dynamic_pointer_cast(planNode)) { - return nodeToOperator(joinNode->probe(), ctx); - } else if ( - auto partitionNode = - std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, partitionNode->partition()); + return std::make_unique( + nextOperatorId(), ctx_, filterNode, projectNode); } else if ( auto valuesNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, valuesNode); + return std::make_unique(nextOperatorId(), ctx_, valuesNode); } else if ( auto tableScanNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, tableScanNode); + return std::make_unique( + nextOperatorId(), ctx_, tableScanNode); } else if ( auto tableWriteNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, tableWriteNode); + return std::make_unique( + nextOperatorId(), ctx_, tableWriteNode); } else if ( auto tableWriteMergeNode = - std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, tableWriteMergeNode); + std::dynamic_pointer_cast( + planNode)) { + return std::make_unique( + nextOperatorId(), ctx_, tableWriteMergeNode); } else if ( auto joinNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, joinNode); + return std::make_unique(nextOperatorId(), ctx_, joinNode); } else if ( auto joinNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, joinNode); + return std::make_unique( + nextOperatorId(), ctx_, joinNode); } else if ( auto joinNode = - std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, joinNode); + std::dynamic_pointer_cast( + planNode)) { + return std::make_unique( + nextOperatorId(), ctx_, joinNode); } else if ( auto aggregationNode = std::dynamic_pointer_cast(planNode)) { if (aggregationNode->isPreGrouped()) { - return std::make_unique(opId.fetch_add(1), ctx, aggregationNode); + return std::make_unique( + nextOperatorId(), ctx_, aggregationNode); } else { - return std::make_unique(opId.fetch_add(1), ctx, aggregationNode); + return std::make_unique( + nextOperatorId(), ctx_, aggregationNode); } } else if ( auto expandNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, expandNode); + return std::make_unique(nextOperatorId(), ctx_, expandNode); } else if ( auto groupIdNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, groupIdNode); + return std::make_unique(nextOperatorId(), ctx_, groupIdNode); } else if ( auto topNNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, topNNode); + return std::make_unique(nextOperatorId(), ctx_, topNNode); } else if ( auto limitNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, limitNode); + return std::make_unique(nextOperatorId(), ctx_, limitNode); } else if ( auto orderByNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, orderByNode); + return std::make_unique(nextOperatorId(), ctx_, orderByNode); } else if ( auto windowNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, windowNode); + return std::make_unique(nextOperatorId(), ctx_, windowNode); } else if ( auto rowNumberNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, rowNumberNode); + return std::make_unique( + nextOperatorId(), ctx_, rowNumberNode); } else if ( auto topNRowNumberNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, topNRowNumberNode); + return std::make_unique( + nextOperatorId(), ctx_, topNRowNumberNode); } else if ( auto markDistinctNode = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, markDistinctNode); + return std::make_unique( + nextOperatorId(), ctx_, markDistinctNode); } else if ( auto mergeJoin = std::dynamic_pointer_cast(planNode)) { - auto mergeJoinOp = std::make_unique(opId.fetch_add(1), ctx, mergeJoin); - ctx->task->createMergeJoinSource(ctx->splitGroupId, mergeJoin->id()); + auto mergeJoinOp = + std::make_unique(nextOperatorId(), ctx_, mergeJoin); + ctx_->task->createMergeJoinSource(ctx_->splitGroupId, mergeJoin->id()); return std::move(mergeJoinOp); } else if ( auto unnest = std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, unnest); + return std::make_unique(nextOperatorId(), ctx_, unnest); } else if ( auto enforceSingleRow = - std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, enforceSingleRow); + std::dynamic_pointer_cast( + planNode)) { + return std::make_unique( + nextOperatorId(), ctx_, enforceSingleRow); } else if ( auto assignUniqueIdNode = std::dynamic_pointer_cast(planNode)) { return std::make_unique( - opId.fetch_add(1), - ctx, + nextOperatorId(), + ctx_, assignUniqueIdNode, assignUniqueIdNode->taskUniqueId(), assignUniqueIdNode->uniqueIdCounter()); } else if ( - auto watermarkAssignerNode = - std::dynamic_pointer_cast(planNode)) { - return std::make_unique( - opId.fetch_add(1), - ctx, - nullptr, - watermarkAssignerNode->project()); - } else if ( - auto joinNode = - std::dynamic_pointer_cast(planNode)) { - return nodeToOperator(joinNode->probe(), ctx); - } else if ( - auto windowAggNode = - std::dynamic_pointer_cast(planNode)) { - return nodeToOperator(windowAggNode->aggregation(), ctx); - } else if ( - auto windowAggNode = - std::dynamic_pointer_cast(planNode)) { - return nodeToOperator(windowAggNode->aggregation(), ctx); + auto aggsHandlerNode = + std::dynamic_pointer_cast( + planNode)) { + return std::make_unique( + nextOperatorId(), ctx_, aggsHandlerNode); } else if ( auto aggsHandlerNode = - std::dynamic_pointer_cast(planNode)) { - return std::make_unique(opId.fetch_add(1), ctx, aggsHandlerNode); + std::dynamic_pointer_cast(planNode)) { + // FIXME: stateRetentionTime is not handled yet + return std::make_unique( + nextOperatorId(), + ctx_, + aggsHandlerNode, + std::make_unique(), // TODO: not complete yet + 0, + aggsHandlerNode->generateUpdateBefore()); } else if ( auto deduplicateNode = std::dynamic_pointer_cast(planNode)) { return std::make_unique( - opId.fetch_add(1), - ctx, + nextOperatorId(), + ctx_, deduplicateNode, - deduplicateNode->rowtimeIndex(), deduplicateNode->minRetentionTime(), + deduplicateNode->rowtimeIndex(), deduplicateNode->generateUpdateBefore(), deduplicateNode->generateInsert(), deduplicateNode->keepLastRow()); } else if ( auto topNNode = std::dynamic_pointer_cast(planNode)) { - auto op = nodeToOperator(topNNode->topN(), ctx); - std::unique_ptr sortKeySelector = - std::make_unique( - std::move(topNNode->sortKeySelectorSpec()->create(INT_MAX, true)), - op->pool()); + auto op = transformOperator(topNNode->topN()); + std::shared_ptr sortKeySelector = + std::make_shared( + topNNode->sortKeySelectorSpec()->create(INT_MAX, true), op->pool()); return std::make_unique( - opId.fetch_add(1), - ctx, + nextOperatorId(), + ctx_, topNNode, std::move(op), - std::move(sortKeySelector), + sortKeySelector, topNNode->generateUpdateBefore(), topNNode->outputRankNumber(), topNNode->cacheSize()); - } else if ( - auto rankNode = - std::dynamic_pointer_cast(planNode)) { - return nodeToOperator(rankNode->ranker(), ctx); - } else if ( - auto groupAggNode = - std::dynamic_pointer_cast(planNode)) { - return nodeToOperator(groupAggNode->aggregation(), ctx); - } else if ( - auto aggsHandlerNode = - std::dynamic_pointer_cast(planNode)) { - return std::make_unique( - opId.fetch_add(1), - ctx, - aggsHandlerNode, - std::make_unique(), // TODO: not complete yet - aggsHandlerNode->generateUpdateBefore(), - aggsHandlerNode->needRetraction()); - } else { - std::unique_ptr extended; - extended = exec::Operator::fromPlanNode(ctx, opId.fetch_add(1), planNode); - VELOX_CHECK(extended, "Unsupported plan node: {}", planNode->toString()); - return extended; } + std::unique_ptr extended = + exec::Operator::fromPlanNode(ctx_, nextOperatorId(), planNode); + VELOX_CHECK( + extended, + "Unsupported plan node: {}\n{}", + planNode->toString(), + process::StackTrace().toString()); + return extended; } } // namespace facebook::velox::stateful diff --git a/velox/experimental/stateful/StatefulPlanner.h b/velox/experimental/stateful/StatefulPlanner.h index 06f4aea5b271..afebee685452 100644 --- a/velox/experimental/stateful/StatefulPlanner.h +++ b/velox/experimental/stateful/StatefulPlanner.h @@ -17,6 +17,7 @@ #include "velox/exec/Operator.h" #include "velox/experimental/stateful/StatefulOperator.h" +#include "velox/experimental/stateful/StatefulPlanNode.h" #include "velox/experimental/stateful/state/StateBackend.h" namespace facebook::velox::core { @@ -26,7 +27,6 @@ struct PlanFragment; namespace facebook::velox::stateful { class StatefulPlanner { - public: // Create stateful operator chain according to plan. static StatefulOperatorPtr plan( @@ -34,14 +34,37 @@ class StatefulPlanner { exec::DriverCtx* ctx, StateBackend* stateBackend); + protected: + StatefulPlanner(exec::DriverCtx* ctx, StateBackend* stateBackend) + : ctx_(ctx), stateBackend_(stateBackend) {} + private: - static std::unique_ptr nodeToStatefulOperator( - const core::PlanNodePtr& planNode, - exec::DriverCtx* ctx, - StateBackend* stateBackend); + exec::DriverCtx* ctx_ = nullptr; + StateBackend* stateBackend_ = nullptr; - static std::unique_ptr nodeToOperator( - const core::PlanNodePtr& planNode, - exec::DriverCtx* ctx); + StatefulOperatorPtr transformStatefulOperators( + const core::PlanNodePtr& planNode); + std::vector transformStatefulOperators( + const std::vector& targets); + std::unique_ptr transformOperator( + const core::PlanNodePtr& planNode); + StatefulOperatorPtr transformWatermarkAssignerOperator( + const StatefulPlanNode& planNode); + StatefulOperatorPtr transformStreamPartitionOperator( + const StatefulPlanNode& planNode); + StatefulOperatorPtr transformStreamJoinOperator( + const StatefulPlanNode& planNode); + StatefulOperatorPtr transformStreamWindowJoinOperator( + const StatefulPlanNode& planNode); + StatefulOperatorPtr transformStreamWindowAggregationOperator( + const StatefulPlanNode& planNode); + StatefulOperatorPtr transformGroupWindowAggregationOperator( + const StatefulPlanNode& planNode); + StatefulOperatorPtr transformStreamRankOperator( + const StatefulPlanNode& planNode); + StatefulOperatorPtr transformGroupAggregationOperator( + const StatefulPlanNode& planNode); + StatefulOperatorPtr transformGenericOperator( + const StatefulPlanNode& planNode); }; } // namespace facebook::velox::stateful