Skip to content

Need to insert a "with torch.no_grad():" in heads.py #4

@wjs20

Description

@wjs20

Calling the .get_c() method from the embedder instance created in the example

embedder = CPCProtEmbedding(model)

ends up throwing a CUDA out of memory error, because it is tracking the gradients

def get_c(self, data, return_mask = False):
        z, mask = self.get_z(data, return_mask=True)
        if self.parallel:
            # workaround for accessing model attributes when DataParallel
            c = self.cpc(data, return_early='c')
        else:
            c = self.cpc.get_c(z)

wrapping the forward pass in a "with torch.no_grad():" worked for me

def get_c(self, data, return_mask = False):
        z, mask = self.get_z(data, return_mask=True)
        with torch.no_grad():
                if self.parallel:
                    # workaround for accessing model attributes when DataParallel
                    c = self.cpc(data, return_early='c')
                else:
                    c = self.cpc.get_c(z)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions