Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@ A tool to show ONNX model summary like torchinfo
3. `cmake -S . -B build/`
4. `cmake --build build/ [--parallel <thread number>]` to build dependency and onnxinfo

## Usage
### Shape Inference
Support node types so far: Conv, Relu, MaxPool, Add, GlobalAveragePool, Flatten, Gemm

### Static Analysis
Will be run when shape inferencing by default.
You can use `infer_shapes(analyze = false)` to run shape inference only.

#### MACs
Doesn't count for non MACs operations like `Relu`, `MaxPool` and so on.

#### Parameters
Calculate trainable parameters for each node.

#### Memory
Calculate the memory usage of each node when input and output. (Bytes)

## Test
`python3 -m pytest -v`

Expand Down
10 changes: 7 additions & 3 deletions include/InferShape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <unordered_map>
#include "onnx.proto3.pb.h"
#include "AttrInfo.hpp"
#include "ModelAnalysis.hpp"

#define INDENT 30

Expand All @@ -13,19 +14,22 @@ class InferShapeImpl {

~InferShapeImpl() = default;

void set_io_iniz_shape_to_map();
void set_io_iniz_shape_to_map(bool analyze);
void infer_shapes(bool analyze = true);
void infer_shapes();
void print_summary();

std::unordered_map<std::string, std::vector<int64_t>> get_ndname_to_shape() {
const std::unordered_map<std::string, std::vector<int64_t>> get_ndname_to_shape() {
return this->ndname_to_shape;
}

private:
onnx::GraphProto graph;
std::unordered_map<std::string, std::vector<int64_t>> ndname_to_shape;
std::unordered_map<std::string, struct AnalyzeData> ndname_to_anal_data;
std::unordered_map<std::string, size_t> ndname_to_dtype_size;

// TODO: check dimensions & attributes needed or default
// TODO: more op types
void infer_shapes_Conv(onnx::NodeProto &node);
void infer_shapes_Relu(onnx::NodeProto &node);
void infer_shapes_MaxPool(onnx::NodeProto &node);
Expand Down
26 changes: 26 additions & 0 deletions include/ModelAnalysis.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "InferShape.hpp"
#include "onnx.proto3.pb.h"

#define MUL_MACS 1
#define ADD_MACS 1
#define DIV_MACS 4

struct AnalyzeData {
int64_t mac = 0;
int64_t param = 0;
int64_t mem = 0;
};

int64_t get_prod(std::vector<int64_t> &vec);

AnalyzeData analyze_node_Conv(onnx::NodeProto &node, std::vector<std::vector<int64_t>> &input_shapes, std::vector<int64_t> &output_shape, std::unordered_map<std::string, size_t> &ndname_to_size);
AnalyzeData analyze_node_Relu(onnx::NodeProto &node, std::vector<std::vector<int64_t>> &input_shapes, std::vector<int64_t> &output_shape, std::unordered_map<std::string, size_t> &ndname_to_size);
AnalyzeData analyze_node_MaxPool(onnx::NodeProto &node, std::vector<std::vector<int64_t>> &input_shapes, std::vector<int64_t> &output_shape, std::unordered_map<std::string, size_t> &ndname_to_size);
AnalyzeData analyze_node_Add(onnx::NodeProto &node, std::vector<std::vector<int64_t>> &input_shapes, std::vector<int64_t> &output_shape, std::unordered_map<std::string, size_t> &ndname_to_size);
AnalyzeData analyze_node_GlobalAveragePool(onnx::NodeProto &node, std::vector<std::vector<int64_t>> &input_shapes, std::vector<int64_t> &output_shape, std::unordered_map<std::string, size_t> &ndname_to_size);
AnalyzeData analyze_node_Flatten(onnx::NodeProto &node, std::vector<std::vector<int64_t>> &input_shapes, std::vector<int64_t> &output_shape, std::unordered_map<std::string, size_t> &ndname_to_size);
AnalyzeData analyze_node_Gemm(onnx::NodeProto &node, std::vector<std::vector<int64_t>> &input_shapes, std::vector<int64_t> &output_shape, std::unordered_map<std::string, size_t> &ndname_to_size);

AnalyzeData analyze_node(onnx::NodeProto &node, const std::unordered_map<std::string, std::vector<int64_t>> &ndname_to_shape, const std::unordered_map<std::string, size_t> &ndname_to_dtype_size);
2 changes: 1 addition & 1 deletion include/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

onnx::ModelProto read_onnx(const std::string &filename);
void print_dim(const ::onnx::TensorShapeProto_Dimension &dim);
void print_val_info(const ::onnx::ValueInfoProto &info);
void print_dims_vec(const std::vector<int64_t> &dims);
std::string dims_vec_to_str(const std::vector<int64_t> &dims);
void set_vec_to_shape(onnx::ValueInfoProto *val_info, const std::vector<int64_t> &dims);
std::string string_trimmer(const std::string &inputString, const size_t maxLen);
std::string int64_to_str(int64_t num);
132 changes: 104 additions & 28 deletions src/InferShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
#include <iomanip>
#include "InferShape.hpp"
#include "utils.hpp"
#include "AttrInfo.hpp"

void InferShapeImpl::set_io_iniz_shape_to_map() {
void InferShapeImpl::set_io_iniz_shape_to_map(bool analyze) {
for (auto input : this->graph.input()) {
auto shape = input.type().tensor_type().shape();
std::vector<int64_t> shape_vec = {};
Expand All @@ -15,6 +14,13 @@ void InferShapeImpl::set_io_iniz_shape_to_map() {
}
}
this->ndname_to_shape[input.name()] = shape_vec;

// get dtype size
if (analyze) {
if (input.type().tensor_type().elem_type() == 1) { // float32
this->ndname_to_dtype_size[input.name()] = 4;
}
}
}

for (auto initializer : this->graph.initializer()) {
Expand All @@ -24,6 +30,13 @@ void InferShapeImpl::set_io_iniz_shape_to_map() {
shape_vec.emplace_back(shape.Get(i));
}
this->ndname_to_shape[initializer.name()] = shape_vec;

// get dtype size
if (analyze) {
if (initializer.data_type() == 1) {
this->ndname_to_dtype_size[initializer.name()] = 4;
}
}
}

for (auto output : this->graph.output()) {
Expand All @@ -36,22 +49,69 @@ void InferShapeImpl::set_io_iniz_shape_to_map() {
}
}
this->ndname_to_shape[output.name()] = shape_vec;

// get dtype size
if (analyze) {
if (output.type().tensor_type().elem_type() == 1) {
this->ndname_to_dtype_size[output.name()] = 4;
}
}
}
}

void InferShapeImpl::print_summary() {
std::cout << std::left << std::setw(INDENT) << "Name"
<< std::left << std::setw(INDENT) << "Type"
<< std::left << std::setw(INDENT) << "Input Shape"
<< std::left << std::setw(INDENT) << "Output Shape" << '\n';
std::cout << std::string(INDENT * 4, '-') << '\n';
std::ios::sync_with_stdio(false);
std::cin.tie(0);

for (auto node : graph.node()) {
std::cout << std::left << std::setw(INDENT) << string_trimmer(node.name(), INDENT-5);
if (this->ndname_to_anal_data.empty()) {
std::cout << std::left << std::setw(INDENT) << "Name"
<< std::left << std::setw(INDENT) << "Type"
<< std::left << std::setw(INDENT) << "Input Shape"
<< std::left << std::setw(INDENT) << "Output Shape" << '\n';
std::cout << std::string(INDENT * 4, '-') << '\n';

std::cout << std::left << std::setw(INDENT) << node.op_type();
std::cout << std::setw(INDENT) << dims_vec_to_str(this->get_ndname_to_shape()[node.input(0)]);
std::cout << std::setw(INDENT) << dims_vec_to_str(this->get_ndname_to_shape()[node.output(0)]);
for (auto node : graph.node()) {
std::cout << std::left << std::setw(INDENT) << string_trimmer(node.name(), INDENT-5);

std::cout << std::left << std::setw(INDENT) << node.op_type();
std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.input(0)]);
std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.output(0)]);
std::cout << '\n';
}
}
else {
std::cout << std::left << std::setw(INDENT) << "Name"
<< std::left << std::setw(INDENT) << "Type"
<< std::left << std::setw(INDENT) << "Input Shape"
<< std::left << std::setw(INDENT) << "Output Shape"
<< std::left << std::setw(INDENT) << "MACs"
<< std::left << std::setw(INDENT) << "Params"
<< std::left << std::setw(INDENT) << "Memory" << '\n';
std::cout << std::string(INDENT * 7, '-') << '\n';

AnalyzeData total_data;
for (auto node : graph.node()) {
total_data.mac += this->ndname_to_anal_data[node.name()].mac;
total_data.param += this->ndname_to_anal_data[node.name()].param;
total_data.mem += this->ndname_to_anal_data[node.name()].mem;

std::cout << std::left << std::setw(INDENT) << string_trimmer(node.name(), INDENT-5);

std::cout << std::left << std::setw(INDENT) << node.op_type();
std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.input(0)]);
std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[node.output(0)]);
std::cout << std::setw(INDENT) << int64_to_str(this->ndname_to_anal_data[node.name()].mac);
std::cout << std::setw(INDENT) << int64_to_str(this->ndname_to_anal_data[node.name()].param);
std::cout << std::setw(INDENT) << int64_to_str(this->ndname_to_anal_data[node.name()].mem);
std::cout << '\n';
}
std::cout << std::left << std::setw(INDENT) << "Total";
std::cout << std::left << std::setw(INDENT) << "-";
std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[graph.input(0).name()]);
std::cout << std::setw(INDENT) << dims_vec_to_str(this->ndname_to_shape[graph.output(0).name()]);
std::cout << std::setw(INDENT) << int64_to_str(total_data.mac);
std::cout << std::setw(INDENT) << int64_to_str(total_data.param);
std::cout << std::setw(INDENT) << int64_to_str(total_data.mem);
std::cout << '\n';
}
}
Expand All @@ -63,13 +123,13 @@ void InferShapeImpl::infer_shapes_Conv(onnx::NodeProto &node) {
std::vector<std::vector<int64_t>> input_shapes;
for (int i = 0; i < node.input_size(); ++i) {
auto ndinput = node.input(i);
if (this->get_ndname_to_shape().find(ndinput) == this->get_ndname_to_shape().end()) {
if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) {
std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n";
exit(1);
}
else {
// get shape from ndname_to_shape
std::vector<int64_t> shape = this->get_ndname_to_shape()[ndinput];
std::vector<int64_t> shape = this->ndname_to_shape[ndinput];
if (i == 0) attr_info.set_default_attr(shape.size());
input_shapes.emplace_back(shape);
}
Expand Down Expand Up @@ -118,19 +178,20 @@ void InferShapeImpl::infer_shapes_Conv(onnx::NodeProto &node) {
set_vec_to_shape(val_info, output_shape);

this->ndname_to_shape[node.output(0)] = output_shape;
this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)];
}

