diff --git a/models/locate.py b/models/locate.py index 284a0b6..f7300bf 100644 --- a/models/locate.py +++ b/models/locate.py @@ -195,7 +195,7 @@ def forward(self, exo, ego, aff_label, epoch): return masks, logits, loss_proto, loss_con - @torch.no_grad() + # @torch.no_grad() def test_forward(self, ego, aff_label): _, ego_key, ego_attn = self.vit_model.get_last_key(ego) # attn: b x 6 x (1+hw) x (1+hw) ego_desc = ego_key.permute(0, 2, 3, 1).flatten(-2, -1)