diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 524ef9d5..5d609f16 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -91,3 +91,118 @@ def compose(self, perturbation, *, input, target): perturbation = perturbation * mask return input * (1 - mask) + perturbation + + +# FIXME: It would be really nice if we could compose composers just like we can compose everything else... +class WarpComposite(Composite): + def __init__( + self, + warp, + *args, + clamp=(0, 255), + premultiplied_alpha=True, + **kwargs, + ): + super().__init__(*args, premultiplied_alpha=premultiplied_alpha, **kwargs) + + self._warp = warp + self.clamp = clamp + + # FIXME: This looks an awful like warp below. We should be able to get rid of this function. + def fixed_warp(self, perturbation, *, input, target): + # Use gs_coords to do fixed perspective warp + assert "gs_coords" in target + + if len(input.shape) == 4 and len(perturbation.shape) == 3: + return torch.stack( + [ + self.warp(perturbation, input=inp, target={"gs_coords": endpoints}) + for inp, endpoints in zip(input, target["gs_coords"]) + ] + ) + else: + # coordinates are [[left, top], [right, top], [right, bottom], [left, bottom]] + # perturbation is CHW + startpoints = [ + [0, 0], + [perturbation.shape[2], 0], + [perturbation.shape[2], perturbation.shape[1]], + [0, perturbation.shape[1]], + ] + endpoints = target["gs_coords"] + + pert_w, pert_h = F.get_image_size(perturbation) + image_w, image_h = F.get_image_size(input) + + # Pad perturbation to image size + if pert_w < image_w or pert_h < image_h: + # left, top, right and bottom + padding = [0, 0, max(image_w - pert_w, 0), max(image_h - pert_h, 0)] + perturbation = F.pad(perturbation, padding) + + perturbation = F.perspective(perturbation, startpoints, endpoints) + + # Crop perturbation to image size + if pert_w != image_w or pert_h != image_h: + perturbation = F.crop(perturbation, 0, 0, image_h, image_w) + return perturbation + + def warp(self, perturbation, *, input, target): + # Always use gs_coords if present in target + if "gs_coords" in target: + return self.fixed_warp(perturbation, input=input, target=target) + + # Otherwise, warp the perturbation onto the input + if len(input.shape) == 4 and len(perturbation.shape) == 3: # support for batch warping + return torch.stack( + [self.warp(perturbation, input=inp, target=target) for inp in input] + ) + else: + pert_w, pert_h = F.get_image_size(perturbation) + image_w, image_h = F.get_image_size(input) + + # Pad perturbation to image size + if pert_w < image_w or pert_h < image_h: + # left, top, right and bottom + padding = [0, 0, max(image_w - pert_w, 0), max(image_h - pert_h, 0)] + perturbation = F.pad(perturbation, padding) + + perturbation = self._warp(perturbation) + + # Crop perturbation to image size + if pert_w != image_w or pert_h != image_h: + perturbation = F.crop(perturbation, 0, 0, image_h, image_w) + return perturbation + + def compose(self, perturbation, *, input, target): + # Create mask of ones to keep track of filled in pixels + mask = torch.ones_like(perturbation[:1]) + + # Add mask to perturbation so we can keep track of warping. + perturbation = torch.cat((perturbation, mask)) + + # Apply warp transform + perturbation = self.warp(perturbation, input=input, target=target) + + # Extract mask from perturbation. The use of channels first forces this hack. + if len(perturbation.shape) == 4: + mask = perturbation[:, 3:, ...] + perturbation = perturbation[:, :3, ...] + else: + mask = perturbation[3:, ...] + perturbation = perturbation[:3, ...] + + # Set/update perturbable mask + perturbable_mask = 1 + if "perturbable_mask" in target: + perturbable_mask = target["perturbable_mask"] + perturbable_mask = perturbable_mask * mask + + # Pre multiply perturbation and clamp it to input min/max + perturbation = perturbation * perturbable_mask + perturbation.clamp_(*self.clamp) + + # Set mask for super().compose + target["perturbable_mask"] = perturbable_mask + + return super().compose(perturbation, input=input, target=target) diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index 29df3059..72ae69d9 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -25,6 +25,7 @@ def __init__( *, initializer: Initializer, projector: Projector | None = None, + size: Iterable[int] | None = None, ): """_summary_ @@ -39,6 +40,10 @@ def __init__( self.perturbation = None + # FIXME: Should this be in UniversalAdversary? + if size is not None: + self.configure_perturbation(torch.empty(size)) + def configure_perturbation(self, input: torch.Tensor | Iterable[torch.Tensor]): def matches(input, perturbation): if perturbation is None: diff --git a/mart/attack/projector.py b/mart/attack/projector.py index f9887354..619f7b57 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -26,6 +26,14 @@ def __call__( if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): self.project_(perturbation, input=input, target=target) + elif ( + isinstance(perturbation, torch.Tensor) + and isinstance(input, Iterable) # noqa: W503 + and isinstance(target, Iterable) # noqa: W503 + ): + for input_i, target_i in zip(input, target): + self.project_(perturbation, input=input_i, target=target_i) + elif ( isinstance(perturbation, Iterable) and isinstance(input, Iterable) # noqa: W503 diff --git a/mart/callbacks/eval_mode.py b/mart/callbacks/eval_mode.py index be3b6397..639444c9 100644 --- a/mart/callbacks/eval_mode.py +++ b/mart/callbacks/eval_mode.py @@ -4,23 +4,47 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + from pytorch_lightning.callbacks import Callback +from mart import utils + +logger = utils.get_pylogger(__name__) + __all__ = ["AttackInEvalMode"] class AttackInEvalMode(Callback): """Switch the model into eval mode during attack.""" - def __init__(self): - self.training_mode_status = None - - def on_train_start(self, trainer, model): - self.training_mode_status = model.training - model.train(False) - - def on_train_end(self, trainer, model): - assert self.training_mode_status is not None - - # Resume the previous training status of the model. - model.train(self.training_mode_status) + def __init__(self, module_classes: type | list[type]): + # FIXME: convert strings to classes using hydra.utils.get_class? This will clean up some verbosity in configuration but will require importing hydra in this callback. + if isinstance(module_classes, type): + module_classes = [module_classes] + + self.module_classes = tuple(module_classes) + + def setup(self, trainer, pl_module, stage): + if stage != "fit": + return + + # Log to the console so the user can see visually see which modules will be in eval mode during training. + for name, module in pl_module.named_modules(): + if isinstance(module, self.module_classes): + logger.info( + f"Setting eval mode for {name} ({module.__class__.__module__}.{module.__class__.__name__})" + ) + + def on_train_epoch_start(self, trainer, pl_module): + # We must use on_train_epoch_start because PL will set pl_module to train mode right before this callback. + # See: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks + for name, module in pl_module.named_modules(): + if isinstance(module, self.module_classes): + module.eval() + + def on_train_epoch_end(self, trainer, pl_module): + # FIXME: Why is this necessary? + for name, module in pl_module.named_modules(): + if isinstance(module, self.module_classes): + module.train() diff --git a/mart/callbacks/no_grad_mode.py b/mart/callbacks/no_grad_mode.py index cfb90ead..4a86d985 100644 --- a/mart/callbacks/no_grad_mode.py +++ b/mart/callbacks/no_grad_mode.py @@ -4,8 +4,15 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + +import torch from pytorch_lightning.callbacks import Callback +from mart import utils + +logger = utils.get_pylogger(__name__) + __all__ = ["ModelParamsNoGrad"] @@ -15,10 +22,25 @@ class ModelParamsNoGrad(Callback): This callback should not change the result. Don't use unless an attack runs faster. """ - def on_train_start(self, trainer, model): - for param in model.parameters(): - param.requires_grad_(False) + def __init__(self, module_names: str | list[str] = None): + if isinstance(module_names, str): + module_names = [module_names] + + self.module_names = module_names + + def setup(self, trainer, pl_module, stage): + if stage != "fit": + return + + # We use setup, and not on_train_start, so that mart.optim.OptimizerFactory can ignore parameters with no gradients. + # See: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks + for name, param in pl_module.named_parameters(): + if any(name.startswith(module_name) for module_name in self.module_names): + logger.info(f"Disabling gradient for {name}") + param.requires_grad_(False) - def on_train_end(self, trainer, model): - for param in model.parameters(): - param.requires_grad_(True) + def teardown(self, trainer, pl_module, stage): + for name, param in pl_module.named_parameters(): + if any(name.startswith(module_name) for module_name in self.module_names): + # FIXME: Why is this necessary? + param.requires_grad_(True) diff --git a/mart/configs/attack/composer/warp_composite.yaml b/mart/configs/attack/composer/warp_composite.yaml new file mode 100644 index 00000000..dcc71680 --- /dev/null +++ b/mart/configs/attack/composer/warp_composite.yaml @@ -0,0 +1,2 @@ +_target_: mart.attack.composer.WarpComposite +warp: ??? diff --git a/mart/configs/attack/perturber/default.yaml b/mart/configs/attack/perturber/default.yaml index 7f4e1a8b..9b094218 100644 --- a/mart/configs/attack/perturber/default.yaml +++ b/mart/configs/attack/perturber/default.yaml @@ -1,3 +1,4 @@ _target_: mart.attack.Perturber initializer: ??? projector: null +size: null diff --git a/mart/configs/callbacks/attack_in_eval_mode.yaml b/mart/configs/callbacks/attack_in_eval_mode.yaml index 2acdc953..4ca096b0 100644 --- a/mart/configs/callbacks/attack_in_eval_mode.yaml +++ b/mart/configs/callbacks/attack_in_eval_mode.yaml @@ -1,2 +1,11 @@ attack_in_eval_mode: _target_: mart.callbacks.AttackInEvalMode + module_classes: ??? + # - _target_: hydra.utils.get_class + # path: mart.models.LitModular + # - _target_: hydra.utils.get_class + # path: torch.nn.BatchNorm2d + # - _target_: hydra.utils.get_class + # path: torch.nn.Dropout + # - _target_: hydra.utils.get_class + # path: torch.nn.SyncBatchNorm diff --git a/mart/configs/callbacks/no_grad_mode.yaml b/mart/configs/callbacks/no_grad_mode.yaml index 6b4312fd..d12d18e9 100644 --- a/mart/configs/callbacks/no_grad_mode.yaml +++ b/mart/configs/callbacks/no_grad_mode.yaml @@ -1,2 +1,3 @@ -attack_in_eval_mode: +no_grad_mode: _target_: mart.callbacks.ModelParamsNoGrad + module_names: ??? diff --git a/mart/configs/experiment/COCO_TorchvisionFasterRCNN_ShapeShifter.yaml b/mart/configs/experiment/COCO_TorchvisionFasterRCNN_ShapeShifter.yaml new file mode 100644 index 00000000..22a81ce5 --- /dev/null +++ b/mart/configs/experiment/COCO_TorchvisionFasterRCNN_ShapeShifter.yaml @@ -0,0 +1,86 @@ +# @package _global_ + +defaults: + - /attack/perturber@model.modules.perturbation: default + - /attack/perturber/initializer@model.modules.perturbation.initializer: uniform + - /attack/perturber/projector@model.modules.perturbation.projector: range + - /attack/composer@model.modules.input_adv: warp_composite + - /attack/gradient_modifier@model.gradient_modifier: lp_normalizer + - override /datamodule: coco + - override /model: torchvision_faster_rcnn + - override /metric: average_precision + - override /optimization: super_convergence + - override /callbacks: + [model_checkpoint, lr_monitor, perturbation_visualizer, gradient_monitor] + +task_name: "COCO_TorchvisionFasterRCNN_ShapeShifter" +tags: ["adv"] + +optimized_metric: "test_metrics/map" + +callbacks: + model_checkpoint: + monitor: "validation_metrics/map" + mode: "min" + + perturbation_visualizer: + perturbation: "model.perturbation.perturbation" + +trainer: + # 117,266 training images, 6 epochs, batch_size=2, 351798 + max_steps: 351798 + # FIXME: "nms_kernel" not implemented for 'BFloat16', torch.ops.torchvision.nms(). + precision: 32 + +datamodule: + num_workers: 8 + ims_per_batch: 2 + +model: + modules: + perturbation: + size: [3, 416, 416] + + initializer: + min: 127 + max: 129 + + input_adv: + warp: + _target_: torchvision.transforms.Compose + transforms: + - _target_: mart.transforms.ColorJitter + brightness: [0.5, 1.5] + contrast: [0.5, 1.5] + saturation: [0.5, 1.0] + hue: [-0.05, 0.05] + - _target_: torchvision.transforms.RandomAffine + degrees: [-5, 5] + translate: [0.1, 0.25] + scale: [0.4, 0.6] + shear: [-3, 3, -3, 3] + interpolation: 2 # BILINEAR + clamp: [0, 255] + + losses_and_detections: + model: + num_classes: null # inferred by torchvision + weights: COCO_V1 + + optimizer: + lr: 25.5 + momentum: 0.9 + maximize: True + + gradient_modifier: + p: inf + + training_sequence: + seq005: perturbation + seq006: input_adv + seq010: + preprocessor: ["input_adv"] + + validation_sequence: ${.training_sequence} + + test_sequence: ${.validation_sequence} diff --git a/mart/configs/metric/average_precision.yaml b/mart/configs/metric/average_precision.yaml index d41f9743..3438a090 100644 --- a/mart/configs/metric/average_precision.yaml +++ b/mart/configs/metric/average_precision.yaml @@ -9,13 +9,5 @@ validation_metrics: compute_on_step: false test_metrics: - _target_: torchmetrics.collections.MetricCollection - _convert_: partial - metrics: - map: - _target_: torchmetrics.detection.MAP - compute_on_step: false - json: - _target_: mart.utils.export.CocoPredictionJSON - prediction_file_name: ${paths.output_dir}/test_prediction.json - groundtruth_file_name: ${paths.output_dir}/test_groundtruth.json + _target_: torchmetrics.detection.MAP + compute_on_step: false diff --git a/mart/models/modular.py b/mart/models/modular.py index a27c6867..755c10ed 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -35,6 +35,7 @@ def __init__( test_sequence=None, test_step_log=None, test_metrics=None, + gradient_modifier=None, load_state_dict=None, output_loss_key="loss", output_preds_key="preds", @@ -91,6 +92,8 @@ def __init__( self.test_step_log = test_step_log or {} self.test_metrics = test_metrics + self.gradient_modifier = gradient_modifier + # Load state dict for specified modules. We flatten it because Hydra # commandlines converts dotted paths to nested dictionaries. load_state_dict = flatten_dict(load_state_dict or {}) @@ -120,6 +123,18 @@ def configure_optimizers(self): return config + def configure_gradient_clipping( + self, optimizer, optimizer_idx, gradient_clip_val=None, gradient_clip_algorithm=None + ): + # Configuring gradient clipping in pl.Trainer is still useful, so use it. + super().configure_gradient_clipping( + optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm + ) + + if self.gradient_modifier is not None: + for group in optimizer.param_groups: + self.gradient_modifier(group["params"]) + def forward(self, **kwargs): return self.model(**kwargs) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 02113899..8ce147bb 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -295,8 +295,16 @@ def __init__(self, *args, **kwargs): # FIXME: This must exist already?! class Sum(torch.nn.Module): - def __init__(self): + def __init__(self, weights=None): super().__init__() - def forward(self, *args): - return sum(args) + self.weights = weights + + def forward(self, *values, weights=None): + weights = weights or self.weights + + if weights is None: + weights = [1 for _ in values] + + assert len(weights) == len(values) + return sum(value * weight for value, weight in zip(values, weights)) diff --git a/mart/transforms/transforms.py b/mart/transforms/transforms.py index 4c7f29f7..4524f6b7 100644 --- a/mart/transforms/transforms.py +++ b/mart/transforms/transforms.py @@ -16,6 +16,7 @@ "Chunk", "TupleTransforms", "GetItems", + "ColorJitter", ] @@ -101,3 +102,21 @@ def __init__(self, keys): def __call__(self, x): x_list = [x[key] for key in self.keys] return x_list + + +class ColorJitter(T.ColorJitter): + def forward(self, img): + # Assume final channel is alpha + if len(img.shape) == 3: + alpha = img[3:4, ...] + rgb = img[:3, ...] + dim = 0 + elif len(img.shape) == 4: + alpha = img[:, 3:4, ...] + rgb = img[:, :3, ...] + dim = 1 + else: + raise NotImplementedError + + rgb = super().forward(rgb) + return torch.cat([rgb, alpha], dim=dim)