-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
35 lines (33 loc) · 1 KB
/
train.py
File metadata and controls
35 lines (33 loc) · 1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import pickle
from pickle import UnpicklingError
import numpy as np
import tensorflow as tf
if __name__ == "__main__":
xTrain = []
yTrain = []
file = open('X.pickle', 'rb')
while True:
try:
xTrain.append(pickle.load(file))
except (EOFError, UnpicklingError):
break
file.close()
file = open('Y.pickle', 'rb')
while True:
try:
yTrain.append(pickle.load(file))
except (EOFError, UnpicklingError):
break
file.close()
xTrain = np.array(xTrain)
yTrain = np.array(yTrain)
inputs = tf.keras.Input(shape=(9,))
x = tf.keras.layers.Dense(32, activation='relu')(inputs)
x = tf.keras.layers.Dense(32, activation='relu')(x)
predictions = tf.keras.layers.Dense(2)(x)
model = tf.keras.Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='adam',
loss='mse',
metrics=['accuracy'])
model.fit(xTrain, yTrain, epochs=200)
model.save('./my_model')