diff --git a/persam_f.py b/persam_f.py index 443d8fc..3321ffe 100644 --- a/persam_f.py +++ b/persam_f.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from torch.nn import functional as F +import torchvision.transforms.functional as TVF import os import cv2 @@ -50,6 +51,7 @@ def main(): if ".DS" not in obj_name: persam_f(args, obj_name, images_path, masks_path, output_path) +sam = None def persam_f(args, obj_name, images_path, masks_path, output_path): @@ -70,19 +72,23 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): ref_mask = cv2.imread(ref_mask_path) ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) - gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0 - gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda() + resolution = [256, 256] + gt_mask = torch.tensor(ref_mask)[None,:, :, 0] > 0 + gt_mask = TVF.resize(gt_mask.float(), resolution) + gt_mask = gt_mask.flatten(1).cuda() print("======> Load SAM" ) - if args.sam_type == 'vit_h': - sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() - elif args.sam_type == 'vit_t': - sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' - device = "cuda" if torch.cuda.is_available() else "cpu" - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) - sam.eval() + global sam + if sam is None: + if args.sam_type == 'vit_h': + sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() + elif args.sam_type == 'vit_t': + sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' + device = "cuda" if torch.cuda.is_available() else "cpu" + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) + sam.eval() for name, param in sam.named_parameters(): @@ -122,6 +128,8 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): topk_xy, topk_label = point_selection(sim, topk=1) + + print('======> Start Training') # Learnable mask weights mask_weights = Mask_Weights().cuda() @@ -130,18 +138,19 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch) - for train_idx in range(args.train_epoch): - - # Run the decoder - masks, scores, logits, logits_high = predictor.predict( - point_coords=topk_xy, - point_labels=topk_label, - multimask_output=True) - logits_high = logits_high.flatten(1) + # Run the decoder + masks, scores, logits, original_logits_high = predictor.predict( + point_coords=topk_xy, + point_labels=topk_label, + multimask_output=True) + + original_logits_high = TVF.resize(original_logits_high,resolution) + original_logits_high = original_logits_high.flatten(1) + for train_idx in range(args.train_epoch): # Weighted sum three-scale masks weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0) - logits_high = logits_high * weights + logits_high = original_logits_high * weights logits_high = logits_high.sum(0).unsqueeze(0) dice_loss = calculate_dice_loss(logits_high, gt_mask) diff --git a/persam_f_multi_obj.py b/persam_f_multi_obj.py index 22d8a9c..b38d292 100644 --- a/persam_f_multi_obj.py +++ b/persam_f_multi_obj.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from torch.nn import functional as F +import torchvision.transforms.functional as TVF import os import cv2 @@ -21,7 +22,7 @@ def get_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--data', type=str, default='./data') - parser.add_argument('--outdir', type=str, default='persam_f') + parser.add_argument('--outdir', type=str, default='persam_f_multi_obj') parser.add_argument('--ckpt', type=str, default='./sam_vit_h_4b8939.pth') parser.add_argument('--sam_type', type=str, default='vit_h') @@ -55,17 +56,20 @@ def main(): for obj_name in os.listdir(images_path): persam_f(args, obj_name, images_path, masks_path, output_path) +sam = None def persam_f(args, obj_name, images_path, masks_path, output_path): print("======> Load SAM" ) - if args.sam_type == 'vit_h': - sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() - elif args.sam_type == 'vit_t': - sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' - device = "cuda" if torch.cuda.is_available() else "cpu" - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) - sam.eval() + global sam + if sam is None: + if args.sam_type == 'vit_h': + sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() + elif args.sam_type == 'vit_t': + sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' + device = "cuda" if torch.cuda.is_available() else "cpu" + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) + sam.eval() for name, param in sam.named_parameters(): @@ -90,8 +94,11 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): ref_mask = cv2.imread(ref_mask_path) ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) - gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0 - gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda() + resolution = [256, 256] + + gt_mask = torch.tensor(ref_mask)[None,:, :, 0] > 0 + gt_mask = TVF.resize(gt_mask.float(), resolution) + gt_mask = gt_mask.unsqueeze(0).flatten(1).cuda() # print("======> Obtain Self Location Prior" ) # Image features encoding @@ -133,18 +140,18 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch_inside) - for train_idx in range(args.train_epoch_inside): - - # Run the decoder - masks, scores, logits, logits_high = predictor.predict( - point_coords=topk_xy, - point_labels=topk_label, - multimask_output=True) - logits_high = logits_high.flatten(1) + # Run the decoder + masks, scores, logits, original_logits_high = predictor.predict( + point_coords=topk_xy, + point_labels=topk_label, + multimask_output=True) + original_logits_high = TVF.resize(original_logits_high,resolution) + original_logits_high = original_logits_high.flatten(1) + for train_idx in range(args.train_epoch_inside): # Weighted sum three-scale masks weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0) - logits_high = logits_high * weights + logits_high = original_logits_high * weights logits_high = logits_high.sum(0).unsqueeze(0) dice_loss = calculate_dice_loss(logits_high, gt_mask) @@ -261,16 +268,15 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): show_points(topk_xy, topk_label, plt.gca()) history_masks.append(mask_colors) # Save masks - + plt.imshow(test_image_original) vis_mask_output_path = os.path.join(output_path, f'vis_mask_{test_idx}_objects:{len(history_masks)}.jpg') with open(vis_mask_output_path, 'wb') as outfile: plt.savefig(outfile, format='jpg') + for i,mask in enumerate(history_masks): - mask_output_path = os.path.join(output_path, test_idx + '.png') - cv2.imwrite(mask_output_path, mask_colors) - - + mask_output_path = os.path.join(output_path, f"{test_idx}_{i}.png") + cv2.imwrite(mask_output_path, mask_colors) class Mask_Weights(nn.Module): def __init__(self): diff --git a/requirements.txt b/requirements.txt index ddb4470..20c0736 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ numpy warnings argparse opencv-python +torchvision