-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathpolicy.py
More file actions
120 lines (90 loc) · 3.85 KB
/
policy.py
File metadata and controls
120 lines (90 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
from mushroom.policy.td_policy import TDPolicy
from mushroom.utils.parameters import Parameter
class BootPolicy(TDPolicy):
def __init__(self, n_approximators, epsilon=None):
if epsilon is None:
epsilon = Parameter(0.)
super(BootPolicy, self).__init__()
self._n_approximators = n_approximators
self._epsilon = epsilon
self._evaluation = False
self._idx = None
def draw_action(self, state):
if not np.random.uniform() < self._epsilon(state):
if self._evaluation:
if isinstance(self._approximator.model, list):
q_list = list()
for q in self._approximator.model:
q_list.append(q.predict(state))
else:
q_list = self._approximator.predict(state)
max_as, count = np.unique(np.argmax(q_list, axis=1),
return_counts=True)
max_a = np.array([max_as[np.random.choice(
np.argwhere(count == np.max(count)).ravel())]])
return max_a
else:
q = self._approximator.predict(state, idx=self._idx)
max_a = np.argwhere(q == np.max(q)).ravel()
if len(max_a) > 1:
max_a = np.array([np.random.choice(max_a)])
return max_a
else:
return np.array([np.random.choice(self._approximator.n_actions)])
def set_epsilon(self, epsilon):
self._epsilon = epsilon
def set_eval(self, eval):
self._evaluation = eval
def set_idx(self, idx):
self._idx = idx
def update_epsilon(self, state):
self._epsilon(state)
class WeightedPolicy(TDPolicy):
def __init__(self, n_approximators, epsilon=None):
if epsilon is None:
epsilon = Parameter(0.)
super(WeightedPolicy, self).__init__()
self._n_approximators = n_approximators
self._epsilon = epsilon
self._evaluation = False
def draw_action(self, state):
if not np.random.uniform() < self._epsilon(state):
if self._evaluation:
if isinstance(self._approximator.model, list):
q_list = list()
for q in self._approximator.model:
q_list.append(q.predict(state))
else:
q_list = self._approximator.predict(state)
max_as, count = np.unique(np.argmax(q_list, axis=1),
return_counts=True)
max_a = np.array([max_as[np.random.choice(
np.argwhere(count == np.max(count)).ravel())]])
return max_a
else:
if isinstance(self._approximator.model, list):
q_list = list()
for i in range(self._n_approximators):
q_list.append(self._approximator.predict(state, idx=i))
else:
q_list = self._approximator.predict(state)
qs = np.array(q_list)
qs = ((qs.T - qs.mean(1)) / qs.std(1)).T # Q STANDARDIZATION
samples = np.ones(self._approximator.n_actions)
for a in range(self._approximator.n_actions):
idx = np.random.randint(self._n_approximators)
samples[a] = qs[idx, a]
max_a = np.array([np.argmax(samples)])
return max_a
else:
return np.array([np.random.choice(
self._approximator.n_actions)])
def set_epsilon(self, epsilon):
self._epsilon = epsilon
def set_eval(self, eval):
self._evaluation = eval
def set_idx(self, idx):
pass
def update_epsilon(self, state):
self._epsilon(state)