Skip to content

Commit fa62d2c

Browse files
committed
whenxuan: init the simple attention module
1 parent 5c63497 commit fa62d2c

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from typing import Optional
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
class SimpleAttentionModule(torch.nn.Module):
8+
"""
9+
A Simple, Parameter-Free Attention Module for Convolutional
10+
Module for Time Series (1D) and Image (2D) Data.
11+
12+
Refernces: "SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks" by Lingxiao Yang, Ru-Yuan Zhang, et al.
13+
14+
URL: https://proceedings.mlr.press/v139/yang21o.html
15+
"""
16+
17+
def __init__(self, n_dims: int, in_channels: int = None, e_lambda: Optional[float] = 1e-4) -> None:
18+
"""
19+
20+
"""
21+
super().__init__()
22+
23+
n_dims = n_dims
24+
25+
self.activaton = nn.Sigmoid()
26+
self.e_lambda = e_lambda
27+
28+
def __repr__(self):
29+
s = self.__class__.__name__ + "("
30+
s += "lambda=%f)" % self.e_lambda
31+
return s
32+
33+
@staticmethod
34+
def get_module_name():
35+
return "simam"
36+
37+
def forward(self, x):
38+
39+
b, c, h, w = x.size()
40+
41+
n = w * h - 1
42+
43+
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
44+
y = (
45+
x_minus_mu_square
46+
/ (
47+
4
48+
* (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)
49+
)
50+
+ 0.5
51+
)
52+
53+
return x * self.activaton(y)

0 commit comments

Comments
 (0)