-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathResNet_PyTorch.py
More file actions
93 lines (75 loc) · 3.05 KB
/
ResNet_PyTorch.py
File metadata and controls
93 lines (75 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import sys
sys.path.append('..')
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
with open('C:\\Users\\brain\\PycharmProjects\\Alpha-Zero-Neural-Network\\config.json') as json_data_file:
config = json.load(json_data_file)
class ConvBlock(nn.Module):
def __init__(self):
super(ConvBlock, self).__init__()
self.action_size = config["boardSize"] * config["boardSize"]
self.board_x = config["boardSize"]
self.board_y = config["boardSize"]
self.conv1 = nn.Conv2d(1, 512, 3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(512)
def forward(self, s):
s = s.view(-1, 1, self.board_x, self.board_y) # batch_size x channels x board_x x board_y
s = F.relu(self.bn1(self.conv1(s)))
return s
class ResBlock(nn.Module):
def __init__(self, inplanes=512, planes=512, stride=1, downsample=None):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
def forward(self, x):
residual = x
out = self.conv1(x)
out = F.relu(self.bn1(out))
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = F.relu(out)
return out
class OutBlock(nn.Module):
def __init__(self):
super(OutBlock, self).__init__()
self.action_size = config["boardSize"] * config["boardSize"]
self.board_x = config["boardSize"]
self.board_y = config["boardSize"]
self.conv = nn.Conv2d(512, 3, kernel_size=1) # value head
self.bn = nn.BatchNorm2d(3)
self.fc1 = nn.Linear(3 * self.board_x * self.board_y, 32)
self.fc2 = nn.Linear(32, 1)
self.conv1 = nn.Conv2d(512, 32, kernel_size=1) # policy head
self.bn1 = nn.BatchNorm2d(32)
self.logsoftmax = nn.LogSoftmax(dim=1)
self.fc = nn.Linear(self.board_x * self.board_y * 32, self.action_size)
def forward(self, s):
v = F.relu(self.bn(self.conv(s))) # value head
v = v.view(-1, 3 * self.board_x * self.board_y) # batch_size X channel X height X width
v = F.relu(self.fc1(v))
v = torch.tanh(self.fc2(v))
p = F.relu(self.bn1(self.conv1(s))) # policy head
p = p.view(-1, self.board_x * self.board_y * 32)
p = self.fc(p)
p = self.logsoftmax(p).exp()
return p, v
class NNet(nn.Module):
def __init__(self):
super(NNet, self).__init__()
self.conv = ConvBlock()
for block in range(config["blocks"]):
setattr(self, "res_%i" % block, ResBlock())
self.outblock = OutBlock()
def forward(self, s):
s = self.conv(s)
for block in range(config["blocks"]):
s = getattr(self, "res_%i" % block)(s)
s = self.outblock(s)
return s