From e35c96539461b2915a90e5826dd946385892ef7a Mon Sep 17 00:00:00 2001 From: Helge Bahmann Date: Tue, 7 Jan 2025 21:09:26 +0100 Subject: [PATCH] Simplify HLS loop dead node elimination After removing theta invariant violation from dead node elimination from llvm/opt, try to make HLS handling similarly regular. --- .../backend/rvsdg2rhls/UnusedStateRemoval.cpp | 33 +++++++------------ .../rvsdg2rhls/distribute-constants.cpp | 6 ++-- .../rvsdg2rhls/UnusedStateRemovalTests.cpp | 7 +--- 3 files changed, 15 insertions(+), 31 deletions(-) diff --git a/jlm/hls/backend/rvsdg2rhls/UnusedStateRemoval.cpp b/jlm/hls/backend/rvsdg2rhls/UnusedStateRemoval.cpp index fd164d1a2..51b8610f0 100644 --- a/jlm/hls/backend/rvsdg2rhls/UnusedStateRemoval.cpp +++ b/jlm/hls/backend/rvsdg2rhls/UnusedStateRemoval.cpp @@ -13,6 +13,12 @@ namespace jlm::hls { +static bool +IsPassthroughLoopVar(const rvsdg::ThetaNode::LoopVar & loopvar) +{ + return loopvar.pre->nusers() == 1 && loopvar.post->origin() == loopvar.pre; +} + static bool IsPassthroughArgument(const rvsdg::output & argument) { @@ -114,22 +120,6 @@ RemoveUnusedStatesFromLambda(llvm::lambda::node & lambdaNode) remove(&lambdaNode); } -static void -RemovePassthroughArgument(const rvsdg::RegionArgument & argument) -{ - auto origin = argument.input()->origin(); - auto result = dynamic_cast(*argument.begin()); - argument.region()->node()->output(result->output()->index())->divert_users(origin); - - auto inputIndex = argument.input()->index(); - auto outputIndex = result->output()->index(); - auto region = argument.region(); - region->RemoveResult(result->index()); - region->RemoveArgument(argument.index()); - region->node()->RemoveInput(inputIndex); - region->node()->RemoveOutput(outputIndex); -} - static void RemoveUnusedStatesFromGammaNode(rvsdg::GammaNode & gammaNode) { @@ -177,15 +167,16 @@ RemoveUnusedStatesFromGammaNode(rvsdg::GammaNode & gammaNode) static void RemoveUnusedStatesFromThetaNode(rvsdg::ThetaNode & thetaNode) { - auto thetaSubregion = thetaNode.subregion(); - for (int i = thetaSubregion->narguments() - 1; i >= 0; --i) + std::vector loopvars; + for (const auto & loopvar : thetaNode.GetLoopVars()) { - auto & argument = *thetaSubregion->argument(i); - if (IsPassthroughArgument(argument)) + if (IsPassthroughLoopVar(loopvar)) { - RemovePassthroughArgument(argument); + loopvar.output->divert_users(loopvar.input->origin()); + loopvars.push_back(loopvar); } } + thetaNode.RemoveLoopVars(std::move(loopvars)); } static void diff --git a/jlm/hls/backend/rvsdg2rhls/distribute-constants.cpp b/jlm/hls/backend/rvsdg2rhls/distribute-constants.cpp index 7fd439520..4aa20ccf9 100644 --- a/jlm/hls/backend/rvsdg2rhls/distribute-constants.cpp +++ b/jlm/hls/backend/rvsdg2rhls/distribute-constants.cpp @@ -36,10 +36,8 @@ distribute_constant(const rvsdg::SimpleOperation & op, rvsdg::simple_output * ou loopvar.output->divert_users( rvsdg::SimpleNode::create_normalized(out->region(), op, {})[0]); distribute_constant(op, arg_replacement); - theta->subregion()->RemoveResult(loopvar.post->index()); - theta->subregion()->RemoveArgument(loopvar.pre->index()); - theta->RemoveInput(loopvar.input->index()); - theta->RemoveOutput(loopvar.output->index()); + loopvar.post->divert_to(loopvar.pre); + theta->RemoveLoopVars({ loopvar }); changed = true; break; } diff --git a/tests/jlm/hls/backend/rvsdg2rhls/UnusedStateRemovalTests.cpp b/tests/jlm/hls/backend/rvsdg2rhls/UnusedStateRemovalTests.cpp index 8320c0d4c..05ef21789 100644 --- a/tests/jlm/hls/backend/rvsdg2rhls/UnusedStateRemovalTests.cpp +++ b/tests/jlm/hls/backend/rvsdg2rhls/UnusedStateRemovalTests.cpp @@ -112,12 +112,7 @@ TestTheta() jlm::hls::RemoveUnusedStates(*rvsdgModule); // Assert - // This assert is only here so that we do not forget this test when we refactor the code - assert(thetaNode->ninputs() == 1); - - // FIXME: This transformation is broken for theta nodes. For the setup above, it - // removes all inputs/outputs, except the predicate. However, the only - // input and output it should remove are input 1 and output 0, respectively. + assert(thetaNode->ninputs() == 3); } static void