diff --git a/run_keras_server.py b/run_keras_server.py index d20d79c..44bfc15 100644 --- a/run_keras_server.py +++ b/run_keras_server.py @@ -10,6 +10,7 @@ from keras.applications import ResNet50 from keras.preprocessing.image import img_to_array from keras.applications import imagenet_utils +import tensorflow as tf from PIL import Image import numpy as np import flask @@ -25,6 +26,9 @@ def load_model(): # substitute in your own networks just as easily) global model model = ResNet50(weights="imagenet") + global graph + # save graph after loading ResNet50 weights + graph = tf.get_default_graph() def prepare_image(image, target): # if the image mode is not RGB, convert it @@ -58,7 +62,9 @@ def predict(): # classify the input image and then initialize the list # of predictions to return to the client - preds = model.predict(image) + # use the same graph saved after loading the model + with graph.as_default(): + preds = model.predict(image) results = imagenet_utils.decode_predictions(preds) data["predictions"] = [] @@ -80,4 +86,4 @@ def predict(): print(("* Loading Keras model and Flask starting server..." "please wait until server has fully started")) load_model() - app.run() \ No newline at end of file + app.run()