Skip to content

Commit 5831d84

Browse files
committed
fix some bugs on 2d data
1 parent e2b69f9 commit 5831d84

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

micron/network/nms.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def max_detection(
6060

6161
# not sure if this does anything
6262
conv_filter = np.ones([1 for _ in range(n_dim + 2)])
63-
upsampled = tf.nn.conv3d_transpose(
63+
upsampled = conv_transpose(
6464
max_pool,
6565
conv_filter.astype(np.float16),
6666
[1, *sm_dims, 1],
@@ -73,5 +73,8 @@ def max_detection(
7373
reduced_maxima = tf.equal(upsampled, sm_maxima)
7474
reduced_maxima = tf.logical_and(reduced_maxima, sm_maxima > 1)
7575

76-
return maxima[0, :, :, :, 0], reduced_maxima[0, :, :, :, 0]
76+
if n_dim == 2:
77+
return maxima[0, :, :, 0], reduced_maxima[0, :, :, 0]
78+
elif n_dim == 3:
79+
return maxima[0, :, :, :, 0], reduced_maxima[0, :, :, :, 0]
7780

0 commit comments

Comments
 (0)