-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathdata_provider.py
More file actions
38 lines (31 loc) · 1.27 KB
/
data_provider.py
File metadata and controls
38 lines (31 loc) · 1.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
import tensorflow as tf
_FILE_PATTERN = 'FACE_%s.tfrecord'
dataset_dir = 'data'
reader = tf.TFRecordReader()
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='raw'),
'image/class/label': tf.FixedLenFeature([1], tf.int64),
}
num_classes = 2
def get_data(split_name):
file_pattern = os.path.join(dataset_dir, _FILE_PATTERN % split_name)
filename_queue = tf.train.string_input_producer([file_pattern])
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,features = keys_to_features )
#image = tf.decode_raw(features['image/encoded'], tf.uint8)
image = tf.image.decode_png(features['image/encoded'])
#label = tf.cast(features['image/class/label'],tf.float32)
label = tf.one_hot(features['image/class/label'], num_classes)
label = tf.reshape(label, shape=(num_classes,))
print ("label:", label)
image = tf.image.convert_image_dtype(image, tf.float32)
image -= 0.5
image *= 2
image = tf.reshape(image, shape=(64*64,))
print (image, label)
return (image, label)
#test_image, test_label = get_data("test")
#print (test_image)
#print (test_label)