Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions gradio_demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import List, Optional, Sequence, Tuple

import numpy as np
import csv


# Generate random colormaps for visualizing different points.
Expand Down Expand Up @@ -400,7 +401,18 @@ def track(

mediapy.write_video(video_file_path, painted_video, fps=video_fps)

return video_file_path
tracks_file_name = uuid.uuid4().hex + "_tracks.csv"
tracks_file_path = os.path.join(video_path, tracks_file_name)

# Save `tracks` as CSV
with open(tracks_file_path, mode='w', newline='') as csv_file:
csv_writer = csv.writer(csv_file)
csv_writer.writerow(["Point_Index", "Frame", "X", "Y"]) # Header
for frame_index, frame_tracks in enumerate(tracks):
for point_index, (x, y) in enumerate(frame_tracks):
csv_writer.writerow([frame_index, point_index, x, y])

return (video_file_path, tracks_file_path)


with gr.Blocks() as demo:
Expand Down Expand Up @@ -593,6 +605,8 @@ def track(
queue = False
)

download_csv = gr.File(label="Download Tracked Points (CSV)")


track_button.click(
fn = track,
Expand All @@ -605,10 +619,10 @@ def track(
query_count,
],
outputs = [
output_video,
output_video,download_csv
],
queue = True,
)


demo.launch(show_api=False, show_error=True, debug=True, share=True)
demo.launch(show_api=False, show_error=True, debug=True, share=True)