-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
123 lines (107 loc) · 5.46 KB
/
train.py
File metadata and controls
123 lines (107 loc) · 5.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
1.构建模型
2.构建数据集
3.构建优化器
4.构建logger和checkpoints saver
"""
import os
import shutil
import time
import torch
import argparse
from torch.utils.data import DataLoader
from utils import read_yml, setup_seed
from road_extractor import build_road_extractor
from dataset import build_dataloader
from schedule import build_schedule
from utils.logger import build_logger
def parse_args():
parser = argparse.ArgumentParser(description='Train a road extractor')
parser.add_argument('--config', default='configs/LRDNet_RNBD.yml', help='train config file path')
parser.add_argument('--work-dir', default=None, help='the dir to save logs and models')
parser.add_argument("--log_interval", type=int, default=50)
parser.add_argument("--save_every", type=int, default=1)
parser.add_argument("--resume_from", type=str, default=None)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()
return args
def print_log(logger, epoch, epochs, iters, total_iters, lr, loss_sum, loss_item, t_begin, interval):
eta = ((time.time() - t_begin) / interval) * (total_iters - iters) / 3600
logger.info('epoch:[{}/{}] iter:[{}/{}] eta:{:.3f}h lr:{:.5f} loss_sum:{:.4f} loss_edge:{:.4f} loss_surf:{:.4f} '
'loss_cent:{:.4f} '.format(epoch, epochs, iters, total_iters, eta, lr, loss_sum,
loss_item['loss_edge'], loss_item['loss_surf'], loss_item['loss_cent']))
# print('epoch:[{}/{}] iter:[{}/{}] eta:{:.3f}h lr:{:.5f} loss_sum:{:.4f} loss_edge:{:.4f} loss_surf:{:.4f} '
# 'loss_cent:{:.4f} '.format(epoch, epochs, iters, total_iters, eta, lr, loss_sum,
# loss_item['loss_edge'], loss_item['loss_surf'], loss_item['loss_cent']))
return time.time()
def print_log_no_edge(logger, epoch, epochs, iters, total_iters, lr, loss_sum, loss_item, t_begin, interval):
eta = ((time.time() - t_begin) / interval) * (total_iters - iters) / 3600
logger.info('epoch:[{}/{}] iter:[{}/{}] eta:{:.3f}h lr:{:.5f} loss_sum:{:.4f} loss_surf:{:.4f} '
'loss_cent:{:.4f} '.format(epoch, epochs, iters, total_iters, eta, lr, loss_sum,
loss_item['loss_surf'], loss_item['loss_cent']))
# print('epoch:[{}/{}] iter:[{}/{}] eta:{:.3f}h lr:{:.5f} loss_sum:{:.4f} loss_edge:{:.4f} loss_surf:{:.4f} '
# 'loss_cent:{:.4f} '.format(epoch, epochs, iters, total_iters, eta, lr, loss_sum,
# loss_item['loss_edge'], loss_item['loss_surf'], loss_item['loss_cent']))
return time.time()
def train(args, model, dataloader, optimizer, schedule, epochs, work_dir, iters=0, start_epoch=0):
total_iters = len(dataloader) * epochs
model.cuda()
logger = build_logger(work_dir + '/train.log', print_log=True)
for epoch in range(start_epoch, epochs):
model.train()
t_begin = time.time()
for i, meta in enumerate(dataloader):
optimizer.zero_grad()
meta = [data.cuda() for data in meta]
loss_sum, loss_item = model.forward_train(meta)
loss_sum.backward()
optimizer.step()
iters += 1
if iters % args.log_interval == 0:
t_begin = print_log_no_edge(logger, epoch, epochs, iters, total_iters, schedule.get_lr()[0],
loss_sum, loss_item, t_begin, args.log_interval)
schedule.step()
# 保存模型
if epochs % args.save_every == 0:
state = {
'iters': iters,
'epoch': epoch,
'work_dir': work_dir,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(state, work_dir + '/model{}_resume.pth'.format(epoch))
def main():
args = parse_args()
cfg = read_yml(args.config)
setup_seed(args.seed) # 固定随机种子
work_dir = args.work_dir if args.work_dir is not None else './work_dir/' + time.strftime("%Y-%m-%d-%H:%M:%S")
if not os.path.exists(work_dir):
os.makedirs(work_dir)
# 将训练用的配置文件拷贝到workdir
shutil.copy(args.config, work_dir)
# 设置device
torch.cuda.set_device(args.device)
model = build_road_extractor(cfg_model=cfg['model'])
dataset, dataloader = build_dataloader(cfg_dataset=cfg['dataset'])
optimizer, schedule = build_schedule(cfg['schedule'], model)
if args.resume_from is not None:
state = torch.load(args.resume_from)
iters = state['iters']
start_epoch = state['epoch'] + 1
work_dir = state['work_dir']
model.load_state_dict(state['state_dict'])
optimizer, schedule = build_schedule(cfg['schedule'], model, last_epoch=start_epoch-1)
optimizer.load_state_dict(state['optimizer'])
# optimizer加载参数时,tensor默认在CPU上,故需将所有的tensor都放到GPU上,否则:在optimizer.step()处报错
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.cuda()
train(args, model, dataloader, optimizer, schedule, cfg['epochs'], work_dir,
iters=iters, start_epoch=start_epoch)
else:
train(args, model, dataloader, optimizer, schedule, cfg['epochs'], work_dir)
if __name__ == '__main__':
main()