-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathabstraction_runner.py
More file actions
157 lines (132 loc) · 6.99 KB
/
abstraction_runner.py
File metadata and controls
157 lines (132 loc) · 6.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import pickle
from deepstellar.Abstraction.StateAbstraction import StateAbstraction
from deepstellar.Abstraction.GraphWrapper import GraphWrapper
import argparse
import sys
sys.path.append("")
def get_abst_model(profile_save_path, abst_save_path, name_prefix, lstm_classifier, model):
comp_num = 64 # 64
k = 3
m = 10
bits = 8
n_step = 0
if not os.path.exists(profile_save_path):
lstm_classifier.profile_train_data(model, profile_save_path)
print("profiling done...")
else:
print("profiling is already done...")
par_k = [m] * k
stateAbst = StateAbstraction(profile_save_path, comp_num, bits, [m] * k, n_step)
wrapper = GraphWrapper(stateAbst)
wrapper.build_model()
save_file = 'wrapper_%s_%s_%s.pkl' % (name_prefix, len(par_k), par_k[0])
save_file = os.path.join(abst_save_path, save_file)
os.makedirs(abst_save_path, exist_ok=True)
with open(save_file, 'wb') as f:
pickle.dump(wrapper, f)
print('finish')
def mnist_lstm_abst():
from RNNModels.mnist_demo.mnist_lstm import MnistLSTMClassifier
dl_model = "./RNNModels/mnist_demo/models/mnist_lstm.h5"
profile_save_path = "./RNNModels/mnist_demo/output/lstm/profile_save"
abst_save_path = "./RNNModels/mnist_demo/output/lstm/abst_model"
name_prefix = "lstm_mnist"
lstm_classifier = MnistLSTMClassifier()
model = lstm_classifier.load_hidden_state_model(dl_model)
get_abst_model(profile_save_path, abst_save_path, name_prefix, lstm_classifier, model)
def mnist_blstm_abst():
from RNNModels.mnist_demo.mnist_blstm import MnistBLSTMClassifier
dl_model = "./RNNModels/mnist_demo/models/mnist_blstm.h5"
profile_save_path = "./RNNModels/mnist_demo/output/blstm/profile_save"
abst_save_path = "./RNNModels/mnist_demo/output/blstm/abst_model"
name_prefix = "blstm_mnist"
lstm_classifier = MnistBLSTMClassifier()
model = lstm_classifier.load_hidden_state_model(dl_model)
get_abst_model(profile_save_path, abst_save_path, name_prefix, lstm_classifier, model)
def snips_blstm_abst():
from RNNModels.snips_demo.snips_blstm import SnipsBLSTMClassifier
dl_model = "./RNNModels/snips_demo/models/snips_blstm.h5"
profile_save_path = "./RNNModels/snips_demo/output/blstm/profile_save"
abst_save_path = "./RNNModels/snips_demo/output/blstm/abst_model"
name_prefix = "blstm_snips"
lstm_classifier = SnipsBLSTMClassifier()
lstm_classifier.embedding_path = "./RNNModels/snips_demo/save/embedding_matrix.npy"
lstm_classifier.data_path = "./RNNModels/snips_demo/save/standard_data.npz"
model = lstm_classifier.load_hidden_state_model(dl_model)
get_abst_model(profile_save_path, abst_save_path, name_prefix, lstm_classifier, model)
def snips_gru_abst():
from RNNModels.snips_demo.snips_gru import SnipsGRUClassifier
dl_model = "./RNNModels/snips_demo/models/snips_gru.h5"
profile_save_path = "./RNNModels/snips_demo/output/gru/profile_save"
abst_save_path = "./RNNModels/snips_demo/output/gru/abst_model"
name_prefix = "gru_snips"
lstm_classifier = SnipsGRUClassifier()
lstm_classifier.embedding_path = "./RNNModels/snips_demo/save/embedding_matrix.npy"
lstm_classifier.data_path = "./RNNModels/snips_demo/save/standard_data.npz"
model = lstm_classifier.load_hidden_state_model(dl_model)
get_abst_model(profile_save_path, abst_save_path, name_prefix, lstm_classifier, model)
def fashion_lstm_abst():
from RNNModels.fashion_demo.fashion_lstm import FashionLSTMClassifier
dl_model = "./RNNModels/fashion_demo/models/fashion_lstm.h5"
profile_save_path = "./RNNModels/fashion_demo/output/lstm/profile_save"
abst_save_path = "./RNNModels/fashion_demo/output/lstm/abst_model"
name_prefix = "lstm_fashion"
lstm_classifier = FashionLSTMClassifier()
model = lstm_classifier.load_hidden_state_model(dl_model)
get_abst_model(profile_save_path, abst_save_path, name_prefix, lstm_classifier, model)
def fashion_gru_abst():
from RNNModels.fashion_demo.fashion_gru import FashionGRUClassifier
dl_model = "./RNNModels/fashion_demo/models/fashion_gru.h5"
profile_save_path = "./RNNModels/fashion_demo/output/gru/profile_save"
abst_save_path = "./RNNModels/fashion_demo/output/gru/abst_model"
name_prefix = "gru_fashion"
lstm_classifier = FashionGRUClassifier()
model = lstm_classifier.load_hidden_state_model(dl_model)
get_abst_model(profile_save_path, abst_save_path, name_prefix, lstm_classifier, model)
def agnews_lstm_abst():
from RNNModels.agnews_demo.agnews_lstm import AGNewsLSTMClassifier
dl_model = "./RNNModels/agnews_demo/models/agnews_lstm.h5"
profile_save_path = "./RNNModels/agnews_demo/output/lstm/profile_save"
abst_save_path = "./RNNModels/agnews_demo/output/lstm/abst_model"
name_prefix = "lstm_agnews"
lstm_classifier = AGNewsLSTMClassifier()
lstm_classifier.embedding_path = "./RNNModels/agnews_demo/save/embedding_matrix.npy"
lstm_classifier.data_path = "./RNNModels/agnews_demo/save/standard_data.npz"
model = lstm_classifier.load_hidden_state_model(dl_model)
get_abst_model(profile_save_path, abst_save_path, name_prefix, lstm_classifier, model)
def agnews_blstm_abst():
from RNNModels.agnews_demo.agnews_blstm import AgnewsBLSTMClassifier
dl_model = "./RNNModels/agnews_demo/models/agnews_blstm.h5"
profile_save_path = "./RNNModels/agnews_demo/output/blstm/profile_save"
abst_save_path = "./RNNModels/agnews_demo/output/blstm/abst_model"
name_prefix = "blstm_agnews"
lstm_classifier = AgnewsBLSTMClassifier()
lstm_classifier.embedding_path = "./RNNModels/agnews_demo/save/embedding_matrix.npy"
lstm_classifier.data_path = "./RNNModels/agnews_demo/save/standard_data.npz"
model = lstm_classifier.load_hidden_state_model(dl_model)
get_abst_model(profile_save_path, abst_save_path, name_prefix, lstm_classifier, model)
if __name__ == '__main__':
parse = argparse.ArgumentParser(
"Generate the abstract model for DeepStellar-cov.")
parse.add_argument('-test_obj', required=True, choices=['mnist_lstm', 'mnist_blstm',
'snips_blstm', 'snips_gru',
'fashion_lstm', 'fashion_gru',
'agnews_lstm', 'agnews_blstm'])
args = parse.parse_args()
if args.test_obj == "mnist_lstm":
mnist_lstm_abst()
if args.test_obj == "mnist_blstm":
mnist_blstm_abst()
if args.test_obj == "snips_blstm":
snips_blstm_abst()
if args.test_obj == "snips_gru":
snips_gru_abst()
if args.test_obj == "fashion_lstm":
fashion_lstm_abst()
if args.test_obj == "fashion_gru":
fashion_gru_abst()
if args.test_obj == "agnews_lstm":
agnews_lstm_abst()
if args.test_obj == "agnews_blstm":
agnews_blstm_abst()