-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
Calling the .get_c() method from the embedder instance created in the example
embedder = CPCProtEmbedding(model)
ends up throwing a CUDA out of memory error, because it is tracking the gradients
def get_c(self, data, return_mask = False):
z, mask = self.get_z(data, return_mask=True)
if self.parallel:
# workaround for accessing model attributes when DataParallel
c = self.cpc(data, return_early='c')
else:
c = self.cpc.get_c(z)
wrapping the forward pass in a "with torch.no_grad():" worked for me
def get_c(self, data, return_mask = False):
z, mask = self.get_z(data, return_mask=True)
with torch.no_grad():
if self.parallel:
# workaround for accessing model attributes when DataParallel
c = self.cpc(data, return_early='c')
else:
c = self.cpc.get_c(z)
Metadata
Metadata
Assignees
Labels
No labels