-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
176 lines (116 loc) · 4.99 KB
/
agent.py
File metadata and controls
176 lines (116 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import cfg
import os
import torch
import torch.nn as nn
import numpy as np
from collections import deque
from random import sample
torch.manual_seed(0)
np.random.seed(1)
class Agent:
def __init__(self, env,
greedy_epsilon = cfg.EPSILON) -> None:
self.env = env
self.gamma = cfg.GAMMA
self.epsilon = greedy_epsilon
self.epsilon_min = cfg.EPSILON_MIN
self.epsilon_decay = cfg.EPSILON_DECAY
self.learning_rate = cfg.SGD_LEARNING_RATE
self.memory = deque(maxlen=cfg.MAX_MEMORY_SIZE)
# size of the Tetris state
self.state_size = env.state_size
# each state gets one overall rating
self.size_of_state_rating = 1
self.build_NN()
self.load()
def build_NN(self):
# He initialization
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
nn.init.constant_(m.bias, 0)
self.model = nn.Sequential(
# input layer goes to hidden layer with 42 neurons
nn.Linear(self.state_size, 42),
# Will use the ReLU activation function for transition to next layer:
# z^(L) = w^(L) * a^(L-1) + b(L)
# a^(L) = ReLU(z^(L))
nn.ReLU(),
# hidden layer 1 goes to hidden layer 2 (arbitrarily chose two hidden layers each with size 42)
nn.Linear(42, 42),
# also uses ReLU activation function
nn.ReLU(),
# hidden layer 2 goes to output layer
nn.Linear(42, self.size_of_state_rating)
)
self.model.apply(init_weights)
# Will use Mean Squared Error Loss function
self.loss_fn = nn.MSELoss()
# Adam optimizer (b/c it has adaptive learning rates)
self.optimizer = torch.optim.Adam(self.model.parameters(), self.learning_rate)
# returns a list of length 2 as:
# [actions to be taken, resulting state]
def choose_action(self, next_states: dict) -> list:
# unpacks dictionary into keys and values
next_actions, corresponding_states = zip(*next_states.items())
if np.random.rand() <= self.epsilon:
ind = np.random.choice(len(next_states))
print("random")
print("-------------------------------")
return [next_actions[ind], corresponding_states[ind]]
# don't want calculate gradients b/c we're not backpropagating at this step,
# just finding out which action to take given the current nn and all possible next states
self.model.eval()
with torch.no_grad():
# makes a forward pass through model with the given states
q_vals = self.model(torch.tensor(corresponding_states, dtype=torch.float32))
# flattens output into 1d tensor
q_vals = torch.flatten(q_vals)
ind = torch.argmax(q_vals).item()
print(q_vals)
print(ind)
print("---------------------")
# setting back to training mode
self.model.train()
return [next_actions[ind], corresponding_states[ind]]
def store_in_memory(self, transition):
self.memory.append(transition)
def lower_greedy_epsilon(self):
if self.epsilon > (cfg.PRECISION - self.epsilon_min):
self.epsilon = self.epsilon * self.epsilon_decay
else:
self.epsilon = self.epsilon_min
def learn(self, batch):
batch_states = []
batch_targets = []
for transition in batch:
s, a, r, ns, done = transition
self.model.eval()
with torch.no_grad():
if done:
target = r
else:
pred = torch.flatten(self.model(torch.tensor(ns, dtype=torch.float32)))
target = r + self.gamma * pred.item()
self.model.train()
batch_states.append(s)
batch_targets.append(target)
self.lower_greedy_epsilon()
self.optimizer.zero_grad()
output = self.model(torch.tensor(batch_states, dtype=torch.float32))
loss = self.loss_fn(output, torch.tensor(batch_targets, dtype=torch.float32).unsqueeze(1))
loss.backward()
self.optimizer.step()
return loss.item()
def replay(self):
batch = sample(self.memory, cfg.BATCH_SIZE)
return self.learn(batch)
def save(self):
torch.save({"model_state_dict" : self.model.state_dict(),
"optim_state_dict" : self.optimizer.state_dict(),
}, cfg.CHECKPOINT_FILE_PATH)
def load(self):
if os.path.isfile(cfg.CHECKPOINT_FILE_PATH):
checkpoint = torch.load(cfg.CHECKPOINT_FILE_PATH)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optim_state_dict"])