diff --git a/solver.py b/solver.py index 3282ec4..62d0c5c 100644 --- a/solver.py +++ b/solver.py @@ -88,7 +88,7 @@ def train(self, querry_dataloader, val_dataloader, task_model, vae, discriminato unlabeled_preds = discriminator(unlab_mu) lab_real_preds = torch.ones(labeled_imgs.size(0)) - unlab_real_preds = torch.ones(unlabeled_imgs.size(0)) + unlab_real_preds = torch.zeros(unlabeled_imgs.size(0)) if self.args.cuda: lab_real_preds = lab_real_preds.cuda()