-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathplot_diff_lib3.py
More file actions
58 lines (45 loc) · 2.04 KB
/
plot_diff_lib3.py
File metadata and controls
58 lines (45 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from scipy import signal
import matplotlib.pyplot as plt
import pandas as pd
from os.path import join
from sys import argv
from collections import OrderedDict
b, a = signal.butter(1, 0.15)
agent_name_list = ['HER', 'HER+Skill Set 1','HER+Skill Set 2','HER+Skill Set 3']
linestyles = OrderedDict(
[('solid', (0, ())),
('loosely dotted', (0, (1, 10))),
('dotted', (0, (1, 5))),
('densely dotted', (0, (1, 1))),
('loosely dashed', (0, (5, 10))),
('dashed', (0, (5, 5))),
('densely dashed', (0, (5, 1))),
('loosely dashdotted', (0, (3, 10, 1, 10))),
('dashdotted', (0, (3, 5, 1, 5))),
('densely dashdotted', (0, (3, 1, 1, 1))),
('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
('dashdotdotted', (0, (3, 5, 1, 5, 1, 5))),
('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))])
# putinb
env_name = "putainb-v0" #argv[1]
env_name = env_name.split('-')
dir_list = ['/Users/virtualworld/new_RL3/corl_paper_results/clusters-v1/%s-%s/run1'%(env_name[0], env_name[1]),
'/Users/virtualworld/new_RL3/corl_paper_results/clusters-v1/%sflat-%s/run1'%(env_name[0], env_name[1]),
'/Users/virtualworld/new_RL3/corl_paper_results/clusters-v1/%sflat-%s/run16'%(env_name[0], env_name[1]),
'/Users/virtualworld/new_RL3/corl_paper_results/clusters-v1/%sflat-%s/run41'%(env_name[0], env_name[1])]
for i, dirname in enumerate(dir_list):
if i == 0:
currlinestyle = 'densely dotted'
elif i ==2:
currlinestyle = 'dashdotted'
elif i ==3:
currlinestyle = 'densely dashdotdotted'
else:
currlinestyle = "densely dashed"
data = pd.read_csv(join(dirname , "progress.csv")).fillna(0.0)
filtered_data = signal.filtfilt(b,a, data["eval/success"])
plt.plot(data["total/epochs"], filtered_data, linewidth = 3.0, linestyle = linestyles[currlinestyle])
plt.legend(agent_name_list)
plt.xlabel('Number of epochs')
plt.ylabel('Success %')
plt.show()