From f119cb12d687874a5be729384a368ecda34af626 Mon Sep 17 00:00:00 2001 From: Guilherme Pires Date: Fri, 5 Sep 2025 18:37:30 -0700 Subject: [PATCH] remove unused mask kwarg from Block --- src/models/predictor.py | 9 +-------- src/models/utils/modules.py | 6 +++--- src/models/vision_transformer.py | 3 +-- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/models/predictor.py b/src/models/predictor.py index 2dd9a38b..4adeeb1e 100644 --- a/src/models/predictor.py +++ b/src/models/predictor.py @@ -220,16 +220,9 @@ def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): x = x.repeat(len(masks_tgt), 1, 1) x = torch.cat([x, pred_tokens], dim=1) - # FIXME: this implementation currently assumes masks_ctxt and masks_tgt - # are alligned 1:1 (ok with MultiMask wrapper on predictor but - # otherwise will break) - masks_ctxt = torch.cat(masks_ctxt, dim=0) - masks_tgt = torch.cat(masks_tgt, dim=0) - masks = torch.cat([masks_ctxt, masks_tgt], dim=1) - # Fwd prop for blk in self.predictor_blocks: - x = blk(x, mask=masks) + x = blk(x) x = self.predictor_norm(x) # Return output corresponding to target tokens diff --git a/src/models/utils/modules.py b/src/models/utils/modules.py index f27ef2a9..4d84bb23 100644 --- a/src/models/utils/modules.py +++ b/src/models/utils/modules.py @@ -58,7 +58,7 @@ def __init__( self.proj_drop = nn.Dropout(proj_drop) self.use_sdpa = use_sdpa - def forward(self, x, mask=None): + def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D] @@ -111,8 +111,8 @@ def __init__( act_layer=act_layer, drop=drop) - def forward(self, x, return_attention=False, mask=None): - y, attn = self.attn(self.norm1(x), mask=mask) + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) if return_attention: return attn x = x + y diff --git a/src/models/vision_transformer.py b/src/models/vision_transformer.py index a8748dfd..64929cb4 100644 --- a/src/models/vision_transformer.py +++ b/src/models/vision_transformer.py @@ -177,12 +177,11 @@ def forward(self, x, masks=None): # Mask away unwanted tokens (if masks provided) if masks is not None: x = apply_masks(x, masks) - masks = torch.cat(masks, dim=0) # Fwd prop outs = [] for i, blk in enumerate(self.blocks): - x = blk(x, mask=masks) + x = blk(x) if self.out_layers is not None and i in self.out_layers: outs.append(self.norm(x))