forked from elleryqueenhomels/fast_neural_style_transfer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
69 lines (44 loc) · 2.08 KB
/
main.py
File metadata and controls
69 lines (44 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Demo - train the style transfer network & use it to generate an image
from __future__ import print_function
from train import train
from generate import generate
from utils import list_images
IS_TRAINING = False
VGG_PATH = './imagenet-vgg-19-weights.npz'
# format: {'style': [content_weight, style_weight, tv_weight]}
STYLES = {
'wave': [1.0, 7.0, 1e-2],
'udnie': [1.0, 12.0, 1e-2],
'escher_sphere': [1.0, 60.0, 1e-2],
'flower': [1.0, 10.0, 1e-2],
'scream': [1.0, 60.0, 1e-2],
'denoised_starry': [1.0, 16.0, 1e-2],
'starry_bright': [1.0, 6.0, 1e-2],
'rain_princess': [1.0, 8.0, 1e-2],
'woman_matisse': [1.0, 20.0, 1e-2],
'mosaic': [1.0, 5.0, 0.0],
}
def main():
if IS_TRAINING:
content_targets = list_images('./MS_COCO') # path to training dataset
for style in list(STYLES.keys()):
print('\nBegin to train the network with the style "%s"...\n' % style)
content_weight, style_weight, tv_weight = STYLES[style]
style_target = 'images/style/' + style + '.jpg'
model_save_path = 'models/' + style + '.ckpt-done'
train(content_targets, style_target, content_weight, style_weight, tv_weight,
vgg_path=VGG_PATH, save_path=model_save_path, debug=True)
print('\nSuccessfully! Done training style "%s"...\n' % style)
print('Successfully finish all the training...\n')
else:
for style in list(STYLES.keys()):
print('\nBegin to generate pictures with the style "%s"...\n' % style)
model_path = 'models/' + style + '.ckpt-done'
output_save_path = 'outputs'
content_targets = list_images('images/content')
generated_images = generate(content_targets, model_path, save_path=output_save_path,
prefix=style + '-')
print('\ntype(generated_images):', type(generated_images))
print('\nlen(generated_images):', len(generated_images), '\n')
if __name__ == '__main__':
main()