Skip to content

bug in models > ali.py #1

@jaks19

Description

@jaks19

Hello, in models > ali.py, line 81 of function set_z, I believe is a bug?

def set_z(self, var=None, volatile=False):
        if var is None:
            self.normal_z = var
        else:
            if self.gpu_ids:
                self.normal_z = Variable(torch.randn((self.opt.batch_size, self.encoder.k)).cuda(), volatile=volatile)
            else:
                self.normal_z = Variable(torch.randn((self.opt.batch_size, self.encoder.k)), volatile=volatile)

Would we not like to have something like if var is not None for the first condition? Otherwise, when GibbsNet has sampled a z from the unclamped chain and passed it to ALI as implemented below, we are losing this z and simply restarting it in ALI, essentially losing the unclamped chain?

def forward(self, volatile=False):
        self.sampling()

        # clamped chain : ALI model
        self.ali_model.set_z(var=self.z)
        self.ali_model.set_input(self.x.data, is_z_given=True)

        self.ali_model.forward()

Any help is appreciated. Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions