-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathagent.py
More file actions
executable file
·44 lines (32 loc) · 1.34 KB
/
agent.py
File metadata and controls
executable file
·44 lines (32 loc) · 1.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
class Agent:
def __init__(self, action_space, args):
self.args = args
self.endeavor = action_space # [0, 1, 2, ...]
self.action = 0 # index: 0 or 1 or 2 or ...
self.q_table = np.zeros_like(self.endeavor)
self.beta_table = self.softmax(self.q_table)
def softmax(self, x):
if not isinstance(x, np.ndarray):
x = np.array(x)
x = x / self.args.temperature # temperature scaling
e_x = np.exp(x - np.max(x)) # prevent overflow
return e_x / np.sum(e_x)
def get_action(self, deterministic=False):
"""
:param deterministic: True일 경우 결정론적으로 가장 높은 확률을 가진 action이 선출됨. 아닐 경우 확률적으로 결정.
:return: 에이전트의 액션.
"""
if deterministic:
b = np.array(self.beta_table)
action = np.random.choice(np.flatnonzero(b == b.max()))
else:
action = np.random.choice(self.endeavor, 1, p=self.beta_table)
action = int(action)
self.action = action
return action
def learn(self, action, reward):
q1 = self.q_table[action]
q2 = reward
self.q_table[action] += self.args.lr * (q2 - q1) / self.beta_table[action]
self.beta_table = self.softmax(self.q_table)