-
Notifications
You must be signed in to change notification settings - Fork 11
Does your model have true test dataset? #5
Copy link
Copy link
Open
Description
I think your code actually reports the validation result right? More specifically, in here
for epoch in range(cmd_args.num_epochs):
random.shuffle(train_idxes)
classifier.train()
avg_loss = loop_dataset(train_graphs, classifier, mi_loss, train_idxes, epoch, optimizer=optimizer, device=device)
avg_loss[4] = 0.0
print('\033[92maverage training of epoch %d: clsloss: %.5f miloss: %.5f loss %.5f acc %.5f auc %.5f\033[0m'
% (epoch, avg_loss[0], avg_loss[1], avg_loss[2], avg_loss[3], avg_loss[4])) # noqa
classifier.eval()
test_loss = loop_dataset(test_graphs, classifier, mi_loss, list(range(len(test_graphs))), epoch, device=device)
test_loss[4] = 0.0
print('\033[93maverage test of epoch %d: clsloss: %.5f miloss: %.5f loss %.5f acc %.5f auc %.5f\033[0m'
% (epoch, test_loss[0], test_loss[1], test_loss[2], test_loss[3], test_loss[4])) # noqa
with open(logfile, 'a+') as log:
log.write('test of epoch %d: clsloss: %.5f miloss: %.5f loss %.5f acc %.5f auc %.5f'
% (epoch, test_loss[0], test_loss[1], test_loss[2], test_loss[3], test_loss[4]) + '\n')
if test_loss[3] > max_acc:
max_acc = test_loss[3]
fname = './checkpoint_%s/time_%s/FOLD%s/model_epoch%s.pt' % (cmd_args.data, first_timstr, foldidx, str(epoch))
torch.save(classifier.state_dict(), fname)
with open('./result_%s/result_%s/acc_result_%s_%s.txt' % (cmd_args.data, first_timstr, cmd_args.data, first_timstr), 'a+') as f:
f.write('\n')
f.write('Fold index: ' + str(foldidx) + '\t')
f.write(str(max_acc) + '\n')Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels