diff --git a/jlm/hls/backend/rvsdg2rhls/UnusedStateRemoval.cpp b/jlm/hls/backend/rvsdg2rhls/UnusedStateRemoval.cpp index 550fb5da1..035ccb162 100644 --- a/jlm/hls/backend/rvsdg2rhls/UnusedStateRemoval.cpp +++ b/jlm/hls/backend/rvsdg2rhls/UnusedStateRemoval.cpp @@ -14,6 +14,12 @@ namespace jlm::hls { +static bool +IsPassthroughLoopVar(const rvsdg::ThetaNode::LoopVar & loopvar) +{ + return loopvar.pre->nusers() == 1 && loopvar.post->origin() == loopvar.pre; +} + static bool IsPassthroughArgument(const rvsdg::output & argument) { @@ -113,22 +119,6 @@ RemoveUnusedStatesFromLambda(rvsdg::LambdaNode & lambdaNode) remove(&lambdaNode); } -static void -RemovePassthroughArgument(const rvsdg::RegionArgument & argument) -{ - auto origin = argument.input()->origin(); - auto result = dynamic_cast(*argument.begin()); - argument.region()->node()->output(result->output()->index())->divert_users(origin); - - auto inputIndex = argument.input()->index(); - auto outputIndex = result->output()->index(); - auto region = argument.region(); - region->RemoveResult(result->index()); - region->RemoveArgument(argument.index()); - region->node()->RemoveInput(inputIndex); - region->node()->RemoveOutput(outputIndex); -} - static void RemoveUnusedStatesFromGammaNode(rvsdg::GammaNode & gammaNode) { @@ -176,15 +166,16 @@ RemoveUnusedStatesFromGammaNode(rvsdg::GammaNode & gammaNode) static void RemoveUnusedStatesFromThetaNode(rvsdg::ThetaNode & thetaNode) { - auto thetaSubregion = thetaNode.subregion(); - for (int i = thetaSubregion->narguments() - 1; i >= 0; --i) + std::vector loopvars; + for (const auto & loopvar : thetaNode.GetLoopVars()) { - auto & argument = *thetaSubregion->argument(i); - if (IsPassthroughArgument(argument)) + if (IsPassthroughLoopVar(loopvar)) { - RemovePassthroughArgument(argument); + loopvar.output->divert_users(loopvar.input->origin()); + loopvars.push_back(loopvar); } } + thetaNode.RemoveLoopVars(std::move(loopvars)); } static void diff --git a/jlm/hls/backend/rvsdg2rhls/distribute-constants.cpp b/jlm/hls/backend/rvsdg2rhls/distribute-constants.cpp index 966662fe8..5121f5e17 100644 --- a/jlm/hls/backend/rvsdg2rhls/distribute-constants.cpp +++ b/jlm/hls/backend/rvsdg2rhls/distribute-constants.cpp @@ -34,10 +34,8 @@ distribute_constant(const rvsdg::SimpleOperation & op, rvsdg::SimpleOutput * out loopvar.pre->divert_users(arg_replacement); loopvar.output->divert_users(out); distribute_constant(op, arg_replacement); - theta->subregion()->RemoveResult(loopvar.post->index()); - theta->subregion()->RemoveArgument(loopvar.pre->index()); - theta->RemoveInput(loopvar.input->index()); - theta->RemoveOutput(loopvar.output->index()); + loopvar.post->divert_to(loopvar.pre); + theta->RemoveLoopVars({ loopvar }); changed = true; break; } diff --git a/tests/jlm/hls/backend/rvsdg2rhls/UnusedStateRemovalTests.cpp b/tests/jlm/hls/backend/rvsdg2rhls/UnusedStateRemovalTests.cpp index 5ee08c2c2..3445182f2 100644 --- a/tests/jlm/hls/backend/rvsdg2rhls/UnusedStateRemovalTests.cpp +++ b/tests/jlm/hls/backend/rvsdg2rhls/UnusedStateRemovalTests.cpp @@ -115,12 +115,7 @@ TestTheta() jlm::hls::RemoveUnusedStates(*rvsdgModule); // Assert - // This assert is only here so that we do not forget this test when we refactor the code - assert(thetaNode->ninputs() == 1); - - // FIXME: This transformation is broken for theta nodes. For the setup above, it - // removes all inputs/outputs, except the predicate. However, the only - // input and output it should remove are input 1 and output 0, respectively. + assert(thetaNode->ninputs() == 3); } static void