forked from twosixlabs/armory-example
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpatch_loss_gradient.py
More file actions
38 lines (28 loc) · 1.36 KB
/
patch_loss_gradient.py
File metadata and controls
38 lines (28 loc) · 1.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torchvision.transforms import RandomErasing
from torch.autograd import Variable
from art.attacks.evasion import ProjectedGradientDescent
from armory.utils.evaluation import patch_method
from patch_loss_gradient_model import get_art_model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class CustomAttack(ProjectedGradientDescent):
def __init__(self, estimator, **kwargs):
# Create copy of the model (to avoid overwriting loss_gradient_framework of original model)
new_estimator = get_art_model(model_kwargs={}, wrapper_kwargs={})
new_estimator.model.load_state_dict(estimator.model.state_dict())
# Point attack to copy of model
super().__init__(new_estimator, **kwargs)
@patch_method(new_estimator)
def loss_gradient_framework(
self, x: "torch.Tensor", y: "torch.Tensor", **kwargs
) -> "torch.Tensor":
x_var = Variable(x, requires_grad=True)
y_cat = torch.argmax(y)
transform = RandomErasing(p=1.0, scale=(0.5, 0.5))
x_mod = torch.stack([transform(x_var[0]) for i in range(100)], dim=0)
logits = self.model.net.forward(x_mod)
loss = self._loss(logits, y_cat.repeat(100))
self._model.zero_grad()
loss.backward()
grads = x_var.grad
return grads