-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodify_prototxt.py
More file actions
78 lines (55 loc) · 2.14 KB
/
modify_prototxt.py
File metadata and controls
78 lines (55 loc) · 2.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python3
import sys
import argparse
import caffe
from caffe import layers as L
from caffe import params as P
from caffe.proto import caffe_pb2
from google.protobuf import text_format
from collections import defaultdict
blob_layer_dict = defaultdict(str)
def get_parent(layer_bottom):
if blob_layer_dict[layer_bottom]:
return blob_layer_dict[layer_bottom]
else:
return layer_bottom
def add_to_dict(layer_top, top_value):
blob_layer_dict[layer_top] = top_value
def modify_prototxt(args):
prototxt = args.prototxt
prototxtfile = open(prototxt,'r')
net = caffe.Net(prototxt, caffe.TEST)
net_par = caffe_pb2.NetParameter()
text_format.Merge(prototxtfile.read(),net_par)
mod_net = caffe_pb2.NetParameter()
mod_net.name = net_par.name
for l in net_par.layer:
ltemp = mod_net.layer.add()
ltemp.CopyFrom(l)
#Needs to be done only for one top case
#There are cases where Input layer have top with a different name
if len(ltemp.top) == 1 and ltemp.type!='Input':
ltemp.top[0] = l.name
if len(ltemp.bottom) > 1:
print('Inside multiple bottoms')
print(ltemp.bottom)
for i in range(len(ltemp.bottom)):
ltemp.bottom[i] = get_parent(l.bottom[i])
print('After modification', ltemp.bottom)
elif len(ltemp.bottom) == 1:
ltemp.bottom[0] = get_parent(l.bottom[0])
else:
print('Layer:%s has no bottom' % (ltemp.name))
#Not to do for input layer since input layer can have top with different name (exp:FCN8s)
if len(l.top) == 1 and l.type!='Input':
add_to_dict(l.top[0], l.name)
with open(net_par.name + '_tmp.prototxt','w') as f:
f.write(text_format.MessageToString(mod_net))
def main(args):
parser = argparse.ArgumentParser(description='Modify prototxt to remove in-place layers')
parser.add_argument('-p', '--prototxt', type=str, help='Prototxt file')
args = parser.parse_args(args)
blob_layer_dict.clear()
modify_prototxt(args)
if __name__=='__main__':
main(sys.argv[1:])