Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
*__pycache__*
.vscode
.DS_Store

bash
logs
.vscode
dataset/tapvid-davis
dataset/tapvid-davis-sampled-4
dataset/tapvid-davis-sampled-2
dataset/tapvid-davis-sampled-test
dataset/tapvid-davis-sampled-2-test
12 changes: 8 additions & 4 deletions data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
return img


def load_video(video_folder: str, resize=None, num_frames=None):
def load_video(video_folder: str, resize=None, num_frames=None, to_tensor=True):
"""
Loads video from folder, resizes frames as desired, and outputs video tensor.

Expand All @@ -97,11 +97,15 @@ def load_video(video_folder: str, resize=None, num_frames=None):

for file in input_files:
if resize is not None:
video.append(transforms.ToTensor()(Image.open(str(file)).resize((resw, resh), Image.LANCZOS)))
img = Image.open(str(file)).resize((resw, resh), Image.LANCZOS)
img = transforms.ToTensor()(img) if to_tensor else img
video.append(img)
else:
video.append(transforms.ToTensor()(Image.open(str(file))))
img = Image.open(str(file))
img = transforms.ToTensor()(img) if to_tensor else img
video.append(img)

return torch.stack(video)
return torch.stack(video) if to_tensor else video


def save_video(video, output_path, fps=30):
Expand Down
34 changes: 27 additions & 7 deletions models/model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,23 @@ def compute_trajectory_cos_sims(self, trajectories, query_points) -> torch.Tenso
return trajectories_cosine_similarities


# ----------------- Anchor Trajectories -----------------
def _get_model_preds_at_anchors(self, model, range_normalizer, preds, anchor_indices, batch_size=None):
# ----------------- Anchor Trajectories (slower, but less memory-consuming) -----------------
def _get_model_preds_at_anchors_old(self, model, range_normalizer, preds, anchor_indices, batch_size=None):
""" preds: N"""
batch_size = batch_size if batch_size is not None else preds.shape[0]

cycle_coords = []
# for each anchor frame in anchor_indices, get tracking predictions from all preds to the anchor frame
for vis_frame in anchor_indices:
# iterate over frames_set_t in batches of size batch_size
coords = []
for i in range(0, preds.shape[0], batch_size):
end_idx = min(i + batch_size, preds.shape[0])
frames_set_t = torch.arange(i, end_idx, device=model.device)
frames_set_t = torch.cat([ torch.tensor([vis_frame], device=model.device), frames_set_t ]).int()
for st_idx in range(0, preds.shape[0], batch_size):
end_idx = min(st_idx + batch_size, preds.shape[0])
frames_set_t = torch.arange(st_idx, end_idx, device=model.device) # source frames
frames_set_t = torch.cat([ torch.tensor([vis_frame], device=model.device), frames_set_t ]).int() # add target frame (vis_frame)
source_frame_indices = torch.arange(1, frames_set_t.shape[0], device=model.device)
target_frame_indices = torch.tensor([0]*(frames_set_t.shape[0]-1), device=model.device)
inp = preds[i:end_idx], source_frame_indices, target_frame_indices, frames_set_t
inp = preds[st_idx:end_idx], source_frame_indices, target_frame_indices, frames_set_t
batch_coords = model(inp)
batch_coords = range_normalizer.unnormalize(batch_coords, src=(-1, 1), dims=[0, 1])
coords.append(batch_coords)
Expand All @@ -153,6 +154,25 @@ def _get_model_preds_at_anchors(self, model, range_normalizer, preds, anchor_ind

return cycle_coords

# ----------------- Anchor Trajectories -----------------
def _get_model_preds_at_anchors(self, model, range_normalizer, preds, anchor_indices, batch_size=None):
""" preds: T. anchor_indices (N_anchors).
Returns: cycle_coords, N_anchors x T x 2.
"""
T = preds.shape[0]
batch_size = batch_size if batch_size is not None else T

frames_set_t = torch.arange(0, T).int() # [0, 1, 2, ..., T-1]
source_frames = torch.arange(0, T, device=model.device)
source_frame_indices = source_frames.repeat(anchor_indices.shape[0]) # T*N_anchors, [0, 1, ..., T-1, 0, 1, ..., T-1, ...]
query_points = preds.repeat(anchor_indices.shape[0], 1) # (T*N_anchors) x 3
target_frame_indices = anchor_indices.unsqueeze(1).repeat(1, source_frames.shape[0]).view(-1) # T*N_anchors, [anchor_indices[0], ..., anchor_indices[0], anchor_indices[1], ..., anchor_indices[1], ...]
inp = query_points, source_frame_indices, target_frame_indices, frames_set_t
cycle_coords = model(inp) # (T*N_anchors) x 2
cycle_coords = range_normalizer.unnormalize(cycle_coords, src=(-1, 1), dims=[0, 1])
cycle_coords = cycle_coords.view(anchor_indices.shape[0], T, 2) # N_anchors x T x 2
return cycle_coords # N_anchors x T x 2

def compute_anchor_trajectories(self, trajectories: torch.Tensor, cos_sims: torch.Tensor, batch_size=None) -> torch.Tensor:
N, T = trajectories.shape[:2]
eql_anchor_cyc_predictions = {}
Expand Down
23 changes: 14 additions & 9 deletions models/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,15 @@ def cache_refined_embeddings(self, move_dino_to_cpu=False):
self.refined_features = refined_features
if move_dino_to_cpu:
self.dino_embed_video = self.dino_embed_video.to("cpu")
torch.cuda.empty_cache()
gc.collect()

def uncache_refined_embeddings(self, move_dino_to_gpu=False):
self.refined_features = None
torch.cuda.empty_cache()
gc.collect()
if move_dino_to_gpu:
self.dino_embed_video = self.dino_embed_video.to("cuda")
torch.cuda.empty_cache()
gc.collect()

def save_weights(self, iter):
torch.save(self.tracker_head.state_dict(), Path(self.ckpt_path) / f"tracker_head_{iter}.pt")
Expand Down Expand Up @@ -300,26 +302,29 @@ def get_cycle_consistent_preds(self, frames_set_t, fg_masks):

return cycle_consistency_preds

def forward(self, inp, use_raw_features=False):
def forward(self, inp, use_raw_features=False, cache_raw_features=False):
"""
inp: source_points_unnormalized, source_frame_indices, target_frame_indices, frames_set_t; where
source_points_unnormalized: B x 3. ((x, y, t) in image scale - NOT normalized)
source_frame_indices: the indices of frames of source points in frames_set_t
target_frame_indices: the indices of target frames in frames_set_t
frames_set_t: N, 0 to T-1 (NOT normalized)
source_points_unnormalized: B x 3. ((x, y, t) in image scale - NOT normalized)
source_frame_indices: the indices of frames of source points in frames_set_t
target_frame_indices: the indices of target frames in frames_set_t
frames_set_t: N, 0 to T-1 (NOT normalized)
use_raw_features: if True, use raw embeddings from DINO.
cache_raw_features: if True, cache raw embeddings for future use.
"""
frames_set_t = inp[-1]

if use_raw_features:
frame_embeddings = raw_embeddings = self.get_dino_embed_video(frames_set_t=frames_set_t)
elif self.refined_features is not None: # load from cache
frame_embeddings = self.refined_features[frames_set_t]
raw_embeddings = self.dino_embed_video[frames_set_t.to(self.dino_embed_video.device)]
if cache_raw_features:
self.raw_embeddings = self.dino_embed_video[frames_set_t.to(self.dino_embed_video.device)]
else:
frame_embeddings, residual_embeddings, raw_embeddings = self.get_refined_embeddings(frames_set_t, return_raw_embeddings=True)
self.residual_embeddings = residual_embeddings
self.raw_embeddings = raw_embeddings
self.frame_embeddings = frame_embeddings
self.raw_embeddings = raw_embeddings
coords = self.get_point_predictions(inp, frame_embeddings)

return coords
2 changes: 1 addition & 1 deletion preprocessing/save_dino_embed_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def save_dino_embed_video(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="./config/preprocessing.yaml", type=str)
parser.add_argument("--data-path", default="./dataset/libby", type=str, required=True)
parser.add_argument("--data-path", default="./dataset/tapvid-davis/davis_480/0", type=str, required=True)
parser.add_argument("--for-mask", action="store_true", default=False)

args = parser.parse_args()
Expand Down