-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
53 lines (49 loc) · 2.2 KB
/
evaluate.py
File metadata and controls
53 lines (49 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
import torch
from torch.utils.data import DataLoader
from models.segmentor import build_model
from data.dataset import SegmentationDataset
from data.transforms import get_val_transforms
from engine.evaluator import Evaluator
from utils.checkpoint import load_model_weights
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, required=True)
parser.add_argument('--mask_dir', type=str, required=True)
parser.add_argument('--checkpoint', type=str, required=True)
parser.add_argument('--num_classes', type=int, default=21)
parser.add_argument('--image_size', type=int, default=512)
parser.add_argument('--backbone_variant', type=str, default='small')
parser.add_argument('--fpn_channels', type=int, default=256)
parser.add_argument('--seg_inner_channels', type=int, default=128)
parser.add_argument('--use_light_head', action='store_true')
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--image_ext', type=str, default='png')
parser.add_argument('--mask_ext', type=str, default='png')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
transform = get_val_transforms(args.image_size)
dataset = SegmentationDataset(
args.image_dir, args.mask_dir, args.num_classes,
transform=transform, image_ext=args.image_ext, mask_ext=args.mask_ext
)
dataloader = DataLoader(
dataset, batch_size=args.batch_size,
shuffle=False, num_workers=args.num_workers, pin_memory=True
)
model = build_model(
num_classes=args.num_classes,
backbone_variant=args.backbone_variant,
fpn_channels=args.fpn_channels,
seg_inner_channels=args.seg_inner_channels,
use_light_head=args.use_light_head
)
model = load_model_weights(args.checkpoint, model)
model = model.to(device)
evaluator = Evaluator(model, device, args.num_classes)
results = evaluator.evaluate(dataloader)
evaluator.print_results(results)
if __name__ == '__main__':
main()