-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict_server.py
More file actions
67 lines (51 loc) · 2.07 KB
/
predict_server.py
File metadata and controls
67 lines (51 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from keras.preprocessing import image
from tensorflow import keras
from flask import Flask, request, Response
from urllib import request as urlrequest
from PIL import Image
from io import BytesIO
from waitress import serve
import json
import numpy as np
import pandas as pd
import time
cloudinary_root = 'https://res.cloudinary.com/dl7hskxab/image/upload/v1623338718/inducks-covers/'
datasets = ['published-fr-recent']
models = {dataset: keras.models.load_model(f'input/{dataset}/model.keras') for dataset in datasets}
app = Flask(__name__)
@app.route('/', methods=['GET'])
def alive():
return Response("I''m alive", 200)
@app.route('/predict', methods=['POST'])
def predict():
start_time = time.time()
request_data = request.get_json()
url = request_data["url"]
if not request_data["dataset"] in datasets:
return Response(f'Invalid dataset: {request_data["dataset"]}', 400)
artists = pd.read_csv(f'input/{request_data["dataset"]}/artists_popular.csv')
artists_top = artists[artists['drawings'] >= 200]
artists_top_name = artists_top['personcode'].values
url = f"{cloudinary_root}{url}"
res = urlrequest.urlopen(url).read()
test_image = Image.open(BytesIO(res)).resize((224, 224))
# Predict artist
test_image = image.img_to_array(test_image)
test_image /= 255.
test_image = np.expand_dims(test_image, axis=0)
prediction = models[request_data["dataset"]].predict(test_image)
prediction_probability = np.amax(prediction)
prediction_idx = np.argmax(prediction)
if prediction_idx >= len(artists_top_name):
message = f'Index {prediction_idx} is not in {artists_top_name}'
print(message)
return Response(message, 400)
predicted_artist = artists_top_name[prediction_idx].replace('_', ' ')
print(f"Predicted {predicted_artist} in {time.time() - start_time}s")
return Response(json.dumps({
"url": url,
"predicted": predicted_artist,
"predictionProbability": prediction_probability * 100
}))
if __name__ == "__main__":
serve(app, host="0.0.0.0", port=8080)