void InferShapeImpl::infer_shapes_Relu(onnx::NodeProto &node) {
// get node input shapes
std::vector<std::vector<int64_t>> input_shapes;
for (auto ndinput : node.input()) {
if (this->get_ndname_to_shape().find(ndinput) == this->get_ndname_to_shape().end()) {
if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) {
std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n";
return;
}
else {
// get shape from ndname_to_shape
input_shapes.emplace_back(this->get_ndname_to_shape()[ndinput]);
input_shapes.emplace_back(this->ndname_to_shape[ndinput]);
}
}

Expand All @@ -140,12 +201,13 @@ void InferShapeImpl::infer_shapes_Relu(onnx::NodeProto &node) {
set_vec_to_shape(val_info, input_shapes[0]);

this->ndname_to_shape[node.output(0)] = input_shapes[0];
this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)];
}

void InferShapeImpl::infer_shapes_MaxPool(onnx::NodeProto &node) {
struct AttrInfo_MaxPool attr_info;

auto input_shape = this->get_ndname_to_shape()[node.input(0)];
auto input_shape = this->ndname_to_shape[node.input(0)];

attr_info.set_default_attr(input_shape.size());
for (auto attr : node.attribute()) {
Expand Down Expand Up @@ -192,19 +254,20 @@ void InferShapeImpl::infer_shapes_MaxPool(onnx::NodeProto &node) {
set_vec_to_shape(val_info, output_shape);

this->ndname_to_shape[node.output(0)] = output_shape;
this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)];
}

