From d6eabfb7f9adc731f271a2fa23ce80625f5140cf Mon Sep 17 00:00:00 2001 From: dheeraj Date: Wed, 6 Sep 2023 14:55:14 +0200 Subject: [PATCH] change the order of BN and Conv of RepMixer Block to match the FastViT Paper --- models/fastvit.py | 1 + models/modules/mobileone.py | 44 ++++++++++++++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/models/fastvit.py b/models/fastvit.py index 6f0363e..c0c9fae 100644 --- a/models/fastvit.py +++ b/models/fastvit.py @@ -282,6 +282,7 @@ def __init__( padding=kernel_size // 2, groups=dim, use_act=False, + use_bn_conv=True, ) self.use_layer_scale = use_layer_scale if use_layer_scale: diff --git a/models/modules/mobileone.py b/models/modules/mobileone.py index 6cc4a36..ae82afa 100644 --- a/models/modules/mobileone.py +++ b/models/modules/mobileone.py @@ -79,6 +79,7 @@ def __init__( use_scale_branch: bool = True, num_conv_branches: int = 1, activation: nn.Module = nn.GELU(), + use_bn_conv=False, ) -> None: """Construct a MobileOneBlock module. @@ -95,6 +96,8 @@ def __init__( use_act: Whether to use activation. Default: ``True`` use_scale_branch: Whether to use scale branch. Default: ``True`` num_conv_branches: Number of linear conv branches. + activation: Activation function. Default: ``nn.GELU()`` + use_bn_conv: Whether to use batchnorm before conv. Default: ``False`` """ super(MobileOneBlock, self).__init__() self.inference_mode = inference_mode @@ -106,6 +109,7 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels self.num_conv_branches = num_conv_branches + self.use_bn_conv = use_bn_conv # Check if SE-ReLU is requested if use_se: @@ -141,9 +145,14 @@ def __init__( if num_conv_branches > 0: rbr_conv = list() for _ in range(self.num_conv_branches): - rbr_conv.append( + if self.use_bn_conv: + rbr_conv.append( + self._bn_conv(kernel_size=kernel_size, padding=padding) + ) + else: + rbr_conv.append( self._conv_bn(kernel_size=kernel_size, padding=padding) - ) + ) self.rbr_conv = nn.ModuleList(rbr_conv) else: self.rbr_conv = None @@ -151,7 +160,10 @@ def __init__( # Re-parameterizable scale branch self.rbr_scale = None if (kernel_size > 1) and use_scale_branch: - self.rbr_scale = self._conv_bn(kernel_size=1, padding=0) + if self.use_bn_conv: + self.rbr_scale = self._bn_conv(kernel_size=1, padding=0) + else: + self.rbr_scale = self._conv_bn(kernel_size=1, padding=0) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply forward pass.""" @@ -313,6 +325,32 @@ def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential: ) mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) return mod_list + + def _bn_conv(self, kernel_size: int, padding: int) -> nn.Sequential: + """Helper method to construct batchnorm-conv layers. + + Args: + kernel_size: Size of the convolution kernel. + padding: Zero-padding size. + + Returns: + BN-Conv module. + """ + mod_list = nn.Sequential() + mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) + mod_list.add_module( + "conv", + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=kernel_size, + stride=self.stride, + padding=padding, + groups=self.groups, + bias=False, + ), + ) + return mod_list def reparameterize_model(model: torch.nn.Module) -> nn.Module: