Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,18 @@ file(GLOB_RECURSE SOURCES "${onnxinfo_SOURCE_DIR}/src/*.cpp")
list(REMOVE_ITEM SOURCES "${onnxinfo_SOURCE_DIR}/src/main.cpp")

add_library(
${PROJECT_NAME} STATIC
_onnxinfo STATIC

${SOURCES}
${onnxinfo_SOURCE_DIR}/third_party/onnx/onnx.proto3.pb.cc
)
add_dependencies(${PROJECT_NAME} protobuf::protoc)
target_link_libraries(${PROJECT_NAME}
add_dependencies(_onnxinfo protobuf::protoc)
target_link_libraries(_onnxinfo
protobuf::libprotobuf
)

pybind11_add_module(_onnxinfo src/main.cpp)
target_link_libraries(_onnxinfo PRIVATE ${PROJECT_NAME})
pybind11_add_module(${PROJECT_NAME} src/main.cpp)
target_link_libraries(${PROJECT_NAME} PRIVATE _onnxinfo)

# custom cmake target for test
# add_custom_target(test ALL)
Expand Down
62 changes: 31 additions & 31 deletions include/AttrInfo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,43 @@
#include "onnx.proto3.pb.h"

struct AttrInfo_Conv {
std::vector<int64_t> kernel_shape; // got from Weight
std::vector<int64_t> strides; // default 1
std::vector<int64_t> pads; // default 0
std::vector<int64_t> dilations; // default 1

int64_t group = 1; // not used now

void set_default_attr(size_t shape_len) {
for (size_t i = 0; i < shape_len - 2; ++i) {
strides.emplace_back(1);
pads.emplace_back(0);
pads.emplace_back(0);
dilations.emplace_back(1);
}
std::vector<int64_t> kernel_shape; // got from Weight
std::vector<int64_t> strides; // default 1
std::vector<int64_t> pads; // default 0
std::vector<int64_t> dilations; // default 1

int64_t group = 1; // not used now

void set_default_attr(size_t shape_len) {
for (size_t i = 0; i < shape_len - 2; ++i) {
strides.emplace_back(1);
pads.emplace_back(0);
pads.emplace_back(0);
dilations.emplace_back(1);
}
}
};

struct AttrInfo_MaxPool {
std::vector<int64_t> kernel_shape;
std::vector<int64_t> strides; // default 1
std::vector<int64_t> pads; // default 0
std::vector<int64_t> dilations; // default 1

int64_t ceil_mode = 0; // not used now
int64_t storage_order = 0; // not used now

void set_default_attr(size_t shape_len) {
for (size_t i = 0; i < shape_len - 2; ++i) {
strides.emplace_back(1);
pads.emplace_back(0);
pads.emplace_back(0);
dilations.emplace_back(1);
}
std::vector<int64_t> kernel_shape;
std::vector<int64_t> strides; // default 1
std::vector<int64_t> pads; // default 0
std::vector<int64_t> dilations; // default 1

int64_t ceil_mode = 0; // not used now
int64_t storage_order = 0; // not used now

void set_default_attr(size_t shape_len) {
for (size_t i = 0; i < shape_len - 2; ++i) {
strides.emplace_back(1);
pads.emplace_back(0);
pads.emplace_back(0);
dilations.emplace_back(1);
}
}
};

struct AttrInfo_Gemm {
bool transA = false;
bool transB = false;
bool transA = false;
bool transB = false;
};
46 changes: 23 additions & 23 deletions include/InferShape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,32 @@

class InferShapeImpl {
public:
InferShapeImpl() = default;
InferShapeImpl(const onnx::GraphProto &in_graph) : graph(in_graph) {}
InferShapeImpl() = default;
InferShapeImpl(const onnx::GraphProto &in_graph) : graph(in_graph) {}

~InferShapeImpl() = default;
~InferShapeImpl() = default;

void set_io_iniz_shape_to_map(bool analyze);
void infer_shapes(bool analyze = true);
void infer_shapes();
void print_summary();
void set_io_iniz_shape_to_map(bool analyze);
void infer_shapes(bool analyze = true);
void infer_shapes();
void print_summary();

const std::unordered_map<std::string, std::vector<int64_t>> get_ndname_to_shape() {
return this->ndname_to_shape;
}
const std::unordered_map<std::string, std::vector<int64_t>> get_ndname_to_shape() {
return this->ndname_to_shape;
}

private:
onnx::GraphProto graph;
std::unordered_map<std::string, std::vector<int64_t>> ndname_to_shape;
std::unordered_map<std::string, struct AnalyzeData> ndname_to_anal_data;
std::unordered_map<std::string, size_t> ndname_to_dtype_size;

// TODO: more op types
void infer_shapes_Conv(onnx::NodeProto &node);
void infer_shapes_Relu(onnx::NodeProto &node);
void infer_shapes_MaxPool(onnx::NodeProto &node);
void infer_shapes_Add(onnx::NodeProto &node);
void infer_shapes_GlobalAveragePool(onnx::NodeProto &node);
void infer_shapes_Flatten(onnx::NodeProto &node);
void infer_shapes_Gemm(onnx::NodeProto &node);
onnx::GraphProto graph;
std::unordered_map<std::string, std::vector<int64_t>> ndname_to_shape;
std::unordered_map<std::string, struct AnalyzeData> ndname_to_anal_data;
std::unordered_map<std::string, size_t> ndname_to_dtype_size;

// TODO: more op types
void infer_shapes_Conv(onnx::NodeProto &node);
void infer_shapes_Relu(onnx::NodeProto &node);
void infer_shapes_MaxPool(onnx::NodeProto &node);
void infer_shapes_Add(onnx::NodeProto &node);
void infer_shapes_GlobalAveragePool(onnx::NodeProto &node);
void infer_shapes_Flatten(onnx::NodeProto &node);
void infer_shapes_Gemm(onnx::NodeProto &node);
};
13 changes: 10 additions & 3 deletions include/ModelAnalysis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,19 @@
#define DIV_MACS 4

struct AnalyzeData {
int64_t mac = 0;
int64_t param = 0;
int64_t mem = 0;
int64_t mac = 0;
int64_t param = 0;
int64_t mem = 0;
};

struct NodeAnalArgs {
std::vector<std::vector<int64_t>> input_shapes;
std::vector<int64_t> output_shape;
std::unordered_map<std::string, size_t> ndname_to_size;
};

int64_t get_prod(std::vector<int64_t> &vec);
NodeAnalArgs get_anal_args(onnx::NodeProto &node, const std::unordered_map<std::string, std::vector<int64_t>> &ndname_to_shape, const std::unordered_map<std::string, size_t> &ndname_to_dtype_size);

AnalyzeData analyze_node_Conv(onnx::NodeProto &node, std::vector<std::vector<int64_t>> &input_shapes, std::vector<int64_t> &output_shape, std::unordered_map<std::string, size_t> &ndname_to_size);
AnalyzeData analyze_node_Relu(onnx::NodeProto &node, std::vector<std::vector<int64_t>> &input_shapes, std::vector<int64_t> &output_shape, std::unordered_map<std::string, size_t> &ndname_to_size);
Expand Down
Loading
Loading