-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
354 lines (299 loc) · 19.1 KB
/
main.py
File metadata and controls
354 lines (299 loc) · 19.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import torch
import argparse
import dataset_utils
import model_utils
import tiling_utils
import sparsity_utils
import checksum_utils
import gpu_config
import copy
import torch.nn as nn
import torch.nn.utils.prune as prune
import log
import fault_simulation
import random
# import psutil, os
import gc
import time
from typing import cast
from zeus.monitor import ZeusMonitor
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Model-level CNN hardening")
_ = parser.add_argument("--model", type=str, required=True, help="name of the CNN")
_ = parser.add_argument("--dataset", type=str, choices=["cifar10", "cifar100", "imagenet"], required=True, help="name of the dataset")
_ = parser.add_argument("--batch-size", type=int, default=128, help="an integer value for bach-size")
_ = parser.add_argument("--is-tiling", action='store_true', help="set the flag, to apply tiling to the model")
_ = parser.add_argument("--policy", type=str, choices=["TwoLevelRoundRobin", "GlobalRoundRobin", "Greedy", "DistributedBlock", "DistributedCTA"], help="scheduling policy")
_ = parser.add_argument("--is-sparsity", action='store_true', help="set the flag, to apply sparsity to the model")
_ = parser.add_argument("--is-finetune", action='store_true', help="set the flag, to apply fine-tuning for sparse model")
_ = parser.add_argument("--epoch", type=int, help="in the case of fine-tune, specify the epoch")
_ = parser.add_argument("--is-checksum", action='store_true', help="set the flag, to apply checksum to sparse model")
_ = parser.add_argument("--is-thresholds-optimization", action='store_true', help="set the flag, to optimize thresholds for fault detection in checksum for sparse model")
_ = parser.add_argument("--is-FI", action='store_true', help="set the flag, to apply fault injection")
_ = parser.add_argument("--FI-location", type=str, choices=["weight", "metadata"], help="set the fault location")
_ = parser.add_argument("--FI-iteration", type=int, help="in the case of FI, specify the number of iterations")
_ = parser.add_argument("--FI-study", type=str, choices=["random", "per-layer"], help="random is for FI into random layers, per-layer is for targeting each layer consecutively")
_ = parser.add_argument("--is-overhead", action='store_true', help="set the flag, to evaluate energy and performance overhead")
_ = parser.add_argument("--case", type=int, help="1: sparsity, 2: checksum")
# process = psutil.Process(os.getpid())
# setting up the arguments values
args = parser.parse_args()
model_name = cast(str, args.model)
dataset_name = cast(str, args.dataset)
batch_size = cast(int, args.batch_size)
is_tiling = cast(bool, args.is_tiling)
policy = cast(str, args.policy)
is_sparsity = cast(bool, args.is_sparsity)
is_finetune = cast(bool, args.is_finetune)
finetune_epochs = cast(int, args.epoch)
is_checksum = cast(bool, args.is_checksum)
is_thresholds_optimization = cast(bool, args.is_thresholds_optimization)
is_FI = cast(bool, args.is_FI)
FI_location = cast(str, args.FI_location)
FI_study = cast(str, args.FI_study)
FI_iteration = cast(int, args.FI_iteration)
is_overhead = cast(bool, args.is_overhead)
case = cast(int, args.case)
run_mode = "exp"
run_mode += "".join([part for part, condition in [("_sparse", is_sparsity),
("_finetune", is_finetune),
(f"_tiling_{policy}", is_tiling),
("_FI", is_FI),
("_overhead", is_overhead),
("_checksum", is_checksum)]
if condition])
# create logger
setup_logger = log.LogHandler(run_mode, model_name, dataset_name)
logger = setup_logger.getLogger()
setup_logger_info = ""
setup_logger_info += "".join(f"{i}: {args.__dict__[i]}, " for i in args.__dict__ if \
(type(args.__dict__[i]) is bool and args.__dict__[i] is True) or \
(type(args.__dict__[i]) is not bool and args.__dict__[i] is not None))
logger.info(f"args: {setup_logger_info}")
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# load dataset and CNN model
trainloader, classes_count, dummy_input = dataset_utils.load_dataset(dataset_name, batch_size, is_train=True)
testloader, classes_count, dummy_input = dataset_utils.load_dataset(dataset_name, batch_size, is_train=False)
model = model_utils.load_model(model_name, dataset_name, device)
if is_tiling and not is_FI:
if not is_sparsity:
gpu = gpu_config.gpu_config(name="dummy", sm=4, cluster=2, CTAs_buffer=2)
for data in testloader:
print("hooking the function ...")
inputs, labels = data[0].to(device), data[1].to(device)
tu = tiling_utils.tiling_utils(policy, gpu, 16, 8, 16)
print("scheduling in progress...")
tu.CNN_to_tiles(model, inputs)
print("scheduling done.")
for layer in tu.scheduler_info_dict:
for img_ind in tu.scheduler_info_dict[layer]:
file = open(f"schduling_{layer}_{img_ind}.json", "w")
file.write(str(tu.scheduler_info_dict[layer][img_ind]) + "\n")
file.close()
break
print("tiling is done and all the restuls are saved")
elif is_sparsity:
for data in testloader:
print("hooking the function ...")
inputs = data[0].to(device)
model_utils.load_params_sparse(model, f"./pruned-models/pruned_model_{model_name}_{dataset_name}.pth", device)
sparsity_utils.sparse_CNN_to_tiles(model, inputs, setup_logger.log_direction)
break
elif is_sparsity and not is_FI:
if is_finetune:
trainloader, classes_count, dummy_input = dataset_utils.load_dataset(dataset_name, batch_size, is_train=True)
accuracy = model_utils.evaluate(model, testloader, device)
logger.info(f"original accuracy = {accuracy}")
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune_method = sparsity_utils.SparsePrune()
prune_method.apply(module, "weight")
accuracy = model_utils.evaluate(model, testloader, device)
logger.info(f"sparsed initial accuracy = {accuracy}")
saved_model_name = f"./pruned-models/pruned_model_{model_name}_{dataset_name}.pth"
if dataset_name == "cifar10":
lr = 0.01
elif dataset_name == "cifar100" or dataset_name == "imagenet":
lr = 0.001
_ = model_utils.fine_tune_sparse(model, trainloader, testloader, finetune_epochs, lr, device, saved_model_name, logger)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
if hasattr(module, 'weight_mask'):
prune.remove(module, 'weight')
# model_utils.load_params(model, saved_model_name, device)
# accuracy = model_utils.evaluate(model, testloader, device)
# logger.info(f"sparsed refined accuracy = {accuracy}")
new_model = copy.deepcopy(model)
sparsity_utils.conv_replacement(new_model)
accuracy = model_utils.evaluate(new_model, testloader, device)
logger.info(f"structured sparsed accuracy = {accuracy}")
elif not is_finetune:
if not is_overhead:
accuracy = model_utils.evaluate(model, testloader, device)
logger.info(f"original accuracy = {accuracy}")
model_utils.load_params_sparse(model, f"./pruned-models/pruned_model_{model_name}_{dataset_name}.pth", device)
accuracy = model_utils.evaluate(model, testloader, device)
logger.info(f"sparsed accuracy = {accuracy}")
new_model = copy.deepcopy(model)
new_model = sparsity_utils.conv_replacement(new_model)
model_utils.load_params_sparse(new_model, f"./pruned-models/pruned_model_{model_name}_{dataset_name}.pth", device)
accuracy = model_utils.evaluate(new_model, testloader, device)
logger.info(f"structured sparsed accuracy = {accuracy}")
elif is_sparsity and is_FI:
conv_layer_names = fault_simulation.get_layers_names(model)
sum_accuracy = 0
worst_accuracy = 100
best_accuracy = 0
sparsity_utils.conv_replacement(model)
fault_simulation.FI_conv_replacement(model, FI_location)
if FI_study == "random":
with torch.no_grad():
for FI_iter in range(FI_iteration):
random.shuffle(conv_layer_names)
target_layer_name = conv_layer_names[0]
fault_simulation.activate_FI_layer(model, target_layer_name)
faulty_accuracy = model_utils.evaluate(model, testloader, device)
if faulty_accuracy < worst_accuracy:
worst_accuracy = faulty_accuracy
if faulty_accuracy > best_accuracy:
best_accuracy = faulty_accuracy
if sdc < worst_sdc:
worst_sdc = sdc
if sdc > best_sdc:
best_sdc = sdc
if sdc_crit < worst_sdc_crit:
worst_sdc_crit = sdc_crit
if sdc_crit > best_sdc_crit:
best_sdc_crit = sdc_crit
sum_accuracy = sum_accuracy + faulty_accuracy
sum_sdc = sum_sdc + sdc
sum_sdc_crit = sum_sdc_crit + sdc_crit
del faulty_accuracy
torch.cuda.empty_cache()
total_detection = 0
total_correction = 0
true_correction = 0
total_naninf_faults = 0
total_no_faults = 0
total_undetected_faults = 0
for name1, layer in model.named_modules():
if isinstance(layer, checksum_utils.ChecksumConv):
logger.info(f"{name1}, {layer.detected_faults}, {layer.corrected_faults}, {layer.true_correction}")
total_detection += layer.detected_faults
total_correction += layer.corrected_faults
true_correction += layer.true_correction
total_naninf_faults += layer.naninf_faults
total_undetected_faults += layer.undetected_faults
for r in layer.wrong_corrections:
logger.info(f"{r}")
logger.info(f"Worst accuracy: {worst_accuracy}")
logger.info(f"Best accuracy: {best_accuracy}")
logger.info(f"Average accuracy: {sum_accuracy / FI_iteration}")
logger.info(f"Lowest SDC: {worst_sdc}")
logger.info(f"Highest SDC: {best_sdc}")
logger.info(f"Average SDC: {sum_sdc / FI_iteration}")
logger.info(f"Lowest critical SDC: {worst_sdc_crit}")
logger.info(f"Highest critical SDC: {best_sdc_crit}")
logger.info(f"Average critical SDC: {sum_sdc_crit / FI_iteration}")
logger.info(f"Total undetected faults: {total_undetected_faults}")
logger.info(f"nan or inf results: {total_naninf_faults}")
logger.info(f"Total detected faults: {total_detection}")
logger.info(f"Total corrected faults: {total_correction}")
logger.info(f"True corrections: {true_correction}")
logger.info("============================\n")
elif FI_study == "layer-wise":
with torch.no_grad():
for target_layer_name in conv_layer_names:
for FI_iter in range(FI_iteration):
faulty_accuracy = model_utils.evaluate(model, testloader, device)
if faulty_accuracy < worst_accuracy:
worst_accuracy = faulty_accuracy
if faulty_accuracy > best_accuracy:
best_accuracy = faulty_accuracy
sum_accuracy = sum_accuracy + faulty_accuracy
gc.collect()
torch.cuda.empty_cache()
logger.info(f"Worst accuracy: {worst_accuracy}")
logger.info(f"Best accuracy: {best_accuracy}")
logger.info(f"Average accuracy over {FI_iteration} iterations for layer {target_layer_name}: {sum_accuracy / FI_iteration}")
else:
conv_layer_names = fault_simulation.get_layers_names(model)
sparsity_utils.conv_replacement(model)
fault_simulation.FI_spconv_replacement(model, FI_location)
model_utils.load_params_sparse(model, f"./pruned-models/pruned_model_{model_name}_{dataset_name}.pth", device)
original_model = copy.deepcopy(model)
if FI_study == "random":
with torch.no_grad():
for FI_iter in range(FI_iteration):
random.shuffle(conv_layer_names)
target_layer_name = conv_layer_names[0]
fault_simulation.activate_FI_layer(model, target_layer_name)
faulty_accuracy, sdc, sdc_crit = model_utils.faulty_evaluate(original_model, model, testloader, device, classes_count)
if faulty_accuracy < worst_accuracy:
worst_accuracy = faulty_accuracy
if faulty_accuracy > best_accuracy:
best_accuracy = faulty_accuracy
if sdc < worst_sdc:
worst_sdc = sdc
if sdc > best_sdc:
best_sdc = sdc
if sdc_crit < worst_sdc_crit:
worst_sdc_crit = sdc_crit
if sdc_crit > best_sdc_crit:
best_sdc_crit = sdc_crit
sum_accuracy = sum_accuracy + faulty_accuracy
sum_sdc = sum_sdc + sdc
sum_sdc_crit = sum_sdc_crit + sdc_crit
del faulty_accuracy
torch.cuda.empty_cache()
logger.info(f"Worst accuracy: {worst_accuracy}")
logger.info(f"Best accuracy: {best_accuracy}")
logger.info(f"Average accuracy: {sum_accuracy / FI_iteration}")
logger.info(f"Lowest SDC: {worst_sdc}")
logger.info(f"Highest SDC: {best_sdc}")
logger.info(f"Average SDC: {sum_sdc / FI_iteration}")
logger.info(f"Lowest critical SDC: {worst_sdc_crit}")
logger.info(f"Highest critical SDC: {best_sdc_crit}")
logger.info(f"Average critical SDC: {sum_sdc_crit / FI_iteration}")
logger.info("============================\n")
elif not is_sparsity:
sum_accuracy = 0
worst_accuracy = 100
best_accuracy = 0
original_model = copy.deepcopy(model)
fault_simulation.FI_conv_replacement(model)
if FI_study == "random":
conv_layer_names = fault_simulation.get_layers_names(model)
with torch.no_grad():
for FI_iter in range(FI_iteration):
random.shuffle(conv_layer_names)
target_layer_name = conv_layer_names[0]
fault_simulation.activate_FI_layer(model, target_layer_name)
faulty_accuracy, sdc, sdc_crit = model_utils.faulty_evaluate(original_model, model, testloader, device, classes_count)
if faulty_accuracy < worst_accuracy:
worst_accuracy = faulty_accuracy
if faulty_accuracy > best_accuracy:
best_accuracy = faulty_accuracy
if sdc < worst_sdc:
worst_sdc = sdc
if sdc > best_sdc:
best_sdc = sdc
if sdc_crit < worst_sdc_crit:
worst_sdc_crit = sdc_crit
if sdc_crit > best_sdc_crit:
best_sdc_crit = sdc_crit
sum_accuracy = sum_accuracy + faulty_accuracy
sum_sdc = sum_sdc + sdc
sum_sdc_crit = sum_sdc_crit + sdc_crit
del faulty_accuracy
model = model_utils.load_model(model_name, dataset_name, device)
torch.cuda.empty_cache()
logger.info(f"Worst accuracy: {worst_accuracy}")
logger.info(f"Best accuracy: {best_accuracy}")
logger.info(f"Average accuracy over {FI_iteration} iterations: {sum_accuracy / FI_iteration}")
logger.info(f"Lowest SDC: {worst_sdc}")
logger.info(f"Highest SDC: {best_sdc}")
logger.info(f"Average SDC: {sum_sdc / FI_iteration}")
logger.info(f"Lowest critical SDC: {worst_sdc_crit}")
logger.info(f"Highest critical SDC: {best_sdc_crit}")
logger.info(f"Average critical SDC: {sum_sdc_crit / FI_iteration}")