-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_deepfool.py
More file actions
62 lines (48 loc) · 1.77 KB
/
test_deepfool.py
File metadata and controls
62 lines (48 loc) · 1.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.utils.data as data_utils
from torch.autograd import Variable
import math
import torchvision.models as models
from PIL import Image
from deepfool import deepfool
import os
net = models.resnet34(pretrained=True)
# Switch to evaluation mode
net.eval()
im_orig = Image.open('C:/vspython/DeepFool-master/Python/test_im2.jpg')
mean = [ 0.485, 0.456, 0.406 ]
std = [ 0.229, 0.224, 0.225 ]
# Remove the mean
im = transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean = mean,
std = std)])(im_orig)
r, loop_i, label_orig, label_pert, pert_image = deepfool(im, net)
labels = open(os.path.join('C:/vspython/DeepFool-master/Python/synset_words.txt'), 'r').read().split('\n')
str_label_orig = labels[np.int(label_orig)].split(',')[0]
str_label_pert = labels[np.int(label_pert)].split(',')[0]
print("Original label = ", str_label_orig)
print("Perturbed label = ", str_label_pert)
def clip_tensor(A, minv, maxv):
A = torch.max(A, minv*torch.ones(A.shape))
A = torch.min(A, maxv*torch.ones(A.shape))
return A
clip = lambda x: clip_tensor(x, 0, 255)
tf = transforms.Compose([transforms.Normalize(mean=[0, 0, 0], std=map(lambda x: 1 / x, std)),
transforms.Normalize(mean=map(lambda x: -x, mean), std=[1, 1, 1]),
transforms.Lambda(clip),
transforms.ToPILImage(),
transforms.CenterCrop(224)])
plt.figure()
plt.imshow(tf(pert_image.cpu()[0]))
plt.title(str_label_pert)
plt.show()