Skip to content

Mask shape is not correct #11

@HaoYang0123

Description

@HaoYang0123

In Attention-forward function, if input mask is not None, the code is not correct. See this code:

==================
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value=True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, float('-inf'))
del mask

If we input x with shape (2, 3, 5) in which 2 is batch size, 3 is the number of regions, 5 is feature size. Then we should input mask with shape (2, 2) in which 2 is batch, 2 is the number of regions. The number regions of x is the number regions of mask because you input cls_token into x. Then you use code (mask = F.pad(mask.flatten(1), (1, 0), value=True)) to let the mask shape become into (2, 3).
However when running the code (dots.masked_fill_(~mask, float('-inf'))), the shapes of dots and mask are not same. The shape of dots is (2, 5, 3, 3) (with head=5) while the shape of mask is (2, 3, 3)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions