-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
64 lines (44 loc) · 2.38 KB
/
test.py
File metadata and controls
64 lines (44 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import os
import pprint
from pathlib import Path
import torch
from plt2pix import plt2pix
def parse_args():
desc = "plt2pix: Line Art Colorization using palettes"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--model', type=str, default='tag2pix', choices=['tag2pix', 'senet', 'resnext', 'catconv', 'catall', 'adain', 'seadain'],
help='Model Types. (default: tag2pix == SECat)')
parser.add_argument('--cpu', action='store_true', help='If set, use cpu only')
parser.add_argument('--input_size', type=int, default=256, help='Width / Height of input image (must be rectangular)')
parser.add_argument('--load', type=str, default="", required=True, help='Path to load network weights')
parser.add_argument('--load_crn', type=str, default="", help='Path to load CRN network weights')
parser.add_argument('--color_space', type=str, default='rgb', choices=['lab', 'rgb', 'hsv'], help='color space of images')
parser.add_argument('--layers', type=int, nargs='+', default=[12,8,5,5],
help='Block counts of each U-Net Decoder blocks of generator. The first argument is count of bottom block.')
parser.add_argument('--use_relu', action='store_true', help='Apply ReLU to colorFC')
parser.add_argument('--no_bn', action='store_true', help='Remove every BN Layer from Generator')
parser.add_argument('--no_guide', action='store_true', help='Remove guide decoder from Generator. If set, Generator will return same G_f: like (G_f, G_f)')
# Palette based colorization
parser.add_argument('--palette_num', type=int, default=5, help='Number of palette colors')
args = parser.parse_args()
return args
def main():
args = parse_args()
model = plt2pix(args)
# Example colorization
import numpy as np
from PIL import Image
img = Image.open('testset/000099.png')
# (R, G, B, R, G, B, ..., R, G, B)
# palette = np.array([-255] * 15) # Default palette
palette = np.array([201, 207, 216, 83, 127, 164, 124, 156, 189, 116, 132, 146, 124, 140, 148])
output = model.colorize(img, palette)
output = Image.fromarray(output)
output.save('test_result.png')
# palette = model.recommend_color(img)
# output = model.colorize(img, palette)
# output = Image.fromarray(output)
# output.save('recommend_result.png')
if __name__ == '__main__':
main()