-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcapsule_layers.py
More file actions
300 lines (248 loc) · 11.9 KB
/
capsule_layers.py
File metadata and controls
300 lines (248 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
'''
Diagnosing Colorectal Polyps in the Wild with Capsule Networks (D-Caps)
Original Paper by Rodney LaLonde, Pujan Kandel, Concetto Spampinato, Michael B. Wallace, and Ulas Bagci
Paper published at ISBI 2020: arXiv version (https://arxiv.org/abs/2001.03305)
Code written by: Rodney LaLonde
If you use significant portions of this code or the ideas from our paper, please cite it :)
If you have any questions, please email me at lalonde@knights.ucf.edu.
These are all the capsule layers needed for D-Caps.
'''
import keras.backend as K
import tensorflow as tf
from keras import initializers, layers
from keras.utils.conv_utils import conv_output_length
import numpy as np
import math
class ExpandDim(layers.Layer):
def call(self, inputs, **kwargs):
return K.expand_dims(inputs, axis=-2)
def compute_output_shape(self, input_shape):
return (input_shape[0:-1] + (1,) + input_shape[-1:])
def get_config(self):
config = {}
base_config = super(ExpandDim, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class RemoveDim(layers.Layer):
def call(self, inputs, **kwargs):
return K.squeeze(inputs, axis=-2)
def compute_output_shape(self, input_shape):
return (input_shape[0:-2] + input_shape[-1:])
def get_config(self):
config = {}
base_config = super(RemoveDim, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class Length(layers.Layer):
"""
Compute the length of vectors. This is used to compute a Tensor that has the same shape with y_true in margin_loss.
Using this layer as model's output can directly predict labels by using `y_pred = np.argmax(model.predict(x), 1)`
inputs: shape=[None, num_vectors, dim_vector] or shape=[None, height, width, 1, dim_vector]
output: shape=[None, num_vectors]
"""
def __init__(self, num_classes, seg=True, **kwargs):
super(Length, self).__init__(**kwargs)
if num_classes == 2:
self.num_classes = 1
else:
self.num_classes = num_classes
def call(self, inputs, **kwargs):
if inputs.get_shape()[-2].value != self.num_classes:
assert inputs.get_shape().ndims == 2, 'Error: Must have num_capsules = num_classes going into Length else have dimensions (batch size, atoms)'
return tf.norm(inputs, axis=-1)
def compute_output_shape(self, input_shape):
return input_shape[:-1]
def get_config(self):
config = {'num_classes': self.num_classes}
base_config = super(Length, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class Mask(layers.Layer):
"""
Mask a Tensor with shape=[None, num_capsule, dim_vector] either by the capsule with max length or by an additional
input mask. Except the max-length capsule (or specified capsule), all vectors are masked to zeros. Then flatten the
masked Tensor.
For example:
```
x = keras.layers.Input(shape=[8, 3, 2]) # batch_size=8, each sample contains 3 capsules with dim_vector=2
y = keras.layers.Input(shape=[8, 3]) # True labels. 8 samples, 3 classes, one-hot coding.
out = Mask()(x) # out.shape=[8, 6]
# or
out2 = Mask()([x, y]) # out2.shape=[8,6]. Masked with true labels y. Of course y can also be manipulated.
```
"""
def __init__(self, resize_masks=False, **kwargs):
super(Mask, self).__init__(**kwargs)
self.resize_masks = resize_masks
def call(self, inputs, **kwargs):
if type(inputs) is list:
assert len(inputs) == 2
inputs, mask = inputs
else:
x = K.sqrt(K.sum(K.square(inputs), -1))
mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1])
masked = K.batch_flatten(inputs * K.expand_dims(mask, -1))
return masked
def compute_output_shape(self, input_shape):
if type(input_shape[0]) is tuple: # true label provided
if len(input_shape[0]) == 3:
return tuple([None, input_shape[0][1] * input_shape[0][2]])
else:
return input_shape[0][0:-2] + input_shape[0][-1:]
else: # no true label provided
if len(input_shape) == 3:
return tuple([None, input_shape[1] * input_shape[2]])
else:
return input_shape[0:-2] + input_shape[-1:]
def get_config(self):
config = {'resize_masks': self.resize_masks}
base_config = super(Mask, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class ConvCapsuleLayer(layers.Layer):
def __init__(self, kernel_size, num_capsule, num_atoms, strides=1, padding='same', routings=3, leaky_routing=False,
kernel_initializer='he_normal', **kwargs):
super(ConvCapsuleLayer, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.num_capsule = num_capsule
self.num_atoms = num_atoms
self.strides = strides
self.padding = padding
self.routings = routings
self.leaky_routing = leaky_routing
self.kernel_initializer = initializers.get(kernel_initializer)
def build(self, input_shape):
assert len(input_shape) == 5, "The input Tensor should have shape=[None, input_height, input_width," \
" input_num_capsule, input_num_atoms]"
self.input_height = input_shape[1]
self.input_width = input_shape[2]
self.input_num_capsule = input_shape[3]
self.input_num_atoms = input_shape[4]
# Transform matrix
self.W = self.add_weight(shape=[self.kernel_size, self.kernel_size,
self.input_num_atoms, self.num_capsule * self.num_atoms],
initializer=self.kernel_initializer,
name='W')
self.b = self.add_weight(shape=[1, 1, self.num_capsule, self.num_atoms],
initializer=initializers.constant(0.1),
name='b')
self.built = True
def call(self, input_tensor, training=None):
input_shape = K.shape(input_tensor)
_, in_height, in_width, _, _ = input_tensor.get_shape()
input_transposed = tf.transpose(input_tensor, [0, 3, 1, 2, 4])
input_tensor_reshaped = K.reshape(input_transposed, [
input_shape[0] * input_shape[3], input_shape[1], input_shape[2], self.input_num_atoms])
input_tensor_reshaped.set_shape((None, in_height.value, in_width.value, self.input_num_atoms))
conv = K.conv2d(input_tensor_reshaped, self.W, (self.strides, self.strides),
padding=self.padding, data_format='channels_last')
votes_shape = K.shape(conv)
_, conv_height, conv_width, _ = conv.get_shape()
# Reshape back to 6D by splitting first dimmension to batch and input_dim
# and splitting last dimmension to output_dim and output_atoms.
votes = K.reshape(conv, [input_shape[0], input_shape[3], votes_shape[1], votes_shape[2],
self.num_capsule, self.num_atoms])
votes.set_shape((None, self.input_num_capsule, conv_height.value, conv_width.value,
self.num_capsule, self.num_atoms))
logit_shape = K.stack([
input_shape[0], input_shape[3], votes_shape[1], votes_shape[2], self.num_capsule])
biases_replicated = K.tile(self.b, [votes_shape[1], votes_shape[2], 1, 1])
activations = _update_routing(
votes=votes,
biases=biases_replicated,
logit_shape=logit_shape,
num_dims=6,
input_dim=self.input_num_capsule,
output_dim=self.num_capsule,
num_routing=self.routings,
leaky=self.leaky_routing)
return activations
def compute_output_shape(self, input_shape):
space = input_shape[1:-2]
new_space = []
for i in range(len(space)):
new_dim = conv_output_length(
space[i],
self.kernel_size,
padding=self.padding,
stride=self.strides,
dilation=1)
new_space.append(new_dim)
return (input_shape[0],) + tuple(new_space) + (self.num_capsule, self.num_atoms)
def get_config(self):
config = {
'kernel_size': self.kernel_size,
'num_capsule': self.num_capsule,
'num_atoms': self.num_atoms,
'strides': self.strides,
'padding': self.padding,
'routings': self.routings,
'leaky_routing': self.leaky_routing,
'kernel_initializer': initializers.serialize(self.kernel_initializer)
}
base_config = super(ConvCapsuleLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _update_routing(votes, biases, logit_shape, num_dims, input_dim, output_dim,
num_routing, leaky):
if num_dims == 6:
votes_t_shape = [5, 0, 1, 2, 3, 4]
r_t_shape = [1, 2, 3, 4, 5, 0]
elif num_dims == 4:
votes_t_shape = [3, 0, 1, 2]
r_t_shape = [1, 2, 3, 0]
else:
raise NotImplementedError('Not implemented')
votes_trans = tf.transpose(votes, votes_t_shape)
def _body(i, logits, activations):
"""Routing while loop."""
# route: [batch, input_dim, output_dim, ...]
if leaky:
route = _leaky_routing(logits, output_dim)
else:
route = tf.nn.softmax(logits, axis=-1)
preactivate_unrolled = route * votes_trans
preact_trans = tf.transpose(preactivate_unrolled, r_t_shape)
preactivate = tf.reduce_sum(preact_trans, axis=1) + biases
activation = _squash(preactivate)
activations = activations.write(i, activation)
act_3d = K.expand_dims(activation, 1)
tile_shape = np.ones(num_dims, dtype=np.int32).tolist()
tile_shape[1] = input_dim
act_replicated = tf.tile(act_3d, tile_shape)
distances = tf.reduce_sum(votes * act_replicated, axis=-1)
logits += distances
return (i + 1, logits, activations)
activations = tf.TensorArray(
dtype=tf.float32, size=num_routing, clear_after_read=False)
logits = tf.fill(logit_shape, 0.0)
i = tf.constant(0, dtype=tf.int32)
_, logits, activations = tf.while_loop(
lambda i, logits, activations: i < num_routing,
_body,
loop_vars=[i, logits, activations],
swap_memory=True)
return K.cast(activations.read(num_routing - 1), dtype='float32')
def _squash(input_tensor):
norm = tf.norm(input_tensor, axis=-1, keepdims=True)
norm_squared = norm * norm
return (input_tensor / norm) * (norm_squared / (1 + norm_squared))
def _leaky_routing(logits, output_dim):
leak = tf.zeros_like(logits, optimize=True)
leak = tf.reduce_sum(leak, axis=-1, keepdims=True)
leaky_logits = tf.concat([leak, logits], axis=-1)
leaky_routing = tf.nn.softmax(leaky_logits, dim=-1)
return tf.split(leaky_routing, [1, output_dim], -1)[1]
def combine_images(generated_images, height=None, width=None):
num = generated_images.shape[0]
if width is None and height is None:
width = int(math.sqrt(num))
height = int(math.ceil(float(num)/width))
elif width is not None and height is None: # height not given
height = int(math.ceil(float(num)/width))
elif height is not None and width is None: # width not given
width = int(math.ceil(float(num)/height))
shape = generated_images.shape[1:3]
image = np.zeros((height*shape[0], width*shape[1]),
dtype=generated_images.dtype)
for index, img in enumerate(generated_images):
i = int(index/width)
j = index % width
image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
img[:, :, 0]
return image