-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
62 lines (58 loc) · 2.63 KB
/
predict.py
File metadata and controls
62 lines (58 loc) · 2.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import argparse
import glob
import numpy as np
import torch
from PIL import Image
from models.segmentor import build_model
from data.transforms import get_test_transforms
from utils.checkpoint import load_model_weights
from utils.visualization import colorize_mask, save_prediction, denormalize
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, required=True)
parser.add_argument('--output_dir', type=str, required=True)
parser.add_argument('--checkpoint', type=str, required=True)
parser.add_argument('--num_classes', type=int, default=21)
parser.add_argument('--image_size', type=int, default=512)
parser.add_argument('--backbone_variant', type=str, default='small')
parser.add_argument('--fpn_channels', type=int, default=256)
parser.add_argument('--seg_inner_channels', type=int, default=128)
parser.add_argument('--use_light_head', action='store_true')
parser.add_argument('--image_ext', type=str, default='png')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--save_overlay', action='store_true')
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
os.makedirs(args.output_dir, exist_ok=True)
transform = get_test_transforms(args.image_size)
model = build_model(
num_classes=args.num_classes,
backbone_variant=args.backbone_variant,
fpn_channels=args.fpn_channels,
seg_inner_channels=args.seg_inner_channels,
use_light_head=args.use_light_head
)
model = load_model_weights(args.checkpoint, model)
model = model.to(device)
model.eval()
image_paths = sorted(glob.glob(os.path.join(args.image_dir, f'*.{args.image_ext}')))
for img_path in image_paths:
image_np = np.array(Image.open(img_path).convert('RGB'))
original_h, original_w = image_np.shape[:2]
transformed = transform(image=image_np)
image_tensor = transformed['image'].unsqueeze(0).to(device)
with torch.no_grad():
pred = model.predict(image_tensor, target_size=(original_h, original_w))
pred = pred.squeeze(0).cpu().numpy()
basename = os.path.splitext(os.path.basename(img_path))[0]
mask_image = Image.fromarray(pred.astype(np.uint8))
mask_image.save(os.path.join(args.output_dir, f'{basename}_pred.png'))
if args.save_overlay:
save_prediction(
image_np, pred,
os.path.join(args.output_dir, f'{basename}_overlay.png')
)
print(f'Predictions saved to {args.output_dir}')
if __name__ == '__main__':
main()