From d88c4587914243c00dcd0b3774c9ffed74d822e2 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 28 Mar 2023 13:13:39 -0700 Subject: [PATCH 1/2] Make Projector batch aware --- mart/attack/perturber/perturber.py | 2 +- mart/attack/projector.py | 121 +++++++++++------------------ 2 files changed, 47 insertions(+), 76 deletions(-) diff --git a/mart/attack/perturber/perturber.py b/mart/attack/perturber/perturber.py index d7eed81f..e8394f4f 100644 --- a/mart/attack/perturber/perturber.py +++ b/mart/attack/perturber/perturber.py @@ -55,7 +55,7 @@ def projector_wrapper(perturber_module, args): raise ValueError("Perturbation must be initialized") input, target = args - return projector(perturber_module.perturbation, input, target) + return projector(perturber_module.perturbation, input=input, target=target) # Will be called before forward() is called. if projector is not None: diff --git a/mart/attack/projector.py b/mart/attack/projector.py index 4c360688..92391c67 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -4,23 +4,37 @@ # SPDX-License-Identifier: BSD-3-Clause # -import abc -from typing import Any, Dict, List, Optional, Union +from __future__ import annotations -import torch +from typing import Any -__all__ = ["Projector"] +import torch -class Projector(abc.ABC): +class Projector: """A projector modifies nn.Parameter's data.""" @torch.no_grad() def __call__( self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], + perturbation: torch.Tensor | tuple, + *, + input: torch.Tensor | tuple, + target: torch.Tensor | dict[str, Any] | tuple, + **kwargs, + ) -> None: + if isinstance(perturbation, tuple): + for perturbation_i, input_i, target_i in zip(perturbation, input, target): + self.project(perturbation_i, input=input_i, target=target_i) + else: + self.project(perturbation, input=input, target=target) + + def project( + self, + perturbation: torch.Tensor | tuple, + *, + input: torch.Tensor | tuple, + target: torch.Tensor | dict[str, Any] | tuple, ) -> None: pass @@ -28,18 +42,20 @@ def __call__( class Compose(Projector): """Apply a list of perturbation modifier.""" - def __init__(self, projectors: List[Projector]): + def __init__(self, projectors: list[Projector]): self.projectors = projectors @torch.no_grad() def __call__( self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], + perturbation: torch.Tensor | tuple, + *, + input: torch.Tensor | tuple, + target: torch.Tensor | dict[str, Any] | tuple, + **kwargs, ) -> None: for projector in self.projectors: - projector(tensor, input, target) + projector(perturbation, input=input, target=target) def __repr__(self): projector_names = [repr(p) for p in self.projectors] @@ -49,26 +65,15 @@ def __repr__(self): class Range(Projector): """Clamp the perturbation so that the output is range-constrained.""" - def __init__( - self, - quantize: Optional[bool] = False, - min: Optional[Union[int, float]] = 0, - max: Optional[Union[int, float]] = 255, - ): + def __init__(self, quantize: bool = False, min: int | float = 0, max: int | float = 255): self.quantize = quantize self.min = min self.max = max - @torch.no_grad() - def __call__( - self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], - ) -> None: + def project(self, perturbation, *, input, target): if self.quantize: - tensor.round_() - tensor.clamp_(self.min, self.max) + perturbation.round_() + perturbation.clamp_(self.min, self.max) def __repr__(self): return ( @@ -82,26 +87,15 @@ class RangeAdditive(Projector): The projector assumes an additive perturbation threat model. """ - def __init__( - self, - quantize: Optional[bool] = False, - min: Optional[Union[int, float]] = 0, - max: Optional[Union[int, float]] = 255, - ): + def __init__(self, quantize: bool = False, min: int | float = 0, max: int | float = 255): self.quantize = quantize self.min = min self.max = max - @torch.no_grad() - def __call__( - self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], - ) -> None: + def project(self, perturbation, *, input, target): if self.quantize: - tensor.round_() - tensor.clamp_(self.min - input, self.max - input) + perturbation.round_() + perturbation.clamp_(self.min - input, self.max - input) def __repr__(self): return ( @@ -112,7 +106,7 @@ def __repr__(self): class Lp(Projector): """Project perturbations to Lp norm, only if the Lp norm is larger than eps.""" - def __init__(self, eps: float, p: Optional[Union[int, float]] = torch.inf): + def __init__(self, eps: int | float, p: int | float = torch.inf): """_summary_ Args: @@ -123,55 +117,32 @@ def __init__(self, eps: float, p: Optional[Union[int, float]] = torch.inf): self.p = p self.eps = eps - @torch.no_grad() - def __call__( - self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], - ) -> None: - pert_norm = tensor.norm(p=self.p) + def project(self, perturbation, *, input, target): + pert_norm = perturbation.norm(p=self.p) if pert_norm > self.eps: # We only upper-bound the norm. - tensor.mul_(self.eps / pert_norm) + perturbation.mul_(self.eps / pert_norm) class LinfAdditiveRange(Projector): """Make sure the perturbation is within the Linf norm ball, and "input + perturbation" is within the [min, max] range.""" - def __init__( - self, - eps: float, - min: Optional[Union[int, float]] = 0, - max: Optional[Union[int, float]] = 255, - ): + def __init__(self, eps: int | float, min: int | float = 0, max: int | float = 255): self.eps = eps self.min = min self.max = max - @torch.no_grad() - def __call__( - self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], - ) -> None: + def project(self, perturbation, *, input, target): eps_min = (input - self.eps).clamp(self.min, self.max) - input eps_max = (input + self.eps).clamp(self.min, self.max) - input - tensor.clamp_(eps_min, eps_max) + perturbation.clamp_(eps_min, eps_max) class Mask(Projector): - @torch.no_grad() - def __call__( - self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], - ) -> None: - tensor.mul_(target["perturbable_mask"]) + def project(self, perturbation, *, input, target): + perturbation.mul_(target["perturbable_mask"]) def __repr__(self): return f"{self.__class__.__name__}()" From 13b8e32b5804a059b6ba6bb74ed65fdaf39dec97 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 28 Mar 2023 13:19:04 -0700 Subject: [PATCH 2/2] Update projector tests --- tests/test_projector.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_projector.py b/tests/test_projector.py index 9983fa4a..a397a98c 100644 --- a/tests/test_projector.py +++ b/tests/test_projector.py @@ -36,7 +36,7 @@ def test_range_projector_repr(): @pytest.mark.parametrize("max", [10, 100, 110]) def test_range_projector(quantize, min, max, input_data, target_data, perturbation): projector = Range(quantize, min, max) - projector(perturbation, input_data, target_data) + projector(perturbation, input=input_data, target=target_data) assert torch.max(perturbation) <= max assert torch.min(perturbation) >= min @@ -61,7 +61,7 @@ def test_range_additive_projector(quantize, min, max, input_data, target_data, p expected_perturbation = torch.clone(perturbation) projector = RangeAdditive(quantize, min, max) - projector(perturbation, input_data, target_data) + projector(perturbation, input=input_data, target=target_data) # modify expected_perturbation if quantize: @@ -78,7 +78,7 @@ def test_lp_projector(eps, p, input_data, target_data, perturbation): expected_perturbation = torch.clone(perturbation) projector = Lp(eps, p) - projector(perturbation, input_data, target_data) + projector(perturbation, input=input_data, target=target_data) # modify expected_perturbation pert_norm = expected_perturbation.norm(p=p) @@ -95,7 +95,7 @@ def test_linf_additive_range_projector(min, max, eps, input_data, target_data, p expected_perturbation = torch.clone(perturbation) projector = LinfAdditiveRange(eps, min, max) - projector(perturbation, input_data, target_data) + projector(perturbation, input=input_data, target=target_data) # get expected result eps_min = (input_data - eps).clamp(min, max) - input_data @@ -117,7 +117,7 @@ def test_mask_projector(input_data, target_data, perturbation): expected_perturbation = torch.clone(perturbation) projector = Mask() - projector(perturbation, input_data, target_data) + projector(perturbation, input=input_data, target=target_data) # get expected output expected_perturbation.mul_(target_data["perturbable_mask"]) @@ -156,7 +156,7 @@ def test_compose(input_data, target_data): compose = Compose(projectors) tensor = Mock() tensor.norm.return_value = 10 - compose(tensor, input_data, target_data) + compose(tensor, input=input_data, target=target_data) # RangeProjector, RangeAdditiveProjector, and LinfAdditiveRangeProjector calls `clamp_` assert tensor.clamp_.call_count == 3