diff --git a/mart/utils/config.py b/mart/utils/config.py index 36dbb6e6..384f4a31 100644 --- a/mart/utils/config.py +++ b/mart/utils/config.py @@ -10,12 +10,13 @@ from hydra import compose as hydra_compose from hydra import initialize_config_dir +from hydra.utils import instantiate as hydra_instantiate DEFAULT_VERSION_BASE = "1.2" DEFAULT_CONFIG_DIR = "." DEFAULT_CONFIG_NAME = "lightning.yaml" -__all__ = ["compose"] +__all__ = ["compose", "instantiate"] def compose( @@ -40,3 +41,26 @@ def compose( cfg = cfg[key] return cfg + + +def instantiate( + *overrides, + version_base: str = DEFAULT_VERSION_BASE, + config_dir: str = DEFAULT_CONFIG_DIR, + config_name: str = DEFAULT_CONFIG_NAME, + export_node: str | None = None, +): + """Compose and instantiate an object. + + Should be useful in testing configs. + """ + cfg = compose( + *overrides, + version_base=version_base, + config_dir=config_dir, + config_name=config_name, + export_node=export_node, + ) + + obj = hydra_instantiate(cfg) + return obj diff --git a/tests/test_composer.py b/tests/test_composer.py index 3aba543c..0b8e7021 100644 --- a/tests/test_composer.py +++ b/tests/test_composer.py @@ -17,6 +17,7 @@ PerturbationRectanglePad, PerturbationRectanglePerspectiveTransform, ) +from mart.utils import instantiate def test_additive_composer_forward(input_data, target_data, perturbation): @@ -125,3 +126,45 @@ def test_pert_rect_perspective_transform(): pert_coords_expected[:, 5:, 5:] = 1 # rounding numeric error from the perspective transformation. assert torch.equal(pert_coords.round(), pert_coords_expected) + + +def test_rect_patch_additive_composer(): + overrides = ["+attack/composer=rect_patch_additive"] + composer = instantiate(*overrides, export_node="attack.composer") + + input = torch.ones((3, 10, 10)) + perturbation = torch.ones_like(input) * 2 + + input_adv_expected = input.clone() + input_adv_expected[:, -5:, -5:] += 2 + + # A simple square patch on the bottom right. + patch_coords = torch.tensor(((5, 5), (10, 5), (10, 10), (5, 10))) + perturbable_mask = torch.zeros((10, 10)) + perturbable_mask[-5:, -5:] = 1 + + target = {"patch_coords": patch_coords, "perturbable_mask": perturbable_mask} + input_adv = composer(perturbation, input=input, target=target) + + assert torch.allclose(input_adv_expected, input_adv) + + +def test_rect_patch_overlay_composer(): + overrides = ["+attack/composer=rect_patch_overlay"] + composer = instantiate(*overrides, export_node="attack.composer") + + input = torch.ones((3, 10, 10)) + perturbation = torch.ones_like(input) * 2 + + input_adv_expected = input.clone() + input_adv_expected[:, -5:, -5:] = 2 + + # A simple square patch on the bottom right. + patch_coords = torch.tensor(((5, 5), (10, 5), (10, 10), (5, 10))) + perturbable_mask = torch.zeros((10, 10)) + perturbable_mask[-5:, -5:] = 1 + + target = {"patch_coords": patch_coords, "perturbable_mask": perturbable_mask} + input_adv = composer(perturbation, input=input, target=target) + + assert torch.allclose(input_adv_expected, input_adv)