-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
143 lines (114 loc) · 5.49 KB
/
app.py
File metadata and controls
143 lines (114 loc) · 5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import streamlit as st
from PIL import Image
import cv2
import numpy as np
import base64
from ultralytics import YOLO
import os
from pathlib import Path
from sam2.sam2_video_predictor import SAM2VideoPredictor
import supervision as sv
# Set up directories
HOME = os.getcwd()
SOURCE_FRAMES_DIR = Path(HOME) / "source_frames"
SOURCE_FRAMES_DIR.mkdir(parents=True, exist_ok=True)
ANNOTATED_FRAMES_DIR = Path(HOME) / "annotated_frames"
ANNOTATED_FRAMES_DIR.mkdir(parents=True, exist_ok=True)
# Load SAM2 model
sam2_model = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
# Helper functions
def detect_objects(image_path):
model = YOLO("yolov8n.pt")
results = model(image_path)
detected_objects = []
object_count = {}
for result in results:
for obj in result.boxes.data:
class_id = int(obj[5].item())
class_name = model.names[class_id]
if class_name in object_count:
object_count[class_name] += 1
else:
object_count[class_name] = 1
unique_label = f"{class_name}_{object_count[class_name]}"
detected_objects.append(unique_label)
return detected_objects
def encode_image(filepath):
with open(filepath, 'rb') as f:
image_bytes = f.read()
return "data:image/jpg;base64," + base64.b64encode(image_bytes).decode('utf-8')
# Streamlit App
st.title("Video Object Summarization")
uploaded_video = st.file_uploader("Upload a Video", type=["mp4", "mov", "avi"])
uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
if uploaded_video and uploaded_image:
# Save uploaded files
video_path = os.path.join(HOME, "uploaded_video.mp4")
with open(video_path, "wb") as f:
f.write(uploaded_video.getbuffer())
image_path = os.path.join(HOME, "uploaded_image.jpg")
with open(image_path, "wb") as f:
f.write(uploaded_image.getbuffer())
# Detect objects in the image
st.write("Detecting objects in the uploaded image...")
detected_objects = detect_objects(image_path)
st.write("Detected objects:", detected_objects)
# User selects objects to summarize
selected_objects = st.multiselect("Select Objects to Summarize", detected_objects)
start_time = st.number_input("Start Time (seconds)", min_value=0, step=1)
end_time = st.number_input("End Time (seconds)", min_value=1, step=1)
if st.button("Process Video"):
st.write("Processing video...")
# Video summarization process
video_info = sv.VideoInfo.from_video_path(video_path)
SCALE_FACTOR = 0.5
frames_generator = sv.get_video_frames_generator(video_path, start=int(start_time * video_info.fps), end=int(end_time * video_info.fps))
images_sink = sv.ImageSink(
target_dir_path=SOURCE_FRAMES_DIR.as_posix(),
overwrite=True,
image_name_pattern="{:05d}.jpeg"
)
with images_sink:
for frame in frames_generator:
frame = sv.scale_image(frame, SCALE_FACTOR)
images_sink.save_image(frame)
SOURCE_FRAME_PATHS = sorted(sv.list_files_with_extensions(SOURCE_FRAMES_DIR.as_posix(), extensions=["jpeg"]))
inference_state = sam2_model.init_state(video_path=SOURCE_FRAMES_DIR.as_posix())
sam2_model.reset_state(inference_state)
FRAME_IDX = 0
widget_boxes = [{'x': 100, 'y': 100, 'width': 0, 'height': 0, 'label': obj} for obj in selected_objects]
for object_id, label in enumerate(selected_objects, start=1):
points = np.array([[box['x'], box['y']] for box in widget_boxes if box['label'] == label], dtype=np.float32)
labels = np.ones(len(points))
_, object_ids, mask_logits = sam2_model.add_new_points(
inference_state=inference_state,
frame_idx=FRAME_IDX,
obj_id=object_id,
points=points,
labels=labels
)
TARGET_VIDEO = Path(HOME) / "final_annotated_video.mp4"
with sv.VideoSink(TARGET_VIDEO.as_posix(), video_info=video_info) as sink:
for frame_idx, object_ids, mask_logits in sam2_model.propagate_in_video(inference_state):
frame_path = SOURCE_FRAME_PATHS[frame_idx]
frame = cv2.imread(frame_path)
masks = (mask_logits > 0.0).cpu().numpy()
masks = np.squeeze(masks).astype(bool)
if np.any(masks):
detections = sv.Detections(
xyxy=sv.mask_to_xyxy(masks=masks),
mask=masks,
class_id=np.array(object_ids)
)
annotated_frame = sv.MaskAnnotator().annotate(scene=frame.copy(), detections=detections)
sink.write_frame(annotated_frame)
annotated_frame_path = ANNOTATED_FRAMES_DIR / f"{frame_idx:05d}.jpeg"
cv2.imwrite(str(annotated_frame_path), annotated_frame)
st.video(str(TARGET_VIDEO))
st.write("Video processing complete. You can download the summarized video.")
# Provide download link for the video
with open(TARGET_VIDEO, "rb") as video_file:
video_bytes = video_file.read()
b64_video = base64.b64encode(video_bytes).decode('utf-8')
download_link = f'<a href="data:video/mp4;base64,{b64_video}" download="summarized_video.mp4">Download Summarized Video</a>'
st.markdown(download_link, unsafe_allow_html=True)