forked from microsoft/TimeCraft
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_inference.py
More file actions
53 lines (39 loc) · 1.96 KB
/
train_inference.py
File metadata and controls
53 lines (39 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import sys
import traceback
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'diffusion'))
from pytorch_lightning.trainer import Trainer
from diffusion.utils.cli_utils import get_parser
from diffusion.utils.init_utils import init_model_data_trainer
from diffusion.utils.test_utils import test_model_with_dp, test_model_uncond, test_model_unseen, test_model_guidance
if __name__ == "__main__":
# data_root = os.environ.get('DATA_ROOT', None)
# if not data_root or not os.path.exists(data_root):
# raise ValueError("DATA_ROOT is not defined or does not exist!")
parser = get_parser()
parser = Trainer.add_argparse_args(parser)
model, data, trainer, opt, logdir, melk = init_model_data_trainer(parser)
if opt.train:
try:
trainer.logger.experiment.config.update(opt)
trainer.fit(model, data)
except Exception as e:
print("Exception occurred during training!")
print(traceback.format_exc())
if trainer is not None and trainer.lightning_module is not None:
print("Attempting to save checkpoint in exception handler via melk() ...")
melk() #
else:
print("Skipped calling melk() because trainer.lightning_module is None")
raise e #
if not opt.no_test and not getattr(trainer, "interrupted", False):
if opt.uncond and not opt.use_guidance:
test_model_uncond(model, data, trainer, opt, logdir)
if opt.use_guidance:
test_model_guidance(model, data, trainer, opt, logdir)
else:
test_model_with_dp(model, data, trainer, opt, logdir, use_pam=opt.use_pam, use_text=opt.use_text)
test_model_unseen(model, data, trainer, opt, logdir, use_pam=opt.use_pam, use_text=opt.use_text)