diff --git a/co-tracker/README.md b/co-tracker/README.md new file mode 100644 index 0000000..f35a6f7 --- /dev/null +++ b/co-tracker/README.md @@ -0,0 +1,9 @@ +# Co-Tracker + +This app takes a video, projects a mask of points onto the first frame, and propagates that mask, +tracking each point over the course of the video. + +If the "visualize" option is checked, the output tracks are projected onto a video, which is returned. Otherwise, +the raw arrays of the tracks and their visibility masks are returned instead. + +The "grid size" option refers to the granularity of masks, in pixels, and the "pad size" option refers to how large of a padding to provide the visualization in the event of tracks moving off screen. \ No newline at end of file diff --git a/co-tracker/main.py b/co-tracker/main.py new file mode 100644 index 0000000..0e05ea5 --- /dev/null +++ b/co-tracker/main.py @@ -0,0 +1,89 @@ +import sieve + +metadata = sieve.Metadata( + title="Video Co-Tracker", + description="Track any point in a video", + code_url="https://github.com/sieve-community/examples/tree/main/co-tracker", + image=sieve.Image( + url="https://github.com/facebookresearch/co-tracker/raw/main/assets/bmx-bumps.gif" + ), + tags=["Video", "Tracking"], + readme=open("README.md", "r").read(), +) + +@sieve.Model( + name="co-tracker", + gpu=True, + python_packages=[ + "git+https://github.com/facebookresearch/co-tracker", + "torchvision==0.15.2", + "torch==2.0.1", + "einops==0.4.1", + "timm==0.6.7", + "tqdm==4.64.1", + "flow_vis", + "matplotlib==3.7.0", + "moviepy==1.0.3", + ], + python_version="3.11", + cuda_version="11.8", + run_commands=[ + "mkdir -p /root/.cache/torch/hub/facebookresearch_co-tracker_master", + "wget -q https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth -O /root/.cache/torch/hub/facebookresearch_co-tracker_master/cotracker_stride_4_wind_8.pth", + ], + metadata=metadata +) +class CoTracker: + def __setup__(self): + import torch + import os + import tempfile + from cotracker.predictor import CoTrackerPredictor + self.model = CoTrackerPredictor( + checkpoint='/root/.cache/torch/hub/facebookresearch_co-tracker_master/cotracker_stride_4_wind_8.pth' + ) + self.use_cuda = torch.cuda.is_available() + if self.use_cuda: + self.model = self.model.cuda() + self.videos_dir = tempfile.mkdtemp() + + def __predict__(self, video: sieve.Video, grid_size: int = 30, pad_size: int = 100, visualize: bool = False): + from cotracker.utils.visualizer import Visualizer, read_video_from_path + import torch + import os + + loaded_video = read_video_from_path(video.path) + loaded_video = torch.from_numpy(loaded_video).permute(0, 3, 1, 2).float() + + if self.use_cuda: + loaded_video = loaded_video.cuda() + + # Unsqueeze to add batch dimension + loaded_video = loaded_video.unsqueeze(0) + + pred_tracks, pred_visibility = self.model(loaded_video, grid_size=grid_size) + + for file in os.listdir(self.videos_dir): + os.remove(os.path.join(self.videos_dir, file)) + + sieve_array_pred_tracks = sieve.Array(array=pred_tracks.cpu().numpy()) + sieve_array_pred_visibility = sieve.Array(array=pred_visibility.cpu().numpy()) + + if visualize: + + vis = Visualizer( + save_dir=self.videos_dir, + pad_value=pad_size + ) + + vis.visualize( + video=loaded_video, + tracks=pred_tracks, + visibility=pred_visibility, + filename="output", + ) + + return sieve.Video(path=os.path.join(self.videos_dir, "output_pred_track.mp4")) + + return sieve_array_pred_tracks, sieve_array_pred_visibility +