diff --git a/mart/__init__.py b/mart/__init__.py index 85181105..82d5d16b 100644 --- a/mart/__init__.py +++ b/mart/__init__.py @@ -1,11 +1,14 @@ -import importlib +import importlib.metadata from mart import attack as attack -from mart import datamodules as datamodules -from mart import models as models from mart import nn as nn from mart import optim as optim from mart import transforms as transforms from mart import utils as utils +from mart.utils.imports import _HAS_LIGHTNING + +if _HAS_LIGHTNING: + from mart import datamodules as datamodules + from mart import models as models __version__ = importlib.metadata.version(__package__ or __name__) diff --git a/mart/attack/__init__.py b/mart/attack/__init__.py index 843ce9bd..6e62c9a3 100644 --- a/mart/attack/__init__.py +++ b/mart/attack/__init__.py @@ -1,4 +1,4 @@ -from .adversary import * +from ..utils.imports import _HAS_LIGHTNING from .adversary_wrapper import * from .composer import * from .enforcer import * @@ -8,3 +8,6 @@ from .objective import * from .perturber import * from .projector import * + +if _HAS_LIGHTNING: + from .adversary import * diff --git a/mart/attack/initializer/base.py b/mart/attack/initializer/base.py index 49a5b173..9fd91091 100644 --- a/mart/attack/initializer/base.py +++ b/mart/attack/initializer/base.py @@ -10,10 +10,6 @@ import torch -from mart.utils import pylogger - -logger = pylogger.get_pylogger(__name__) - class Initializer: """Initializer base class.""" diff --git a/mart/attack/initializer/vision.py b/mart/attack/initializer/vision.py index 363141ef..0824fe26 100644 --- a/mart/attack/initializer/vision.py +++ b/mart/attack/initializer/vision.py @@ -4,14 +4,15 @@ # SPDX-License-Identifier: BSD-3-Clause # +import logging + import torch import torchvision import torchvision.transforms.functional as F -from ...utils import pylogger from .base import Initializer -logger = pylogger.get_pylogger(__name__) +logger = logging.getLogger(__name__) class Image(Initializer): diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index 76ff1f42..493cf2e4 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Callable, Iterable, Sequence import torch -from lightning.pytorch.utilities.exceptions import MisconfigurationException from .projector import Projector @@ -85,21 +84,19 @@ def create_from_tensor(tensor): def named_parameters(self, *args, **kwargs): if self.perturbation is None: - raise MisconfigurationException("You need to call configure_perturbation before fit.") + raise RuntimeError("You need to call configure_perturbation before fit.") return super().named_parameters(*args, **kwargs) def parameters(self, *args, **kwargs): if self.perturbation is None: - raise MisconfigurationException("You need to call configure_perturbation before fit.") + raise RuntimeError("You need to call configure_perturbation before fit.") return super().parameters(*args, **kwargs) def forward(self, **batch): if self.perturbation is None: - raise MisconfigurationException( - "You need to call the configure_perturbation before forward." - ) + raise RuntimeError("You need to call the configure_perturbation before forward.") self.projector_(self.perturbation, **batch) # We need to register the hook at every forward pass. diff --git a/mart/callbacks/__init__.py b/mart/callbacks/__init__.py index 5e648370..d04a6d26 100644 --- a/mart/callbacks/__init__.py +++ b/mart/callbacks/__init__.py @@ -1,3 +1,4 @@ +# All Lightning callbacks dependent on lightning, so we don't import mart.callbacks by default. from ..utils.imports import _HAS_TORCHVISION from .adversary_connector import * from .eval_mode import * diff --git a/mart/utils/__init__.py b/mart/utils/__init__.py index 30e41c0a..6e05a092 100644 --- a/mart/utils/__init__.py +++ b/mart/utils/__init__.py @@ -1,11 +1,12 @@ +# Only import components without external dependency. from .adapters import * -from .config import * -from .imports import _HAS_TORCHVISION +from .imports import _HAS_LIGHTNING from .monkey_patch import * -from .pylogger import * -from .rich_utils import * +from .optimization import * from .silent import * from .utils import * -if _HAS_TORCHVISION: - from .export import * +if _HAS_LIGHTNING: + from .lightning import * + from .pylogger import * + from .rich_utils import * diff --git a/mart/utils/imports.py b/mart/utils/imports.py index d053ce5e..01a82a8c 100644 --- a/mart/utils/imports.py +++ b/mart/utils/imports.py @@ -4,17 +4,19 @@ # SPDX-License-Identifier: BSD-3-Clause # +import logging from importlib.util import find_spec -from .pylogger import get_pylogger - -logger = get_pylogger(__name__) +# Avoid importing .pylogger when checking imports before running other code. +logger = logging.getLogger(__name__) def has(module_name): module = find_spec(module_name) if module is None: - logger.warn(f"{module_name} is not installed, so some features in MART are unavailable.") + logger.warning( + f"{module_name} is not installed, so some features in MART are unavailable." + ) return False else: return True @@ -25,3 +27,4 @@ def has(module_name): _HAS_TORCHVISION = has("torchvision") _HAS_TIMM = has("timm") _HAS_PYCOCOTOOLS = has("pycocotools") +_HAS_LIGHTNING = has("lightning") diff --git a/mart/utils/lightning.py b/mart/utils/lightning.py new file mode 100644 index 00000000..bc3ef69a --- /dev/null +++ b/mart/utils/lightning.py @@ -0,0 +1,275 @@ +import os +import time +import warnings +from collections import OrderedDict +from glob import glob +from importlib.util import find_spec +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +import hydra +from hydra.core.hydra_config import HydraConfig +from lightning.pytorch import Callback +from lightning.pytorch.loggers import Logger +from lightning.pytorch.utilities import rank_zero_only +from lightning.pytorch.utilities.model_summary import summarize +from omegaconf import DictConfig, OmegaConf + +from mart.utils import pylogger, rich_utils + +__all__ = [ + "close_loggers", + "extras", + "get_metric_value", + "get_resume_checkpoint", + "instantiate_callbacks", + "instantiate_loggers", + "log_hyperparameters", + "save_file", + "task_wrapper", +] + +log = pylogger.get_pylogger(__name__) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that wraps the task function in extra utilities. + + Makes multirun more resistant to failure. + + Utilities: + - Calling the `utils.extras()` before the task is started + - Calling the `utils.close_loggers()` after the task is finished + - Logging the exception if occurs + - Logging the task total execution time + - Logging the output dir + """ + + def wrap(cfg: DictConfig): + + # apply extra utilities + extras(cfg) + + # execute the task + try: + start_time = time.time() + metric_dict, object_dict = task_func(cfg=cfg) + except Exception as ex: + log.exception("") # save exception to `.log` file + raise ex + finally: + path = Path(cfg.paths.output_dir, "exec_time.log") + content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" + save_file(path, content) # save task execution time (even if exception occurs) + close_loggers() # close loggers (even if exception occurs so multirun won't fail) + + log.info(f"Output dir: {cfg.paths.output_dir}") + + return metric_dict, object_dict + + return wrap + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +@rank_zero_only +def save_file(path, content) -> None: + """Save file in rank zero mode (only on one process in multi-GPU setup).""" + with open(path, "w+") as file: + file.write(content) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("Callbacks config is empty.") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("Logger config is empty.") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + summary = summarize(model) + + hparams["model/params/total"] = summary.total_parameters + hparams["model/params/trainable"] = summary.trainable_parameters + hparams["model/params/non_trainable"] = summary.total_parameters - summary.trainable_parameters + + hparams["datamodule"] = cfg["datamodule"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + trainer.logger.log_hyperparams(hparams) + + +def get_metric_value(metric_dict: dict, metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule.""" + + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def close_loggers() -> None: + """Makes sure all loggers closed properly (prevents logging failure during multirun).""" + + log.info("Closing loggers...") + + if find_spec("wandb"): # if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + +def get_resume_checkpoint(config: DictConfig) -> Tuple[DictConfig]: + """Resume a task from an existing checkpoint along with the config.""" + + resume_checkpoint = None + + if config.get("resume"): + resume_filename = os.path.join("checkpoints", "*.ckpt") + resume_dir = hydra.utils.to_absolute_path(config.resume) + + # If we pass an explicit path, parse it to get base directory and checkpoint filename + if os.path.isfile(resume_dir): + resume_path = Path(resume_dir) + + # Resume path looks something like: + # /path/to/checkpoints/checkpoint_name.ckpt + # So we pass in to base directory and checkpoint filename and "checkpoints" directory + resume_dir = os.path.join(*resume_path.parts[:-2]) + resume_filename = os.path.join(*resume_path.parts[-2:]) + + # Get old and new overrides and combine them + current_overrides = HydraConfig.get().overrides.task + overrides_config = OmegaConf.load(os.path.join(resume_dir, ".hydra", "overrides.yaml")) + overrides = overrides_config + current_overrides + + # Find checkpoint and set PL trainer to resume + resume_checkpoint = glob(os.path.join(resume_dir, resume_filename)) + + if len(resume_checkpoint) == 0: + msg = f"No checkpoint found in {os.path.join(resume_dir, resume_filename)}!" + log.error(msg) + raise Exception(msg) + + # If we find more than 1 checkpoint, tell the user to be more explicit about their choice + # of checkpoint to resume from! + if len(resume_checkpoint) > 1: + msg = f"Found more than 1 checkpoint in {resume_dir} so you must pass a checkpoint path to resume:" + log.error(msg) + for path in resume_checkpoint: + log.error(f" {path}") + raise Exception(msg) + + resume_checkpoint = resume_checkpoint[0] + log.info(f"Resuming from {resume_checkpoint}") + # Save the ckpt_path in cfg for fit() and test(). + # This override won't be written to disk .hydra/overrides.yaml + overrides += [f"+ckpt_path={resume_checkpoint}"] + + # Load hydra.conf and use job config name to load original config with overrides + hydra_config = OmegaConf.load(os.path.join(resume_dir, ".hydra", "hydra.yaml")) + config_name = hydra_config.hydra.job.config_name + config = hydra.compose(config_name, overrides=overrides) + + return config diff --git a/mart/utils/utils.py b/mart/utils/utils.py index f4a0a4ec..b19e87b0 100644 --- a/mart/utils/utils.py +++ b/mart/utils/utils.py @@ -1,280 +1,9 @@ -import os -import time -import warnings -from collections import OrderedDict -from glob import glob -from importlib.util import find_spec -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple - -import hydra -from hydra.core.hydra_config import HydraConfig -from lightning.pytorch import Callback -from lightning.pytorch.loggers import Logger -from lightning.pytorch.utilities import rank_zero_only -from lightning.pytorch.utilities.model_summary import summarize -from omegaconf import DictConfig, OmegaConf - -from mart.utils import pylogger, rich_utils +from typing import Optional __all__ = [ - "close_loggers", - "extras", - "get_metric_value", - "get_resume_checkpoint", - "instantiate_callbacks", - "instantiate_loggers", - "log_hyperparameters", - "save_file", - "task_wrapper", "flatten_dict", ] -log = pylogger.get_pylogger(__name__) - - -def task_wrapper(task_func: Callable) -> Callable: - """Optional decorator that wraps the task function in extra utilities. - - Makes multirun more resistant to failure. - - Utilities: - - Calling the `utils.extras()` before the task is started - - Calling the `utils.close_loggers()` after the task is finished - - Logging the exception if occurs - - Logging the task total execution time - - Logging the output dir - """ - - def wrap(cfg: DictConfig): - - # apply extra utilities - extras(cfg) - - # execute the task - try: - start_time = time.time() - metric_dict, object_dict = task_func(cfg=cfg) - except Exception as ex: - log.exception("") # save exception to `.log` file - raise ex - finally: - path = Path(cfg.paths.output_dir, "exec_time.log") - content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" - save_file(path, content) # save task execution time (even if exception occurs) - close_loggers() # close loggers (even if exception occurs so multirun won't fail) - - log.info(f"Output dir: {cfg.paths.output_dir}") - - return metric_dict, object_dict - - return wrap - - -def extras(cfg: DictConfig) -> None: - """Applies optional utilities before the task is started. - - Utilities: - - Ignoring python warnings - - Setting tags from command line - - Rich config printing - """ - - # return if no `extras` config - if not cfg.get("extras"): - log.warning("Extras config not found! ") - return - - # disable python warnings - if cfg.extras.get("ignore_warnings"): - log.info("Disabling python warnings! ") - warnings.filterwarnings("ignore") - - # prompt user to input tags from command line if none are provided in the config - if cfg.extras.get("enforce_tags"): - log.info("Enforcing tags! ") - rich_utils.enforce_tags(cfg, save_to_file=True) - - # pretty print config tree using Rich library - if cfg.extras.get("print_config"): - log.info("Printing config tree with Rich! ") - rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) - - -@rank_zero_only -def save_file(path, content) -> None: - """Save file in rank zero mode (only on one process in multi-GPU setup).""" - with open(path, "w+") as file: - file.write(content) - - -def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: - """Instantiates callbacks from config.""" - callbacks: List[Callback] = [] - - if not callbacks_cfg: - log.warning("Callbacks config is empty.") - return callbacks - - if not isinstance(callbacks_cfg, DictConfig): - raise TypeError("Callbacks config must be a DictConfig!") - - for _, cb_conf in callbacks_cfg.items(): - if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: - log.info(f"Instantiating callback <{cb_conf._target_}>") - callbacks.append(hydra.utils.instantiate(cb_conf)) - - return callbacks - - -def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: - """Instantiates loggers from config.""" - logger: List[Logger] = [] - - if not logger_cfg: - log.warning("Logger config is empty.") - return logger - - if not isinstance(logger_cfg, DictConfig): - raise TypeError("Logger config must be a DictConfig!") - - for _, lg_conf in logger_cfg.items(): - if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: - log.info(f"Instantiating logger <{lg_conf._target_}>") - logger.append(hydra.utils.instantiate(lg_conf)) - - return logger - - -@rank_zero_only -def log_hyperparameters(object_dict: Dict[str, Any]) -> None: - """Controls which config parts are saved by lightning loggers. - - Additionally saves: - - Number of model parameters - """ - - hparams = {} - - cfg = object_dict["cfg"] - model = object_dict["model"] - trainer = object_dict["trainer"] - - if not trainer.logger: - log.warning("Logger not found! Skipping hyperparameter logging...") - return - - hparams["model"] = cfg["model"] - - # save number of model parameters - summary = summarize(model) - - hparams["model/params/total"] = summary.total_parameters - hparams["model/params/trainable"] = summary.trainable_parameters - hparams["model/params/non_trainable"] = summary.total_parameters - summary.trainable_parameters - - hparams["datamodule"] = cfg["datamodule"] - hparams["trainer"] = cfg["trainer"] - - hparams["callbacks"] = cfg.get("callbacks") - hparams["extras"] = cfg.get("extras") - - hparams["task_name"] = cfg.get("task_name") - hparams["tags"] = cfg.get("tags") - hparams["ckpt_path"] = cfg.get("ckpt_path") - hparams["seed"] = cfg.get("seed") - - # send hparams to all loggers - trainer.logger.log_hyperparams(hparams) - - -def get_metric_value(metric_dict: dict, metric_name: str) -> float: - """Safely retrieves value of the metric logged in LightningModule.""" - - if not metric_name: - log.info("Metric name is None! Skipping metric value retrieval...") - return None - - if metric_name not in metric_dict: - raise Exception( - f"Metric value not found! \n" - "Make sure metric name logged in LightningModule is correct!\n" - "Make sure `optimized_metric` name in `hparams_search` config is correct!" - ) - - metric_value = metric_dict[metric_name].item() - log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") - - return metric_value - - -def close_loggers() -> None: - """Makes sure all loggers closed properly (prevents logging failure during multirun).""" - - log.info("Closing loggers...") - - if find_spec("wandb"): # if wandb is installed - import wandb - - if wandb.run: - log.info("Closing wandb!") - wandb.finish() - - -def get_resume_checkpoint(config: DictConfig) -> Tuple[DictConfig]: - """Resume a task from an existing checkpoint along with the config.""" - - resume_checkpoint = None - - if config.get("resume"): - resume_filename = os.path.join("checkpoints", "*.ckpt") - resume_dir = hydra.utils.to_absolute_path(config.resume) - - # If we pass an explicit path, parse it to get base directory and checkpoint filename - if os.path.isfile(resume_dir): - resume_path = Path(resume_dir) - - # Resume path looks something like: - # /path/to/checkpoints/checkpoint_name.ckpt - # So we pass in to base directory and checkpoint filename and "checkpoints" directory - resume_dir = os.path.join(*resume_path.parts[:-2]) - resume_filename = os.path.join(*resume_path.parts[-2:]) - - # Get old and new overrides and combine them - current_overrides = HydraConfig.get().overrides.task - overrides_config = OmegaConf.load(os.path.join(resume_dir, ".hydra", "overrides.yaml")) - overrides = overrides_config + current_overrides - - # Find checkpoint and set PL trainer to resume - resume_checkpoint = glob(os.path.join(resume_dir, resume_filename)) - - if len(resume_checkpoint) == 0: - msg = f"No checkpoint found in {os.path.join(resume_dir, resume_filename)}!" - log.error(msg) - raise Exception(msg) - - # If we find more than 1 checkpoint, tell the user to be more explicit about their choice - # of checkpoint to resume from! - if len(resume_checkpoint) > 1: - msg = f"Found more than 1 checkpoint in {resume_dir} so you must pass a checkpoint path to resume:" - log.error(msg) - for path in resume_checkpoint: - log.error(f" {path}") - raise Exception(msg) - - resume_checkpoint = resume_checkpoint[0] - log.info(f"Resuming from {resume_checkpoint}") - # Save the ckpt_path in cfg for fit() and test(). - # This override won't be written to disk .hydra/overrides.yaml - overrides += [f"+ckpt_path={resume_checkpoint}"] - - # Load hydra.conf and use job config name to load original config with overrides - hydra_config = OmegaConf.load(os.path.join(resume_dir, ".hydra", "hydra.yaml")) - config_name = hydra_config.hydra.job.config_name - config = hydra.compose(config_name, overrides=overrides) - - return config - def flatten_dict(d, delimiter="."): def get_dottedpath_items(d: dict, parent: Optional[str] = None): diff --git a/tests/test_dependency.py b/tests/test_dependency.py index 5ee5ed24..4b7961b9 100644 --- a/tests/test_dependency.py +++ b/tests/test_dependency.py @@ -8,6 +8,7 @@ from mart.utils.imports import ( _HAS_FIFTYONE, + _HAS_LIGHTNING, _HAS_PYCOCOTOOLS, _HAS_TIMM, _HAS_TORCHVISION, @@ -17,5 +18,9 @@ def test_dependency_on_ci(): if os.getenv("CI") == "true": assert ( - _HAS_FIFTYONE and _HAS_TIMM and _HAS_PYCOCOTOOLS and _HAS_TORCHVISION is True + _HAS_FIFTYONE + and _HAS_TIMM + and _HAS_PYCOCOTOOLS + and _HAS_TORCHVISION + and _HAS_LIGHTNING is True ), "The dependency is not complete on CI, thus some tests are skipped." diff --git a/tests/test_perturber.py b/tests/test_perturber.py index af627359..cc7cdf64 100644 --- a/tests/test_perturber.py +++ b/tests/test_perturber.py @@ -4,15 +4,10 @@ # SPDX-License-Identifier: BSD-3-Clause # -from functools import partial from unittest.mock import Mock import pytest -import torch -from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch.optim import SGD -import mart from mart.attack import Perturber @@ -36,10 +31,10 @@ def test_misconfiguration(input_data, target_data): perturber = Perturber(initializer=initializer, projector=projector) - with pytest.raises(MisconfigurationException): + with pytest.raises(RuntimeError): perturber(input=input_data, target=target_data) - with pytest.raises(MisconfigurationException): + with pytest.raises(RuntimeError): perturber.parameters()