-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtrain.py
More file actions
45 lines (36 loc) · 1.42 KB
/
train.py
File metadata and controls
45 lines (36 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import time
from data.data_loader import get_data_loader
from models.models import create_model
from option_parser import TrainingOptionParser
from utils.visualizer import Visualizer
parser = TrainingOptionParser()
opt = parser.parse_args()
data_loader = get_data_loader(opt)
print("[INFO] batch size : {}".format(opt.batch_size))
print("[INFO] training batches : {}".format(len(data_loader)))
model = create_model(opt)
visualizer = Visualizer(opt)
total_steps = 0
epoch_count = 0
for epoch in range(opt.epoch):
epoch_start_time = time.time()
iter_count = 0
for i, data in enumerate(data_loader):
batch_start_time = time.time()
total_steps += opt.batch_size
iter_count += opt.batch_size
# data : list
model.set_input(data[0])
model.optimize_parameters()
batch_end_time = time.time()
if iter_count % opt.print_freq == 0:
errors = model.get_losses()
visualizer.print_current_errors(epoch, iter_count, errors, (batch_end_time - batch_start_time))
if total_steps % opt.plot_freq == 0:
save_result = total_steps % opt.plot_freq == 0
visualizer.display_current_results(model.get_visuals(), int(total_steps/opt.plot_freq), save_result)
if opt.display_id > 0:
visualizer.plot_current_errors(epoch, total_steps, errors)
model.remove(epoch_count)
epoch_count += 1
model.save(epoch_count)