From 3ecbe19e43c0c6e7cde53905f452627ab6bc24e0 Mon Sep 17 00:00:00 2001 From: Guangda Ji Date: Tue, 5 Nov 2024 12:13:51 +0100 Subject: [PATCH] add fix for cotracker v2 --- cotracker/models/build_cotracker.py | 12 ++++----- cotracker/predictor.py | 40 +++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/cotracker/models/build_cotracker.py b/cotracker/models/build_cotracker.py index b816efa0..8710a836 100644 --- a/cotracker/models/build_cotracker.py +++ b/cotracker/models/build_cotracker.py @@ -24,13 +24,13 @@ def build_cotracker( def build_cotracker(checkpoint=None, offline=True, window_len=16, v2=False): - if offline: - cotracker = CoTrackerThreeOffline( - stride=4, corr_radius=3, window_len=window_len - ) + if v2: + cotracker = CoTracker2(stride=4, window_len=window_len) else: - if v2: - cotracker = CoTracker2(stride=4, window_len=window_len) + if offline: + cotracker = CoTrackerThreeOffline( + stride=4, corr_radius=3, window_len=window_len + ) else: cotracker = CoTrackerThreeOnline( stride=4, corr_radius=3, window_len=window_len diff --git a/cotracker/predictor.py b/cotracker/predictor.py index 53e2303e..9eceb7b1 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -21,6 +21,7 @@ def __init__( ): super().__init__() self.support_grid_size = 6 + self.v2 = v2 model = build_cotracker( checkpoint, v2=v2, @@ -153,9 +154,14 @@ def _compute_sparse_tracks( grid_pts = grid_pts.repeat(B, 1, 1) queries = torch.cat([queries, grid_pts], dim=1) - tracks, visibilities, __, __ = self.model.forward( - video=video, queries=queries, iters=6 - ) + if self.v2: + tracks, visibilities, __ = self.model.forward( + video=video, queries=queries, iters=6 + ) + else: + tracks, visibilities, __, __ = self.model.forward( + video=video, queries=queries, iters=6 + ) if backward_tracking: tracks, visibilities = self._compute_backward_tracks( @@ -193,9 +199,14 @@ def _compute_backward_tracks(self, video, queries, tracks, visibilities): inv_queries = queries.clone() inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 - inv_tracks, inv_visibilities, _, _ = self.model( - video=inv_video, queries=inv_queries, iters=6 - ) + if self.v2: + inv_tracks, inv_visibilities, _ = self.model( + video=inv_video, queries=inv_queries, iters=6 + ) + else: + inv_tracks, inv_visibilities, _, _ = self.model( + video=inv_video, queries=inv_queries, iters=6 + ) inv_tracks = inv_tracks.flip(1) inv_visibilities = inv_visibilities.flip(1) @@ -279,15 +290,22 @@ def forward( B, T, 3, self.interp_shape[0], self.interp_shape[1] ) - tracks, visibilities, confidence, __ = self.model( - video=video_chunk, queries=self.queries, iters=6, is_online=True - ) + if self.v2: + tracks, visibilities, __ = self.model( + video=video_chunk, queries=self.queries, iters=6, is_online=True + ) + else: + tracks, visibilities, confidence, __ = self.model( + video=video_chunk, queries=self.queries, iters=6, is_online=True + ) if add_support_grid: tracks = tracks[:,:,:self.N] visibilities = visibilities[:,:,:self.N] - confidence = confidence[:,:,:self.N] + if not self.v2: + confidence = confidence[:,:,:self.N] - visibilities = visibilities * confidence + if not self.v2: + visibilities = visibilities * confidence thr = 0.6 return ( tracks