-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel_training.py
More file actions
31 lines (26 loc) · 1.17 KB
/
model_training.py
File metadata and controls
31 lines (26 loc) · 1.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
import numpy as np
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.estimator import regression
import network_module
EPOCHS = 3
LEARNING_RATE = 1e-3
TRAINING_DATA_FILE = 'basic1_training_data-17-34.npy'
MODEL_NAME = 'basic'
VERSION = 3
def train_model(training_data, model=None):
X = np.array([np.resize(i[0],(len(training_data[0][0]) * len(training_data[0][0][0]), 1)) for i in training_data])
Y = [i[1] for i in training_data]
print("Input Length: " + str(len(X[0])))
print("Output Length: " + str(len(Y[0])))
if model is None:
model = network_module.neural_network_model(input_size=len(X[0]), output_size=len(Y[0]), lr=LEARNING_RATE)
model.fit({'input':X}, {'targets':Y}, n_epoch=EPOCHS, snapshot_step=500, show_metric=True,
run_id=MODEL_NAME+str(VERSION))
return model
training_data = np.load('training_data\\' + TRAINING_DATA_FILE, allow_pickle=True)
training_data2 = np.load('training_data\\cf_training_data-17-36.npy', allow_pickle=True)
model = train_model(training_data)
model = train_model(training_data2, model)
model.save(MODEL_NAME+str(VERSION))