diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index 17f7beb8..76ff1f42 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Sequence +from typing import TYPE_CHECKING, Callable, Iterable, Sequence import torch from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -25,17 +25,22 @@ def __init__( *, initializer: Initializer, projector: Projector | None = None, + grad_modifier: Callable = None, ): """_summary_ Args: initializer (Initializer): To initialize the perturbation. projector (Projector): To project the perturbation into some space. + grad_modifier (Callable): A non in-place operation being hooked by Tensor.register_hook(). """ super().__init__() self.initializer_ = initializer self.projector_ = projector or Projector() + self.grad_modifier = grad_modifier + # We can call the handler to remove the hook later by self._grad_modifier_hook_handler.remove() + self._grad_modifier_hook_handler = None self.perturbation = None @@ -97,6 +102,9 @@ def forward(self, **batch): ) self.projector_(self.perturbation, **batch) + # We need to register the hook at every forward pass. + if self.grad_modifier: + self._grad_modifier_hook_handler = self.perturbation.register_hook(self.grad_modifier) return self.perturbation @@ -111,8 +119,9 @@ def __init__( shape: Sequence[int], initializer: Initializer, projector: Projector | None = None, + grad_modifier: Callable = None, ): - super().__init__(initializer=initializer, projector=projector) + super().__init__(initializer=initializer, projector=projector, grad_modifier=grad_modifier) # We just configure the perturbation here. No need to invoke configure_perturbation() externally. self._configure_perturbation(shape)