forked from elleryqueenhomels/fast_neural_style_transfer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
65 lines (44 loc) · 2.51 KB
/
generate.py
File metadata and controls
65 lines (44 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Use a trained Image Transform Net to generate
# a style transferred image with a specific style
import tensorflow as tf
import image_transform_net as itn
from utils import get_images, save_images
def generate(contents_path, model_path, is_same_size=False, resize_height=None, resize_width=None, save_path=None, prefix='stylized-', suffix=None):
if isinstance(contents_path, str):
contents_path = [contents_path]
if is_same_size or (resize_height is not None and resize_width is not None):
outputs = _handler1(contents_path, model_path, resize_height=resize_height, resize_width=resize_width, save_path=save_path, prefix=prefix, suffix=suffix)
return list(outputs)
else:
outputs = _handler2(contents_path, model_path, save_path=save_path, prefix=prefix, suffix=suffix)
return outputs
def _handler1(content_path, model_path, resize_height=None, resize_width=None, save_path=None, prefix=None, suffix=None):
# get the actual image data, output shape: (num_images, height, width, color_channels)
content_target = get_images(content_path, resize_height, resize_width)
with tf.Graph().as_default(), tf.Session() as sess:
# build the dataflow graph
content_image = tf.placeholder(tf.float32, shape=content_target.shape, name='content_image')
output_image = itn.transform(content_image)
# restore the trained model and run the style transferring
saver = tf.train.Saver()
saver.restore(sess, model_path)
output = sess.run(output_image, feed_dict={content_image: content_target})
if save_path is not None:
save_images(content_path, output, save_path, prefix=prefix, suffix=suffix)
return output
def _handler2(content_path, model_path, save_path=None, prefix=None, suffix=None):
with tf.Graph().as_default(), tf.Session() as sess:
# build the dataflow graph
content_image = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='content_image')
output_image = itn.transform(content_image)
# restore the trained model and run the style transferring
saver = tf.train.Saver()
saver.restore(sess, model_path)
output = []
for content in content_path:
content_target = get_images(content)
result = sess.run(output_image, feed_dict={content_image: content_target})
output.append(result[0])
if save_path is not None:
save_images(content_path, output, save_path, prefix=prefix, suffix=suffix)
return output