Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions PyTorch/build-in/Classification/ECANet/ecanet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# eca_model_factory.py
import model.eca_resnet as eca_net

def Model(num_classes=1000):
"""
Minimal ECA model factory, returns a model instance.
"""
model_type='eca_resnet50'
k_size=[3,3,3,3]
pretrained=False
model_func = getattr(eca_net, model_type)
return model_func(k_size=k_size, num_classes=num_classes, pretrained=pretrained)


# 测试
if __name__ == "__main__":
m = Model(model_type='eca_resnet50', k_size=[3,3,3,3], num_classes=100)
print(m)
2 changes: 2 additions & 0 deletions PyTorch/build-in/Classification/ECANet/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .eca_resnet import *
from .eca_mobilenetv2 import *
129 changes: 129 additions & 0 deletions PyTorch/build-in/Classification/ECANet/model/eca_mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from torch import nn
from .eca_module import eca_layer

__all__ = ['ECA_MobileNetV2', 'eca_mobilenet_v2']


model_urls = {
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
}


class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)


class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio, k_size):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]

hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup

layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
])
layers.append(eca_layer(oup, k_size))
self.conv = nn.Sequential(*layers)

def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)


class ECA_MobileNetV2(nn.Module):
def __init__(self, num_classes=1000, width_mult=1.0):
super(ECA_MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]

# building first layer
input_channel = int(input_channel * width_mult)
self.last_channel = int(last_channel * max(1.0, width_mult))
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = int(c * width_mult)
for i in range(n):
if c < 96:
ksize = 1
else:
ksize = 3
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t, k_size=ksize))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
# make it nn.Sequential
self.features = nn.Sequential(*features)

# building classifier
self.classifier = nn.Sequential(
nn.Dropout(0.25),
nn.Linear(self.last_channel, num_classes),
)

# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)

def forward(self, x):
x = self.features(x)
x = x.mean(-1).mean(-1)
x = self.classifier(x)
return x


def eca_mobilenet_v2(pretrained=False, progress=True, **kwargs):
"""
Constructs a ECA_MobileNetV2 architecture from

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = ECA_MobileNetV2(**kwargs)
# if pretrained:
# state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
# progress=progress)
# model.load_state_dict(state_dict)
return model
29 changes: 29 additions & 0 deletions PyTorch/build-in/Classification/ECANet/model/eca_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from torch import nn
from torch.nn.parameter import Parameter

class eca_layer(nn.Module):
"""Constructs a ECA module.

Args:
channel: Number of channels of the input feature map
k_size: Adaptive selection of kernel size
"""
def __init__(self, channel, k_size=3):
super(eca_layer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
# feature descriptor on the global spatial information
y = self.avg_pool(x)

# Two different branches of ECA module
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)

# Multi-scale information fusion
y = self.sigmoid(y)

return x * y.expand_as(x)

22 changes: 22 additions & 0 deletions PyTorch/build-in/Classification/ECANet/model/eca_ns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import time
from torch import nn


class eca_layer(nn.Module):
def __init__(self, channel, k_size):
super(eca_layer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.k_size = k_size
self.conv = nn.Conv1d(channel, channel, kernel_size=k_size, bias=False, groups=channel)
self.sigmoid = nn.Sigmoid()


def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x)
y = nn.functional.unfold(y.transpose(-1, -3), kernel_size=(1, self.k_size), padding=(0, (self.k_size - 1) // 2))
y = self.conv(y.transpose(-1, -2)).unsqueeze(-1)
y = self.sigmoid(y)
x = x * y.expand_as(x)
return x
Loading