From f4805cd1982539cb62409215198949cf12df4990 Mon Sep 17 00:00:00 2001 From: Ray Huang Date: Tue, 10 Dec 2024 23:02:30 +0800 Subject: [PATCH] feat: static analysis of MACs, params, mem for ONNX --- README.md | 17 +++ include/InferShape.hpp | 10 +- include/ModelAnalysis.hpp | 26 +++++ include/utils.hpp | 2 +- src/InferShape.cpp | 132 +++++++++++++++++----- src/ModelAnalysis.cpp | 228 ++++++++++++++++++++++++++++++++++++++ src/main.cpp | 2 +- src/utils.cpp | 39 +++---- 8 files changed, 398 insertions(+), 58 deletions(-) create mode 100644 include/ModelAnalysis.hpp create mode 100644 src/ModelAnalysis.cpp diff --git a/README.md b/README.md index be6c91f..ee95eb2 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,23 @@ A tool to show ONNX model summary like torchinfo 3. `cmake -S . -B build/` 4. `cmake --build build/ [--parallel ]` to build dependency and onnxinfo +## Usage +### Shape Inference +Support node types so far: Conv, Relu, MaxPool, Add, GlobalAveragePool, Flatten, Gemm + +### Static Analysis +Will be run when shape inferencing by default. +You can use `infer_shapes(analyze = false)` to run shape inference only. + +#### MACs +Doesn't count for non MACs operations like `Relu`, `MaxPool` and so on. + +#### Parameters +Calculate trainable parameters for each node. + +#### Memory +Calculate the memory usage of each node when input and output. (Bytes) + ## Test `python3 -m pytest -v` diff --git a/include/InferShape.hpp b/include/InferShape.hpp index 4dd9272..5673387 100644 --- a/include/InferShape.hpp +++ b/include/InferShape.hpp @@ -3,6 +3,7 @@ #include #include "onnx.proto3.pb.h" #include "AttrInfo.hpp" +#include "ModelAnalysis.hpp" #define INDENT 30 @@ -13,19 +14,22 @@ class InferShapeImpl { ~InferShapeImpl() = default; - void set_io_iniz_shape_to_map(); + void set_io_iniz_shape_to_map(bool analyze); + void infer_shapes(bool analyze = true); void infer_shapes(); void print_summary(); - std::unordered_map> get_ndname_to_shape() { + const std::unordered_map> get_ndname_to_shape() { return this->ndname_to_shape; } private: onnx::GraphProto graph; std::unordered_map> ndname_to_shape; + std::unordered_map ndname_to_anal_data; + std::unordered_map ndname_to_dtype_size; - // TODO: check dimensions & attributes needed or default + // TODO: more op types void infer_shapes_Conv(onnx::NodeProto &node); void infer_shapes_Relu(onnx::NodeProto &node); void infer_shapes_MaxPool(onnx::NodeProto &node); diff --git a/include/ModelAnalysis.hpp b/include/ModelAnalysis.hpp new file mode 100644 index 0000000..32a9d16 --- /dev/null +++ b/include/ModelAnalysis.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include "InferShape.hpp" +#include "onnx.proto3.pb.h" + +#define MUL_MACS 1 +#define ADD_MACS 1 +#define DIV_MACS 4 + +struct AnalyzeData { + int64_t mac = 0; + int64_t param = 0; + int64_t mem = 0; +}; + +int64_t get_prod(std::vector &vec); + +AnalyzeData analyze_node_Conv(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size); +AnalyzeData analyze_node_Relu(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size); +AnalyzeData analyze_node_MaxPool(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size); +AnalyzeData analyze_node_Add(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size); +AnalyzeData analyze_node_GlobalAveragePool(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size); +AnalyzeData analyze_node_Flatten(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size); +AnalyzeData analyze_node_Gemm(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size); + +AnalyzeData analyze_node(onnx::NodeProto &node, const std::unordered_map> &ndname_to_shape, const std::unordered_map &ndname_to_dtype_size); diff --git a/include/utils.hpp b/include/utils.hpp index 02cd30c..2e5a47a 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -9,8 +9,8 @@ onnx::ModelProto read_onnx(const std::string &filename); void print_dim(const ::onnx::TensorShapeProto_Dimension &dim); -void print_val_info(const ::onnx::ValueInfoProto &info); void print_dims_vec(const std::vector &dims); std::string dims_vec_to_str(const std::vector &dims); void set_vec_to_shape(onnx::ValueInfoProto *val_info, const std::vector &dims); std::string string_trimmer(const std::string &inputString, const size_t maxLen); +std::string int64_to_str(int64_t num); diff --git a/src/InferShape.cpp b/src/InferShape.cpp index 19ea8ec..9f9c7b4 100644 --- a/src/InferShape.cpp +++ b/src/InferShape.cpp @@ -2,9 +2,8 @@ #include #include "InferShape.hpp" #include "utils.hpp" -#include "AttrInfo.hpp" -void InferShapeImpl::set_io_iniz_shape_to_map() { +void InferShapeImpl::set_io_iniz_shape_to_map(bool analyze) { for (auto input : this->graph.input()) { auto shape = input.type().tensor_type().shape(); std::vector shape_vec = {}; @@ -15,6 +14,13 @@ void InferShapeImpl::set_io_iniz_shape_to_map() { } } this->ndname_to_shape[input.name()] = shape_vec; + + // get dtype size + if (analyze) { + if (input.type().tensor_type().elem_type() == 1) { // float32 + this->ndname_to_dtype_size[input.name()] = 4; + } + } } for (auto initializer : this->graph.initializer()) { @@ -24,6 +30,13 @@ void InferShapeImpl::set_io_iniz_shape_to_map() { shape_vec.emplace_back(shape.Get(i)); } this->ndname_to_shape[initializer.name()] = shape_vec; + + // get dtype size + if (analyze) { + if (initializer.data_type() == 1) { + this->ndname_to_dtype_size[initializer.name()] = 4; + } + } } for (auto output : this->graph.output()) { @@ -36,22 +49,69 @@ void InferShapeImpl::set_io_iniz_shape_to_map() { } } this->ndname_to_shape[output.name()] = shape_vec; + + // get dtype size + if (analyze) { + if (output.type().tensor_type().elem_type() == 1) { + this->ndname_to_dtype_size[output.name()] = 4; + } + } } } void InferShapeImpl::print_summary() { - std::cout << std::left << std::setw(INDENT) << "Name" - << std::left << std::setw(INDENT) << "Type" - << std::left << std::setw(INDENT) << "Input Shape" - << std::left << std::setw(INDENT) << "Output Shape" << '\n'; - std::cout << std::string(INDENT * 4, '-') << '\n'; + std::ios::sync_with_stdio(false); + std::cin.tie(0); - for (auto node : graph.node()) { - std::cout << std::left << std::setw(INDENT) << string_trimmer(node.name(), INDENT-5); + if (this->ndname_to_anal_data.empty()) { + std::cout << std::left << std::setw(INDENT) << "Name" + << std::left << std::setw(INDENT) << "Type" + << std::left << std::setw(INDENT) << "Input Shape" + << std::left << std::setw(INDENT) << "Output Shape" << '\n'; + std::cout << std::string(INDENT * 4, '-') << '\n'; - std::cout << std::left << std::setw(INDENT) << node.op_type(); - std::cout << std::setw(INDENT) << dims_vec_to_str(this->get_ndname_to_shape()[node.input(0)]); - std::cout << std::setw(INDENT) << dims_vec_to_str(this->get_ndname_to_shape()[node.output(0)]); + for (auto node : graph.node()) { + std::cout << std::left << std::setw(INDENT) << string_trimmer(node.name(), INDENT-5); + + std::cout << std::left << std::setw(INDENT) << node.op_type(); + std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.input(0)]); + std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.output(0)]); + std::cout << '\n'; + } + } + else { + std::cout << std::left << std::setw(INDENT) << "Name" + << std::left << std::setw(INDENT) << "Type" + << std::left << std::setw(INDENT) << "Input Shape" + << std::left << std::setw(INDENT) << "Output Shape" + << std::left << std::setw(INDENT) << "MACs" + << std::left << std::setw(INDENT) << "Params" + << std::left << std::setw(INDENT) << "Memory" << '\n'; + std::cout << std::string(INDENT * 7, '-') << '\n'; + + AnalyzeData total_data; + for (auto node : graph.node()) { + total_data.mac += this->ndname_to_anal_data[node.name()].mac; + total_data.param += this->ndname_to_anal_data[node.name()].param; + total_data.mem += this->ndname_to_anal_data[node.name()].mem; + + std::cout << std::left << std::setw(INDENT) << string_trimmer(node.name(), INDENT-5); + + std::cout << std::left << std::setw(INDENT) << node.op_type(); + std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.input(0)]); + std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.output(0)]); + std::cout << std::setw(INDENT) << int64_to_str(this->ndname_to_anal_data[node.name()].mac); + std::cout << std::setw(INDENT) << int64_to_str(this->ndname_to_anal_data[node.name()].param); + std::cout << std::setw(INDENT) << int64_to_str(this->ndname_to_anal_data[node.name()].mem); + std::cout << '\n'; + } + std::cout << std::left << std::setw(INDENT) << "Total"; + std::cout << std::left << std::setw(INDENT) << "-"; + std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[graph.input(0).name()]); + std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[graph.output(0).name()]); + std::cout << std::setw(INDENT) << int64_to_str(total_data.mac); + std::cout << std::setw(INDENT) << int64_to_str(total_data.param); + std::cout << std::setw(INDENT) << int64_to_str(total_data.mem); std::cout << '\n'; } } @@ -63,13 +123,13 @@ void InferShapeImpl::infer_shapes_Conv(onnx::NodeProto &node) { std::vector> input_shapes; for (int i = 0; i < node.input_size(); ++i) { auto ndinput = node.input(i); - if (this->get_ndname_to_shape().find(ndinput) == this->get_ndname_to_shape().end()) { + if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; exit(1); } else { // get shape from ndname_to_shape - std::vector shape = this->get_ndname_to_shape()[ndinput]; + std::vector shape = this->ndname_to_shape[ndinput]; if (i == 0) attr_info.set_default_attr(shape.size()); input_shapes.emplace_back(shape); } @@ -118,19 +178,20 @@ void InferShapeImpl::infer_shapes_Conv(onnx::NodeProto &node) { set_vec_to_shape(val_info, output_shape); this->ndname_to_shape[node.output(0)] = output_shape; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } void InferShapeImpl::infer_shapes_Relu(onnx::NodeProto &node) { // get node input shapes std::vector> input_shapes; for (auto ndinput : node.input()) { - if (this->get_ndname_to_shape().find(ndinput) == this->get_ndname_to_shape().end()) { + if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; return; } else { // get shape from ndname_to_shape - input_shapes.emplace_back(this->get_ndname_to_shape()[ndinput]); + input_shapes.emplace_back(this->ndname_to_shape[ndinput]); } } @@ -140,12 +201,13 @@ void InferShapeImpl::infer_shapes_Relu(onnx::NodeProto &node) { set_vec_to_shape(val_info, input_shapes[0]); this->ndname_to_shape[node.output(0)] = input_shapes[0]; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } void InferShapeImpl::infer_shapes_MaxPool(onnx::NodeProto &node) { struct AttrInfo_MaxPool attr_info; - auto input_shape = this->get_ndname_to_shape()[node.input(0)]; + auto input_shape = this->ndname_to_shape[node.input(0)]; attr_info.set_default_attr(input_shape.size()); for (auto attr : node.attribute()) { @@ -192,19 +254,20 @@ void InferShapeImpl::infer_shapes_MaxPool(onnx::NodeProto &node) { set_vec_to_shape(val_info, output_shape); this->ndname_to_shape[node.output(0)] = output_shape; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } void InferShapeImpl::infer_shapes_Add(onnx::NodeProto &node) { // get node input shapes std::vector> input_shapes; for (auto ndinput : node.input()) { - if (this->get_ndname_to_shape().find(ndinput) == this->get_ndname_to_shape().end()) { + if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; return; } else { // get shape from ndname_to_shape - input_shapes.emplace_back(this->get_ndname_to_shape()[ndinput]); + input_shapes.emplace_back(this->ndname_to_shape[ndinput]); } } @@ -214,19 +277,20 @@ void InferShapeImpl::infer_shapes_Add(onnx::NodeProto &node) { set_vec_to_shape(val_info, input_shapes[0]); this->ndname_to_shape[node.output(0)] = input_shapes[0]; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } void InferShapeImpl::infer_shapes_GlobalAveragePool(onnx::NodeProto &node) { // get node input shapes std::vector> input_shapes; for (auto ndinput : node.input()) { - if (this->get_ndname_to_shape().find(ndinput) == this->get_ndname_to_shape().end()) { + if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; return; } else { // get shape from ndname_to_shape - input_shapes.emplace_back(this->get_ndname_to_shape()[ndinput]); + input_shapes.emplace_back(this->ndname_to_shape[ndinput]); } } @@ -245,6 +309,7 @@ void InferShapeImpl::infer_shapes_GlobalAveragePool(onnx::NodeProto &node) { set_vec_to_shape(val_info, output_shape); this->ndname_to_shape[node.output(0)] = output_shape; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } void InferShapeImpl::infer_shapes_Flatten(onnx::NodeProto &node) { @@ -253,13 +318,13 @@ void InferShapeImpl::infer_shapes_Flatten(onnx::NodeProto &node) { // get node input shapes std::vector> input_shapes; for (auto ndinput : node.input()) { - if (this->get_ndname_to_shape().find(ndinput) == this->get_ndname_to_shape().end()) { + if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; return; } else { // get shape from ndname_to_shape - input_shapes.emplace_back(this->get_ndname_to_shape()[ndinput]); + input_shapes.emplace_back(this->ndname_to_shape[ndinput]); } } @@ -290,6 +355,7 @@ void InferShapeImpl::infer_shapes_Flatten(onnx::NodeProto &node) { set_vec_to_shape(val_info, output_shape); this->ndname_to_shape[node.output(0)] = output_shape; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } void InferShapeImpl::infer_shapes_Gemm(onnx::NodeProto &node) { @@ -308,13 +374,13 @@ void InferShapeImpl::infer_shapes_Gemm(onnx::NodeProto &node) { std::vector> input_shapes; for (size_t num = 0; num < 2; ++num) { auto ndinput = node.input(num); - if (this->get_ndname_to_shape().find(ndinput) == this->get_ndname_to_shape().end()) { + if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) { std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n"; return; } else { // get shape from ndname_to_shape - input_shapes.emplace_back(this->get_ndname_to_shape()[ndinput]); + input_shapes.emplace_back(this->ndname_to_shape[ndinput]); if (num == 0 && attr_info.transA) { std::reverse(input_shapes[num].begin(), input_shapes[num].end()); } @@ -338,14 +404,14 @@ void InferShapeImpl::infer_shapes_Gemm(onnx::NodeProto &node) { set_vec_to_shape(val_info, output_shape); this->ndname_to_shape[node.output(0)] = output_shape; + this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)]; } -void InferShapeImpl::infer_shapes() { - this->set_io_iniz_shape_to_map(); +void InferShapeImpl::infer_shapes(bool analyze) { + this->set_io_iniz_shape_to_map(analyze); // infer shape for each node and store to value_info for (auto node : graph.node()) { - // std::cout << "Node: " << node.name() << '\n'; if (node.op_type() == "Conv") { this->infer_shapes_Conv(node); } @@ -371,5 +437,15 @@ void InferShapeImpl::infer_shapes() { std::cerr << "Error: " << node.op_type() << " not supported now\n"; exit(1); } + + if (analyze) { + // analyze node and store to ndname_to_anal_data + this->ndname_to_anal_data[node.name()] = + analyze_node(node, this->ndname_to_shape, this->ndname_to_dtype_size); + } } } + +void InferShapeImpl::infer_shapes() { + this->infer_shapes(true); +} diff --git a/src/ModelAnalysis.cpp b/src/ModelAnalysis.cpp new file mode 100644 index 0000000..1f0ddd9 --- /dev/null +++ b/src/ModelAnalysis.cpp @@ -0,0 +1,228 @@ +#include "ModelAnalysis.hpp" +#include "utils.hpp" + +int64_t get_prod(std::vector &vec) { + int64_t prod = 1; + for (int64_t v : vec) { + prod *= v; + } + return prod; +} + +AnalyzeData analyze_node_Conv(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { + AnalyzeData data; + + // MACs + std::vector kernel_shape = input_shapes[1]; + std::vector reduce_shape(kernel_shape.begin() + 1, kernel_shape.end()); + int64_t out_prod = get_prod(output_shape); + data.mac = out_prod * get_prod(reduce_shape) * MUL_MACS; + + if (node.input_size() == 3) { // with bias + data.mac += out_prod * ADD_MACS; + } + + // Parameters & Memory + for (size_t i = 1; i < input_shapes.size(); ++i) { + data.param += get_prod(input_shapes[i]); + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(input_shapes[0]) * ndname_to_size[node.input(0)]; + data.mem += out_prod * ndname_to_size[node.output(0)]; + + return data; +} + +AnalyzeData analyze_node_Relu(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { + AnalyzeData data; + + data.param = 0; // no trainable parameters + data.mac = 0; // non MACs operation + + // Memory + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; + + return data; +} + +AnalyzeData analyze_node_MaxPool(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { + AnalyzeData data; + + // MACs + data.mac = 0; // non MACs operation + + // Parameters & Memory + data.param = 0; // no trainable parameters + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; + + return data; +} + +AnalyzeData analyze_node_Add(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { + AnalyzeData data; + + // MACs + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mac += get_prod(input_shapes[i]) * ADD_MACS; + } + + // Parameters & Memory + data.param = 0; // no trainable parameters + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; + + return data; +} + +AnalyzeData analyze_node_GlobalAveragePool(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { + AnalyzeData data; + + // MACs + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mac += get_prod(input_shapes[i]) * ADD_MACS; + } + data.mac += get_prod(output_shape) * DIV_MACS; + + // Parameters & Memory + data.param = 0; // no trainable parameters + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; + + return data; +} + +AnalyzeData analyze_node_Flatten(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { + AnalyzeData data; + + // MACs + data.mac = 0; // non MACs operation + + // Parameters & Memory + data.param = 0; // no trainable parameters + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; + + return data; +} + +AnalyzeData analyze_node_Gemm(onnx::NodeProto &node, std::vector> &input_shapes, std::vector &output_shape, std::unordered_map &ndname_to_size) { + AnalyzeData data; + + // MACs + data.mac = get_prod(input_shapes[0]) * get_prod(output_shape) * MUL_MACS; + + // Parameters & Memory + data.param = get_prod(input_shapes[0]) * get_prod(output_shape) + get_prod(output_shape); + for (size_t i = 0; i < input_shapes.size(); ++i) { + data.mem += get_prod(input_shapes[i]) * ndname_to_size[node.input(i)]; + } + data.mem += get_prod(output_shape) * ndname_to_size[node.output(0)]; + + return data; +} + +AnalyzeData analyze_node(onnx::NodeProto &node, const std::unordered_map> &ndname_to_shape, const std::unordered_map &ndname_to_dtype_size) { + AnalyzeData data; + if (node.op_type() == "Conv") { + std::unordered_map ndname_to_size; + std::vector> input_shapes; + for (auto input : node.input()) { + ndname_to_size[input] = ndname_to_dtype_size.at(input); + input_shapes.emplace_back(ndname_to_shape.at(input)); + } + std::vector output_shape = ndname_to_shape.at(node.output(0)); + ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); + + data = analyze_node_Conv(node, input_shapes, output_shape, ndname_to_size); + } + else if (node.op_type() == "Relu") { + std::unordered_map ndname_to_size; + std::vector> input_shapes; + for (auto input : node.input()) { + ndname_to_size[input] = ndname_to_dtype_size.at(input); + input_shapes.emplace_back(ndname_to_shape.at(input)); + } + std::vector output_shape = ndname_to_shape.at(node.output(0)); + ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); + + data = analyze_node_Relu(node, input_shapes, output_shape, ndname_to_size); + } + else if (node.op_type() == "MaxPool") { + std::unordered_map ndname_to_size; + std::vector> input_shapes; + for (auto input : node.input()) { + ndname_to_size[input] = ndname_to_dtype_size.at(input); + input_shapes.emplace_back(ndname_to_shape.at(input)); + } + std::vector output_shape = ndname_to_shape.at(node.output(0)); + ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); + + data = analyze_node_MaxPool(node, input_shapes, output_shape, ndname_to_size); + } + else if (node.op_type() == "Add") { + std::unordered_map ndname_to_size; + std::vector> input_shapes; + for (auto input : node.input()) { + ndname_to_size[input] = ndname_to_dtype_size.at(input); + input_shapes.emplace_back(ndname_to_shape.at(input)); + } + std::vector output_shape = ndname_to_shape.at(node.output(0)); + ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); + + data = analyze_node_Add(node, input_shapes, output_shape, ndname_to_size); + } + else if (node.op_type() == "GlobalAveragePool") { + std::unordered_map ndname_to_size; + std::vector> input_shapes; + for (auto input : node.input()) { + ndname_to_size[input] = ndname_to_dtype_size.at(input); + input_shapes.emplace_back(ndname_to_shape.at(input)); + } + std::vector output_shape = ndname_to_shape.at(node.output(0)); + ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); + + data = analyze_node_GlobalAveragePool(node, input_shapes, output_shape, ndname_to_size); + } + else if (node.op_type() == "Flatten") { + std::unordered_map ndname_to_size; + std::vector> input_shapes; + for (auto input : node.input()) { + ndname_to_size[input] = ndname_to_dtype_size.at(input); + input_shapes.emplace_back(ndname_to_shape.at(input)); + } + std::vector output_shape = ndname_to_shape.at(node.output(0)); + ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); + + data = analyze_node_Flatten(node, input_shapes, output_shape, ndname_to_size); + } + else if (node.op_type() == "Gemm") { + std::unordered_map ndname_to_size; + std::vector> input_shapes; + for (auto input : node.input()) { + ndname_to_size[input] = ndname_to_dtype_size.at(input); + input_shapes.emplace_back(ndname_to_shape.at(input)); + } + std::vector output_shape = ndname_to_shape.at(node.output(0)); + ndname_to_size[node.output(0)] = ndname_to_dtype_size.at(node.output(0)); + + data = analyze_node_Gemm(node, input_shapes, output_shape, ndname_to_size); + } + else { + std::cerr << "Error: " << node.op_type() << " not supported now\n"; + exit(1); + } + + return data; +} diff --git a/src/main.cpp b/src/main.cpp index 7e02b6d..a78e0be 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -14,7 +14,7 @@ PYBIND11_MODULE(_onnxinfo, m) { py::class_(m, "InferShapeImpl") .def(py::init()) .def("set_io_iniz_shape_to_map", &InferShapeImpl::set_io_iniz_shape_to_map) - .def("infer_shapes", &InferShapeImpl::infer_shapes) + .def("infer_shapes", py::overload_cast<>(&InferShapeImpl::infer_shapes)) .def("print_summary", &InferShapeImpl::print_summary) .def("get_ndname_to_shape", &InferShapeImpl::get_ndname_to_shape); diff --git a/src/utils.cpp b/src/utils.cpp index 7fb2468..3678ca2 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -34,26 +34,6 @@ void print_dim(const ::onnx::TensorShapeProto_Dimension &dim) { } } -void print_val_info(const ::onnx::ValueInfoProto &info) { - auto shape = info.type().tensor_type().shape(); - std::cout << info.name() << "\t"; - // print input shape - // NOT YET - std::cout << "\t"; - - // print output shape - std::cout << "["; - if (shape.dim_size() != 0) { - int size = shape.dim_size(); - for (int i = 0; i < size - 1; ++i) { - print_dim(shape.dim(i)); - std::cout << ", "; - } - print_dim(shape.dim(size - 1)); - } - std::cout << "]\n"; -} - void print_dims_vec(const std::vector &dims) { std::cout << "["; for (size_t i = 0; i < dims.size() - 1; ++i) { @@ -81,11 +61,20 @@ void set_vec_to_shape(onnx::ValueInfoProto *val_info, const std::vector } std::string string_trimmer(const std::string &inputString, const size_t maxLen) { - std::string trimmedString = inputString; + std::string trimmedString = inputString; - if (trimmedString.length() > maxLen) { - trimmedString = trimmedString.substr(0, maxLen - 3) + "..."; - } + if (trimmedString.length() > maxLen) { + trimmedString = trimmedString.substr(0, maxLen - 3) + "..."; + } - return trimmedString; + return trimmedString; +} + +std::string int64_to_str(int64_t num) { + std::string str = std::to_string(num); + // add comma + for (int i = str.length() - 3; i > 0; i -= 3) { + str.insert(i, ","); + } + return str; }