From 5f6521f91d41eed0bce27bcf94f3984e5515f397 Mon Sep 17 00:00:00 2001 From: Andrew Healey Date: Thu, 6 Jul 2023 12:04:41 -0400 Subject: [PATCH 01/11] Precompute logits during training --- persam_f.py | 17 ++++++++--------- persam_f_multi_obj.py | 19 +++++++++---------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/persam_f.py b/persam_f.py index 443d8fc..8719f97 100644 --- a/persam_f.py +++ b/persam_f.py @@ -130,18 +130,17 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch) - for train_idx in range(args.train_epoch): - - # Run the decoder - masks, scores, logits, logits_high = predictor.predict( - point_coords=topk_xy, - point_labels=topk_label, - multimask_output=True) - logits_high = logits_high.flatten(1) + # Run the decoder + masks, scores, logits, logits_high = predictor.predict( + point_coords=topk_xy, + point_labels=topk_label, + multimask_output=True) + original_logits_high = original_logits_high.flatten(1) + for train_idx in range(args.train_epoch): # Weighted sum three-scale masks weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0) - logits_high = logits_high * weights + logits_high = original_logits_high * weights logits_high = logits_high.sum(0).unsqueeze(0) dice_loss = calculate_dice_loss(logits_high, gt_mask) diff --git a/persam_f_multi_obj.py b/persam_f_multi_obj.py index 22d8a9c..cff69cc 100644 --- a/persam_f_multi_obj.py +++ b/persam_f_multi_obj.py @@ -131,20 +131,19 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): mask_weights.train() optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch_inside) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch) - for train_idx in range(args.train_epoch_inside): - - # Run the decoder - masks, scores, logits, logits_high = predictor.predict( - point_coords=topk_xy, - point_labels=topk_label, - multimask_output=True) - logits_high = logits_high.flatten(1) + # Run the decoder + masks, scores, logits, logits_high = predictor.predict( + point_coords=topk_xy, + point_labels=topk_label, + multimask_output=True) + original_logits_high = original_logits_high.flatten(1) + for train_idx in range(args.train_epoch): # Weighted sum three-scale masks weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0) - logits_high = logits_high * weights + logits_high = original_logits_high * weights logits_high = logits_high.sum(0).unsqueeze(0) dice_loss = calculate_dice_loss(logits_high, gt_mask) From 022c4378b3069718ab1f533a7ad9d97efdf43c1b Mon Sep 17 00:00:00 2001 From: Andrew Healey Date: Thu, 6 Jul 2023 12:55:47 -0400 Subject: [PATCH 02/11] Fix typo --- persam_f.py | 2 +- persam_f_multi_obj.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/persam_f.py b/persam_f.py index 8719f97..5a6284b 100644 --- a/persam_f.py +++ b/persam_f.py @@ -131,7 +131,7 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch) # Run the decoder - masks, scores, logits, logits_high = predictor.predict( + masks, scores, logits, original_logits_high = predictor.predict( point_coords=topk_xy, point_labels=topk_label, multimask_output=True) diff --git a/persam_f_multi_obj.py b/persam_f_multi_obj.py index cff69cc..5e4786c 100644 --- a/persam_f_multi_obj.py +++ b/persam_f_multi_obj.py @@ -199,7 +199,7 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): topk_xy, topk_label = point_selection(sim, topk=1) # First-step prediction - masks, scores, logits, logits_high = predictor.predict( + masks, scores, logits, original_logits_high = predictor.predict( point_coords=topk_xy, point_labels=topk_label, multimask_output=True) From a89f1e71cdf4fe45fd99c322c2757c9793285759 Mon Sep 17 00:00:00 2001 From: Andrew Healey Date: Thu, 6 Jul 2023 14:02:50 -0400 Subject: [PATCH 03/11] Decrease training time to 1s --- persam_f.py | 9 +++++++-- requirements.txt | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/persam_f.py b/persam_f.py index 5a6284b..1e86b7d 100644 --- a/persam_f.py +++ b/persam_f.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from torch.nn import functional as F +import torchvision.transforms.functional as TVF import os import cv2 @@ -70,9 +71,11 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): ref_mask = cv2.imread(ref_mask_path) ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) - gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0 - gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda() + resolution = [256, 256] + gt_mask = torch.tensor(ref_mask)[None,:, :, 0] > 0 + gt_mask = TVF.resize(gt_mask.float(), resolution) + gt_mask = gt_mask.flatten(1).cuda() print("======> Load SAM" ) if args.sam_type == 'vit_h': @@ -135,6 +138,8 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): point_coords=topk_xy, point_labels=topk_label, multimask_output=True) + + original_logits_high = TVF.resize(original_logits_high,resolution) original_logits_high = original_logits_high.flatten(1) for train_idx in range(args.train_epoch): diff --git a/requirements.txt b/requirements.txt index ddb4470..20c0736 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ numpy warnings argparse opencv-python +torchvision From f517cfd56f324468d9f53efc1d6274a7e756db6c Mon Sep 17 00:00:00 2001 From: Andrew Healey Date: Thu, 6 Jul 2023 14:04:05 -0400 Subject: [PATCH 04/11] Use global SAM --- persam_f.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/persam_f.py b/persam_f.py index 1e86b7d..3321ffe 100644 --- a/persam_f.py +++ b/persam_f.py @@ -51,6 +51,7 @@ def main(): if ".DS" not in obj_name: persam_f(args, obj_name, images_path, masks_path, output_path) +sam = None def persam_f(args, obj_name, images_path, masks_path, output_path): @@ -78,14 +79,16 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): gt_mask = gt_mask.flatten(1).cuda() print("======> Load SAM" ) - if args.sam_type == 'vit_h': - sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() - elif args.sam_type == 'vit_t': - sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' - device = "cuda" if torch.cuda.is_available() else "cpu" - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) - sam.eval() + global sam + if sam is None: + if args.sam_type == 'vit_h': + sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() + elif args.sam_type == 'vit_t': + sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' + device = "cuda" if torch.cuda.is_available() else "cpu" + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) + sam.eval() for name, param in sam.named_parameters(): @@ -125,6 +128,8 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): topk_xy, topk_label = point_selection(sim, topk=1) + + print('======> Start Training') # Learnable mask weights mask_weights = Mask_Weights().cuda() From 4e68d2338146006dfd231a4d70963ca3dee6aab7 Mon Sep 17 00:00:00 2001 From: Andrew Healey Date: Thu, 6 Jul 2023 14:29:19 -0400 Subject: [PATCH 05/11] Apply same changes to multi_obj --- persam_f_multi_obj.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/persam_f_multi_obj.py b/persam_f_multi_obj.py index 5e4786c..4442484 100644 --- a/persam_f_multi_obj.py +++ b/persam_f_multi_obj.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from torch.nn import functional as F +import torchvision.transforms.functional as TVF import os import cv2 @@ -55,17 +56,20 @@ def main(): for obj_name in os.listdir(images_path): persam_f(args, obj_name, images_path, masks_path, output_path) +sam = None def persam_f(args, obj_name, images_path, masks_path, output_path): print("======> Load SAM" ) - if args.sam_type == 'vit_h': - sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() - elif args.sam_type == 'vit_t': - sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' - device = "cuda" if torch.cuda.is_available() else "cpu" - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) - sam.eval() + global sam + if sam is None: + if args.sam_type == 'vit_h': + sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() + elif args.sam_type == 'vit_t': + sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' + device = "cuda" if torch.cuda.is_available() else "cpu" + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) + sam.eval() for name, param in sam.named_parameters(): @@ -90,8 +94,11 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): ref_mask = cv2.imread(ref_mask_path) ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) + resolution = [256, 256] + gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0 - gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda() + gt_mask = TVF.resize(gt_mask.float(), resolution) + gt_mask = gt_mask.unsqueeze(0).flatten(1).cuda() # print("======> Obtain Self Location Prior" ) # Image features encoding @@ -134,10 +141,11 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch) # Run the decoder - masks, scores, logits, logits_high = predictor.predict( + masks, scores, logits, original_logits_high = predictor.predict( point_coords=topk_xy, point_labels=topk_label, multimask_output=True) + original_logits_high = TVF.resize(original_logits_high,resolution) original_logits_high = original_logits_high.flatten(1) for train_idx in range(args.train_epoch): From 8a3900dc9a646a92e6a178f12bf5b97364d225e1 Mon Sep 17 00:00:00 2001 From: Andrew Healey Date: Thu, 6 Jul 2023 14:32:59 -0400 Subject: [PATCH 06/11] Unsqueeze in multi_obj --- persam_f_multi_obj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/persam_f_multi_obj.py b/persam_f_multi_obj.py index 4442484..5652304 100644 --- a/persam_f_multi_obj.py +++ b/persam_f_multi_obj.py @@ -96,7 +96,7 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): resolution = [256, 256] - gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0 + gt_mask = torch.tensor(ref_mask)[None,:, :, 0] > 0 gt_mask = TVF.resize(gt_mask.float(), resolution) gt_mask = gt_mask.unsqueeze(0).flatten(1).cuda() From 38bc64e2a90a810e2c89141cfd9d46a88587fc91 Mon Sep 17 00:00:00 2001 From: Andrew Healey Date: Thu, 6 Jul 2023 14:35:01 -0400 Subject: [PATCH 07/11] Fix bug in existing repo --- persam_f_multi_obj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/persam_f_multi_obj.py b/persam_f_multi_obj.py index 5652304..f8ae5bb 100644 --- a/persam_f_multi_obj.py +++ b/persam_f_multi_obj.py @@ -138,7 +138,7 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): mask_weights.train() optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch_inside) # Run the decoder masks, scores, logits, original_logits_high = predictor.predict( From ee942c72459bc40e37f391282b9226ca1b43ac1e Mon Sep 17 00:00:00 2001 From: Andrew Healey Date: Thu, 6 Jul 2023 14:35:51 -0400 Subject: [PATCH 08/11] More fixes --- persam_f_multi_obj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/persam_f_multi_obj.py b/persam_f_multi_obj.py index f8ae5bb..a2bafc0 100644 --- a/persam_f_multi_obj.py +++ b/persam_f_multi_obj.py @@ -148,7 +148,7 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): original_logits_high = TVF.resize(original_logits_high,resolution) original_logits_high = original_logits_high.flatten(1) - for train_idx in range(args.train_epoch): + for train_idx in range(args.train_epoch_inside): # Weighted sum three-scale masks weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0) logits_high = original_logits_high * weights From fca0933c2a8fe12085ceacce7c64510f2e7a5bbb Mon Sep 17 00:00:00 2001 From: Andrew Healey Date: Thu, 6 Jul 2023 14:44:01 -0400 Subject: [PATCH 09/11] Fix my own bug --- persam_f_multi_obj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/persam_f_multi_obj.py b/persam_f_multi_obj.py index a2bafc0..f6c4bda 100644 --- a/persam_f_multi_obj.py +++ b/persam_f_multi_obj.py @@ -207,7 +207,7 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): topk_xy, topk_label = point_selection(sim, topk=1) # First-step prediction - masks, scores, logits, original_logits_high = predictor.predict( + masks, scores, logits, logits_high = predictor.predict( point_coords=topk_xy, point_labels=topk_label, multimask_output=True) From 558153ea51c9e0c782fc74fa6e9f89434432c15b Mon Sep 17 00:00:00 2001 From: Andrew Healey Date: Thu, 6 Jul 2023 19:36:11 -0400 Subject: [PATCH 10/11] Change output dir --- persam_f_multi_obj.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/persam_f_multi_obj.py b/persam_f_multi_obj.py index f6c4bda..654f618 100644 --- a/persam_f_multi_obj.py +++ b/persam_f_multi_obj.py @@ -22,7 +22,7 @@ def get_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--data', type=str, default='./data') - parser.add_argument('--outdir', type=str, default='persam_f') + parser.add_argument('--outdir', type=str, default='persam_f_multi_obj') parser.add_argument('--ckpt', type=str, default='./sam_vit_h_4b8939.pth') parser.add_argument('--sam_type', type=str, default='vit_h') @@ -268,14 +268,15 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): show_points(topk_xy, topk_label, plt.gca()) history_masks.append(mask_colors) # Save masks - + plt.imshow(test_image_original) vis_mask_output_path = os.path.join(output_path, f'vis_mask_{test_idx}_objects:{len(history_masks)}.jpg') with open(vis_mask_output_path, 'wb') as outfile: plt.savefig(outfile, format='jpg') + for i,mask in history_masks: - mask_output_path = os.path.join(output_path, test_idx + '.png') - cv2.imwrite(mask_output_path, mask_colors) + mask_output_path = os.path.join(output_path, f"{test_idx}_{i}.png") + cv2.imwrite(mask_output_path, mask_colors) From 38d5a807d1392d9f826bc8f5a85dbb6f1341c82c Mon Sep 17 00:00:00 2001 From: Andrew Healey Date: Thu, 6 Jul 2023 21:44:15 -0400 Subject: [PATCH 11/11] Switch out enumerate call --- persam_f_multi_obj.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/persam_f_multi_obj.py b/persam_f_multi_obj.py index 654f618..b38d292 100644 --- a/persam_f_multi_obj.py +++ b/persam_f_multi_obj.py @@ -273,13 +273,11 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): vis_mask_output_path = os.path.join(output_path, f'vis_mask_{test_idx}_objects:{len(history_masks)}.jpg') with open(vis_mask_output_path, 'wb') as outfile: plt.savefig(outfile, format='jpg') - for i,mask in history_masks: + for i,mask in enumerate(history_masks): mask_output_path = os.path.join(output_path, f"{test_idx}_{i}.png") cv2.imwrite(mask_output_path, mask_colors) - - class Mask_Weights(nn.Module): def __init__(self): super().__init__()