-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_neck_data.py
More file actions
executable file
·115 lines (81 loc) · 3.67 KB
/
train_neck_data.py
File metadata and controls
executable file
·115 lines (81 loc) · 3.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/python
import tensorflow as tf
import tensorflow.contrib.slim.nets as nets
import numpy as np
from tensorflow.python.tools import inspect_checkpoint as chkp
from sklearn.metrics import average_precision_score
VAL_SIZE = 0.05
images = np.load('/home/yasaman/HN/neck_images.npy')
labels = np.load('/home/yasaman/HN/neck_labels_cor.npy')
slim = tf.contrib.slim
vgg = nets.vgg
labels = np.transpose(np.asarray([labels, 1-labels]))
images = np.reshape(images, [-1, 224, 224, 1])
# images are black and white but vgg16 needs 3 channels
images = np.repeat(images, 3, axis=3)
test_idx = np.random.choice(images.shape[0], int(VAL_SIZE * images.shape[0]), replace=False)
test_images = images[test_idx]
test_labels = labels[test_idx]
test_idx.sort()
test_idx = test_idx[::-1]
images = np.delete(images, test_idx, axis=0)
labels = np.delete(labels, test_idx, axis=0)
# removed validation set from training set
print(sum(labels[:,0]), images.shape, "val set", test_images.shape)
in_images = tf.placeholder(images.dtype, images.shape)
in_labels = tf.placeholder(labels.dtype, labels.shape)
val_images = tf.placeholder(test_images.dtype, test_images.shape)
val_labels = tf.placeholder(test_labels.dtype, test_labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((in_images, in_labels))
dataset = dataset.shuffle(6000)
dataset = dataset.batch(64)
dataset = dataset.repeat()
#dataset_it = dataset.make_initializable_iterator()
val_set = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_set = val_set.batch(test_images.shape[0])
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
next_images, next_labels = iterator.get_next()
logits, intermed = vgg.vgg_16(next_images, num_classes=2)
prob = tf.nn.softmax(logits)
train_init_op = iterator.make_initializer(dataset)
val_init_op = iterator.make_initializer(val_set)
loss = tf.losses.softmax_cross_entropy(next_labels, logits)
total_loss = tf.losses.get_total_loss()
tf.summary.scalar('xentropy loss', total_loss)
train = tf.train.GradientDescentOptimizer(0.001).minimize(total_loss)
img_net_path = '/home/yasaman/HN/image_net_trained/vgg_16.ckpt'
us_path = '/home/yasaman/HN/neck_us_trained'
#inspecting checkpoint file
chkp.print_tensors_in_checkpoint_file(img_net_path, tensor_name='', all_tensors=False, all_tensor_names=True)
# restoring only convolutional layers
scratch_variables = ['vgg_16/fc6', 'vgg_16/fc7', 'vgg_16/fc8']
restored_variables = tf.contrib.framework.get_variables_to_restore(exclude=scratch_variables)
print("restored variables ....", restored_variables)
restorer = tf.train.Saver(restored_variables)
saver = tf.train.Saver()
# for variables that have to be initialized from scratch
detailed_scratch_vars = []
for layer in scratch_variables:
detailed_scratch_vars.extend(tf.contrib.framework.get_variables(scope=layer))
print("fc layers.....", detailed_scratch_vars)
init_scratch = tf.variables_initializer(detailed_scratch_vars)
with tf.Session() as sess:
restorer.restore(sess, img_net_path)
sess.run(init_scratch)
sess.run(train_init_op, feed_dict={in_images:images,in_labels:labels})
merged_summaries = tf.summary.merge_all()
writer = tf.summary.FileWriter('/home/yasaman/HN/run_log/', sess.graph)
for i in range(5000):
sess.run(train)
if(i%100 == 0):
summ = sess.run(merged_summaries)
writer.add_summary(summ, i)
print(sess.run(total_loss))
saver.save(sess, us_path)
# print error on evaluation set
sess.run(val_init_op, feed_dict={val_images:test_images,
val_labels:test_labels})
probabilities = sess.run(prob)
writer.close()
avg_prec_score = average_precision_score(test_labels, probabilities)
print("average precision score on validation set", avg_prec_score)