diff --git a/include/CIRCT2Graph.h b/include/CIRCT2Graph.h index 0c63991..e40f3df 100644 --- a/include/CIRCT2Graph.h +++ b/include/CIRCT2Graph.h @@ -34,8 +34,11 @@ class CIRCT2Graph { hw::HWModuleOp topModule; llvm::DenseMap valueNodeMap; - void processInputPort(); + void processIoPort(); + void processInputPort(hw::PortInfo port); + void processOutputPort(hw::PortInfo port); void processOperations(); + void processOutputOp(hw::OutputOp op); void processConstantOp(hw::ConstantOp constantOp); }; diff --git a/src/CIRCT2Graph.cpp b/src/CIRCT2Graph.cpp index d6d2fd0..87a195b 100644 --- a/src/CIRCT2Graph.cpp +++ b/src/CIRCT2Graph.cpp @@ -6,34 +6,68 @@ #include "llvm/Support/Casting.h" #include "mlir/IR/Block.h" +using llvm::dyn_cast; + graph* CIRCT2Graph::generateGraph() { g = new graph(); - processInputPort(); + processIoPort(); processOperations(); return g; } -void CIRCT2Graph::processInputPort() { - size_t inputPortIdx = 0; - for(size_t i = 0; i < topModule.getNumPorts(); i++) { - auto portInfo = topModule.getPort(i); - if(portInfo.isOutput()) continue; +/// For each Input and Output port of TopModule. Generate NODE_IND and NODE_OUT +void CIRCT2Graph::processIoPort() { + auto portList = topModule.getPortList(); + for (auto& p : portList) { + //std::cout<< p.getName().str() << " "<< std::endl; + if (p.isInput()) { + processInputPort(p); + }else if (p.isOutput()) { + processOutputPort(p); + }else{ //InOut is not supproted! + Panic(); + } + } +} + +void CIRCT2Graph::processInputPort(hw::PortInfo portInfo) { + Assert(portInfo.isInput(),""); // 创建 typeInfo - TypeInfo* typeInfo = new TypeInfo(); - typeInfo->set_sign(portInfo.type.isSignedInteger()); - typeInfo->set_width(portInfo.type.getIntOrFloatBitWidth()); - typeInfo->set_reset(UNCERTAIN); + auto typeInfo = std::make_unique(); + if (isa(portInfo.type)) { + typeInfo->set_sign(false); + typeInfo->set_width(1); + typeInfo->set_clock(true); + typeInfo->set_reset(UNCERTAIN); + } else { + typeInfo->set_sign(portInfo.type.isSignedInteger()); + typeInfo->set_width(portInfo.type.getIntOrFloatBitWidth()); + typeInfo->set_reset(UNCERTAIN); + } // 创建 node Node* node = new Node(NODE_INP); - node->name = topModule.getPortName(i).str(); - node->updateInfo(typeInfo); + node->name = portInfo.getName(); + node->updateInfo(typeInfo.get()); // 将 node 添加到图中 g->input.push_back(node); // 将 value-node 关系添加到 map 中 - mlir::Value portValue = topModule.getBodyRegion().front().getArgument(inputPortIdx); + mlir::Value portValue = topModule.getBodyRegion().front().getArgument(portInfo.argNum); valueNodeMap[portValue] = node; - inputPortIdx++; - } +} + +void CIRCT2Graph::processOutputPort(hw::PortInfo portInfo) { + Assert(portInfo.isOutput(),""); + + auto typeInfo = std::make_unique(); + typeInfo->set_sign(portInfo.type.isSignedInteger()); + typeInfo->set_width(portInfo.type.getIntOrFloatBitWidth()); + typeInfo->set_reset(UNCERTAIN); + + Node* node = new Node(NODE_OUT); + node->name = portInfo.getName(); + node->updateInfo(typeInfo.get()); + + g->output.push_back(node); } void CIRCT2Graph::processOperations() { @@ -43,14 +77,33 @@ void CIRCT2Graph::processOperations() { for (auto& op : body->getOperations()) { // GRH: add more operation in this branch - if (auto constantOp = llvm::dyn_cast(op)) { + if (auto constantOp = dyn_cast(op)) { processConstantOp(constantOp); - } else { + } else if (auto outputOp = dyn_cast(op)) { + processOutputOp(outputOp); + }else { Assert(false, "Unsupported operation: %s", op.getName().getStringRef().str().c_str()); } } } +/// hw.output sig1,sig2 : i1,i8 +/// Connect signals to the output of a module +/// Order is defined by the hw.module's declaration +void CIRCT2Graph::processOutputOp(hw::OutputOp op) { + auto operands = op.getOperands(); + auto output_node_it = g->output.begin(); // C++ 17 doesn't have zip. + for (auto operand : operands) { + auto* node = *output_node_it; + auto* refNode = new ENode(valueNodeMap[operand]); // The node assigned to output + auto* expTree = new ExpTree(refNode, node); + node->assignTree.push_back(expTree); + ++output_node_it; + } +} + + +/// hw.constant 2 : i3 void CIRCT2Graph::processConstantOp(hw::ConstantOp constantOp) { auto intType = llvm::dyn_cast(constantOp.getType()); Assert(intType, "hw.constant expects integer result type");