Skip to content

Commit 0e4bd20

Browse files
committed
whenxuan: black .
1 parent d68e67d commit 0e4bd20

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

channel_attention/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def create_conv_layer(
1414
) -> nn.Module:
1515
"""
1616
Create a convolutional layer based on the number of dimensions.
17-
17+
1818
:param n_dims: The number of dimensions (1 or 2).
1919
:param in_channels: Number of input channels.
2020
:param out_channels: Number of output channels.
@@ -23,7 +23,7 @@ def create_conv_layer(
2323
:param padding: Padding added to both sides of the input. Default is 0.
2424
:param bias: If True, adds a learnable bias to the output. Default is True.
2525
:param groups: Number of blocked connections from input channels to output channels. Default is 1.
26-
26+
2727
:return: A convolutional layer (nn.Conv1d or nn.Conv2d).
2828
"""
2929
if n_dims == 1:

tests.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,19 @@ def test_MultiSEAttention(self) -> None:
166166

167167
# Test MultiSEAttention for time series (1D)
168168
mse_1d = MultiSEAttention(
169-
n_dims=1, n_channels=self.n_channels, n_branches=3,
169+
n_dims=1,
170+
n_channels=self.n_channels,
171+
n_branches=3,
170172
)
171173
for x in time_series_inputs:
172174
output = mse_1d(x)
173175
self.assertEqual(output.shape, x.shape)
174176

175177
# Test MultiSEAttention for images (2D)
176178
mse_2d = MultiSEAttention(
177-
n_dims=2, n_channels=self.n_channels, n_branches=4,
179+
n_dims=2,
180+
n_channels=self.n_channels,
181+
n_branches=4,
178182
)
179183
for x in image_inputs:
180184
output = mse_2d(x)

0 commit comments

Comments
 (0)