-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
29 lines (25 loc) · 1.09 KB
/
train.py
File metadata and controls
29 lines (25 loc) · 1.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'}
from sys import exit
from tools.prepro import *
from tools.model import *
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras import models
from tensorflow.keras import layers
#model.evaluate(X_test,Y_test,verbose=1)
if __name__=="__main__":
pprint("Loading Data")
train,test,val_ds=load_data("data/pump/*/normal/*.wav","data/pump/*/abnormal/*.wav")
pprint("Building Model")
model=build_model(train)
print(model.summary())
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
EPOCHS = 200
stopper=tf.keras.callbacks.EarlyStopping(patience=3)
history = model.fit(train, validation_data=val_ds, epochs=EPOCHS, callbacks=stopper)
model.save('saved_model')
#metrics = history.history
#hist_dat=np.array([[history.epoch],[metrics['loss']],[metrics['loss']]])
#np.savetxt("train_data.csv",hist_dat,delimiter=",")