Skip to content

Commit 7a0c343

Browse files
committed
whenxuan: update the multi-se channels attention
1 parent 0a7b419 commit 7a0c343

File tree

2 files changed

+96
-1
lines changed

2 files changed

+96
-1
lines changed

channel_attention/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
__all__ = [
22
"SEAttention",
3+
"MultiSEAttention",
34
"SpatialAttention",
45
"ChannelAttention",
56
"ConvBlockAttention",
@@ -9,7 +10,7 @@
910

1011
__version__ = "0.0.1"
1112

12-
from .squeeze_excitation import SEAttention
13+
from .squeeze_excitation import SEAttention, MultiSEAttention
1314

1415
from .spatial_attention import SpatialAttention
1516

channel_attention/squeeze_excitation.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import torch
44
from torch import nn
55

6+
from channel_attention.utils import create_conv_layer
7+
68

79
class SEAttention(nn.Module):
810
"""
@@ -75,3 +77,95 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7577

7678
# Scale the input tensor with the recalibrated weights
7779
return x * y.expand_as(x)
80+
81+
82+
class MultiSEAttention(nn.Module):
83+
"""
84+
Multi-Branch Squeeze-and-Excitation Attention Module for Time Series (1D) or Image (2D) Analysis.
85+
This module enhances the representational power of the standard SE block by incorporating multiple branches and adaptive style assignment.
86+
"""
87+
88+
def __init__(self, n_dims: int, n_channels: int, reduction: int = 4, n_branches: int = 3) -> None:
89+
"""
90+
Multi-Branch Squeeze-and-Excitation Attention Module for Time Series (1D) or Image (2D) Analysis.
91+
92+
:param n_dims: (int) The dimension of input data, either 1 (time series) or 2 (image).
93+
:param n_channels: (int) The number of input channels of time series data.
94+
:param reduction: (int) The reduction ratio for the intermediate layer in the SE block.
95+
:param n_branches: (int) The number of branches in the multi-branch SE module.
96+
"""
97+
super(MultiSEAttention, self).__init__()
98+
99+
# Dimension assertion
100+
assert n_dims in [1, 2], "The dimension of input data must be either 1 or 2."
101+
self.n_dims = n_dims
102+
103+
# Create the average pooling layer and activation function
104+
self.avg_pool = nn.AdaptiveAvgPool2d(1) if n_dims == 2 else nn.AdaptiveAvgPool1d(1)
105+
self.activation = nn.Sigmoid()
106+
107+
# Store the reduction ratio, number of branches, and number of channels
108+
self.reduction = reduction
109+
self.n_branches = n_branches
110+
self.n_channels = n_channels
111+
new_channels = n_channels * n_branches
112+
113+
# Layers for multi-branch excitation
114+
self.fc = nn.Sequential(create_conv_layer(n_dims=n_dims, in_channels=new_channels, out_channels=new_channels // self.reduction, kernel_size=1, bias=True, groups=n_branches),
115+
nn.ReLU(inplace=True),
116+
create_conv_layer(n_dims=n_dims, in_channels=new_channels // self.reduction, out_channels=new_channels, kernel_size=1, bias=True, groups=n_branches))
117+
118+
# Style assignment layer
119+
self.style_assigner = nn.Linear(n_channels, n_branches, bias=False)
120+
121+
# Repeat size for reshaping the output
122+
self.repeat_size = (1, 1) if n_dims == 2 else (1,)
123+
124+
def _style_assignment(self, channel_mean: torch.Tensor, batch_size: int) -> torch.Tensor:
125+
"""
126+
Assign styles to each channel based on the channel mean.
127+
128+
:param channel_mean: (torch.Tensor) The mean values of each channel, shape (batch_size, n_channels, 1, 1).
129+
:param batch_size: (int) The batch size of the input tensor.
130+
131+
:return: (torch.Tensor) Style assignment probabilities for each branch, shape (batch_size, n_branches).
132+
"""
133+
style_assignment = self.style_assigner(channel_mean.view(batch_size, -1))
134+
style_assignment = nn.functional.softmax(style_assignment, dim=1)
135+
return style_assignment
136+
137+
def forward(self, x: torch.Tensor) -> torch.Tensor:
138+
"""
139+
Forward pass of the MultiSEAttention module.
140+
141+
:param x: (torch.Tensor)
142+
1D Time Series: Input tensor of shape (batch_size, channels, seq_len);
143+
2D Image: Input tensor of shape (batch_size, channels, height, width).
144+
145+
:return: (torch.Tensor) Output tensor of the same shape as input.
146+
"""
147+
# Apply global average pooling
148+
avg_y = self.avg_pool(x)
149+
batch_size, n_channels = avg_y.shape[:2]
150+
151+
# Perform style assignment
152+
style_assignment = self._style_assignment(avg_y, batch_size=batch_size) # B x N
153+
154+
# Multi-branch excitation
155+
avg_y = avg_y.repeat(1, self.n_branches, *self.repeat_size)
156+
157+
# [batch_size, n_branches * n_channels, 1, 1]
158+
z = self.fc(avg_y)
159+
160+
# Apply style assignment
161+
style_assignment = style_assignment.repeat_interleave(n_channels, dim=1)
162+
if self.n_dims == 1:
163+
z = z * style_assignment[:, :, None]
164+
else:
165+
z = z * style_assignment[:, :, None, None]
166+
167+
# [batch_size, n_channels, 1, 1]
168+
z = torch.sum(z.view(batch_size, self.n_branches, n_channels, *self.repeat_size), dim=1) # B x C x 1 x 1
169+
z = self.activation(z)
170+
171+
return x * z

0 commit comments

Comments
 (0)