diff --git a/micron/network/nms.py b/micron/network/nms.py index 8afb7d8..305bd99 100644 --- a/micron/network/nms.py +++ b/micron/network/nms.py @@ -1,57 +1,80 @@ import numpy as np import tensorflow as tf +from typing import List -def max_detection(soft_mask, window_size, threshold): - data_format = "NDHWC" - w_depth = window_size[1] - w_height = window_size[2] - w_width = window_size[3] +def max_detection( + soft_mask, window_size: List, threshold: float, reject_empty_threshold: bool = False +): sm_shape = soft_mask.get_shape().as_list() - sm_depth = sm_shape[1] - sm_height = sm_shape[2] - sm_width = sm_shape[3] + n_dim = len(sm_shape) - 2 + sm_dims = [sm_shape[i + 1] for i in range(n_dim)] + w_dims = [window_size[i + 1] for i in range(n_dim)] - max_pool = tf.nn.max_pool3d(soft_mask, window_size, window_size, padding="SAME", data_format=data_format) + if n_dim == 2: + data_format = "NHWC" + pool_func = tf.nn.max_pool2d + conv_transpose = tf.nn.conv2d_transpose + elif n_dim == 3: + data_format = "NDHWC" + pool_func = tf.nn.max_pool3d + conv_transpose = tf.nn.conv3d_transpose - conv_filter = np.ones([w_depth,w_height,w_width,1,1]) + max_pool = pool_func( + soft_mask, w_dims, w_dims, padding="SAME", data_format=data_format + ) - upsampled = tf.nn.conv3d_transpose( - max_pool, - conv_filter.astype(np.float32), - [1,sm_depth,sm_height,sm_width,1], - window_size, - padding='SAME', - data_format='NDHWC', - name="nms_conv_0" - ) + conv_filter = np.ones([*w_dims, 1, 1]) + + upsampled = conv_transpose( + max_pool, + conv_filter.astype(np.float32), + [1, *sm_dims, 1], + w_dims, + padding="SAME", + data_format=data_format, + name="nms_conv_0", + ) - maxima = tf.equal(upsampled, soft_mask) - maxima = tf.logical_and(maxima, soft_mask>=threshold) + threshold_maxima = tf.logical_and(maxima, soft_mask >= threshold) + if reject_empty_threshold: + num_points = tf.count_nonzero(threshold_maxima) + maxima = tf.cond(num_points > 0, threshold_maxima, maxima) # Fix doubles # Check the necessary window size and adapt for isotropic vs unisotropic nms: - nms_dims = np.array(window_size) != 1 - double_suppresion_window = [3**(dim) for dim in nms_dims] - - sm_maxima = tf.add(tf.cast(maxima, tf.float32),soft_mask) - max_pool = tf.nn.max_pool3d(sm_maxima, double_suppresion_window, [1,1,1,1,1], padding="SAME", data_format=data_format) - conv_filter = np.ones([1,1,1,1,1]) - upsampled = tf.nn.conv3d_transpose( - max_pool, - conv_filter.astype(np.float32), - [1,sm_depth,sm_height,sm_width,1], - [1,1,1,1,1], - padding='SAME', - data_format=data_format, - name="nms_conv_1" - ) + double_suppresion_window = [1 if dim == 1 else 3 for dim in w_dims] + sm_maxima = tf.add(tf.cast(maxima, tf.float32), soft_mask) + + # sm_maxima smoothed over large window + max_pool = pool_func( + sm_maxima, + double_suppresion_window, + [1 for _ in range(n_dim)], + padding="SAME", + data_format=data_format, + ) + + # not sure if this does anything + conv_filter = np.ones([1 for _ in range(n_dim + 2)]) + upsampled = conv_transpose( + max_pool, + conv_filter.astype(np.float32), + [1, *sm_dims, 1], + [1 for _ in range(n_dim)], + padding="SAME", + data_format=data_format, + name="nms_conv_1", + ) reduced_maxima = tf.equal(upsampled, sm_maxima) - reduced_maxima = tf.logical_and(reduced_maxima, sm_maxima>1) + reduced_maxima = tf.logical_and(reduced_maxima, sm_maxima > 1) + + if n_dim == 2: + return maxima[0, :, :, 0], reduced_maxima[0, :, :, 0] + elif n_dim == 3: + return maxima[0, :, :, :, 0], reduced_maxima[0, :, :, :, 0] - return maxima[0,:,:,:,0], reduced_maxima[0,:,:,:,0] - #return maxima, reduced_maxima