-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
150 lines (124 loc) · 7.88 KB
/
main.py
File metadata and controls
150 lines (124 loc) · 7.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import argparse
import dataset_utils
import models_utils
import utils
import run_mode_utils
import handlers
import clipping
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Model-level CNN hardening")
parser.add_argument("--model", type=str, required=True, help="name of the CNN")
parser.add_argument("--dataset", type=str, choices=["cifar10", "cifar100", "imagenet"], required=True, help="name of the dataset")
parser.add_argument("--batch-size", type=int, default=128, help="an integer value for bach-size")
parser.add_argument("--is-ranking", action='store_true', help="set the flag, if you want to only rank the channels in conv layers")
parser.add_argument("--is-pruning", action='store_true', help="set the flag, if you want to prune the CNN")
parser.add_argument("--is-pruned", action='store_true', help="set the flag, if you want to load a pruned CNN")
parser.add_argument("--pruning-method", type=str, default="hm", help="pruning method, homogeneous or heterogeneous")
parser.add_argument("--pruning-ratio", type=float, default=None, help="a float value for pruning conv layers")
parser.add_argument("--pruned-checkpoint", type=str, default=None, help="directory to the pruned model checkpoint")
parser.add_argument("--is-hardening", action='store_true', help="set the flag, if you want to harden the CNN")
parser.add_argument("--is-hardened", action='store_true', help="set the flag, if you want to load a hardened CNN")
parser.add_argument("--hardening-ratio", type=float, default=None, help="a float value for hardening conv layers")
parser.add_argument("--hardened-checkpoint", type=str, default=None, help="directory to the hardened model checkpoint")
parser.add_argument("--importance", type=str, choices=["l1-norm", "vul-gain", "salience", "deepvigor", "channel-FI"], default=None, help="method for importance analysis either in pruning or hardening")
parser.add_argument("--clipping", type=str, choices=["ranger"], default=None, help="method for clipping ReLU in hardening")
parser.add_argument("--is-FI", action="store_true", help="set the flag, for performing fault simulation in weights")
parser.add_argument("--BER", type=float, default=None, help="a float value for Bit Error Rate")
parser.add_argument("--repeat", type=int, default=None, help="number of fault simulation experiments")
parser.add_argument("--is-performance", action='store_true', help="set the flag, if you want to test theperformance of a CNN")
# setting up the arguments values
args = parser.parse_args()
model_name = args.model
dataset_name = args.dataset
batch_size = args.batch_size
is_ranking = args.is_ranking
is_pruning = args.is_pruning
is_pruned = args.is_pruned
pruning_method = args.pruning_method
pruning_ratio = args.pruning_ratio
pruned_checkpoint = args.pruned_checkpoint
is_hardening = args.is_hardening
is_hardened = args.is_hardened
hardening_ratio = args.hardening_ratio
hardened_checkpoint = args.hardened_checkpoint
importance_command = args.importance
clipping_command = args.clipping
is_FI = args.is_FI
BER = args.BER
repetition_count = args.repeat
is_performance = args.is_performance
#is_hardening and is_pruning should not be True at a same time
assert not (is_hardening and is_pruning) == True
# create log file
run_mode = "test"
run_mode += "".join([part for part, condition in [("_channel_ranking", is_ranking),
("_pruning", is_pruning),
("_hardening", is_hardening),
("_FI", is_FI),
("_performance", is_performance)]
if condition])
setup_logger = handlers.LogHandler(run_mode, model_name, dataset_name)
logger = setup_logger.getLogger()
setup_logger_info = ""
setup_logger_info += "".join(f"{i}: {args.__dict__[i]}, " for i in args.__dict__ if \
(type(args.__dict__[i]) is bool and args.__dict__[i] is True) or \
(type(args.__dict__[i]) is not bool and args.__dict__[i] is not None))
logger.info(f"args: {setup_logger_info}")
# set the device
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# load dataset and CNN model
trainloader, classes_count, dummy_input = dataset_utils.load_dataset(dataset_name, batch_size, is_train=True)
testloader, classes_count, dummy_input = dataset_utils.load_dataset(dataset_name, batch_size, is_train=False)
model = models_utils.load_model(model_name, dataset_name, device)
if is_pruned:
assert pruning_ratio is not None
assert pruned_checkpoint is not None
pu = utils.prune_utils(model, pruning_method, classes_count, pruning_method)
pu.set_pruning_ratios(pruning_ratio)
model = pu.homogeneous_prune(model)
models_utils.load_params(model, pruned_checkpoint, device)
if is_hardened:
assert hardening_ratio is not None
assert hardened_checkpoint is not None
assert clipping_command is not None
clippingHandler = handlers.ClippingHandler(logger)
clippingHandler.register("ranger", clipping.Ranger_thresholds)
hr = utils.hardening_utils(hardening_ratio, clipping_command)
hr.thresholds_extraction(model, clippingHandler, clipping_command, trainloader, device, logger)
hardened_model = hr.relu_replacement(model)
model = hr.conv_replacement(model)
models_utils.load_params(model, hardened_checkpoint, device)
dummy_input = dummy_input.to(device)
model = model.to(device)
runModeHandler = handlers.RunModeHandler(logger)
runModeHandler.register("test", run_mode_utils.test_func)
runModeHandler.register("test_pruning", run_mode_utils.pruning_func)
runModeHandler.register("test_hardening", run_mode_utils.hardening_func)
runModeHandler.register("test_FI", run_mode_utils.weights_FI_simulation)
runModeHandler.register("test_channel_ranking", run_mode_utils.channel_ranking_func)
runModeHandler.register("test_performance", run_mode_utils.performance_func)
if run_mode == "test":
runModeHandler.execute(run_mode, model, testloader, device, dummy_input, logger)
elif run_mode == "test_pruning":
assert importance_command is not None
assert pruning_ratio is not None
runModeHandler.execute(run_mode, model, trainloader, testloader, classes_count, dummy_input,
pruning_method, device, pruning_ratio, importance_command, logger)
elif run_mode == "test_hardening":
assert importance_command is not None
assert hardening_ratio is not None
assert clipping_command is not None
runModeHandler.execute(run_mode, model, trainloader, testloader, dummy_input, classes_count,
pruning_method, pruning_ratio, hardening_ratio, importance_command,
clipping_command, device, logger)
elif run_mode == "test_FI":
assert BER is not None
assert repetition_count is not None
runModeHandler.execute(run_mode, model, testloader, repetition_count, BER, classes_count, device, logger)
elif run_mode == "test_channel_ranking":
assert importance_command is not None
runModeHandler.execute(run_mode, model, trainloader, importance_command, classes_count, logger, device)
elif run_mode == "test_performance":
runModeHandler.execute(run_mode, model, dummy_input, logger)
#TODO: iterative pruning + refining