-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTrainer.py
More file actions
115 lines (97 loc) · 4.63 KB
/
Trainer.py
File metadata and controls
115 lines (97 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from MCTS import MCTS
import numpy as np
training_params = {
'num_iterations': 100,
'num_episodes': 100,
'update_threshold': 0.55,
'max_training_examples_per_iteration': 100000,
'num_of_validation_games': 50,
'max_complete_examples': 50,
}
class Trainer:
def __init__(self, game, nnet):
self.game = game
self.nnet = nnet
self.pnet = nnet.__class__(self.game)
self.args = training_params
self.mcts = MCTS(self.game, self.nnet, self.args)
self.complete_training_history = []
def executeEpisode(self):
training_data = []
board = self.game.getInitBoard()
current_player = 1
episode_number = 0
while True:
episode_number += 1
# use the MCTS to get the move probabilities
canonical_board = self.game.getCanonicalForm(board, current_player)
pi = self.mcts.getProbabilityDist(canonical_board)
sym = self.game.getSymmetries(canonical_board, pi)
for sym_board, sym_pi in sym:
training_data.append([sym_board, current_player, sym_pi])
# choose a random move based on pi
move = np.random.choice(len(pi), p=pi)
board, current_player = self.game.getNextState(board, current_player, move)
# see if the game is over which would be the reward
r = self.game.getGameEnded(board, current_player)
if r != 0:
return [(x[0], x[2], r * ((-1)**(current_player!=x[1]))) for x in training_data]
def train(self):
print("Starting the training")
for i in range(self.args['num_iterations']):
print("Training outer iteration: " + str(i))
training_examples = []
for j in range(self.args['num_episodes']):
print("Training inner iteration: " + str(j))
self.mcts = MCTS(self.game, self.nnet, self.args)
training_examples.append(self.executeEpisode())
if (len(training_examples) > self.args['max_training_examples_per_iteration']):
training_examples.pop_front()
self.complete_training_history.append(training_examples)
if(len(self.complete_training_history) > self.args['max_complete_examples']):
self.complete_training_history.pop_front()
self.nnet.save("nnetsave")
self.pnet.load("nnetsave")
training_examples = []
for example_sequence in self.complete_training_history:
training_examples.extend(example_sequence)
shuffle(training_examples)
print("Training nnet")
self.nnet.train(training_examples)
wins, draws, losses = validate(self.nnet, self.pnet)
print("Wins " + str(wins))
print("Losses " + str(losses))
print("Draws " + str(draws))
# only update if win percentage is greater than 55
if float(wins)/(float(wins+losses)) < self.args['update_threshold'] and wins + losses > 0:
self.nnet = self.nnet.load("nnetsave")
else:
self.nnet.save("nnetbest")
def validate(net1, net2):
'''
Determine win rate of net1
'''
# Make fresh MCTS for each nnet
net1MCTS = MCTS(self.game, self.net1, self.args)
net2MCTS = MCTS(self.game, self.net2, self.args)
# Functions for greedy moves
playerFunctions = {
1: lambda x: np.argmax(net1MCTS.getProbabilityDist(x, temp = 0)),
-1: lambda x: np.argmax(net2MCTS.getProbabilityDist(x, temp = 0))
}
# To store win rates 0 -> losses
wins = {-1:0, 0:0, 1:0}
current_starting_player = 1
current_opposing_player = -1
current_player = current_starting_player
for i in range(self.args['num_of_validation_games']):
current_player = current_starting_player
current_board = self.game.getInitBoard()
while self.game.getGameEnded(current_board, current_player) == False:
action = playerFunction[current_player](self.game.getCanonicalForm(current_board, current_player))
if self.game.getValidMoves(self.game.getCanonicalForm(current_board, current_player), 1)[action]==0:
break
current_board, current_player = self.game.getNextState(current_board, current_player, action)
wins[self.game.getGameEnded(current_board, 1)] += 1
current_starting_player, current_opposing_player = current_opposing_player, current_starting_player
return wins[1], wins[0], wins[-1]