Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 62 additions & 39 deletions micron/network/nms.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,80 @@
import numpy as np
import tensorflow as tf

from typing import List

def max_detection(soft_mask, window_size, threshold):
data_format = "NDHWC"

w_depth = window_size[1]
w_height = window_size[2]
w_width = window_size[3]
def max_detection(
soft_mask, window_size: List, threshold: float, reject_empty_threshold: bool = False
):

sm_shape = soft_mask.get_shape().as_list()
sm_depth = sm_shape[1]
sm_height = sm_shape[2]
sm_width = sm_shape[3]
n_dim = len(sm_shape) - 2
sm_dims = [sm_shape[i + 1] for i in range(n_dim)]
w_dims = [window_size[i + 1] for i in range(n_dim)]

max_pool = tf.nn.max_pool3d(soft_mask, window_size, window_size, padding="SAME", data_format=data_format)
if n_dim == 2:
data_format = "NHWC"
pool_func = tf.nn.max_pool2d
conv_transpose = tf.nn.conv2d_transpose
elif n_dim == 3:
data_format = "NDHWC"
pool_func = tf.nn.max_pool3d
conv_transpose = tf.nn.conv3d_transpose

conv_filter = np.ones([w_depth,w_height,w_width,1,1])
max_pool = pool_func(
soft_mask, w_dims, w_dims, padding="SAME", data_format=data_format
)

upsampled = tf.nn.conv3d_transpose(
max_pool,
conv_filter.astype(np.float32),
[1,sm_depth,sm_height,sm_width,1],
window_size,
padding='SAME',
data_format='NDHWC',
name="nms_conv_0"
)
conv_filter = np.ones([*w_dims, 1, 1])

upsampled = conv_transpose(
max_pool,
conv_filter.astype(np.float32),
[1, *sm_dims, 1],
w_dims,
padding="SAME",
data_format=data_format,
name="nms_conv_0",
)


maxima = tf.equal(upsampled, soft_mask)
maxima = tf.logical_and(maxima, soft_mask>=threshold)
threshold_maxima = tf.logical_and(maxima, soft_mask >= threshold)
if reject_empty_threshold:
num_points = tf.count_nonzero(threshold_maxima)
maxima = tf.cond(num_points > 0, threshold_maxima, maxima)

# Fix doubles
# Check the necessary window size and adapt for isotropic vs unisotropic nms:
nms_dims = np.array(window_size) != 1
double_suppresion_window = [3**(dim) for dim in nms_dims]

sm_maxima = tf.add(tf.cast(maxima, tf.float32),soft_mask)
max_pool = tf.nn.max_pool3d(sm_maxima, double_suppresion_window, [1,1,1,1,1], padding="SAME", data_format=data_format)
conv_filter = np.ones([1,1,1,1,1])
upsampled = tf.nn.conv3d_transpose(
max_pool,
conv_filter.astype(np.float32),
[1,sm_depth,sm_height,sm_width,1],
[1,1,1,1,1],
padding='SAME',
data_format=data_format,
name="nms_conv_1"
)
double_suppresion_window = [1 if dim == 1 else 3 for dim in w_dims]
sm_maxima = tf.add(tf.cast(maxima, tf.float32), soft_mask)

# sm_maxima smoothed over large window
max_pool = pool_func(
sm_maxima,
double_suppresion_window,
[1 for _ in range(n_dim)],
padding="SAME",
data_format=data_format,
)

# not sure if this does anything
conv_filter = np.ones([1 for _ in range(n_dim + 2)])
upsampled = conv_transpose(
max_pool,
conv_filter.astype(np.float32),
[1, *sm_dims, 1],
[1 for _ in range(n_dim)],
padding="SAME",
data_format=data_format,
name="nms_conv_1",
)

reduced_maxima = tf.equal(upsampled, sm_maxima)
reduced_maxima = tf.logical_and(reduced_maxima, sm_maxima>1)
reduced_maxima = tf.logical_and(reduced_maxima, sm_maxima > 1)

if n_dim == 2:
return maxima[0, :, :, 0], reduced_maxima[0, :, :, 0]
elif n_dim == 3:
return maxima[0, :, :, :, 0], reduced_maxima[0, :, :, :, 0]

return maxima[0,:,:,:,0], reduced_maxima[0,:,:,:,0]
#return maxima, reduced_maxima