diff --git a/jlm/hls/Makefile.sub b/jlm/hls/Makefile.sub index 3e1e1de41..3e5fc17be 100644 --- a/jlm/hls/Makefile.sub +++ b/jlm/hls/Makefile.sub @@ -33,9 +33,15 @@ libhls_SOURCES = \ jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp \ jlm/hls/backend/rvsdg2rhls/ThetaConversion.cpp \ jlm/hls/backend/rvsdg2rhls/UnusedStateRemoval.cpp \ + \ + jlm/hls/backend/rvsdg2rhls/static/ThetaConversion.cpp \ + jlm/hls/backend/rvsdg2rhls/static/rvsdg2rhls.cpp \ \ jlm/hls/ir/hls.cpp \ - \ + jlm/hls/ir/static-hls.cpp \ + jlm/hls/ir/static/loop.cpp \ + jlm/hls/ir/static/fsm.cpp \ + \ jlm/hls/util/view.cpp \ libhls_HEADERS = \ @@ -70,7 +76,13 @@ libhls_HEADERS = \ jlm/hls/backend/rvsdg2rhls/ThetaConversion.hpp \ jlm/hls/backend/rvsdg2rhls/UnusedStateRemoval.hpp \ \ + jlm/hls/backend/rvsdg2rhls/static/ThetaConversion.hpp \ + jlm/hls/backend/rvsdg2rhls/static/rvsdg2rhls.hpp \ + \ jlm/hls/ir/hls.hpp \ + jlm/hls/ir/static-hls.hpp \ + jlm/hls/ir/static/loop.hpp \ + jlm/hls/ir/static/fsm.hpp \ \ jlm/hls/util/view.hpp \ @@ -81,6 +93,8 @@ libhls_TESTS += \ tests/jlm/hls/backend/rvsdg2rhls/TestTheta \ tests/jlm/hls/backend/rvsdg2rhls/UnusedStateRemovalTests \ tests/jlm/hls/backend/rvsdg2rhls/test-loop-passthrough \ + \ + tests/jlm/hls/backend/rvsdg2rhls/static/TestTheta \ libhls_TEST_LIBS += \ libjlmtest \ diff --git a/jlm/hls/backend/rvsdg2rhls/static/ThetaConversion.cpp b/jlm/hls/backend/rvsdg2rhls/static/ThetaConversion.cpp new file mode 100644 index 000000000..a89f35adc --- /dev/null +++ b/jlm/hls/backend/rvsdg2rhls/static/ThetaConversion.cpp @@ -0,0 +1,123 @@ +/* + * Copyright 2024 Louis Maurin + * See COPYING for terms of redistribution. + */ + +#include +#include +#include + +#include + +namespace jlm::static_hls +{ + +void +addSimpleNodes(jlm::rvsdg::region & region, jlm::static_hls::loop_node & loop) +{ + for (auto & node : jlm::rvsdg::topdown_traverser(®ion)) + { + // FIXME now we only handle simple nodes + if (dynamic_cast(node)) + { + loop.add_node(node); + } + else + { + JLM_UNREACHABLE("Static HLS only support simple nodes in theta node at this point"); + } + } +} + +static void +ConvertThetaNode(jlm::rvsdg::theta_node & theta) +{ + std::cout << "***** Converting theta node *****" << std::endl; + + auto loop = static_hls::loop_node::create(theta.region()); + + // add loopvars and populate the smap + for (size_t i = 0; i < theta.ninputs(); i++) + { + loop->add_loopvar(theta.input(i)); + // divert theta outputs + theta.output(i)->divert_users(loop->output(i)); + } + + // copy contents of theta + addSimpleNodes(*theta.subregion(), *loop); + + for (size_t i = 0; i < theta.ninputs(); i++) + { + loop->add_loopback_arg(theta.input(i)); + } + + loop->print_nodes_registers(); + + std::cout << "**** Printing operations users ****" << std::endl; + for (auto & node : loop->compute_subregion()->nodes) + { + for (size_t i = 0; i < node.ninputs(); i++) + { + std::cout << "node " << node.operation().debug_string(); + std::cout << " input " << i; + std::cout << " users: "; + auto users = loop->get_users(node.input(i)); + for (auto & user : users) + { + if (auto node_out = dynamic_cast(user)) + { + std::cout << node_out->node()->operation().debug_string() << ", "; + } + else if (auto arg = dynamic_cast(user)) + { + std::cout << "control arg " << arg->index() << ", "; + } + } + std::cout << std::endl; + } + } + + loop->finalize(); + + // // copy contents of theta + // theta.subregion()->copy(loop->subregion(), smap, false, false); + remove(&theta); +} + +static void +ConvertThetaNodesInRegion(jlm::rvsdg::region & region); + +static void +ConvertThetaNodesInStructuralNode(jlm::rvsdg::structural_node & structuralNode) +{ + for (size_t n = 0; n < structuralNode.nsubregions(); n++) + { + ConvertThetaNodesInRegion(*structuralNode.subregion(n)); + } + + if (auto thetaNode = dynamic_cast(&structuralNode)) + { + ConvertThetaNode(*thetaNode); + } +} + +static void +ConvertThetaNodesInRegion(jlm::rvsdg::region & region) +{ + for (auto & node : jlm::rvsdg::topdown_traverser(®ion)) + { + if (auto structuralNode = dynamic_cast(node)) + { + ConvertThetaNodesInStructuralNode(*structuralNode); + } + } +} + +void +ConvertThetaNodes(jlm::llvm::RvsdgModule & rvsdgModule) +{ + ConvertThetaNodesInRegion(*rvsdgModule.Rvsdg().root()); +} + +} diff --git a/jlm/hls/backend/rvsdg2rhls/static/ThetaConversion.hpp b/jlm/hls/backend/rvsdg2rhls/static/ThetaConversion.hpp new file mode 100644 index 000000000..3cb1efb94 --- /dev/null +++ b/jlm/hls/backend/rvsdg2rhls/static/ThetaConversion.hpp @@ -0,0 +1,21 @@ +/* + * Copyright 2024 Louis Maurin + * See COPYING for terms of redistribution. + */ + +#ifndef JLM_HLS_BACKEND_RVSDG2RHLS_STATIC_THETACONVERSION_HPP +#define JLM_HLS_BACKEND_RVSDG2RHLS_STATIC_THETACONVERSION_HPP + +#include +#include +#include + +namespace jlm::static_hls +{ + +void +ConvertThetaNodes(jlm::llvm::RvsdgModule & rvsdgModule); + +} // namespace jlm::static_hls + +#endif // JLM_HLS_BACKEND_RVSDG2RHLS_STATIC_THETACONVERSION_HPP diff --git a/jlm/hls/backend/rvsdg2rhls/static/rvsdg2rhls.cpp b/jlm/hls/backend/rvsdg2rhls/static/rvsdg2rhls.cpp new file mode 100644 index 000000000..da8001486 --- /dev/null +++ b/jlm/hls/backend/rvsdg2rhls/static/rvsdg2rhls.cpp @@ -0,0 +1,21 @@ +/* + * Copyright 2024 Louis Maurin + * See COPYING for terms of redistribution. + */ + +#include +#include +#include + +namespace jlm::static_hls +{ + +using namespace jlm::static_hls; + +void +rvsdg2rhls(llvm::RvsdgModule & rvsdgModule) +{ + ConvertThetaNodes(rvsdgModule); +} + +} // namespace jlm::static_hls diff --git a/jlm/hls/backend/rvsdg2rhls/static/rvsdg2rhls.hpp b/jlm/hls/backend/rvsdg2rhls/static/rvsdg2rhls.hpp new file mode 100644 index 000000000..9e81a5ac9 --- /dev/null +++ b/jlm/hls/backend/rvsdg2rhls/static/rvsdg2rhls.hpp @@ -0,0 +1,20 @@ +/* + * Copyright 2024 Louis Maurin + * See COPYING for terms of redistribution. + */ + +#ifndef JLM_HLS_BACKEND_RVSDG2RHLS_STATIC_RVSDG2RHLS_HPP +#define JLM_HLS_BACKEND_RVSDG2RHLS_STATIC_RVSDG2RHLS_HPP + +#include +#include + +namespace jlm::static_hls +{ + +void +rvsdg2rhls(llvm::RvsdgModule & rvsdgModule); + +} + +#endif // JLM_HLS_BACKEND_RVSDG2RHLS_STATIC_RVSDG2RHLS_HPP diff --git a/jlm/hls/ir/static-hls.cpp b/jlm/hls/ir/static-hls.cpp new file mode 100644 index 000000000..9d7d28e66 --- /dev/null +++ b/jlm/hls/ir/static-hls.cpp @@ -0,0 +1,52 @@ +/* + * Copyright 2024 Louis Maurin + * See COPYING for terms of redistribution. + */ + +#include + +namespace jlm::static_hls +{ + +jlm::rvsdg::node * +mux_add_input(jlm::rvsdg::node * old_mux, jlm::rvsdg::output * new_input, bool predicate) +{ + JLM_ASSERT(jlm::rvsdg::is(old_mux->operation())); + + // If the input already exists, return the old mux + // FIXME and add to the fsm + for (size_t i = 0; i < old_mux->ninputs(); i++) + { + if (old_mux->input(i)->origin() == new_input) + { + return old_mux; + } + } + + std::vector new_mux_inputs; + for (size_t i = 0; i < old_mux->ninputs(); i++) + { + new_mux_inputs.push_back(old_mux->input(i)->origin()); + } + + jlm::rvsdg::simple_node * new_mux; + if (!predicate) + { + new_mux_inputs.push_back(new_input); + + new_mux = jlm::static_hls::mux_op::create(new_mux_inputs); + } + else + { + new_mux = jlm::static_hls::mux_op::create(*new_input, new_mux_inputs); + } + + old_mux->output(0)->divert_users(new_mux->output(0)); + + remove(old_mux); + + // old_mux = static_cast(new_mux); + return new_mux; +}; + +} // namespace jlm::static_hls diff --git a/jlm/hls/ir/static-hls.hpp b/jlm/hls/ir/static-hls.hpp new file mode 100644 index 000000000..cc06d5295 --- /dev/null +++ b/jlm/hls/ir/static-hls.hpp @@ -0,0 +1,250 @@ +/* + * Copyright 2024 Louis Maurin + * See COPYING for terms of redistribution. + */ + +#ifndef JLM_HLS_IR_STATIC_HLS_HPP +#define JLM_HLS_IR_STATIC_HLS_HPP + +// FIXME check what's needed +#include +#include +#include +#include +#include +#include + +namespace jlm::static_hls +{ + +// class mux_ctltype final : public jlm::rvsdg::statetype +// { +// public: +// virtual ~mux_ctltype() noexcept; + +// virtual std::string +// debug_string() const override; + +// virtual bool +// operator==(const jlm::rvsdg::type & other) const noexcept override; + +// std::shared_ptr +// copy() const override; + +// inline size_t +// nalternatives() const noexcept +// { +// return nalternatives_; +// } + +// /** +// * \brief Instantiates control type +// * +// * \returns Control type instance +// * +// * Creates an instance of a control type capable of representing +// * the specified number of alternatives. The returned instance +// * will usually be a static singleton for the type. +// */ +// static std::shared_ptr +// Create(); + +// private: +// size_t nalternatives_; +// }; + +class mux_op final : public jlm::rvsdg::simple_op +{ +private: + mux_op( + std::vector> & operands_type, + const std::shared_ptr & result_type) + // : jlm::rvsdg::simple_op(create_portvector(nalternatives, type), { type }) + : jlm::rvsdg::simple_op(operands_type, { result_type }) + {} + +public: + virtual ~mux_op() + {} + + inline size_t + nalternatives() const + { + return nalternatives_; + } + + inline size_t + has_predicate() const + { + return has_predicate_; + } + + std::string + debug_string() const override + { + return "SHLS_MUX"; + }; + + bool + operator==(const jlm::rvsdg::operation & other) const noexcept override + { + auto ot = dynamic_cast(&other); + // check predicate and value + return ot && ot->argument(0).type() == argument(0).type() + && ot->result(0).type() == result(0).type(); + }; + + std::unique_ptr + copy() const override + { + return std::unique_ptr(new mux_op(*this)); + }; + + static jlm::rvsdg::simple_node * + create(jlm::rvsdg::output & predicate, const std::vector & alternatives) + { + if (alternatives.empty()) + throw util::error("Insufficient number of operands."); + auto ctl = dynamic_cast(&predicate.type()); + if (!ctl) + throw util::error("Predicate needs to be a ctltype."); + if (alternatives.size() != ctl->nalternatives()) + throw util::error("Alternatives and predicate do not match."); + + auto region = predicate.region(); + auto operands = std::vector(); + operands.push_back(&predicate); + operands.insert(operands.end(), alternatives.begin(), alternatives.end()); + auto operands_type = std::vector>( + alternatives.size() + 1, + alternatives.front()->Type()); + operands_type.at(0) = jlm::rvsdg::ctltype::Create(alternatives.size()); + mux_op op(operands_type, alternatives.front()->Type()); + op.nalternatives_ = alternatives.size(); + op.has_predicate_ = true; + return jlm::rvsdg::simple_node::create(region, op, operands); + }; + + static jlm::rvsdg::simple_node * + create(const std::vector & alternatives) + { + if (alternatives.empty()) + throw util::error("Insufficient number of operands."); + + auto region = alternatives[0]->region(); + auto operands = std::vector(); + + // operands.push_back(); + operands.insert(operands.end(), alternatives.begin(), alternatives.end()); + auto operands_type = std::vector>( + alternatives.size(), + alternatives.front()->Type()); + mux_op op(operands_type, alternatives.front()->Type()); + op.nalternatives_ = alternatives.size(); + op.has_predicate_ = false; + return jlm::rvsdg::simple_node::create(region, op, operands); + }; + +private: + // static std::vector + // create_portvector(size_t nalternatives, const jlm::rvsdg::type & type) + // { + // auto vec = std::vector(nalternatives + 1, type); + // vec[0] = jlm::rvsdg::ctltype(nalternatives); + // return vec; + // }; + + size_t nalternatives_; + bool has_predicate_; +}; + +// TODO doc +/*! \brief Adds a input to a mux node. + * Internally create a new node and remove the old one !! + * \param old_mux The mux node to add the input to. + * \param new_input The input to add to the mux. + * \return The new mux node with the added new input. + */ +jlm::rvsdg::node * +mux_add_input(jlm::rvsdg::node * old_mux, jlm::rvsdg::output * new_input, bool predicate = false); + +// TODO doc +inline jlm::rvsdg::node * +mux_connect_predicate(jlm::rvsdg::node * old_mux, jlm::rvsdg::output * predicate) +{ + return mux_add_input(old_mux, predicate, true); +}; + +/*! \brief A register operation with a store predicate, data input and data output. + */ +extern size_t instances_count; + +class reg_op final : public jlm::rvsdg::simple_op +{ +public: + virtual ~reg_op() + {} + + reg_op(const std::shared_ptr & type) + : jlm::rvsdg::simple_op( + std::vector>{ jlm::rvsdg::ctltype::Create(2), + type }, + { type }) + // : jlm::rvsdg::simple_op(std::vector{type}, { type }) + { + id_ = instances_count++; + } + + std::string + debug_string() const override + { + if (origin_debug_string_.empty()) + return jlm::util::strfmt("SHLS_REG", id_); + return jlm::util::strfmt("SHLS_REG", id_, "(", origin_debug_string_, ")"); + } + + bool + operator==(const jlm::rvsdg::operation & other) const noexcept override + { + return true; // TODO check if that's how to do it + } + + std::unique_ptr + copy() const override + { + return std::unique_ptr(new reg_op(*this)); + } + + /*! \brief Creates a new register node. + * \param store_input The store predicate of the register. + * \param input The data input of the register. + * \return The newly created register node. + */ + static jlm::rvsdg::node * + create( + jlm::rvsdg::output & store_input, + jlm::rvsdg::output & input, + std::string origin_debug_string) + { + reg_op op(input.Type()); + op.origin_debug_string_ = origin_debug_string; + return jlm::rvsdg::simple_node::create(input.region(), op, { &store_input, &input }); + // return jlm::rvsdg::simple_node::create_normalized(input.region(), op, { &store_input, &input + // })[0]; + } + +private: + // static std::vector + // create_portvector(size_t nalternatives, const jlm::rvsdg::type & type) + // { + // auto vec = std::vector(nalternatives + 1, type); + // vec[0] = jlm::rvsdg::ctltype(nalternatives); + // return vec; + // } + size_t id_ = 0; // TODO delete this, only for debugging + std::string origin_debug_string_ = ""; +}; + +} // namespace jlm::static_hls + +#endif // JLM_HLS_IR_STATIC_HLS_HPP diff --git a/jlm/hls/ir/static/fsm.cpp b/jlm/hls/ir/static/fsm.cpp new file mode 100644 index 000000000..5a9bcafae --- /dev/null +++ b/jlm/hls/ir/static/fsm.cpp @@ -0,0 +1,242 @@ +/* + * Copyright 2024 Louis Maurin + * See COPYING for terms of redistribution. + */ + +#include +#include + +#include + +namespace jlm::static_hls +{ + +fsm_state * +fsm_state::create(fsm_node_temp * parent_fsm_node, size_t index) +{ + // The region is being created with a pointer to the parent node and an index + // But the region is not being added to the parent node's subregions list + auto fs = new fsm_state(parent_fsm_node, index); + return fs; +}; + +void +fsm_state::enable_reg(jlm::rvsdg::node * node) +{ + JLM_ASSERT(jlm::rvsdg::is(node->operation())); + JLM_ASSERT(dynamic_cast(node->input(0)->origin())); + auto fsm_output = static_cast(node->input(0)->origin()); + + for (auto & result : fsm_output->results) + { + if (result.region() != this) + continue; + + auto ctrl_const = jlm::rvsdg::control_constant(this, 2, 1); + + auto old_origin = result.origin(); + result.divert_to(ctrl_const); + jlm::rvsdg::remove(static_cast(old_origin)->node()); + } +}; + +void +fsm_state::add_ctl_result(size_t nalternatives, jlm::rvsdg::structural_output * structural_output) +{ + auto ctrl_const = jlm::rvsdg::control_constant(this, nalternatives, 0); + + // TODO connecting these regions result to a dummy empty structural node like this is not clean + // and may break asumptions made in the jlm code base + //! This is just a temporary workaround + // note: calling create() will add the result to the results list of the structural_output + jlm::rvsdg::result::create( + this, + ctrl_const, + structural_output, + jlm::rvsdg::ctltype::Create(nalternatives)); +}; + +void +fsm_state::set_mux_ctl(jlm::rvsdg::input * result, size_t alternatives) +{ + muxes_ctl_[result] = alternatives; +}; + +void +fsm_state::apply_mux_ctl() +{ + for (auto & mux_ctl : muxes_ctl_) + { + auto result = mux_ctl.first; + auto mux_output = result->origin(); + JLM_ASSERT(jlm::rvsdg::is(mux_output)); + + auto mux = static_cast(mux_output)->node(); + auto fsm_output = mux->input(0)->origin(); + JLM_ASSERT(jlm::rvsdg::is(fsm_output)); + + for (auto & result : static_cast(fsm_output)->results) + { + if (result.region() != this) + continue; + + auto mux_op = dynamic_cast(&mux->operation()); + if (!mux_op) + JLM_UNREACHABLE("SHLS: Mux operation not found"); + + auto ctrl_const = jlm::rvsdg::control_constant(this, mux_op->nalternatives(), mux_ctl.second); + auto old_origin = result.origin(); + result.divert_to(ctrl_const); + jlm::rvsdg::remove(static_cast(old_origin)->node()); + } + + // auto ctrl_const = jlm::rvsdg::control_constant(this, mux_ctl.second, 0); + // mux_ctl.first->divert_to(ctrl_const); + } +}; + +void +fsm_node_temp::print_states() const +{ + std::unordered_map region_to_state; + for (size_t i = 0; i < states_.size(); i++) + { + std::cout << "S" << i << ": "; + auto state = states_[i]; + + for (size_t result_id = 0; result_id < state->nresults(); result_id++) + { + auto result = state->result(result_id); + auto node_output = dynamic_cast(*result->output()->begin()); + if (!node_output) + continue; + + auto node = node_output->node(); + + if (dynamic_cast(&node->operation())) + { + auto ctl = static_cast(result->origin())->node(); + if (!jlm::rvsdg::is_ctlconstant_op(ctl->operation())) + continue; + + // If the register store is not on skip it + auto clt_val = + static_cast(ctl->operation()).value().alternative(); + if (clt_val == 0) + continue; + + std::cout << "R" << result_id << " "; + } + } + } +}; + +// FIXME: This function is not implemented +fsm_node_temp * +fsm_node_temp::copy(jlm::rvsdg::region * region, jlm::rvsdg::substitution_map & smap) const +{ + JLM_UNREACHABLE("SHLS: fsm_node_temp::copy() is not implemented"); + return nullptr; + // auto ln = new fsm_node_temp(region); + // return ln; +}; + +fsm_node_temp::~fsm_node_temp() +{ + std::cout << "*** Deleting fsm_node_temp ***" << std::endl; + for (auto state : states_) + { + delete state; + } +}; + +fsm_node_builder::~fsm_node_builder() +{ + std::cout << "*** Deleting fsm_node_builder ***" << std::endl; + // // states_.clear(); + // for (auto state : states_) { + // delete state; + // } + // delete fsm_node_temp_; +}; + +fsm_node_builder * +fsm_node_builder::create(jlm::rvsdg::region * parent) +{ + // FSM should be create in the control region of a loop node + JLM_ASSERT(jlm::rvsdg::is(parent->node()->operation())); + JLM_ASSERT(parent->index() == 0); + + auto fn = new fsm_node_builder(parent); + return fn; +}; + +jlm::rvsdg::structural_output * +fsm_node_temp::add_ctl_output(size_t nalternatives) +{ + auto structural_output = + jlm::rvsdg::structural_output::create(this, jlm::rvsdg::ctltype::Create(nalternatives)); + for (auto state : states_) + { + state->add_ctl_result(nalternatives, structural_output); + } + std::cout << "Added register output to FSM node with results size " + << structural_output->results.size() << ", empty " << structural_output->results.empty() + << std::endl; + return structural_output; +}; + +// jlm::rvsdg::structural_output * +// fsm_node_temp::add_mux_ouput() +// { +// auto structural_output = jlm::rvsdg::structural_output::create(this, jlm::rvsdg::ctltype(1)); +// for (auto state : states_) +// { +// state->add_ctl_result(1, structural_output); +// } +// std::cout << "Added mux output to FSM node with results size " << +// structural_output->results.size() << ", empty " << structural_output->results.empty() << +// std::endl; return structural_output; +// }; + +fsm_state * +fsm_node_temp::add_state() +{ + auto state = fsm_state::create(this, states_.size()); + + for (size_t i = 0; i < noutputs(); i++) + { + state->add_ctl_result(2, output(i)); + } + + states_.push_back(state); + return state; +}; + +void +fsm_node_builder::generate_gamma(jlm::rvsdg::output * predicate) +{ + gamma_ = jlm::rvsdg::gamma_node::create(predicate, fsm_node_temp_->states_.size()); + + for (size_t i = 0; i < fsm_node_temp_->noutputs(); i++) + { + std::vector results_origins; + for (auto & result : fsm_node_temp_->output(i)->results) + { + auto region = gamma_->subregion(result.region()->index()); + auto new_ctrl_const = + static_cast(result.origin())->node()->copy(region, {}); + results_origins.push_back(new_ctrl_const->output(0)); + } + gamma_->add_exitvar(results_origins); + } + + for (size_t i = 0; i < fsm_node_temp_->noutputs(); i++) + { + fsm_node_temp_->output(i)->divert_users(gamma_->output(i)); + } + + delete fsm_node_temp_; +}; + +} // namespace jlm::static_hls diff --git a/jlm/hls/ir/static/fsm.hpp b/jlm/hls/ir/static/fsm.hpp new file mode 100644 index 000000000..53c501b30 --- /dev/null +++ b/jlm/hls/ir/static/fsm.hpp @@ -0,0 +1,219 @@ +/* + * Copyright 2024 Louis Maurin + * See COPYING for terms of redistribution. + */ + +#ifndef JLM_HLS_IR_STATIC_FSM_HPP +#define JLM_HLS_IR_STATIC_FSM_HPP + +#include +#include +#include + +namespace jlm::static_hls +{ +class fsm_node_temp; +class fsm_node_builder; + +class fsm_state final : private jlm::rvsdg::region +{ + friend fsm_node_temp; + +private: + inline fsm_state(jlm::rvsdg::structural_node * node, size_t index) + : region(node, index) /*, index_(index)*/ + {} + +public: + using region::copy; + + static fsm_state * + create(fsm_node_temp * parent_fsm_node, size_t index); + + void + enable_reg(jlm::rvsdg::node * node); + + // TODO doc + void + add_ctl_result(size_t nalternatives, jlm::rvsdg::structural_output * structural_output); + + void + set_mux_ctl(jlm::rvsdg::input * result, size_t alternatives); + + void + apply_mux_ctl(); + +private: + // size_t index_; + std::unordered_map muxes_ctl_; +}; + +/*! \class fsm_op + * \brief Operation for the Finite State Machine. + * See fsm_node for more details. + */ +class fsm_op final : public jlm::rvsdg::structural_op +{ +public: + virtual ~fsm_op() noexcept + {} + + std::string + debug_string() const override + { + return "SHLS_FSM"; + } + + std::unique_ptr + copy() const override + { + return std::unique_ptr(new fsm_op(*this)); + } +}; + +class fsm_node_temp final : public jlm::rvsdg::structural_node +{ + friend fsm_node_builder; + +public: + ~fsm_node_temp(); + +private: + inline fsm_node_temp(jlm::rvsdg::region * parent) + : structural_node(fsm_op(), parent, 1) + {} + +public: + /*! \brief Copies the loop node. + * \param region The parent region of the new loop node. + * \param smap The substitution map for nodes in the loop subregions. + * \return The newly created loop node. + */ + virtual fsm_node_temp * + copy(jlm::rvsdg::region * region, jlm::rvsdg::substitution_map & smap) const override; + + void + print_states() const; + + // TODO finish this doc + /*! \brief Connects a register store signal input to the fsm. + */ + jlm::rvsdg::structural_output * + add_register_ouput() + { + return add_ctl_output(2); + }; + + jlm::rvsdg::structural_output * + add_ctl_output(size_t nalternatives); + + // //TODO finish this doc + // /*! \brief Connects a register store signal input to the fsm. + // */ + // jlm::rvsdg::structural_output * + // add_mux_ouput(); + + /*! \brief Adds a new state to the fsm. + * Creates a fsm_state instance that is added to the states_ vector + */ + fsm_state * + add_state(); + +private: + std::vector states_; +}; + +/*! \class fsm_node + \brief Finite State Machine node for HLS. +*/ +class fsm_node_builder final +{ +public: + ~fsm_node_builder(); + +private: + inline fsm_node_builder(jlm::rvsdg::region * parent) + { + fsm_node_temp_ = new fsm_node_temp(parent); + }; + +public: + /*! \brief Creates a new fsm node. + * Simply calls the constructor. + * \param parent The parent region of the fsm. + * \return The newly created fsm node. + */ + static fsm_node_builder * + create(jlm::rvsdg::region * parent); + + // /*! \brief Copies the loop node. + // * \param region The parent region of the new loop node. + // * \param smap The substitution map for nodes in the loop subregions. + // * \return The newly created loop node. + // */ + // virtual fsm_node * + // copy(jlm::rvsdg::region * region, jlm::rvsdg::substitution_map & smap) const override; + + // TODO finish this doc + /*! \brief Connects a register store signal input to the fsm. + */ + jlm::rvsdg::structural_output * + add_register_ouput() + { + return fsm_node_temp_->add_register_ouput(); + }; + + // TODO finish this doc + /*! \brief Connects a register store signal input to the fsm. + */ + jlm::rvsdg::structural_output * + add_ctl_output(size_t nalternatives) + { + return fsm_node_temp_->add_ctl_output(nalternatives); + }; + + // //TODO finish this doc + // /*! \brief Connects a register store signal input to the fsm. + // */ + // jlm::rvsdg::structural_output * + // add_mux_ouput() { + // return fsm_node_temp_->add_mux_ouput(); + // }; + + /*! \brief Adds a new state to the fsm. + * Creates a fsm_state instance that is added to the states_ vector + */ + fsm_state * + add_state() + { + return fsm_node_temp_->add_state(); + }; + + void + apply_mux_ctl() + { + for (auto state : fsm_node_temp_->states_) + state->apply_mux_ctl(); + } + + /*! \brief Generates the gamma node (which is the final implementation) for the fsm. + */ + void + generate_gamma(jlm::rvsdg::output * predicate); + + inline size_t + nalternatives() const + { + if (!gamma_) + return fsm_node_temp_->states_.size(); + return gamma_->nsubregions(); + } + +private: + fsm_node_temp * fsm_node_temp_ = nullptr; + jlm::rvsdg::gamma_node * gamma_ = nullptr; +}; + +} // namespace jlm::static_hls + +#endif // JLM_HLS_IR_STATIC_FSM_HPP diff --git a/jlm/hls/ir/static/loop.cpp b/jlm/hls/ir/static/loop.cpp new file mode 100644 index 000000000..9683c2587 --- /dev/null +++ b/jlm/hls/ir/static/loop.cpp @@ -0,0 +1,358 @@ +/* + * Copyright 2024 Louis Maurin + * See COPYING for terms of redistribution. + */ + +#include + +namespace jlm::static_hls +{ +size_t instances_count = 0; + +std::vector +loop_node::get_users(jlm::rvsdg::node_input * input) const +{ + auto node = get_mux(input); + if (!node) + { + std::cout << "input is a loop input"; + return {}; + } + + std::vector users; + + for (size_t i = 0; i < node->ninputs(); i++) + { + users.push_back(node->input(i)->origin()); + } + return users; +}; + +jlm::rvsdg::node * +loop_node::get_mux(jlm::rvsdg::node_input * node) const +{ + auto control_result = get_origin_result(node); + if (!control_result) + { + return nullptr; + } + + auto node_output = dynamic_cast(control_result->origin()); + if (!node_output) + JLM_UNREACHABLE("SHLS: call of function loop_node::get_mux() with invalid argument"); + JLM_ASSERT(jlm::rvsdg::is(node_output->node()->operation())); + return node_output->node(); +}; + +loop_node * +loop_node::create(jlm::rvsdg::region * parent) +{ + auto ln = new loop_node(parent); + ln->fsm_ = fsm_node_builder::create(ln->control_subregion()); + return ln; +}; + +jlm::rvsdg::structural_output * +loop_node::add_loopvar(jlm::rvsdg::theta_input * theta_input) +{ + auto input = jlm::rvsdg::structural_input::create( + this, + theta_input->origin(), + theta_input->origin()->Type()); + auto output = jlm::rvsdg::structural_output::create(this, theta_input->origin()->Type()); + + auto argument_in = + jlm::rvsdg::argument::create(control_subregion(), input, theta_input->origin()->Type()); + + reg_smap_.insert(theta_input->argument(), argument_in); + + return output; +}; + +// FIXME: This function is not implemented +loop_node * +loop_node::copy(jlm::rvsdg::region * region, jlm::rvsdg::substitution_map & smap) const +{ + JLM_UNREACHABLE("SHLS: loop_node::copy() is not implemented"); + return nullptr; + // auto ln = new loop_node(region); + // return ln; +}; + +void +loop_node::add_node(jlm::rvsdg::node * node) +{ + std::cout << "Adding node: " << node->operation().debug_string() << std::endl; + + jlm::rvsdg::node * compute_node; + + auto fsm_state = fsm_->add_state(); + + if (auto implemented_node = is_op_implemented(node->operation())) + { + compute_node = implemented_node; + std::cout << "Node operation " << node->operation().debug_string() << " already implemented" + << std::endl; + + //! *** Add the inputs connection to the muxes of the control region **** + for (size_t i = 0; i < node->ninputs(); i++) + { + // Get the get mux connected to the already implemented node + auto mux = get_mux(compute_node->input(i)); + if (!mux) + JLM_UNREACHABLE("SHLS loop_node::add_node() : mux not found"); + + auto input_new_origin = reg_smap_.lookup(node->input(i)->origin()); + if (!input_new_origin) + JLM_UNREACHABLE("SHLS loop_node::add_node() : node input origin not found in reg_smap_"); + + mux = mux_add_input(mux, input_new_origin); + + auto mux_op = dynamic_cast(&mux->operation()); + if (!mux_op) + JLM_UNREACHABLE("Mux operation not found"); + + fsm_state->set_mux_ctl(*(mux->output(0)->begin()), mux_op->nalternatives() - 1); + } + + // If the node is not implemented yet + } + else + { + //! *** Create args in compute region and result and mux in control region for each node input + //! **** + std::vector inputs_args; + for (size_t i = 0; i < node->ninputs(); i++) + { + //! Create arg in compute region + auto input_arg = backedge_argument::create(compute_subregion(), node->input(i)->Type()); + inputs_args.push_back(input_arg); + + auto input_new_origin = reg_smap_.lookup(node->input(i)->origin()); + if (!input_new_origin) + JLM_UNREACHABLE(util::strfmt( + "SHLS loop_node::add_node() : node input origin not found in reg_smap_ " + "for node with op ", + node->operation().debug_string()) + .c_str()); + + //! This will create a mux without a predicate which is added afterwards + auto mux = jlm::static_hls::mux_op::create({ input_new_origin }); + + //! Create corresponding result in control region + auto res = backedge_result::create(mux->output(0)); + res->argument_ = input_arg; + input_arg->result_ = res; + + auto mux_op = dynamic_cast(&mux->operation()); + if (!mux_op) + JLM_UNREACHABLE("Mux operation not found"); + + fsm_state->set_mux_ctl(res, mux_op->nalternatives() - 1); + } + + // Copy the node into the compute subregion of the loop + compute_node = node->copy(compute_subregion(), inputs_args); + } + + //! Create a backedge and register for each of the node outputs + for (size_t i = 0; i < node->noutputs(); i++) + { + auto backedge_result = add_backedge(compute_node->output(i)); + auto reg_store_origin = fsm_->add_register_ouput(); + + auto reg = reg_op::create( + *reg_store_origin, + *backedge_result->argument(), + jlm::util::strfmt(compute_node->operation().debug_string(), ":", i)); + fsm_state->enable_reg(reg); + + reg_smap_.insert(node->output(i), reg->output(0)); + } +}; + +void +loop_node::add_loopback_arg(jlm::rvsdg::theta_input * theta_input) +{ + auto new_arg_out = reg_smap_.lookup(theta_input->argument()); + if (!new_arg_out) + JLM_UNREACHABLE("SHLS: loop_node::add_loopback_arg() cannot find the argument in reg_smap_"); + + auto new_arg = dynamic_cast(new_arg_out); + if (!new_arg) + JLM_UNREACHABLE("SHLS: loop_node::add_loopback_arg() cannot cast to arg"); + + auto fsm_state = fsm_->add_state(); + + std::cout << "new_arg nusers: " << new_arg->nusers() << std::endl; + + //! Need to iterate through the old users because the users are modified during the loop + std::vector old_users(new_arg->begin(), new_arg->end()); + + for (auto user : old_users) + { + auto mux_in = dynamic_cast(user); + if (!mux_in) + JLM_UNREACHABLE("SHLS: loop_node::add_loopback_arg() cannot cast to node_input"); + + auto mux = mux_in->node(); + + auto mux_op = dynamic_cast(&mux->operation()); + if (!mux_op) + JLM_UNREACHABLE("SHLS: loop_node::add_loopback_arg() mux operation not found"); + + auto new_result_origin = reg_smap_.lookup(theta_input->result()->origin()); + + mux = mux_add_input(mux, new_result_origin); + fsm_state->set_mux_ctl(*(mux->output(0)->begin()), mux_op->nalternatives() - 1); + }; +}; + +// TODO optimize this function by using a set of operations +jlm::rvsdg::node * +loop_node::is_op_implemented(const jlm::rvsdg::operation & op) const noexcept +{ + for (size_t ind_res = 0; ind_res < compute_subregion()->nresults(); ind_res++) + { + if (static_cast(compute_subregion()->result(ind_res)->origin()) + ->node() + ->operation() + == op) + { + return static_cast(compute_subregion()->result(ind_res)->origin()) + ->node(); + } + } + return nullptr; +}; + +backedge_result * +loop_node::add_backedge(jlm::rvsdg::output * origin) +{ + auto result_loop = backedge_result::create(origin); + auto argument_loop = backedge_argument::create(control_subregion(), origin->Type()); + argument_loop->result_ = result_loop; + result_loop->argument_ = argument_loop; + return result_loop; +}; + +jlm::rvsdg::result * +loop_node::get_origin_result(jlm::rvsdg::node_input * input) const +{ + auto arg = dynamic_cast(input->origin()); + if (!arg) + return nullptr; // This is a loop_var + return arg->result(); +}; + +void +loop_node::print_nodes_registers() const +{ + std::cout << "**** Printing nodes and registers ****" << std::endl; + for (size_t i = 0; i < compute_subregion()->nresults(); i++) + { + auto node_ouput = + static_cast(compute_subregion()->result(i)->origin()); + + auto node = node_ouput->node(); + std::cout << "node " << node->operation().debug_string() << " | ouput " << node_ouput->index() + << " = "; + + auto backedge_result = + dynamic_cast(compute_subregion()->result(i)); + if (!backedge_result) + { + JLM_UNREACHABLE("SHLS: loop_node::print_nodes_registers() cannot cast to backedge_result"); + } + auto arg_user = backedge_result->argument()->begin(); + + auto node_in = dynamic_cast(*arg_user); + if (!node_in) + { + std::cout << "arg_user is not a node_input" << std::endl; + continue; + } + std::cout << node_in->node()->operation().debug_string() << std::endl; + + // auto reg_out = reg_smap_.lookup(node.output(i)); + // if (!reg_out) + // { + // std::cout << "node ouput " << i << " not in reg_smap_ " << std::endl; + // continue; + // } + // auto node_out = dynamic_cast(reg_out); + // if (!node_out) + // { + // std::cout << "node ouput " << i << " substitute is not a node_ouput" << std::endl; + // continue; + // } + // std::cout << "node ouput " << i << " : " << node_out->node()->operation().debug_string() << + // std::endl; + } +}; + +void +loop_node::remove_single_input_muxes() +{ + for (size_t i = 0; i < control_subregion()->nresults(); i++) + { + auto node_output = + dynamic_cast(control_subregion()->result(i)->origin()); + if (!node_output) + JLM_UNREACHABLE("SHLS: loop_node::remove_single_input_muxes() cannot cast to node_output"); + + auto mux = node_output->node(); + JLM_ASSERT(jlm::rvsdg::is(mux->operation())); + + auto mux_op = dynamic_cast(&mux->operation()); + if (!mux_op) + JLM_UNREACHABLE("Mux operation not found"); + + if (mux_op->nalternatives() == 1) + { + size_t input_id = 0; + if (mux_op->has_predicate()) + input_id = 1; + mux->output(0)->divert_users(mux->input(input_id)->origin()); + remove(mux); + } + } +}; + +void +loop_node::connect_muxes() +{ + for (size_t i = 0; i < control_subregion()->nresults(); i++) + { + auto node_output = + dynamic_cast(control_subregion()->result(i)->origin()); + if (!node_output) + JLM_UNREACHABLE("SHLS: loop_node::connect_muxes() cannot cast to node_output"); + + auto mux = node_output->node(); + JLM_ASSERT(jlm::rvsdg::is(mux->operation())); + + auto mux_op = dynamic_cast(&mux->operation()); + if (!mux_op) + JLM_UNREACHABLE("Mux operation not found"); + + auto mux_ctl_origin = fsm_->add_ctl_output(mux_op->nalternatives()); + mux_connect_predicate(mux, mux_ctl_origin); + } +}; + +void +loop_node::finalize() +{ + // FIXME this is a temporary solution with an argument + auto arg = jlm::rvsdg::argument::create( + control_subregion(), + nullptr, + jlm::rvsdg::ctltype::Create(fsm_->nalternatives())); + connect_muxes(); + fsm_->apply_mux_ctl(); + fsm_->generate_gamma(arg); + remove_single_input_muxes(); +}; + +} // namespace jlm::static_hls diff --git a/jlm/hls/ir/static/loop.hpp b/jlm/hls/ir/static/loop.hpp new file mode 100644 index 000000000..ec848fa14 --- /dev/null +++ b/jlm/hls/ir/static/loop.hpp @@ -0,0 +1,262 @@ +/* + * Copyright 2024 Louis Maurin + * See COPYING for terms of redistribution. + */ + +#ifndef JLM_HLS_IR_STATIC_LOOP_HPP +#define JLM_HLS_IR_STATIC_LOOP_HPP + +// FIXME check what's needed +#include +#include +#include +#include // needed +#include +#include + +#include +#include + +namespace jlm::static_hls +{ + +/*! \class loop_op + * \brief Loop operation for static loop node. + * See loop_node for more details. + */ +class loop_op final : public jlm::rvsdg::structural_op +{ +public: + virtual ~loop_op() noexcept + {} + + std::string + debug_string() const override + { + return "SHLS_LOOP"; + } + + std::unique_ptr + copy() const override + { + return std::unique_ptr(new loop_op(*this)); + } +}; + +class backedge_argument; +class backedge_result; + +/*! \class loop_node + \brief Static loop node for HLS. + + Theta node are lowered into this node. + This node has 2 regions : + - control region : contains the control flow of the loop, registers, muxes and FSM + - compute region : contains the operations that can be computed in parallel +*/ +class loop_node final : public jlm::rvsdg::structural_node +{ + // public: + // ~loop_node(); + +private: + inline loop_node(jlm::rvsdg::region * parent) + : structural_node(loop_op(), parent, 2) + {} + +public: + /*! \brief Debug function that prints the registers for each node output of the control region. + */ + void + print_nodes_registers() const; + + // TODO doc + void + print_fsm() const; + + inline const jlm::rvsdg::substitution_map * + get_reg_smap() const + { + return ®_smap_; + } + + /*! \brief Get the connected control region result from a node input in the compute region. + * \param input The node input in the compute region. + * \return The connected control region result. + */ + jlm::rvsdg::result * + get_origin_result(jlm::rvsdg::node_input * input) const; + + /*! \brief Get a list of registers nodes connected to a node input in the compute region. + * Also prints is the input is a loop input + * \param node The node input in the compute region. + * \return The list of register nodes connected to the node input. + */ + std::vector + get_users(jlm::rvsdg::node_input * node) const; + + /*! \brief get the connected mux to a node input in the compute region. + */ + jlm::rvsdg::node * + get_mux(jlm::rvsdg::node_input * node) const; + + /*! \brief Creates a new loop node. + * Simply calls the constructor. + * \param parent The parent region of the loop. + * \return The newly created loop node. + */ + static loop_node * + create(jlm::rvsdg::region * parent); + + // FIXME: this doc + /*! \brief Adds a loop input to the loop. + * \param origin The origin of the loop input. + * \return The newly created loop input. + */ + jlm::rvsdg::structural_output * + add_loopvar(jlm::rvsdg::theta_input * theta_input); + + /*! \brief Returns the compute region of the loop. + */ + inline jlm::rvsdg::region * + control_subregion() const noexcept + { + return structural_node::subregion(0); + } + + /*! \brief Returns the compute region of the loop. + */ + inline jlm::rvsdg::region * + compute_subregion() const noexcept + { + return structural_node::subregion(1); + } + + /*! \brief Copies the loop node. + * \param region The parent region of the new loop node. + * \param smap The substitution map for nodes in the loop subregions. + * \return The newly created loop node. + */ + virtual loop_node * + copy(jlm::rvsdg::region * region, jlm::rvsdg::substitution_map & smap) const override; + + /*! \brief Adds a node to the loop. + * Used when building the loop_node from a theta node. + * This method adds a node to the loop_node either by adding an operation to the compute region or + * by using an already implemented one. Also adds a register in the control region for each of its + * outputs. This will also add inputs to the muxes in the control region to route the node node. + * \param node The node to add. + */ + void + add_node(jlm::rvsdg::node * node); + + // TODO doc + void + add_loopback_arg(jlm::rvsdg::theta_input * theta_input); + + /*! \brief Determines if an operation is aleardy added in the compute region. + * This goes through every node in the compute region to check if the operation is already + * implemented. \param op The operation to check. \return The node that implements the operation + * if it is already implemented, nullptr otherwise. + */ + jlm::rvsdg::node * + is_op_implemented(const jlm::rvsdg::operation & op) const noexcept; + + /*! \brief Adds a backedge to the loop. + * \param origin The origin of the backedge. + * \return The newly created backedge. + */ + backedge_result * + add_backedge(jlm::rvsdg::output * origin); + + /*! \brief Replace muxes with only one input by the input node. + * This is called in the finalize method. + */ + void + remove_single_input_muxes(); + + // TODO doc + void + connect_muxes(); + + /*! \brief Finishes the loop building and creates the FSM gamma node + * This calls remove_single_input_muxes. + */ + void + finalize(); + +private: + jlm::static_hls::fsm_node_builder * fsm_; // The FSM node of the loop + // TODO rename + jlm::rvsdg::substitution_map reg_smap_; // Maps the original node ouput with its substitute + // register output, also maps arguments +}; + +// TODO rename those as they are not only used for backedges +class backedge_argument : public jlm::rvsdg::argument +{ + friend loop_node; + friend backedge_result; + +public: + ~backedge_argument() override = default; + + backedge_result * + result() + { + return result_; + } + +private: + backedge_argument( + jlm::rvsdg::region * region, + const std::shared_ptr & type) + : jlm::rvsdg::argument(region, nullptr, type), + result_(nullptr) + {} + + static backedge_argument * + create(jlm::rvsdg::region * region, const std::shared_ptr & type) + { + auto argument = new backedge_argument(region, type); + region->append_argument(argument); + return argument; + } + + backedge_result * result_; +}; + +class backedge_result : public jlm::rvsdg::result +{ + friend loop_node; + friend backedge_argument; + +public: + ~backedge_result() override = default; + + backedge_argument * + argument() const + { + return argument_; + } + +private: + backedge_result(jlm::rvsdg::output * origin) + : jlm::rvsdg::result(origin->region(), origin, nullptr, origin->port()), + argument_(nullptr) + {} + + static backedge_result * + create(jlm::rvsdg::output * origin) + { + auto result = new backedge_result(origin); + origin->region()->append_result(result); + return result; + } + + backedge_argument * argument_; +}; + +} // namespace jlm::static_hls + +#endif // JLM_HLS_IR_STATIC_LOOP_HPP diff --git a/tests/jlm/hls/backend/rvsdg2rhls/static/TestTheta.cpp b/tests/jlm/hls/backend/rvsdg2rhls/static/TestTheta.cpp new file mode 100644 index 000000000..7cd4f5b9f --- /dev/null +++ b/tests/jlm/hls/backend/rvsdg2rhls/static/TestTheta.cpp @@ -0,0 +1,123 @@ +/* + * Copyright 2024 Louis Maurin + * See COPYING for terms of redistribution. + */ +#include "test-registry.hpp" + +#include +#include +#include +#include + +#include + +#include +#include + +std::unique_ptr +CreateTestModule() +{ + using namespace jlm::llvm; + + auto module = jlm::llvm::RvsdgModule::Create(jlm::util::filepath(""), "", ""); + auto graph = &module->Rvsdg(); + + auto nf = graph->node_normal_form(typeid(jlm::rvsdg::operation)); + nf->set_mutable(false); + + MemoryStateType mt; + auto pointerType = PointerType::Create(); + auto fcttype = FunctionType::Create( + { jlm::rvsdg::bittype::Create(32) }, + // { jlm::rvsdg::bittype::Create(32), jlm::rvsdg::bittype::Create(32) }, + { jlm::rvsdg::bittype::Create(32) }); + auto fct = lambda::node::create(graph->root(), fcttype, "f", linkage::external_linkage); + + auto thetanode = jlm::rvsdg::theta_node::create(fct->subregion()); + + auto sum_loop_var = thetanode->add_loopvar(fct->fctargument(0)); + // auto loop_var2 = thetanode->add_loopvar(fct->fctargument(1)); + + auto one = jlm::rvsdg::create_bitconstant(thetanode->subregion(), 32, 1); + auto five = jlm::rvsdg::create_bitconstant(thetanode->subregion(), 32, 5); + auto sum = jlm::rvsdg::bitadd_op::create(32, sum_loop_var->argument(), one); + auto cmp = jlm::rvsdg::bitult_op::create(32, sum, five); + auto predicate = jlm::rvsdg::match(1, { { 1, 1 } }, 0, 2, cmp); + + // auto sum2 = jlm::rvsdg::bitadd_op::create(32, sum_loop_var->argument(), loop_var2->argument()); + + // change to loop_var result origin to the ouput of sum + // (by default the loop_var result origin is connected to the loop_var argument (to itself)) + sum_loop_var->result()->divert_to(sum); + + // loop_var2->result()->divert_to(sum2); + + thetanode->set_predicate(predicate); + + fct->finalize({ sum_loop_var }); + + // Make the function external func + graph->add_export(fct->output(), jlm::rvsdg::expport(pointerType, "f")); + + return module; +} + +static int +TestTheta() +{ + auto rvsdgModule = CreateTestModule(); + + std::cout << "********** Original graph **********" << std::endl; + jlm::rvsdg::view(*rvsdgModule->Rvsdg().root()->graph(), stdout); + + std::cout << "********** Running static rvsdg2rhls **********" << std::endl; + jlm::static_hls::rvsdg2rhls(*rvsdgModule); + + std::cout << "********** Converted graph **********" << std::endl; + jlm::rvsdg::view(*rvsdgModule->Rvsdg().root()->graph(), stdout); + + auto lambda = &*rvsdgModule->Rvsdg().root()->begin(); + auto lambda_node = jlm::util::AssertedCast(lambda); + + auto loop = &*lambda_node->subregion()->begin(); + auto loop_node = jlm::util::AssertedCast(loop); + + auto orig_module = CreateTestModule(); + + auto orig_lambda = &*orig_module->Rvsdg().root()->begin(); + auto orig_lambda_node = static_cast(orig_lambda); + + auto orig_theta = &*orig_lambda_node->subregion()->begin(); + auto orig_theta_node = static_cast(orig_theta); + + for (auto & node : jlm::rvsdg::topdown_traverser(orig_theta_node->subregion())) + { + auto imp_node = loop_node->is_op_implemented(node->operation()); + JLM_ASSERT(imp_node); + + for (size_t i = 0; i < imp_node->ninputs(); i++) + { + if (!dynamic_cast(imp_node->input(i)->origin())) + continue; + + bool origin_found_in_users = false; + for (auto user : loop_node->get_users(imp_node->input(i))) + { + auto reg_smap = loop_node->get_reg_smap(); + auto output_origin = reg_smap->lookup(imp_node->input(i)->origin()); + if (!output_origin) + { + std::cout << "output_origin not in reg_smap for " << imp_node->operation().debug_string() + << "input " << i << std::endl; + } + if (output_origin == user) + origin_found_in_users = true; + } + JLM_ASSERT(origin_found_in_users); + } + } + + return 0; +} + +JLM_UNIT_TEST_REGISTER("jlm/hls/backend/rvsdg2rhls/static/TestTheta", TestTheta)