From 79366b978e7ede3e6e73168b22af7cac758194fa Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 21 Sep 2023 10:04:00 -0700 Subject: [PATCH 01/15] Add mart.generate_config --- mart/generate_config.py | 43 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 mart/generate_config.py diff --git a/mart/generate_config.py b/mart/generate_config.py new file mode 100644 index 00000000..db7cee61 --- /dev/null +++ b/mart/generate_config.py @@ -0,0 +1,43 @@ +import os + +import fire +from hydra import compose, initialize_config_dir +from omegaconf import OmegaConf + + +def generate( + *overrides, + version_base: str = "1.2", + config_dir: str = "configs", + config_name: str = "lightning.yaml", + export_node: str = None, + export_name: str = "output.yaml", + resolve: bool = False, +): + # An absolute path {config_dir} is added to the search path of configs, preceding those in mart.configs. + if not os.path.isabs(config_dir): + config_dir = os.path.abspath(config_dir) + + with initialize_config_dir(version_base=version_base, config_dir=config_dir): + cfg = compose(config_name=config_name, overrides=overrides) + + # Export a sub-tree. + if export_node is not None: + for key in export_node.split("."): + cfg = cfg[key] + + # Resolve all interpolation in the sub-tree. + if resolve: + OmegaConf.resolve(cfg) + + # Create folders for output if necessary. + folder = os.path.dirname(export_name) + if folder != "" and not os.path.isdir(folder): + os.makedirs(folder) + + OmegaConf.save(config=cfg, f=export_name) + print(f"Config file saved to {export_name}") + + +if __name__ == "__main__": + fire.Fire(generate) From ec13141947c8ba2b6b33d1986f6fa7995bcfeec7 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 21 Sep 2023 10:04:30 -0700 Subject: [PATCH 02/15] Add header. --- mart/generate_config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mart/generate_config.py b/mart/generate_config.py index db7cee61..383000c6 100644 --- a/mart/generate_config.py +++ b/mart/generate_config.py @@ -1,3 +1,9 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + import os import fire From da814ba0b5f7480c08e4157b5590d724704a36f3 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 21 Sep 2023 10:14:21 -0700 Subject: [PATCH 03/15] Print yaml to stdout. --- mart/generate_config.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/mart/generate_config.py b/mart/generate_config.py index 383000c6..0ca4cdd7 100644 --- a/mart/generate_config.py +++ b/mart/generate_config.py @@ -17,7 +17,6 @@ def generate( config_dir: str = "configs", config_name: str = "lightning.yaml", export_node: str = None, - export_name: str = "output.yaml", resolve: bool = False, ): # An absolute path {config_dir} is added to the search path of configs, preceding those in mart.configs. @@ -36,13 +35,8 @@ def generate( if resolve: OmegaConf.resolve(cfg) - # Create folders for output if necessary. - folder = os.path.dirname(export_name) - if folder != "" and not os.path.isdir(folder): - os.makedirs(folder) - - OmegaConf.save(config=cfg, f=export_name) - print(f"Config file saved to {export_name}") + # OmegaConf.to_yaml() already ends with `\n`. + print(OmegaConf.to_yaml(cfg), end="") if __name__ == "__main__": From a8f484c9d64d9a72658a510d911d3358fff73ca5 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 27 Sep 2023 06:22:45 -0700 Subject: [PATCH 04/15] Modualize components. --- mart/generate_config.py | 51 ++++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/mart/generate_config.py b/mart/generate_config.py index 0ca4cdd7..17678657 100644 --- a/mart/generate_config.py +++ b/mart/generate_config.py @@ -11,33 +11,58 @@ from omegaconf import OmegaConf -def generate( +def mart_compose( *overrides, version_base: str = "1.2", config_dir: str = "configs", config_name: str = "lightning.yaml", export_node: str = None, - resolve: bool = False, ): - # An absolute path {config_dir} is added to the search path of configs, preceding those in mart.configs. + # Add an absolute path {config_dir} to the search path of configs, preceding those in mart.configs. if not os.path.isabs(config_dir): config_dir = os.path.abspath(config_dir) + # initialize_config_dir() requires an absolute path. with initialize_config_dir(version_base=version_base, config_dir=config_dir): cfg = compose(config_name=config_name, overrides=overrides) - # Export a sub-tree. - if export_node is not None: - for key in export_node.split("."): - cfg = cfg[key] + # Export a sub-tree. + if export_node is not None: + for key in export_node.split("."): + cfg = cfg[key] + + return cfg + + +def get_yaml_cfg(cfg, resolve: bool = False): + # Resolve all interpolation in the sub-tree. + if resolve: + OmegaConf.resolve(cfg) + + return OmegaConf.to_yaml(cfg) + + +def main( + *overrides, + version_base: str = "1.2", + config_dir: str = "configs", + config_name: str = "lightning.yaml", + export_node: str = None, + resolve: bool = False, +): + cfg = mart_compose( + *overrides, + version_base=version_base, + config_dir=config_dir, + config_name=config_name, + export_node=export_node, + ) - # Resolve all interpolation in the sub-tree. - if resolve: - OmegaConf.resolve(cfg) + cfg_yaml = get_yaml_cfg(cfg, resolve=resolve) - # OmegaConf.to_yaml() already ends with `\n`. - print(OmegaConf.to_yaml(cfg), end="") + # OmegaConf.to_yaml() already ends with `\n`. + print(cfg_yaml, end="") if __name__ == "__main__": - fire.Fire(generate) + fire.Fire(main) From b02d4c7db4ba0e2fc9beb6ba963e9807f5aa6128 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 27 Sep 2023 06:26:18 -0700 Subject: [PATCH 05/15] Comment. --- mart/generate_config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mart/generate_config.py b/mart/generate_config.py index 17678657..2145c037 100644 --- a/mart/generate_config.py +++ b/mart/generate_config.py @@ -22,7 +22,8 @@ def mart_compose( if not os.path.isabs(config_dir): config_dir = os.path.abspath(config_dir) - # initialize_config_dir() requires an absolute path. + # hydra.initialize_config_dir() requires an absolute path, + # while hydra.initialize() searches paths relatively to mart. with initialize_config_dir(version_base=version_base, config_dir=config_dir): cfg = compose(config_name=config_name, overrides=overrides) From d674b26026e450958be8efe130f4830dbbf67584 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 27 Sep 2023 06:32:24 -0700 Subject: [PATCH 06/15] Reuse default values. --- mart/generate_config.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/mart/generate_config.py b/mart/generate_config.py index 2145c037..c01ec4b2 100644 --- a/mart/generate_config.py +++ b/mart/generate_config.py @@ -10,12 +10,16 @@ from hydra import compose, initialize_config_dir from omegaconf import OmegaConf +DEFAULT_VERSION_BASE = "1.2" +DEFAULT_CONFIG_DIR = "." +DEFAULT_CONFIG_NAME = "lightning.yaml" + def mart_compose( *overrides, - version_base: str = "1.2", - config_dir: str = "configs", - config_name: str = "lightning.yaml", + version_base: str = DEFAULT_VERSION_BASE, + config_dir: str = DEFAULT_CONFIG_DIR, + config_name: str = DEFAULT_CONFIG_NAME, export_node: str = None, ): # Add an absolute path {config_dir} to the search path of configs, preceding those in mart.configs. @@ -45,9 +49,9 @@ def get_yaml_cfg(cfg, resolve: bool = False): def main( *overrides, - version_base: str = "1.2", - config_dir: str = "configs", - config_name: str = "lightning.yaml", + version_base: str = DEFAULT_VERSION_BASE, + config_dir: str = DEFAULT_CONFIG_DIR, + config_name: str = DEFAULT_CONFIG_NAME, export_node: str = None, resolve: bool = False, ): From 541c769d26c9b1fbe19f619ee3101c3742321e1c Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 27 Sep 2023 06:40:02 -0700 Subject: [PATCH 07/15] Use futuristic annotations. --- mart/generate_config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mart/generate_config.py b/mart/generate_config.py index c01ec4b2..2fa5c281 100644 --- a/mart/generate_config.py +++ b/mart/generate_config.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + import os import fire @@ -20,7 +22,7 @@ def mart_compose( version_base: str = DEFAULT_VERSION_BASE, config_dir: str = DEFAULT_CONFIG_DIR, config_name: str = DEFAULT_CONFIG_NAME, - export_node: str = None, + export_node: str | None = None, ): # Add an absolute path {config_dir} to the search path of configs, preceding those in mart.configs. if not os.path.isabs(config_dir): @@ -52,7 +54,7 @@ def main( version_base: str = DEFAULT_VERSION_BASE, config_dir: str = DEFAULT_CONFIG_DIR, config_name: str = DEFAULT_CONFIG_NAME, - export_node: str = None, + export_node: str | None = None, resolve: bool = False, ): cfg = mart_compose( From 22b83d49dca39b286eb0e4f4bceee8e0ac5a128d Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 27 Sep 2023 07:49:18 -0700 Subject: [PATCH 08/15] Move componenets to mart.utils.config. --- mart/generate_config.py | 48 +++++----------------------- mart/utils/config.py | 69 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 40 deletions(-) create mode 100644 mart/utils/config.py diff --git a/mart/generate_config.py b/mart/generate_config.py index 2fa5c281..2468c851 100644 --- a/mart/generate_config.py +++ b/mart/generate_config.py @@ -6,47 +6,15 @@ from __future__ import annotations -import os - import fire -from hydra import compose, initialize_config_dir -from omegaconf import OmegaConf - -DEFAULT_VERSION_BASE = "1.2" -DEFAULT_CONFIG_DIR = "." -DEFAULT_CONFIG_NAME = "lightning.yaml" - - -def mart_compose( - *overrides, - version_base: str = DEFAULT_VERSION_BASE, - config_dir: str = DEFAULT_CONFIG_DIR, - config_name: str = DEFAULT_CONFIG_NAME, - export_node: str | None = None, -): - # Add an absolute path {config_dir} to the search path of configs, preceding those in mart.configs. - if not os.path.isabs(config_dir): - config_dir = os.path.abspath(config_dir) - - # hydra.initialize_config_dir() requires an absolute path, - # while hydra.initialize() searches paths relatively to mart. - with initialize_config_dir(version_base=version_base, config_dir=config_dir): - cfg = compose(config_name=config_name, overrides=overrides) - - # Export a sub-tree. - if export_node is not None: - for key in export_node.split("."): - cfg = cfg[key] - - return cfg - - -def get_yaml_cfg(cfg, resolve: bool = False): - # Resolve all interpolation in the sub-tree. - if resolve: - OmegaConf.resolve(cfg) - return OmegaConf.to_yaml(cfg) +from .utils.config import ( + DEFAULT_CONFIG_DIR, + DEFAULT_CONFIG_NAME, + DEFAULT_VERSION_BASE, + compose, + get_yaml_cfg, +) def main( @@ -57,7 +25,7 @@ def main( export_node: str | None = None, resolve: bool = False, ): - cfg = mart_compose( + cfg = compose( *overrides, version_base=version_base, config_dir=config_dir, diff --git a/mart/utils/config.py b/mart/utils/config.py new file mode 100644 index 00000000..20204a5c --- /dev/null +++ b/mart/utils/config.py @@ -0,0 +1,69 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from __future__ import annotations + +import os + +from hydra import compose as hydra_compose +from hydra import initialize_config_dir +from hydra.utils import instantiate as hydra_instantiate +from omegaconf import OmegaConf + +DEFAULT_VERSION_BASE = "1.2" +DEFAULT_CONFIG_DIR = "." +DEFAULT_CONFIG_NAME = "lightning.yaml" + + +def compose( + *overrides, + version_base: str = DEFAULT_VERSION_BASE, + config_dir: str = DEFAULT_CONFIG_DIR, + config_name: str = DEFAULT_CONFIG_NAME, + export_node: str | None = None, +): + # Add an absolute path {config_dir} to the search path of configs, preceding those in mart.configs. + if not os.path.isabs(config_dir): + config_dir = os.path.abspath(config_dir) + + # hydra.initialize_config_dir() requires an absolute path, + # while hydra.initialize() searches paths relatively to mart. + with initialize_config_dir(version_base=version_base, config_dir=config_dir): + cfg = hydra_compose(config_name=config_name, overrides=overrides) + + # Export a sub-tree. + if export_node is not None: + for key in export_node.split("."): + cfg = cfg[key] + + return cfg + + +def instantiate( + *overrides, + version_base: str = DEFAULT_VERSION_BASE, + config_dir: str = DEFAULT_CONFIG_DIR, + config_name: str = DEFAULT_CONFIG_NAME, + export_node: str | None = None, +): + cfg = compose( + *overrides, + version_base=version_base, + config_dir=config_dir, + config_name=config_name, + export_node=export_node, + ) + + obj = hydra_instantiate(cfg) + return obj + + +def get_yaml_cfg(cfg, resolve: bool = False): + # Resolve all interpolation in the sub-tree. + if resolve: + OmegaConf.resolve(cfg) + + return OmegaConf.to_yaml(cfg) From 7f4b5b8177fb158f2f50a11c8050e1705cf76218 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 27 Sep 2023 07:52:29 -0700 Subject: [PATCH 09/15] Update mart.utils imports. --- mart/utils/__init__.py | 1 + mart/utils/config.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/mart/utils/__init__.py b/mart/utils/__init__.py index 50e71b3d..b243a608 100644 --- a/mart/utils/__init__.py +++ b/mart/utils/__init__.py @@ -1,4 +1,5 @@ from .adapters import * +from .config import * from .export import * from .monkey_patch import * from .pylogger import * diff --git a/mart/utils/config.py b/mart/utils/config.py index 20204a5c..4dce772d 100644 --- a/mart/utils/config.py +++ b/mart/utils/config.py @@ -17,6 +17,8 @@ DEFAULT_CONFIG_DIR = "." DEFAULT_CONFIG_NAME = "lightning.yaml" +__all__ = ["compose", "instantiate", "get_yaml_cfg"] + def compose( *overrides, From fc6795463e42cf119bab50b8b8d79cc524c944c2 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 27 Sep 2023 10:38:19 -0700 Subject: [PATCH 10/15] Test composer configs. --- tests/test_composer.py | 43 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/test_composer.py b/tests/test_composer.py index 93db2abd..c60d42c5 100644 --- a/tests/test_composer.py +++ b/tests/test_composer.py @@ -15,6 +15,7 @@ RectanglePad, RectanglePerspectiveTransform, ) +from mart.utils import instantiate def test_additive_composer_forward(input_data, target_data, perturbation): @@ -114,3 +115,45 @@ def test_rect_perspective_transform(): pert_coords_expected[:, 5:, 5:] = 1 # rounding numeric error from the perspective transformation. assert torch.equal(pert_coords.round(), pert_coords_expected) + + +def test_rect_patch_additive_composer(): + overrides = ["+attack/composer=rectangle_patch_additive"] + composer = instantiate(*overrides, export_node="attack.composer") + + input = torch.ones((3, 10, 10)) + perturbation = torch.ones_like(input) * 2 + + input_adv_expected = input.clone() + input_adv_expected[:, -5:, -5:] += 2 + + # A simple square patch on the bottom right. + patch_coords = torch.tensor(((5, 5), (10, 5), (10, 10), (5, 10))) + perturbable_mask = torch.zeros((10, 10)) + perturbable_mask[-5:, -5:] = 1 + + target = {"patch_coords": patch_coords, "perturbable_mask": perturbable_mask} + input_adv = composer(perturbation, input=input, target=target) + + assert torch.allclose(input_adv_expected, input_adv) + + +def test_rect_patch_overlay_composer(): + overrides = ["+attack/composer=rectangle_patch_overlay"] + composer = instantiate(*overrides, export_node="attack.composer") + + input = torch.ones((3, 10, 10)) + perturbation = torch.ones_like(input) * 2 + + input_adv_expected = input.clone() + input_adv_expected[:, -5:, -5:] = 2 + + # A simple square patch on the bottom right. + patch_coords = torch.tensor(((5, 5), (10, 5), (10, 10), (5, 10))) + perturbable_mask = torch.zeros((10, 10)) + perturbable_mask[-5:, -5:] = 1 + + target = {"patch_coords": patch_coords, "perturbable_mask": perturbable_mask} + input_adv = composer(perturbation, input=input, target=target) + + assert torch.allclose(input_adv_expected, input_adv) From 1c54c6ecc3afb6372d63115a228c53f50d8089a3 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 27 Sep 2023 10:41:02 -0700 Subject: [PATCH 11/15] Update config names. --- tests/test_composer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_composer.py b/tests/test_composer.py index c60d42c5..8fb10de1 100644 --- a/tests/test_composer.py +++ b/tests/test_composer.py @@ -118,7 +118,7 @@ def test_rect_perspective_transform(): def test_rect_patch_additive_composer(): - overrides = ["+attack/composer=rectangle_patch_additive"] + overrides = ["+attack/composer=rect_patch_additive"] composer = instantiate(*overrides, export_node="attack.composer") input = torch.ones((3, 10, 10)) @@ -139,7 +139,7 @@ def test_rect_patch_additive_composer(): def test_rect_patch_overlay_composer(): - overrides = ["+attack/composer=rectangle_patch_overlay"] + overrides = ["+attack/composer=rect_patch_overlay"] composer = instantiate(*overrides, export_node="attack.composer") input = torch.ones((3, 10, 10)) From 694a3b7cbc4338604f043ed6b771152e87646355 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 28 Sep 2023 10:04:57 -0700 Subject: [PATCH 12/15] Comment. --- mart/utils/config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mart/utils/config.py b/mart/utils/config.py index 4dce772d..4dfa43c0 100644 --- a/mart/utils/config.py +++ b/mart/utils/config.py @@ -51,6 +51,10 @@ def instantiate( config_name: str = DEFAULT_CONFIG_NAME, export_node: str | None = None, ): + """Compose and instantiate an object. + + Should be useful in testing configs. + """ cfg = compose( *overrides, version_base=version_base, From 199a290dea217f713acc360f228dd8e8f9ae36dd Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 28 Sep 2023 15:57:00 -0700 Subject: [PATCH 13/15] Remove mart.utils.config.instantiate() for now. --- mart/utils/config.py | 26 +------------------------- 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/mart/utils/config.py b/mart/utils/config.py index 4dfa43c0..5851f526 100644 --- a/mart/utils/config.py +++ b/mart/utils/config.py @@ -10,14 +10,13 @@ from hydra import compose as hydra_compose from hydra import initialize_config_dir -from hydra.utils import instantiate as hydra_instantiate from omegaconf import OmegaConf DEFAULT_VERSION_BASE = "1.2" DEFAULT_CONFIG_DIR = "." DEFAULT_CONFIG_NAME = "lightning.yaml" -__all__ = ["compose", "instantiate", "get_yaml_cfg"] +__all__ = ["compose", "get_yaml_cfg"] def compose( @@ -44,29 +43,6 @@ def compose( return cfg -def instantiate( - *overrides, - version_base: str = DEFAULT_VERSION_BASE, - config_dir: str = DEFAULT_CONFIG_DIR, - config_name: str = DEFAULT_CONFIG_NAME, - export_node: str | None = None, -): - """Compose and instantiate an object. - - Should be useful in testing configs. - """ - cfg = compose( - *overrides, - version_base=version_base, - config_dir=config_dir, - config_name=config_name, - export_node=export_node, - ) - - obj = hydra_instantiate(cfg) - return obj - - def get_yaml_cfg(cfg, resolve: bool = False): # Resolve all interpolation in the sub-tree. if resolve: From 86971a1e12c9089e53539f5c19352f0c04b2b7cc Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 28 Sep 2023 16:05:32 -0700 Subject: [PATCH 14/15] Simplify to_yaml(). --- mart/generate_config.py | 5 +++-- mart/utils/config.py | 11 +---------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/mart/generate_config.py b/mart/generate_config.py index 2468c851..f729c96e 100644 --- a/mart/generate_config.py +++ b/mart/generate_config.py @@ -7,13 +7,13 @@ from __future__ import annotations import fire +from omegaconf import OmegaConf from .utils.config import ( DEFAULT_CONFIG_DIR, DEFAULT_CONFIG_NAME, DEFAULT_VERSION_BASE, compose, - get_yaml_cfg, ) @@ -24,6 +24,7 @@ def main( config_name: str = DEFAULT_CONFIG_NAME, export_node: str | None = None, resolve: bool = False, + sort_keys: bool = True, ): cfg = compose( *overrides, @@ -33,7 +34,7 @@ def main( export_node=export_node, ) - cfg_yaml = get_yaml_cfg(cfg, resolve=resolve) + cfg_yaml = OmegaConf.to_yaml(cfg, resolve=resolve, sort_keys=sort_keys) # OmegaConf.to_yaml() already ends with `\n`. print(cfg_yaml, end="") diff --git a/mart/utils/config.py b/mart/utils/config.py index 5851f526..36dbb6e6 100644 --- a/mart/utils/config.py +++ b/mart/utils/config.py @@ -10,13 +10,12 @@ from hydra import compose as hydra_compose from hydra import initialize_config_dir -from omegaconf import OmegaConf DEFAULT_VERSION_BASE = "1.2" DEFAULT_CONFIG_DIR = "." DEFAULT_CONFIG_NAME = "lightning.yaml" -__all__ = ["compose", "get_yaml_cfg"] +__all__ = ["compose"] def compose( @@ -41,11 +40,3 @@ def compose( cfg = cfg[key] return cfg - - -def get_yaml_cfg(cfg, resolve: bool = False): - # Resolve all interpolation in the sub-tree. - if resolve: - OmegaConf.resolve(cfg) - - return OmegaConf.to_yaml(cfg) From 37415ba7e49399371a2e682580ecfaf504d5e80b Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 28 Sep 2023 16:16:20 -0700 Subject: [PATCH 15/15] Add mart.utils.instantiate for testing configs. --- mart/utils/config.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/mart/utils/config.py b/mart/utils/config.py index 36dbb6e6..384f4a31 100644 --- a/mart/utils/config.py +++ b/mart/utils/config.py @@ -10,12 +10,13 @@ from hydra import compose as hydra_compose from hydra import initialize_config_dir +from hydra.utils import instantiate as hydra_instantiate DEFAULT_VERSION_BASE = "1.2" DEFAULT_CONFIG_DIR = "." DEFAULT_CONFIG_NAME = "lightning.yaml" -__all__ = ["compose"] +__all__ = ["compose", "instantiate"] def compose( @@ -40,3 +41,26 @@ def compose( cfg = cfg[key] return cfg + + +def instantiate( + *overrides, + version_base: str = DEFAULT_VERSION_BASE, + config_dir: str = DEFAULT_CONFIG_DIR, + config_name: str = DEFAULT_CONFIG_NAME, + export_node: str | None = None, +): + """Compose and instantiate an object. + + Should be useful in testing configs. + """ + cfg = compose( + *overrides, + version_base=version_base, + config_dir=config_dir, + config_name=config_name, + export_node=export_node, + ) + + obj = hydra_instantiate(cfg) + return obj