-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathensemble.py
More file actions
57 lines (40 loc) · 1.94 KB
/
ensemble.py
File metadata and controls
57 lines (40 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from inference import apply_model
def get_model_results(model_ids, data_path=os.getenv("DATASET_TEST_PATH")):
"""Takes a list of model ids and returns the results for each model as a list of lists.
Also returns the paths"""
results = []
paths = []
columns = None
for model_id in model_ids:
_, result_df = apply_model(model_id, data_path)
results.append(result_df.drop(columns=["file"]).values.tolist())
paths = result_df["file"].tolist()
if columns is None:
columns = result_df.columns.drop("file")
# Inner lists contain the results for one image over all models
results = list(map(list, zip(*results)))
return results, paths, columns
def ensemble_results(model_ids, data_path=os.getenv("DATASET_TEST_PATH")):
"""Takes a list of model ids and applies them to the given data_path.
Returns the averaged results as a dataframe."""
results, paths, columns = get_model_results(model_ids, data_path)
# We have no softmax layer in our models, so we have to use averaging method to create the ensemble
averaged_results = [[sum(values) / len(values) for values in zip(*row)] for row in results]
# Accuracies and confusionmatrix expect a dataframe
ensemble_results_df = pd.DataFrame(averaged_results, columns=columns)
ensemble_results_df.insert(0, "path", paths)
return ensemble_results_df
def save_confusion_matrix_as_heatmap(conf_matrix, labels, filename):
"""Saves the given confusion matrix as a heatmap."""
fig, ax = plt.subplots(figsize=(10, 7))
sns.heatmap(conf_matrix, annot=True, fmt='.2f', xticklabels=labels, yticklabels=labels, ax=ax)
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
ax.set_xlabel('Predicted')
ax.set_ylabel('Actual')
fig.suptitle("Top-1: 82,16% & Top-3: 95.33%", y=0.1, weight='bold')
plt.savefig(filename)