diff --git a/mart/attack/gradient_modifier.py b/mart/attack/gradient_modifier.py index fcb9b0db..dd680a95 100644 --- a/mart/attack/gradient_modifier.py +++ b/mart/attack/gradient_modifier.py @@ -4,8 +4,10 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + import abc -from typing import Union +from typing import Iterable import torch @@ -15,25 +17,33 @@ class GradientModifier(abc.ABC): """Gradient modifier base class.""" - @abc.abstractmethod - def __call__(self, grad: torch.Tensor) -> torch.Tensor: + def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: pass class Sign(GradientModifier): - def __call__(self, grad: torch.Tensor) -> torch.Tensor: - return grad.sign() + def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + parameters = [p for p in parameters if p.grad is not None] + + for p in parameters: + p.grad.detach().sign_() class LpNormalizer(GradientModifier): """Scale gradients by a certain L-p norm.""" - def __init__(self, p: Union[int, float]): - super().__init__ - + def __init__(self, p: int | float): self.p = p - def __call__(self, grad: torch.Tensor) -> torch.Tensor: - grad_norm = grad.norm(p=self.p) - grad_normalized = grad / grad_norm - return grad_normalized + def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + parameters = [p for p in parameters if p.grad is not None] + + for p in parameters: + p_norm = torch.norm(p.grad.detach(), p=self.p) + p.grad.detach().div_(p_norm) diff --git a/mart/attack/perturber/perturber.py b/mart/attack/perturber/perturber.py index d7eed81f..a34a5d39 100644 --- a/mart/attack/perturber/perturber.py +++ b/mart/attack/perturber/perturber.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # +from collections import namedtuple from typing import Any, Dict, Optional, Union import torch @@ -71,7 +72,15 @@ def on_run_start(self, *, adversary, input, target, model, **kwargs): # A backward hook that will be called when a gradient w.r.t the Tensor is computed. if self.gradient_modifier is not None: - self.perturbation.register_hook(self.gradient_modifier) + + def gradient_modifier(grad): + # Create fake tensor with cloned grad so we can use in-place operations + FakeTensor = namedtuple("FakeTensor", ["grad"]) + param = FakeTensor(grad=grad.clone()) + self.gradient_modifier([param]) + return param.grad + + self.perturbation.register_hook(gradient_modifier) self.initializer(self.perturbation) diff --git a/tests/test_adversary.py b/tests/test_adversary.py index c7438458..6c3eed87 100644 --- a/tests/test_adversary.py +++ b/tests/test_adversary.py @@ -12,6 +12,7 @@ import mart from mart.attack import Adversary +from mart.attack.gradient_modifier import Sign from mart.attack.perturber import Perturber @@ -145,3 +146,42 @@ def model(input, target, model=None, **kwargs): # Simulate a new batch of data of different size. new_input_data = torch.cat([input_data, input_data]) output3 = adversary(new_input_data, target_data, model=model) + + +def test_adversary_gradient(input_data, target_data): + composer = mart.attack.composer.Additive() + enforcer = Mock() + optimizer = partial(SGD, lr=1.0, maximize=True) + + # Force zeros, positive and negative gradients + def gain(logits): + return ( + (0 * logits[0, :, :]).mean() + + (0.1 * logits[1, :, :]).mean() # noqa: W503 + + (-0.1 * logits[2, :, :]).mean() # noqa: W503 + ) + + # Perturbation initialized as zero. + def initializer(x): + torch.nn.init.constant_(x, 0) + + perturber = Perturber(initializer, Sign()) + + adversary = Adversary( + composer=composer, + enforcer=enforcer, + perturber=perturber, + optimizer=optimizer, + max_iters=1, + gain=gain, + ) + + def model(input, target, model=None, **kwargs): + return {"logits": adversary(input, target)} + + adversary(input_data, target_data, model=model) + input_adv = adversary(input_data, target_data) + + perturbation = input_data - input_adv + + torch.testing.assert_close(perturbation.unique(), torch.Tensor([-1, 0, 1])) diff --git a/tests/test_gradient.py b/tests/test_gradient.py index e71023d4..a4ad49ee 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -11,15 +11,23 @@ def test_gradient_sign(input_data): - gradient = Sign() - output = gradient(input_data) - expected_output = input_data.sign() - torch.testing.assert_close(output, expected_output) + # Don't share input_data with other tests, because the gradient would be changed. + input_data = torch.tensor([1.0, 2.0, 3.0]) + input_data.grad = torch.tensor([-1.0, 3.0, 0.0]) + grad_modifier = Sign() + grad_modifier(input_data) + expected_grad = torch.tensor([-1.0, 1.0, 0.0]) + torch.testing.assert_close(input_data.grad, expected_grad) + + +def test_gradient_lp_normalizer(): + # Don't share input_data with other tests, because the gradient would be changed. + input_data = torch.tensor([1.0, 2.0, 3.0]) + input_data.grad = torch.tensor([-1.0, 3.0, 0.0]) -def test_gradient_lp_normalizer(input_data): p = 1 - gradient = LpNormalizer(p) - output = gradient(input_data) - expected_output = input_data / input_data.norm(p=p) - torch.testing.assert_close(output, expected_output) + grad_modifier = LpNormalizer(p) + grad_modifier(input_data) + expected_grad = torch.tensor([-0.25, 0.75, 0.0]) + torch.testing.assert_close(input_data.grad, expected_grad)