-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
147 lines (124 loc) · 5.02 KB
/
api.py
File metadata and controls
147 lines (124 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from fastapi import FastAPI, File, UploadFile, HTTPException
from torchvision import transforms
import onnxruntime
import numpy as np
import io
from PIL import Image
import uvicorn
import logging
import sys
import requests
# For logging options see
# https://docs.python.org/3/library/logging.html
logging.basicConfig(filename='api_log.log', filemode='w', format='%(asctime)s %(message)s', datefmt='%d/%m/%Y %H:%M:%S', level=logging.INFO)
# Path to pretrained model
MODEL_PATH = './model/empty_v5_24_08_23.onnx'
# Input image size
IMG_SIZE = 224
# Transformations used for input images
img_transforms = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Predicted class labels
classes = {0: 'empty', 1: 'ok' }
try:
# Initialize API Server
app = FastAPI()
except Exception as e:
logging.error('Failed to start the API server: %s' % e)
sys.exit(1)
# Function is run (only) before the application starts
@app.on_event("startup")
async def load_model():
"""
Load the pretrained model on startup.
"""
try:
# Load the onnx model and the trained weights
model = onnxruntime.InferenceSession(MODEL_PATH)
# Add model to app state
app.package = {"model": model}
except Exception as e:
logging.error('Failed to load the model file: %s' % e)
raise HTTPException(status_code=500, detail='Failed to load the model file: %s' % e)
def softmax(x):
return(np.exp(x)/np.exp(x).sum())
def predict(image):
"""
Perform prediction on input image.
"""
# Get model from app state
model = app.package["model"]
image = img_transforms(image.convert("RGB")).unsqueeze(0)
# Transform tensor to numpy array
img = image.detach().cpu().numpy()
input = {model.get_inputs()[0].name: img}
# Run model prediction
output = model.run(None, input)
# Get predicted class
pred = np.argmax(output[0], 1)
pred_class = pred.item()
# Get the confidence value for the prediction
pred_confidences = softmax(output[0][0])[pred_class]
# Confidence of the prediction as %
pred_confidences = float(pred_confidences)
# Return predicted class and confidence in dictionary form
predictions = {'prediction': classes[pred_class], 'confidence': pred_confidences}
return predictions
# Endpoint for POST requests: input image is received with the http request
@app.post("/empty")
async def postit(file: UploadFile = File(...)):
try:
# Loads the image sent with the POST request
req_content = await file.read()
image = Image.open(io.BytesIO(req_content)).convert('RGB')
image.draft('RGB', (IMG_SIZE, IMG_SIZE))
except Exception as e:
logging.error('Failed to load the input image file: %s' % e)
raise HTTPException(status_code=400, detail='Failed to load the input image file: %s' % e)
# Get predicted class and confidence
try:
predictions = predict(image)
except Exception as e:
logging.error('Failed to analyze the input image file: %s' % e)
raise HTTPException(status_code=500, detail='Failed to analyze the input image file: %s' % e)
return predictions
# Endpoint for GET requests: input image path is received with the http request
@app.get("/emptypath")
async def postit_url(path: str):
try:
# Loads the image from the path sent with the GET request
image = Image.open(path).convert('RGB')
image.draft('RGB', (IMG_SIZE, IMG_SIZE))
except Exception as e:
logging.error('Failed to recognize file %s as an image. Error: %s' % (path, e))
raise HTTPException(status_code=400, detail='Failed to load the input image file: %s' % e)
# Get predicted class and confidence
try:
predictions = predict(image)
except Exception as e:
logging.error('Failed to analyze the input image file: %s' % e)
raise HTTPException(status_code=500, detail='Failed to analyze the input image file: %s' % e)
return predictions
# Endpoint for GET requests: input image path is received with the http request
@app.get("/emptyurl")
async def postit_url(url: str):
try:
# Loads the image from the path sent with the GET request
req_content = requests.get(url)
image = Image.open(io.BytesIO(req_content.content)).convert('RGB')
image.draft('RGB', (IMG_SIZE, IMG_SIZE))
except Exception as e:
logging.error('Failed to recognize file %s as an image. Error: %s' % (url, e))
raise HTTPException(status_code=400, detail='Failed to load the input image file: %s' % e)
# Get predicted class and confidence
try:
predictions = predict(image)
except Exception as e:
logging.error('Failed to analyze the input image file: %s' % e)
raise HTTPException(status_code=500, detail='Failed to analyze the input image file: %s' % e)
return predictions
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=8000)