-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDeepVigor_analysis.py
More file actions
282 lines (231 loc) · 16 KB
/
DeepVigor_analysis.py
File metadata and controls
282 lines (231 loc) · 16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import torch
import torch.nn as nn
import torch.nn.functional as F
import DeepVigor_utils
import math
from typing import Dict, Union
class DeepVigor_analysis():
def __init__(self,
model: nn.Module,
device: Union[torch.device, str]) -> None:
self.model = model
self.device = device
def delta_injection_channel(self, channel, value):
def hook(model, input, output):
output[:, channel] += value.unsqueeze(1).unsqueeze(2).expand(output.size(0), output.size(2), output.size(3))
output.retain_grad()
self.activation = output
return hook
def channels_vulnerability_factor(self,
images: torch.tensor,
layer: nn.Module,
layer_info_set: Dict,
out_no: int) -> torch.tensor:
features_count = layer_info_set["neurons_in_layer"]
stride = layer_info_set["stride"]
kernel_size = layer_info_set["kernel_size"]
x_pad = layer_info_set["layer_inputs"]
cnv_weights = layer_info_set["layer_weights"]
batch_size = layer_info_set["batch_size"]
out_channel_count = layer_info_set["out_channel"]
resolution = 10
inf_represent = 2 ** resolution
neurons_in_channel = features_count // out_channel_count
fmap_width = int(math.sqrt(neurons_in_channel))
self.activation = torch.tensor([])
non_crit_channels = torch.zeros(out_channel_count, device=self.device)
last_layer_out = self.model.forward(images)
_, detected_labels = torch.max(last_layer_out, 1)
one_hots = torch.unsqueeze(F.one_hot(detected_labels, num_classes=out_no), 1)
neurons_samples = torch.max(torch.tensor([int(torch.log2(torch.tensor([neurons_in_channel]))), 1])).item()
del last_layer_out
for channel in range(out_channel_count):
errors_dist_weight_channel = torch.zeros(4 * resolution + 3, device=self.device)
neurons_set = torch.tensor([])
nrn_counter = 0
while nrn_counter < neurons_samples:
rand_neuron = torch.randint(neurons_in_channel, (1,)).item()
if rand_neuron not in neurons_set:
nrn_counter += 1
neurons_set = torch.cat((neurons_set, torch.tensor([rand_neuron])))
neuron_weights = cnv_weights[channel].unsqueeze(0)
output_ind_row = rand_neuron // fmap_width
output_ind_col = rand_neuron % fmap_width
input_ind_row = output_ind_row * stride
input_ind_col = output_ind_col * stride
sliced_inputs = x_pad[:, :,
input_ind_row : input_ind_row + kernel_size,
input_ind_col : input_ind_col + kernel_size]
errors_dist_weight_neuron = DeepVigor_utils.vulnerability_values_space_weight(sliced_inputs, neuron_weights, self.device)
errors_dist_weight_channel += errors_dist_weight_neuron
del errors_dist_weight_neuron
del sliced_inputs
del neurons_set
#analysis for faulty weights
errors_dist_weight_channel = errors_dist_weight_channel / nrn_counter
VVSS_dict_weights = DeepVigor_utils.creating_VVSS_dict(errors_dist_weight_channel, resolution, self.device)
#finding deltas in negative numbers
dlt_search_l = -torch.ones(batch_size, device=self.device)
self.activation = torch.tensor([])
handle = layer.register_forward_hook(self.delta_injection_channel(channel, dlt_search_l))
corrupted_out = self.model(images)
_, corrupted_labels = torch.max(corrupted_out, 1)
handle.remove()
loss_deepvigor = (torch.sum(torch.sigmoid(torch.unsqueeze(torch.sum(corrupted_out * one_hots, 1), 1) - corrupted_out))) / batch_size
loss_deepvigor.backward()
channel_grad = torch.sum(torch.sum(self.activation.grad.data, 3), 2).detach()
channel_grad[channel_grad != 0] = 1
grad_bool_map = torch.eq(channel_grad[:, channel], torch.zeros_like(channel_grad[:, channel], device=self.device))
del self.activation.grad
del loss_deepvigor
del corrupted_out
torch.cuda.empty_cache()
if torch.sum(channel_grad, 0)[channel] != 0: #there is some images misclassified by faults
true_classified = torch.eq(corrupted_labels, detected_labels)
if torch.sum(true_classified) == batch_size:
if VVSS_dict_weights['neg_inf'].size(0) <= 1: #all images are misclassified by vulnerability_values < -1
dlt_search_l = torch.ones_like(dlt_search_l, device=self.device) * (-inf_represent)
else:
vector_len = VVSS_dict_weights['neg_inf'].size(0)
iteration_count = vector_len // 2
index_tensor = torch.ones(batch_size, dtype=torch.int, device=self.device) * iteration_count
for _ in range(iteration_count, 0, -1):
dlt_search_l = VVSS_dict_weights['neg_inf'][index_tensor]
handle = layer.register_forward_hook(self.delta_injection_channel(channel, dlt_search_l))
corrupted_out = self.model(images)
_, corrupted_labels = torch.max(corrupted_out, 1)
true_classified = torch.eq(corrupted_labels, detected_labels)
index_tensor = torch.logical_not(true_classified) * (index_tensor + 1) + true_classified * (index_tensor - 1)
index_tensor[index_tensor >= vector_len] = vector_len - 1
index_tensor[index_tensor < 0] = 0
handle.remove()
del corrupted_out
del corrupted_labels
del true_classified
dlt_search_l = VVSS_dict_weights['neg_inf'][index_tensor]
dlt_search_l[grad_bool_map == 1] = -inf_represent
torch.cuda.empty_cache()
else: #images misclassified by vulnerability_values < 0
true_classified_init = torch.clone(torch.logical_not(true_classified))
vector_len_inf = VVSS_dict_weights['neg_inf'].size(0)
iteration_count_inf = vector_len_inf // 2
index_tensor_inf = torch.ones(batch_size, dtype=torch.int, device=self.device) * iteration_count_inf
vector_len_1 = VVSS_dict_weights['neg_1'].size(0)
iteration_count_1 = vector_len_1 // 2 #will be bigger value
index_tensor_1 = torch.ones(batch_size, dtype=torch.int, device=self.device) * iteration_count_1
for _ in range(iteration_count_1, 0, -1):
dlt_search_l = torch.logical_not(true_classified_init) * VVSS_dict_weights['neg_1'][index_tensor_1] + true_classified_init * VVSS_dict_weights['neg_inf'][index_tensor_inf]
handle = layer.register_forward_hook(self.delta_injection_channel(channel, dlt_search_l))
corrupted_out = self.model(images)
_, corrupted_labels = torch.max(corrupted_out, 1)
true_classified = torch.eq(corrupted_labels, detected_labels)
index_tensor_inf = true_classified_init * (torch.logical_not(true_classified) * (index_tensor_inf + 1) + true_classified * (index_tensor_inf - 1))
index_tensor_1 = torch.logical_not(true_classified_init) * (torch.logical_not(true_classified) * (index_tensor_1 + 1) + true_classified * (index_tensor_1 - 1))
index_tensor_inf[index_tensor_inf >= vector_len_inf] = vector_len_inf - 1
index_tensor_inf[index_tensor_inf < 0] = 0
index_tensor_1[index_tensor_1 >= vector_len_1] = vector_len_1 - 1
index_tensor_1[index_tensor_1 < 0] = 0
handle.remove()
del corrupted_out
del corrupted_labels
del true_classified
dlt_search_l = torch.logical_not(true_classified_init) * VVSS_dict_weights['neg_1'][index_tensor_1] + true_classified_init * VVSS_dict_weights['neg_inf'][index_tensor_inf]
dlt_search_l[grad_bool_map == 1] = -inf_represent
torch.cuda.empty_cache()
else:
#counter for vul < -inf
dlt_search_l = torch.ones_like(dlt_search_l, device=self.device) * -inf_represent
#free memory
del channel_grad
del grad_bool_map
#del corrupted_out
torch.cuda.empty_cache()
#finding deltas in positive numbers
dlt_search_h = torch.ones(batch_size, device=self.device)
self.activation = torch.tensor([])
handle = layer.register_forward_hook(self.delta_injection_channel(channel, dlt_search_h))
corrupted_out = self.model(images)
_, corrupted_labels = torch.max(corrupted_out, 1)
handle.remove()
loss_deepvigor = (torch.sum(torch.sigmoid(torch.unsqueeze(torch.sum(corrupted_out * one_hots, 1), 1) - corrupted_out))) / batch_size
loss_deepvigor.backward()
channel_grad = torch.sum(torch.sum(self.activation.grad.data, 3), 2).detach()
channel_grad[channel_grad != 0] = 1
grad_bool_map = torch.eq(channel_grad[:, channel], torch.zeros_like(channel_grad[:, channel], device=self.device))
del self.activation.grad
del loss_deepvigor
del corrupted_out
torch.cuda.empty_cache()
if torch.sum(channel_grad, 0)[channel] != 0: #there is some images misclassified by faults
true_classified = torch.eq(corrupted_labels, detected_labels)
if torch.sum(true_classified) == batch_size: #all images are misclassified by vulnerability_values > 1
if VVSS_dict_weights['pos_inf'].size(0) <= 1:
dlt_search_h = torch.ones_like(dlt_search_h, device=self.device) * inf_represent
else:
vector_len = VVSS_dict_weights['pos_inf'].size(0)
iteration_count = vector_len // 2
index_tensor = torch.ones(batch_size, dtype=torch.int, device=self.device) * iteration_count
for _ in range(iteration_count, 0, -1):
dlt_search_h = VVSS_dict_weights['pos_inf'][index_tensor]
handle = layer.register_forward_hook(self.delta_injection_channel(channel, dlt_search_h))
corrupted_out = self.model(images)
_, corrupted_labels = torch.max(corrupted_out, 1)
true_classified = torch.eq(corrupted_labels, detected_labels)
index_tensor = torch.logical_not(true_classified) * (index_tensor - 1) + true_classified * (index_tensor + 1)
index_tensor[index_tensor >= vector_len] = vector_len - 1
index_tensor[index_tensor < 0] = 0
handle.remove()
del true_classified
del corrupted_labels
del corrupted_out
dlt_search_h = VVSS_dict_weights['pos_inf'][index_tensor]
dlt_search_h[grad_bool_map == 1] = inf_represent
torch.cuda.empty_cache()
else:
vector_len_inf = VVSS_dict_weights['pos_inf'].size(0)
vector_len_1 = VVSS_dict_weights['pos_1'].size(0)
iteration_count_inf = vector_len_inf // 2
iteration_count_1 = vector_len_1 // 2 #will be bigger value
index_tensor_inf = torch.ones(batch_size, dtype=torch.int, device=self.device) * iteration_count_inf
index_tensor_1 = torch.ones(batch_size, dtype=torch.int, device=self.device) * iteration_count_1
true_classified_init = torch.clone(true_classified)
for _ in range(iteration_count_1, 0, -1):
dlt_search_h = torch.logical_not(true_classified_init) * VVSS_dict_weights['pos_1'][index_tensor_1] + true_classified_init * VVSS_dict_weights['pos_inf'][index_tensor_inf]
handle = layer.register_forward_hook(self.delta_injection_channel(channel, dlt_search_h))
corrupted_out = self.model(images)
_, corrupted_labels = torch.max(corrupted_out, 1)
true_classified = torch.eq(corrupted_labels, detected_labels)
index_tensor_inf = true_classified_init * (torch.logical_not(true_classified) * (index_tensor_inf - 1) + true_classified * (index_tensor_inf + 1))
index_tensor_1 = torch.logical_not(true_classified_init) * (torch.logical_not(true_classified) * (index_tensor_1 - 1) + true_classified * (index_tensor_1 + 1))
index_tensor_inf[index_tensor_inf >= vector_len_inf] = vector_len_inf - 1
index_tensor_inf[index_tensor_inf < 0] = 0
index_tensor_1[index_tensor_1 >= vector_len_1] = vector_len_1 - 1
index_tensor_1[index_tensor_1 < 0] = 0
handle.remove()
del corrupted_labels
del true_classified
del corrupted_out
dlt_search_h = torch.logical_not(true_classified_init) * VVSS_dict_weights['pos_1'][index_tensor_1] + true_classified_init * VVSS_dict_weights['pos_inf'][index_tensor_inf]
dlt_search_h[grad_bool_map == 1] = inf_represent
torch.cuda.empty_cache()
else:
dlt_search_h = torch.ones_like(dlt_search_h, device=self.device) * inf_represent
#NVF calculation
negative_vulnerability_powers = torch.zeros_like(dlt_search_l, device=self.device)
negative_vulnerability_powers = torch.floor(torch.log2(torch.abs(dlt_search_l))).int()
positive_vulnerability_powers = torch.zeros_like(dlt_search_h, device=self.device)
positive_vulnerability_powers = torch.floor(torch.log2(dlt_search_h)).int()
positive_vulnerability_powers[positive_vulnerability_powers > 10] = 10
lower_bound_criticality = errors_dist_weight_channel[10 - negative_vulnerability_powers]
upper_bound_criticality = errors_dist_weight_channel[22 + 10 + positive_vulnerability_powers]
noncriticality_channel = torch.sum(upper_bound_criticality - lower_bound_criticality)
non_crit_channels[channel] = noncriticality_channel / batch_size
#free memory
del channel_grad
del grad_bool_map
#del corrupted_out
torch.cuda.empty_cache()
del images
del x_pad
torch.cuda.empty_cache()
return 1 - non_crit_channels