Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions mart/attack/gradient_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# SPDX-License-Identifier: BSD-3-Clause
#

from __future__ import annotations

import abc
from typing import Union
from typing import Iterable

import torch

Expand All @@ -15,25 +17,33 @@
class GradientModifier(abc.ABC):
"""Gradient modifier base class."""

@abc.abstractmethod
def __call__(self, grad: torch.Tensor) -> torch.Tensor:
def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
pass


class Sign(GradientModifier):
def __call__(self, grad: torch.Tensor) -> torch.Tensor:
return grad.sign()
def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
if isinstance(parameters, torch.Tensor):
Copy link
Contributor

@mzweilin mzweilin Mar 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wouldn't need a batch-aware GradientModifier if we accept the shared modality-iterable-dispatch mechanism in #115

parameters = [parameters]

parameters = [p for p in parameters if p.grad is not None]

for p in parameters:
p.grad.detach().sign_()


class LpNormalizer(GradientModifier):
"""Scale gradients by a certain L-p norm."""

def __init__(self, p: Union[int, float]):
super().__init__

def __init__(self, p: int | float):
self.p = p

def __call__(self, grad: torch.Tensor) -> torch.Tensor:
grad_norm = grad.norm(p=self.p)
grad_normalized = grad / grad_norm
return grad_normalized
def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]

parameters = [p for p in parameters if p.grad is not None]

for p in parameters:
p_norm = torch.norm(p.grad.detach(), p=self.p)
p.grad.detach().div_(p_norm)
11 changes: 10 additions & 1 deletion mart/attack/perturber/perturber.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause
#

from collections import namedtuple
from typing import Any, Dict, Optional, Union

import torch
Expand Down Expand Up @@ -71,7 +72,15 @@ def on_run_start(self, *, adversary, input, target, model, **kwargs):

# A backward hook that will be called when a gradient w.r.t the Tensor is computed.
if self.gradient_modifier is not None:
self.perturbation.register_hook(self.gradient_modifier)

def gradient_modifier(grad):
# Create fake tensor with cloned grad so we can use in-place operations
FakeTensor = namedtuple("FakeTensor", ["grad"])
param = FakeTensor(grad=grad.clone())
self.gradient_modifier([param])
return param.grad

self.perturbation.register_hook(gradient_modifier)

self.initializer(self.perturbation)

Expand Down
40 changes: 40 additions & 0 deletions tests/test_adversary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import mart
from mart.attack import Adversary
from mart.attack.gradient_modifier import Sign
from mart.attack.perturber import Perturber


Expand Down Expand Up @@ -145,3 +146,42 @@ def model(input, target, model=None, **kwargs):
# Simulate a new batch of data of different size.
new_input_data = torch.cat([input_data, input_data])
output3 = adversary(new_input_data, target_data, model=model)


def test_adversary_gradient(input_data, target_data):
composer = mart.attack.composer.Additive()
enforcer = Mock()
optimizer = partial(SGD, lr=1.0, maximize=True)

# Force zeros, positive and negative gradients
def gain(logits):
return (
(0 * logits[0, :, :]).mean()
+ (0.1 * logits[1, :, :]).mean() # noqa: W503
+ (-0.1 * logits[2, :, :]).mean() # noqa: W503
)

# Perturbation initialized as zero.
def initializer(x):
torch.nn.init.constant_(x, 0)

perturber = Perturber(initializer, Sign())

adversary = Adversary(
composer=composer,
enforcer=enforcer,
perturber=perturber,
optimizer=optimizer,
max_iters=1,
gain=gain,
)

def model(input, target, model=None, **kwargs):
return {"logits": adversary(input, target)}

adversary(input_data, target_data, model=model)
input_adv = adversary(input_data, target_data)

perturbation = input_data - input_adv

torch.testing.assert_close(perturbation.unique(), torch.Tensor([-1, 0, 1]))
26 changes: 17 additions & 9 deletions tests/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,23 @@


def test_gradient_sign(input_data):
gradient = Sign()
output = gradient(input_data)
expected_output = input_data.sign()
torch.testing.assert_close(output, expected_output)
# Don't share input_data with other tests, because the gradient would be changed.
input_data = torch.tensor([1.0, 2.0, 3.0])
input_data.grad = torch.tensor([-1.0, 3.0, 0.0])

grad_modifier = Sign()
grad_modifier(input_data)
expected_grad = torch.tensor([-1.0, 1.0, 0.0])
torch.testing.assert_close(input_data.grad, expected_grad)


def test_gradient_lp_normalizer():
# Don't share input_data with other tests, because the gradient would be changed.
input_data = torch.tensor([1.0, 2.0, 3.0])
input_data.grad = torch.tensor([-1.0, 3.0, 0.0])

def test_gradient_lp_normalizer(input_data):
p = 1
gradient = LpNormalizer(p)
output = gradient(input_data)
expected_output = input_data / input_data.norm(p=p)
torch.testing.assert_close(output, expected_output)
grad_modifier = LpNormalizer(p)
grad_modifier(input_data)
expected_grad = torch.tensor([-0.25, 0.75, 0.0])
torch.testing.assert_close(input_data.grad, expected_grad)