-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
83 lines (65 loc) · 3.46 KB
/
main.py
File metadata and controls
83 lines (65 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import time
import sys
import getopt
import os
import dataset_utils
import analysis_utils
import CNNs
if __name__ == "__main__":
arg_list = sys.argv[1:]
short_options = "m:c:b:d:r"
long_options = ["model=", "channel-sampling=", "batch-size=", "dataset=", "run-mode="]
try:
arguments, values = getopt.getopt(arg_list, short_options, long_options)
for arg, val in arguments:
print(arg, val)
if arg in ["-m", "--model"]:
network_name = str(val)
elif arg in ["-d", "--dataset"]:
dataset_name = str(val)
elif arg in ["-b", "--batch-size"]:
batch_size = int(val)
elif arg in ["r", "--run-mode"]: #"sampling-analysis", "full-analysis-weight", "sampling-analysis-act", "sampling-analysis-weight"
run_mode = str(val)
elif arg in ["c", "--channel-sampling"]:
ch_sampling_ratio = float(val)
except:
raise Exception("parameters are not specified correctly!")
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
#load dataset and CNN model
is_train = False
dataloader, classes_count = dataset_utils.load_dataset(dataset_name, is_train, batch_size)
net = CNNs.load_model(network_name, dataset_name, device)
#create log file
log_direction = network_name + "-" + dataset_name + "/" + run_mode
if not os.path.exists(log_direction):
os.makedirs(log_direction)
log_file_name = log_direction + "/log-" + run_mode + "-" + network_name + "-" + dataset_name + ".txt"
log_file = open(log_file_name, 'w')
log_file.close()
#performing based on run_mode
for data in dataloader:
images, labels = data[0].to(device), data[1].to(device)
if run_mode == "full-analysis-weight":
au = analysis_utils.analysis_utils(net, network_name, images, dataset_name, batch_size, classes_count, log_file_name, log_direction, device)
au.full_weights_analysis(net)
elif run_mode == "full-analysis-act":
au = analysis_utils.analysis_utils(net, network_name, images, dataset_name, batch_size, classes_count, log_file_name, log_direction, device)
au.full_activations_analysis(net)
elif run_mode == "sampling-analysis-act":
log_file_name = log_direction + "/log-" + run_mode + "-" + network_name + "-" + dataset_name + "-" + str(ch_sampling_ratio) + ".txt"
log_file = open(log_file_name, 'w')
log_file.close()
au = analysis_utils.analysis_utils(net, network_name, images, dataset_name, batch_size, classes_count, log_file_name, log_direction, device, channel_sampling_ratio=ch_sampling_ratio)
au.sampling_analysis_act(net)
elif run_mode == "sampling-analysis-weight":
log_file_name = log_direction + "/log-" + run_mode + "-" + network_name + "-" + dataset_name + "-" + str(ch_sampling_ratio) + ".txt"
log_file = open(log_file_name, 'w')
log_file.close()
au = analysis_utils.analysis_utils(net, network_name, images, dataset_name, batch_size, classes_count, log_file_name, log_direction, device, channel_sampling_ratio=ch_sampling_ratio)
au.sampling_analysis_weight(net)
break