Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mart/attack/perturber/perturber.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def projector_wrapper(perturber_module, args):
raise ValueError("Perturbation must be initialized")

input, target = args
return projector(perturber_module.perturbation, input, target)
return projector(perturber_module.perturbation, input=input, target=target)

# Will be called before forward() is called.
if projector is not None:
Expand Down
121 changes: 46 additions & 75 deletions mart/attack/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,58 @@
# SPDX-License-Identifier: BSD-3-Clause
#

import abc
from typing import Any, Dict, List, Optional, Union
from __future__ import annotations

import torch
from typing import Any

__all__ = ["Projector"]
import torch


class Projector(abc.ABC):
class Projector:
"""A projector modifies nn.Parameter's data."""

@torch.no_grad()
def __call__(
self,
tensor: torch.Tensor,
input: torch.Tensor,
target: Union[torch.Tensor, Dict[str, Any]],
perturbation: torch.Tensor | tuple,
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
**kwargs,
) -> None:
if isinstance(perturbation, tuple):
for perturbation_i, input_i, target_i in zip(perturbation, input, target):
self.project(perturbation_i, input=input_i, target=target_i)
else:
self.project(perturbation, input=input, target=target)

def project(
self,
perturbation: torch.Tensor | tuple,
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
) -> None:
pass


class Compose(Projector):
"""Apply a list of perturbation modifier."""

def __init__(self, projectors: List[Projector]):
def __init__(self, projectors: list[Projector]):
Copy link
Contributor

@mzweilin mzweilin Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's better to make it a dictionary so that it's more easily configurable in Hydra? Like that in Enforcer:

def __init__(self, constraints: dict[str, Constraint] | None = None) -> None:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that should live in another PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My read of the Projectors class is that they need a wholesale redesign because there's a lot of mixing of functionality instead of focusing on composition.

self.projectors = projectors

@torch.no_grad()
def __call__(
self,
tensor: torch.Tensor,
input: torch.Tensor,
target: Union[torch.Tensor, Dict[str, Any]],
perturbation: torch.Tensor | tuple,
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
**kwargs,
) -> None:
for projector in self.projectors:
projector(tensor, input, target)
projector(perturbation, input=input, target=target)

def __repr__(self):
projector_names = [repr(p) for p in self.projectors]
Expand All @@ -49,26 +65,15 @@ def __repr__(self):
class Range(Projector):
"""Clamp the perturbation so that the output is range-constrained."""

def __init__(
self,
quantize: Optional[bool] = False,
min: Optional[Union[int, float]] = 0,
max: Optional[Union[int, float]] = 255,
):
def __init__(self, quantize: bool = False, min: int | float = 0, max: int | float = 255):
self.quantize = quantize
self.min = min
self.max = max

@torch.no_grad()
def __call__(
self,
tensor: torch.Tensor,
input: torch.Tensor,
target: Union[torch.Tensor, Dict[str, Any]],
) -> None:
def project(self, perturbation, *, input, target):
if self.quantize:
tensor.round_()
tensor.clamp_(self.min, self.max)
perturbation.round_()
perturbation.clamp_(self.min, self.max)

