-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
118 lines (103 loc) · 4.19 KB
/
test.py
File metadata and controls
118 lines (103 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import torch
from utils import mkdir_p, parse_args,get_trained_loss, create_save_path,crl_utils
from solvers.runners import test, test_CRL
from solvers.loss import loss_dict
from time import localtime, strftime
from models import model_dict
from datasets import dataloader_dict, dataset_nclasses_dict, dataset_classname_dict
import numpy as np
import torch.nn as nn
import logging
import json
args = parse_args()
np.random.seed(args.seed)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO,
format="%(levelname)s: %(message)s",
handlers=[
logging.StreamHandler()
])
num_classes = dataset_nclasses_dict[args.dataset]
classes_name_list = dataset_classname_dict[args.dataset]
# prepare model
logging.info(f"Using model : {args.model}")
assert args.checkpoint, "Please provide a trained model file"
assert os.path.isfile(args.checkpoint)
logging.info(f'Resuming from saved checkpoint: {args.checkpoint}')
checkpoint_folder = os.path.dirname(args.checkpoint)
saved_model_dict = torch.load(args.checkpoint)
model = model_dict[args.model](num_classes=num_classes, alpha=args.alpha)
try:
model.load_state_dict(saved_model_dict['state_dict'])
except:
model=nn.DataParallel(model)
model.load_state_dict(saved_model_dict['state_dict'])
model.cuda()
if("CRL" in args.loss):
test = test_CRL
# set up dataset
logging.info(f"Using dataset : {args.dataset}")
trainloader, valloader, testloader = dataloader_dict[args.dataset](args)
history = crl_utils.History(len(trainloader.dataset))
criterion = loss_dict[args.loss](gamma=args.gamma, alpha=args.alpha, beta=args.beta,
loss=args.loss, delta=args.delta, history=history, arguments=args)
test_loss, top1, top3, top5, sce_score, ece_score,all_metrics = test(testloader, model, criterion)
logging.info("Stats: test_loss : {:.4f} | top1 : {:.4f} | top3 : {:.4f} | top5 : {:.4f} | SCE: {:.5f} | ECE: {:.5f} | AUROC: {:5f} | FPR-AT-95: {:5f} | AUPR-S: {:5f} | AUPR-E: {:5f} | AURC: {:5f} | EAURC: {:5f}".format(
test_loss,
top1,
top3,
top5,
sce_score,
ece_score,
all_metrics["auroc"],
all_metrics["fpr-at-95"],
all_metrics["aupr-success"],
all_metrics["aupr-error"],
all_metrics["aurc"],
all_metrics["eaurc"]
))
# save the tpr and fpr
if not os.path.isdir(args.aurocfolder):
mkdir_p(args.aurocfolder)
trained_loss=get_trained_loss(args.checkpoint)
auroc_name=args.dataset+"_"+args.checkpoint[11:].split("/")[1]
print(auroc_name)
# auroc_name=args.model+'_'+args.dataset+'_'+strftime("%d-%b", localtime())+create_save_path(args)+"_tpr_fpr.npy"
# np.save(os.path.join(args.aurocfolder, auroc_name), np.append(auroc['tpr'],auroc['fpr']))
username=os.getlogin()
# append test_loss, top1, top3, top5, sce_score, ece_score as json object
jsonfile=args.resultsfile+"_"+username+".json"
if not os.path.isfile(jsonfile):
with open(jsonfile, 'w') as f:
json.dump([{}], f)
data=[]
if os.stat(jsonfile).st_size != 0:
data=json.load(open(jsonfile))
data.append({
"model": args.model,
"dataset": args.dataset,
"loss": args.loss+"_"+args.pairing,
"alpha": args.alpha,
"beta": args.beta,
"gamma": args.gamma,
"theta": args.theta,
"scaling": args.scalefactor,
"total_epochs": args.epochs,
"scheduler steps": args.schedule_steps,
"top3": top3,
"top5": top5,
"SCE": sce_score,
"ECE": ece_score,
"top1": top1,
"AUROC": all_metrics["auroc"],
"FPR-AT-95": all_metrics["fpr-at-95"],
"AUPR-S": all_metrics["aupr-success"],
"AUPR-E": all_metrics["aupr-error"],
"AURC": all_metrics["aurc"],
"EAURC": all_metrics["eaurc"],
"date": strftime("%d-%b", localtime())
})
with open(jsonfile, 'w') as f:
json.dump(data, f, indent=4)
logging.info("Saved results to {}".format(jsonfile))