-
Notifications
You must be signed in to change notification settings - Fork 36
Expand file tree
/
Copy pathadvanced_example_experiment.py
More file actions
193 lines (155 loc) · 6.56 KB
/
advanced_example_experiment.py
File metadata and controls
193 lines (155 loc) · 6.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
This is an advanced experiment example, which makes use of sacred's captured functions with prefixes.
We wrap all the experiment-specific functionality inside the "ExperimentWrapper" class, and define methods with sacred's
@ex.capture decorator. This allows a modular design of the configuration, where certain sub-dictionaries (e.g., "data")
are parsed by a specific method. This avoids having one large "main" function which takes all parameters as input.
"""
import numpy as np
from seml import Experiment
ex = Experiment()
# Named configs can be used to define subconfigurations in a modular way. They can be composed in the experiment's configuration yaml file.
@ex.named_config
def preprocessing_none():
"""A named configuration that can be enabled in the configuration yaml file"""
preprocessing = {
"mean": 0.0,
"std": 1.0,
}
@ex.named_config
def preprocessing_normalize():
"""A named configuration that can be enabled in the configuration yaml file"""
preprocessing = {
"mean": 0.377,
"std": 0.23,
}
@ex.named_config
def batchnorm():
"""A named configuration that can be enabled in the configuration yaml file"""
model = {"batchnorm": True}
@ex.named_config
def no_batchnorm():
"""A named configuration that can be enabled in the configuration yaml file"""
model = {"batchnorm": False, "residual": False}
@ex.config
def config():
name = "${config.model.model_type}_${config.data.dataset}"
class ModelVariant1:
"""
A dummy model variant 1, which could, e.g., be a certain model or baseline in practice.
"""
def __init__(self, hidden_sizes, dropout, batchnorm, residual):
self.hidden_sizes = hidden_sizes
self.dropout = dropout
self.batchnorm = batchnorm
self.residual = residual
class ModelVariant2:
"""
A dummy model variant 2, which could, e.g., be a certain model or baseline in practice.
"""
def __init__(self, hidden_sizes, dropout, batchnorm, residual):
self.hidden_sizes = hidden_sizes
self.dropout = dropout
self.batchnorm = batchnorm
self.residual = residual
class ExperimentWrapper:
"""
A simple wrapper around a sacred experiment, making use of sacred's captured functions with prefixes.
This allows a modular design of the configuration, where certain sub-dictionaries (e.g., "data") are parsed by
specific method. This avoids having one large "main" function which takes all parameters as input.
"""
def __init__(self, init_all=True):
if init_all:
self.init_all()
# With the prefix option we can "filter" the configuration for the sub-dictionary under "data".
@ex.capture(prefix="data")
def init_dataset(self, dataset):
"""
Perform dataset loading, preprocessing etc.
Since we set prefix="data", this method only gets passed the respective sub-dictionary, enabling a modular
experiment design.
"""
if dataset == "large_dataset_1":
self.data = "load_dataset_here"
elif dataset == "large_dataset_2":
self.data = "and so on"
# ...
else:
self.data = "..."
@ex.capture(prefix="model")
def init_model(
self,
model_type: str,
model_params: dict,
batchnorm: bool,
residual: bool = True,
):
if model_type == "variant_1":
# Here we can pass the "model_params" dict to the constructor directly, which can be very useful in
# practice, since we don't have to do any model-specific processing of the config dictionary.
self.model = ModelVariant1(
**model_params, batchnorm=batchnorm, residual=residual
)
elif model_type == "variant_2":
self.model = ModelVariant2(
**model_params, batchnorm=batchnorm, residual=residual
)
@ex.capture(prefix="optimization")
def init_optimizer(self, regularization: dict, optimizer_type: str):
weight_decay = regularization["weight_decay"]
self.optimizer = optimizer_type # initialize optimizer
@ex.capture(prefix="preprocessing")
def init_preprocessing(self, mean: float, std: float):
self.preprocessing_parameters = (mean, std)
@ex.capture(prefix="augmentation")
def init_augmentation(self, flip: bool):
self.augmentation_parameters = (flip,)
def init_all(self):
"""
Sequentially run the sub-initializers of the experiment.
"""
self.init_dataset()
self.init_model()
self.init_optimizer()
self.init_preprocessing()
self.init_augmentation()
@ex.capture(prefix="training")
def train(self, patience, num_epochs):
# everything is set up
for e in range(num_epochs):
# simulate training
# calling reschedule hook
reschedule_hook(model_weights={}, step=e)
continue
results = {
"test_acc": 0.5 + 0.3 * np.random.randn(),
"test_loss": np.random.uniform(0, 10),
# ...
}
return results
# We can call this command, e.g., from a Jupyter notebook with init_all=False to get an "empty" experiment wrapper,
# where we can then for instance load a pretrained model to inspect the performance.
@ex.command(unobserved=True)
def get_experiment(init_all=False):
print("get_experiment")
experiment = ExperimentWrapper(init_all=init_all)
return experiment
# This function will be called when the reschedule is triggered.
# It should save the current state of the experiment and return a
# dictionary that may be used to update the configuration upon rescheduling.
# You are responsible for implementing the actual saving/loading of the experiment state
# due to the updated config.
@ex.reschedule_hook
def reschedule_hook(model_weights, step, **kwargs):
# Here you would save the current state of the experiment
# and return any necessary configuration updates.
# !!! You will need to call this function regularly from within your training loop
# to check if rescheduling is needed.
# Pass everything you need to store your state to this function.
return {"checkpoint_path": "path/to/saved/checkpoint"}
# This function will be called by default. Note that we could in principle manually pass an experiment instance,
# e.g., obtained by loading a model from the database or by calling this from a Jupyter notebook.
@ex.automain
def train(experiment=None):
if experiment is None:
experiment = ExperimentWrapper()
return experiment.train()