Skip to content

Commit 0cd5d52

Browse files
committed
whenxuan: update the readme
1 parent 6418be9 commit 0cd5d52

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,24 @@ We only develop and test with PyTorch. Please make sure to install it from [PyTo
2626

2727
## Usage <a id="Usage"></a>
2828

29+
**The core of the channel attention mechanism lies in its invariance between input and output.** Therefore, we can easily embed this module into a certain location in a neural network to further improve the model's performance.
2930

31+
~~~python
32+
import torch
33+
from channel_attention import SEAttention
3034

35+
# 1D Time Series Data with (batch_size, channels, seq_len)
36+
inputs = torch.rand(8, 16, 128)
37+
attn = SEAttention(n_dims=1, n_channels=16, reduction=4)
38+
print(attn(inputs).shape)
3139

40+
# 2D Image Data with (batch_size, channels, height, width)
41+
inputs_2d = torch.rand(8, 16, 64, 64)
42+
attn_2d = SEAttention(n_dims=2, n_channels=16, reduction=4)
43+
print(attn_2d(inputs_2d).shape)
44+
~~~
45+
46+
When the number of input channels is small, the channel attention mechanism is very lightweight and does not significantly increase computational complexity.
3247

3348
## Modules <a id="Modules"></a>
3449

@@ -50,5 +65,12 @@ We only develop and test with PyTorch. Please make sure to install it from [PyTo
5065
<img width="80%" src="images/SpatialAttention.png">
5166
</div>
5267

68+
#### 4. [`ConvBlockAttention`](https://github.com/wwhenxuan/Channel-Attention/blob/main/channel_attention/spatial_attention.py): [[paper]]() The Convolutional Block Attention Module (CBAM) combining Channel Attention and Spatial Attention.
69+
70+
<div align="center">
71+
<img width="80%" src="images/ConvBlockAttention.png">
72+
</div>
73+
74+
5375

5476
## Experiments <a id="Experiments"></a>

images/ConvBlockAttention.png

83.2 KB
Loading

tests.ipynb

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,20 @@
433433
"id": "88291fd7",
434434
"metadata": {},
435435
"outputs": [],
436-
"source": []
436+
"source": [
437+
"import torch\n",
438+
"from channel_attention import SEAttention\n",
439+
"\n",
440+
"# 1D Time Series Data with (batch_size, channels, seq_len)\n",
441+
"inputs = torch.rand(8, 16, 128)\n",
442+
"attn = SEAttention(n_dims=1, n_channels=16, reduction=4)\n",
443+
"print(attn(inputs).shape)\n",
444+
"\n",
445+
"# 2D Image Data with (batch_size, channels, height, width)\n",
446+
"inputs_2d = torch.rand(8, 16, 64, 64)\n",
447+
"attn_2d = SEAttention(n_dims=2, n_channels=16, reduction=4)\n",
448+
"print(attn_2d(inputs_2d).shape)"
449+
]
437450
}
438451
],
439452
"metadata": {

0 commit comments

Comments
 (0)