diff --git a/CMakeLists.txt b/CMakeLists.txt index c0ec2a7..735b828 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,18 +58,18 @@ file(GLOB_RECURSE SOURCES "${onnxinfo_SOURCE_DIR}/src/*.cpp") list(REMOVE_ITEM SOURCES "${onnxinfo_SOURCE_DIR}/src/main.cpp") add_library( - ${PROJECT_NAME} STATIC + _onnxinfo STATIC ${SOURCES} ${onnxinfo_SOURCE_DIR}/third_party/onnx/onnx.proto3.pb.cc ) -add_dependencies(${PROJECT_NAME} protobuf::protoc) -target_link_libraries(${PROJECT_NAME} +add_dependencies(_onnxinfo protobuf::protoc) +target_link_libraries(_onnxinfo protobuf::libprotobuf ) -pybind11_add_module(_onnxinfo src/main.cpp) -target_link_libraries(_onnxinfo PRIVATE ${PROJECT_NAME}) +pybind11_add_module(${PROJECT_NAME} src/main.cpp) +target_link_libraries(${PROJECT_NAME} PRIVATE _onnxinfo) # custom cmake target for test # add_custom_target(test ALL) diff --git a/include/AttrInfo.hpp b/include/AttrInfo.hpp index 2478be9..c938896 100644 --- a/include/AttrInfo.hpp +++ b/include/AttrInfo.hpp @@ -4,43 +4,43 @@ #include "onnx.proto3.pb.h" struct AttrInfo_Conv { - std::vector kernel_shape; // got from Weight - std::vector strides; // default 1 - std::vector pads; // default 0 - std::vector dilations; // default 1 - - int64_t group = 1; // not used now - - void set_default_attr(size_t shape_len) { - for (size_t i = 0; i < shape_len - 2; ++i) { - strides.emplace_back(1); - pads.emplace_back(0); - pads.emplace_back(0); - dilations.emplace_back(1); - } + std::vector kernel_shape; // got from Weight + std::vector strides; // default 1 + std::vector pads; // default 0 + std::vector dilations; // default 1 + + int64_t group = 1; // not used now + + void set_default_attr(size_t shape_len) { + for (size_t i = 0; i < shape_len - 2; ++i) { + strides.emplace_back(1); + pads.emplace_back(0); + pads.emplace_back(0); + dilations.emplace_back(1); } + } }; struct AttrInfo_MaxPool { - std::vector kernel_shape; - std::vector strides; // default 1 - std::vector pads; // default 0 - std::vector dilations; // default 1 - - int64_t ceil_mode = 0; // not used now - int64_t storage_order = 0; // not used now - - void set_default_attr(size_t shape_len) { - for (size_t i = 0; i < shape_len - 2; ++i) { - strides.emplace_back(1); - pads.emplace_back(0); - pads.emplace_back(0); - dilations.emplace_back(1); - } + std::vector kernel_shape; + std::vector strides; // default 1 + std::vector pads; // default 0 + std::vector dilations; // default 1 + + int64_t ceil_mode = 0; // not used now + int64_t storage_order = 0; // not used now + + void set_default_attr(size_t shape_len) { + for (size_t i = 0; i < shape_len - 2; ++i) { + strides.emplace_back(1); + pads.emplace_back(0); + pads.emplace_back(0); + dilations.emplace_back(1); } + } }; struct AttrInfo_Gemm { - bool transA = false; - bool transB = false; + bool transA = false; + bool transB = false; }; diff --git a/include/InferShape.hpp b/include/InferShape.hpp index 5673387..44f70ab 100644 --- a/include/InferShape.hpp +++ b/include/InferShape.hpp @@ -9,32 +9,32 @@ class InferShapeImpl { public: - InferShapeImpl() = default; - InferShapeImpl(const onnx::GraphProto &in_graph) : graph(in_graph) {} + InferShapeImpl() = default; + InferShapeImpl(const onnx::GraphProto &in_graph) : graph(in_graph) {} - ~InferShapeImpl() = default; + ~InferShapeImpl() = default; - void set_io_iniz_shape_to_map(bool analyze); - void infer_shapes(bool analyze = true); - void infer_shapes(); - void print_summary(); + void set_io_iniz_shape_to_map(bool analyze); + void infer_shapes(bool analyze = true); + void infer_shapes(); + void print_summary(); - const std::unordered_map> get_ndname_to_shape() { - return this->ndname_to_shape; - } + const std::unordered_map> get_ndname_to_shape() { + return this->ndname_to_shape; + } private: - onnx::GraphProto graph; - std::unordered_map> ndname_to_shape; - std::unordered_map ndname_to_anal_data; - std::unordered_map ndname_to_dtype_size; - - // TODO: more op types - void infer_shapes_Conv(onnx::NodeProto &node); - void infer_shapes_Relu(onnx::NodeProto &node); - void infer_shapes_MaxPool(onnx::NodeProto &node); - void infer_shapes_Add(onnx::NodeProto &node); - void infer_shapes_GlobalAveragePool(onnx::NodeProto &node); - void infer_shapes_Flatten(onnx::NodeProto &node); - void infer_shapes_Gemm(onnx::NodeProto &node); + onnx::GraphProto graph; + std::unordered_map> ndname_to_shape; + std::unordered_map ndname_to_anal_data; + std::unordered_map ndname_to_dtype_size; + + // TODO: more op types + void infer_shapes_Conv(onnx::NodeProto &node); + void infer_shapes_Relu(onnx::NodeProto &node); + void infer_shapes_MaxPool(onnx::NodeProto &node); + void infer_shapes_Add(onnx::NodeProto &node); + void infer_shapes_GlobalAveragePool(onnx::NodeProto &node); + void infer_shapes_Flatten(onnx::NodeProto &node); + void infer_shapes_Gemm(onnx::NodeProto &node); }; diff --git a/include/ModelAnalysis.hpp b/include/ModelAnalysis.hpp index 32a9d16..6f3f440 100644 --- a/include/ModelAnalysis.hpp +++ b/include/ModelAnalysis.hpp @@ -8,12 +8,19 @@ #define DIV_MACS 4 struct AnalyzeData { - int64_t mac = 0; - int64_t param = 0; - int64_t mem = 0; + int64_t mac = 0; + int64_t param = 0; + int64_t mem = 0; +}; + +struct NodeAnalArgs { + std::vector> input_shapes; + std::vector output_shape; + std::unordered_map ndname_to_size; }; int64_t get_prod(std::vector &vec); +NodeAnalArgs get_anal_args(onnx::NodeProto &node, const std::unordered_map> &ndname_to_shape, const std::unordered_map &ndname_to_dtype_size); AnalyzeData analyze_node_Conv(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size); AnalyzeData analyze_node_Relu(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size); diff --git a/src/InferShape.cpp b/src/InferShape.cpp index 9f9c7b4..7c65236 100644 --- a/src/InferShape.cpp +++ b/src/InferShape.cpp @@ -4,448 +4,448 @@ #include "utils.hpp" void InferShapeImpl::set_io_iniz_shape_to_map(bool analyze) { - for (auto input : this->graph.input()) { - auto shape = input.type().tensor_type().shape(); - std::vector shape_vec = {}; - for (int i = 0; i < shape.dim_size(); ++i) { - auto dim = shape.dim(i); - if (dim.has_dim_value()) { - shape_vec.emplace_back(dim.dim_value()); - } - } - this->ndname_to_shape[input.name()] = shape_vec; - - // get dtype size - if (analyze) { - if (input.type().tensor_type().elem_type() == 1) { // float32 - this->ndname_to_dtype_size[input.name()] = 4; - } - } - } - - for (auto initializer : this->graph.initializer()) { - auto shape = initializer.dims(); - std::vector shape_vec = {}; - for (int i = 0; i < shape.size(); ++i) { - shape_vec.emplace_back(shape.Get(i)); - } - this->ndname_to_shape[initializer.name()] = shape_vec; - - // get dtype size - if (analyze) { - if (initializer.data_type() == 1) { - this->ndname_to_dtype_size[initializer.name()] = 4; - } - } - } - - for (auto output : this->graph.output()) { - auto shape = output.type().tensor_type().shape(); - std::vector shape_vec = {}; - for (int i = 0; i < shape.dim_size(); ++i) { - auto dim = shape.dim(i); - if (dim.has_dim_value()) { - shape_vec.emplace_back(dim.dim_value()); - } - } - this->ndname_to_shape[output.name()] = shape_vec; - - // get dtype size - if (analyze) { - if (output.type().tensor_type().elem_type() == 1) { - this->ndname_to_dtype_size[output.name()] = 4; - } - } + for (auto input : this->graph.input()) { + auto shape = input.type().tensor_type().shape(); + std::vector shape_vec = {}; + for (int i = 0; i < shape.dim_size(); ++i) { + auto dim = shape.dim(i); + if (dim.has_dim_value()) { + shape_vec.emplace_back(dim.dim_value()); + } } + this->ndname_to_shape[input.name()] = shape_vec; + + // get dtype size + if (analyze) { + if (input.type().tensor_type().elem_type() == 1) { // float32 + this->ndname_to_dtype_size[input.name()] = 4; + } + } + } + + for (auto initializer : this->graph.initializer()) { + auto shape = initializer.dims(); + std::vector shape_vec = {}; + for (int i = 0; i < shape.size(); ++i) { + shape_vec.emplace_back(shape.Get(i)); + } + this->ndname_to_shape[initializer.name()] = shape_vec; + + // get dtype size + if (analyze) { + if (initializer.data_type() == 1) { + this->ndname_to_dtype_size[initializer.name()] = 4; + } + } + } + + for (auto output : this->graph.output()) { + auto shape = output.type().tensor_type().shape(); + std::vector shape_vec = {}; + for (int i = 0; i < shape.dim_size(); ++i) { + auto dim = shape.dim(i); + if (dim.has_dim_value()) { + shape_vec.emplace_back(dim.dim_value()); + } + } + this->ndname_to_shape[output.name()] = shape_vec; + + // get dtype size + if (analyze) { + if (output.type().tensor_type().elem_type() == 1) { + this->ndname_to_dtype_size[output.name()] = 4; + } + } + } } void InferShapeImpl::print_summary() { - std::ios::sync_with_stdio(false); - std::cin.tie(0); + std::ios::sync_with_stdio(false); + std::cin.tie(0); - if (this->ndname_to_anal_data.empty()) { - std::cout << std::left << std::setw(INDENT) << "Name" - << std::left << std::setw(INDENT) << "Type" - << std::left << std::setw(INDENT) << "Input Shape" - << std::left << std::setw(INDENT) << "Output Shape" << '\n'; - std::cout << std::string(INDENT * 4, '-') << '\n'; + if (this->ndname_to_anal_data.empty()) { + std::cout << std::left << std::setw(INDENT) << "Name" + << std::left << std::setw(INDENT) << "Type" + << std::left << std::setw(INDENT) << "Input Shape" + << std::left << std::setw(INDENT) << "Output Shape" << '\n'; + std::cout << std::string(INDENT * 4, '-') << '\n'; - for (auto node : graph.node()) { - std::cout << std::left << std::setw(INDENT) << string_trimmer(node.name(), INDENT-5); + for (auto node : graph.node()) { + std::cout << std::left << std::setw(INDENT) << string_trimmer(node.name(), INDENT-5); - std::cout << std::left << std::setw(INDENT) << node.op_type(); - std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.input(0)]); - std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.output(0)]); - std::cout << '\n'; - } + std::cout << std::left << std::setw(INDENT) << node.op_type(); + std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.input(0)]); + std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.output(0)]); + std::cout << '\n'; } - else { - std::cout << std::left << std::setw(INDENT) << "Name" - << std::left << std::setw(INDENT) << "Type" - << std::left << std::setw(INDENT) << "Input Shape" - << std::left << std::setw(INDENT) << "Output Shape" - << std::left << std::setw(INDENT) << "MACs" - << std::left << std::setw(INDENT) << "Params" - << std::left << std::setw(INDENT) << "Memory" << '\n'; - std::cout << std::string(INDENT * 7, '-') << '\n'; - - AnalyzeData total_data; - for (auto node : graph.node()) { - total_data.mac += this->ndname_to_anal_data[node.name()].mac; - total_data.param += this->ndname_to_anal_data[node.name()].param; - total_data.mem += this->ndname_to_anal_data[node.name()].mem; - - std::cout << std::left << std::setw(INDENT) << string_trimmer(node.name(), INDENT-5); - - std::cout << std::left << std::setw(INDENT) << node.op_type(); - std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.input(0)]); - std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.output(0)]); - std::cout << std::setw(INDENT) << int64_to_str(this->ndname_to_anal_data[node.name()].mac); - std::cout << std::setw(INDENT) << int64_to_str(this->ndname_to_anal_data[node.name()].param); - std::cout << std::setw(INDENT) << int64_to_str(this->ndname_to_anal_data[node.name()].mem); - std::cout << '\n'; - } - std::cout << std::left << std::setw(INDENT) << "Total"; - std::cout << std::left << std::setw(INDENT) << "-"; - std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[graph.input(0).name()]); - std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[graph.output(0).name()]); - std::cout << std::setw(INDENT) << int64_to_str(total_data.mac); - std::cout << std::setw(INDENT) << int64_to_str(total_data.param); - std::cout << std::setw(INDENT) << int64_to_str(total_data.mem); - std::cout << '\n'; + } + else { + std::cout << std::left << std::setw(INDENT) << "Name" + << std::left << std::setw(INDENT) << "Type" + << std::left << std::setw(INDENT) << "Input Shape" + << std::left << std::setw(INDENT) << "Output Shape" + << std::left << std::setw(INDENT) << "MACs" + << std::left << std::setw(INDENT) << "Params" + << std::left << std::setw(INDENT) << "Memory" << '\n'; + std::cout << std::string(INDENT * 7, '-') << '\n'; + + AnalyzeData total_data; + for (auto node : graph.node()) { + total_data.mac += this->ndname_to_anal_data[node.name()].mac; + total_data.param += this->ndname_to_anal_data[node.name()].param; + total_data.mem += this->ndname_to_anal_data[node.name()].mem; + + std::cout << std::left << std::setw(INDENT) << string_trimmer(node.name(), INDENT-5); + + std::cout << std::left << std::setw(INDENT) << node.op_type(); + std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.input(0)]); + std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.output(0)]); + std::cout << std::setw(INDENT) << int64_to_str(this->ndname_to_anal_data[node.name()].mac); + std::cout << std::setw(INDENT) << int64_to_str(this->ndname_to_anal_data[node.name()].param); + std::cout << std::setw(INDENT) << int64_to_str(this->ndname_to_anal_data[node.name()].mem); + std::cout << '\n'; } + std::cout << std::left << std::setw(INDENT) << "Total"; + std::cout << std::left << std::setw(INDENT) << "-"; + std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[graph.input(0).name()]); + std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[graph.output(0).name()]); + std::cout << std::setw(INDENT) << int64_to_str(total_data.mac); + std::cout << std::setw(INDENT) << int64_to_str(total_data.param); + std::cout << std::setw(INDENT) << int64_to_str(total_data.mem); + std::cout << '\n'; + } } void InferShapeImpl::infer_shapes_Conv(onnx::NodeProto &node) { - struct AttrInfo_Conv attr_info; - - // get node input shapes - std::vector> input_shapes; - for (int i = 0; i < node.input_size(); ++i) { - auto ndinput = node.input(i); - if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { - std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; - exit(1); - } - else { - // get shape from ndname_to_shape - std::vector shape = this->ndname_to_shape[ndinput]; - if (i == 0) attr_info.set_default_attr(shape.size()); - input_shapes.emplace_back(shape); - } - } - - // get attributes (kernel_shape, strides, pads, dilations, group) - for (auto attr : node.attribute()) { - if (attr.name() == "kernel_shape") { - for (int i = 0; i < attr.ints_size(); ++i) { - attr_info.kernel_shape.emplace_back(attr.ints(i)); - } - } - else if (attr.name() == "strides") { - attr_info.strides.clear(); - for (int i = 0; i < attr.ints_size(); ++i) { - attr_info.strides.emplace_back(attr.ints(i)); - } - } - else if (attr.name() == "pads") { - attr_info.pads.clear(); - for (int i = 0; i < attr.ints_size(); ++i) { - attr_info.pads.emplace_back(attr.ints(i)); - } - } - else if (attr.name() == "dilations") { - attr_info.dilations.clear(); - for (int i = 0; i < attr.ints_size(); ++i) { - attr_info.dilations.emplace_back(attr.ints(i)); - } - } - } - - // calculate output shape after convolution - auto input_shape = input_shapes[0]; - auto weight_shape = input_shapes[1]; - std::vector output_shape; - output_shape.emplace_back(input_shape[0]); // batch size - output_shape.emplace_back(weight_shape[0]); // number of channels - for (size_t i = 0; i < attr_info.kernel_shape.size(); ++i) { - output_shape.emplace_back((input_shape[i + 2] + 2 * attr_info.pads[i] - attr_info.dilations[i] * (attr_info.kernel_shape[i] - 1) - 1) / attr_info.strides[i] + 1); - } - - // set value_info and update ndname_to_shape - onnx::ValueInfoProto *val_info = this->graph.add_value_info(); - val_info->set_name(node.name()); - set_vec_to_shape(val_info, output_shape); - - this->ndname_to_shape[node.output(0)] = output_shape; - this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; + struct AttrInfo_Conv attr_info; + + // get node input shapes + std::vector> input_shapes; + for (int i = 0; i < node.input_size(); ++i) { + auto ndinput = node.input(i); + if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { + std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; + exit(1); + } + else { + // get shape from ndname_to_shape + std::vector shape = this->ndname_to_shape[ndinput]; + if (i == 0) attr_info.set_default_attr(shape.size()); + input_shapes.emplace_back(shape); + } + } + + // get attributes (kernel_shape, strides, pads, dilations, group) + for (auto attr : node.attribute()) { + if (attr.name() == "kernel_shape") { + for (int i = 0; i < attr.ints_size(); ++i) { + attr_info.kernel_shape.emplace_back(attr.ints(i)); + } + } + else if (attr.name() == "strides") { + attr_info.strides.clear(); + for (int i = 0; i < attr.ints_size(); ++i) { + attr_info.strides.emplace_back(attr.ints(i)); + } + } + else if (attr.name() == "pads") { + attr_info.pads.clear(); + for (int i = 0; i < attr.ints_size(); ++i) { + attr_info.pads.emplace_back(attr.ints(i)); + } + } + else if (attr.name() == "dilations") { + attr_info.dilations.clear(); + for (int i = 0; i < attr.ints_size(); ++i) { + attr_info.dilations.emplace_back(attr.ints(i)); + } + } + } + + // calculate output shape after convolution + auto input_shape = input_shapes[0]; + auto weight_shape = input_shapes[1]; + std::vector output_shape; + output_shape.emplace_back(input_shape[0]); // batch size + output_shape.emplace_back(weight_shape[0]); // number of channels + for (size_t i = 0; i < attr_info.kernel_shape.size(); ++i) { + output_shape.emplace_back((input_shape[i + 2] + 2 * attr_info.pads[i] - attr_info.dilations[i] * (attr_info.kernel_shape[i] - 1) - 1) / attr_info.strides[i] + 1); + } + + // set value_info and update ndname_to_shape + onnx::ValueInfoProto *val_info = this->graph.add_value_info(); + val_info->set_name(node.name()); + set_vec_to_shape(val_info, output_shape); + + this->ndname_to_shape[node.output(0)] = output_shape; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } void InferShapeImpl::infer_shapes_Relu(onnx::NodeProto &node) { - // get node input shapes - std::vector> input_shapes; - for (auto ndinput : node.input()) { - if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { - std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; - return; - } - else { - // get shape from ndname_to_shape - input_shapes.emplace_back(this->ndname_to_shape[ndinput]); - } - } - - // set value_info and update ndname_to_shape - onnx::ValueInfoProto *val_info = this->graph.add_value_info(); - val_info->set_name(node.name()); - set_vec_to_shape(val_info, input_shapes[0]); - - this->ndname_to_shape[node.output(0)] = input_shapes[0]; - this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; + // get node input shapes + std::vector> input_shapes; + for (auto ndinput : node.input()) { + if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { + std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; + return; + } + else { + // get shape from ndname_to_shape + input_shapes.emplace_back(this->ndname_to_shape[ndinput]); + } + } + + // set value_info and update ndname_to_shape + onnx::ValueInfoProto *val_info = this->graph.add_value_info(); + val_info->set_name(node.name()); + set_vec_to_shape(val_info, input_shapes[0]); + + this->ndname_to_shape[node.output(0)] = input_shapes[0]; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } void InferShapeImpl::infer_shapes_MaxPool(onnx::NodeProto &node) { - struct AttrInfo_MaxPool attr_info; - - auto input_shape = this->ndname_to_shape[node.input(0)]; - - attr_info.set_default_attr(input_shape.size()); - for (auto attr : node.attribute()) { - // std::cout << "attr name: " << attr.name() << '\n'; - if (attr.name() == "kernel_shape") { // required - for (int i = 0; i < attr.ints_size(); ++i) { - attr_info.kernel_shape.emplace_back(attr.ints(i)); - } - } - else if (attr.name() == "strides") { - attr_info.strides.clear(); - for (int i = 0; i < attr.ints_size(); ++i) { - attr_info.strides.emplace_back(attr.ints(i)); - } - } - else if (attr.name() == "pads") { - attr_info.pads.clear(); - for (int i = 0; i < attr.ints_size(); ++i) { - attr_info.pads.emplace_back(attr.ints(i)); - } - } - else if (attr.name() == "ceil_mode") { - attr_info.ceil_mode = attr.i(); - } - else if (attr.name() == "dilations") { - attr_info.dilations.clear(); - for (int i = 0; i < attr.ints_size(); ++i) { - attr_info.dilations.emplace_back(attr.ints(i)); - } - } - } - - // calculate output shape after maxpool - std::vector output_shape; - output_shape.emplace_back(input_shape[0]); // batch size - output_shape.emplace_back(input_shape[1]); // number of channels - for (size_t i = 0; i < attr_info.kernel_shape.size(); ++i) { - output_shape.emplace_back((input_shape[i + 2] + 2 * attr_info.pads[i] - attr_info.dilations[i] * (attr_info.kernel_shape[i] - 1) - 1) / attr_info.strides[i] + 1); - } - - // set value_info and update ndname_to_shape - onnx::ValueInfoProto *val_info = this->graph.add_value_info(); - val_info->set_name(node.name()); - set_vec_to_shape(val_info, output_shape); - - this->ndname_to_shape[node.output(0)] = output_shape; - this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; + struct AttrInfo_MaxPool attr_info; + + auto input_shape = this->ndname_to_shape[node.input(0)]; + + attr_info.set_default_attr(input_shape.size()); + for (auto attr : node.attribute()) { + // std::cout << "attr name: " << attr.name() << '\n'; + if (attr.name() == "kernel_shape") { // required + for (int i = 0; i < attr.ints_size(); ++i) { + attr_info.kernel_shape.emplace_back(attr.ints(i)); + } + } + else if (attr.name() == "strides") { + attr_info.strides.clear(); + for (int i = 0; i < attr.ints_size(); ++i) { + attr_info.strides.emplace_back(attr.ints(i)); + } + } + else if (attr.name() == "pads") { + attr_info.pads.clear(); + for (int i = 0; i < attr.ints_size(); ++i) { + attr_info.pads.emplace_back(attr.ints(i)); + } + } + else if (attr.name() == "ceil_mode") { + attr_info.ceil_mode = attr.i(); + } + else if (attr.name() == "dilations") { + attr_info.dilations.clear(); + for (int i = 0; i < attr.ints_size(); ++i) { + attr_info.dilations.emplace_back(attr.ints(i)); + } + } + } + + // calculate output shape after maxpool + std::vector output_shape; + output_shape.emplace_back(input_shape[0]); // batch size + output_shape.emplace_back(input_shape[1]); // number of channels + for (size_t i = 0; i < attr_info.kernel_shape.size(); ++i) { + output_shape.emplace_back((input_shape[i + 2] + 2 * attr_info.pads[i] - attr_info.dilations[i] * (attr_info.kernel_shape[i] - 1) - 1) / attr_info.strides[i] + 1); + } + + // set value_info and update ndname_to_shape + onnx::ValueInfoProto *val_info = this->graph.add_value_info(); + val_info->set_name(node.name()); + set_vec_to_shape(val_info, output_shape); + + this->ndname_to_shape[node.output(0)] = output_shape; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } void InferShapeImpl::infer_shapes_Add(onnx::NodeProto &node) { - // get node input shapes - std::vector> input_shapes; - for (auto ndinput : node.input()) { - if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { - std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; - return; - } - else { - // get shape from ndname_to_shape - input_shapes.emplace_back(this->ndname_to_shape[ndinput]); - } - } - - // set value_info and update ndname_to_shape - onnx::ValueInfoProto *val_info = this->graph.add_value_info(); - val_info->set_name(node.name()); - set_vec_to_shape(val_info, input_shapes[0]); - - this->ndname_to_shape[node.output(0)] = input_shapes[0]; - this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; + // get node input shapes + std::vector> input_shapes; + for (auto ndinput : node.input()) { + if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { + std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; + return; + } + else { + // get shape from ndname_to_shape + input_shapes.emplace_back(this->ndname_to_shape[ndinput]); + } + } + + // set value_info and update ndname_to_shape + onnx::ValueInfoProto *val_info = this->graph.add_value_info(); + val_info->set_name(node.name()); + set_vec_to_shape(val_info, input_shapes[0]); + + this->ndname_to_shape[node.output(0)] = input_shapes[0]; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } void InferShapeImpl::infer_shapes_GlobalAveragePool(onnx::NodeProto &node) { - // get node input shapes - std::vector> input_shapes; - for (auto ndinput : node.input()) { - if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { - std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; - return; - } - else { - // get shape from ndname_to_shape - input_shapes.emplace_back(this->ndname_to_shape[ndinput]); - } - } - - // calculate output shape after globalaveragepool - std::vector output_shape = { - input_shapes[0][0], // batch size - input_shapes[0][1] // number of channels - }; - for (size_t i = 2; i < input_shapes[0].size(); ++i) { - output_shape.emplace_back(1); + // get node input shapes + std::vector> input_shapes; + for (auto ndinput : node.input()) { + if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { + std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; + return; } - - // set value_info and update ndname_to_shape - onnx::ValueInfoProto *val_info = this->graph.add_value_info(); - val_info->set_name(node.name()); - set_vec_to_shape(val_info, output_shape); - - this->ndname_to_shape[node.output(0)] = output_shape; - this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; + else { + // get shape from ndname_to_shape + input_shapes.emplace_back(this->ndname_to_shape[ndinput]); + } + } + + // calculate output shape after globalaveragepool + std::vector output_shape = { + input_shapes[0][0], // batch size + input_shapes[0][1] // number of channels + }; + for (size_t i = 2; i < input_shapes[0].size(); ++i) { + output_shape.emplace_back(1); + } + + // set value_info and update ndname_to_shape + onnx::ValueInfoProto *val_info = this->graph.add_value_info(); + val_info->set_name(node.name()); + set_vec_to_shape(val_info, output_shape); + + this->ndname_to_shape[node.output(0)] = output_shape; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } void InferShapeImpl::infer_shapes_Flatten(onnx::NodeProto &node) { - // attribute axis is not used now - - // get node input shapes - std::vector> input_shapes; - for (auto ndinput : node.input()) { - if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { - std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; - return; - } - else { - // get shape from ndname_to_shape - input_shapes.emplace_back(this->ndname_to_shape[ndinput]); - } - } - - // calculate output shape after flatten - std::vector output_shape; - int64_t flatten_dim = 1; - for (size_t i = 2; i < input_shapes[0].size(); ++i) { - flatten_dim *= input_shapes[0][i]; - } - - if (flatten_dim == 1) { - output_shape = { - input_shapes[0][0], // batch size - input_shapes[0][1] // number of channels - }; + // attribute axis is not used now + + // get node input shapes + std::vector> input_shapes; + for (auto ndinput : node.input()) { + if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { + std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; + return; } else { - output_shape = { - input_shapes[0][0], // batch size - input_shapes[0][1], // number of channels - flatten_dim // flatten_dim - }; + // get shape from ndname_to_shape + input_shapes.emplace_back(this->ndname_to_shape[ndinput]); } + } - // set value_info and update ndname_to_shape - onnx::ValueInfoProto *val_info = this->graph.add_value_info(); - val_info->set_name(node.name()); - set_vec_to_shape(val_info, output_shape); + // calculate output shape after flatten + std::vector output_shape; + int64_t flatten_dim = 1; + for (size_t i = 2; i < input_shapes[0].size(); ++i) { + flatten_dim *= input_shapes[0][i]; + } - this->ndname_to_shape[node.output(0)] = output_shape; - this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; -} - -void InferShapeImpl::infer_shapes_Gemm(onnx::NodeProto &node) { - // get attributes - AttrInfo_Gemm attr_info; - for (auto attr : node.attribute()) { - if (attr.name() == "transA") { - if (attr.i() == 1) attr_info.transA = true; - } - else if (attr.name() == "transB") { - if (attr.i() == 1) attr_info.transB = true; - } - } - - // get node input shapes (A and B) - std::vector> input_shapes; - for (size_t num = 0; num < 2; ++num) { - auto ndinput = node.input(num); - if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { - std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; - return; - } - else { - // get shape from ndname_to_shape - input_shapes.emplace_back(this->ndname_to_shape[ndinput]); - if (num == 0 && attr_info.transA) { - std::reverse(input_shapes[num].begin(), input_shapes[num].end()); - } - else if (num == 1 && attr_info.transB) { - std::reverse(input_shapes[num].begin(), input_shapes[num].end()); - } - } - } - - // calculate output shape after gemm - // [M, K] * [K, N] = [M, N] - std::vector output_shape; + if (flatten_dim == 1) { + output_shape = { + input_shapes[0][0], // batch size + input_shapes[0][1] // number of channels + }; + } + else { output_shape = { - input_shapes[0][0], - input_shapes[1][1] + input_shapes[0][0], // batch size + input_shapes[0][1], // number of channels + flatten_dim // flatten_dim }; + } + + // set value_info and update ndname_to_shape + onnx::ValueInfoProto *val_info = this->graph.add_value_info(); + val_info->set_name(node.name()); + set_vec_to_shape(val_info, output_shape); - // set value_info and update ndname_to_shape - onnx::ValueInfoProto *val_info = this->graph.add_value_info(); - val_info->set_name(node.name()); - set_vec_to_shape(val_info, output_shape); + this->ndname_to_shape[node.output(0)] = output_shape; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; +} - this->ndname_to_shape[node.output(0)] = output_shape; - this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; +void InferShapeImpl::infer_shapes_Gemm(onnx::NodeProto &node) { + // get attributes + AttrInfo_Gemm attr_info; + for (auto attr : node.attribute()) { + if (attr.name() == "transA") { + if (attr.i() == 1) attr_info.transA = true; + } + else if (attr.name() == "transB") { + if (attr.i() == 1) attr_info.transB = true; + } + } + + // get node input shapes (A and B) + std::vector> input_shapes; + for (size_t num = 0; num < 2; ++num) { + auto ndinput = node.input(num); + if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { + std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; + return; + } + else { + // get shape from ndname_to_shape + input_shapes.emplace_back(this->ndname_to_shape[ndinput]); + if (num == 0 && attr_info.transA) { + std::reverse(input_shapes[num].begin(), input_shapes[num].end()); + } + else if (num == 1 && attr_info.transB) { + std::reverse(input_shapes[num].begin(), input_shapes[num].end()); + } + } + } + + // calculate output shape after gemm + // [M, K] * [K, N] = [M, N] + std::vector output_shape; + output_shape = { + input_shapes[0][0], + input_shapes[1][1] + }; + + // set value_info and update ndname_to_shape + onnx::ValueInfoProto *val_info = this->graph.add_value_info(); + val_info->set_name(node.name()); + set_vec_to_shape(val_info, output_shape); + + this->ndname_to_shape[node.output(0)] = output_shape; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } void InferShapeImpl::infer_shapes(bool analyze) { - this->set_io_iniz_shape_to_map(analyze); + this->set_io_iniz_shape_to_map(analyze); - // infer shape for each node and store to value_info - for (auto node : graph.node()) { - if (node.op_type() == "Conv") { - this->infer_shapes_Conv(node); - } - else if (node.op_type() == "Relu") { - this->infer_shapes_Relu(node); - } - else if (node.op_type() == "MaxPool") { - this->infer_shapes_MaxPool(node); - } - else if (node.op_type() == "Add") { - this->infer_shapes_Add(node); - } - else if (node.op_type() == "GlobalAveragePool") { - this->infer_shapes_GlobalAveragePool(node); - } - else if (node.op_type() == "Flatten") { - this->infer_shapes_Flatten(node); - } - else if (node.op_type() == "Gemm") { - this->infer_shapes_Gemm(node); - } - else { - std::cerr << "Error: " << node.op_type() << " not supported now\n"; - exit(1); - } - - if (analyze) { - // analyze node and store to ndname_to_anal_data - this->ndname_to_anal_data[node.name()] = - analyze_node(node, this->ndname_to_shape, this->ndname_to_dtype_size); - } + // infer shape for each node and store to value_info + for (auto node : graph.node()) { + if (node.op_type() == "Conv") { + this->infer_shapes_Conv(node); + } + else if (node.op_type() == "Relu") { + this->infer_shapes_Relu(node); + } + else if (node.op_type() == "MaxPool") { + this->infer_shapes_MaxPool(node); + } + else if (node.op_type() == "Add") { + this->infer_shapes_Add(node); + } + else if (node.op_type() == "GlobalAveragePool") { + this->infer_shapes_GlobalAveragePool(node); + } + else if (node.op_type() == "Flatten") { + this->infer_shapes_Flatten(node); + } + else if (node.op_type() == "Gemm") { + this->infer_shapes_Gemm(node); + } + else { + std::cerr << "Error: " << node.op_type() << " not supported now\n"; + exit(1); + } + + if (analyze) { + // analyze node and store to ndname_to_anal_data + this->ndname_to_anal_data[node.name()] = + analyze_node(node, this->ndname_to_shape, this->ndname_to_dtype_size); } + } } void InferShapeImpl::infer_shapes() { - this->infer_shapes(true); + this->infer_shapes(true); } diff --git a/src/ModelAnalysis.cpp b/src/ModelAnalysis.cpp index 1f0ddd9..49a5d51 100644 --- a/src/ModelAnalysis.cpp +++ b/src/ModelAnalysis.cpp @@ -2,227 +2,183 @@ #include "utils.hpp" int64_t get_prod(std::vector &vec) { - int64_t prod = 1; - for (int64_t v : vec) { - prod *= v; - } - return prod; + int64_t prod = 1; + for (int64_t v : vec) { + prod *= v; + } + return prod; +} + +NodeAnalArgs get_anal_args(onnx::NodeProto &node, const std::unordered_map> &ndname_to_shape, const std::unordered_map &ndname_to_dtype_size) { + NodeAnalArgs anal_args; + std::vector> input_shapes; + std::unordered_map ndname_to_size; + for (auto input : node.input()) { + input_shapes.emplace_back(ndname_to_shape.at(input)); + ndname_to_size[input] = ndname_to_dtype_size.at(input); + } + std::vector output_shape = ndname_to_shape.at(node.output(0)); + ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); + + anal_args.input_shapes = input_shapes; + anal_args.output_shape = output_shape; + anal_args.ndname_to_size = ndname_to_size; + + return anal_args; } AnalyzeData analyze_node_Conv(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { - AnalyzeData data; - - // MACs - std::vector kernel_shape = input_shapes[1]; - std::vector reduce_shape(kernel_shape.begin() + 1, kernel_shape.end()); - int64_t out_prod = get_prod(output_shape); - data.mac = out_prod * get_prod(reduce_shape) * MUL_MACS; - - if (node.input_size() == 3) { // with bias - data.mac += out_prod * ADD_MACS; - } - - // Parameters & Memory - for (size_t i = 1; i < input_shapes.size(); ++i) { - data.param += get_prod(input_shapes[i]); - data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; - } - data.mem += get_prod(input_shapes[0]) * ndname_to_size[node.input(0)]; - data.mem += out_prod * ndname_to_size[node.output(0)]; - - return data; + AnalyzeData data; + + // MACs + std::vector kernel_shape = input_shapes[1]; + std::vector reduce_shape(kernel_shape.begin() + 1, kernel_shape.end()); + int64_t out_prod = get_prod(output_shape); + data.mac = out_prod * get_prod(reduce_shape) * MUL_MACS; + + if (node.input_size() == 3) { // with bias + data.mac += out_prod * ADD_MACS; + } + + // Parameters & Memory + for (size_t i = 1; i < input_shapes.size(); ++i) { + data.param += get_prod(input_shapes[i]); + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(input_shapes[0]) * ndname_to_size[node.input(0)]; + data.mem += out_prod * ndname_to_size[node.output(0)]; + + return data; } AnalyzeData analyze_node_Relu(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { - AnalyzeData data; + AnalyzeData data; - data.param = 0; // no trainable parameters - data.mac = 0; // non MACs operation + data.param = 0; // no trainable parameters + data.mac = 0; // non MACs operation - // Memory - for (size_t i = 0; i < input_shapes.size(); ++i) { - data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; - } - data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; + // Memory + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; - return data; + return data; } AnalyzeData analyze_node_MaxPool(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { - AnalyzeData data; + AnalyzeData data; - // MACs - data.mac = 0; // non MACs operation + // MACs + data.mac = 0; // non MACs operation - // Parameters & Memory - data.param = 0; // no trainable parameters - for (size_t i = 0; i < input_shapes.size(); ++i) { - data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; - } - data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; + // Parameters & Memory + data.param = 0; // no trainable parameters + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; - return data; + return data; } AnalyzeData analyze_node_Add(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { - AnalyzeData data; + AnalyzeData data; - // MACs - for (size_t i = 0; i < input_shapes.size(); ++i) { - data.mac += get_prod(input_shapes[i]) * ADD_MACS; - } + // MACs + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mac += get_prod(input_shapes[i]) * ADD_MACS; + } - // Parameters & Memory - data.param = 0; // no trainable parameters - for (size_t i = 0; i < input_shapes.size(); ++i) { - data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; - } - data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; + // Parameters & Memory + data.param = 0; // no trainable parameters + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; - return data; + return data; } AnalyzeData analyze_node_GlobalAveragePool(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { - AnalyzeData data; - - // MACs - for (size_t i = 0; i < input_shapes.size(); ++i) { - data.mac += get_prod(input_shapes[i]) * ADD_MACS; - } - data.mac += get_prod(output_shape) * DIV_MACS; - - // Parameters & Memory - data.param = 0; // no trainable parameters - for (size_t i = 0; i < input_shapes.size(); ++i) { - data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; - } - data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; - - return data; + AnalyzeData data; + + // MACs + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mac += get_prod(input_shapes[i]) * ADD_MACS; + } + data.mac += get_prod(output_shape) * DIV_MACS; + + // Parameters & Memory + data.param = 0; // no trainable parameters + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; + + return data; } AnalyzeData analyze_node_Flatten(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { - AnalyzeData data; + AnalyzeData data; - // MACs - data.mac = 0; // non MACs operation + // MACs + data.mac = 0; // non MACs operation - // Parameters & Memory - data.param = 0; // no trainable parameters - for (size_t i = 0; i < input_shapes.size(); ++i) { - data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; - } - data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; + // Parameters & Memory + data.param = 0; // no trainable parameters + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; - return data; + return data; } AnalyzeData analyze_node_Gemm(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { - AnalyzeData data; + AnalyzeData data; - // MACs - data.mac = get_prod(input_shapes[0]) * get_prod(output_shape) * MUL_MACS; + // MACs + data.mac = get_prod(input_shapes[0]) * get_prod(output_shape) * MUL_MACS; - // Parameters & Memory - data.param = get_prod(input_shapes[0]) * get_prod(output_shape) + get_prod(output_shape); - for (size_t i = 0; i < input_shapes.size(); ++i) { - data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; - } - data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; + // Parameters & Memory + data.param = get_prod(input_shapes[0]) * get_prod(output_shape) + get_prod(output_shape); + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; - return data; + return data; } AnalyzeData analyze_node(onnx::NodeProto &node, const std::unordered_map> &ndname_to_shape, const std::unordered_map &ndname_to_dtype_size) { - AnalyzeData data; - if (node.op_type() == "Conv") { - std::unordered_map ndname_to_size; - std::vector> input_shapes; - for (auto input : node.input()) { - ndname_to_size[input] = ndname_to_dtype_size.at(input); - input_shapes.emplace_back(ndname_to_shape.at(input)); - } - std::vector output_shape = ndname_to_shape.at(node.output(0)); - ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); - - data = analyze_node_Conv(node, input_shapes, output_shape, ndname_to_size); - } - else if (node.op_type() == "Relu") { - std::unordered_map ndname_to_size; - std::vector> input_shapes; - for (auto input : node.input()) { - ndname_to_size[input] = ndname_to_dtype_size.at(input); - input_shapes.emplace_back(ndname_to_shape.at(input)); - } - std::vector output_shape = ndname_to_shape.at(node.output(0)); - ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); - - data = analyze_node_Relu(node, input_shapes, output_shape, ndname_to_size); - } - else if (node.op_type() == "MaxPool") { - std::unordered_map ndname_to_size; - std::vector> input_shapes; - for (auto input : node.input()) { - ndname_to_size[input] = ndname_to_dtype_size.at(input); - input_shapes.emplace_back(ndname_to_shape.at(input)); - } - std::vector output_shape = ndname_to_shape.at(node.output(0)); - ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); - - data = analyze_node_MaxPool(node, input_shapes, output_shape, ndname_to_size); - } - else if (node.op_type() == "Add") { - std::unordered_map ndname_to_size; - std::vector> input_shapes; - for (auto input : node.input()) { - ndname_to_size[input] = ndname_to_dtype_size.at(input); - input_shapes.emplace_back(ndname_to_shape.at(input)); - } - std::vector output_shape = ndname_to_shape.at(node.output(0)); - ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); - - data = analyze_node_Add(node, input_shapes, output_shape, ndname_to_size); - } - else if (node.op_type() == "GlobalAveragePool") { - std::unordered_map ndname_to_size; - std::vector> input_shapes; - for (auto input : node.input()) { - ndname_to_size[input] = ndname_to_dtype_size.at(input); - input_shapes.emplace_back(ndname_to_shape.at(input)); - } - std::vector output_shape = ndname_to_shape.at(node.output(0)); - ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); - - data = analyze_node_GlobalAveragePool(node, input_shapes, output_shape, ndname_to_size); - } - else if (node.op_type() == "Flatten") { - std::unordered_map ndname_to_size; - std::vector> input_shapes; - for (auto input : node.input()) { - ndname_to_size[input] = ndname_to_dtype_size.at(input); - input_shapes.emplace_back(ndname_to_shape.at(input)); - } - std::vector output_shape = ndname_to_shape.at(node.output(0)); - ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); - - data = analyze_node_Flatten(node, input_shapes, output_shape, ndname_to_size); - } - else if (node.op_type() == "Gemm") { - std::unordered_map ndname_to_size; - std::vector> input_shapes; - for (auto input : node.input()) { - ndname_to_size[input] = ndname_to_dtype_size.at(input); - input_shapes.emplace_back(ndname_to_shape.at(input)); - } - std::vector output_shape = ndname_to_shape.at(node.output(0)); - ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); - - data = analyze_node_Gemm(node, input_shapes, output_shape, ndname_to_size); - } - else { - std::cerr << "Error: " << node.op_type() << " not supported now\n"; - exit(1); - } - - return data; + AnalyzeData data; + NodeAnalArgs anal_args = get_anal_args(node, ndname_to_shape, ndname_to_dtype_size); + if (node.op_type() == "Conv") { + data = analyze_node_Conv(node, anal_args.input_shapes, anal_args.output_shape, anal_args.ndname_to_size); + } + else if (node.op_type() == "Relu") { + data = analyze_node_Relu(node, anal_args.input_shapes, anal_args.output_shape, anal_args.ndname_to_size); + } + else if (node.op_type() == "MaxPool") { + data = analyze_node_MaxPool(node, anal_args.input_shapes, anal_args.output_shape, anal_args.ndname_to_size); + } + else if (node.op_type() == "Add") { + data = analyze_node_Add(node, anal_args.input_shapes, anal_args.output_shape, anal_args.ndname_to_size); + } + else if (node.op_type() == "GlobalAveragePool") { + data = analyze_node_GlobalAveragePool(node, anal_args.input_shapes, anal_args.output_shape, anal_args.ndname_to_size); + } + else if (node.op_type() == "Flatten") { + data = analyze_node_Flatten(node, anal_args.input_shapes, anal_args.output_shape, anal_args.ndname_to_size); + } + else if (node.op_type() == "Gemm") { + data = analyze_node_Gemm(node, anal_args.input_shapes, anal_args.output_shape, anal_args.ndname_to_size); + } + else { + std::cerr << "Error: " << node.op_type() << " not supported now\n"; + exit(1); + } + + return data; } diff --git a/src/main.cpp b/src/main.cpp index a78e0be..3df9478 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -5,7 +5,7 @@ namespace py = pybind11; -PYBIND11_MODULE(_onnxinfo, m) { +PYBIND11_MODULE(onnxinfo, m) { m.doc() = "pybind11 onnxinfo module"; // optional module docstring m.def("read_onnx", &read_onnx, "A C++ function that read ONNX model"); diff --git a/src/utils.cpp b/src/utils.cpp index 3678ca2..68bec7a 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -35,21 +35,21 @@ void print_dim(const ::onnx::TensorShapeProto_Dimension &dim) { } void print_dims_vec(const std::vector &dims) { - std::cout << "["; - for (size_t i = 0; i < dims.size() - 1; ++i) { - std::cout << dims[i] << ", "; - } - std::cout << dims.back() << "]"; + std::cout << "["; + for (size_t i = 0; i < dims.size() - 1; ++i) { + std::cout << dims[i] << ", "; + } + std::cout << dims.back() << "]"; } std::string dims_vec_to_str(const std::vector &dims) { - std::string str = "["; - for (size_t i = 0; i < dims.size() - 1; ++i) { - str += std::to_string(dims[i]) + ", "; - } - str += std::to_string(dims.back()) + "]"; + std::string str = "["; + for (size_t i = 0; i < dims.size() - 1; ++i) { + str += std::to_string(dims[i]) + ", "; + } + str += std::to_string(dims.back()) + "]"; - return str; + return str; } void set_vec_to_shape(onnx::ValueInfoProto *val_info, const std::vector &dims) { @@ -64,7 +64,7 @@ std::string string_trimmer(const std::string &inputString, const size_t maxLen) std::string trimmedString = inputString; if (trimmedString.length() > maxLen) { - trimmedString = trimmedString.substr(0, maxLen - 3) + "..."; + trimmedString = trimmedString.substr(0, maxLen - 3) + "..."; } return trimmedString; diff --git a/tests/test_onnxinfo.py b/tests/test_onnxinfo.py index 261bbd7..22f80e4 100644 --- a/tests/test_onnxinfo.py +++ b/tests/test_onnxinfo.py @@ -1,17 +1,17 @@ -import _onnxinfo +import onnxinfo import pytest def test_read_onnx_nofile(): with pytest.raises(ValueError): - _onnxinfo.read_onnx('non-existing.onnx') + onnxinfo.read_onnx('non-existing.onnx') def test_read_onnx(): - info = _onnxinfo.read_onnx('models/resnet18_Opset16.onnx') + info = onnxinfo.read_onnx('models/resnet18_Opset16.onnx') assert info is not None def test_infer_shape(): - model = _onnxinfo.read_onnx('models/resnet18_Opset16.onnx') - target = _onnxinfo.InferShapeImpl(model.graph()) + model = onnxinfo.read_onnx('models/resnet18_Opset16.onnx') + target = onnxinfo.InferShapeImpl(model.graph()) old_size = len(target.get_ndname_to_shape()) target.infer_shapes() new_size = len(target.get_ndname_to_shape()) @@ -19,8 +19,8 @@ def test_infer_shape(): def test_print_summary(): try: - model = _onnxinfo.read_onnx('models/resnet18_Opset16.onnx') - target = _onnxinfo.InferShapeImpl(model.graph()) + model = onnxinfo.read_onnx('models/resnet18_Opset16.onnx') + target = onnxinfo.InferShapeImpl(model.graph()) target.infer_shapes() target.print_summary() except: