-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluation_graphs_basic.py
113 lines (88 loc) · 3.7 KB
/
evaluation_graphs_basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import joblib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import umap
from mpl_toolkits import mplot3d
from train_ganomaly import load_data_file
##############
# Parameters #
##############
model_dir = 'experiments/outliers/if'
# Only departments in this dict will be processed, the rest will be dropped
depa_string_mapdict = {
'Ob/gyn':'Ob/gyn',
'Adolescent': 'Adolescent medicine',
'Chirurgie':'Surgery',
'Réadaptation':'Long term care',
'Oncologie':'Oncology',
'Néonatologie':'NICU',
'Soins intensifs':'PICU',
'Pédiatrie':'General pediatrics',
'Psychiatrie':'Psychiatry',
'Pouponnière':'Nursery',
'Multispécialité':'Specialized pediatrics',
'Overall':'Overall'
}
#############
# Functions #
#############
def make_projection_graph(depas, predictions, coords):
indices = np.random.choice(np.arange(len(depas)), 2000)
sampled_depa = np.take(depas, indices)
sampled_predictions = np.take(predictions, indices)
sampled_coords = np.take(coords, indices, axis=0)
umap_vectors = umap.UMAP(n_components=3).fit_transform(sampled_coords)
graph_marker_dict = {1:'o', -1:'^'}
graph_markers = [graph_marker_dict[i] for i in sampled_predictions.tolist()]
graph_color_dict = {depa:n for depa, n in zip(list(set(sampled_depa)), range(len(list(set(sampled_depa)))))}
graph_colors = [graph_color_dict[depa] for depa in sampled_depa]
graph_data = []
for umapcoords, color, marker in zip(umap_vectors, graph_colors, graph_markers):
graph_data.append([umapcoords[0], umapcoords[1], umapcoords[2], color, marker])
graph_data_df = pd.DataFrame(data=graph_data, columns=['x', 'y', 'z', 'color', 'marker'])
ax = plt.figure(figsize=(16,10)).gca(projection='3d')
for marker, group in graph_data_df.groupby('marker'):
ax.scatter(
xs=group['x'] ,
ys=group['y'] ,
zs=group['z'],
c=group['color'],
marker=marker,
cmap='prism'
)
plt.savefig(os.path.join(model_dir, 'umap_projection.png'))
plt.gcf().clear()
def make_score_graph(depas, scores):
graph_data = []
for score, depa in zip(scores, depas):
if depa in depa_string_mapdict.keys():
graph_data.append([score, depa_string_mapdict[depa]])
graph_data_df = pd.DataFrame(data=graph_data, columns=['Anomaly score', 'Department'])
sns.set(font_scale=1.5, style="whitegrid")
sns.catplot(orient='h', y='Department', x='Anomaly score', data=graph_data_df, kind='box', showfliers=False, aspect=1.4, order=graph_data_df.groupby('Department').agg('median').sort_values(by='Anomaly score').index.values)
plt.savefig(os.path.join(model_dir, 'anomaly_score_by_depa.png'), dpi=600)
plt.gcf().clear()
def make_ratio_graph(ratios):
ratio_df = pd.DataFrame.from_dict(ratios, orient='index', columns=['Anomaly ratio'])
ratio_df.reset_index(inplace=True)
ratio_df.rename(inplace=True, index=str, columns={'index':'Department'})
sns.set(font_scale=1.5, style="whitegrid")
sns.barplot(orient='h', data=ratio_df, y='Department', x='Anomaly ratio', order=ratio_df.sort_values(by='Anomaly ratio')['Department'])
plt.subplots_adjust(left=0.45, bottom=0.15)
plt.savefig(os.path.join(model_dir, 'anomaly_ratio_by_depa.png'), dpi=600)
plt.gcf().clear()
###########
# Execute #
###########
if __name__ == '__main__':
cat_depa_list = load_data_file(os.path.join(model_dir, 'cat_depa_list.pkl'))
anomaly_ratios = load_data_file(os.path.join(model_dir, 'anomaly_ratios.pkl'))
coords = joblib.load(os.path.join(model_dir, 'coords.joblib'))
predictions = joblib.load(os.path.join(model_dir, 'predictions.joblib'))
scores = joblib.load(os.path.join(model_dir, 'scores.joblib'))
make_projection_graph(cat_depa_list, predictions, coords)
make_score_graph(cat_depa_list, scores)
make_ratio_graph(anomaly_ratios)