-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplotting.py
98 lines (83 loc) · 2.67 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap
def plot_tsne(embeddings, sentences, classes, legend_info = None):
'''
Input:
embeddings: np.array of shape num_sentences x 2
sentences: list of sentences of length num_sentences
classes: list of binary labels of size num_sentences. 1 when the sentence belongs to class 1 and 0 otherwise.
Output:
plots embeddings with hover labels
** The hover annotations are based on https://stackoverflow.com/a/47166787/1434041 **
'''
x = embeddings[:, 0]
y = embeddings[:, 1]
names = sentences
fig, ax = plt.subplots()
if legend_info is None:
sc = plt.scatter(x, y, c = classes, alpha = 0.5)
else:
class_mpatches = [mpatches.Patch(color=color, label=name) for name,color in legend_info]
colormap = ListedColormap([c for _,c in legend_info])
plt.legend(handles = class_mpatches)
sc = plt.scatter(x, y, c = classes, cmap = colormap, alpha = 0.5)
annot = ax.annotate('', xy = (0,0),
xytext = (-10,10),
textcoords = 'offset points',
bbox = dict(boxstyle = 'round', fc = 'w'),
arrowprops = dict(arrowstyle = '->'))
annot.set_visible(False)
def update_annot(ind):
pos = sc.get_offsets()[ind['ind'][0]]
annot.xy = pos
text = '{}'.format('\n'.join([names[n] for n in ind['ind']]))
annot.set_text(text)
annot.get_bbox_patch().set_alpha(0.4)
def hover(event):
vis = annot.get_visible()
if event.inaxes == ax:
cont, ind = sc.contains(event)
if cont:
update_annot(ind)
annot.set_visible(True)
fig.canvas.draw_idle()
else:
if vis:
annot.set_visible(False)
fig.canvas.draw_idle()
fig.canvas.mpl_connect('motion_notify_event', hover)
fig1 = plt.gcf()
plt.draw()
plt.show()
return fig1
def plot_precision_recall_f1_curve(precisions, recalls, thresholds, f1_scores):
'''
Given arrays of precisions, recalls, thresholds, f1_scores for each confidence threshold
Plots precision-recall curve with and f1 score overlay
'''
plt.scatter(recalls, precisions, c = thresholds)
plt.plot(recalls, precisions, label = 'precision-recall')
plt.plot(recalls, f1_scores, c = 'r', label = 'f1 score')
plt.xlabel('recall')
plt.ylabel('precision')
plt.ylim(0.0, 1.0)
plt.xlim(0.2, 1.0)
plt.legend()
fig1 = plt.gcf()
plt.show()
return fig1
def plot_precision_by_conf(bins_, precision_by_bin):
'''
Given precisions array for each confidence bin
Plots precision by confidence bin
'''
plt.plot(bins_, precision_by_bin)
plt.scatter(bins_, precision_by_bin)
plt.ylim(0.0, 1.0)
plt.xlabel('Confidence Bins')
plt.ylabel('Precision')
plt.title('Confidence Score Calibration')
fig1 = plt.gcf()
plt.show()
return fig1