From 862c56f660afa12e22a87560e43fbddee155f29b Mon Sep 17 00:00:00 2001 From: Steffen-Wolf Date: Tue, 2 Nov 2021 10:21:55 -0400 Subject: [PATCH] Add batch_norm parameter for the Unet architecture --- funlib/learn/torch/models/unet.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/funlib/learn/torch/models/unet.py b/funlib/learn/torch/models/unet.py index d904261..1488462 100644 --- a/funlib/learn/torch/models/unet.py +++ b/funlib/learn/torch/models/unet.py @@ -12,6 +12,7 @@ def __init__( out_channels, kernel_sizes, activation, + batch_norm=False, padding='valid'): super(ConvPass, self).__init__() @@ -44,7 +45,21 @@ def __init__( kernel_size, padding=pad)) except KeyError: - raise RuntimeError("%dD convolution not implemented" % self.dims) + raise RuntimeError( + "%dD convolution not implemented" % self.dims) + + if batch_norm: + try: + bn = { + 2: torch.nn.BatchNorm2d, + 3: torch.nn.BatchNorm3d + }[self.dims] + + layers.append( + bn(out_channels)) + except KeyError: + raise RuntimeError( + "BatchNorm%dD not implemented" % self.dims) in_channels = out_channels @@ -128,7 +143,6 @@ def __init__( stride=scale_factor) else: - self.up = torch.nn.Upsample( scale_factor=scale_factor, mode=mode) @@ -234,6 +248,7 @@ def __init__( kernel_size_down=None, kernel_size_up=None, activation='ReLU', + batch_norm=False, fov=(1, 1, 1), voxel_size=(1, 1, 1), num_fmaps_out=None, @@ -373,6 +388,7 @@ def __init__( num_fmaps*fmap_inc_factor**level, kernel_size_down[level], activation=activation, + batch_norm=batch_norm, padding=padding) for level in range(self.num_levels) ]) @@ -410,6 +426,7 @@ def __init__( else num_fmaps_out, kernel_size_up[level], activation=activation, + batch_norm=batch_norm, padding=padding) for level in range(self.num_levels - 1) ])