void InferShapeImpl::infer_shapes_Add(onnx::NodeProto &node) {
// get node input shapes
std::vector<std::vector<int64_t>> input_shapes;
for (auto ndinput : node.input()) {
if (this->get_ndname_to_shape().find(ndinput) == this->get_ndname_to_shape().end()) {
if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) {
std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n";
return;
}
else {
// get shape from ndname_to_shape
input_shapes.emplace_back(this->get_ndname_to_shape()[ndinput]);
input_shapes.emplace_back(this->ndname_to_shape[ndinput]);
}
}

Expand All @@ -214,19 +277,20 @@ void InferShapeImpl::infer_shapes_Add(onnx::NodeProto &node) {
set_vec_to_shape(val_info, input_shapes[0]);

this->ndname_to_shape[node.output(0)] = input_shapes[0];
this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)];
}

void InferShapeImpl::infer_shapes_GlobalAveragePool(onnx::NodeProto &node) {
// get node input shapes
std::vector<std::vector<int64_t>> input_shapes;
for (auto ndinput : node.input()) {
if (this->get_ndname_to_shape().find(ndinput) == this->get_ndname_to_shape().end()) {
if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) {
std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n";
return;
}
else {
// get shape from ndname_to_shape
input_shapes.emplace_back(this->get_ndname_to_shape()[ndinput]);
input_shapes.emplace_back(this->ndname_to_shape[ndinput]);
}
}

Expand All @@ -245,6 +309,7 @@ void InferShapeImpl::infer_shapes_GlobalAveragePool(onnx::NodeProto &node) {
set_vec_to_shape(val_info, output_shape);

this->ndname_to_shape[node.output(0)] = output_shape;
this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)];
}

