-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathattacks.py
More file actions
42 lines (36 loc) · 1.45 KB
/
attacks.py
File metadata and controls
42 lines (36 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import ipdb
import torch
import torch.nn as nn
from utils_sphere import normalize, get_norm
class pgd_sphere(object):
""" projected gradient desscent, with random initialization within the ball """
def __init__(self, **kwargs):
# define default attack parameters here:
self.param = {'eps': 0.01,
'num_iter': 50,
'loss_fn': nn.BCEWithLogitsLoss()}
# parse thru the dictionary and modify user-specific params
self.parse_param(**kwargs)
def generate(self, model, x, y):
eps = self.param['eps']
num_iter = self.param['num_iter']
loss_fn = self.param['loss_fn']
r = get_norm(x)
delta = torch.rand_like(x, requires_grad=True)
delta.data = 2 * delta.detach() - 1
delta.data = eps * normalize(delta.detach())
delta.data = r * normalize(x + delta) - x
for t in range(num_iter):
model.zero_grad()
loss = loss_fn(model(x + delta), y)
loss.backward()
delta_grad = delta.grad.detach()
delta.data = delta + eps * normalize(delta_grad)
# delta.data = delta + eps * delta_grad
delta.data = r * normalize(x + delta) - x
delta.grad.zero_()
return delta.detach()
def parse_param(self, **kwargs):
for key, value in kwargs.items():
if key in self.param:
self.param[key] = value