-
Notifications
You must be signed in to change notification settings - Fork 26
Mask shape is not correct #11
Copy link
Copy link
Open
Description
In Attention-forward function, if input mask is not None, the code is not correct. See this code:
==================
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value=True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, float('-inf'))
del mask
If we input x with shape (2, 3, 5) in which 2 is batch size, 3 is the number of regions, 5 is feature size. Then we should input mask with shape (2, 2) in which 2 is batch, 2 is the number of regions. The number regions of x is the number regions of mask because you input cls_token into x. Then you use code (mask = F.pad(mask.flatten(1), (1, 0), value=True)) to let the mask shape become into (2, 3).
However when running the code (dots.masked_fill_(~mask, float('-inf'))), the shapes of dots and mask are not same. The shape of dots is (2, 5, 3, 3) (with head=5) while the shape of mask is (2, 3, 3)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels