forked from arthurkosmala/EwaldMP
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexperiment_oe62.py
More file actions
101 lines (83 loc) · 2.77 KB
/
experiment_oe62.py
File metadata and controls
101 lines (83 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import logging
import os
import seml
import torch
from sacred import Experiment
from ocpmodels import models
from ocpmodels.common import logger
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import setup_logging
from ocpmodels.trainers import EnergyTrainer
ex = Experiment()
seml.setup_logger(ex)
@ex.post_run_hook
def collect_stats(_run):
seml.collect_exp_stats(_run)
@ex.config
def config():
overwrite = None
db_collection = None
if db_collection is not None:
ex.observers.append(
seml.create_mongodb_observer(db_collection, overwrite=overwrite)
)
@ex.automain
def run(
dataset,
task,
model,
optimizer,
logger,
name,
):
setup_logging()
# checkpoint_path_train = [checkpoint path if resuming from previous run]
trainer = EnergyTrainer(
task=task,
model=model,
dataset=dataset,
optimizer=optimizer,
identifier=name,
run_dir="./",
# directory to save results if is_debug=False. Prediction files are saved here so be careful not to override!
is_debug=False, # if True, do not save checkpoint, logs, or results
print_every=5000,
seed=0, # random seed to use
logger=logger, # logger of choice (tensorboard and wandb supported)
local_rank=0,
amp=False, # use PyTorch Automatic Mixed Precision (faster training and less memory usage)
)
# trainer.load_checkpoint(checkpoint_path=checkpoint_path_train)
trainer.train()
results_memory = {
"max_allocated": torch.cuda.max_memory_allocated() / 1024 / 1024,
"max_reserved": torch.cuda.max_memory_reserved() / 1024 / 1024,
}
#### Validation part ####
checkpoint_path = os.path.join(
trainer.config["cmd"]["checkpoint_dir"], "best_checkpoint.pt"
)
trainer = EnergyTrainer(
task=task,
model=model,
dataset=dataset,
optimizer=optimizer,
identifier="val",
run_dir="./",
# directory to save results if is_debug=False. Prediction files are saved here so be careful not to override!
is_debug=True, # if True, do not save checkpoint, logs, or results
print_every=5000,
seed=0, # random seed to use
logger=logger, # logger of choice (tensorboard and wandb supported)
local_rank=0,
amp=False, # use PyTorch Automatic Mixed Precision (faster training and less memory usage)
)
trainer.load_checkpoint(checkpoint_path=checkpoint_path)
metrics = trainer.validate()
results = {key: val["metric"] for key, val in metrics.items()}
results = {
"performance": results,
"memory": results_memory,
}
# the returned result will be written into the database
return results