Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mart/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import importlib
import importlib.metadata

from mart import attack as attack
from mart import datamodules as datamodules
from mart import models as models
from mart import nn as nn
from mart import optim as optim
from mart import transforms as transforms
from mart import utils as utils
from mart.utils.imports import _HAS_LIGHTNING

if _HAS_LIGHTNING:
from mart import datamodules as datamodules
from mart import models as models

__version__ = importlib.metadata.version(__package__ or __name__)
5 changes: 4 additions & 1 deletion mart/attack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .adversary import *
from ..utils.imports import _HAS_LIGHTNING
from .adversary_wrapper import *
from .composer import *
from .enforcer import *
Expand All @@ -8,3 +8,6 @@
from .objective import *
from .perturber import *
from .projector import *

if _HAS_LIGHTNING:
from .adversary import *
4 changes: 0 additions & 4 deletions mart/attack/initializer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@

import torch

from mart.utils import pylogger

logger = pylogger.get_pylogger(__name__)


class Initializer:
"""Initializer base class."""
Expand Down
5 changes: 3 additions & 2 deletions mart/attack/initializer/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# SPDX-License-Identifier: BSD-3-Clause
#

import logging

import torch
import torchvision
import torchvision.transforms.functional as F

from ...utils import pylogger
from .base import Initializer

logger = pylogger.get_pylogger(__name__)
logger = logging.getLogger(__name__)


class Image(Initializer):
Expand Down
9 changes: 3 additions & 6 deletions mart/attack/perturber.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import TYPE_CHECKING, Callable, Iterable, Sequence

import torch
from lightning.pytorch.utilities.exceptions import MisconfigurationException

from .projector import Projector

Expand Down Expand Up @@ -85,21 +84,19 @@ def create_from_tensor(tensor):

def named_parameters(self, *args, **kwargs):
if self.perturbation is None:
raise MisconfigurationException("You need to call configure_perturbation before fit.")
raise RuntimeError("You need to call configure_perturbation before fit.")

return super().named_parameters(*args, **kwargs)

def parameters(self, *args, **kwargs):
if self.perturbation is None:
raise MisconfigurationException("You need to call configure_perturbation before fit.")
raise RuntimeError("You need to call configure_perturbation before fit.")

return super().parameters(*args, **kwargs)

def forward(self, **batch):
if self.perturbation is None:
raise MisconfigurationException(
"You need to call the configure_perturbation before forward."
)
raise RuntimeError("You need to call the configure_perturbation before forward.")

self.projector_(self.perturbation, **batch)
# We need to register the hook at every forward pass.
Expand Down
1 change: 1 addition & 0 deletions mart/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# All Lightning callbacks dependent on lightning, so we don't import mart.callbacks by default.
from ..utils.imports import _HAS_TORCHVISION
from .adversary_connector import *
from .eval_mode import *
Expand Down
13 changes: 7 additions & 6 deletions mart/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Only import components without external dependency.
from .adapters import *
from .config import *
from .imports import _HAS_TORCHVISION
from .imports import _HAS_LIGHTNING
from .monkey_patch import *
from .pylogger import *
from .rich_utils import *
from .optimization import *
from .silent import *
from .utils import *

if _HAS_TORCHVISION:
from .export import *
if _HAS_LIGHTNING:
from .lightning import *
from .pylogger import *
from .rich_utils import *
11 changes: 7 additions & 4 deletions mart/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@
# SPDX-License-Identifier: BSD-3-Clause
#

import logging
from importlib.util import find_spec

from .pylogger import get_pylogger

logger = get_pylogger(__name__)
# Avoid importing .pylogger when checking imports before running other code.
logger = logging.getLogger(__name__)


def has(module_name):
module = find_spec(module_name)
if module is None:
logger.warn(f"{module_name} is not installed, so some features in MART are unavailable.")
logger.warning(
f"{module_name} is not installed, so some features in MART are unavailable."
)
return False
else:
return True
Expand All @@ -25,3 +27,4 @@ def has(module_name):
_HAS_TORCHVISION = has("torchvision")
_HAS_TIMM = has("timm")
_HAS_PYCOCOTOOLS = has("pycocotools")
_HAS_LIGHTNING = has("lightning")
Loading
Loading