-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmetrics.py
More file actions
47 lines (40 loc) · 1.23 KB
/
metrics.py
File metadata and controls
47 lines (40 loc) · 1.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import re
flags = tf.app.flags
flags.DEFINE_string('log', None, 'Log file')
FLAGS = flags.FLAGS
def main(_):
metrics = {
'train': {},
'train_valid': {},
'valid': {},
}
template = r'([a-z_]*):\Wepoch\W(\d*)\Wloss\W([\d\.]*)'
with open(FLAGS.log, 'r') as file:
for line in file:
line = line.rstrip()
line_search = re.search(template, line)
phase = line_search.group(1)
epoch = int(line_search.group(2))
loss = float(line_search.group(3))
phase_metric = metrics[phase]
epoch_metric = phase_metric.get(epoch, [])
epoch_metric.append(loss)
phase_metric[epoch] = epoch_metric
mean_epoch_losses = {
'train': [],
'train_valid': [],
'valid': [],
}
plt.title(FLAGS.log)
for phase in metrics:
for epoch in sorted(metrics[phase].keys()):
mean_epoch_losses[phase].append(np.mean(metrics[phase][epoch]))
for phase in mean_epoch_losses:
plt.plot(mean_epoch_losses[phase], label=phase)
plt.legend()
plt.show()
if __name__ == '__main__':
tf.app.run()