-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_utils.py
194 lines (158 loc) · 5.73 KB
/
eval_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#%%
"""
Evaluation utility functions.
This module contains util functions for computing evaluation scores.
"""
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
def summary_scores(all_scores):
"""Summarize group scores.
Args:
all_scores (dict{str,list}):
{group name: score list of individual cells}.
Returns:
dict{str,float}:
Group-wise aggregation scores.
float:
score aggregated on all samples
"""
sep_scores = {k:np.mean(s) for k, s in all_scores.items() if s}
overal_agg = np.mean([s for k, s in sep_scores.items() if s])
return sep_scores, overal_agg
def keep_type(adata, nodes, target, k_cluster):
"""Select cells of targeted type
Args:
adata (Anndata):
Anndata object.
nodes (list):
Indexes for cells
target (str):
Cluster name.
k_cluster (str):
Cluster key in adata.obs dataframe
Returns:
list:
Selected cells.
"""
return nodes[adata.obs[k_cluster][nodes].values == target]
def cross_boundary_correctness(
adata,
k_cluster,
k_velocity,
cluster_edges,
return_raw=False,
x_emb="X_umap"
):
"""Cross-Boundary Direction Correctness Score (A->B)
Args:
adata (Anndata):
Anndata object.
k_cluster (str):
key to the cluster column in adata.obs DataFrame.
k_velocity (str):
key to the velocity matrix in adata.obsm.
cluster_edges (list of tuples("A", "B")):
pairs of clusters has transition direction A->B
return_raw (bool):
return aggregated or raw scores.
x_emb (str):
key to x embedding for visualization.
Returns:
dict:
all_scores indexed by cluster_edges or mean scores indexed by cluster_edges
float:
averaged score over all cells.
"""
scores = {}
all_scores = {}
x_emb = adata.obsm[x_emb]
if x_emb == "X_umap":
v_emb = adata.obsm['{}_umap'.format(k_velocity)]
else:
v_emb = adata.obsm[[key for key in adata.obsm if key.startswith(k_velocity)][0]]
for u, v in cluster_edges:
sel = adata.obs[k_cluster] == u
nbs = adata.uns['neighbors']['indices'][sel] # [n * 30]
boundary_nodes = map(lambda nodes:keep_type(adata, nodes, v, k_cluster), nbs)
x_points = x_emb[sel]
x_velocities = v_emb[sel]
type_score = []
for x_pos, x_vel, nodes in zip(x_points, x_velocities, boundary_nodes):
if len(nodes) == 0: continue
position_dif = x_emb[nodes] - x_pos
dir_scores = cosine_similarity(position_dif, x_vel.reshape(1,-1)).flatten()
type_score.append(np.mean(dir_scores))
scores[(u, v)] = np.mean(type_score)
all_scores[(u, v)] = type_score
if return_raw:
return all_scores
return scores, np.mean([sc for sc in scores.values()])
def inner_cluster_coh(adata, k_cluster, k_velocity, return_raw=False):
"""In-cluster Coherence Score.
Args:
adata (Anndata):
Anndata object.
k_cluster (str):
key to the cluster column in adata.obs DataFrame.
k_velocity (str):
key to the velocity matrix in adata.obsm.
return_raw (bool):
return aggregated or raw scores.
Returns:
dict:
all_scores indexed by cluster_edges mean scores indexed by cluster_edges
float:
averaged score over all cells.
"""
clusters = np.unique(adata.obs[k_cluster])
scores = {}
all_scores = {}
for cat in clusters:
sel = adata.obs[k_cluster] == cat
nbs = adata.uns['neighbors']['indices'][sel]
same_cat_nodes = map(lambda nodes:keep_type(adata, nodes, cat, k_cluster), nbs)
velocities = adata.layers[k_velocity]
cat_vels = velocities[sel]
cat_score = [cosine_similarity(cat_vels[[ith]], velocities[nodes]).mean()
for ith, nodes in enumerate(same_cat_nodes)
if len(nodes) > 0]
all_scores[cat] = cat_score
scores[cat] = np.mean(cat_score)
if return_raw:
return all_scores
return scores, np.mean([sc for sc in scores.values()])
def evaluate(
adata,
cluster_edges,
k_cluster,
k_velocity="velocity",
x_emb="X_umap",
verbose=True
):
"""Evaluate velocity estimation results using 5 metrics.
Args:
adata (Anndata):
Anndata object.
cluster_edges (list of tuples("A", "B")):
pairs of clusters has transition direction A->B
k_cluster (str):
key to the cluster column in adata.obs DataFrame.
k_velocity (str):
key to the velocity matrix in adata.obsm.
x_emb (str):
key to x embedding for visualization.
Returns:
dict:
aggregated metric scores.
"""
from .eval_utils import cross_boundary_correctness
from .eval_utils import inner_cluster_coh
crs_bdr_crc = cross_boundary_correctness(adata, k_cluster, k_velocity, cluster_edges, True, x_emb)
ic_coh = inner_cluster_coh(adata, k_cluster, k_velocity, True)
if verbose:
print("# Cross-Boundary Direction Correctness (A->B)\n{}\nTotal Mean: {}".format(*summary_scores(crs_bdr_crc)))
print("# In-cluster Coherence\n{}\nTotal Mean: {}".format(*summary_scores(ic_coh)))
return {
"Cross-Boundary Direction Correctness (A->B)": crs_bdr_crc,
"In-cluster Coherence": ic_coh,
}