-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest.py
More file actions
81 lines (62 loc) · 2.26 KB
/
test.py
File metadata and controls
81 lines (62 loc) · 2.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import cv2
import torch
import torch.nn as nn
device = torch.device("cuda")
model = nn.Sequential(
nn.Conv2d(3, 16, (3, 3), padding=(1, 1)),
nn.LeakyReLU(0.1),
nn.Conv2d(16, 32, (3, 3), padding=(1, 1)),
nn.LeakyReLU(0.1),
nn.Conv2d(32, 64, (3, 3), padding=(1, 1)),
nn.LeakyReLU(0.1),
nn.Conv2d(64, 128, (3, 3), padding=(1, 1)),
nn.LeakyReLU(0.1),
nn.Conv2d(128, 128, (3, 3), padding=(1, 1)),
nn.LeakyReLU(0.1),
nn.Conv2d(128, 256, (3, 3), padding=(1, 1)),
nn.LeakyReLU(0.1),
nn.ConvTranspose2d(256, 3, (4, 4), (2, 2), (1, 1), (0, 0), bias=False),
nn.Sigmoid()
).to(device)
weights = torch.load("weights.pth")
model.load_state_dict(weights)
def img_to_input(img):
img = np.swapaxes(np.swapaxes(np.array(img, dtype=float), 0, 2), 1, 2) / 255.0
shape = (1,) + img.shape
return torch.from_numpy(img.reshape(shape)).float().cuda()
def output_to_img(output):
img = output.cpu().detach().numpy()
img = img.reshape(img.shape[1:])
return np.swapaxes(np.swapaxes(np.array(img, dtype=float), 0, 2), 0, 1)
if __name__ == "__main__":
model.eval()
with torch.no_grad():
img_in = cv2.imread("images/28.jpg")
input = img_to_input(img_in)
output = model(input)
img_out = output_to_img(output)
cv2.imshow("input", img_in)
cv2.moveWindow("input", 0, 0)
cv2.imshow("output", img_out)
cv2.moveWindow("output", 500, 0)
cv2.waitKey()
cv2.destroyAllWindows()
for i in range(1, 2579):
img_in = cv2.imread("train/inputs/input_{:04d}.bmp".format(i))
input = img_to_input(img_in[:,:,::-1])
output = model(input)
img_out = output_to_img(output)[:,:,::-1]
print(img_in.shape)
print(input.shape)
print(output.shape)
print(img_out.shape)
target = cv2.imread("train/outputs/output_{:04d}.png".format(i))
cv2.imshow("input", img_in)
cv2.moveWindow("input", 0, 0)
cv2.imshow("output", img_out)
cv2.moveWindow("output", 500, 0)
cv2.imshow("target", target)
cv2.moveWindow("target", 1000, 0)
cv2.waitKey()
cv2.destroyAllWindows()