-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstatistics.py
More file actions
executable file
·49 lines (36 loc) · 1.44 KB
/
statistics.py
File metadata and controls
executable file
·49 lines (36 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
class Statistics():
"""
Save statistics during training/testing
Draw curves
"""
def __init__(self):
self.path = "plots/"
self.episodes = []
self.rewards = []
self.explore_probabilities = []
self.losses = []
def add_episode_stats(self, episode, total_reward, explore_probability, loss):
self.episodes.append(episode)
self.rewards.append(total_reward)
self.explore_probabilities.append(explore_probability)
self.losses.append(loss)
if self.episodes[-1] % 5 == 0 and self.episodes[-1] > 0:
self.save_plots()
def save_plot(self, title, xlabel, ylabel, values):
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.plot(np.array(self.episodes), np.array(values))
plt.savefig(self.path + ylabel + str(self.episodes[-1]) + '.png')
plt.close()
def save_plots(self):
self.save_plot("Evolution of reward over episodes",
"Episodes", "Rewards", self.rewards)
self.save_plot("Evolution of loss over episodes",
"Episodes", "Loss", self.losses)
self.save_plot("Exploration probability over episodes",
"Episodes", "exploration_probability", self.explore_probabilities)