forked from LongMa319/SCI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
62 lines (48 loc) · 1.93 KB
/
test.py
File metadata and controls
62 lines (48 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import sys
import numpy as np
import torch
import argparse
import torch.utils
import torch.backends.cudnn as cudnn
from PIL import Image
from torch.autograd import Variable
from model import Finetunemodel
from multi_read_data import MemoryFriendlyLoader
parser = argparse.ArgumentParser("SCI")
parser.add_argument('--data_path', type=str, default='./data/medium',
help='location of the data corpus')
parser.add_argument('--save_path', type=str, default='./results/medium', help='location of the data corpus')
parser.add_argument('--model', type=str, default='./weights/medium.pt', help='location of the data corpus')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--seed', type=int, default=2, help='random seed')
args = parser.parse_args()
save_path = args.save_path
os.makedirs(save_path, exist_ok=True)
TestDataset = MemoryFriendlyLoader(img_dir=args.data_path, task='test')
test_queue = torch.utils.data.DataLoader(
TestDataset, batch_size=1,
pin_memory=True, num_workers=0)
def save_images(tensor, path):
image_numpy = tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
im = Image.fromarray(np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8'))
im.save(path, 'png')
def main():
if not torch.cuda.is_available():
print('no gpu device available')
sys.exit(1)
model = Finetunemodel(args.model)
model = model.cuda()
model.eval()
with torch.no_grad():
for _, (input, image_name) in enumerate(test_queue):
input = Variable(input, volatile=True).cuda()
image_name = image_name[0].split('\\')[-1].split('.')[0]
i, r = model(input)
u_name = '%s.png' % (image_name)
print('processing {}'.format(u_name))
u_path = save_path + '/' + u_name
save_images(r, u_path)
if __name__ == '__main__':
main()