33import torch
44from torch import nn
55
6+ from channel_attention .utils import create_conv_layer
7+
68
79class SEAttention (nn .Module ):
810 """
@@ -75,3 +77,95 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7577
7678 # Scale the input tensor with the recalibrated weights
7779 return x * y .expand_as (x )
80+
81+
82+ class MultiSEAttention (nn .Module ):
83+ """
84+ Multi-Branch Squeeze-and-Excitation Attention Module for Time Series (1D) or Image (2D) Analysis.
85+ This module enhances the representational power of the standard SE block by incorporating multiple branches and adaptive style assignment.
86+ """
87+
88+ def __init__ (self , n_dims : int , n_channels : int , reduction : int = 4 , n_branches : int = 3 ) -> None :
89+ """
90+ Multi-Branch Squeeze-and-Excitation Attention Module for Time Series (1D) or Image (2D) Analysis.
91+
92+ :param n_dims: (int) The dimension of input data, either 1 (time series) or 2 (image).
93+ :param n_channels: (int) The number of input channels of time series data.
94+ :param reduction: (int) The reduction ratio for the intermediate layer in the SE block.
95+ :param n_branches: (int) The number of branches in the multi-branch SE module.
96+ """
97+ super (MultiSEAttention , self ).__init__ ()
98+
99+ # Dimension assertion
100+ assert n_dims in [1 , 2 ], "The dimension of input data must be either 1 or 2."
101+ self .n_dims = n_dims
102+
103+ # Create the average pooling layer and activation function
104+ self .avg_pool = nn .AdaptiveAvgPool2d (1 ) if n_dims == 2 else nn .AdaptiveAvgPool1d (1 )
105+ self .activation = nn .Sigmoid ()
106+
107+ # Store the reduction ratio, number of branches, and number of channels
108+ self .reduction = reduction
109+ self .n_branches = n_branches
110+ self .n_channels = n_channels
111+ new_channels = n_channels * n_branches
112+
113+ # Layers for multi-branch excitation
114+ self .fc = nn .Sequential (create_conv_layer (n_dims = n_dims , in_channels = new_channels , out_channels = new_channels // self .reduction , kernel_size = 1 , bias = True , groups = n_branches ),
115+ nn .ReLU (inplace = True ),
116+ create_conv_layer (n_dims = n_dims , in_channels = new_channels // self .reduction , out_channels = new_channels , kernel_size = 1 , bias = True , groups = n_branches ))
117+
118+ # Style assignment layer
119+ self .style_assigner = nn .Linear (n_channels , n_branches , bias = False )
120+
121+ # Repeat size for reshaping the output
122+ self .repeat_size = (1 , 1 ) if n_dims == 2 else (1 ,)
123+
124+ def _style_assignment (self , channel_mean : torch .Tensor , batch_size : int ) -> torch .Tensor :
125+ """
126+ Assign styles to each channel based on the channel mean.
127+
128+ :param channel_mean: (torch.Tensor) The mean values of each channel, shape (batch_size, n_channels, 1, 1).
129+ :param batch_size: (int) The batch size of the input tensor.
130+
131+ :return: (torch.Tensor) Style assignment probabilities for each branch, shape (batch_size, n_branches).
132+ """
133+ style_assignment = self .style_assigner (channel_mean .view (batch_size , - 1 ))
134+ style_assignment = nn .functional .softmax (style_assignment , dim = 1 )
135+ return style_assignment
136+
137+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
138+ """
139+ Forward pass of the MultiSEAttention module.
140+
141+ :param x: (torch.Tensor)
142+ 1D Time Series: Input tensor of shape (batch_size, channels, seq_len);
143+ 2D Image: Input tensor of shape (batch_size, channels, height, width).
144+
145+ :return: (torch.Tensor) Output tensor of the same shape as input.
146+ """
147+ # Apply global average pooling
148+ avg_y = self .avg_pool (x )
149+ batch_size , n_channels = avg_y .shape [:2 ]
150+
151+ # Perform style assignment
152+ style_assignment = self ._style_assignment (avg_y , batch_size = batch_size ) # B x N
153+
154+ # Multi-branch excitation
155+ avg_y = avg_y .repeat (1 , self .n_branches , * self .repeat_size )
156+
157+ # [batch_size, n_branches * n_channels, 1, 1]
158+ z = self .fc (avg_y )
159+
160+ # Apply style assignment
161+ style_assignment = style_assignment .repeat_interleave (n_channels , dim = 1 )
162+ if self .n_dims == 1 :
163+ z = z * style_assignment [:, :, None ]
164+ else :
165+ z = z * style_assignment [:, :, None , None ]
166+
167+ # [batch_size, n_channels, 1, 1]
168+ z = torch .sum (z .view (batch_size , self .n_branches , n_channels , * self .repeat_size ), dim = 1 ) # B x C x 1 x 1
169+ z = self .activation (z )
170+
171+ return x * z
0 commit comments