From c9059a7bb538209c8e17231a85fcc288502ffd86 Mon Sep 17 00:00:00 2001 From: YUKINA-3252 Date: Sun, 14 Jan 2024 20:26:07 +0900 Subject: [PATCH] add the description of segmentation mask for online mode --- cotracker/predictor.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cotracker/predictor.py b/cotracker/predictor.py index 575095bf..b53c7815 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -197,6 +197,7 @@ def forward( video_chunk, is_first_step: bool = False, queries: torch.Tensor = None, + segm_mask: torch.Tensor = None, grid_size: int = 10, grid_query_frame: int = 0, add_support_grid=False, @@ -220,6 +221,14 @@ def forward( grid_pts = get_points_on_a_grid( grid_size, self.interp_shape, device=video_chunk.device ) + if segm_mask is not None: + segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest") + point_mask = segm_mask[0, 0][ + (grid_pts[0, :, 1]).round().long().cpu(), + (grid_pts[0, :, 0]).round().long().cpu(), + ].bool() + grid_pts = grid_pts[:, point_mask] + queries = torch.cat( [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], dim=2,