-
Notifications
You must be signed in to change notification settings - Fork 47
Open
Description
Lines 195 to 204 in 81c3a9c
| if args.method == 'ENT': | |
| loss_t = entropy(F1, output, args.lamda) | |
| loss_t.backward() | |
| optimizer_f.step() | |
| optimizer_g.step() | |
| elif args.method == 'MME': | |
| loss_t = adentropy(F1, output, args.lamda) | |
| loss_t.backward() | |
| optimizer_f.step() | |
| optimizer_g.step() |
Lines 28 to 41 in 81c3a9c
| def entropy(F1, feat, lamda, eta=1.0): | |
| out_t1 = F1(feat, reverse=True, eta=-eta) | |
| out_t1 = F.softmax(out_t1) | |
| loss_ent = -lamda * torch.mean(torch.sum(out_t1 * | |
| (torch.log(out_t1 + 1e-5)), 1)) | |
| return loss_ent | |
| def adentropy(F1, feat, lamda, eta=1.0): | |
| out_t1 = F1(feat, reverse=True, eta=eta) | |
| out_t1 = F.softmax(out_t1) | |
| loss_adent = lamda * torch.mean(torch.sum(out_t1 * | |
| (torch.log(out_t1 + 1e-5)), 1)) | |
| return loss_adent |
Thank you for your code.
From your code it seems that
ENT method try to minimize entropy on classifier but maximize on feature extractor;
AdENT method try to maximize entropy on classifier but minimize on feature extractor, which is proposed in your paper.
BUT, in your paper the ENT method seems to be described as minimize entropy on both classifier and feature extractor, as referred in Yves Grandvalet and Yoshua Bengio. Semi-supervised learning by entropy minimization. In NIPS, 2005
So, i'm very confused about it. I'm looking forward to hearing from you.
Metadata
Metadata
Assignees
Labels
No labels