Skip to content

Does your model have true test dataset? #5

@lygztq

Description

@lygztq

I think your code actually reports the validation result right? More specifically, in here

for epoch in range(cmd_args.num_epochs):
        random.shuffle(train_idxes)
        classifier.train()
        avg_loss = loop_dataset(train_graphs, classifier, mi_loss, train_idxes, epoch, optimizer=optimizer, device=device)
        avg_loss[4] = 0.0
        print('\033[92maverage training of epoch %d: clsloss: %.5f miloss: %.5f loss %.5f acc %.5f auc %.5f\033[0m'
              % (epoch, avg_loss[0], avg_loss[1], avg_loss[2], avg_loss[3], avg_loss[4])) # noqa

        classifier.eval()
        test_loss = loop_dataset(test_graphs, classifier, mi_loss, list(range(len(test_graphs))), epoch, device=device)
        test_loss[4] = 0.0
        print('\033[93maverage test of epoch %d: clsloss: %.5f miloss: %.5f loss %.5f acc %.5f auc %.5f\033[0m'
              % (epoch, test_loss[0], test_loss[1], test_loss[2], test_loss[3], test_loss[4])) # noqa

        with open(logfile, 'a+') as log:
            log.write('test of epoch %d: clsloss: %.5f miloss: %.5f loss %.5f acc %.5f auc %.5f'
                      % (epoch, test_loss[0], test_loss[1], test_loss[2], test_loss[3], test_loss[4]) + '\n')

        if test_loss[3] > max_acc:
            max_acc = test_loss[3]
            fname = './checkpoint_%s/time_%s/FOLD%s/model_epoch%s.pt' % (cmd_args.data, first_timstr, foldidx, str(epoch))
            torch.save(classifier.state_dict(), fname)

with open('./result_%s/result_%s/acc_result_%s_%s.txt' % (cmd_args.data, first_timstr, cmd_args.data, first_timstr), 'a+') as f:
        f.write('\n')
        f.write('Fold index: ' + str(foldidx) + '\t')
        f.write(str(max_acc) + '\n')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions