-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
97 lines (77 loc) · 3.19 KB
/
main.py
File metadata and controls
97 lines (77 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (c) 2025-present, Royal Bank of Canada.
# Copyright (c) 2025-present, Kim et al.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
##########################################################################################
# Code is originally from the TAFAS (https://arxiv.org/pdf/2501.04970.pdf) implementation
# from https://github.com/kimanki/TAFAS by Kim et al. which is licensed under
# Modified MIT License (Non-Commercial with Permission).
# You may obtain a copy of the License at
#
# https://github.com/kimanki/TAFAS/blob/master/LICENSE
#
###########################################################################################
import os
from models.build import build_model, load_best_model, build_norm_module
from utils.parser import parse_args, load_config
from datasets.build import update_cfg_from_dataset
from trainer import build_trainer
from predictor import Predictor
from utils.misc import set_seeds, set_devices
from tta.tafas import build_adapter
import tta.cosa as cosa
import tta.petsa as petsa
import tta.dynatta as dynatta
from config import get_norm_module_cfg
def main():
args = parse_args()
cfg = load_config(args)
update_cfg_from_dataset(cfg, cfg.DATA.NAME)
cfg.RESULT_DIR = os.path.join(cfg.RESULT_DIR, cfg.TRAIN.CHECKPOINT_DIR.split('./checkpoints/')[-1])
if not os.path.exists(cfg.RESULT_DIR):
os.makedirs(cfg.RESULT_DIR)
# select cuda devices
set_devices(cfg.VISIBLE_DEVICES)
with open(os.path.join(cfg.RESULT_DIR, 'config.yaml'), 'w') as f:
f.write(cfg.dump())
# set random seed
set_seeds(cfg.SEED)
# build model
model = build_model(cfg)
norm_module = build_norm_module(cfg) if cfg.NORM_MODULE.ENABLE else None
if cfg.TRAIN.ENABLE:
# build trainer
trainer = build_trainer(cfg, model, norm_module=norm_module)
trainer.train()
if cfg.TTA.ENABLE or cfg.TEST.ENABLE:
model = load_best_model(cfg, model)
if cfg.NORM_MODULE.ENABLE:
norm_module = load_best_model(get_norm_module_cfg(cfg), norm_module)
if cfg.TTA.ENABLE:
if 'TAFAS' in cfg.RESULT_DIR:
# print("TAFAS")
adapter = build_adapter(cfg, model, norm_module=norm_module)
adapter.adapt()
adapter.count_parameters()
elif 'PETSA' in cfg.RESULT_DIR:
# print("PETSA")
adapter = petsa.build_adapter(cfg, model, norm_module=norm_module)
adapter.adapt()
adapter.count_parameters()
elif 'SIMPLE' in cfg.RESULT_DIR:
# print("SIMPLE ADAPTER")
adapter = cosa.build_adapter(cfg, model, norm_module=norm_module)
adapter.adapt()
adapter.count_parameters()
elif 'DYNATTA' in cfg.RESULT_DIR:
# print("DYNATTA ADAPTER")
adapter = dynatta.build_adapter(cfg, model, norm_module=norm_module)
adapter.adapt()
adapter.count_parameters()
if cfg.TEST.ENABLE:
predictor = Predictor(cfg, model, norm_module=norm_module)
predictor.predict()
if __name__ == '__main__':
main()