-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
298 lines (227 loc) · 11.6 KB
/
train.py
File metadata and controls
298 lines (227 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
import torch
from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from resources.consts import (
CHECKPOINT_PATH,
TRAINING_STATE_SAVE_PATH_FORMAT,
VAL_SAMPLE_SAVE_PATH,
MODEL_SAVE_PATH_FORMAT
)
import utils.training_utils as training_utils
from datasets import get_dataset_distributed
from models import define_model
from collections import defaultdict, OrderedDict
import shutil
import click
import glob
import yaml
import logging
import os
import os.path as osp
class Trainer:
def __init__(self, world_size, rank, cfg_file):
with open(cfg_file, "r") as f:
opt = yaml.safe_load(f)
self.opt = opt
if opt["exp_name"] is None:
self.exp_name = opt["dataset"]["dataset_name"] + "_" + opt["model"]["name"]
else:
self.exp_name = opt["exp_name"]
self.proj_dir = opt["proj_dir"]
if self.opt["train"]["load_iter"] == "auto":
exp_path = osp.join(CHECKPOINT_PATH, self.exp_name)
training_state_paths = sorted(glob.glob(osp.join(exp_path, "training_states/*.state")))
if len(training_state_paths) == 0:
self.load_iter = -1
else:
self.load_iter = int(osp.basename(training_state_paths[-1]).replace(".state", ""))
else:
self.load_iter = self.opt["train"]["load_iter"]
self.total_iters = opt["train"]["total_iters"]
self.checkpoint_interval = opt["train"]["checkpoint_interval"]
self.print_interval = opt["train"]["print_interval"]
self.val_interval = opt["train"]["val_interval"]
self.rank = rank
if rank == 0:
self.initialize_training_folders(self.load_iter == -1)
self.train_batch_size = opt["datasets"]["train"]["batch_size"]
self.val_batch_size = opt["datasets"]["val"]["batch_size"]
self.train_dataloader, self.val_dataloader = get_dataset_distributed(
world_size, rank, opt
)
self.model = define_model(opt["model"])
self.model = DDP(self.model, device_ids=[torch.cuda.current_device()], find_unused_parameters=True)
self.bare_model = self.model.module if hasattr(self.model, "module") else self.model
self.optimizer = training_utils.get_optimizer(self.bare_model.get_net_parameters(), opt["train"]["optimizer"])
self.warmup = opt["train"]["optimizer"]["warmup"]
self.l_factor = self.opt["train"]["optimizer"]["l_factor"]
self.cycle_every_n_epoch = self.opt["train"]["optimizer"]["cycle_every"]
self.resume_training()
def initialize_training_folders(self, from_scratch):
exp_path = osp.join(self.proj_dir, CHECKPOINT_PATH, self.exp_name)
model_folder = osp.join(exp_path, "models")
training_state_folder = osp.join(exp_path, "training_states")
if from_scratch:
if osp.isdir(exp_path):
timestamp = training_utils.get_timestamp()
empty = (len(glob.glob(osp.join(training_state_folder, "*.state"))) == 0)
if not empty:
os.rename(exp_path, osp.join(osp.dirname(exp_path), self.exp_name + "_archived_" + timestamp))
else:
shutil.rmtree(exp_path)
os.makedirs(exp_path)
os.makedirs(model_folder, exist_ok=True)
os.makedirs(training_state_folder, exist_ok=True)
self.writer = SummaryWriter(log_dir=exp_path)
training_utils.setup_logger("base", exp_path, screen=True, tofile=True)
self.logger = logging.getLogger("base")
def training_loop(self):
timer = training_utils.AvgTimer()
data_timer = training_utils.AvgTimer()
if self.rank == 0:
self.logger.info(f"Number of params: {self.bare_model.count_parameters()}")
self.validation()
self.bare_model.train()
for epoch in range(self.current_iter, self.total_iters):
epoch_logs = defaultdict(float)
for batch_idx, datapoint in enumerate(self.train_dataloader):
self.cyclic_steps += 1
for i, g in enumerate(self.optimizer.param_groups):
g["lr"] = self.l_factor * min(1.0, self.cyclic_steps / self.warmup) / max(self.cyclic_steps, self.warmup)
# record dataloading time
data_timer.record()
self.optimizer.zero_grad()
timer.start()
loss, tb_logs = self.bare_model(datapoint)
loss.backward()
self.optimizer.step()
timer.record()
if self.rank == 0:
for k, v in tb_logs.items():
self.writer.add_scalar(k, v, self.global_step)
epoch_logs[f"{k}_epoch"] += v
if self.rank == 0 and self.global_step % self.print_interval == 0:
avg_time = timer.get_avg_time()
data_avg_time = data_timer.get_avg_time()
assert len(self.optimizer.param_groups) == 1
curr_lr = self.optimizer.param_groups[0]["lr"]
self.logger.info(
f"Epoch {epoch}; Step {self.global_step}; Average training time (Net - Data):"
f"{avg_time:.2f} - {data_avg_time:.4f} | LR: {curr_lr:.8f}",
)
self.global_step += 1
data_timer.start()
if self.rank == 0:
for k, _ in epoch_logs.items():
assert "_epoch" in k
epoch_logs[k] /= len(self.train_dataloader)
self.writer.add_scalar(k, epoch_logs[k], epoch)
for tag, value in self.bare_model.get_net_parameters():
if value.grad is not None:
self.logger.info(f"grad: {value.grad.cpu()}")
self.writer.add_histogram("grad/" + tag, value.grad.cpu(), epoch)
if self.rank == 0 and epoch > 0 and epoch % self.val_interval == 0:
self.validation()
self.bare_model.train()
if self.rank == 0:
for i, g in enumerate(self.optimizer.param_groups):
self.writer.add_scalar(f'param_group{i}/lr', g["lr"], epoch)
if epoch % self.cycle_every_n_epoch == 0 and epoch > 0:
self.cyclic_steps = self.warmup - 1
if self.rank == 0 and epoch > 0 and epoch % self.checkpoint_interval == 0:
self.logger.info(f"Saving checkpoints {self.current_iter}")
self.save_training()
self.current_iter += 1
self.average_weights()
def save_training(self):
self.logger.info(f"Saving iter {self.current_iter}...")
checkpoint = self.bare_model.get_checkpoint()
pretrained_paths = {
x: MODEL_SAVE_PATH_FORMAT.format(self.proj_dir, self.exp_name, self.current_iter, x) for x in checkpoint.keys()
}
state = {
"pretrained_paths": pretrained_paths,
"optimizer": self.optimizer.state_dict(),
"current_iter": self.current_iter,
"cyclic_steps": self.cyclic_steps,
"global_step": self.global_step,
}
training_state_save_path = TRAINING_STATE_SAVE_PATH_FORMAT.format(self.proj_dir, self.exp_name, self.current_iter)
torch.save(state, training_state_save_path)
for ckpt_name, ckpt_state_dict in checkpoint.items():
torch.save(ckpt_state_dict, MODEL_SAVE_PATH_FORMAT.format(self.proj_dir, self.exp_name, self.current_iter, ckpt_name))
def resume_training(self):
if self.load_iter == -1:
pretrained_paths = self.opt["train"]["pretrained_paths"]
if self.rank == 0:
self.logger.info("Training model from scratch...")
self.logger.info(f"Loading pretrained models from {pretrained_paths.items()}")
self.bare_model.load_network(pretrained_paths)
self.current_iter = 0
self.cyclic_steps = 0
self.global_step = 0
else:
if self.rank == 0:
self.logger.info(f"Resuming training from iter {self.load_iter + 1}")
state_path = TRAINING_STATE_SAVE_PATH_FORMAT.format(self.proj_dir, self.exp_name, self.load_iter)
if not osp.isfile(state_path):
raise ValueError(f"Training state for iter {self.load_iter} not found")
state = torch.load(state_path, map_location="cpu")
self.cyclic_steps = state["cyclic_steps"] # we increment cyclic_step by 1 when batch begins
self.current_iter = state["current_iter"] + 1
self.global_step = state["global_step"] + 1
pretrained_paths = state["pretrained_paths"]
for path in pretrained_paths.values():
if not osp.isfile(path):
raise ValueError(f"{path} does not exist")
resume_optimizer = state["optimizer"]
self.optimizer.load_state_dict(resume_optimizer)
self.bare_model.load_network(pretrained_paths)
def average_weights(self):
weights_to_average = self.opt["model"]["weights_to_average"]
if len(weights_to_average) == 0:
return
first_model_path = MODEL_SAVE_PATH_FORMAT.format(self.proj_dir, self.exp_name, weights_to_average[0], "net")
avg_state_dict = torch.load(first_model_path, map_location="cpu")
for i in weights_to_average[1:]:
model_path = MODEL_SAVE_PATH_FORMAT.format(self.proj_dir, self.exp_name, i, "net")
curr_weights = torch.load(model_path, map_location="cpu")
for k, v in curr_weights.items():
avg_state_dict[k] = torch.cat((avg_state_dict[k], curr_weights[k]), dim=0)
for k, v in avg_state_dict.items():
avg_state_dict[k] = torch.mean(v, dim=0, keepdim=True)
avg_s = "".join([str(i) for i in weights_to_average])
out_path = MODEL_SAVE_PATH_FORMAT.format(self.proj_dir, self.exp_name, weights_to_average[-1], "net")
out_path = os.path.join(*out_path.split(os.path.sep)[:-1], f"avg_{avg_s}_net.pth")
torch.save(avg_state_dict, out_path)
def validation(self):
if self.rank != 0:
return
if self.val_dataloader is None:
self.logger.warning("No validation dataloader was given. Skipping validation...")
return
timer = training_utils.AvgTimer()
val_save_root = VAL_SAMPLE_SAVE_PATH.format(self.proj_dir, self.exp_name, self.current_iter)
os.makedirs(val_save_root, exist_ok=True)
self.logger.info("Evaluating metrics on validation set")
timer.start()
self.bare_model.eval()
tb_logs = self.bare_model.validate(self.val_dataloader, save_root=val_save_root)
for k, v in tb_logs.items():
self.writer.add_scalar(k, v, self.current_iter)
timer.record()
self.logger.info(f"Evaluation time: {timer.get_avg_time()}")
@click.command()
@click.option("--cfg_file", required=True, type=str, help="Config file path")
def main(cfg_file):
dist.init_process_group(backend="nccl")
world_size = torch.cuda.device_count()
gpu_id = int(os.environ["LOCAL_RANK"])
print("World_size:", world_size)
print("Rank", gpu_id)
print(30 * "-")
torch.cuda.set_device(gpu_id)
Trainer(world_size, gpu_id, cfg_file).training_loop()
if __name__ == "__main__":
main()