-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcheck_model_predictions.py
More file actions
88 lines (68 loc) · 2.76 KB
/
check_model_predictions.py
File metadata and controls
88 lines (68 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import cv2
from tracker import Sort
from helpers.location_predictor import LocationPredictor
from helpers.dataset_helpers import create_image_processor
from helpers.evaluation_functions import plot_results
from helpers.model_helpers import load_model_from_checkpoint
train_folder = 'training/leaderboard_data/'
use_dummy = False
validate_result = False
save_video = False
if not use_dummy:
model = load_model_from_checkpoint(True)
image_processor = create_image_processor()
location_predictor = LocationPredictor(model, image_processor, 0.0)
videos = [f for f in os.listdir(train_folder) if f.endswith('.mp4')]
#videos = ['training072.mp4']
total_frames = 0
correct_frames = 0
wait_for_key = True
for video_name in videos:
csv_filename = os.path.join(train_folder, video_name.split('.')[0] + '.csv')
if os.path.exists(csv_filename):
os.remove(csv_filename)
with open(os.path.join(train_folder, video_name.split('.')[0] + '.csv'), 'w') as f:
f.write(",t,hexbug,x,y\n")
idx_counter = 0
video_path = os.path.join(train_folder, video_name)
cap = cv2.VideoCapture(video_path)
if save_video:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out = cv2.VideoWriter("output_video.mp4", fourcc, fps, (width, height))
tracker = Sort(max_age=200, min_hits=1, threshold=9000, mode="centroid")
frame_number = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
orig_size = frame.shape[:2]
locations = location_predictor.predict(frame, orig_size)
boxes = locations['boxes']
labels = locations['labels']
scores = locations['scores']
print("---")
print(locations)
if frame_number == 0:
tracker.initialize_trackers(boxes)
plot_results(model, frame, locations["scores"], locations["labels"] ,locations["boxes"])
# Save data in format ,t,hexbug,x,y in csv file
idx_counter += 1
resized = cv2.resize(frame, None, fx=0.9, fy=0.9, interpolation=cv2.INTER_AREA)
cv2.imshow('Tracking', resized)
if save_video:
out.write(resized)
# print ("Tracks: ", tracks)
if wait_for_key:
key = cv2.waitKey(1)
if key == ord('q'):
wait_for_key = False
frame_number += 1
total_frames += 1
print(f"----------------------------Finished video: {video_name} with {frame_number} frames")
cap.release()
cv2.destroyAllWindows()
print(f"Total frames: {total_frames}, Correct frames: {correct_frames}, Success Rate: {correct_frames / total_frames:.2%}")