-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
74 lines (58 loc) · 1.74 KB
/
train.py
File metadata and controls
74 lines (58 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Neural network training process.
# Starts by loading 'latest_network.h5'.
# Reads training data from 'training_data.log' and trains the network on it.
# Writes the trained network back to 'latest_network.h5'. Keeps doing this until stopped.
import config
import model
import numpy as np
import torch
import json
import time
def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = model.Model(device)
try:
net.load()
except FileNotFoundError:
net.store()
print("Initialised new model")
print("loading training data")
try:
with open("training_data.log", "r") as f:
lines = f.readlines()
except FileNotFoundError:
print("No training data")
return []
lines = lines[-config.last_N_games:]
data = [[json.loads(x) for x in line.strip().split('\t')] for line in lines]
print("done loading")
xs = [l[0][0] for l in data]
values = [l[1] for l in data]
probs = [l[2] for l in data]
xs = torch.Tensor(xs)
values = torch.Tensor(values)
probs = torch.Tensor(probs)
print(xs.shape)
print(values.shape)
print(probs.shape)
hist = net.train_model([xs, probs, values], config.train_epochs)
net.store()
return hist
import gc
def main():
try:
log = open('loss.tsv', 'r')
log.close()
except FileNotFoundError:
with open("loss.tsv", "w") as log:
log.write("total\tprob\tvalue\n")
#while True:
losses = train()
with open("loss.tsv", "a") as log:
for total, prob, value in losses:
log.write(f"{total}\t{prob}\t{value}\n")
# gc.collect()
# torch.cuda.empty_cache()
# print(losses)
if __name__ == "__main__":
main()