-
Notifications
You must be signed in to change notification settings - Fork 0
/
confusionchart.py
66 lines (54 loc) · 2.33 KB
/
confusionchart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os, sys
current = os.path.dirname(os.path.realpath(__file__))
parent = os.path.dirname(current)
sys.path.append(parent)
import numpy as np
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot as plt
from util.plot_struct import DataStruct
import itertools
def plot_confusion_matrix(dataStruct:DataStruct, cmap=plt.cm.Blues, fontsize=20):
labels = dataStruct.labels
pre = dataStruct.pre
classes = dataStruct.classes
title = dataStruct.title
conf_numpy = confusion_matrix(labels, pre) if not dataStruct.conf_matrix else np.asarray(dataStruct.conf_matrix)
if dataStruct.normalize:
conf_numpy = conf_numpy.astype('float') / conf_numpy.sum(axis = 1)
conf_numpy = np.around(conf_numpy,decimals=2)
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(conf_numpy)
# plt.figure(figsize=(8, 7))
plt.imshow(conf_numpy, interpolation='nearest', cmap=cmap)
plt.title(title, fontsize=fontsize)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=fontsize)
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45, fontsize=fontsize)
plt.yticks(tick_marks, classes, fontsize=fontsize)
fmt = '.2f' if dataStruct.normalize else 'd'
thresh = conf_numpy.max() / 2.
for i, j in itertools.product(range(conf_numpy.shape[0]), range(conf_numpy.shape[1])):
plt.text(j, i, format(conf_numpy[i, j], fmt),
horizontalalignment="center",
fontsize=fontsize,
color="white" if conf_numpy[i, j] > thresh else "black")
plt.ylabel('True label', fontsize=fontsize)
plt.xlabel('Predicted label', fontsize=fontsize)
plt.tight_layout()
plt.savefig(dataStruct.filePath, dpi = dataStruct.dpi)
if dataStruct.show:
plt.show()
if __name__ == '__main__':
dataStruct = DataStruct()
# dataStruct.labels = [0,0,0,0,1,1,1,1,2,2,2,2]
# dataStruct.pre = [0,0,0,0,1,1,1,1,2,2,2,2]
dataStruct.conf_matrix = [[40,0,0],[1,35,4],[1,5,34]]
dataStruct.classes = ['Pushing & \nPulling','Beckoning','Rubbing \nFingers']
dataStruct.filePath = 'example/confusion_matrix.png'
dataStruct.title = 'Confusion Matrix'
dataStruct.normalize = True
dataStruct.show = True
plot_confusion_matrix(dataStruct)