Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions src/models/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,9 @@ def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1):
x = x.repeat(len(masks_tgt), 1, 1)
x = torch.cat([x, pred_tokens], dim=1)

# FIXME: this implementation currently assumes masks_ctxt and masks_tgt
# are alligned 1:1 (ok with MultiMask wrapper on predictor but
# otherwise will break)
masks_ctxt = torch.cat(masks_ctxt, dim=0)
masks_tgt = torch.cat(masks_tgt, dim=0)
masks = torch.cat([masks_ctxt, masks_tgt], dim=1)

# Fwd prop
for blk in self.predictor_blocks:
x = blk(x, mask=masks)
x = blk(x)
x = self.predictor_norm(x)

# Return output corresponding to target tokens
Expand Down
6 changes: 3 additions & 3 deletions src/models/utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.proj_drop = nn.Dropout(proj_drop)
self.use_sdpa = use_sdpa

def forward(self, x, mask=None):
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D]
Expand Down Expand Up @@ -111,8 +111,8 @@ def __init__(
act_layer=act_layer,
drop=drop)

def forward(self, x, return_attention=False, mask=None):
y, attn = self.attn(self.norm1(x), mask=mask)
def forward(self, x, return_attention=False):
y, attn = self.attn(self.norm1(x))
if return_attention:
return attn
x = x + y
Expand Down
3 changes: 1 addition & 2 deletions src/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,11 @@ def forward(self, x, masks=None):
# Mask away unwanted tokens (if masks provided)
if masks is not None:
x = apply_masks(x, masks)
masks = torch.cat(masks, dim=0)

# Fwd prop
outs = []
for i, blk in enumerate(self.blocks):
x = blk(x, mask=masks)
x = blk(x)
if self.out_layers is not None and i in self.out_layers:
outs.append(self.norm(x))

Expand Down