forked from Brankonymous/MelanomaClassification
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
113 lines (93 loc) · 4.67 KB
/
test.py
File metadata and controls
113 lines (93 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
from utils import loadDataset
from utils.constants import *
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import pickle
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
class TestNeuralNetwork():
def __init__(self, config):
self.config = config
def startTest(self):
# Initialize dataset
TestDataset, _ = loadDataset(isTrain=False, modelName=self.config['model_name'], datasetName=self.config['dataset_name'])
# Generate DataLoader
TestLoader = torch.utils.data.DataLoader(TestDataset, batch_size=BATCH_SIZE, shuffle=False)
# Load model
model = None
if self.config['model_name'] == 'VGG':
model = torch.load(SAVED_MODEL_PATH + 'VGG_model.pth').to(DEVICE)
elif self.config['model_name'] == 'XGBoost':
model = pickle.load(open(SAVED_MODEL_PATH + 'XGBoost_model', 'rb'))
else:
raise ValueError("Please choose either VGG or XGBoost")
self.testModel(model, TestLoader)
def testModel(self, model, DataLoader):
all_predictions = []
all_labels = []
# Evaluate the model
if self.config['model_name'] == 'VGG':
model.eval()
with torch.no_grad():
for images, labels in DataLoader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = model(images)
prediction = outputs.argmax(1)
all_predictions.extend(prediction.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
elif self.config['model_name'] == 'XGBoost':
for images, labels in DataLoader:
features = images
features_2d = np.array(features).reshape(features.shape[0], -1)
predictions = model.predict(features_2d)
all_predictions.extend(predictions)
all_labels.extend(labels.numpy())
# Generate classification report
accuracy = accuracy_score(all_labels, all_predictions)
precision = classification_report(all_labels, all_predictions, target_names=CLASS_NAMES, labels=[0, 1], output_dict=True)['weighted avg']['precision']
recall = classification_report(all_labels, all_predictions, target_names=CLASS_NAMES, labels=[0, 1], output_dict=True)['weighted avg']['recall']
f1_score = classification_report(all_labels, all_predictions, target_names=CLASS_NAMES, labels=[0, 1], output_dict=True)['weighted avg']['f1-score']
report = classification_report(all_labels, all_predictions, target_names=CLASS_NAMES, labels=[0, 1])
print(
'Accuracy: {:.2f}'.format(accuracy), '\n',
'Precision: {:.2f}'.format(precision), '\n',
'Recall: {:.2f}'.format(recall), '\n',
'F1 Score: {:.2f}'.format(f1_score), '\n',
'Report: ', report
)
if self.config['save_plot'] or self.config['show_plot']:
self.plotResults(accuracy, precision, recall, f1_score, all_labels, all_predictions)
return accuracy, precision, recall, f1_score
def plotResults(self, accuracy, precision, recall, f1_score, all_labels, all_predictions):
# Plot model metrics
fig, ax = plt.subplots(dpi=150)
ax.bar('Tačnost', accuracy, label='Tačnost')
ax.bar('Preciznost', precision, label='Preciznost')
ax.bar('Odziv', recall, label='Odziv')
ax.bar('F1 Mera', f1_score, label='F1 Mera')
ax.set_ylim(0, 1)
ax.set_ylabel('Ocene')
ax.set_title('Performanse modela ' + self.config['model_name'])
if self.config['save_plot']:
plt.savefig(SAVED_PLOT_PATH + self.config['model_name'] + '_' + self.config['dataset_name'] + '_rezultati.png')
if self.config['show_plot']:
plt.show()
else:
plt.close()
# Plot confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
fig, ax = plt.subplots(dpi=150)
sns.heatmap(cm, annot=True, ax=ax, cmap='Blues', fmt='g', cbar=False)
ax.set_xlabel('Predviđene klase')
ax.set_ylabel('Stvarne klase')
ax.set_title('Matrica konfuzije za model ' + self.config['model_name'])
ax.xaxis.set_ticklabels(CLASS_NAMES_SERBIAN)
ax.yaxis.set_ticklabels(CLASS_NAMES_SERBIAN)
if self.config['save_plot']:
plt.savefig(SAVED_PLOT_PATH + self.config['model_name'] + '_' + self.config['dataset_name'] + '_matrica_konfuzije.png')
if self.config['show_plot']:
plt.show()
else:
plt.close()