-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel.py
More file actions
79 lines (67 loc) · 3.06 KB
/
model.py
File metadata and controls
79 lines (67 loc) · 3.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import config
class RNNAgent(nn.Module):
def __init__(self):
super(RNNAgent, self).__init__()
self.fc1 = nn.Linear(config.input_shape, config.rnn_hidden_dim)
self.rnn = nn.GRUCell(config.rnn_hidden_dim, config.rnn_hidden_dim)
self.fc2 = nn.Linear(config.rnn_hidden_dim, config.n_actions)
def init_hidden(self):
# make hidden states on same device as model
return self.fc1.weight.new(1, config.rnn_hidden_dim).zero_()
def forward(self, inputs, hidden_state):
x = F.relu(self.fc1(inputs))
h_in = hidden_state.reshape(-1, config.rnn_hidden_dim)
h = self.rnn(x, h_in)
q = self.fc2(h)
return q, h
class QMixer(nn.Module):
def __init__(self):
super(QMixer, self).__init__()
self.n_agents = config.n_agents
self.state_dim = int(np.prod(config.state_shape))
self.embed_dim = config.mixing_embed_dim
if config.hypernet_layers == 1:
self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents)
self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim)
if config.hypernet_layers == 2:
hypernet_embed = config.hypernet_embed
self.hyper_w_1 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
nn.ReLU(),
nn.Linear(hypernet_embed, self.embed_dim * self.n_agents))
self.hyper_w_final = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
nn.ReLU(),
nn.Linear(hypernet_embed, self.embed_dim))
elif config.hypernet_layers > 2:
raise Exception("Sorry >2 hypernet layers is not implemented!")
else:
raise Exception("Error setting number of hypernet layers.")
# State dependent bias for hidden layer
self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)
# V(s) instead of a bias for the last layers
self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim),
nn.ReLU(),
nn.Linear(self.embed_dim, 1))
def forward(self, agent_qs, states):
bs = agent_qs.size(0)
states = states.reshape(-1, self.state_dim)
agent_qs = agent_qs.view(-1, 1, self.n_agents)
# First layer
w1 = th.abs(self.hyper_w_1(states))
b1 = self.hyper_b_1(states)
w1 = w1.view(-1, self.n_agents, self.embed_dim)
b1 = b1.view(-1, 1, self.embed_dim)
hidden = F.elu(th.bmm(agent_qs, w1) + b1)
# Second layer
w_final = th.abs(self.hyper_w_final(states))
w_final = w_final.view(-1, self.embed_dim, 1)
# State-dependent bias
v = self.V(states).view(-1, 1, 1)
# Compute final output
y = th.bmm(hidden, w_final) + v
# Reshape and return
q_tot = y.view(bs, -1, 1)
return q_tot