Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ A tool to show ONNX model summary like torchinfo
## Test
`python3 -m pytest -v`

Use model from [ONNX Model Zoo](https://github.com/onnx/models/tree/main) to test.
Use model(resnet18_Opset16.onnx) from [ONNX Model Zoo](https://github.com/onnx/models/tree/main) to test.

## Docker
* Run `docker build -t onnxinfo -f docker/Dockerfile .` first.
Expand Down
46 changes: 46 additions & 0 deletions include/AttrInfo.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include <vector>
#include "onnx.proto3.pb.h"

struct AttrInfo_Conv {
std::vector<int64_t> kernel_shape; // got from Weight
std::vector<int64_t> strides; // default 1
std::vector<int64_t> pads; // default 0
std::vector<int64_t> dilations; // default 1

int64_t group = 1; // not used now

void set_default_attr(size_t shape_len) {
for (size_t i = 0; i < shape_len - 2; ++i) {
strides.emplace_back(1);
pads.emplace_back(0);
pads.emplace_back(0);
dilations.emplace_back(1);
}
}
};

struct AttrInfo_MaxPool {
std::vector<int64_t> kernel_shape;
std::vector<int64_t> strides; // default 1
std::vector<int64_t> pads; // default 0
std::vector<int64_t> dilations; // default 1

int64_t ceil_mode = 0; // not used now
int64_t storage_order = 0; // not used now

void set_default_attr(size_t shape_len) {
for (size_t i = 0; i < shape_len - 2; ++i) {
strides.emplace_back(1);
pads.emplace_back(0);
pads.emplace_back(0);
dilations.emplace_back(1);
}
}
};

struct AttrInfo_Gemm {
bool transA = false;
bool transB = false;
};
36 changes: 36 additions & 0 deletions include/InferShape.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include <unordered_map>
#include "onnx.proto3.pb.h"
#include "AttrInfo.hpp"

#define INDENT 30

class InferShapeImpl {
public:
InferShapeImpl() = default;
InferShapeImpl(const onnx::GraphProto &in_graph) : graph(in_graph) {}

~InferShapeImpl() = default;

void set_io_iniz_shape_to_map();
void infer_shapes();
void print_summary();

std::unordered_map<std::string, std::vector<int64_t>> get_ndname_to_shape() {
return this->ndname_to_shape;
}

private:
onnx::GraphProto graph;
std::unordered_map<std::string, std::vector<int64_t>> ndname_to_shape;

// TODO: check dimensions & attributes needed or default
void infer_shapes_Conv(onnx::NodeProto &node);
void infer_shapes_Relu(onnx::NodeProto &node);
void infer_shapes_MaxPool(onnx::NodeProto &node);
void infer_shapes_Add(onnx::NodeProto &node);
void infer_shapes_GlobalAveragePool(onnx::NodeProto &node);
void infer_shapes_Flatten(onnx::NodeProto &node);
void infer_shapes_Gemm(onnx::NodeProto &node);
};
10 changes: 7 additions & 3 deletions include/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
#include <iostream>
#include "onnx.proto3.pb.h"

onnx::GraphProto read_onnx(const std::string &filename);

void iterate_graph(const ::onnx::GraphProto &graph);
onnx::ModelProto read_onnx(const std::string &filename);
void print_dim(const ::onnx::TensorShapeProto_Dimension &dim);
void print_val_info(const ::onnx::ValueInfoProto &info);
void print_dims_vec(const std::vector<int64_t> &dims);
std::string dims_vec_to_str(const std::vector<int64_t> &dims);
void set_vec_to_shape(onnx::ValueInfoProto *val_info, const std::vector<int64_t> &dims);
std::string string_trimmer(const std::string &inputString, const size_t maxLen);
Loading
Loading