diff --git a/Cell_BLAST/blast.py b/Cell_BLAST/blast.py index 896417e..e8f51ae 100644 --- a/Cell_BLAST/blast.py +++ b/Cell_BLAST/blast.py @@ -23,7 +23,8 @@ MINIMAL = 0 -def _wasserstein_distance_impl(x: np.ndarray, y: np.ndarray): # pragma: no cover +@numba.jit(nopython=True, nogil=True, cache=True) +def wasserstein_distance(x: np.ndarray, y: np.ndarray): # pragma: no cover x_sorter = np.argsort(x) y_sorter = np.argsort(y) xy = np.concatenate((x, y)) @@ -34,37 +35,6 @@ def _wasserstein_distance_impl(x: np.ndarray, y: np.ndarray): # pragma: no cove return np.sum(np.multiply(np.abs(x_cdf - y_cdf), deltas)) -def _energy_distance_impl(x: np.ndarray, y: np.ndarray): # pragma: no cover - x_sorter = np.argsort(x) - y_sorter = np.argsort(y) - xy = np.concatenate((x, y)) - xy.sort() - deltas = np.diff(xy) - x_cdf = np.searchsorted(x[x_sorter], xy[:-1], "right") / x.size - y_cdf = np.searchsorted(y[y_sorter], xy[:-1], "right") / y.size - return np.sqrt(2 * np.sum(np.multiply(np.square(x_cdf - y_cdf), deltas))) - - -@numba.extending.overload( - scipy.stats.wasserstein_distance, jit_options={"nogil": True, "cache": True} -) -def _wasserstein_distance(x: np.ndarray, y: np.ndarray): # pragma: no cover - if ( - x == numba.float32[::1] and y == numba.float32[::1] - ) or ( - x == numba.float64[::1] and y == numba.float64[::1] - ): - return _wasserstein_distance_impl - - -@numba.extending.overload( - scipy.stats.energy_distance, jit_options={"nogil": True, "cache": True} -) -def _energy_distance(x: np.ndarray, y: np.ndarray): # pragma: no cover - if x == numba.float32[::1] and y == numba.float32[::1]: - return _energy_distance_impl - - @numba.jit(nopython=True, nogil=True, cache=True) def ed(x: np.ndarray, y: np.ndarray): # pragma: no cover r""" @@ -230,10 +200,10 @@ def npd_v1( np.std(y_posterior) + np.float32(eps) ) return 0.5 * ( - scipy.stats.wasserstein_distance( + wasserstein_distance( xy_posterior1[: len(x_posterior)], xy_posterior1[-len(y_posterior) :] ) - + scipy.stats.wasserstein_distance( + + wasserstein_distance( xy_posterior2[: len(x_posterior)], xy_posterior2[-len(y_posterior) :] ) ) @@ -939,13 +909,15 @@ def query( hits, dist, pval, - query - if store_dataset - else anndata.AnnData( - X=scipy.sparse.csr_matrix((query.shape[0], 0)), - obs=pd.DataFrame(index=query.obs.index), - var=pd.DataFrame(), - uns={}, + ( + query + if store_dataset + else anndata.AnnData( + X=scipy.sparse.csr_matrix((query.shape[0], 0)), + obs=pd.DataFrame(index=query.obs.index), + var=pd.DataFrame(), + uns={}, + ) ), )