-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
Thanks for sharing your work!
PSA/semantic-segmentation/network/PSA.py
Lines 64 to 95 in 588b370
| def spatial_pool(self, x): | |
| input_x = self.conv_v_right(x) | |
| batch, channel, height, width = input_x.size() | |
| # [N, IC, H*W] | |
| input_x = input_x.view(batch, channel, height * width) | |
| # [N, 1, H, W] | |
| context_mask = self.conv_q_right(x) | |
| # [N, 1, H*W] | |
| context_mask = context_mask.view(batch, 1, height * width) | |
| # [N, 1, H*W] | |
| context_mask = self.softmax_right(context_mask) | |
| # [N, IC, 1] | |
| # context = torch.einsum('ndw,new->nde', input_x, context_mask) | |
| context = torch.matmul(input_x, context_mask.transpose(1,2)) | |
| # [N, IC, 1, 1] | |
| context = context.unsqueeze(-1) | |
| # [N, OC, 1, 1] | |
| context = self.conv_up(context) | |
| # [N, OC, 1, 1] | |
| mask_ch = self.sigmoid(context) | |
| out = x * mask_ch | |
| return out |
It seems that spatial_pool function is the same with Channel-only self attention module.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels
