-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathconv_visualisation.py
More file actions
81 lines (72 loc) · 2.47 KB
/
conv_visualisation.py
File metadata and controls
81 lines (72 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import matplotlib as mp
#%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.contrib.slim as slim
import math
from utils import *
from PIL import Image
#model creation
x = tf.placeholder('float', shape=(None, 1, 7, 7), name='input_x')
y = tf.placeholder('float')
n_classes=76
net = x
# First convolutional layer.
hidden_1 = tf.layers.conv2d(inputs=net, name='layer_conv1',
filters=49, kernel_size=3, strides=1,
padding='same', activation=tf.nn.relu)
# Second convolutional layer.
hidden_2 = tf.layers.conv2d(inputs=hidden_1, name='layer_conv2',
filters=49, kernel_size=3, strides=1,
padding='same', activation=tf.nn.relu)
#Pooling Layer
pool_1 = tf.nn.pool(input=hidden_2, name='pool', pooling_type='MAX',
padding='SAME', window_shape=(2, 2))
net = tf.contrib.layers.flatten(pool_1)
net = tf.layers.dropout(inputs=net, rate=0.4, training=False)
output = tf.layers.dense(inputs=net, name='o', units=n_classes, activation=None)
#the tensorflow session is created and
#the trained model variables are restored from file
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, "trained_conv_net-4")
print("Model restored.")
def getActivations(layer,stimuli):
units = sess.run(layer,feed_dict={x:stimuli})
plotNNFilter(units)
'''def plotNNFilter(units):
filters = units.shape[3]
plt.figure(1, figsize=(100,200))
n_columns = 6
n_rows = math.ceil(filters / n_columns) + 1
for i in range(filters):
plt.subplot(n_rows, n_columns, i+1)
plt.title('Filter ' + str(i))
plt.imshow(units[0,:,:,i], cmap="gray")'''
def plotNNFilter(units):
units = np.clip(units, 0, 1)
im = units
im = np.reshape(im, (units.shape[-2], units.shape[-1]))
im = Image.fromarray(im*255)
im.show()
#print(np.shape(units))
#print(units)
filters = units.shape[3]
for x in range(filters):
im = units[:,:,:,x]
im = np.reshape(im, (units.shape[-2], units.shape[-1]))
im = Image.fromarray(im)
im.show()
target = np.array(generate_target(7, 7, 0.5)).reshape((7, 7))
#target = np.zeros((7, 7))
#target[3:5, 3:5] = 1
#target[3, 3] = 1
#target[5, 3] = 1
#target[4, 3] = 1
#target[4, 4] = 1
im = target*255
im = Image.fromarray(im.astype('uint8'))
#im.show()
getActivations(hidden_1,target.reshape(-1, 1, 7, 7))