Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
4 changes: 3 additions & 1 deletion django_project/akangatu/core/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from .preprocessor import Preprocessor

preprocessor = Preprocessor()

class Predictor:
_model = None

Expand All @@ -15,7 +17,7 @@ def predict(self, pixel_array):
if not self._model:
self._load_model()

pixels = Preprocessor.reshape_data(pixel_array)
pixels = preprocessor.image_to_mnist(pixel_array)

probs = self._model.predict(pixels)
return probs.argmax(axis=1)
39 changes: 36 additions & 3 deletions django_project/akangatu/core/preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,39 @@
import numpy as np


MNIST_SIZE = 28
COLOR_SCALE = 255

class Preprocessor:
def reshape_data(request):
pixel_array = np.array(request)
return pixel_array.reshape(-1, 1, 28, 28) / 255.0
def image_to_mnist(self, image):
pixel_array = self._aggregate_to_mnist_array(image)

return self._reshape_data(pixel_array) / COLOR_SCALE

def _reshape_data(self, pixel_array):
return pixel_array.reshape(1, 1, MNIST_SIZE, MNIST_SIZE)

def _get_group_sizes(self, image):
image_width = len(image[0])
image_height = len(image)

return image_height // MNIST_SIZE, image_width // MNIST_SIZE

def _aggregate_to_mnist_array(self, image):
mnist_array = np.zeros([MNIST_SIZE, MNIST_SIZE])

for i in range(MNIST_SIZE):
for j in range(MNIST_SIZE):
mnist_array[i][j] = self._calculate_mnist_pixel(image,i,j)

return mnist_array

def _calculate_mnist_pixel(self, image,line,column):
group_height, group_width = self._get_group_sizes(image)

rgb_sum = 0
for i in range(group_height):
for j in range(group_width):
rgb_sum += sum(image[line*group_height+i][column*group_width+j])//len(image[line][column])

return rgb_sum / (group_width * group_height)
7 changes: 4 additions & 3 deletions django_project/akangatu/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from django_quicky import routing
from django.views.decorators.csrf import csrf_exempt
import json
import matplotlib.image as img

from .models import Predictor

Expand All @@ -25,11 +26,11 @@ def predict(request):
if request.method.upper() != 'POST':
return HttpResponseNotAllowed(['POST']) # List of allowed ones

json_data = request.POST.dict() or json.loads(request.body.decode('utf-8'))
image_file = next(iter(request.FILES.values()))

print(json_data)
image = img.imread(image_file)

pred = predictor_acessor.predict(json_data['pixels'])
pred = predictor_acessor.predict(image)

answer = {'label': int(pred[0]) }

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ Keras==2.2.4
Keras-Applications==1.0.6
Keras-Preprocessing==1.0.5
Markdown==3.0.1
matplotlib==3.0.2
numpy==1.15.4
Pillow==5.3.0
protobuf==3.6.1
pytz==2018.7
PyYAML==3.13
Expand Down