def __repr__(self):
return (
Expand All @@ -82,26 +87,15 @@ class RangeAdditive(Projector):
The projector assumes an additive perturbation threat model.
"""

def __init__(
self,
quantize: Optional[bool] = False,
min: Optional[Union[int, float]] = 0,
max: Optional[Union[int, float]] = 255,
):
def __init__(self, quantize: bool = False, min: int | float = 0, max: int | float = 255):
self.quantize = quantize
self.min = min
self.max = max

@torch.no_grad()
def __call__(
self,
tensor: torch.Tensor,
input: torch.Tensor,
target: Union[torch.Tensor, Dict[str, Any]],
) -> None:
def project(self, perturbation, *, input, target):
if self.quantize:
tensor.round_()
tensor.clamp_(self.min - input, self.max - input)
perturbation.round_()
perturbation.clamp_(self.min - input, self.max - input)

def __repr__(self):
return (
Expand All @@ -112,7 +106,7 @@ def __repr__(self):
class Lp(Projector):
"""Project perturbations to Lp norm, only if the Lp norm is larger than eps."""

def __init__(self, eps: float, p: Optional[Union[int, float]] = torch.inf):
def __init__(self, eps: int | float, p: int | float = torch.inf):
"""_summary_

Args:
Expand All @@ -123,55 +117,32 @@ def __init__(self, eps: float, p: Optional[Union[int, float]] = torch.inf):
self.p = p
self.eps = eps

@torch.no_grad()
def __call__(
self,
tensor: torch.Tensor,
input: torch.Tensor,
target: Union[torch.Tensor, Dict[str, Any]],
) -> None:
pert_norm = tensor.norm(p=self.p)
def project(self, perturbation, *, input, target):
pert_norm = perturbation.norm(p=self.p)
if pert_norm > self.eps:
# We only upper-bound the norm.
tensor.mul_(self.eps / pert_norm)
perturbation.mul_(self.eps / pert_norm)


class LinfAdditiveRange(Projector):
"""Make sure the perturbation is within the Linf norm ball, and "input + perturbation" is
within the [min, max] range."""

def __init__(
self,
eps: float,
min: Optional[Union[int, float]] = 0,
max: Optional[Union[int, float]] = 255,
):
def __init__(self, eps: int | float, min: int | float = 0, max: int | float = 255):
self.eps = eps
self.min = min
self.max = max

@torch.no_grad()
def __call__(
self,
tensor: torch.Tensor,
input: torch.Tensor,
target: Union[torch.Tensor, Dict[str, Any]],
) -> None:
def project(self, perturbation, *, input, target):
eps_min = (input - self.eps).clamp(self.min, self.max) - input
eps_max = (input + self.eps).clamp(self.min, self.max) - input

tensor.clamp_(eps_min, eps_max)
perturbation.clamp_(eps_min, eps_max)


class Mask(Projector):
@torch.no_grad()
def __call__(
self,
tensor: torch.Tensor,
input: torch.Tensor,
target: Union[torch.Tensor, Dict[str, Any]],
) -> None:
tensor.mul_(target["perturbable_mask"])
def project(self, perturbation, *, input, target):
perturbation.mul_(target["perturbable_mask"])

def __repr__(self):
return f"{self.__class__.__name__}()"
12 changes: 6 additions & 6 deletions tests/test_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_range_projector_repr():
@pytest.mark.parametrize("max", [10, 100, 110])
def test_range_projector(quantize, min, max, input_data, target_data, perturbation):
projector = Range(quantize, min, max)
projector(perturbation, input_data, target_data)
projector(perturbation, input=input_data, target=target_data)

assert torch.max(perturbation) <= max
assert torch.min(perturbation) >= min
Expand All @@ -61,7 +61,7 @@ def test_range_additive_projector(quantize, min, max, input_data, target_data, p
expected_perturbation = torch.clone(perturbation)

projector = RangeAdditive(quantize, min, max)
projector(perturbation, input_data, target_data)
projector(perturbation, input=input_data, target=target_data)

# modify expected_perturbation
if quantize:
Expand All @@ -78,7 +78,7 @@ def test_lp_projector(eps, p, input_data, target_data, perturbation):
expected_perturbation = torch.clone(perturbation)

projector = Lp(eps, p)
projector(perturbation, input_data, target_data)
projector(perturbation, input=input_data, target=target_data)

# modify expected_perturbation
pert_norm = expected_perturbation.norm(p=p)
Expand All @@ -95,7 +95,7 @@ def test_linf_additive_range_projector(min, max, eps, input_data, target_data, p
expected_perturbation = torch.clone(perturbation)

projector = LinfAdditiveRange(eps, min, max)
projector(perturbation, input_data, target_data)
projector(perturbation, input=input_data, target=target_data)

# get expected result
eps_min = (input_data - eps).clamp(min, max) - input_data
Expand All @@ -117,7 +117,7 @@ def test_mask_projector(input_data, target_data, perturbation):
expected_perturbation = torch.clone(perturbation)

projector = Mask()
projector(perturbation, input_data, target_data)
projector(perturbation, input=input_data, target=target_data)

# get expected output
expected_perturbation.mul_(target_data["perturbable_mask"])
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_compose(input_data, target_data):
compose = Compose(projectors)
tensor = Mock()
tensor.norm.return_value = 10
compose(tensor, input_data, target_data)
compose(tensor, input=input_data, target=target_data)

# RangeProjector, RangeAdditiveProjector, and LinfAdditiveRangeProjector calls `clamp_`
assert tensor.clamp_.call_count == 3
Expand Down