-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathSingleObjectTracking.py
More file actions
264 lines (236 loc) · 8.7 KB
/
SingleObjectTracking.py
File metadata and controls
264 lines (236 loc) · 8.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# -*- coding: utf-8 -*-
import numpy as np
import cv2
import os
from numba import jit
np.set_printoptions(threshold=np.inf)
# Adjustable 数据集名称 [bird、girl_in_the_sea、girl_in_the_alley、girl_in_the_garden]
DATASET_NAME = 'bird'
# Adjustable 采用的直方图类型
HIST_TYPE = 'BIN'
# Adjustable mean-shift最大迭代次数
MAX_ITER_NUM = 10
# Adjustable HSV直方图下,特征提取的通道(0-H, 1-S, 2-V)
HSV_CHANNEL = 2
# -------→ Col/x/width
# |
# |
# |
# ↓
# Row/y/height
# 加载图像数据集
def load_data_set(path):
first_window = np.loadtxt(path + './first_window.txt', delimiter=',', dtype=int)
_frames = []
file_names = os.listdir(path)
file_names.remove('first_window.txt')
file_names.sort(key=lambda x: int(x.replace(' ', '').split('.')[0]))
for i in range(len(file_names)):
filename = file_names[i]
if not (filename.endswith('jpg') or filename.endswith('png')):
continue
file_path = os.path.join(path, filename)
img = cv2.imread(file_path)
_frames.append(img)
return first_window, _frames
# 截取图像
@jit
def crop_image(_window, _image):
[col, row, width, height] = _window
return _image[row:row + height, col:col + width]
# 在指定图像上画框
@jit
def add_rectangle(_window, _image):
[col, row, width, height] = _window
return cv2.rectangle(np.copy(_image), (col, row), (col + width, row + height), 255, 2)
# Epanechnikov kernel
@jit
def e_kernel(x):
if x <= 1:
return 1 - x
else:
return 0
# 获取指定窗口各个像素点的权值(固定的)
# 对于mean_shift的优化,认为距离中心点越近的像素,对应的影响越大
@jit
def get_image_weights(width, height):
# 目标中心点
x0, y0 = int(height / 2), int(width / 2)
# 核函数窗口大小
h = x0 ** 2 + y0 ** 2
# 各个像素点的权值
weight = np.zeros((height, width), dtype=np.float32)
for i in range(height):
for j in range(width):
# 每个像素点的归一化像素位置
pos_normed = ((i - x0) ** 2 + (j - y0) ** 2) / h
weight[i, j] = e_kernel(pos_normed)
return weight
# 根据像素点位置获取直方图下标索引index
@jit
def get_hist_index(i, j, _object):
if HIST_TYPE == 'BIN':
# 注意OpenCv读取图像时,通道顺序为BGR
_R = np.fix(_object[i, j, 2] / 16)
_G = np.fix(_object[i, j, 1] / 16)
_B = np.fix(_object[i, j, 0] / 16)
# RGB像素值结合(范围0-4095)
index = int(_R * 256 + _G * 16 + _B)
elif HIST_TYPE == 'HSV':
cvt = cv2.cvtColor(np.copy(_object), cv2.COLOR_BGR2HSV)
# 当选择HSV直方图时,需要提供参数channel,代表选择的是H、S or V
index = min(cvt[i, j, HSV_CHANNEL], 255)
elif HIST_TYPE == 'GRAY':
cvt = cv2.cvtColor(np.copy(_object), cv2.COLOR_BGR2GRAY)
index = cvt[i, j]
else:
index = 0
return index
# 获取指定图像的概率直方图,直方图类型有三种:
# BIN-将RGB颜色空间量化为16x16x16=4096
# HSV-将RGB颜色空间转化内HSV,取H/S/V通道 256
# GRAY-灰度颜色空间:256
@jit
def img2prob_histogram(_object):
[m, n, _] = np.shape(_object)
# 各个像素点的权重(固定的)
_weight = get_image_weights(n, m)
if HIST_TYPE == 'BIN':
# 直方图大小
hist_size = 16 * 16 * 16
_histogram = np.zeros(hist_size, dtype=np.float32)
for i in range(m):
for j in range(n):
# 直方图下标
hist_index = get_hist_index(i, j, _object)
_histogram[hist_index] += _weight[i, j]
elif HIST_TYPE == 'GRAY':
hist_size = 256
_histogram = np.zeros(hist_size, dtype=np.float32)
for i in range(m):
for j in range(n):
# 直方图下标
hist_index = get_hist_index(i, j, _object)
_histogram[hist_index] += _weight[i, j]
elif HIST_TYPE == 'HSV':
hist_size = 180
_histogram = np.zeros(hist_size, dtype=np.float32)
for i in range(m):
for j in range(n):
# 直方图下标
hist_index = get_hist_index(i, j, _object)
_histogram[hist_index] += _weight[i, j]
else:
_histogram = None
# 直方图归一化
return _histogram / np.sum(_weight)
# 计算直方图各个bin的权重wi
@jit
def get_hist_bin_weights(h1, h2):
hist_size = len(h1)
sim = np.zeros(hist_size, np.float32)
for i in range(hist_size):
if h2[i] != 0:
sim[i] = np.sqrt(h1[i] / h2[i])
return sim
# 计算两个直方图的相似度(巴氏系数B)
# 直方图之间的距离D=sqrt(1-B)
@jit
def get_hist_similarity(h1, h2):
hist_size = len(h1)
dist = 0
for i in range(hist_size):
dist += np.sqrt(h1[i] * h2[i])
return dist
# 运行mean-shift,根据目标在上一帧的位置以及目标直方图,预测在当前帧的位置
@jit
def predict_window(prior_window, target_hist, current_frame):
# 当前帧目标位置
current_window = np.copy(prior_window)
[_, _, width, height] = current_window
# 标记迭代次数
iter_num = 0
x_shift_old, y_shift_old = 0, 0
while True:
# 当前目标区域
img_object = crop_image(current_window, current_frame)
# 计算基于前一帧目标窗口的概率直方图
current_hist = img2prob_histogram(img_object)
# 直方图各个bin的权重
bin_weights = get_hist_bin_weights(target_hist, current_hist)
bin_weights /= np.sum(bin_weights)
sum_weight = 0
# mean-shift偏移向量
x_shift, y_shift = 0, 0
x0 = int(height / 2)
y0 = int(width / 2)
# 计算基于巴氏距离最小化/巴氏系数最大化的mean-shift偏移向量
for i in range(height):
for j in range(width):
hist_index = get_hist_index(i, j, img_object)
sum_weight += bin_weights[hist_index]
x_shift += bin_weights[hist_index] * (j - y0)
y_shift += bin_weights[hist_index] * (i - x0)
x_shift /= sum_weight
y_shift /= sum_weight
# 防止出界,设置起始点的最小值为0
current_window[0] = max(current_window[0] + x_shift, 0)
current_window[1] = max(current_window[1] + y_shift, 0)
if iter_num >= MAX_ITER_NUM:
# 迭代次数达到上限,说明目标极有可能丢失,此时目标窗口不变,等待目标重新出现
# 这里没有用到巴氏距离,原因是为了防止迭代次数过多
# current_window = prior_window
break
if abs(x_shift - x_shift_old) < 1e-6 and abs(y_shift - y_shift_old) < 1e-6:
break
x_shift_old, y_shift_old = x_shift, y_shift
iter_num += 1
# print('window shift vector: ', [x_shift, y_shift])
# print('iterations: ', iter_num)
return current_window
# 执行目标跟踪
@jit
def run_object_detection(first_window, _frames):
# 图像输出目录
output_dir = './output/{0}/'.format(DATASET_NAME)
# 目标图像概率直方图
target_hist = img2prob_histogram(crop_image(first_window, _frames[0]))
# 直接输出第一张图像的检测结果
frame_detect = add_rectangle(first_window, _frames[0])
os.makedirs(output_dir, exist_ok=True)
cv2.imwrite(output_dir + '%04d.jpg' % 0, frame_detect)
cv2.imshow('', frame_detect)
cv2.waitKey(1)
prior_window = first_window
for i in range(1, len(_frames)):
print(i)
# 预测目标在当前帧的位置
prior_window = predict_window(prior_window, target_hist, _frames[i])
print(prior_window)
frame_detect = add_rectangle(prior_window, _frames[i])
cv2.imwrite(output_dir + '%04d.jpg' % i, frame_detect)
cv2.imshow('', frame_detect)
cv2.waitKey(1)
# 显示检测结果
def show_detect_result():
output_dir = './output/{0}/'.format(DATASET_NAME)
file_names = os.listdir(output_dir)
i = 0
ids = np.arange(0, len(file_names))
for filename in file_names:
if i in ids:
file_path = os.path.join(output_dir, filename)
img = cv2.imread(file_path)
cv2.imshow(DATASET_NAME, img)
cv2.waitKey(20)
i += 1
if __name__ == '__main__':
# 加载视频帧
object_first_window, frames = load_data_set('./dataset/' + DATASET_NAME)
# 手动设定ROI(首帧目标区域)
# object_first_window = cv2.selectROI('', frames[0])
print(object_first_window)
# 执行目标检测
run_object_detection(object_first_window, frames)
# 输出跟踪结果
show_detect_result()