void InferShapeImpl::infer_shapes_Flatten(onnx::NodeProto &node) {
Expand All @@ -253,13 +318,13 @@ void InferShapeImpl::infer_shapes_Flatten(onnx::NodeProto &node) {
// get node input shapes
std::vector<std::vector<int64_t>> input_shapes;
for (auto ndinput : node.input()) {
if (this->get_ndname_to_shape().find(ndinput) == this->get_ndname_to_shape().end()) {
if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) {
std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n";
return;
}
else {
// get shape from ndname_to_shape
input_shapes.emplace_back(this->get_ndname_to_shape()[ndinput]);
input_shapes.emplace_back(this->ndname_to_shape[ndinput]);
}
}

Expand Down Expand Up @@ -290,6 +355,7 @@ void InferShapeImpl::infer_shapes_Flatten(onnx::NodeProto &node) {
set_vec_to_shape(val_info, output_shape);

this->ndname_to_shape[node.output(0)] = output_shape;
this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)];
}

void InferShapeImpl::infer_shapes_Gemm(onnx::NodeProto &node) {
Expand All @@ -308,13 +374,13 @@ void InferShapeImpl::infer_shapes_Gemm(onnx::NodeProto &node) {
std::vector<std::vector<int64_t>> input_shapes;
for (size_t num = 0; num < 2; ++num) {
auto ndinput = node.input(num);
if (this->get_ndname_to_shape().find(ndinput) == this->get_ndname_to_shape().end()) {
if (this->ndname_to_shape.find(ndinput) == this->ndname_to_shape.end()) {
std::cerr << "Error: " << ndinput << " not found in ndname_to_shape\n";
return;
}
else {
// get shape from ndname_to_shape
input_shapes.emplace_back(this->get_ndname_to_shape()[ndinput]);
input_shapes.emplace_back(this->ndname_to_shape[ndinput]);
if (num == 0 && attr_info.transA) {
std::reverse(input_shapes[num].begin(), input_shapes[num].end());
}
Expand All @@ -338,14 +404,14 @@ void InferShapeImpl::infer_shapes_Gemm(onnx::NodeProto &node) {
set_vec_to_shape(val_info, output_shape);

this->ndname_to_shape[node.output(0)] = output_shape;
this->ndname_to_dtype_size[node.output(0)] = this->ndname_to_dtype_size[node.input(0)];
}

void InferShapeImpl::infer_shapes() {
this->set_io_iniz_shape_to_map();
void InferShapeImpl::infer_shapes(bool analyze) {
this->set_io_iniz_shape_to_map(analyze);

// infer shape for each node and store to value_info
for (auto node : graph.node()) {
// std::cout << "Node: " << node.name() << '\n';
if (node.op_type() == "Conv") {
this->infer_shapes_Conv(node);
}
Expand All @@ -371,5 +437,15 @@ void InferShapeImpl::infer_shapes() {
std::cerr << "Error: " << node.op_type() << " not supported now\n";
exit(1);
}

if (analyze) {
// analyze node and store to ndname_to_anal_data
this->ndname_to_anal_data[node.name()] =
analyze_node(node, this->ndname_to_shape, this->ndname_to_dtype_size);
}
}
}

void InferShapeImpl::infer_shapes() {
this->infer_shapes(true);
}
Loading
Loading