Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions models/fastvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def __init__(
padding=kernel_size // 2,
groups=dim,
use_act=False,
use_bn_conv=True,
)
self.use_layer_scale = use_layer_scale
if use_layer_scale:
Expand Down
44 changes: 41 additions & 3 deletions models/modules/mobileone.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
use_scale_branch: bool = True,
num_conv_branches: int = 1,
activation: nn.Module = nn.GELU(),
use_bn_conv=False,
) -> None:
"""Construct a MobileOneBlock module.

Expand All @@ -95,6 +96,8 @@ def __init__(
use_act: Whether to use activation. Default: ``True``
use_scale_branch: Whether to use scale branch. Default: ``True``
num_conv_branches: Number of linear conv branches.
activation: Activation function. Default: ``nn.GELU()``
use_bn_conv: Whether to use batchnorm before conv. Default: ``False``
"""
super(MobileOneBlock, self).__init__()
self.inference_mode = inference_mode
Expand All @@ -106,6 +109,7 @@ def __init__(
self.in_channels = in_channels
self.out_channels = out_channels
self.num_conv_branches = num_conv_branches
self.use_bn_conv = use_bn_conv

# Check if SE-ReLU is requested
if use_se:
Expand Down Expand Up @@ -141,17 +145,25 @@ def __init__(
if num_conv_branches > 0:
rbr_conv = list()
for _ in range(self.num_conv_branches):
rbr_conv.append(
if self.use_bn_conv:
rbr_conv.append(
self._bn_conv(kernel_size=kernel_size, padding=padding)
)
else:
rbr_conv.append(
self._conv_bn(kernel_size=kernel_size, padding=padding)
)
)
self.rbr_conv = nn.ModuleList(rbr_conv)
else:
self.rbr_conv = None

# Re-parameterizable scale branch
self.rbr_scale = None
if (kernel_size > 1) and use_scale_branch:
self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
if self.use_bn_conv:
self.rbr_scale = self._bn_conv(kernel_size=1, padding=0)
else:
self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
Expand Down Expand Up @@ -313,6 +325,32 @@ def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
)
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
return mod_list

def _bn_conv(self, kernel_size: int, padding: int) -> nn.Sequential:
"""Helper method to construct batchnorm-conv layers.

Args:
kernel_size: Size of the convolution kernel.
padding: Zero-padding size.

Returns:
BN-Conv module.
"""
mod_list = nn.Sequential()
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
mod_list.add_module(
"conv",
nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=kernel_size,
stride=self.stride,
padding=padding,
groups=self.groups,
bias=False,
),
)
return mod_list


def reparameterize_model(model: torch.nn.Module) -> nn.Module:
Expand Down