-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
65 lines (45 loc) · 1.63 KB
/
test.py
File metadata and controls
65 lines (45 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from load_data import get_dataloader
from models import VGG16
import torch
from utils.plot_confusion_matrix import plot_confusion_matrix
from sklearn.metrics import accuracy_score, confusion_matrix
from utils import mylogger
import logging
mylogger.setup("configs/logger.yaml")
logger = logging.getLogger("root")
from ruamel.yaml import YAML
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def test(model, test_loader, plot_cm: bool):
logger.info("------------ Test ------------")
labels = []
preds = []
model.eval()
with torch.no_grad():
for data in test_loader:
img, label = data
img, label = img.to(device), label.to(device)
pred = model(img)
pred = pred.argmax(dim=1)
labels.append(label)
preds.append(pred)
labels = torch.cat(labels).cpu().numpy()
preds = torch.cat(preds).cpu().numpy()
acc = accuracy_score(labels, preds)
cm = confusion_matrix(labels, preds)
logger.info(f"Accuracy on test set: {acc}")
logger.info(f"Confusion matrix:\n{cm}")
if plot_cm:
plot_confusion_matrix(cm, class_names=[str(i) for i in range(12)])
if __name__ == "__main__":
with open("configs/train.yaml") as f:
config = YAML().load(f)["test"]
model = VGG16(
in_channels=3,
num_classes=12,
use_cbam=config["use_cbam"],
use_residual=config["use_residual"],
)
model.load_state_dict(torch.load(config["model_path"]))
model = model.to(device)
_, _, test_loader = get_dataloader(config["batch_size"])
test(model, test_loader, config["plot_cm"])