Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/ModelAnalysis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
using str_sz_map_t = std::unordered_map<std::string, size_t>;
using str_shape_map_t = std::unordered_map<std::string, std::vector<int64_t>>;

constexpr int64_t MUL_MACS = 1;
constexpr int64_t ADD_MACS = 1;
// refers to onnx-tool/node.py
constexpr int64_t DIV_MACS = 4;
constexpr int64_t MUL_FLOPS = 1;
constexpr int64_t ADD_FLOPS = 1;
constexpr int64_t CMP_FLOPS = 1;
constexpr int64_t DIV_FLOPS = 1;

struct AnalyzeData {
int64_t mac = 0;
int64_t flop = 0;
int64_t param = 0;
int64_t mem = 0;
};
Expand Down
8 changes: 4 additions & 4 deletions src/InferShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ void InferShapeImpl::print_summary() {
<< std::left << std::setw(TP_IND) << "Type"
<< std::left << std::setw(SP_IND) << "Input Shape"
<< std::left << std::setw(SP_IND) << "Output Shape"
<< std::left << std::setw(DT_IND) << "MACs"
<< std::left << std::setw(DT_IND) << "FLOPs"
<< std::left << std::setw(DT_IND) << "Params"
<< std::left << std::setw(DT_IND) << "Memory" << '\n';
std::cout << std::string(TOTAL_IND, '-') << '\n';

AnalyzeData total_data;
for (auto node : m_graph.node()) {
total_data.mac += this->m_name_to_anal_data[node.name()].mac;
total_data.flop += this->m_name_to_anal_data[node.name()].flop;
total_data.param += this->m_name_to_anal_data[node.name()].param;
total_data.mem += this->m_name_to_anal_data[node.name()].mem;

Expand All @@ -103,7 +103,7 @@ void InferShapeImpl::print_summary() {
std::cout << std::left << std::setw(TP_IND) << node.op_type();
std::cout << std::setw(SP_IND) << dims_vec_to_str(this->m_name_to_shape[node.input(0)]);
std::cout << std::setw(SP_IND) << dims_vec_to_str(this->m_name_to_shape[node.output(0)]);
std::cout << std::setw(DT_IND) << int64_to_str(this->m_name_to_anal_data[node.name()].mac);
std::cout << std::setw(DT_IND) << int64_to_str(this->m_name_to_anal_data[node.name()].flop);
std::cout << std::setw(DT_IND) << int64_to_str(this->m_name_to_anal_data[node.name()].param);
std::cout << std::setw(DT_IND) << int64_to_str(this->m_name_to_anal_data[node.name()].mem);
std::cout << '\n';
Expand All @@ -112,7 +112,7 @@ void InferShapeImpl::print_summary() {
std::cout << std::left << std::setw(TP_IND) << "-";
std::cout << std::setw(SP_IND) << dims_vec_to_str(this->m_name_to_shape[m_graph.input(0).name()]);
std::cout << std::setw(SP_IND) << dims_vec_to_str(this->m_name_to_shape[m_graph.output(0).name()]);
std::cout << std::setw(DT_IND) << int64_to_str(total_data.mac);
std::cout << std::setw(DT_IND) << int64_to_str(total_data.flop);
std::cout << std::setw(DT_IND) << int64_to_str(total_data.param);
std::cout << std::setw(DT_IND) << int64_to_str(total_data.mem);
std::cout << '\n';
Expand Down
36 changes: 25 additions & 11 deletions src/ModelAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ AnalyzeData analyze_node_Conv(onnx::NodeProto &node, NodeAnalArgs &anal_args) {
std::vector<int64_t> kernel_shape = input_shapes[1];
std::vector<int64_t> reduce_shape(kernel_shape.begin() + 1, kernel_shape.end());
int64_t out_prod = get_prod(output_shape);
data.mac = out_prod * get_prod(reduce_shape) * MUL_MACS;
data.flop = out_prod * get_prod(reduce_shape) * MUL_FLOPS;

if (node.input_size() == 3) { // with bias
data.mac += out_prod * ADD_MACS;
data.flop += out_prod * ADD_FLOPS;
}

// Parameters & Memory
Expand All @@ -64,7 +64,11 @@ AnalyzeData analyze_node_Relu(onnx::NodeProto &node, NodeAnalArgs &anal_args) {
str_sz_map_t ndname_to_size = anal_args.ndname_to_size;

data.param = 0; // no trainable parameters
data.mac = 0; // non MACs operation

// MACs
for (size_t i = 0; i < input_shapes.size(); ++i) {
data.flop += get_prod(input_shapes[i]) * CMP_FLOPS;
}

// Memory
for (size_t i = 0; i < input_shapes.size(); ++i) {
Expand All @@ -82,7 +86,15 @@ AnalyzeData analyze_node_MaxPool(onnx::NodeProto &node, NodeAnalArgs &anal_args)
str_sz_map_t ndname_to_size = anal_args.ndname_to_size;

// MACs
data.mac = 0; // non MACs operation
std::vector<int64_t> kernel_shape;
for (auto attr : node.attribute()) {
if (attr.name() == "kernel_shape") {
for (int i = 0; i < attr.ints_size(); ++i) {
kernel_shape.emplace_back(attr.ints(i));
}
}
}
data.flop = get_prod(output_shape) * get_prod(kernel_shape) * CMP_FLOPS;

// Parameters & Memory
data.param = 0; // no trainable parameters
Expand All @@ -101,9 +113,8 @@ AnalyzeData analyze_node_Add(onnx::NodeProto &node, NodeAnalArgs &anal_args) {
str_sz_map_t ndname_to_size = anal_args.ndname_to_size;

// MACs
for (size_t i = 0; i < input_shapes.size(); ++i) {
data.mac += get_prod(input_shapes[i]) * ADD_MACS;
}
data.flop += get_prod(output_shape) * ADD_FLOPS;


// Parameters & Memory
data.param = 0; // no trainable parameters
Expand All @@ -123,9 +134,9 @@ AnalyzeData analyze_node_GlobalAveragePool(onnx::NodeProto &node, NodeAnalArgs &

// MACs
for (size_t i = 0; i < input_shapes.size(); ++i) {
data.mac += get_prod(input_shapes[i]) * ADD_MACS;
data.flop += get_prod(input_shapes[i]) * ADD_FLOPS;
}
data.mac += get_prod(output_shape) * DIV_MACS;
data.flop += get_prod(output_shape) * DIV_FLOPS;

// Parameters & Memory
data.param = 0; // no trainable parameters
Expand All @@ -144,7 +155,7 @@ AnalyzeData analyze_node_Flatten(onnx::NodeProto &node, NodeAnalArgs &anal_args)
str_sz_map_t ndname_to_size = anal_args.ndname_to_size;

// MACs
data.mac = 0; // non MACs operation
data.flop = 0; // non MACs operation

// Parameters & Memory
data.param = 0; // no trainable parameters
Expand All @@ -163,7 +174,10 @@ AnalyzeData analyze_node_Gemm(onnx::NodeProto &node, NodeAnalArgs &anal_args) {
str_sz_map_t ndname_to_size = anal_args.ndname_to_size;

// MACs
data.mac = get_prod(input_shapes[0]) * get_prod(output_shape) * MUL_MACS;
data.flop = get_prod(input_shapes[0]) * get_prod(output_shape) * MUL_FLOPS;
if (node.input_size() == 3) { // with bias
data.flop += get_prod(output_shape) * ADD_FLOPS;
}

// Parameters & Memory
data.param = get_prod(input_shapes[0]) * get_prod(output_shape) + get_prod(output_shape);
Expand Down
Loading