-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbackground-removal.py
More file actions
65 lines (43 loc) · 1.89 KB
/
background-removal.py
File metadata and controls
65 lines (43 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import argparse
import warnings
warnings.filterwarnings('ignore')
from PIL import Image
from lang_sam import LangSAM
import numpy as np
# image name from images in input directory
parser = argparse.ArgumentParser()
parser.add_argument("--image_dir", "-d", type=str, default='input/', help="path to image directory")
parser.add_argument("--image_name", "-i", type=str, default='INHS_FISH_005052.jpg', help="image name with extension")
parser.add_argument("--out_dir", "-o", type=str, default='output/', help="path to output")
args = parser.parse_args()
img_name = args.image_name
input_dir = args.image_dir
out_dir = args.out_dir
# loading the pillow image
img = Image.open(os.path.join(input_dir, f'{img_name}'))
image_pil = img.convert("RGB")
# running the GroundedSAM
text_prompt, BOX_THRESHOLD = "fish", 0.30
model = LangSAM()
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt, box_threshold=BOX_THRESHOLD)
# check whether the GroundedSAM finds a fish mask or not
if len(masks) == 0:
print('GroundedSAM is not able to find a fish in the image.')
else:
# only condisering the first seg map
mask = masks[0].detach().cpu().numpy()
# bg-removed image -> mod_img
img_array = np.asarray(img).copy()
z_r, z_c = np.where(mask == 0)
img_array[z_r, z_c] = np.array([255, 255, 255])
mod_img = Image.fromarray(img_array.astype(np.uint8))
# save the mask and mod_img
mask_pil = Image.fromarray((mask*255).astype(np.uint8))
mask_dir = os.path.join(out_dir, 'mask')
os.makedirs(mask_dir, exist_ok=True)
bg_dir = os.path.join(out_dir, 'bg_removed')
os.makedirs(bg_dir, exist_ok=True)
mask_pil.save(os.path.join(mask_dir, f'{img_name}'))
mod_img.save(os.path.join(bg_dir, f'{img_name}'))
print('Mask area: {:.2f}% of the entire image'.format(100*mask.sum()/(mask.shape[0]* mask.shape[1])))