-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
90 lines (78 loc) · 3.36 KB
/
inference.py
File metadata and controls
90 lines (78 loc) · 3.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding:utf-8 -*-
import cv2
import time
import argparse
import numpy as np
from PIL import Image
from utils.anchor_generator import generate_anchors
from utils.anchor_decode import decode_bbox
from utils.nms import single_class_non_max_suppression
from load_model.tensorflow_loader import load_tf_model, tf_inference
sess, graph = load_tf_model('models/face_mask_detection.pb')
# anchor configuration
feature_map_sizes = [[33, 33], [17, 17], [9, 9], [5, 5], [3, 3]]
anchor_sizes = [[0.04, 0.056], [0.08, 0.11], [0.16, 0.22], [0.32, 0.45], [0.64, 0.72]]
anchor_ratios = [[1, 0.62, 0.42]] * 5
# generate anchors
anchors = generate_anchors(feature_map_sizes, anchor_sizes, anchor_ratios)
# for inference , the batch size is 1, the model output shape is [1, N, 4],
# so we expand dim for anchors to [1, anchor_num, 4]
anchors_exp = np.expand_dims(anchors, axis=0)
id2class = {0: 'Mask', 1: 'No Face Mask'}
def inference(image,
conf_thresh=0.5,
iou_thresh=0.4,
target_shape=(160, 160),
draw_result=True,
show_result=True
):
'''
Main function of detection inference
:param image: 3D numpy array of image
:param conf_thresh: the min threshold of classification probabity.
:param iou_thresh: the IOU threshold of NMS
:param target_shape: the model input size.
:param draw_result: whether to daw bounding box to the image.
:param show_result: whether to display the image.
:return:
'''
# image = np.copy(image)
output_info = []
height, width, _ = image.shape
image_resized = cv2.resize(image, target_shape)
image_np = image_resized / 255.0
image_exp = np.expand_dims(image_np, axis=0)
y_bboxes_output, y_cls_output = tf_inference(sess, graph, image_exp)
# remove the batch dimension, for batch is always 1 for inference.
y_bboxes = decode_bbox(anchors_exp, y_bboxes_output)[0]
y_cls = y_cls_output[0]
# To speed up, do single class NMS, not multiple classes NMS.
bbox_max_scores = np.max(y_cls, axis=1)
bbox_max_score_classes = np.argmax(y_cls, axis=1)
# keep_idx is the alive bounding box after nms.
keep_idxs = single_class_non_max_suppression(y_bboxes,
bbox_max_scores,
conf_thresh=conf_thresh,
iou_thresh=iou_thresh,
)
for idx in keep_idxs:
conf = float(bbox_max_scores[idx])
class_id = bbox_max_score_classes[idx]
bbox = y_bboxes[idx]
# clip the coordinate, avoid the value exceed the image boundary.
xmin = max(0, int(bbox[0] * width))
ymin = max(0, int(bbox[1] * height))
xmax = min(int(bbox[2] * width), width)
ymax = min(int(bbox[3] * height), height)
if draw_result:
if class_id == 0:
color = (0, 255, 0)
else:
color = (255, 0, 0)
cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
cv2.putText(image, "%s: %.2f" % (id2class[class_id], conf), (xmin + 1, ymin - 1),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color)
output_info.append([class_id, conf, xmin, ymin, xmax, ymax])
if show_result:
Image.fromarray(image).show()
return output_info