diff --git a/run_keras_server.py b/run_keras_server.py index d20d79c..53b56ce 100644 --- a/run_keras_server.py +++ b/run_keras_server.py @@ -11,6 +11,7 @@ from keras.preprocessing.image import img_to_array from keras.applications import imagenet_utils from PIL import Image +import tensorflow as tf import numpy as np import flask import io @@ -23,8 +24,9 @@ def load_model(): # load the pre-trained Keras model (here we are using a model # pre-trained on ImageNet and provided by Keras, but you can # substitute in your own networks just as easily) - global model + global model, graph model = ResNet50(weights="imagenet") + graph = tf.get_default_graph() def prepare_image(image, target): # if the image mode is not RGB, convert it @@ -58,7 +60,8 @@ def predict(): # classify the input image and then initialize the list # of predictions to return to the client - preds = model.predict(image) + with graph.as_default(): + preds = model.predict(image) results = imagenet_utils.decode_predictions(preds) data["predictions"] = [] @@ -80,4 +83,4 @@ def predict(): print(("* Loading Keras model and Flask starting server..." "please wait until server has fully started")) load_model() - app.run() \ No newline at end of file + app.run()