From 50b4419dce248f6c5a2f118a27d8de3170dac178 Mon Sep 17 00:00:00 2001 From: tnarek Date: Mon, 20 May 2024 16:45:44 +0300 Subject: [PATCH 1/2] add efficient anchor trajectory computation --- .gitignore | 9 +++++++- data/data_utils.py | 12 ++++++---- models/model_inference.py | 32 +++++++++++++++++++++----- models/tracker.py | 23 ++++++++++-------- preprocessing/save_dino_embed_video.py | 2 +- 5 files changed, 57 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index cc79092..db4c0fc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,11 @@ *__pycache__* .vscode .DS_Store - +bash +logs +.vscode +dataset/tapvid-davis +dataset/tapvid-davis-sampled-4 +dataset/tapvid-davis-sampled-2 +dataset/tapvid-davis-sampled-test +dataset/tapvid-davis-sampled-2-test \ No newline at end of file diff --git a/data/data_utils.py b/data/data_utils.py index 1f69c6c..65854d9 100644 --- a/data/data_utils.py +++ b/data/data_utils.py @@ -76,7 +76,7 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False): return img -def load_video(video_folder: str, resize=None, num_frames=None): +def load_video(video_folder: str, resize=None, num_frames=None, to_tensor=True): """ Loads video from folder, resizes frames as desired, and outputs video tensor. @@ -97,11 +97,15 @@ def load_video(video_folder: str, resize=None, num_frames=None): for file in input_files: if resize is not None: - video.append(transforms.ToTensor()(Image.open(str(file)).resize((resw, resh), Image.LANCZOS))) + img = Image.open(str(file)).resize((resw, resh), Image.LANCZOS) + img = transforms.ToTensor()(img) if to_tensor else img + video.append(img) else: - video.append(transforms.ToTensor()(Image.open(str(file)))) + img = Image.open(str(file)) + img = transforms.ToTensor()(img) if to_tensor else img + video.append(img) - return torch.stack(video) + return torch.stack(video) if to_tensor else video def save_video(video, output_path, fps=30): diff --git a/models/model_inference.py b/models/model_inference.py index eb62e86..772effb 100644 --- a/models/model_inference.py +++ b/models/model_inference.py @@ -126,22 +126,23 @@ def compute_trajectory_cos_sims(self, trajectories, query_points) -> torch.Tenso return trajectories_cosine_similarities - # ----------------- Anchor Trajectories ----------------- + # ----------------- Anchor Trajectories (slower, but less memory-consuming) ----------------- def _get_model_preds_at_anchors(self, model, range_normalizer, preds, anchor_indices, batch_size=None): """ preds: N""" batch_size = batch_size if batch_size is not None else preds.shape[0] cycle_coords = [] + # for each anchor frame in anchor_indices, get tracking predictions from all preds to the anchor frame for vis_frame in anchor_indices: # iterate over frames_set_t in batches of size batch_size coords = [] - for i in range(0, preds.shape[0], batch_size): - end_idx = min(i + batch_size, preds.shape[0]) - frames_set_t = torch.arange(i, end_idx, device=model.device) - frames_set_t = torch.cat([ torch.tensor([vis_frame], device=model.device), frames_set_t ]).int() + for st_idx in range(0, preds.shape[0], batch_size): + end_idx = min(st_idx + batch_size, preds.shape[0]) + frames_set_t = torch.arange(st_idx, end_idx, device=model.device) # source frames + frames_set_t = torch.cat([ torch.tensor([vis_frame], device=model.device), frames_set_t ]).int() # add target frame (vis_frame) source_frame_indices = torch.arange(1, frames_set_t.shape[0], device=model.device) target_frame_indices = torch.tensor([0]*(frames_set_t.shape[0]-1), device=model.device) - inp = preds[i:end_idx], source_frame_indices, target_frame_indices, frames_set_t + inp = preds[st_idx:end_idx], source_frame_indices, target_frame_indices, frames_set_t batch_coords = model(inp) batch_coords = range_normalizer.unnormalize(batch_coords, src=(-1, 1), dims=[0, 1]) coords.append(batch_coords) @@ -153,6 +154,25 @@ def _get_model_preds_at_anchors(self, model, range_normalizer, preds, anchor_ind return cycle_coords + # ----------------- Anchor Trajectories ----------------- + def _get_model_preds_at_anchors_old(self, model, range_normalizer, preds, anchor_indices, batch_size=None): + """ preds: T. anchor_indices (N_anchors). + Returns: cycle_coords, N_anchors x T x 2. + """ + T = preds.shape[0] + batch_size = batch_size if batch_size is not None else T + + frames_set_t = torch.arange(0, T).int() # [0, 1, 2, ..., T-1] + source_frames = torch.arange(0, T, device=model.device) + source_frame_indices = source_frames.repeat(anchor_indices.shape[0]) # T*N_anchors, [0, 1, ..., T-1, 0, 1, ..., T-1, ...] + query_points = preds.repeat(anchor_indices.shape[0], 1) # (T*N_anchors) x 3 + target_frame_indices = anchor_indices.unsqueeze(1).repeat(1, source_frames.shape[0]).view(-1) # T*N_anchors, [anchor_indices[0], ..., anchor_indices[0], anchor_indices[1], ..., anchor_indices[1], ...] + inp = query_points, source_frame_indices, target_frame_indices, frames_set_t + cycle_coords = model(inp) # (T*N_anchors) x 2 + cycle_coords = range_normalizer.unnormalize(cycle_coords, src=(-1, 1), dims=[0, 1]) + cycle_coords = cycle_coords.view(anchor_indices.shape[0], T, 2) # N_anchors x T x 2 + return cycle_coords # N_anchors x T x 2 + def compute_anchor_trajectories(self, trajectories: torch.Tensor, cos_sims: torch.Tensor, batch_size=None) -> torch.Tensor: N, T = trajectories.shape[:2] eql_anchor_cyc_predictions = {} diff --git a/models/tracker.py b/models/tracker.py index 9961ba1..92cf00f 100644 --- a/models/tracker.py +++ b/models/tracker.py @@ -133,13 +133,15 @@ def cache_refined_embeddings(self, move_dino_to_cpu=False): self.refined_features = refined_features if move_dino_to_cpu: self.dino_embed_video = self.dino_embed_video.to("cpu") + torch.cuda.empty_cache() + gc.collect() def uncache_refined_embeddings(self, move_dino_to_gpu=False): self.refined_features = None - torch.cuda.empty_cache() - gc.collect() if move_dino_to_gpu: self.dino_embed_video = self.dino_embed_video.to("cuda") + torch.cuda.empty_cache() + gc.collect() def save_weights(self, iter): torch.save(self.tracker_head.state_dict(), Path(self.ckpt_path) / f"tracker_head_{iter}.pt") @@ -300,13 +302,15 @@ def get_cycle_consistent_preds(self, frames_set_t, fg_masks): return cycle_consistency_preds - def forward(self, inp, use_raw_features=False): + def forward(self, inp, use_raw_features=False, cache_raw_features=False): """ inp: source_points_unnormalized, source_frame_indices, target_frame_indices, frames_set_t; where - source_points_unnormalized: B x 3. ((x, y, t) in image scale - NOT normalized) - source_frame_indices: the indices of frames of source points in frames_set_t - target_frame_indices: the indices of target frames in frames_set_t - frames_set_t: N, 0 to T-1 (NOT normalized) + source_points_unnormalized: B x 3. ((x, y, t) in image scale - NOT normalized) + source_frame_indices: the indices of frames of source points in frames_set_t + target_frame_indices: the indices of target frames in frames_set_t + frames_set_t: N, 0 to T-1 (NOT normalized) + use_raw_features: if True, use raw embeddings from DINO. + cache_raw_features: if True, cache raw embeddings for future use. """ frames_set_t = inp[-1] @@ -314,12 +318,13 @@ def forward(self, inp, use_raw_features=False): frame_embeddings = raw_embeddings = self.get_dino_embed_video(frames_set_t=frames_set_t) elif self.refined_features is not None: # load from cache frame_embeddings = self.refined_features[frames_set_t] - raw_embeddings = self.dino_embed_video[frames_set_t.to(self.dino_embed_video.device)] + if cache_raw_features: + self.raw_embeddings = self.dino_embed_video[frames_set_t.to(self.dino_embed_video.device)] else: frame_embeddings, residual_embeddings, raw_embeddings = self.get_refined_embeddings(frames_set_t, return_raw_embeddings=True) self.residual_embeddings = residual_embeddings + self.raw_embeddings = raw_embeddings self.frame_embeddings = frame_embeddings - self.raw_embeddings = raw_embeddings coords = self.get_point_predictions(inp, frame_embeddings) return coords diff --git a/preprocessing/save_dino_embed_video.py b/preprocessing/save_dino_embed_video.py index 37dd4ca..12fd63f 100644 --- a/preprocessing/save_dino_embed_video.py +++ b/preprocessing/save_dino_embed_video.py @@ -28,7 +28,7 @@ def save_dino_embed_video(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", default="./config/preprocessing.yaml", type=str) - parser.add_argument("--data-path", default="./dataset/libby", type=str, required=True) + parser.add_argument("--data-path", default="./dataset/tapvid-davis/davis_480/0", type=str, required=True) parser.add_argument("--for-mask", action="store_true", default=False) args = parser.parse_args() From 5ee75a1c3dab2f9aa6ab66754a643e690fdffd96 Mon Sep 17 00:00:00 2001 From: tnarek Date: Tue, 21 May 2024 00:33:55 +0300 Subject: [PATCH 2/2] fix pred anchor traj naming --- models/model_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/model_inference.py b/models/model_inference.py index 772effb..9423cf6 100644 --- a/models/model_inference.py +++ b/models/model_inference.py @@ -127,7 +127,7 @@ def compute_trajectory_cos_sims(self, trajectories, query_points) -> torch.Tenso # ----------------- Anchor Trajectories (slower, but less memory-consuming) ----------------- - def _get_model_preds_at_anchors(self, model, range_normalizer, preds, anchor_indices, batch_size=None): + def _get_model_preds_at_anchors_old(self, model, range_normalizer, preds, anchor_indices, batch_size=None): """ preds: N""" batch_size = batch_size if batch_size is not None else preds.shape[0] @@ -155,7 +155,7 @@ def _get_model_preds_at_anchors(self, model, range_normalizer, preds, anchor_ind return cycle_coords # ----------------- Anchor Trajectories ----------------- - def _get_model_preds_at_anchors_old(self, model, range_normalizer, preds, anchor_indices, batch_size=None): + def _get_model_preds_at_anchors(self, model, range_normalizer, preds, anchor_indices, batch_size=None): """ preds: T. anchor_indices (N_anchors). Returns: cycle_coords, N_anchors x T x 2. """