From 6d4948b434fec40c3d857fa44270af434631fbdc Mon Sep 17 00:00:00 2001 From: Zhi-Jie Cao Date: Tue, 7 Feb 2023 21:34:25 +0800 Subject: [PATCH] Formatting --- Cell_BLAST/__init__.py | 20 +- Cell_BLAST/blast.py | 738 ++++++++++++++++++++++--------------- Cell_BLAST/config.py | 5 +- Cell_BLAST/data.py | 231 +++++++----- Cell_BLAST/directi.py | 583 ++++++++++++++++------------- Cell_BLAST/latent.py | 518 ++++++++++++++++++-------- Cell_BLAST/metrics.py | 263 ++++++++----- Cell_BLAST/prob.py | 534 +++++++++++++++++---------- Cell_BLAST/rebuild.py | 164 +++++---- Cell_BLAST/rmbatch.py | 215 +++++++---- Cell_BLAST/utils.py | 149 ++++---- Cell_BLAST/weighting.py | 416 ++++++++++++--------- test/regression_refresh.py | 69 ---- test/regression_test.py | 82 ----- 14 files changed, 2352 insertions(+), 1635 deletions(-) delete mode 100644 test/regression_refresh.py delete mode 100644 test/regression_test.py diff --git a/Cell_BLAST/__init__.py b/Cell_BLAST/__init__.py index 98ff70b..b346173 100644 --- a/Cell_BLAST/__init__.py +++ b/Cell_BLAST/__init__.py @@ -6,17 +6,29 @@ from importlib.metadata import version except ModuleNotFoundError: from pkg_resources import get_distribution + version = lambda name: get_distribution(name).version from .utils import in_ipynb if not in_ipynb(): import matplotlib - matplotlib.use("agg") -from . import (blast, config, data, directi, latent, metrics, prob, - rebuild, rmbatch, utils, weighting) + matplotlib.use("agg") +from . import ( + blast, + config, + data, + directi, + latent, + metrics, + prob, + rebuild, + rmbatch, + utils, + weighting, +) name = "Cell_BLAST" __copyright__ = "2022, Gao Lab" @@ -34,5 +46,5 @@ "rebuild", "rmbatch", "utils", - "weighting" + "weighting", ] diff --git a/Cell_BLAST/blast.py b/Cell_BLAST/blast.py index 831dabf..c12088e 100644 --- a/Cell_BLAST/blast.py +++ b/Cell_BLAST/blast.py @@ -5,9 +5,10 @@ import collections import os import re -import typing import tempfile +import typing +import anndata import joblib import numba import numpy as np @@ -15,7 +16,6 @@ import scipy.sparse import scipy.stats import sklearn.neighbors -import anndata from . import config, data, directi, metrics, utils @@ -29,8 +29,8 @@ def _wasserstein_distance_impl(x: np.ndarray, y: np.ndarray): # pragma: no cove xy = np.concatenate((x, y)) xy.sort() deltas = np.diff(xy) - x_cdf = np.searchsorted(x[x_sorter], xy[:-1], 'right') / x.size - y_cdf = np.searchsorted(y[y_sorter], xy[:-1], 'right') / y.size + x_cdf = np.searchsorted(x[x_sorter], xy[:-1], "right") / x.size + y_cdf = np.searchsorted(y[y_sorter], xy[:-1], "right") / y.size return np.sum(np.multiply(np.abs(x_cdf - y_cdf), deltas)) @@ -46,14 +46,16 @@ def _energy_distance_impl(x: np.ndarray, y: np.ndarray): # pragma: no cover @numba.extending.overload( - scipy.stats.wasserstein_distance, jit_options={"nogil": True, "cache": True}) + scipy.stats.wasserstein_distance, jit_options={"nogil": True, "cache": True} +) def _wasserstein_distance(x: np.ndarray, y: np.ndarray): # pragma: no cover if x == numba.float32[::1] and y == numba.float32[::1]: return _wasserstein_distance_impl @numba.extending.overload( - scipy.stats.energy_distance, jit_options={"nogil": True, "cache": True}) + scipy.stats.energy_distance, jit_options={"nogil": True, "cache": True} +) def _energy_distance(x: np.ndarray, y: np.ndarray): # pragma: no cover if x == numba.float32[::1] and y == numba.float32[::1]: return _energy_distance_impl @@ -70,7 +72,7 @@ def ed(x: np.ndarray, y: np.ndarray): # pragma: no cover @numba.jit(nopython=True, nogil=True, cache=True) def _md( - x: np.ndarray, y: np.ndarray, x_posterior: np.ndarray + x: np.ndarray, y: np.ndarray, x_posterior: np.ndarray ) -> np.ndarray: # pragma: no cover r""" x : latent_dim @@ -87,8 +89,7 @@ def _md( @numba.jit(nopython=True, nogil=True, cache=True) def md( - x: np.ndarray, y: np.ndarray, - x_posterior: np.ndarray, y_posterior: np.ndarray + x: np.ndarray, y: np.ndarray, x_posterior: np.ndarray, y_posterior: np.ndarray ) -> np.ndarray: # pragma: no cover r""" x : latent_dim @@ -103,32 +104,37 @@ def md( @numba.jit(nopython=True, nogil=True, cache=True) def _compute_pcasd( - x: np.ndarray, x_posterior: np.ndarray, eps: float + x: np.ndarray, x_posterior: np.ndarray, eps: float ) -> np.ndarray: # pragma: no cover r""" x : latent_dim x_posterior : n_posterior * latent_dim """ - centered_x_posterior = x_posterior - np.sum(x_posterior, axis=0) / x_posterior.shape[0] + centered_x_posterior = ( + x_posterior - np.sum(x_posterior, axis=0) / x_posterior.shape[0] + ) cov_x = np.dot(centered_x_posterior.T, centered_x_posterior) - v = np.real(np.linalg.eig(cov_x.astype(np.complex64))[1]) # Suppress domain change due to rounding errors + v = np.real( + np.linalg.eig(cov_x.astype(np.complex64))[1] + ) # Suppress domain change due to rounding errors x_posterior = np.dot(x_posterior - x, v) squared_x_posterior = np.square(x_posterior) asd = np.empty((2, x_posterior.shape[1]), dtype=np.float32) for p in range(x_posterior.shape[1]): mask = x_posterior[:, p] < 0 - asd[0, p] = np.sqrt(( - np.sum(squared_x_posterior[mask, p]) - ) / max(np.sum(mask), 1)) + eps - asd[1, p] = np.sqrt(( - np.sum(squared_x_posterior[~mask, p]) - ) / max(np.sum(~mask), 1)) + eps + asd[0, p] = ( + np.sqrt((np.sum(squared_x_posterior[mask, p])) / max(np.sum(mask), 1)) + eps + ) + asd[1, p] = ( + np.sqrt((np.sum(squared_x_posterior[~mask, p])) / max(np.sum(~mask), 1)) + + eps + ) return np.concatenate((v, asd), axis=0) @numba.jit(nopython=True, nogil=True, cache=True) def _compute_pcasd_across_models( - x: np.ndarray, x_posterior: np.ndarray, eps: float = 1e-1 + x: np.ndarray, x_posterior: np.ndarray, eps: float = 1e-1 ) -> np.ndarray: # pragma: no cover r""" x : n_models * latent_dim @@ -142,8 +148,11 @@ def _compute_pcasd_across_models( @numba.jit(nopython=True, nogil=True, cache=True) def _amd( - x: np.ndarray, y: np.ndarray, x_posterior: np.ndarray, - eps: float, x_is_pcasd: bool = False + x: np.ndarray, + y: np.ndarray, + x_posterior: np.ndarray, + eps: float, + x_is_pcasd: bool = False, ) -> np.ndarray: # pragma: no cover r""" x : latent_dim @@ -167,9 +176,13 @@ def _amd( @numba.jit(nopython=True, nogil=True, cache=True) def amd( - x: np.ndarray, y: np.ndarray, - x_posterior: np.ndarray, y_posterior: np.ndarray, eps: float = 1e-1, - x_is_pcasd: bool = False, y_is_pcasd: bool = False + x: np.ndarray, + y: np.ndarray, + x_posterior: np.ndarray, + y_posterior: np.ndarray, + eps: float = 1e-1, + x_is_pcasd: bool = False, + y_is_pcasd: bool = False, ) -> np.ndarray: # pragma: no cover r""" x : latent_dim @@ -180,15 +193,18 @@ def amd( if np.all(x == y): return 0.0 return 0.5 * ( - _amd(x, y, x_posterior, eps, x_is_pcasd) + - _amd(y, x, y_posterior, eps, y_is_pcasd) + _amd(x, y, x_posterior, eps, x_is_pcasd) + + _amd(y, x, y_posterior, eps, y_is_pcasd) ) @numba.jit(nopython=True, nogil=True, cache=True) def npd_v1( - x: np.ndarray, y: np.ndarray, - x_posterior: np.ndarray, y_posterior: np.ndarray, eps: float = 0.0 + x: np.ndarray, + y: np.ndarray, + x_posterior: np.ndarray, + y_posterior: np.ndarray, + eps: float = 0.0, ) -> np.ndarray: # pragma: no cover r""" x : latent_dim @@ -203,20 +219,25 @@ def npd_v1( x_posterior = np.sum(x_posterior * projection, axis=1) # n_posterior_samples y_posterior = np.sum(y_posterior * projection, axis=1) # n_posterior_samples xy_posterior = np.concatenate((x_posterior, y_posterior)) - xy_posterior1 = (xy_posterior - np.mean(x_posterior)) / (np.std(x_posterior) + np.float32(eps)) - xy_posterior2 = (xy_posterior - np.mean(y_posterior)) / (np.std(y_posterior) + np.float32(eps)) - return 0.5 * (scipy.stats.wasserstein_distance( - xy_posterior1[:len(x_posterior)], - xy_posterior1[-len(y_posterior):] - ) + scipy.stats.wasserstein_distance( - xy_posterior2[:len(x_posterior)], - xy_posterior2[-len(y_posterior):] - )) + xy_posterior1 = (xy_posterior - np.mean(x_posterior)) / ( + np.std(x_posterior) + np.float32(eps) + ) + xy_posterior2 = (xy_posterior - np.mean(y_posterior)) / ( + np.std(y_posterior) + np.float32(eps) + ) + return 0.5 * ( + scipy.stats.wasserstein_distance( + xy_posterior1[: len(x_posterior)], xy_posterior1[-len(y_posterior) :] + ) + + scipy.stats.wasserstein_distance( + xy_posterior2[: len(x_posterior)], xy_posterior2[-len(y_posterior) :] + ) + ) @numba.jit(nopython=True, nogil=True, cache=True) def _npd_v2( - x: np.ndarray, y: np.ndarray, x_posterior: np.ndarray, eps: float + x: np.ndarray, y: np.ndarray, x_posterior: np.ndarray, eps: float ) -> np.ndarray: # pragma: no cover r""" x : latent_dim @@ -230,17 +251,17 @@ def _npd_v2( projected_noise = np.sum((x_posterior - x) * udev, axis=1) projected_y = np.sum((y - x) * udev) mask = (projected_noise * projected_y) >= 0 - scaler = np.sqrt( - np.sum(np.square(projected_noise[mask])) / - max(np.sum(mask), 1) - ) + scaler = np.sqrt(np.sum(np.square(projected_noise[mask])) / max(np.sum(mask), 1)) return np.abs(projected_y) / (scaler + eps) @numba.jit(nopython=True, nogil=True, cache=True) def npd_v2( - x: np.ndarray, y: np.ndarray, - x_posterior: np.ndarray, y_posterior: np.ndarray, eps: float = 1e-1 + x: np.ndarray, + y: np.ndarray, + x_posterior: np.ndarray, + y_posterior: np.ndarray, + eps: float = 1e-1, ) -> np.ndarray: # pragma: no cover r""" x : latent_dim @@ -250,16 +271,12 @@ def npd_v2( """ if np.all(x == y): return 0.0 - return 0.5 * ( - _npd_v2(x, y, x_posterior, eps) + - _npd_v2(y, x, y_posterior, eps) - ) + return 0.5 * (_npd_v2(x, y, x_posterior, eps) + _npd_v2(y, x, y_posterior, eps)) @numba.jit(nopython=True, nogil=True, cache=True) def _hit_ed_across_models( - query_latent: np.ndarray, - ref_latent: np.ndarray + query_latent: np.ndarray, ref_latent: np.ndarray ) -> np.ndarray: # pragma: no cover r""" query_latent : n_models * latent_dim @@ -277,8 +294,10 @@ def _hit_ed_across_models( @numba.jit(nopython=True, nogil=True, cache=True) def _hit_md_across_models( - query_latent: np.ndarray, ref_latent: np.ndarray, - query_posterior: np.ndarray, ref_posterior: np.ndarray + query_latent: np.ndarray, + ref_latent: np.ndarray, + query_posterior: np.ndarray, + ref_posterior: np.ndarray, ) -> np.ndarray: # pragma: no cover r""" query_latent : n_models * latent_dim @@ -300,8 +319,11 @@ def _hit_md_across_models( @numba.jit(nopython=True, nogil=True, cache=True) def _hit_amd_across_models( - query_latent: np.ndarray, ref_latent: np.ndarray, - query_posterior: np.ndarray, ref_posterior: np.ndarray, eps: float = 1e-1 + query_latent: np.ndarray, + ref_latent: np.ndarray, + query_posterior: np.ndarray, + ref_posterior: np.ndarray, + eps: float = 1e-1, ) -> np.ndarray: # pragma: no cover r""" query_latent : n_models * latent_dim @@ -318,17 +340,24 @@ def _hit_amd_across_models( y = ref_latent[j, i, ...] # latent_dim y_posterior = ref_posterior[j, i, ...] # n_posterior * latent_dim dist[j, i] = amd( - x, y, x_posterior, y_posterior, - eps=eps, x_is_pcasd=False, y_is_pcasd=True + x, + y, + x_posterior, + y_posterior, + eps=eps, + x_is_pcasd=False, + y_is_pcasd=True, ) return dist @numba.jit(nopython=True, nogil=True, cache=True) def _hit_npd_v1_across_models( - query_latent: np.ndarray, ref_latent: np.ndarray, - query_posterior: np.ndarray, ref_posterior: np.ndarray, - eps: float = 0.0 + query_latent: np.ndarray, + ref_latent: np.ndarray, + query_posterior: np.ndarray, + ref_posterior: np.ndarray, + eps: float = 0.0, ) -> np.ndarray: # pragma: no cover r""" query_latent : n_models * latent_dim @@ -350,9 +379,11 @@ def _hit_npd_v1_across_models( @numba.jit(nopython=True, nogil=True, cache=True) def _hit_npd_v2_across_models( - query_latent: np.ndarray, ref_latent: np.ndarray, - query_posterior: np.ndarray, ref_posterior: np.ndarray, - eps: float = 1e-1 + query_latent: np.ndarray, + ref_latent: np.ndarray, + query_posterior: np.ndarray, + ref_posterior: np.ndarray, + eps: float = 1e-1, ) -> np.ndarray: # pragma: no cover r""" query_latent : n_models * latent_dim @@ -377,12 +408,11 @@ def _hit_npd_v2_across_models( md: _hit_md_across_models, amd: _hit_amd_across_models, npd_v1: _hit_npd_v1_across_models, - npd_v2: _hit_npd_v2_across_models + npd_v2: _hit_npd_v2_across_models, } class BLAST(object): - r""" Cell BLAST @@ -446,11 +476,16 @@ class BLAST(object): """ def __init__( - self, models: typing.List[directi.DIRECTi], - ref: anndata.AnnData, distance_metric: str = "npd_v1", - n_posterior: int = 50, n_empirical: int = 10000, - cluster_empirical: bool = False, eps: typing.Optional[float] = None, - force_components: bool = True, **kwargs + self, + models: typing.List[directi.DIRECTi], + ref: anndata.AnnData, + distance_metric: str = "npd_v1", + n_posterior: int = 50, + n_empirical: int = 10000, + cluster_empirical: bool = False, + eps: typing.Optional[float] = None, + force_components: bool = True, + **kwargs, ) -> None: self.models = models self.ref = anndata.AnnData( @@ -463,8 +498,11 @@ def __init__( self.posterior = np.array([None] * self.ref.shape[0]) self.empirical = None - self.distance_metric = globals()[distance_metric] \ - if isinstance(distance_metric, str) else distance_metric + self.distance_metric = ( + globals()[distance_metric] + if isinstance(distance_metric, str) + else distance_metric + ) self.n_posterior = n_posterior if self.distance_metric is not ed else 0 self.n_empirical = n_empirical self.cluster_empirical = cluster_empirical @@ -479,59 +517,75 @@ def __len__(self) -> int: def __getitem__(self, s) -> "BLAST": blast = BLAST( np.array(self.models)[s].tolist(), - self.ref, self.distance_metric, self.n_posterior, self.n_empirical, - self.cluster_empirical, self.eps, force_components=False + self.ref, + self.distance_metric, + self.n_posterior, + self.n_empirical, + self.cluster_empirical, + self.eps, + force_components=False, ) blast.latent = self.latent[:, s, ...] if self.latent is not None else None blast.cluster = self.cluster[:, s, ...] if self.cluster is not None else None - blast.nearest_neighbors = np.array(self.nearest_neighbors)[s].tolist() \ - if self.nearest_neighbors is not None else None + blast.nearest_neighbors = ( + np.array(self.nearest_neighbors)[s].tolist() + if self.nearest_neighbors is not None + else None + ) if self.posterior is not None: for i in range(self.posterior.size): if self.posterior[i] is not None: blast.posterior[i] = self.posterior[i][s, ...] - blast.empirical = [ - item for item in np.array(self.empirical)[s] - ] if self.empirical is not None else None + blast.empirical = ( + [item for item in np.array(self.empirical)[s]] + if self.empirical is not None + else None + ) return blast def _get_latent(self, n_jobs: int) -> np.ndarray: # n_cells * n_models * latent_dim if self.latent is None: utils.logger.info("Projecting to latent space...") - self.latent = np.stack(joblib.Parallel( - n_jobs=min(n_jobs, len(self)), backend="threading" - )(joblib.delayed(model.inference)( - self.ref - ) for model in self.models), axis=1) + self.latent = np.stack( + joblib.Parallel(n_jobs=min(n_jobs, len(self)), backend="threading")( + joblib.delayed(model.inference)(self.ref) for model in self.models + ), + axis=1, + ) return self.latent def _get_cluster(self, n_jobs: int) -> np.ndarray: # n_cells * n_models if self.cluster is None: utils.logger.info("Obtaining intrinsic clustering...") - self.cluster = np.stack(joblib.Parallel( - n_jobs=min(n_jobs, len(self)), backend="threading" - )(joblib.delayed(model.clustering)( - self.ref - ) for model in self.models), axis=1) + self.cluster = np.stack( + joblib.Parallel(n_jobs=min(n_jobs, len(self)), backend="threading")( + joblib.delayed(model.clustering)(self.ref) for model in self.models + ), + axis=1, + ) return self.cluster def _get_posterior( - self, n_jobs: int, random_seed: int, - idx: typing.Optional[np.ndarray] = None + self, n_jobs: int, random_seed: int, idx: typing.Optional[np.ndarray] = None ) -> np.ndarray: # n_cells * (n_models * n_posterior * latent_dim) if idx is None: idx = np.arange(self.ref.shape[0]) - new_idx = np.intersect1d(np.unique(idx), np.where(np.vectorize( - lambda x: x is None - )(self.posterior))[0]) + new_idx = np.intersect1d( + np.unique(idx), + np.where(np.vectorize(lambda x: x is None)(self.posterior))[0], + ) if new_idx.size: utils.logger.info("Sampling from posteriors...") new_ref = self.ref[new_idx, :] - new_posterior = np.stack(joblib.Parallel( - n_jobs=min(n_jobs, len(self)), backend="loky" - )(joblib.delayed(model.inference)( - new_ref, n_posterior=self.n_posterior, random_seed=random_seed - ) for model in self.models), axis=1) # n_cells * n_models * n_posterior * latent_dim + new_posterior = np.stack( + joblib.Parallel(n_jobs=min(n_jobs, len(self)), backend="loky")( + joblib.delayed(model.inference)( + new_ref, n_posterior=self.n_posterior, random_seed=random_seed + ) + for model in self.models + ), + axis=1, + ) # n_cells * n_models * n_posterior * latent_dim # NOTE: Slow discontigous memcopy here, but that's necessary since # we will be caching values by cells. It also makes values more # contiguous and faster to access in later cell-based operations. @@ -540,15 +594,20 @@ def _get_posterior( new_latent = self._get_latent(n_jobs)[new_idx] self.posterior[new_idx] = joblib.Parallel( n_jobs=n_jobs, backend="threading" - )(joblib.delayed(_compute_pcasd_across_models)( - _new_latent, _new_posterior, **dist_kws - ) for _new_latent, _new_posterior in zip(new_latent, new_posterior)) + )( + joblib.delayed(_compute_pcasd_across_models)( + _new_latent, _new_posterior, **dist_kws + ) + for _new_latent, _new_posterior in zip(new_latent, new_posterior) + ) else: - self.posterior[new_idx] = [item for item in new_posterior] # NOTE: No memcopy here + self.posterior[new_idx] = [ + item for item in new_posterior + ] # NOTE: No memcopy here return self.posterior[idx] def _get_nearest_neighbors( - self, n_jobs: int + self, n_jobs: int ) -> typing.List[sklearn.neighbors.NearestNeighbors]: # n_models if self.nearest_neighbors is None: latent = self._get_latent(n_jobs).swapaxes(0, 1) @@ -557,20 +616,19 @@ def _get_nearest_neighbors( utils.logger.info("Fitting nearest neighbor trees...") self.nearest_neighbors = joblib.Parallel( n_jobs=min(n_jobs, len(self)), backend="loky" - )(joblib.delayed(self._fit_nearest_neighbors)( - _latent - ) for _latent in latent) + )( + joblib.delayed(self._fit_nearest_neighbors)(_latent) + for _latent in latent + ) return self.nearest_neighbors def _get_empirical( - self, n_jobs: int, random_seed: int + self, n_jobs: int, random_seed: int ) -> np.ndarray: # n_models * [n_clusters * n_empirical] if self.empirical is None: utils.logger.info("Generating empirical null distributions...") if not self.cluster_empirical: - self.cluster = np.zeros(( - self.ref.shape[0], len(self) - ), dtype=int) + self.cluster = np.zeros((self.ref.shape[0], len(self)), dtype=int) latent = self._get_latent(n_jobs) cluster = self._get_cluster(n_jobs) rs = np.random.RandomState(random_seed) @@ -585,46 +643,54 @@ def _get_empirical( for k in range(len(self)): # model_idx empirical = np.zeros((np.max(cluster[:, k]) + 1, self.n_empirical)) for c in np.unique(cluster[:, k]): # cluster_idx - fg = rs.choice(np.where(cluster[:, k] == c)[0], size=self.n_empirical) + fg = rs.choice( + np.where(cluster[:, k] == c)[0], size=self.n_empirical + ) if self.distance_metric is ed: - empirical[c] = np.sort(joblib.Parallel( - n_jobs=n_jobs, backend="threading" - )(joblib.delayed(self.distance_metric)( - latent[fg[i]], latent[bg[i]] - ) for i in range(self.n_empirical))) + empirical[c] = np.sort( + joblib.Parallel(n_jobs=n_jobs, backend="threading")( + joblib.delayed(self.distance_metric)( + latent[fg[i]], latent[bg[i]] + ) + for i in range(self.n_empirical) + ) + ) else: fg_posterior = self._get_posterior(n_jobs, random_seed, idx=fg) - empirical[c] = np.sort(joblib.Parallel( - n_jobs=n_jobs, backend="threading" - )(joblib.delayed(self.distance_metric)( - latent[fg[i], k], latent[bg[i], k], - fg_posterior[i][k], bg_posterior[i][k], - **dist_kws - ) for i in range(self.n_empirical))) + empirical[c] = np.sort( + joblib.Parallel(n_jobs=n_jobs, backend="threading")( + joblib.delayed(self.distance_metric)( + latent[fg[i], k], + latent[bg[i], k], + fg_posterior[i][k], + bg_posterior[i][k], + **dist_kws, + ) + for i in range(self.n_empirical) + ) + ) self.empirical.append(empirical) return self.empirical def _force_components( - self, n_jobs: int = config._USE_GLOBAL, - random_seed: int = config._USE_GLOBAL + self, n_jobs: int = config._USE_GLOBAL, random_seed: int = config._USE_GLOBAL ) -> None: n_jobs = config.N_JOBS if n_jobs == config._USE_GLOBAL else n_jobs - random_seed = config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed + random_seed = ( + config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed + ) self._get_nearest_neighbors(n_jobs) if self.distance_metric is not ed: self._get_posterior(n_jobs, random_seed) self._get_empirical(n_jobs, random_seed) @staticmethod - def _fit_nearest_neighbors( - x: np.ndarray - ) -> sklearn.neighbors.NearestNeighbors: + def _fit_nearest_neighbors(x: np.ndarray) -> sklearn.neighbors.NearestNeighbors: return sklearn.neighbors.NearestNeighbors().fit(x) @staticmethod def _nearest_neighbor_search( - nn: sklearn.neighbors.NearestNeighbors, - query: np.ndarray, n_neighbors: int + nn: sklearn.neighbors.NearestNeighbors, query: np.ndarray, n_neighbors: int ) -> np.ndarray: return nn.kneighbors(query, n_neighbors=n_neighbors)[1] @@ -636,8 +702,7 @@ def _nearest_neighbor_merge(x: np.ndarray) -> np.ndarray: # pragma: no cover @staticmethod @numba.jit(nopython=True, nogil=True, cache=True) def _empirical_pvalue( - hits: np.ndarray, dist: np.ndarray, - cluster: np.ndarray, empirical: np.ndarray + hits: np.ndarray, dist: np.ndarray, cluster: np.ndarray, empirical: np.ndarray ) -> np.ndarray: # pragma: no cover r""" hits : n_hits @@ -648,9 +713,10 @@ def _empirical_pvalue( pval = np.empty(dist.shape) for i in range(dist.shape[1]): # model index for j in range(dist.shape[0]): # hit index - pval[j, i] = np.searchsorted( - empirical[i][cluster[hits[j], i]], dist[j, i] - ) / empirical[i].shape[1] + pval[j, i] = ( + np.searchsorted(empirical[i][cluster[hits[j], i]], dist[j, i]) + / empirical[i].shape[1] + ) return pval def save(self, path: str, only_used_genes: bool = True) -> None: @@ -669,10 +735,13 @@ def save(self, path: str, only_used_genes: bool = True) -> None: if self.ref is not None: if only_used_genes: if "__libsize__" not in self.ref.obs.columns: - data.compute_libsize(self.ref) # So that align will still work properly - ref = data.select_vars(self.ref, np.unique(np.concatenate([ - model.genes for model in self.models - ]))) + data.compute_libsize( + self.ref + ) # So that align will still work properly + ref = data.select_vars( + self.ref, + np.unique(np.concatenate([model.genes for model in self.models])), + ) ref.uns["distance_metric"] = self.distance_metric.__name__ ref.uns["n_posterior"] = self.n_posterior ref.uns["n_empirical"] = self.n_empirical @@ -683,8 +752,14 @@ def save(self, path: str, only_used_genes: bool = True) -> None: if self.latent is not None: ref.uns["cluster"] = self.cluster if self.empirical is not None: - ref.uns["empirical"] = {str(i): item for i, item in enumerate(self.empirical)} - ref.uns["posterior"] = {str(i): item for i, item in enumerate(self.posterior) if item is not None} + ref.uns["empirical"] = { + str(i): item for i, item in enumerate(self.empirical) + } + ref.uns["posterior"] = { + str(i): item + for i, item in enumerate(self.posterior) + if item is not None + } ref.write(os.path.join(path, "ref.h5ad")) for i in range(len(self)): self.models[i].save(os.path.join(path, f"model_{i}")) @@ -712,23 +787,29 @@ def load(cls, path: str, mode: int = NORMAL, **kwargs): ref = anndata.read_h5ad(os.path.join(path, "ref.h5ad")) models = [] model_paths = sorted( - os.path.join(path, d) for d in os.listdir(path) - if re.fullmatch(r'model_[0-9]+', d) - and os.path.isdir(os.path.join(path, d)) + os.path.join(path, d) + for d in os.listdir(path) + if re.fullmatch(r"model_[0-9]+", d) and os.path.isdir(os.path.join(path, d)) ) for model_path in model_paths: models.append(directi.DIRECTi.load(model_path, _mode=mode)) blast = cls( - models, ref, ref.uns["distance_metric"], ref.uns["n_posterior"], - ref.uns["n_empirical"], ref.uns["cluster_empirical"], + models, + ref, + ref.uns["distance_metric"], + ref.uns["n_posterior"], + ref.uns["n_empirical"], + ref.uns["cluster_empirical"], None if ref.uns["eps"] == "None" else ref.uns["eps"], - force_components=False + force_components=False, ) blast.latent = blast.ref.uns["latent"] if "latent" in blast.ref.uns else None blast.cluster = blast.ref.uns["cluster"] if "latent" in blast.ref.uns else None - blast.empirical = [ - blast.ref.uns["empirical"][str(i)] for i in range(len(blast)) - ] if "empirical" in blast.ref.uns else None + blast.empirical = ( + [blast.ref.uns["empirical"][str(i)] for i in range(len(blast))] + if "empirical" in blast.ref.uns + else None + ) if "posterior" in blast.ref.uns: for i in range(ref.shape[0]): if str(i) in blast.ref.uns["posterior"]: @@ -737,10 +818,12 @@ def load(cls, path: str, mode: int = NORMAL, **kwargs): return blast def query( - self, query: anndata.AnnData, n_neighbors: int = 5, - store_dataset: bool = False, - n_jobs: int = config._USE_GLOBAL, - random_seed: int = config._USE_GLOBAL + self, + query: anndata.AnnData, + n_neighbors: int = 5, + store_dataset: bool = False, + n_jobs: int = config._USE_GLOBAL, + random_seed: int = config._USE_GLOBAL, ) -> "Hits": r""" BLAST query @@ -770,24 +853,30 @@ def query( Query hits """ n_jobs = config.N_JOBS if n_jobs == config._USE_GLOBAL else n_jobs - random_seed = config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed + random_seed = ( + config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed + ) utils.logger.info("Projecting to latent space...") query_latent = joblib.Parallel( n_jobs=min(n_jobs, len(self)), backend="threading" - )(joblib.delayed(model.inference)( - query - ) for model in self.models) # n_models * [n_cells * latent_dim] + )( + joblib.delayed(model.inference)(query) for model in self.models + ) # n_models * [n_cells * latent_dim] utils.logger.info("Doing nearest neighbor search...") nearest_neighbors = self._get_nearest_neighbors(n_jobs) - nni = np.stack(joblib.Parallel( - n_jobs=min(n_jobs, len(self)), backend="threading" - )(joblib.delayed(self._nearest_neighbor_search)( - _nearest_neighbor, _query_latent, n_neighbors - ) for _nearest_neighbor, _query_latent in zip( - nearest_neighbors, query_latent - )), axis=2) # n_cells * n_neighbors * n_models + nni = np.stack( + joblib.Parallel(n_jobs=min(n_jobs, len(self)), backend="threading")( + joblib.delayed(self._nearest_neighbor_search)( + _nearest_neighbor, _query_latent, n_neighbors + ) + for _nearest_neighbor, _query_latent in zip( + nearest_neighbors, query_latent + ) + ), + axis=2, + ) # n_cells * n_neighbors * n_models utils.logger.info("Merging hits across models...") hits = joblib.Parallel(n_jobs=n_jobs, backend="threading")( @@ -803,51 +892,66 @@ def query( dist = joblib.Parallel(n_jobs=n_jobs, backend="threading")( joblib.delayed(_hit_ed_across_models)( query_latent[i], ref_latent[hits[i]] - ) for i in range(len(hits)) + ) + for i in range(len(hits)) ) # list of n_hits * n_models else: utils.logger.info("Computing posterior distribution distances...") - query_posterior = np.stack(joblib.Parallel( - n_jobs=min(n_jobs, len(self)), backend="loky" - )(joblib.delayed(model.inference)( - query, n_posterior=self.n_posterior, random_seed=random_seed - ) for model in self.models), axis=1) # n_cells * n_models * n_posterior_samples * latent_dim - ref_posterior = np.stack(self._get_posterior( - n_jobs, random_seed, idx=hitsu - )) # n_cells * n_models * n_posterior_samples * latent_dim + query_posterior = np.stack( + joblib.Parallel(n_jobs=min(n_jobs, len(self)), backend="loky")( + joblib.delayed(model.inference)( + query, n_posterior=self.n_posterior, random_seed=random_seed + ) + for model in self.models + ), + axis=1, + ) # n_cells * n_models * n_posterior_samples * latent_dim + ref_posterior = np.stack( + self._get_posterior(n_jobs, random_seed, idx=hitsu) + ) # n_cells * n_models * n_posterior_samples * latent_dim distance_metric = DISTANCE_METRIC_ACROSS_MODELS[self.distance_metric] dist_kws = {"eps": self.eps} if self.eps is not None else {} dist = joblib.Parallel(n_jobs=n_jobs, backend="threading")( joblib.delayed(distance_metric)( - query_latent[i], ref_latent[hits[i]], - query_posterior[i], ref_posterior[hitsi[i]], **dist_kws - ) for i in range(len(hits)) + query_latent[i], + ref_latent[hits[i]], + query_posterior[i], + ref_posterior[hitsi[i]], + **dist_kws, + ) + for i in range(len(hits)) ) # list of n_hits * n_models utils.logger.info("Computing empirical p-values...") empirical = self._get_empirical(n_jobs, random_seed) cluster = self._get_cluster(n_jobs) - pval = joblib.Parallel( - n_jobs=n_jobs, backend="threading" - )(joblib.delayed(self._empirical_pvalue)( - _hits, _dist, cluster, empirical - ) for _hits, _dist in zip(hits, dist)) # list of n_hits * n_models + pval = joblib.Parallel(n_jobs=n_jobs, backend="threading")( + joblib.delayed(self._empirical_pvalue)(_hits, _dist, cluster, empirical) + for _hits, _dist in zip(hits, dist) + ) # list of n_hits * n_models return Hits( - self, hits, dist, pval, - query if store_dataset else anndata.AnnData( + self, + hits, + dist, + pval, + query + if store_dataset + else anndata.AnnData( X=scipy.sparse.csr_matrix((query.shape[0], 0)), obs=pd.DataFrame(index=query.obs.index), - var=pd.DataFrame(), uns={} - ) + var=pd.DataFrame(), + uns={}, + ), ) def align( - self, query: typing.Union[ - anndata.AnnData, typing.Mapping[str, anndata.AnnData] - ], n_jobs: int = config._USE_GLOBAL, - random_seed: int = config._USE_GLOBAL, - path: typing.Optional[str] = None, **kwargs + self, + query: typing.Union[anndata.AnnData, typing.Mapping[str, anndata.AnnData]], + n_jobs: int = config._USE_GLOBAL, + random_seed: int = config._USE_GLOBAL, + path: typing.Optional[str] = None, + **kwargs, ) -> "BLAST": r""" Align internal DIRECTi models with query datasets (fine tuning). @@ -876,29 +980,39 @@ def align( blast A new BLAST object with aligned internal models. """ - if any(model._mode == directi._TEST for model in self.models): # pragma: no cover + if any( + model._mode == directi._TEST for model in self.models + ): # pragma: no cover raise Exception("Align not available!") n_jobs = config.N_JOBS if n_jobs == config._USE_GLOBAL else n_jobs - random_seed = config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed + random_seed = ( + config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed + ) path = path or tempfile.mkdtemp() - aligned_models = joblib.Parallel( - n_jobs=n_jobs, backend="threading" - )( + aligned_models = joblib.Parallel(n_jobs=n_jobs, backend="threading")( joblib.delayed(directi.align_DIRECTi)( - self.models[i], self.ref, query, random_seed=random_seed, - path=os.path.join(path, f"aligned_model_{i}"), **kwargs - ) for i in range(len(self)) + self.models[i], + self.ref, + query, + random_seed=random_seed, + path=os.path.join(path, f"aligned_model_{i}"), + **kwargs, + ) + for i in range(len(self)) ) return BLAST( - aligned_models, self.ref, distance_metric=self.distance_metric, - n_posterior=self.n_posterior, n_empirical=self.n_empirical, - cluster_empirical=self.cluster_empirical, eps=self.eps + aligned_models, + self.ref, + distance_metric=self.distance_metric, + n_posterior=self.n_posterior, + n_empirical=self.n_empirical, + cluster_empirical=self.cluster_empirical, + eps=self.eps, ) class Hits(object): - r""" BLAST hits @@ -929,33 +1043,39 @@ class Hits(object): FILTER_BY_PVAL = 1 def __init__( - self, blast: BLAST, - hits: typing.List[np.ndarray], - dist: typing.List[np.ndarray], - pval: typing.List[np.ndarray], - query: anndata.AnnData + self, + blast: BLAST, + hits: typing.List[np.ndarray], + dist: typing.List[np.ndarray], + pval: typing.List[np.ndarray], + query: anndata.AnnData, ) -> None: self.blast = blast self.hits = np.asarray(hits, dtype=object) self.dist = np.asarray(dist, dtype=object) self.pval = np.asarray(pval, dtype=object) self.query = query - if not self.hits.shape[0] == self.dist.shape[0] == \ - self.pval.shape[0] == self.query.shape[0]: + if ( + not self.hits.shape[0] + == self.dist.shape[0] + == self.pval.shape[0] + == self.query.shape[0] + ): raise ValueError("Inconsistent shape!") def __len__(self) -> int: return self.query.shape[0] def __iter__(self): - for idx, (_hits, _dist, _pval) in enumerate(zip(self.hits, self.dist, self.pval)): + for idx, (_hits, _dist, _pval) in enumerate( + zip(self.hits, self.dist, self.pval) + ): yield Hits(self.blast, [_hits], [_dist], [_pval], self.query[idx, :]) def __getitem__(self, s): s = [s] if isinstance(s, (int, np.integer)) else s return Hits( - self.blast, self.hits[s], self.dist[s], self.pval[s], - self.query[s, :] + self.blast, self.hits[s], self.dist[s], self.pval[s], self.query[s, :] ) def to_data_frames(self) -> typing.Mapping[str, pd.DataFrame]: @@ -979,7 +1099,7 @@ def to_data_frames(self) -> typing.Mapping[str, pd.DataFrame]: return df_dict def reconcile_models( - self, dist_method: str = "mean", pval_method: str = "gmean" + self, dist_method: str = "mean", pval_method: str = "gmean" ) -> "Hits": r""" Integrate model-specific distances and empirical p-values. @@ -1007,8 +1127,12 @@ def reconcile_models( @staticmethod @numba.jit(nopython=True, nogil=True, cache=True) def _filter_hits( - hits: np.ndarray, dist: np.ndarray, pval: np.ndarray, - by: int, cutoff: float, model_tolerance: int + hits: np.ndarray, + dist: np.ndarray, + pval: np.ndarray, + by: int, + cutoff: float, + model_tolerance: int, ) -> typing.Tuple[np.ndarray, np.ndarray, np.ndarray]: # pragma: no cover r""" hits : n_hits @@ -1024,8 +1148,7 @@ def _filter_hits( @staticmethod @numba.jit(nopython=True, nogil=True, cache=True) def _filter_reconciled_hits( - hits: np.ndarray, dist: np.ndarray, pval: np.ndarray, - by: int, cutoff: float + hits: np.ndarray, dist: np.ndarray, pval: np.ndarray, by: int, cutoff: float ) -> typing.Tuple[np.ndarray, np.ndarray, np.ndarray]: # pragma: no cover r""" hits : n_hits @@ -1039,8 +1162,11 @@ def _filter_reconciled_hits( return hits[hit_mask], dist[hit_mask], pval[hit_mask] def filter( - self, by: str = "pval", cutoff: float = 0.05, - model_tolerance: int = 0, n_jobs: int = 1 + self, + by: str = "pval", + cutoff: float = 0.05, + model_tolerance: int = 0, + n_jobs: int = 1, ) -> "Hits": r""" Filter hits by posterior distance or p-value @@ -1070,26 +1196,37 @@ def filter( else: # by == "dist" by = Hits.FILTER_BY_DIST if self.dist[0].ndim == 1: - hits, dist, pval = [_ for _ in zip(*joblib.Parallel( - n_jobs=n_jobs, backend="threading" - )(joblib.delayed(self._filter_reconciled_hits)( - _hits, _dist, _pval, by, cutoff - ) for _hits, _dist, _pval in zip( - self.hits, self.dist, self.pval - )))] + hits, dist, pval = [ + _ + for _ in zip( + *joblib.Parallel(n_jobs=n_jobs, backend="threading")( + joblib.delayed(self._filter_reconciled_hits)( + _hits, _dist, _pval, by, cutoff + ) + for _hits, _dist, _pval in zip(self.hits, self.dist, self.pval) + ) + ) + ] else: - hits, dist, pval = [_ for _ in zip(*joblib.Parallel( - n_jobs=n_jobs, backend="threading" - )(joblib.delayed(self._filter_hits)( - _hits, _dist, _pval, by, cutoff, model_tolerance - ) for _hits, _dist, _pval in zip( - self.hits, self.dist, self.pval - )))] + hits, dist, pval = [ + _ + for _ in zip( + *joblib.Parallel(n_jobs=n_jobs, backend="threading")( + joblib.delayed(self._filter_hits)( + _hits, _dist, _pval, by, cutoff, model_tolerance + ) + for _hits, _dist, _pval in zip(self.hits, self.dist, self.pval) + ) + ) + ] return Hits(self.blast, hits, dist, pval, self.query) def annotate( - self, field: str, min_hits: int = 2, - majority_threshold: float = 0.5, return_evidence: bool = False + self, + field: str, + min_hits: int = 2, + majority_threshold: float = 0.5, + return_evidence: bool = False, ) -> pd.DataFrame: r""" Annotate query cells based on existing annotations of hit cells @@ -1122,8 +1259,9 @@ def annotate( """ ref = self.blast.ref.obs[field].to_numpy().ravel() n_hits = np.repeat(0, len(self.hits)) - if np.issubdtype(ref.dtype.type, np.character) or \ - np.issubdtype(ref.dtype.type, np.object_): + if np.issubdtype(ref.dtype.type, np.character) or np.issubdtype( + ref.dtype.type, np.object_ + ): prediction = np.repeat("rejected", len(self.hits)).astype(object) majority_frac = np.repeat(np.nan, len(self.hits)) for i, _hits in enumerate(self.hits): @@ -1162,9 +1300,12 @@ def annotate( return pd.DataFrame(result, index=self.query.obs_names) def blast2co( - self, cl_dag: utils.CellTypeDAG, - cl_field: str = "cell_ontology_class", - min_hits: int = 2, thresh: float = 0.5, min_path: int = 4 + self, + cl_dag: utils.CellTypeDAG, + cl_field: str = "cell_ontology_class", + min_hits: int = 2, + thresh: float = 0.5, + min_path: int = 4, ) -> pd.DataFrame: r""" Annotate query cells based on existing annotations of hit cells @@ -1208,7 +1349,8 @@ def blast2co( cl_dag.value_set(cl, np.sum(1 - _pval[hits == cl]) / np.sum(1 - _pval)) cl_dag.value_update() leaves = cl_dag.best_leaves( - thresh=thresh, min_path=min_path, retrieve=cl_field) + thresh=thresh, min_path=min_path, retrieve=cl_field + ) if len(leaves) == 1: prediction[i] = leaves[0] elif len(leaves) > 1: @@ -1228,8 +1370,11 @@ def _get_reconcile_method(method: str): raise ValueError("Unknown method!") # pragma: no cover def gene_gradient( - self, eval_point: str = "query", normalize_deviation: bool = True, - avg_models: bool = True, n_jobs: int = config._USE_GLOBAL + self, + eval_point: str = "query", + normalize_deviation: bool = True, + avg_models: bool = True, + n_jobs: int = config._USE_GLOBAL, ) -> typing.List[np.ndarray]: r""" Compute gene-wise gradient for each pair of query-hit cells @@ -1264,50 +1409,49 @@ def gene_gradient( n_jobs = config.N_JOBS if n_jobs == config._USE_GLOBAL else n_jobs if self.query.shape[1] == 0: raise RuntimeError( - "No query data available! Please set \"store_dataset\" to True " + 'No query data available! Please set "store_dataset" to True ' "when calling BLAST.query()" ) ref_idx = np.concatenate(self.hits) - query_idx = np.concatenate([ - idx * np.ones_like(_hits) - for idx, _hits in enumerate(self.hits) - ]) + query_idx = np.concatenate( + [idx * np.ones_like(_hits) for idx, _hits in enumerate(self.hits)] + ) ref = self.blast.ref[ref_idx, :] query = self.query[query_idx, :] query_latent = joblib.Parallel( n_jobs=min(n_jobs, len(self)), backend="threading" - )(joblib.delayed(model.inference)( - self.query - ) for model in self.blast.models) - query_latent = np.stack(query_latent)[:, query_idx, :] # n_models * sum(n_hits) * latent_dim + )(joblib.delayed(model.inference)(self.query) for model in self.blast.models) + query_latent = np.stack(query_latent)[ + :, query_idx, : + ] # n_models * sum(n_hits) * latent_dim ref_latent = self.blast._get_latent(n_jobs) # n_cells * n_models * latent_dim - ref_latent = ref_latent[ref_idx].swapaxes(0, 1) # n_models * sum(n_hits) * latent_dim + ref_latent = ref_latent[ref_idx].swapaxes( + 0, 1 + ) # n_models * sum(n_hits) * latent_dim deviation = query_latent - ref_latent # n_models * sum(n_hits) * latent_dim if normalize_deviation: deviation /= np.linalg.norm(deviation, axis=2, keepdims=True) if eval_point in ("ref", "both"): - gene_dev_ref = joblib.Parallel( - n_jobs=n_jobs, backend="threading" - )(joblib.delayed(model.gene_grad)( - ref, latent_grad=_deviation - ) for model, _deviation in zip( - self.blast.models, deviation - )) # n_models * [sum(n_hits) * n_genes] - gene_dev_ref = np.stack(gene_dev_ref, axis=1) # sum(n_hits) * n_models * n_genes + gene_dev_ref = joblib.Parallel(n_jobs=n_jobs, backend="threading")( + joblib.delayed(model.gene_grad)(ref, latent_grad=_deviation) + for model, _deviation in zip(self.blast.models, deviation) + ) # n_models * [sum(n_hits) * n_genes] + gene_dev_ref = np.stack( + gene_dev_ref, axis=1 + ) # sum(n_hits) * n_models * n_genes if eval_point in ("query", "both"): - gene_dev_query = joblib.Parallel( - n_jobs=n_jobs, backend="threading" - )(joblib.delayed(model.gene_grad)( - query, latent_grad=_deviation - ) for model, _deviation in zip( - self.blast.models, deviation - )) # n_models * [sum(n_hits) * n_genes] - gene_dev_query = np.stack(gene_dev_query, axis=1) # sum(n_hits) * n_models * n_genes + gene_dev_query = joblib.Parallel(n_jobs=n_jobs, backend="threading")( + joblib.delayed(model.gene_grad)(query, latent_grad=_deviation) + for model, _deviation in zip(self.blast.models, deviation) + ) # n_models * [sum(n_hits) * n_genes] + gene_dev_query = np.stack( + gene_dev_query, axis=1 + ) # sum(n_hits) * n_models * n_genes if eval_point == "ref": gene_dev = gene_dev_ref @@ -1320,15 +1464,22 @@ def gene_gradient( gene_dev = np.mean(gene_dev, axis=1) split_idx = np.cumsum([_hits.size for _hits in self.hits])[:-1] - gene_dev = np.split(gene_dev, split_idx) # n_queries * [n_hits * * n_genes] + gene_dev = np.split( + gene_dev, split_idx + ) # n_queries * [n_hits * * n_genes] return gene_dev def sankey( - query: np.ndarray, ref: np.ndarray, title: str = "Sankey", - width: int = 500, height: int = 500, tint_cutoff: int = 1, - font: str = "Arial", font_size: float = 10.0, - suppress_plot: bool = False + query: np.ndarray, + ref: np.ndarray, + title: str = "Sankey", + width: int = 500, + height: int = 500, + tint_cutoff: int = 1, + font: str = "Arial", + font_size: float = 10.0, + suppress_plot: bool = False, ) -> dict: # pragma: no cover r""" Make a sankey diagram of query-reference mapping (only works in @@ -1373,40 +1524,27 @@ def sankey( node=dict( pad=15, thickness=20, - line=dict( - color="black", - width=0.5 - ), - label=np.concatenate([ - query_c, ref_c - ], axis=0), - color=["#E64B35"] * len(query_c) + - ["#4EBBD5"] * len(ref_c) + line=dict(color="black", width=0.5), + label=np.concatenate([query_c, ref_c], axis=0), + color=["#E64B35"] * len(query_c) + ["#4EBBD5"] * len(ref_c), ), link=dict( source=query_i.tolist(), - target=( - ref_i + len(query_c) - ).tolist(), + target=(ref_i + len(query_c)).tolist(), value=cf["count"].tolist(), - color=np.vectorize( - lambda x: "#F0F0F0" if x <= tint_cutoff else "#CCCCCC" - )(cf["count"]) - ) + color=np.vectorize(lambda x: "#F0F0F0" if x <= tint_cutoff else "#CCCCCC")( + cf["count"] + ), + ), ) sankey_layout = dict( - title=title, - width=width, - height=height, - font=dict( - family=font, - size=font_size - ) + title=title, width=width, height=height, font=dict(family=font, size=font_size) ) fig = dict(data=[sankey_data], layout=sankey_layout) if not suppress_plot: import plotly.offline + plotly.offline.init_notebook_mode() plotly.offline.iplot(fig, validate=False) - return fig \ No newline at end of file + return fig diff --git a/Cell_BLAST/config.py b/Cell_BLAST/config.py index 8fa1ff9..0e961f2 100644 --- a/Cell_BLAST/config.py +++ b/Cell_BLAST/config.py @@ -6,7 +6,6 @@ from .utils import autodevice - RANDOM_SEED = 0 N_JOBS = 1 DEVICE = autodevice() @@ -20,9 +19,7 @@ "compression": "gzip", "compression_opts": 7, } -H5_TRACK_OPTS = { - "track_times": False -} +H5_TRACK_OPTS = {"track_times": False} SUPERVISION = None RESOLUTION = 10.0 diff --git a/Cell_BLAST/data.py b/Cell_BLAST/data.py index 65ddde4..5845c30 100644 --- a/Cell_BLAST/data.py +++ b/Cell_BLAST/data.py @@ -2,24 +2,24 @@ Dataset utilities """ +import os import typing import warnings +from collections import OrderedDict +import anndata as ad +import h5py +import matplotlib.axes +import matplotlib.pyplot as plt import numpy as np import pandas as pd import scipy.sparse import scipy.stats -import sklearn.metrics -import matplotlib.axes -import matplotlib.pyplot as plt import seaborn as sns -import anndata as ad -import h5py +import sklearn.metrics import torch -import os -from collections import OrderedDict -from . import utils, config +from . import config, utils def compute_libsize(adata: ad.AnnData) -> None: @@ -51,22 +51,24 @@ def normalize(adata: ad.AnnData, target: float = 10000.0) -> None: if "__libsize__" not in adata.obs.columns: compute_libsize(adata) normalizer = target / np.expand_dims(adata.obs["__libsize__"].to_numpy(), axis=1) - adata.X = adata.X.multiply(normalizer).tocsr() \ - if scipy.sparse.issparse(adata.X) \ + adata.X = ( + adata.X.multiply(normalizer).tocsr() + if scipy.sparse.issparse(adata.X) else adata.X * normalizer + ) def find_variable_genes( - adata: ad.AnnData, - slot: str = "variable_genes", - x_low_cutoff: float = 0.1, - x_high_cutoff: float = 8.0, - y_low_cutoff: float = 1.0, - y_high_cutoff: float = np.inf, - num_bin: int = 20, - binning_method: str = "equal_frequency", - grouping: typing.Optional[str] = None, - min_group_frac: float = 0.5 + adata: ad.AnnData, + slot: str = "variable_genes", + x_low_cutoff: float = 0.1, + x_high_cutoff: float = 8.0, + y_low_cutoff: float = 1.0, + y_high_cutoff: float = np.inf, + num_bin: int = 20, + binning_method: str = "equal_frequency", + grouping: typing.Optional[str] = None, + min_group_frac: float = 0.5, ) -> typing.Union[matplotlib.axes.Axes, typing.Mapping[str, matplotlib.axes.Axes]]: r""" A reimplementation of the Seurat v2 "mean.var.plot" gene selection @@ -113,10 +115,14 @@ def find_variable_genes( for group in groups: tmp_adata = adata[adata.obs[grouping] == group, :].copy() ax_dict[group] = find_variable_genes( - tmp_adata, slot=slot, - x_low_cutoff=x_low_cutoff, x_high_cutoff=x_high_cutoff, - y_low_cutoff=y_low_cutoff, y_high_cutoff=y_high_cutoff, - num_bin=num_bin, binning_method=binning_method + tmp_adata, + slot=slot, + x_low_cutoff=x_low_cutoff, + x_high_cutoff=x_high_cutoff, + y_low_cutoff=y_low_cutoff, + y_high_cutoff=y_high_cutoff, + num_bin=num_bin, + binning_method=binning_method, ) selected_list.append(tmp_adata.var[slot].to_numpy().ravel()) selected_count = np.stack(selected_list, axis=1).sum(axis=1) @@ -128,10 +134,9 @@ def find_variable_genes( normalize(adata) X = adata.X mean = np.asarray(np.mean(X, axis=0)).ravel() - var = np.asarray(np.mean( - X.power(2) if scipy.sparse.issparse(X) else np.square(X), - axis=0 - )).ravel() - np.square(mean) + var = np.asarray( + np.mean(X.power(2) if scipy.sparse.issparse(X) else np.square(X), axis=0) + ).ravel() - np.square(mean) log_mean = np.log1p(mean) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) @@ -141,20 +146,23 @@ def find_variable_genes( log_mean_bin = pd.cut(log_mean, num_bin) elif binning_method == "equal_frequency": log_mean_bin = pd.cut( - log_mean, [-1] + np.percentile( + log_mean, + [-1] + + np.percentile( log_mean[log_mean > 0], np.linspace(0, 100, num_bin) - ).tolist() + ).tolist(), ) else: raise ValueError("Invalid binning method!") - summary_df = pd.DataFrame({ - "log_mean": log_mean, - "log_vmr": log_vmr, - "log_mean_bin": log_mean_bin - }, index=adata.var_names) - summary_df["log_vmr_scaled"] = summary_df.loc[ - :, ["log_vmr", "log_mean_bin"] - ].groupby("log_mean_bin").transform(lambda x: (x - x.mean()) / x.std()) + summary_df = pd.DataFrame( + {"log_mean": log_mean, "log_vmr": log_vmr, "log_mean_bin": log_mean_bin}, + index=adata.var_names, + ) + summary_df["log_vmr_scaled"] = ( + summary_df.loc[:, ["log_vmr", "log_mean_bin"]] + .groupby("log_mean_bin") + .transform(lambda x: (x - x.mean()) / x.std()) + ) summary_df["log_vmr_scaled"].fillna(0, inplace=True) selected = summary_df.query( f"log_mean > {x_low_cutoff} & log_mean < {x_high_cutoff} & " @@ -167,13 +175,22 @@ def find_variable_genes( _, ax = plt.subplots(figsize=(7, 7)) ax = sns.scatterplot( - x="log_mean", y="log_vmr_scaled", hue="selected", - data=summary_df, edgecolor=None, s=5, ax=ax + x="log_mean", + y="log_vmr_scaled", + hue="selected", + data=summary_df, + edgecolor=None, + s=5, + ax=ax, ) for _, row in selected.iterrows(): ax.text( - row["log_mean"], row["log_vmr_scaled"], row.name, - size="x-small", ha="center", va="center" + row["log_mean"], + row["log_vmr_scaled"], + row.name, + size="x-small", + ha="center", + va="center", ) ax.set_xlabel("Average expression") ax.set_ylabel("Dispersion") @@ -181,14 +198,18 @@ def find_variable_genes( def _expanded_subset( - mat: typing.Union[scipy.sparse.spmatrix, np.ndarray], idx: np.ndarray, - axis: int = 0, fill: typing.Any = 0 + mat: typing.Union[scipy.sparse.spmatrix, np.ndarray], + idx: np.ndarray, + axis: int = 0, + fill: typing.Any = 0, ) -> typing.Union[scipy.sparse.spmatrix, np.ndarray]: assert axis in (0, 1) expand_size = max(idx.max() - mat.shape[axis] + 1, 0) if axis == 0: if scipy.sparse.issparse(mat): - expand_mat = scipy.sparse.lil_matrix((expand_size, mat.shape[1]), dtype=mat.dtype) + expand_mat = scipy.sparse.lil_matrix( + (expand_size, mat.shape[1]), dtype=mat.dtype + ) if fill != 0: expand_mat[:] = fill expand_mat = scipy.sparse.vstack([mat.tocsr(), expand_mat.tocsr()]) @@ -199,7 +220,9 @@ def _expanded_subset( result_mat = expand_mat[idx, :] else: if scipy.sparse.issparse(mat): - expand_mat = scipy.sparse.lil_matrix((mat.shape[0], expand_size), dtype=mat.dtype) + expand_mat = scipy.sparse.lil_matrix( + (mat.shape[0], expand_size), dtype=mat.dtype + ) if fill != 0: expand_mat[:] = fill expand_mat = scipy.sparse.hstack([mat.tocsc(), expand_mat.tocsc()]) @@ -239,12 +262,11 @@ def select_vars(adata: ad.AnnData, var_names: typing.List[str]) -> ad.AnnData: if new_var_names.size > 0: # pragma: no cover utils.logger.warning( "%d out of %d variables are not found, will be set to zero!", - len(new_var_names), len(var_names) + len(new_var_names), + len(var_names), ) utils.logger.info(str(new_var_names.tolist()).strip("[]")) - idx = np.vectorize( - lambda x: np.where(all_var_names == x)[0][0] - )(var_names) + idx = np.vectorize(lambda x: np.where(all_var_names == x)[0][0])(var_names) new_X = _expanded_subset(adata.X, idx, axis=1, fill=0) new_var = adata.var.reindex(var_names) @@ -254,8 +276,9 @@ def select_vars(adata: ad.AnnData, var_names: typing.List[str]) -> ad.AnnData: def map_vars( - adata: ad.AnnData, mapping: pd.DataFrame, - map_hvg: typing.Optional[typing.List[str]] = None + adata: ad.AnnData, + mapping: pd.DataFrame, + map_hvg: typing.Optional[typing.List[str]] = None, ) -> ad.AnnData: r""" Map variables of input dataset to some other terms, @@ -292,7 +315,7 @@ def map_vars( target_idx = [target_idx_map[val] for val in mapping.iloc[:, 1]] mapping = scipy.sparse.csc_matrix( (np.repeat(1, mapping.shape[0]), (source_idx, target_idx)), - shape=(source.size, target.size) + shape=(source.size, target.size), ) # Sanity check @@ -324,9 +347,11 @@ def map_vars( def annotation_confidence( - adata: ad.AnnData, annotation: typing.Union[str, typing.List[str]], - used_vars: typing.Optional[typing.List[str]] = None, - metric: str = "cosine", return_group_percentile: bool = True + adata: ad.AnnData, + annotation: typing.Union[str, typing.List[str]], + used_vars: typing.Optional[typing.List[str]] = None, + metric: str = "cosine", + return_group_percentile: bool = True, ) -> typing.Tuple[np.ndarray, np.ndarray]: r""" Compute annotation confidence of each obs (cell) based on @@ -380,13 +405,15 @@ def annotation_confidence( return confidence -def write_table(adata: ad.AnnData, filename: str, orientation: str = "cg", **kwargs) -> None: +def write_table( + adata: ad.AnnData, filename: str, orientation: str = "cg", **kwargs +) -> None: r""" Write the expression matrix to a plain-text file. Note that ``obs`` (cell) meta table, ``var`` (gene) meta table and data in the ``uns`` slot are discarded, only the expression matrix is written to the file. - + Parameters ---------- adata @@ -404,27 +431,23 @@ def write_table(adata: ad.AnnData, filename: str, orientation: str = "cg", **kwa os.makedirs(os.path.dirname(filename)) if orientation == "cg": df = pd.DataFrame( - utils.densify(adata.X), - index=adata.obs_names, - columns=adata.var_names + utils.densify(adata.X), index=adata.obs_names, columns=adata.var_names ) elif orientation == "gc": df = pd.DataFrame( - utils.densify(adata.X.T), - index=adata.var_names, - columns=adata.obs_names + utils.densify(adata.X.T), index=adata.var_names, columns=adata.obs_names ) else: # pragma: no cover raise ValueError("Invalid orientation!") df.to_csv(filename, **kwargs) + def read_table( - filename: str, orientation: str = "cg", - sparsify: bool = False, **kwargs + filename: str, orientation: str = "cg", sparsify: bool = False, **kwargs ) -> ad.AnnData: r""" Read expression matrix from a plain-text file - + Parameters ---------- filename @@ -448,9 +471,8 @@ def read_table( scipy.sparse.csr_matrix(df.values) if sparsify else df.values, pd.DataFrame(index=df.index), pd.DataFrame(index=df.columns), - {} + {}, ) - class Dataset(torch.utils.data.Dataset): @@ -459,21 +481,26 @@ def __init__(self, data_dict: typing.OrderedDict) -> None: self.data_dict = data_dict for key, value in self.data_dict.items(): self.data_dict[key] = torch.tensor(utils.densify(value)).float() - + def __getitem__(self, idx): - if isinstance(idx, (slice, np.ndarray)): - return OrderedDict([ - #(item, torch.tensor(utils.densify(self.data_dict[item][idx])).float()) for item in self.data_dict - (item, self.data_dict[item][idx]) for item in self.data_dict - ]) + return OrderedDict( + [ + # (item, torch.tensor(utils.densify(self.data_dict[item][idx])).float()) for item in self.data_dict + (item, self.data_dict[item][idx]) + for item in self.data_dict + ] + ) elif isinstance(idx, int): - return OrderedDict([ - #(item, torch.tensor(utils.densify(self.data_dict[item][idx])).squeeze().float()) for item in self.data_dict - (item, self.data_dict[item][idx]) for item in self.data_dict - ]) + return OrderedDict( + [ + # (item, torch.tensor(utils.densify(self.data_dict[item][idx])).squeeze().float()) for item in self.data_dict + (item, self.data_dict[item][idx]) + for item in self.data_dict + ] + ) return self.data_dict[idx] - + def __len__(self): data_size = set([item.shape[0] for item in self.data_dict.values()]) if data_size: @@ -481,29 +508,30 @@ def __len__(self): return data_size.pop() return 0 -def h5_to_h5ad(inputfile: str, outputfile: str): +def h5_to_h5ad(inputfile: str, outputfile: str): with h5py.File(inputfile, "r") as f: - obs = pd.DataFrame( - dict_from_group(f["obs"]), - index=utils.decode(f["obs_names"][...]) - ) - var = pd.DataFrame( - dict_from_group(f["var"]), - index=utils.decode(f["var_names"][...]) + obs = pd.DataFrame( + dict_from_group(f["obs"]), index=utils.decode(f["obs_names"][...]) + ) + var = pd.DataFrame( + dict_from_group(f["var"]), index=utils.decode(f["var_names"][...]) + ) + uns = dict_from_group(f["uns"]) + + exprs_handle = f["exprs"] + if isinstance(exprs_handle, h5py.Group): # Sparse matrix + mat = scipy.sparse.csr_matrix( + ( + exprs_handle["data"][...], + exprs_handle["indices"][...], + exprs_handle["indptr"][...], + ), + shape=exprs_handle["shape"][...], ) - uns = dict_from_group(f["uns"]) - - exprs_handle = f["exprs"] - if isinstance(exprs_handle, h5py.Group): # Sparse matrix - mat = scipy.sparse.csr_matrix(( - exprs_handle['data'][...], - exprs_handle['indices'][...], - exprs_handle['indptr'][...] - ), shape=exprs_handle['shape'][...]) - else: # Dense matrix - mat = exprs_handle[...].astype(np.float32) - + else: # Dense matrix + mat = exprs_handle[...].astype(np.float32) + adata = ad.AnnData(X=mat, obs=obs, var=var, uns=dict(uns)) adata.write(outputfile) @@ -520,6 +548,7 @@ def read_clean(data): data = data.flat[0] return data + def dict_from_group(group): assert isinstance(group, h5py.Group) d = {} @@ -529,4 +558,4 @@ def dict_from_group(group): else: value = read_clean(group[key][...]) d[key] = value - return d \ No newline at end of file + return d diff --git a/Cell_BLAST/directi.py b/Cell_BLAST/directi.py index 40be2df..a9a0aa8 100644 --- a/Cell_BLAST/directi.py +++ b/Cell_BLAST/directi.py @@ -4,28 +4,28 @@ """ import os -import typing import tempfile +import time +import typing +from collections import OrderedDict +import anndata as ad import numpy as np import pandas as pd +import scipy import torch -from torch import nn import torch.distributions as D -import time -import anndata as ad -import scipy -from collections import OrderedDict +from torch import nn from torch.utils.tensorboard import SummaryWriter -from . import config, data, utils, latent, prob, rmbatch +from . import config, data, latent, prob, rmbatch, utils from .rebuild import RMSprop _TRAIN = 1 _TEST = 0 -class DIRECTi(nn.Module): +class DIRECTi(nn.Module): r""" DIRECTi model. @@ -72,27 +72,28 @@ class DIRECTi(nn.Module): _TEST = 0 def __init__( - self, genes: typing.List[str], - latent_module: "latent.Latent", - prob_module: "prob.ProbModel", - rmbatch_modules: typing.Tuple["rmbatch.RMBatch"], - denoising: bool = True, - learning_rate: float = 1e-3, - path: typing.Optional[str] = None, - random_seed: int = config._USE_GLOBAL, - _mode: int = _TRAIN + self, + genes: typing.List[str], + latent_module: "latent.Latent", + prob_module: "prob.ProbModel", + rmbatch_modules: typing.Tuple["rmbatch.RMBatch"], + denoising: bool = True, + learning_rate: float = 1e-3, + path: typing.Optional[str] = None, + random_seed: int = config._USE_GLOBAL, + _mode: int = _TRAIN, ) -> None: - super().__init__() if path is None: path = tempfile.mkdtemp() else: - os.makedirs(path, exist_ok = True) + os.makedirs(path, exist_ok=True) utils.logger.info("Using model path: %s", path) - random_seed = config.RANDOM_SEED \ - if random_seed == config._USE_GLOBAL else random_seed + random_seed = ( + config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed + ) self.ensure_reproducibility(random_seed) self.genes = genes @@ -105,13 +106,19 @@ def __init__( self.random_seed = random_seed self._mode = _mode - self.opt_latent_reg = RMSprop(self.latent_module.parameters_reg(), lr = learning_rate) - self.opt_latent_fit = RMSprop(self.latent_module.parameters_fit(), lr = learning_rate) - self.opt_prob = RMSprop(self.prob_module.parameters(), lr = learning_rate) + self.opt_latent_reg = RMSprop( + self.latent_module.parameters_reg(), lr=learning_rate + ) + self.opt_latent_fit = RMSprop( + self.latent_module.parameters_fit(), lr=learning_rate + ) + self.opt_prob = RMSprop(self.prob_module.parameters(), lr=learning_rate) self.opts_rmbatch = [ - RMSprop(_rmbatch.parameters(), lr = learning_rate) if _rmbatch._class in ( - "Adversarial", "MNNAdversarial", "AdaptiveMNNAdversarial" - ) else None for _rmbatch in self.rmbatch_modules + RMSprop(_rmbatch.parameters(), lr=learning_rate) + if _rmbatch._class + in ("Adversarial", "MNNAdversarial", "AdaptiveMNNAdversarial") + else None + for _rmbatch in self.rmbatch_modules ] @staticmethod @@ -125,44 +132,57 @@ def get_config(self) -> typing.Mapping: "genes": self.genes, "latent_module": self.latent_module.get_config(), "prob_module": self.prob_module.get_config(), - "rmbatch_modules": [_module.get_config() for _module in self.rmbatch_modules], + "rmbatch_modules": [ + _module.get_config() for _module in self.rmbatch_modules + ], "denoising": self.denoising, "learning_rate": self.learning_rate, "path": self.path, "random_seed": self.random_seed, - "_mode": self._mode + "_mode": self._mode, } @staticmethod - def preprocess(x: torch.Tensor, libs: torch.Tensor, noisy: bool = True) -> torch.Tensor: + def preprocess( + x: torch.Tensor, libs: torch.Tensor, noisy: bool = True + ) -> torch.Tensor: x = x / (libs / 10000) if noisy: - x = D.Poisson(rate = x).sample() + x = D.Poisson(rate=x).sample() x = x.log1p() return x - def fit(self, - dataset: data.Dataset, - batch_size: int = 128, - val_split: float = 0.1, - epoch: int = 1000, - patience: int = 30, - tolerance: float = 0.0, - progress_bar: bool = False): - + def fit( + self, + dataset: data.Dataset, + batch_size: int = 128, + val_split: float = 0.1, + epoch: int = 1000, + patience: int = 30, + tolerance: float = 0.0, + progress_bar: bool = False, + ): val_size = int(len(dataset) * val_split) train_size = len(dataset) - val_size train_dataset, val_dataset = torch.utils.data.random_split( - dataset, [train_size, val_size], - generator=torch.Generator().manual_seed(self.random_seed) + dataset, + [train_size, val_size], + generator=torch.Generator().manual_seed(self.random_seed), ) - train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, - shuffle = True, drop_last = True, - generator=torch.Generator().manual_seed(self.random_seed)) - val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, - shuffle = True, - generator=torch.Generator().manual_seed(self.random_seed)) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + drop_last=True, + generator=torch.Generator().manual_seed(self.random_seed), + ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=True, + generator=torch.Generator().manual_seed(self.random_seed), + ) assert self._mode == _TRAIN self.to(config.DEVICE) @@ -172,14 +192,12 @@ def fit(self, self.latent_module.check_fine_tune() self.prob_module.check_fine_tune() - patience_remain = patience best_loss = 1e10 - summarywriter = SummaryWriter(log_dir = os.path.join(self.path, 'summary')) + summarywriter = SummaryWriter(log_dir=os.path.join(self.path, "summary")) for _epoch in range(epoch): - start_time = time.time() if progress_bar: @@ -217,7 +235,6 @@ def fit(self, self.save_weights(self.path) def train_epoch(self, train_dataloader, epoch, summarywriter): - self.train() loss_record = {} @@ -225,12 +242,11 @@ def train_epoch(self, train_dataloader, epoch, summarywriter): self.prob_module.init_loss_record(loss_record) for _rmbatch in self.rmbatch_modules: _rmbatch.init_loss_record(loss_record) - loss_record['early_stop_loss'] = 0 - loss_record['total_loss'] = 0 + loss_record["early_stop_loss"] = 0 + loss_record["total_loss"] = 0 datasize = 0 for feed_dict in train_dataloader: - for key, value in feed_dict.items(): feed_dict[key] = value.to(config.DEVICE) @@ -240,7 +256,9 @@ def train_epoch(self, train_dataloader, epoch, summarywriter): x = self.preprocess(exprs, libs, self.denoising) l, l_components = self.latent_module(x) - latent_d_loss = self.latent_module.d_loss(l_components, feed_dict, loss_record) + latent_d_loss = self.latent_module.d_loss( + l_components, feed_dict, loss_record + ) self.opt_latent_reg.zero_grad() latent_d_loss.backward() self.opt_latent_reg.step() @@ -252,7 +270,9 @@ def train_epoch(self, train_dataloader, epoch, summarywriter): if mask.sum() > 0: for _ in range(_rmbatch.n_steps): pred = _rmbatch(l, mask) - rmbatch_d_loss = _rmbatch.d_loss(pred, feed_dict, mask, loss_record) + rmbatch_d_loss = _rmbatch.d_loss( + pred, feed_dict, mask, loss_record + ) if not _opt is None: _opt.zero_grad() rmbatch_d_loss.backward() @@ -260,7 +280,9 @@ def train_epoch(self, train_dataloader, epoch, summarywriter): x = self.preprocess(exprs, libs, self.denoising) l, l_components = self.latent_module(x) - latent_g_loss = self.latent_module.g_loss(l_components, feed_dict, loss_record) + latent_g_loss = self.latent_module.g_loss( + l_components, feed_dict, loss_record + ) full_l = [l] for _rmbatch in self.rmbatch_modules: full_l.append(feed_dict[_rmbatch.name]) @@ -272,7 +294,9 @@ def train_epoch(self, train_dataloader, epoch, summarywriter): mask = _rmbatch.get_mask(l, feed_dict) if mask.sum() > 0: pred = _rmbatch(l, mask) - rmbatch_g_loss = _rmbatch.g_loss(pred, feed_dict, mask, loss_record) + rmbatch_g_loss = _rmbatch.g_loss( + pred, feed_dict, mask, loss_record + ) loss = loss + rmbatch_g_loss self.opt_latent_fit.zero_grad() @@ -281,16 +305,15 @@ def train_epoch(self, train_dataloader, epoch, summarywriter): self.opt_latent_fit.step() self.opt_prob.step() - loss_record['early_stop_loss'] += prob_loss.item() * x.shape[0] - loss_record['total_loss'] += loss.item() * x.shape[0] + loss_record["early_stop_loss"] += prob_loss.item() * x.shape[0] + loss_record["total_loss"] += loss.item() * x.shape[0] for key, value in loss_record.items(): - summarywriter.add_scalar(key + ':0 (train)', value / datasize, epoch) + summarywriter.add_scalar(key + ":0 (train)", value / datasize, epoch) - return loss_record['early_stop_loss'] / datasize + return loss_record["early_stop_loss"] / datasize def val_epoch(self, val_dataloader, epoch, summarywriter): - self.eval() loss_record = {} @@ -298,12 +321,11 @@ def val_epoch(self, val_dataloader, epoch, summarywriter): self.prob_module.init_loss_record(loss_record) for _rmbatch in self.rmbatch_modules: _rmbatch.init_loss_record(loss_record) - loss_record['early_stop_loss'] = 0 - loss_record['total_loss'] = 0 + loss_record["early_stop_loss"] = 0 + loss_record["total_loss"] = 0 datasize = 0 for feed_dict in val_dataloader: - for key, value in feed_dict.items(): feed_dict[key] = value.to(config.DEVICE) @@ -326,7 +348,9 @@ def val_epoch(self, val_dataloader, epoch, summarywriter): x = self.preprocess(exprs, libs, self.denoising) l, l_components = self.latent_module(x) - latent_g_loss = self.latent_module.g_loss(l_components, feed_dict, loss_record) + latent_g_loss = self.latent_module.g_loss( + l_components, feed_dict, loss_record + ) full_l = [l] for _rmbatch in self.rmbatch_modules: full_l.append(feed_dict[_rmbatch.name]) @@ -338,20 +362,21 @@ def val_epoch(self, val_dataloader, epoch, summarywriter): mask = _rmbatch.get_mask(l, feed_dict) if mask.sum() > 0: pred = _rmbatch(l, mask) - rmbatch_g_loss = _rmbatch.g_loss(pred, feed_dict, mask, loss_record) + rmbatch_g_loss = _rmbatch.g_loss( + pred, feed_dict, mask, loss_record + ) loss = loss + rmbatch_g_loss - loss_record['early_stop_loss'] += prob_loss.item() * x.shape[0] - loss_record['total_loss'] += loss.item() * x.shape[0] + loss_record["early_stop_loss"] += prob_loss.item() * x.shape[0] + loss_record["total_loss"] += loss.item() * x.shape[0] for key, value in loss_record.items(): - summarywriter.add_scalar(key + ':0 (val)', value / datasize, epoch) - - return loss_record['early_stop_loss'] / datasize + summarywriter.add_scalar(key + ":0 (val)", value / datasize, epoch) + return loss_record["early_stop_loss"] / datasize def save_weights(self, path: str, checkpoint: str = "checkpoint.pk"): - os.makedirs(path, exist_ok = True) + os.makedirs(path, exist_ok=True) torch.save(self.state_dict(), os.path.join(path, checkpoint)) def load_weights(self, path: str, checkpoint: str = "checkpoint.pk"): @@ -360,28 +385,31 @@ def load_weights(self, path: str, checkpoint: str = "checkpoint.pk"): @classmethod def load_config(cls, configuration: typing.Mapping): + _class = configuration["latent_module"]["_class"] + latent_module = getattr(latent, _class)(**configuration["latent_module"]) - _class = configuration['latent_module']['_class'] - latent_module = getattr(latent, _class)(**configuration['latent_module']) - - _class = configuration['prob_module']['_class'] - prob_module = getattr(prob, _class)(**configuration['prob_module']) + _class = configuration["prob_module"]["_class"] + prob_module = getattr(prob, _class)(**configuration["prob_module"]) rmbatch_modules = nn.ModuleList() - for _conf in configuration['rmbatch_modules']: - _class = _conf['_class'] + for _conf in configuration["rmbatch_modules"]: + _class = _conf["_class"] rmbatch_modules.append(getattr(rmbatch, _class)(**_conf)) - configuration['latent_module'] = latent_module - configuration['prob_module'] = prob_module - configuration['rmbatch_modules'] = rmbatch_modules + configuration["latent_module"] = latent_module + configuration["prob_module"] = prob_module + configuration["rmbatch_modules"] = rmbatch_modules model = cls(**configuration) return model - def save(self, path: typing.Optional[str] = None, - config: str = "config.pk", weights: str = "weights.pk"): + def save( + self, + path: typing.Optional[str] = None, + config: str = "config.pk", + weights: str = "weights.pk", + ): r""" Save model to files @@ -398,17 +426,20 @@ def save(self, path: typing.Optional[str] = None, torch.save(self.get_config(), os.path.join(self.path, config)) torch.save(self.state_dict(), os.path.join(self.path, weights)) else: - os.makedirs(path, exist_ok = True) + os.makedirs(path, exist_ok=True) configuration = self.get_config() - configuration['path'] = path + configuration["path"] = path torch.save(configuration, os.path.join(path, config)) torch.save(self.state_dict(), os.path.join(path, weights)) - @classmethod - def load(cls, path: str, - config: str = "config.pk", weights: str = "weights.pk", - _mode: int = _TRAIN) -> None: + def load( + cls, + path: str, + config: str = "config.pk", + weights: str = "weights.pk", + _mode: int = _TRAIN, + ) -> None: r""" Load model from files @@ -424,24 +455,25 @@ def load(cls, path: str, assert os.path.exists(path) configuration = torch.load(os.path.join(path, config)) - if configuration['_mode'] == _TEST and _mode == _TRAIN: - raise RuntimeError("The model was minimal, please use argument '_mode=Cell_BLAST.blast.MINIMAL'") + if configuration["_mode"] == _TEST and _mode == _TRAIN: + raise RuntimeError( + "The model was minimal, please use argument '_mode=Cell_BLAST.blast.MINIMAL'" + ) model = cls.load_config(configuration) - model.load_state_dict(torch.load(os.path.join(path, weights)), strict = False) + model.load_state_dict(torch.load(os.path.join(path, weights)), strict=False) return model - - def inference(self, - adata: ad.AnnData, - batch_size: int = 4096, - n_posterior: int = 0, - progress_bar: bool = False, - priority: str = "auto", - random_seed: typing.Optional[int] = config._USE_GLOBAL, - ) -> np.ndarray: - + def inference( + self, + adata: ad.AnnData, + batch_size: int = 4096, + n_posterior: int = 0, + progress_bar: bool = False, + priority: str = "auto", + random_seed: typing.Optional[int] = config._USE_GLOBAL, + ) -> np.ndarray: r""" Project expression profiles into the cell embedding space. @@ -480,8 +512,11 @@ def inference(self, self.eval() self.to(config.DEVICE) - random_seed = config.RANDOM_SEED \ - if random_seed is None or random_seed == config._USE_GLOBAL else random_seed + random_seed = ( + config.RANDOM_SEED + if random_seed is None or random_seed == config._USE_GLOBAL + else random_seed + ) x = data.select_vars(adata, self.genes).X if "__libsize__" not in adata.obs.columns: data.compute_libsize(adata) @@ -496,27 +531,44 @@ def inference(self, xrep = np.repeat(x, n_posterior, axis=0) lrep = np.repeat(l, n_posterior, axis=0) data_dict = OrderedDict(exprs=xrep, library_size=lrep) - return self._fetch_latent( - data.Dataset(data_dict), batch_size, True, progress_bar, random_seed - ).astype(np.float32).reshape((x.shape[0], n_posterior, -1)) + return ( + self._fetch_latent( + data.Dataset(data_dict), + batch_size, + True, + progress_bar, + random_seed, + ) + .astype(np.float32) + .reshape((x.shape[0], n_posterior, -1)) + ) else: # priority == "memory": data_dict = OrderedDict(exprs=x, library_size=l) - return np.stack([self._fetch_latent( - data.Dataset(data_dict), batch_size, True, progress_bar, - (random_seed + i) if random_seed is not None else None - ).astype(np.float32) for i in range(n_posterior)], axis=1) + return np.stack( + [ + self._fetch_latent( + data.Dataset(data_dict), + batch_size, + True, + progress_bar, + (random_seed + i) if random_seed is not None else None, + ).astype(np.float32) + for i in range(n_posterior) + ], + axis=1, + ) data_dict = OrderedDict(exprs=x, library_size=l) return self._fetch_latent( data.Dataset(data_dict), batch_size, False, progress_bar, random_seed ).astype(np.float32) - - - def clustering(self, - adata: ad.AnnData, - batch_size: int = 4096, - return_confidence: bool = False, - progress_bar: bool = False) -> typing.Tuple[np.ndarray]: + def clustering( + self, + adata: ad.AnnData, + batch_size: int = 4096, + return_confidence: bool = False, + progress_bar: bool = False, + ) -> typing.Tuple[np.ndarray]: r""" Get model intrinsic clustering of the data. @@ -551,16 +603,18 @@ def clustering(self, l = adata.obs["__libsize__"].to_numpy().reshape((-1, 1)) data_dict = OrderedDict(exprs=x, library_size=l) cat = self._fetch_cat( - data.Dataset(data_dict), batch_size, False, - progress_bar + data.Dataset(data_dict), batch_size, False, progress_bar ).astype(np.float32) if return_confidence: return cat.argmax(axis=1), cat.max(axis=1) return cat.argmax(axis=1) def gene_grad( - self, adata: ad.AnnData, latent_grad: np.ndarray, - batch_size: int = 4096, progress_bar: bool = False + self, + adata: ad.AnnData, + latent_grad: np.ndarray, + batch_size: int = 4096, + progress_bar: bool = False, ) -> np.ndarray: r""" Fetch gene space gradients with regard to latent space gradients @@ -590,19 +644,24 @@ def gene_grad( if "__libsize__" not in adata.obs.columns: data.compute_libsize(adata) l = adata.obs["__libsize__"].to_numpy().reshape((-1, 1)) - data_dict = OrderedDict( - exprs=x, library_size=l, output_grad=latent_grad) + data_dict = OrderedDict(exprs=x, library_size=l, output_grad=latent_grad) return self._fetch_grad( - data.Dataset(data_dict), - batch_size=batch_size, progress_bar=progress_bar + data.Dataset(data_dict), batch_size=batch_size, progress_bar=progress_bar ) - def _fetch_latent(self, dataset: data.Dataset, batch_size: int, noisy: bool, - progress_bar: bool, random_seed: int) -> np.ndarray: - + def _fetch_latent( + self, + dataset: data.Dataset, + batch_size: int, + noisy: bool, + progress_bar: bool, + random_seed: int, + ) -> np.ndarray: self.ensure_reproducibility(random_seed) - dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = False) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=False + ) if progress_bar: dataloader = utils.smart_tqdm()(dataloader) @@ -613,15 +672,19 @@ def _fetch_latent(self, dataset: data.Dataset, batch_size: int, noisy: bool, feed_dict[key] = value.to(config.DEVICE) exprs = feed_dict["exprs"] libs = feed_dict["library_size"] - latents.append(self.latent_module.fetch_latent(self.preprocess(exprs, libs, noisy))) + latents.append( + self.latent_module.fetch_latent(self.preprocess(exprs, libs, noisy)) + ) return torch.cat(latents).cpu().numpy() - def _fetch_cat(self, dataset: data.Dataset, batch_size: int, noisy: bool, - progress_bar: bool) -> typing.Tuple[np.ndarray]: - + def _fetch_cat( + self, dataset: data.Dataset, batch_size: int, noisy: bool, progress_bar: bool + ) -> typing.Tuple[np.ndarray]: self.ensure_reproducibility(self.random_seed) - dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = False) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=False + ) if progress_bar: dataloader = utils.smart_tqdm()(dataloader) @@ -632,14 +695,19 @@ def _fetch_cat(self, dataset: data.Dataset, batch_size: int, noisy: bool, feed_dict[key] = value.to(config.DEVICE) exprs = feed_dict["exprs"] libs = feed_dict["library_size"] - cats.append(self.latent_module.fetch_cat(self.preprocess(exprs, libs, noisy))) + cats.append( + self.latent_module.fetch_cat(self.preprocess(exprs, libs, noisy)) + ) return torch.cat(cats).cpu().numpy() - def _fetch_grad(self, dataset: data.Dataset, batch_size: int, progress_bar: bool) -> np.ndarray: - + def _fetch_grad( + self, dataset: data.Dataset, batch_size: int, progress_bar: bool + ) -> np.ndarray: self.ensure_reproducibility(self.random_seed) - dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = False) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=False + ) if progress_bar: dataloader = utils.smart_tqdm()(dataloader) @@ -650,38 +718,41 @@ def _fetch_grad(self, dataset: data.Dataset, batch_size: int, progress_bar: bool exprs = feed_dict["exprs"] libs = feed_dict["library_size"] latent_grad = feed_dict["output_grad"] - grads.append(self.latent_module.fetch_grad(self.preprocess(exprs, libs, self.denoising), latent_grad)) + grads.append( + self.latent_module.fetch_grad( + self.preprocess(exprs, libs, self.denoising), latent_grad + ) + ) return torch.cat(grads).cpu().numpy() def fit_DIRECTi( - adata: ad.AnnData, - genes: typing.Optional[typing.List[str]] = None, - supervision: typing.Optional[str] = None, - batch_effect: typing.Optional[typing.List[str]] = None, - latent_dim: int = 10, - cat_dim: typing.Optional[int] = None, - h_dim: int = 128, - depth: int = 1, - prob_module: str = "NB", - rmbatch_module: typing.Union[str, typing.List[str]] = "Adversarial", - latent_module_kwargs: typing.Optional[typing.Mapping] = None, - prob_module_kwargs: typing.Optional[typing.Mapping] = None, - rmbatch_module_kwargs: typing.Optional[typing.Union[ - typing.Mapping, typing.List[typing.Mapping] - ]] = None, - optimizer: str = "RMSPropOptimizer", - learning_rate: float = 1e-3, - batch_size: int = 128, - val_split: float = 0.1, - epoch: int = 1000, - patience: int = 30, - progress_bar: bool = False, - reuse_weights: typing.Optional[str] = None, - random_seed: int = config._USE_GLOBAL, - path: typing.Optional[str] = None + adata: ad.AnnData, + genes: typing.Optional[typing.List[str]] = None, + supervision: typing.Optional[str] = None, + batch_effect: typing.Optional[typing.List[str]] = None, + latent_dim: int = 10, + cat_dim: typing.Optional[int] = None, + h_dim: int = 128, + depth: int = 1, + prob_module: str = "NB", + rmbatch_module: typing.Union[str, typing.List[str]] = "Adversarial", + latent_module_kwargs: typing.Optional[typing.Mapping] = None, + prob_module_kwargs: typing.Optional[typing.Mapping] = None, + rmbatch_module_kwargs: typing.Optional[ + typing.Union[typing.Mapping, typing.List[typing.Mapping]] + ] = None, + optimizer: str = "RMSPropOptimizer", + learning_rate: float = 1e-3, + batch_size: int = 128, + val_split: float = 0.1, + epoch: int = 1000, + patience: int = 30, + progress_bar: bool = False, + reuse_weights: typing.Optional[str] = None, + random_seed: int = config._USE_GLOBAL, + path: typing.Optional[str] = None, ) -> DIRECTi: - r""" A convenient one-step function to build and fit DIRECTi models. Should work well in most cases. @@ -760,8 +831,11 @@ def fit_DIRECTi( See the DIRECTi ipython notebook (:ref:`vignettes`) for live examples. """ - random_seed = config.RANDOM_SEED \ - if random_seed is None or random_seed == config._USE_GLOBAL else random_seed + random_seed = ( + config.RANDOM_SEED + if random_seed is None or random_seed == config._USE_GLOBAL + else random_seed + ) DIRECTi.ensure_reproducibility(random_seed) if latent_module_kwargs is None: @@ -782,7 +856,7 @@ def fit_DIRECTi( data.compute_libsize(adata) data_dict = OrderedDict( library_size=adata.obs["__libsize__"].to_numpy().reshape((-1, 1)), - exprs=data.select_vars(adata, genes).X + exprs=data.select_vars(adata, genes).X, ) if batch_effect is None: @@ -805,13 +879,17 @@ def fit_DIRECTi( if cat_dim is None: cat_dim = data_dict[supervision].shape[1] elif cat_dim > data_dict[supervision].shape[1]: - data_dict[supervision] = scipy.sparse.hstack([ - data_dict[supervision].tocsc(), - scipy.sparse.csc_matrix(( - data_dict[supervision].shape[0], - cat_dim - data_dict[supervision].shape[1] - )) - ]).tocsr() + data_dict[supervision] = scipy.sparse.hstack( + [ + data_dict[supervision].tocsc(), + scipy.sparse.csc_matrix( + ( + data_dict[supervision].shape[0], + cat_dim - data_dict[supervision].shape[1], + ) + ), + ] + ).tocsr() elif cat_dim < data_dict[supervision].shape[1]: # pragma: no cover raise ValueError( "`cat_dim` must be greater than or equal to " @@ -842,17 +920,15 @@ def fit_DIRECTi( rmbatch_list = nn.ModuleList() full_latent_dim = [latent_dim] for _batch_effect, _rmbatch_module, _rmbatch_module_kwargs in zip( - batch_effect, rmbatch_module, rmbatch_module_kwargs + batch_effect, rmbatch_module, rmbatch_module_kwargs ): batch_dim = len(adata.obs[_batch_effect].dropna().unique()) full_latent_dim.append(batch_dim) - kwargs = dict( - batch_dim=batch_dim, - latent_dim=latent_dim, - name=_batch_effect - ) + kwargs = dict(batch_dim=batch_dim, latent_dim=latent_dim, name=_batch_effect) if _rmbatch_module in ( - "Adversarial", "MNNAdversarial", "AdaptiveMNNAdversarial" + "Adversarial", + "MNNAdversarial", + "AdaptiveMNNAdversarial", ): kwargs.update(dict(h_dim=h_dim, depth=depth)) kwargs.update(_rmbatch_module_kwargs) @@ -862,7 +938,9 @@ def fit_DIRECTi( kwargs.update(_rmbatch_module_kwargs) rmbatch_list.append(getattr(rmbatch, _rmbatch_module)(**kwargs)) - kwargs = dict(output_dim=len(genes), full_latent_dim=full_latent_dim, h_dim=h_dim, depth=depth) + kwargs = dict( + output_dim=len(genes), full_latent_dim=full_latent_dim, h_dim=h_dim, depth=depth + ) kwargs.update(prob_module_kwargs) prob_module = getattr(prob, prob_module)(**kwargs) @@ -888,33 +966,31 @@ def fit_DIRECTi( val_split=val_split, epoch=epoch, patience=patience, - progress_bar=progress_bar + progress_bar=progress_bar, ) return model def align_DIRECTi( - model: DIRECTi, - original_adata: ad.AnnData, - new_adata: typing.Union[ad.AnnData, typing.Mapping[str, ad.AnnData]], - rmbatch_module: str = "MNNAdversarial", - rmbatch_module_kwargs: typing.Optional[typing.Mapping] = None, - deviation_reg: float = 0.01, - optimizer: str = "RMSPropOptimizer", - learning_rate: float = 1e-3, - batch_size: int = 256, - val_split: float = 0.1, - epoch: int = 100, - patience: int = 100, - tolerance: float = 0.0, - reuse_weights: bool = True, - progress_bar: bool = False, - random_seed: int = config._USE_GLOBAL, - path: typing.Optional[str] = None - + model: DIRECTi, + original_adata: ad.AnnData, + new_adata: typing.Union[ad.AnnData, typing.Mapping[str, ad.AnnData]], + rmbatch_module: str = "MNNAdversarial", + rmbatch_module_kwargs: typing.Optional[typing.Mapping] = None, + deviation_reg: float = 0.01, + optimizer: str = "RMSPropOptimizer", + learning_rate: float = 1e-3, + batch_size: int = 256, + val_split: float = 0.1, + epoch: int = 100, + patience: int = 100, + tolerance: float = 0.0, + reuse_weights: bool = True, + progress_bar: bool = False, + random_seed: int = config._USE_GLOBAL, + path: typing.Optional[str] = None, ) -> DIRECTi: - r""" Align datasets starting with an existing DIRECTi model (fine-tuning) @@ -968,14 +1044,15 @@ def align_DIRECTi( Aligned model. """ - random_seed = config.RANDOM_SEED \ - if random_seed == config._USE_GLOBAL else random_seed + random_seed = ( + config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed + ) DIRECTi.ensure_reproducibility(random_seed) if path is None: path = tempfile.mkdtemp() else: - os.makedirs(path, exist_ok = True) + os.makedirs(path, exist_ok=True) if rmbatch_module_kwargs is None: rmbatch_module_kwargs = {} @@ -984,8 +1061,9 @@ def align_DIRECTi( if isinstance(new_adata, ad.AnnData): new_adatas = {"__new__": new_adata} elif isinstance(new_adata, dict): - assert "__original__" not in new_adata, \ - "Key `__original__` is now allowed in new datasets." + assert ( + "__original__" not in new_adata + ), "Key `__original__` is now allowed in new datasets." new_adatas = new_adata.copy() # shallow else: raise TypeError("Invalid type for argument `new_dataset`.") @@ -995,54 +1073,59 @@ def align_DIRECTi( _rmbatch_module["delay"] = 0 kwargs = { "batch_dim": len(new_adatas) + 1, - 'latent_dim': model.latent_module.latent_dim, - "delay": 0, "name": "__align__", - "_class": rmbatch_module + "latent_dim": model.latent_module.latent_dim, + "delay": 0, + "name": "__align__", + "_class": rmbatch_module, } - if rmbatch_module in ( - "Adversarial", "MNNAdversarial", "AdaptiveMNNAdversarial" - ): - kwargs.update(dict( - h_dim=model.latent_module.h_dim, - depth=model.latent_module.depth, - dropout=model.latent_module.dropout, - lambda_reg=0.01 - )) + if rmbatch_module in ("Adversarial", "MNNAdversarial", "AdaptiveMNNAdversarial"): + kwargs.update( + dict( + h_dim=model.latent_module.h_dim, + depth=model.latent_module.depth, + dropout=model.latent_module.dropout, + lambda_reg=0.01, + ) + ) elif rmbatch_module not in ("RMBatch", "MNN"): # pragma: no cover raise ValueError("Unknown rmbatch_module!") # else "RMBatch" or "MNN" kwargs.update(rmbatch_module_kwargs) _config["rmbatch_modules"].append(kwargs) - _config["prob_module"]['full_latent_dim'].append(len(new_adatas) + 1) + _config["prob_module"]["full_latent_dim"].append(len(new_adatas) + 1) _config["prob_module"]["fine_tune"] = True _config["prob_module"]["deviation_reg"] = deviation_reg _config["learning_rate"] = learning_rate aligned_model = DIRECTi.load_config(_config) if reuse_weights: - aligned_model.load_state_dict(model.state_dict(), strict = False) - supervision = aligned_model.latent_module.name if isinstance( - aligned_model.latent_module, latent.SemiSupervisedCatGau - ) else None + aligned_model.load_state_dict(model.state_dict(), strict=False) + supervision = ( + aligned_model.latent_module.name + if isinstance(aligned_model.latent_module, latent.SemiSupervisedCatGau) + else None + ) - assert "__align__" not in original_adata.obs.columns, \ - "Please remove column `__align__` from obs of the original dataset." + assert ( + "__align__" not in original_adata.obs.columns + ), "Please remove column `__align__` from obs of the original dataset." original_adata = ad.AnnData( X=original_adata.X, obs=original_adata.obs.copy(deep=False), - var=original_adata.var.copy(deep=False) + var=original_adata.var.copy(deep=False), ) if "__libsize__" not in original_adata.obs.columns: data.compute_libsize(original_adata) original_adata = data.select_vars(original_adata, model.genes) for key in new_adatas.keys(): - assert "__align__" not in new_adatas[key].obs.columns, \ - f"Please remove column `__align__` from new dataset {key}." + assert ( + "__align__" not in new_adatas[key].obs.columns + ), f"Please remove column `__align__` from new dataset {key}." new_adatas[key] = ad.AnnData( X=new_adatas[key].X, obs=new_adatas[key].obs.copy(deep=False), - var=new_adatas[key].var.copy(deep=False) + var=new_adatas[key].var.copy(deep=False), ) new_adatas[key].obs = new_adatas[key].obs.loc[ :, new_adatas[key].obs.columns == "__libsize__" @@ -1060,25 +1143,27 @@ def align_DIRECTi( data_dict = OrderedDict( library_size=adata.obs["__libsize__"].to_numpy().reshape((-1, 1)), - exprs=data.select_vars(adata, model.genes).X # Ensure order + exprs=data.select_vars(adata, model.genes).X, # Ensure order ) for rmbatch_module in aligned_model.rmbatch_modules: data_dict[rmbatch_module.name] = utils.encode_onehot( adata.obs[rmbatch_module.name], sort=True ) if isinstance(aligned_model.latent_module, latent.SemiSupervisedCatGau): - data_dict[supervision] = utils.encode_onehot( - adata.obs[supervision], sort=True - ) + data_dict[supervision] = utils.encode_onehot(adata.obs[supervision], sort=True) cat_dim = aligned_model.latent_module.cat_dim if cat_dim > data_dict[supervision].shape[1]: - data_dict[supervision] = scipy.sparse.hstack([ - data_dict[supervision].tocsc(), - scipy.sparse.csc_matrix(( - data_dict[supervision].shape[0], - cat_dim - data_dict[supervision].shape[1] - )) - ]).tocsr() + data_dict[supervision] = scipy.sparse.hstack( + [ + data_dict[supervision].tocsc(), + scipy.sparse.csc_matrix( + ( + data_dict[supervision].shape[0], + cat_dim - data_dict[supervision].shape[1], + ) + ), + ] + ).tocsr() if optimizer != "RMSPropOptimizer": utils.logger.warning("Argument `optimizer` is not supported!") @@ -1090,6 +1175,6 @@ def align_DIRECTi( epoch=epoch, patience=patience, tolerance=tolerance, - progress_bar=progress_bar + progress_bar=progress_bar, ) - return aligned_model \ No newline at end of file + return aligned_model diff --git a/Cell_BLAST/latent.py b/Cell_BLAST/latent.py index dce3f6e..edf68cf 100644 --- a/Cell_BLAST/latent.py +++ b/Cell_BLAST/latent.py @@ -2,22 +2,28 @@ Latent space / encoder modules for DIRECTi """ +import itertools +import typing + import torch -from torch import nn import torch.distributions as D import torch.nn.functional as F -import typing -import itertools -from .rebuild import Linear -from .rebuild import MLP +from torch import nn + from . import config, utils +from .rebuild import MLP, Linear class Regularizer(nn.Module): def __init__( - self, latent_dim: int, h_dim: int = 128, depth: int = 1, dropout: float = 0.0, - name: str = 'Reg', _class: str = 'Regularizer', - **kwargs + self, + latent_dim: int, + h_dim: int = 128, + depth: int = 1, + dropout: float = 0.0, + name: str = "Reg", + _class: str = "Regularizer", + **kwargs, ) -> None: super().__init__() self.latent_dim = latent_dim @@ -48,22 +54,28 @@ def get_config(self) -> typing.Mapping: "depth": self.depth, "dropout": self.dropout, "name": self.name, - "_class": self._class + "_class": self._class, } - class Latent(nn.Module): r""" Abstract base class for latent variable modules. """ + def __init__( - self, input_dim: int, latent_dim: int, h_dim: int = 128, depth: int = 1, - dropout: float = 0.0, lambda_reg: float = 0.0, - fine_tune: bool = False, deviation_reg: float = 0.0, - name: str = "Latent", - _class: str = "Latent", - **kwargs + self, + input_dim: int, + latent_dim: int, + h_dim: int = 128, + depth: int = 1, + dropout: float = 0.0, + lambda_reg: float = 0.0, + fine_tune: bool = False, + deviation_reg: float = 0.0, + name: str = "Latent", + _class: str = "Latent", + **kwargs, ) -> None: super().__init__() self.input_dim = input_dim @@ -76,13 +88,15 @@ def __init__( self.deviation_reg = deviation_reg self.name = name self._class = _class - self.record_prefix = 'discriminator' + self.record_prefix = "discriminator" for key in kwargs.keys(): utils.logger.warning("Argument `%s` is no longer supported!" % key) @staticmethod - def gan_d_loss(y: torch.Tensor, y_hat: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + def gan_d_loss( + y: torch.Tensor, y_hat: torch.Tensor, eps: float = 1e-8 + ) -> torch.Tensor: return -(torch.log(y_hat + eps) + torch.log(1 - y + eps)).mean() @staticmethod @@ -100,7 +114,7 @@ def get_config(self) -> typing.Mapping: "fine_tune": self.fine_tune, "deviation_reg": self.deviation_reg, "name": self.name, - "_class": self._class + "_class": self._class, } @@ -126,45 +140,59 @@ class Gau(Latent): name Name of the module. """ - def __init__(self, - input_dim: int, - latent_dim: int, - h_dim: int = 128, - depth: int = 1, - dropout: float = 0.0, - lambda_reg: float = 0.001, - fine_tune: bool = False, - deviation_reg: float = 0.0, - name: str = "Gau", - _class: str = "Gau", - **kwargs) -> None: + def __init__( + self, + input_dim: int, + latent_dim: int, + h_dim: int = 128, + depth: int = 1, + dropout: float = 0.0, + lambda_reg: float = 0.001, + fine_tune: bool = False, + deviation_reg: float = 0.0, + name: str = "Gau", + _class: str = "Gau", + **kwargs, + ) -> None: super().__init__( - input_dim, latent_dim, h_dim, depth, dropout, - lambda_reg, fine_tune, deviation_reg, name, _class, **kwargs + input_dim, + latent_dim, + h_dim, + depth, + dropout, + lambda_reg, + fine_tune, + deviation_reg, + name, + _class, + **kwargs, ) self.gau_reg = Regularizer(latent_dim, h_dim, depth, dropout, name="gau") - self.gaup_sampler = D.Normal(loc = torch.tensor(0.0), scale = torch.tensor(1.0)) + self.gaup_sampler = D.Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)) i_dim = [input_dim] + [h_dim] * (depth - 1) if depth > 0 else [] o_dim = [h_dim] * depth dropout = [dropout] * depth self.mlp = MLP(i_dim, o_dim, dropout, bias=False, batch_normalization=True) - self.gau = Linear(h_dim, latent_dim) if depth > 0 else Linear(input_dim, latent_dim) + self.gau = ( + Linear(h_dim, latent_dim) if depth > 0 else Linear(input_dim, latent_dim) + ) - #fine-tune + # fine-tune def save_origin_state(self) -> None: self.mlp.save_origin_state() self.mlp.first_layer_trainable = False self.gau.save_origin_state() - #fine-tune + # fine-tune def deviation_loss(self) -> torch.Tensor: - return self.deviation_reg * \ - (self.mlp.deviation_loss() + self.gau.deviation_loss()) + return self.deviation_reg * ( + self.mlp.deviation_loss() + self.gau.deviation_loss() + ) - #fine_tune + # fine_tune def check_fine_tune(self) -> None: if self.fine_tune: self.save_origin_state() @@ -188,31 +216,65 @@ def fetch_grad(self, x: torch.Tensor, latent_grad: torch.Tensor) -> torch.Tensor return x_with_grad.grad - def d_loss(self, gau: torch.Tensor, feed_dict: typing.Mapping, loss_record: typing.Mapping) -> typing.Tuple[torch.Tensor]: - - gaup = self.gaup_sampler.sample((gau.shape[0], self.latent_dim)).to(config.DEVICE) + def d_loss( + self, gau: torch.Tensor, feed_dict: typing.Mapping, loss_record: typing.Mapping + ) -> typing.Tuple[torch.Tensor]: + gaup = self.gaup_sampler.sample((gau.shape[0], self.latent_dim)).to( + config.DEVICE + ) gau_pred = self.gau_reg(gau) gaup_pred = self.gau_reg(gaup) gau_d_loss = self.gan_d_loss(gau_pred, gaup_pred) - loss_record[self.record_prefix + '/' + self.name + '/' + self.gau_reg.name + '/d_loss/d_loss'] += gau_d_loss.item() * gau.shape[0] + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.gau_reg.name + + "/d_loss/d_loss" + ] += (gau_d_loss.item() * gau.shape[0]) return self.lambda_reg * gau_d_loss - def g_loss(self, gau: torch.Tensor, feed_dict: typing.Mapping, loss_record: typing.Mapping) -> typing.Tuple[torch.Tensor]: - + def g_loss( + self, gau: torch.Tensor, feed_dict: typing.Mapping, loss_record: typing.Mapping + ) -> typing.Tuple[torch.Tensor]: gau_pred = self.gau_reg(gau) gau_g_loss = self.gan_g_loss(gau_pred) - loss_record[self.record_prefix + '/' + self.name + '/' + self.gau_reg.name + '/g_loss/g_loss'] += gau_g_loss.item() * gau.shape[0] + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.gau_reg.name + + "/g_loss/g_loss" + ] += (gau_g_loss.item() * gau.shape[0]) if self.fine_tune: - return self.lambda_reg * gau_g_loss + self.deviation_reg * self.deviation_loss() + return ( + self.lambda_reg * gau_g_loss + + self.deviation_reg * self.deviation_loss() + ) else: return self.lambda_reg * gau_g_loss def init_loss_record(self, loss_record: typing.Mapping) -> None: - - loss_record[self.record_prefix + '/' + self.name + '/' + self.gau_reg.name + '/d_loss/d_loss'] = 0 - loss_record[self.record_prefix + '/' + self.name + '/' + self.gau_reg.name + '/g_loss/g_loss'] = 0 + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.gau_reg.name + + "/d_loss/d_loss" + ] = 0 + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.gau_reg.name + + "/g_loss/g_loss" + ] = 0 def parameters_reg(self): return self.gau_reg.parameters() @@ -224,10 +286,7 @@ def parameters_fit(self): ) def get_config(self) -> typing.Mapping: - return { - **super().get_config() - } - + return {**super().get_config()} class CatGau(Latent): @@ -256,42 +315,54 @@ class CatGau(Latent): name Name of the module. """ - def __init__(self, - input_dim: int, - latent_dim: int, - cat_dim: int, - h_dim: int = 128, - depth: int = 1, - dropout: float = 0.0, - lambda_reg: float = 0.001, - fine_tune: bool = False, - deviation_reg: float = 0.0, - name: str = "CatGau", - _class: str = "CatGau", - **kwargs) -> None: + def __init__( + self, + input_dim: int, + latent_dim: int, + cat_dim: int, + h_dim: int = 128, + depth: int = 1, + dropout: float = 0.0, + lambda_reg: float = 0.001, + fine_tune: bool = False, + deviation_reg: float = 0.0, + name: str = "CatGau", + _class: str = "CatGau", + **kwargs, + ) -> None: super().__init__( - input_dim, latent_dim, h_dim, depth, dropout, - lambda_reg, fine_tune, deviation_reg, name, _class, **kwargs + input_dim, + latent_dim, + h_dim, + depth, + dropout, + lambda_reg, + fine_tune, + deviation_reg, + name, + _class, + **kwargs, ) self.cat_dim = cat_dim self.gau_reg = Regularizer(latent_dim, h_dim, depth, dropout, name="gau") - self.gaup_sampler = D.Normal(loc = torch.tensor(0.0), scale = torch.tensor(1.0)) + self.gaup_sampler = D.Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)) self.cat_reg = Regularizer(cat_dim, h_dim, depth, dropout, name="cat") - self.catp_sampler = D.OneHotCategorical(probs = torch.ones(cat_dim) / cat_dim) + self.catp_sampler = D.OneHotCategorical(probs=torch.ones(cat_dim) / cat_dim) i_dim = [input_dim] + [h_dim] * (depth - 1) if depth > 0 else [] o_dim = [h_dim] * depth dropout = [dropout] * depth self.mlp = MLP(i_dim, o_dim, dropout, bias=False, batch_normalization=True) - self.gau = Linear(h_dim, latent_dim) if depth > 0 else Linear(input_dim, latent_dim) + self.gau = ( + Linear(h_dim, latent_dim) if depth > 0 else Linear(input_dim, latent_dim) + ) self.cat = Linear(h_dim, cat_dim) if depth > 0 else Linear(input_dim, cat_dim) - self.softmax = nn.Softmax(dim = 1) + self.softmax = nn.Softmax(dim=1) self.mat = Linear(cat_dim, latent_dim, bias=False, init_std=0.1, trunc=False) - - #fine-tune + # fine-tune def save_origin_state(self) -> None: self.mlp.save_origin_state() self.mlp.first_layer_trainable = False @@ -299,13 +370,16 @@ def save_origin_state(self) -> None: self.cat.save_origin_state() self.mat.save_origin_state() - #fine-tune + # fine-tune def deviation_loss(self) -> torch.Tensor: - return self.deviation_reg * \ - (self.mlp.deviation_loss() + self.gau.deviation_loss() + \ - self.cat.deviation_loss() + self.mat.deviation_loss()) + return self.deviation_reg * ( + self.mlp.deviation_loss() + + self.gau.deviation_loss() + + self.cat.deviation_loss() + + self.mat.deviation_loss() + ) - #fine_tune + # fine_tune def check_fine_tune(self) -> None: if self.fine_tune: self.save_origin_state() @@ -340,67 +414,129 @@ def fetch_grad(self, x: torch.Tensor, latent_grad: torch.Tensor) -> torch.Tensor return x_with_grad.grad - def d_loss(self, catgau: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping, loss_record: typing.Mapping) -> torch.Tensor: - + def d_loss( + self, + catgau: typing.Tuple[torch.Tensor], + feed_dict: typing.Mapping, + loss_record: typing.Mapping, + ) -> torch.Tensor: gau, cat = catgau - gaup = self.gaup_sampler.sample((gau.shape[0], self.latent_dim)).to(config.DEVICE) + gaup = self.gaup_sampler.sample((gau.shape[0], self.latent_dim)).to( + config.DEVICE + ) gau_pred = self.gau_reg(gau) gaup_pred = self.gau_reg(gaup) gau_d_loss = self.gan_d_loss(gau_pred, gaup_pred) - loss_record[self.record_prefix + '/' + self.name + '/' + self.gau_reg.name + '/d_loss/d_loss'] += gau_d_loss.item() * gau.shape[0] + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.gau_reg.name + + "/d_loss/d_loss" + ] += (gau_d_loss.item() * gau.shape[0]) catp = self.catp_sampler.sample((cat.shape[0],)).to(config.DEVICE) cat_pred = self.cat_reg(cat) catp_pred = self.cat_reg(catp) cat_d_loss = self.gan_d_loss(cat_pred, catp_pred) - loss_record[self.record_prefix + '/' + self.name + '/' + self.cat_reg.name + '/d_loss/d_loss'] += cat_d_loss.item() * cat.shape[0] + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.cat_reg.name + + "/d_loss/d_loss" + ] += (cat_d_loss.item() * cat.shape[0]) return self.lambda_reg * (gau_d_loss + cat_d_loss) - def g_loss(self, catgau: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping, loss_record: typing.Mapping) -> typing.Tuple[torch.Tensor]: - + def g_loss( + self, + catgau: typing.Tuple[torch.Tensor], + feed_dict: typing.Mapping, + loss_record: typing.Mapping, + ) -> typing.Tuple[torch.Tensor]: gau, cat = catgau gau_pred = self.gau_reg(gau) gau_g_loss = self.gan_g_loss(gau_pred) - loss_record[self.record_prefix + '/' + self.name + '/' + self.gau_reg.name + '/g_loss/g_loss'] += gau_g_loss.item() * gau.shape[0] + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.gau_reg.name + + "/g_loss/g_loss" + ] += (gau_g_loss.item() * gau.shape[0]) cat_pred = self.cat_reg(cat) cat_g_loss = self.gan_g_loss(cat_pred) - loss_record[self.record_prefix + '/' + self.name + '/' + self.cat_reg.name + '/g_loss/g_loss'] += cat_g_loss.item() * cat.shape[0] + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.cat_reg.name + + "/g_loss/g_loss" + ] += (cat_g_loss.item() * cat.shape[0]) if self.fine_tune: - return self.lambda_reg * (gau_g_loss + cat_g_loss) + self.deviation_reg * self.deviation_loss() + return ( + self.lambda_reg * (gau_g_loss + cat_g_loss) + + self.deviation_reg * self.deviation_loss() + ) else: return self.lambda_reg * (gau_g_loss + cat_g_loss) def init_loss_record(self, loss_record: typing.Mapping) -> None: - - loss_record[self.record_prefix + '/' + self.name + '/' + self.gau_reg.name + '/d_loss/d_loss'] = 0 - loss_record[self.record_prefix + '/' + self.name + '/' + self.gau_reg.name + '/g_loss/g_loss'] = 0 - loss_record[self.record_prefix + '/' + self.name + '/' + self.cat_reg.name + '/d_loss/d_loss'] = 0 - loss_record[self.record_prefix + '/' + self.name + '/' + self.cat_reg.name + '/g_loss/g_loss'] = 0 + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.gau_reg.name + + "/d_loss/d_loss" + ] = 0 + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.gau_reg.name + + "/g_loss/g_loss" + ] = 0 + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.cat_reg.name + + "/d_loss/d_loss" + ] = 0 + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.cat_reg.name + + "/g_loss/g_loss" + ] = 0 def parameters_reg(self): - return itertools.chain( - self.gau_reg.parameters(), - self.cat_reg.parameters() - ) + return itertools.chain(self.gau_reg.parameters(), self.cat_reg.parameters()) def parameters_fit(self): return itertools.chain( self.mlp.parameters(), self.gau.parameters(), self.cat.parameters(), - self.mat.parameters() + self.mat.parameters(), ) def get_config(self) -> typing.Mapping: - return { - "cat_dim": self.cat_dim, - **super().get_config() - } + return {"cat_dim": self.cat_dim, **super().get_config()} class SemiSupervisedCatGau(CatGau): @@ -439,26 +575,37 @@ class SemiSupervisedCatGau(CatGau): name Name of latent module. """ + def __init__( - self, - input_dim: int, - latent_dim: int, - cat_dim: int, - h_dim: int = 128, - depth: int = 1, - dropout: float = 0.0, - lambda_sup: float = 10.0, - background_catp: float = 1e-3, - lambda_reg: float = 0.001, - fine_tune: bool = False, - deviation_reg: float = 0.0, - name: str = "SemiSupervisedCatGau", - _class: str = "SemiSupervisedCatGau", - **kwargs + self, + input_dim: int, + latent_dim: int, + cat_dim: int, + h_dim: int = 128, + depth: int = 1, + dropout: float = 0.0, + lambda_sup: float = 10.0, + background_catp: float = 1e-3, + lambda_reg: float = 0.001, + fine_tune: bool = False, + deviation_reg: float = 0.0, + name: str = "SemiSupervisedCatGau", + _class: str = "SemiSupervisedCatGau", + **kwargs, ) -> None: super().__init__( - input_dim, latent_dim, cat_dim, h_dim, depth, dropout, - lambda_reg, fine_tune, deviation_reg, name, _class, **kwargs + input_dim, + latent_dim, + cat_dim, + h_dim, + depth, + dropout, + lambda_reg, + fine_tune, + deviation_reg, + name, + _class, + **kwargs, ) self.lambda_sup = lambda_sup self.background_catp = background_catp @@ -471,65 +618,138 @@ def forward(self, x: torch.Tensor) -> typing.Tuple[torch.Tensor]: latent = gau + self.mat(cat) return latent, (gau, cat, cat_logit) - def d_loss(self, catgau: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping, loss_record: typing.Mapping) -> torch.Tensor: - + def d_loss( + self, + catgau: typing.Tuple[torch.Tensor], + feed_dict: typing.Mapping, + loss_record: typing.Mapping, + ) -> torch.Tensor: gau, cat, _ = catgau cats = feed_dict[self.name] - gaup = self.gaup_sampler.sample((gau.shape[0], self.latent_dim)).to(config.DEVICE) + gaup = self.gaup_sampler.sample((gau.shape[0], self.latent_dim)).to( + config.DEVICE + ) gau_pred = self.gau_reg(gau) gaup_pred = self.gau_reg(gaup) gau_d_loss = self.gan_d_loss(gau_pred, gaup_pred) - loss_record[self.record_prefix + '/' + self.name + '/' + self.gau_reg.name + '/d_loss/d_loss'] += gau_d_loss.item() * gau.shape[0] - - cat_prob = torch.ones(self.cat_dim) * self.background_catp + cats.cpu().sum(dim = 0) - catp_sampler = D.OneHotCategorical(probs = cat_prob / cat_prob.sum()) + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.gau_reg.name + + "/d_loss/d_loss" + ] += (gau_d_loss.item() * gau.shape[0]) + + cat_prob = torch.ones(self.cat_dim) * self.background_catp + cats.cpu().sum( + dim=0 + ) + catp_sampler = D.OneHotCategorical(probs=cat_prob / cat_prob.sum()) catp = catp_sampler.sample((cat.shape[0],)).to(config.DEVICE) cat_pred = self.cat_reg(cat) catp_pred = self.cat_reg(catp) cat_d_loss = self.gan_d_loss(cat_pred, catp_pred) - loss_record[self.record_prefix + '/' + self.name + '/' + self.cat_reg.name + '/d_loss/d_loss'] += cat_d_loss.item() * cat.shape[0] + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.cat_reg.name + + "/d_loss/d_loss" + ] += (cat_d_loss.item() * cat.shape[0]) return self.lambda_reg * (gau_d_loss + cat_d_loss) - def g_loss(self, catgau: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping, loss_record: typing.Mapping) -> typing.Tuple[torch.Tensor]: - + def g_loss( + self, + catgau: typing.Tuple[torch.Tensor], + feed_dict: typing.Mapping, + loss_record: typing.Mapping, + ) -> typing.Tuple[torch.Tensor]: gau, cat, cat_logit = catgau cats = feed_dict[self.name] - mask = cat.sum(dim = 1) > 0 + mask = cat.sum(dim=1) > 0 if mask.sum() > 0: - sup_loss = F.cross_entropy(cat_logit[mask], cats[mask].argmax(dim = 1)) + sup_loss = F.cross_entropy(cat_logit[mask], cats[mask].argmax(dim=1)) else: sup_loss = torch.tensor(0) - loss_record['semi_supervision/' + self.name + '/supervised_loss'] += sup_loss.item() * cats.shape[0] + loss_record["semi_supervision/" + self.name + "/supervised_loss"] += ( + sup_loss.item() * cats.shape[0] + ) gau_pred = self.gau_reg(gau) gau_g_loss = self.gan_g_loss(gau_pred) - loss_record[self.record_prefix + '/' + self.name + '/' + self.gau_reg.name + '/g_loss/g_loss'] += gau_g_loss.item() * gau.shape[0] + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.gau_reg.name + + "/g_loss/g_loss" + ] += (gau_g_loss.item() * gau.shape[0]) cat_pred = self.cat_reg(cat) cat_g_loss = self.gan_g_loss(cat_pred) - loss_record[self.record_prefix + '/' + self.name + '/' + self.cat_reg.name + '/g_loss/g_loss'] += cat_g_loss.item() * cat.shape[0] + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.cat_reg.name + + "/g_loss/g_loss" + ] += (cat_g_loss.item() * cat.shape[0]) if self.fine_tune: - return self.lambda_sup * sup_loss + \ - self.lambda_reg * (gau_g_loss + cat_g_loss) + \ - self.deviation_reg * self.deviation_loss() + return ( + self.lambda_sup * sup_loss + + self.lambda_reg * (gau_g_loss + cat_g_loss) + + self.deviation_reg * self.deviation_loss() + ) else: - return self.lambda_sup * sup_loss + self.lambda_reg * (gau_g_loss + cat_g_loss) + return self.lambda_sup * sup_loss + self.lambda_reg * ( + gau_g_loss + cat_g_loss + ) def init_loss_record(self, loss_record: typing.Mapping) -> None: - - loss_record[self.record_prefix + '/' + self.name + '/' + self.gau_reg.name + '/d_loss/d_loss'] = 0 - loss_record[self.record_prefix + '/' + self.name + '/' + self.gau_reg.name + '/g_loss/g_loss'] = 0 - loss_record[self.record_prefix + '/' + self.name + '/' + self.cat_reg.name + '/d_loss/d_loss'] = 0 - loss_record[self.record_prefix + '/' + self.name + '/' + self.cat_reg.name + '/g_loss/g_loss'] = 0 - loss_record['semi_supervision/' + self.name + '/supervised_loss'] = 0 + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.gau_reg.name + + "/d_loss/d_loss" + ] = 0 + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.gau_reg.name + + "/g_loss/g_loss" + ] = 0 + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.cat_reg.name + + "/d_loss/d_loss" + ] = 0 + loss_record[ + self.record_prefix + + "/" + + self.name + + "/" + + self.cat_reg.name + + "/g_loss/g_loss" + ] = 0 + loss_record["semi_supervision/" + self.name + "/supervised_loss"] = 0 def get_config(self) -> typing.Mapping: return { "lambda_sup": self.lambda_sup, "background_catp": self.background_catp, - **super().get_config() - } \ No newline at end of file + **super().get_config(), + } diff --git a/Cell_BLAST/metrics.py b/Cell_BLAST/metrics.py index f194309..b66d706 100644 --- a/Cell_BLAST/metrics.py +++ b/Cell_BLAST/metrics.py @@ -4,13 +4,13 @@ import typing +import anndata as ad +import igraph import numpy as np import pandas as pd import scipy.sparse import sklearn.metrics import sklearn.neighbors -import igraph -import anndata as ad from sklearn.metrics.pairwise import cosine_similarity from . import blast, utils @@ -18,11 +18,11 @@ _identity = lambda x, y: 1 if x == y else 0 -#=============================================================================== +# =============================================================================== # # Cluster based metrics # -#=============================================================================== +# =============================================================================== def confusion_matrix(x: np.ndarray, y: np.ndarray) -> pd.DataFrame: r""" Reimplemented this because sklearn.metrics.confusion_matrix @@ -39,7 +39,7 @@ def confusion_matrix(x: np.ndarray, y: np.ndarray) -> pd.DataFrame: def class_specific_accuracy( - true: np.ndarray, pred: np.ndarray, expectation: pd.DataFrame + true: np.ndarray, pred: np.ndarray, expectation: pd.DataFrame ) -> pd.DataFrame: df = pd.DataFrame(index=np.unique(true), columns=["number", "accuracy"]) expectation = expectation.astype(np.bool) @@ -47,13 +47,17 @@ def class_specific_accuracy( true_mask = true == c pred_mask = np.in1d(pred, expectation.columns[expectation.loc[c]]) df.loc[c, "number"] = true_mask.sum() - df.loc[c, "accuracy"] = np.logical_and(pred_mask, true_mask).sum() / df.loc[c, "number"] + df.loc[c, "accuracy"] = ( + np.logical_and(pred_mask, true_mask).sum() / df.loc[c, "number"] + ) return df def mean_balanced_accuracy( - true: np.ndarray, pred: np.ndarray, expectation: pd.DataFrame, - population_weighed: bool = False + true: np.ndarray, + pred: np.ndarray, + expectation: pd.DataFrame, + population_weighed: bool = False, ) -> float: df = class_specific_accuracy(true, pred, expectation) if population_weighed: @@ -62,9 +66,10 @@ def mean_balanced_accuracy( def cl_accuracy( - cl_dag: utils.CellTypeDAG, - source: np.ndarray, target: np.ndarray, - ref_cl_list: typing.List[str] # a list of unique cl in ref + cl_dag: utils.CellTypeDAG, + source: np.ndarray, + target: np.ndarray, + ref_cl_list: typing.List[str], # a list of unique cl in ref ) -> pd.DataFrame: if len(source) != len(target): raise ValueError("Invalid input: different cell number.") @@ -93,15 +98,33 @@ def cl_accuracy( cache[_target_cl] = 1 # equal or descendant results as 1 elif cl_dag.is_ancestor_of(_target_cl, query_cl): intermediates = set.intersection( - set(cl_dag.graph.bfsiter(cl_dag.get_vertex(query_cl), mode=igraph.OUT)), - set(cl_dag.graph.bfsiter(cl_dag.get_vertex(_target_cl), mode=igraph.IN)) + set( + cl_dag.graph.bfsiter( + cl_dag.get_vertex(query_cl), mode=igraph.OUT + ) + ), + set( + cl_dag.graph.bfsiter( + cl_dag.get_vertex(_target_cl), mode=igraph.IN + ) + ), + ) + cache[_target_cl] = np.mean( + [ + len(list(cl_dag.graph.bfsiter(intermediate, mode=igraph.IN))) + / len( + list( + cl_dag.graph.bfsiter( + cl_dag.get_vertex(_target_cl), mode=igraph.IN + ) + ) + ) + for intermediate in intermediates + if ref_cl_set.intersection( + set(cl_dag.graph.bfsiter(intermediate, mode=igraph.IN)) + ) + ] ) - cache[_target_cl] = np.mean([ - len(list(cl_dag.graph.bfsiter(intermediate, mode=igraph.IN))) / - len(list(cl_dag.graph.bfsiter(cl_dag.get_vertex(_target_cl), mode=igraph.IN))) - for intermediate in intermediates - if ref_cl_set.intersection(set(cl_dag.graph.bfsiter(intermediate, mode=igraph.IN))) - ]) else: cache[_target_cl] = 0 cl_accuracy[i] = cache[_target_cl] @@ -117,40 +140,53 @@ def cl_accuracy( def cl_mba( - cl_dag: utils.CellTypeDAG, - source: np.ndarray, target: np.ndarray, - ref_cl_list: typing.List[str] + cl_dag: utils.CellTypeDAG, + source: np.ndarray, + target: np.ndarray, + ref_cl_list: typing.List[str], ) -> float: - accuracy_df = cl_accuracy(cl_dag=cl_dag, source=source, target=target, ref_cl_list=ref_cl_list) + accuracy_df = cl_accuracy( + cl_dag=cl_dag, source=source, target=target, ref_cl_list=ref_cl_list + ) return accuracy_df["accuracy"].mean() -#=============================================================================== +# =============================================================================== # # Distance based metrics # -#=============================================================================== +# =============================================================================== def nearest_neighbor_accuracy( - x: np.ndarray, y: np.ndarray, metric: str = "minkowski", - similarity: typing.Callable = _identity, n_jobs: int = 1 + x: np.ndarray, + y: np.ndarray, + metric: str = "minkowski", + similarity: typing.Callable = _identity, + n_jobs: int = 1, ) -> float: nearestNeighbors = sklearn.neighbors.NearestNeighbors( - n_neighbors=2, metric=metric, n_jobs=n_jobs) + n_neighbors=2, metric=metric, n_jobs=n_jobs + ) nearestNeighbors.fit(x) nni = nearestNeighbors.kneighbors(x, return_distance=False) return np.vectorize(similarity)(y, y[nni[:, 1].ravel()]).mean() def mean_average_precision_from_latent( - x: np.ndarray, y: np.ndarray, p: typing.Optional[np.ndarray] = None, - k: float = 0.01, metric: str = "minkowski", posterior_metric: str = "npd_v1", - similarity: typing.Callable = _identity, n_jobs: int = 1 + x: np.ndarray, + y: np.ndarray, + p: typing.Optional[np.ndarray] = None, + k: float = 0.01, + metric: str = "minkowski", + posterior_metric: str = "npd_v1", + similarity: typing.Callable = _identity, + n_jobs: int = 1, ) -> float: if k < 1: k = y.shape[0] * k k = np.round(k).astype(np.int) nearestNeighbors = sklearn.neighbors.NearestNeighbors( - n_neighbors=min(y.shape[0], k + 1), metric=metric, n_jobs=n_jobs) + n_neighbors=min(y.shape[0], k + 1), metric=metric, n_jobs=n_jobs + ) nearestNeighbors.fit(x) nni = nearestNeighbors.kneighbors(x, return_distance=False) if p is not None: @@ -158,10 +194,7 @@ def mean_average_precision_from_latent( pnnd = np.empty_like(nni, np.float32) for i in range(pnnd.shape[0]): for j in range(pnnd.shape[1]): - pnnd[i, j] = posterior_metric( - x[i], x[nni[i, j]], - p[i], p[nni[i, j]] - ) + pnnd[i, j] = posterior_metric(x[i], x[nni[i, j]], p[i], p[nni[i, j]]) nni[i] = nni[i][np.argsort(pnnd[i])] return mean_average_precision(y, y[nni[:, 1:]], similarity=similarity) @@ -171,9 +204,13 @@ def average_silhouette_score(x: np.ndarray, y: np.ndarray) -> float: def seurat_alignment_score( - x: np.ndarray, y: np.ndarray, k: float = 0.01, n: int = 1, - metric: str = "minkowski", random_seed: typing.Optional[int] = None, - n_jobs: int = 1 + x: np.ndarray, + y: np.ndarray, + k: float = 0.01, + n: int = 1, + metric: str = "minkowski", + random_seed: typing.Optional[int] = None, + n_jobs: int = 1, ) -> float: random_state = np.random.RandomState(random_seed) idx_list = [np.where(y == _y)[0] for _y in np.unique(y)] @@ -181,40 +218,45 @@ def seurat_alignment_score( subsample_scores = [] for _ in range(n): subsample_idx_list = [ - random_state.choice(idx, subsample_size, replace=False) - for idx in idx_list + random_state.choice(idx, subsample_size, replace=False) for idx in idx_list ] subsample_y = y[np.concatenate(subsample_idx_list)] subsample_x = x[np.concatenate(subsample_idx_list)] _k = subsample_y.shape[0] * k if k < 1 else k _k = np.round(_k).astype(np.int) nearestNeighbors = sklearn.neighbors.NearestNeighbors( - n_neighbors=min(subsample_y.shape[0], _k + 1), - metric=metric, n_jobs=n_jobs + n_neighbors=min(subsample_y.shape[0], _k + 1), metric=metric, n_jobs=n_jobs ) nearestNeighbors.fit(subsample_x) nni = nearestNeighbors.kneighbors(subsample_x, return_distance=False) same_y_hits = ( - subsample_y[nni[:, 1:]] == np.expand_dims(subsample_y, axis=1) - ).sum(axis=1).mean() + (subsample_y[nni[:, 1:]] == np.expand_dims(subsample_y, axis=1)) + .sum(axis=1) + .mean() + ) subsample_scores.append( - (_k - same_y_hits) * len(idx_list) / - (_k * (len(idx_list) - 1)) + (_k - same_y_hits) * len(idx_list) / (_k * (len(idx_list) - 1)) ) return np.mean(subsample_scores) def batch_mixing_entropy( - x: np.ndarray, y: np.ndarray, boots: int = 100, - sample_size: int = 100, k: int = 100, metric: str = "minkowski", - random_seed: typing.Optional[int] = None, n_jobs: int = 1 + x: np.ndarray, + y: np.ndarray, + boots: int = 100, + sample_size: int = 100, + k: int = 100, + metric: str = "minkowski", + random_seed: typing.Optional[int] = None, + n_jobs: int = 1, ) -> float: random_state = np.random.RandomState(random_seed) batches = np.unique(y) entropy = 0 for _ in range(boots): bootsamples = random_state.choice( - np.arange(x.shape[0]), sample_size, replace=False) + np.arange(x.shape[0]), sample_size, replace=False + ) subsample_x = x[bootsamples] neighbor = sklearn.neighbors.NearestNeighbors( n_neighbors=k, metric=metric, n_jobs=n_jobs @@ -232,11 +274,11 @@ def batch_mixing_entropy( return entropy -#=============================================================================== +# =============================================================================== # # Ranking based metrics # -#=============================================================================== +# =============================================================================== def _average_precision(r: np.ndarray) -> float: cummean = np.cumsum(r) / (np.arange(r.size) + 1) mask = r > 0 @@ -246,8 +288,9 @@ def _average_precision(r: np.ndarray) -> float: def mean_average_precision( - true: np.ndarray, hits: np.ndarray, - similarity: typing.Callable[[typing.Any, typing.Any], float] = _identity + true: np.ndarray, + hits: np.ndarray, + similarity: typing.Callable[[typing.Any, typing.Any], float] = _identity, ) -> float: r""" Mean average precision @@ -270,15 +313,18 @@ def mean_average_precision( return np.apply_along_axis(_average_precision, 1, r).mean() -#=============================================================================== +# =============================================================================== # # Structure preservation # -#=============================================================================== +# =============================================================================== def avg_neighbor_jacard( - x: np.ndarray, y: np.ndarray, - x_metric: str = "minkowski", y_metric: str = "minkowski", - k: typing.Union[int, float] = 0.01, n_jobs: int = 1 + x: np.ndarray, + y: np.ndarray, + x_metric: str = "minkowski", + y_metric: str = "minkowski", + k: typing.Union[int, float] = 0.01, + n_jobs: int = 1, ) -> float: r""" Average neighborhood Jacard index. @@ -320,16 +366,16 @@ def avg_neighbor_jacard( n_neighbors=min(n, k + 1), metric=y_metric, n_jobs=n_jobs ).fit(y) nni_y = nn.kneighbors(y, return_distance=False)[:, 1:] - jacard = np.array([ - np.intersect1d(i, j).size / np.union1d(i, j).size - for i, j in zip(nni_x, nni_y) - ]) + jacard = np.array( + [ + np.intersect1d(i, j).size / np.union1d(i, j).size + for i, j in zip(nni_x, nni_y) + ] + ) return jacard.mean() -def jacard_index( - x: scipy.sparse.csr_matrix, y: scipy.sparse.csr_matrix -) -> np.ndarray: +def jacard_index(x: scipy.sparse.csr_matrix, y: scipy.sparse.csr_matrix) -> np.ndarray: r""" Compute Jacard index between two nearest neighbor graphs @@ -346,14 +392,18 @@ def jacard_index( Jacard index for each row """ xy = x + y - return np.asarray((xy == 2).sum(axis=1)).ravel() / \ - np.asarray((xy > 0).sum(axis=1)).ravel() + return ( + np.asarray((xy == 2).sum(axis=1)).ravel() + / np.asarray((xy > 0).sum(axis=1)).ravel() + ) def neighbor_preservation_score( - x: np.ndarray, nng: scipy.sparse.spmatrix, - metric: str = "minkowski", k: typing.Union[int, float] = 0.01, - n_jobs: int = 1 + x: np.ndarray, + nng: scipy.sparse.spmatrix, + metric: str = "minkowski", + k: typing.Union[int, float] = 0.01, + n_jobs: int = 1, ) -> float: if not x.shape[0] == nng.shape[0] == nng.shape[1]: raise ValueError("Inconsistent shape!") @@ -364,24 +414,28 @@ def neighbor_preservation_score( n_neighbors=min(n, k + 1), metric=metric, n_jobs=n_jobs ).fit(x) nni = nn.kneighbors(x, return_distance=False)[:, 1:] - ap = np.array([ - _average_precision(_nng.toarray().ravel()[_nni]) - for _nng, _nni in zip(nng, nni) - ]) - max_ap = np.array([ - _average_precision(np.sort( - _nng.toarray().ravel() - )[::-1][:nni.shape[1]]) for _nng in nng - ]) + ap = np.array( + [ + _average_precision(_nng.toarray().ravel()[_nni]) + for _nng, _nni in zip(nng, nni) + ] + ) + max_ap = np.array( + [ + _average_precision(np.sort(_nng.toarray().ravel())[::-1][: nni.shape[1]]) + for _nng in nng + ] + ) ap /= max_ap return ap.mean() - -def calc_reference_sas(adata: ad.AnnData, - batch_effect: str = 'dataset_name', - cell_ontology: str = 'cell_ontology_class', - similarity: typing.Callable = _identity): +def calc_reference_sas( + adata: ad.AnnData, + batch_effect: str = "dataset_name", + cell_ontology: str = "cell_ontology_class", + similarity: typing.Callable = _identity, +): neighbors_propotion = [] n = len(adata.obs[batch_effect].unique()) for x in adata.obs[batch_effect].unique(): @@ -390,30 +444,43 @@ def calc_reference_sas(adata: ad.AnnData, neighbors = 0 own_neighbors = 0 for j in adata.obs[cell_ontology].unique(): - own_neighbors += similarity(i, j) * \ - ((adata.obs[cell_ontology] == j) & (adata.obs[batch_effect] == x)).sum() - neighbors += similarity(i, j) * \ - (adata.obs[cell_ontology] == j).sum() + own_neighbors += ( + similarity(i, j) + * ( + (adata.obs[cell_ontology] == j) & (adata.obs[batch_effect] == x) + ).sum() + ) + neighbors += similarity(i, j) * (adata.obs[cell_ontology] == j).sum() - propotion += own_neighbors / neighbors * \ - ((adata.obs[cell_ontology] == i) & (adata.obs[batch_effect] == x)).sum() + propotion += ( + own_neighbors + / neighbors + * ( + (adata.obs[cell_ontology] == i) & (adata.obs[batch_effect] == x) + ).sum() + ) propotion = propotion / (adata.obs[batch_effect] == x).sum() neighbors_propotion.append(propotion) neighbors_propotion = np.mean(neighbors_propotion) - return 1 - (neighbors_propotion - 1/n) / (1 - 1/n) + return 1 - (neighbors_propotion - 1 / n) / (1 - 1 / n) + def mean_average_correlation( - x: np.ndarray, y: np.ndarray, b: np.ndarray, k: float = 0.001, - metric: str = "minkowski", n_jobs: int = 1 + x: np.ndarray, + y: np.ndarray, + b: np.ndarray, + k: float = 0.001, + metric: str = "minkowski", + n_jobs: int = 1, ) -> float: - if k < 1: k = y.shape[0] * k k = np.round(k).astype(np.int) nearestNeighbors = sklearn.neighbors.NearestNeighbors( - n_neighbors=min(y.shape[0], k + 1), metric=metric, n_jobs=n_jobs) + n_neighbors=min(y.shape[0], k + 1), metric=metric, n_jobs=n_jobs + ) nearestNeighbors.fit(x) nn = nearestNeighbors.kneighbors(x, return_distance=False) correlation = [] @@ -421,4 +488,4 @@ def mean_average_correlation( diff = (b != b[nni[0]])[nni] if diff.sum() > 0: correlation.append(cosine_similarity(y[nni[[0]]], y[nni][diff]).mean()) - return np.float64(np.mean(correlation)) \ No newline at end of file + return np.float64(np.mean(correlation)) diff --git a/Cell_BLAST/prob.py b/Cell_BLAST/prob.py index ec45124..f82792a 100644 --- a/Cell_BLAST/prob.py +++ b/Cell_BLAST/prob.py @@ -3,26 +3,34 @@ """ import math +import typing + import torch -from torch import nn import torch.nn.functional as F -import typing -from .rebuild import Linear -from .rebuild import MLP +from torch import nn + from . import utils +from .rebuild import MLP, Linear + class ProbModel(nn.Module): r""" Abstract base class for generative model modules. """ + def __init__( - self, output_dim: int, full_latent_dim: typing.Tuple[int], - h_dim: int = 128, depth: int = 1, - dropout: float = 0.0, lambda_reg: float = 0.0, - fine_tune: bool = False, deviation_reg: float = 0.0, - name: str = "ProbModel", - _class: str = "ProbModel", - **kwargs + self, + output_dim: int, + full_latent_dim: typing.Tuple[int], + h_dim: int = 128, + depth: int = 1, + dropout: float = 0.0, + lambda_reg: float = 0.0, + fine_tune: bool = False, + deviation_reg: float = 0.0, + name: str = "ProbModel", + _class: str = "ProbModel", + **kwargs, ) -> None: super().__init__() self.output_dim = output_dim @@ -35,7 +43,7 @@ def __init__( self.deviation_reg = deviation_reg self.name = name self._class = _class - self.record_prefix = 'decoder' + self.record_prefix = "decoder" for key in kwargs.keys(): utils.logger.warning("Argument `%s` is no longer supported!" % key) @@ -46,7 +54,7 @@ def __init__( if depth > 0: dropout[0] = 0.0 self.mlp = MLP(i_dim, o_dim, dropout) - + def get_config(self) -> typing.Mapping: return { "output_dim": self.output_dim, @@ -58,13 +66,14 @@ def get_config(self) -> typing.Mapping: "fine_tune": self.fine_tune, "deviation_reg": self.deviation_reg, "name": self.name, - "_class": self._class + "_class": self._class, } + class NB(ProbModel): # Negative binomial r""" Build a Negative Binomial generative module. - + Parameters ---------- output_dim @@ -88,90 +97,124 @@ class NB(ProbModel): # Negative binomial name Name of the module. """ + def __init__( - self, - output_dim: int, - full_latent_dim: typing.Tuple[int], - h_dim: int = 128, - depth: int = 1, - dropout: float = 0.0, - lambda_reg: float = 0.0, - fine_tune: bool = False, - deviation_reg: float = 0.0, - name: str = "NB", - _class: str = "NB", - **kwargs + self, + output_dim: int, + full_latent_dim: typing.Tuple[int], + h_dim: int = 128, + depth: int = 1, + dropout: float = 0.0, + lambda_reg: float = 0.0, + fine_tune: bool = False, + deviation_reg: float = 0.0, + name: str = "NB", + _class: str = "NB", + **kwargs, ) -> None: super().__init__( - output_dim, full_latent_dim, h_dim, depth, dropout, lambda_reg, - fine_tune, deviation_reg, name, _class, **kwargs + output_dim, + full_latent_dim, + h_dim, + depth, + dropout, + lambda_reg, + fine_tune, + deviation_reg, + name, + _class, + **kwargs, + ) + + self.mu = ( + Linear(h_dim, output_dim) + if depth > 0 + else Linear(full_latent_dim, output_dim) + ) + self.softmax = nn.Softmax(dim=1) + self.log_theta = ( + Linear(h_dim, output_dim) + if depth > 0 + else Linear(full_latent_dim, output_dim) ) - self.mu = Linear(h_dim, output_dim) if depth > 0 else Linear(full_latent_dim, output_dim) - self.softmax = nn.Softmax(dim = 1) - self.log_theta = Linear(h_dim, output_dim) if depth > 0 else Linear(full_latent_dim, output_dim) - - #fine-tune + # fine-tune def save_origin_state(self) -> None: self.mlp.save_origin_state() self.mu.save_origin_state() self.log_theta.save_origin_state() - - #fine-tune + + # fine-tune def deviation_loss(self) -> torch.Tensor: - return self.deviation_reg * \ - (self.mlp.deviation_loss() + self.mu.deviation_loss() + self.log_theta.deviation_loss()) - - #fine_tune + return self.deviation_reg * ( + self.mlp.deviation_loss() + + self.mu.deviation_loss() + + self.log_theta.deviation_loss() + ) + + # fine_tune def check_fine_tune(self) -> None: if self.fine_tune: self.save_origin_state() - + @staticmethod - def log_likelihood(x: torch.Tensor, mu: torch.Tensor, log_theta: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + def log_likelihood( + x: torch.Tensor, mu: torch.Tensor, log_theta: torch.Tensor, eps: float = 1e-8 + ) -> torch.Tensor: theta = torch.exp(log_theta) - return theta * log_theta \ - - theta * torch.log(theta + mu + eps) \ - + x * torch.log(mu + eps) - x * torch.log(theta + mu + eps) \ - + torch.lgamma(x + theta) - torch.lgamma(theta) \ + return ( + theta * log_theta + - theta * torch.log(theta + mu + eps) + + x * torch.log(mu + eps) + - x * torch.log(theta + mu + eps) + + torch.lgamma(x + theta) + - torch.lgamma(theta) - torch.lgamma(x + 1) - - def forward(self, full_x: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping) -> torch.Tensor: - - y = feed_dict['exprs'] + ) + + def forward( + self, full_x: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping + ) -> torch.Tensor: + y = feed_dict["exprs"] x = self.mlp(full_x) - + softmax_mu = self.softmax(self.mu(x)) - mu = softmax_mu * y.sum(dim = 1, keepdim = True) + mu = softmax_mu * y.sum(dim=1, keepdim=True) log_theta = self.log_theta(x) return mu, log_theta - - def loss(self, mu_theta: typing.Tuple[torch.Tensor], - feed_dict: typing.Mapping, loss_record: typing.Mapping) -> torch.Tensor: - y = feed_dict['exprs'] + def loss( + self, + mu_theta: typing.Tuple[torch.Tensor], + feed_dict: typing.Mapping, + loss_record: typing.Mapping, + ) -> torch.Tensor: + y = feed_dict["exprs"] mu, log_theta = mu_theta raw_loss = -self.log_likelihood(y, mu, log_theta).mean() - loss_record[self.record_prefix + '/' + self.name + '/raw_loss'] += raw_loss.item() * mu.shape[0] + loss_record[self.record_prefix + "/" + self.name + "/raw_loss"] += ( + raw_loss.item() * mu.shape[0] + ) reg_loss = raw_loss + self.lambda_reg * log_theta.var() - loss_record[self.record_prefix + '/' + self.name + '/regularized_loss'] += reg_loss.item() * mu.shape[0] + loss_record[self.record_prefix + "/" + self.name + "/regularized_loss"] += ( + reg_loss.item() * mu.shape[0] + ) if self.fine_tune: return reg_loss + self.deviation_reg * self.deviation_loss() else: return reg_loss - + def init_loss_record(self, loss_record: typing.Mapping) -> None: - - loss_record[self.record_prefix + '/' + self.name + '/raw_loss'] = 0 - loss_record[self.record_prefix + '/' + self.name + '/regularized_loss'] = 0 + loss_record[self.record_prefix + "/" + self.name + "/raw_loss"] = 0 + loss_record[self.record_prefix + "/" + self.name + "/regularized_loss"] = 0 class ZINB(NB): r""" Build a Zero-Inflated Negative Binomial generative module. - + Parameters ---------- output_dim @@ -195,83 +238,119 @@ class ZINB(NB): name Name of the module. """ + def __init__( - self, - output_dim: int, - full_latent_dim: typing.Tuple[int], - h_dim: int = 128, - depth: int = 1, - dropout: float = 0.0, - lambda_reg: float = 0.0, - fine_tune: bool = False, - deviation_reg: float = 0.0, - name: str = "ZINB", - _class: str = "ZINB", - **kwargs + self, + output_dim: int, + full_latent_dim: typing.Tuple[int], + h_dim: int = 128, + depth: int = 1, + dropout: float = 0.0, + lambda_reg: float = 0.0, + fine_tune: bool = False, + deviation_reg: float = 0.0, + name: str = "ZINB", + _class: str = "ZINB", + **kwargs, ) -> None: super().__init__( - output_dim, full_latent_dim, h_dim, depth, dropout, lambda_reg, - fine_tune, deviation_reg, name, _class, **kwargs + output_dim, + full_latent_dim, + h_dim, + depth, + dropout, + lambda_reg, + fine_tune, + deviation_reg, + name, + _class, + **kwargs, + ) + + self.pi = ( + Linear(h_dim, output_dim) + if depth > 0 + else Linear(full_latent_dim, output_dim) ) - self.pi = Linear(h_dim, output_dim) if depth > 0 else Linear(full_latent_dim, output_dim) - - #fine-tune + # fine-tune def save_origin_state(self) -> None: self.mlp.save_origin_state() self.mu.save_origin_state() self.log_theta.save_origin_state() self.pi.save_origin_state() - - #fine-tune + + # fine-tune def deviation_loss(self) -> torch.Tensor: - return self.deviation_reg * \ - (self.mlp.deviation_loss() + self.mu.deviation_loss() + self.log_theta.deviation_loss() + self.pi.deviation_loss()) - - #fine_tune + return self.deviation_reg * ( + self.mlp.deviation_loss() + + self.mu.deviation_loss() + + self.log_theta.deviation_loss() + + self.pi.deviation_loss() + ) + + # fine_tune def check_fine_tune(self) -> None: if self.fine_tune: self.save_origin_state() - + @staticmethod - def log_likelihood(x: torch.Tensor, mu: torch.Tensor, log_theta: torch.Tensor, pi: torch.tensor, eps: float = 1e-8) -> torch.Tensor: + def log_likelihood( + x: torch.Tensor, + mu: torch.Tensor, + log_theta: torch.Tensor, + pi: torch.tensor, + eps: float = 1e-8, + ) -> torch.Tensor: theta = torch.exp(log_theta) case_zero = F.softplus( - - pi + theta * log_theta - - theta * torch.log(theta + mu + eps) - ) - F.softplus(- pi) - case_non_zero = - pi - F.softplus(- pi) \ - + theta * log_theta \ - - theta * torch.log(theta + mu + eps) \ - + x * torch.log(mu + eps) - x * torch.log(theta + mu + eps) \ - + torch.lgamma(x + theta) - torch.lgamma(theta) \ + -pi + theta * log_theta - theta * torch.log(theta + mu + eps) + ) - F.softplus(-pi) + case_non_zero = ( + -pi + - F.softplus(-pi) + + theta * log_theta + - theta * torch.log(theta + mu + eps) + + x * torch.log(mu + eps) + - x * torch.log(theta + mu + eps) + + torch.lgamma(x + theta) + - torch.lgamma(theta) - torch.lgamma(x + 1) + ) mask = (x < eps).float() res = mask * case_zero + (1 - mask) * case_non_zero return res - - def forward(self, full_x: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping) -> torch.Tensor: - - y = feed_dict['exprs'] + + def forward( + self, full_x: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping + ) -> torch.Tensor: + y = feed_dict["exprs"] x = self.mlp(full_x) - + softmax_mu = self.softmax(self.mu(x)) - mu = softmax_mu * y.sum(dim = 1, keepdim = True) + mu = softmax_mu * y.sum(dim=1, keepdim=True) log_theta = self.log_theta(x) pi = self.pi(x) return mu, log_theta, pi - - def loss(self, mu_theta_pi: typing.Tuple[torch.Tensor], - feed_dict: typing.Mapping, loss_record: typing.Mapping) -> torch.Tensor: - y = feed_dict['exprs'] + def loss( + self, + mu_theta_pi: typing.Tuple[torch.Tensor], + feed_dict: typing.Mapping, + loss_record: typing.Mapping, + ) -> torch.Tensor: + y = feed_dict["exprs"] mu, log_theta, pi = mu_theta_pi raw_loss = -self.log_likelihood(y, mu, log_theta, pi).mean() - loss_record[self.record_prefix + '/' + self.name + '/raw_loss'] += raw_loss.item() * mu.shape[0] + loss_record[self.record_prefix + "/" + self.name + "/raw_loss"] += ( + raw_loss.item() * mu.shape[0] + ) reg_loss = raw_loss + self.lambda_reg * log_theta.var() - loss_record[self.record_prefix + '/' + self.name + '/regularized_loss'] += reg_loss.item() * mu.shape[0] + loss_record[self.record_prefix + "/" + self.name + "/regularized_loss"] += ( + reg_loss.item() * mu.shape[0] + ) if self.fine_tune: return reg_loss + self.deviation_reg * self.deviation_loss() @@ -282,7 +361,7 @@ def loss(self, mu_theta_pi: typing.Tuple[torch.Tensor], class LN(ProbModel): r""" Build a Log Normal generative module. - + Parameters ---------- output_dim @@ -304,84 +383,113 @@ class LN(ProbModel): name Name of the module. """ + def __init__( - self, - output_dim: int, - full_latent_dim: typing.Tuple[int], - h_dim: int = 128, - depth: int = 1, - dropout: float = 0.0, - lambda_reg: float = 0.0, - fine_tune: bool = False, - deviation_reg: float = 0.0, - name: str = "LN", - _class: str = "LN", - **kwargs + self, + output_dim: int, + full_latent_dim: typing.Tuple[int], + h_dim: int = 128, + depth: int = 1, + dropout: float = 0.0, + lambda_reg: float = 0.0, + fine_tune: bool = False, + deviation_reg: float = 0.0, + name: str = "LN", + _class: str = "LN", + **kwargs, ) -> None: super().__init__( - output_dim, full_latent_dim, h_dim, depth, dropout, lambda_reg, - fine_tune, deviation_reg, name, _class, **kwargs + output_dim, + full_latent_dim, + h_dim, + depth, + dropout, + lambda_reg, + fine_tune, + deviation_reg, + name, + _class, + **kwargs, + ) + + self.mu = ( + Linear(h_dim, output_dim) + if depth > 0 + else Linear(full_latent_dim, output_dim) + ) + self.log_var = ( + Linear(h_dim, output_dim) + if depth > 0 + else Linear(full_latent_dim, output_dim) ) - self.mu = Linear(h_dim, output_dim) if depth > 0 else Linear(full_latent_dim, output_dim) - self.log_var = Linear(h_dim, output_dim) if depth > 0 else Linear(full_latent_dim, output_dim) - - #fine-tune + # fine-tune def save_origin_state(self) -> None: self.mlp.save_origin_state() self.mu.save_origin_state() self.log_var.save_origin_state() - - #fine-tune + + # fine-tune def deviation_loss(self) -> torch.Tensor: - return self.deviation_reg * \ - (self.mlp.deviation_loss() + self.mu.deviation_loss() + self.log_var.deviation_loss()) - - #fine_tune + return self.deviation_reg * ( + self.mlp.deviation_loss() + + self.mu.deviation_loss() + + self.log_var.deviation_loss() + ) + + # fine_tune def check_fine_tune(self) -> None: if self.fine_tune: self.save_origin_state() - + @staticmethod - def log_likelihood(x: torch.Tensor, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: - return - 0.5 * ( - torch.square(x - mu) / torch.exp(log_var) - + math.log(2 * math.pi) + log_var - ) - - def forward(self, full_x: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping) -> torch.Tensor: - + def log_likelihood( + x: torch.Tensor, mu: torch.Tensor, log_var: torch.Tensor + ) -> torch.Tensor: + return -0.5 * ( + torch.square(x - mu) / torch.exp(log_var) + math.log(2 * math.pi) + log_var + ) + + def forward( + self, full_x: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping + ) -> torch.Tensor: x = self.mlp(full_x) mu = torch.expm1(self.mu(x)) log_var = self.log_var(x) return mu, log_var - - def loss(self, mu_var: typing.Tuple[torch.Tensor], - feed_dict: typing.Mapping, loss_record: typing.Mapping) -> torch.Tensor: - y = feed_dict['exprs'] + def loss( + self, + mu_var: typing.Tuple[torch.Tensor], + feed_dict: typing.Mapping, + loss_record: typing.Mapping, + ) -> torch.Tensor: + y = feed_dict["exprs"] mu, log_var = mu_var raw_loss = -self.log_likelihood(torch.log1p(y), mu, log_var).mean() - loss_record[self.record_prefix + '/' + self.name + '/raw_loss'] += raw_loss.item() * mu.shape[0] + loss_record[self.record_prefix + "/" + self.name + "/raw_loss"] += ( + raw_loss.item() * mu.shape[0] + ) reg_loss = raw_loss - loss_record[self.record_prefix + '/' + self.name + '/regularized_loss'] += reg_loss.item() * mu.shape[0] + loss_record[self.record_prefix + "/" + self.name + "/regularized_loss"] += ( + reg_loss.item() * mu.shape[0] + ) if self.fine_tune: return reg_loss + self.deviation_reg * self.deviation_loss() else: return reg_loss - + def init_loss_record(self, loss_record: typing.Mapping) -> None: - - loss_record[self.record_prefix + '/' + self.name + '/raw_loss'] = 0 - loss_record[self.record_prefix + '/' + self.name + '/regularized_loss'] = 0 + loss_record[self.record_prefix + "/" + self.name + "/raw_loss"] = 0 + loss_record[self.record_prefix + "/" + self.name + "/regularized_loss"] = 0 class ZILN(LN): r""" Build a Zero-Inflated Log Normal generative module. - + Parameters ---------- output_dim @@ -403,73 +511,111 @@ class ZILN(LN): name Name of the module. """ + def __init__( - self, - output_dim: int, - full_latent_dim: typing.Tuple[int], - h_dim: int = 128, - depth: int = 1, - dropout: float = 0.0, - lambda_reg: float = 0.0, - fine_tune: bool = False, - deviation_reg: float = 0.0, - name: str = "ZILN", - _class: str = "ZILN", - **kwargs + self, + output_dim: int, + full_latent_dim: typing.Tuple[int], + h_dim: int = 128, + depth: int = 1, + dropout: float = 0.0, + lambda_reg: float = 0.0, + fine_tune: bool = False, + deviation_reg: float = 0.0, + name: str = "ZILN", + _class: str = "ZILN", + **kwargs, ) -> None: super().__init__( - output_dim, full_latent_dim, h_dim, depth, dropout, lambda_reg, - fine_tune, deviation_reg, name, _class, **kwargs + output_dim, + full_latent_dim, + h_dim, + depth, + dropout, + lambda_reg, + fine_tune, + deviation_reg, + name, + _class, + **kwargs, ) - self.pi = Linear(h_dim, output_dim) if depth > 0 else Linear(full_latent_dim, output_dim) - - #fine-tune + self.pi = ( + Linear(h_dim, output_dim) + if depth > 0 + else Linear(full_latent_dim, output_dim) + ) + + # fine-tune def save_origin_state(self) -> None: self.mlp.save_origin_state() self.mu.save_origin_state() self.log_var.save_origin_state() self.pi.save_origin_state() - - #fine-tune + + # fine-tune def deviation_loss(self) -> torch.Tensor: - return self.deviation_reg * \ - (self.mlp.deviation_loss() + self.mu.deviation_loss() + self.log_var.deviation_loss() + self.pi.deviation_loss()) - - #fine_tune + return self.deviation_reg * ( + self.mlp.deviation_loss() + + self.mu.deviation_loss() + + self.log_var.deviation_loss() + + self.pi.deviation_loss() + ) + + # fine_tune def check_fine_tune(self) -> None: if self.fine_tune: self.save_origin_state() - + @staticmethod - def log_likelihood(x: torch.Tensor, mu: torch.Tensor, log_var: torch.Tensor, pi: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: - case_zero = - F.softplus(- pi) - case_non_zero = - pi - F.softplus(- pi) - 0.5 * ( - torch.square(x - mu) / torch.exp(log_var) - + math.log(2 * math.pi) + log_var + def log_likelihood( + x: torch.Tensor, + mu: torch.Tensor, + log_var: torch.Tensor, + pi: torch.Tensor, + eps: float = 1e-8, + ) -> torch.Tensor: + case_zero = -F.softplus(-pi) + case_non_zero = ( + -pi + - F.softplus(-pi) + - 0.5 + * ( + torch.square(x - mu) / torch.exp(log_var) + + math.log(2 * math.pi) + + log_var + ) ) mask = (x < eps).float() res = mask * case_zero + (1 - mask) * case_non_zero return res - - def forward(self, full_x: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping) -> torch.Tensor: - + + def forward( + self, full_x: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping + ) -> torch.Tensor: x = self.mlp(full_x) mu = torch.expm1(self.mu(x)) log_var = self.log_var(x) pi = self.pi(x) return mu, log_var, pi - - def loss(self, mu_var_pi: typing.Tuple[torch.Tensor], - feed_dict: typing.Mapping, loss_record: typing.Mapping) -> torch.Tensor: - y = feed_dict['exprs'] + def loss( + self, + mu_var_pi: typing.Tuple[torch.Tensor], + feed_dict: typing.Mapping, + loss_record: typing.Mapping, + ) -> torch.Tensor: + y = feed_dict["exprs"] mu, log_var, pi = mu_var_pi raw_loss = -self.log_likelihood(torch.log1p(y), mu, log_var, pi).mean() - loss_record[self.record_prefix + '/' + self.name + '/raw_loss'] += raw_loss.item() * mu.shape[0] + loss_record[self.record_prefix + "/" + self.name + "/raw_loss"] += ( + raw_loss.item() * mu.shape[0] + ) reg_loss = raw_loss - loss_record[self.record_prefix + '/' + self.name + '/regularized_loss'] += reg_loss.item() * mu.shape[0] + loss_record[self.record_prefix + "/" + self.name + "/regularized_loss"] += ( + reg_loss.item() * mu.shape[0] + ) if self.fine_tune: return reg_loss + self.deviation_reg * self.deviation_loss() @@ -479,5 +625,7 @@ def loss(self, mu_var_pi: typing.Tuple[torch.Tensor], class MSE(ProbModel): def __init__(self, *args, **kwargs): - utils.logger.warning("Prob module `MSE` is no longer supported, running as `ProbModel`") - super().__init__(*args, **kwargs) \ No newline at end of file + utils.logger.warning( + "Prob module `MSE` is no longer supported, running as `ProbModel`" + ) + super().__init__(*args, **kwargs) diff --git a/Cell_BLAST/rebuild.py b/Cell_BLAST/rebuild.py index f381378..66fc8c1 100644 --- a/Cell_BLAST/rebuild.py +++ b/Cell_BLAST/rebuild.py @@ -2,14 +2,15 @@ Rebuild basic NN components in PyTorch to follows TensorFlow behaviors """ +import typing + import torch -from torch import nn import torch.nn.functional as F -import typing +from torch import nn -class Linear(nn.Module): - __constants__ = ['in_features', 'out_features'] +class Linear(nn.Module): + __constants__ = ["in_features", "out_features"] in_features: int out_features: int weight: torch.Tensor @@ -17,9 +18,14 @@ class Linear(nn.Module): # allow multiple input tensors # add initialize standard deviation # add truncating option - def __init__(self, in_features: typing.Union[typing.Tuple[int], int], out_features: int, bias: bool = True, - init_std: float = 0.01, trunc: bool = True) -> "Linear": - + def __init__( + self, + in_features: typing.Union[typing.Tuple[int], int], + out_features: int, + bias: bool = True, + init_std: float = 0.01, + trunc: bool = True, + ) -> "Linear": if not isinstance(in_features, list) and not isinstance(in_features, tuple): in_features = [in_features] @@ -33,24 +39,31 @@ def __init__(self, in_features: typing.Union[typing.Tuple[int], int], out_featur if bias: self.bias = nn.Parameter(torch.Tensor(out_features)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) self.init_std = init_std self.trunc = trunc self.reset_parameters() def reset_parameters(self): - #change initializers + # change initializers if self.trunc: for _weight in self.weight: - nn.init.trunc_normal_(_weight, std = self.init_std, a = -2 * self.init_std, b = 2 * self.init_std) + nn.init.trunc_normal_( + _weight, + std=self.init_std, + a=-2 * self.init_std, + b=2 * self.init_std, + ) else: for _weight in self.weight: - nn.init.normal_(_weight, std = self.init_std) + nn.init.normal_(_weight, std=self.init_std) if self.bias is not None: nn.init.zeros_(self.bias) - def forward(self, input: typing.Union[typing.Tuple[torch.Tensor], torch.Tensor]) -> torch.Tensor: + def forward( + self, input: typing.Union[typing.Tuple[torch.Tensor], torch.Tensor] + ) -> torch.Tensor: if not isinstance(input, list) and not isinstance(input, tuple): input = [input] for i, (_input, _weight) in enumerate(zip(input, self.weight)): @@ -59,23 +72,24 @@ def forward(self, input: typing.Union[typing.Tuple[torch.Tensor], torch.Tensor]) else: result = F.linear(_input, _weight, self.bias) return result - #return F.linear(torch.cat(input, dim=-1), torch.cat(list(self.weight), dim=-1), self.bias) + # return F.linear(torch.cat(input, dim=-1), torch.cat(list(self.weight), dim=-1), self.bias) def extra_repr(self) -> str: - return 'in_features={}, out_features={}, bias={}'.format( + return "in_features={}, out_features={}, bias={}".format( self.in_features, self.out_features, self.bias is not None ) - #fine-tune + # fine-tune def save_origin_state(self) -> None: - self.weight_origin = [_weight.detach().clone() \ - for _weight in self.weight] + self.weight_origin = [_weight.detach().clone() for _weight in self.weight] if self.bias is not None: self.bias_origin = self.bias.detach().clone() - #fine-tune + # fine-tune def deviation_loss(self) -> torch.Tensor: - for i, (_weight, _weight_origin) in enumerate(zip(self.weight, self.weight_origin)): + for i, (_weight, _weight_origin) in enumerate( + zip(self.weight, self.weight_origin) + ): if i: result = result + F.mse_loss(_weight, _weight_origin) else: @@ -86,9 +100,18 @@ def deviation_loss(self) -> torch.Tensor: class RMSprop(torch.optim.Optimizer): - - def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, - decoupled_decay=False, lr_in_momentum=True): + def __init__( + self, + params, + lr=1e-2, + alpha=0.9, + eps=1e-10, + weight_decay=0, + momentum=0.0, + centered=False, + decoupled_decay=False, + lr_in_momentum=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -101,16 +124,23 @@ def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, moment raise ValueError("Invalid alpha value: {}".format(alpha)) defaults = dict( - lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, - decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) + lr=lr, + momentum=momentum, + alpha=alpha, + eps=eps, + centered=centered, + weight_decay=weight_decay, + decoupled_decay=decoupled_decay, + lr_in_momentum=lr_in_momentum, + ) super(RMSprop, self).__init__(params, defaults) def __setstate__(self, state): super(RMSprop, self).__setstate__(state) for group in self.param_groups: - group.setdefault('momentum', 0) - group.setdefault('centered', False) + group.setdefault("momentum", 0) + group.setdefault("centered", False) @torch.no_grad() def step(self, closure=None): @@ -125,68 +155,77 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad if grad.is_sparse: - raise RuntimeError('RMSprop does not support sparse gradients') + raise RuntimeError("RMSprop does not support sparse gradients") state = self.state[p] # State initialization if len(state) == 0: - state['step'] = 0 - state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero - if group['momentum'] > 0: - state['momentum_buffer'] = torch.zeros_like(p) - if group['centered']: - state['grad_avg'] = torch.zeros_like(p) + state["step"] = 0 + state["square_avg"] = torch.ones_like(p) # PyTorch inits to zero + if group["momentum"] > 0: + state["momentum_buffer"] = torch.zeros_like(p) + if group["centered"]: + state["grad_avg"] = torch.zeros_like(p) - square_avg = state['square_avg'] - one_minus_alpha = 1. - group['alpha'] + square_avg = state["square_avg"] + one_minus_alpha = 1.0 - group["alpha"] - state['step'] += 1 + state["step"] += 1 - if group['weight_decay'] != 0: - if group['decoupled_decay']: - p.mul_(1. - group['lr'] * group['weight_decay']) + if group["weight_decay"] != 0: + if group["decoupled_decay"]: + p.mul_(1.0 - group["lr"] * group["weight_decay"]) else: - grad = grad.add(p, alpha=group['weight_decay']) + grad = grad.add(p, alpha=group["weight_decay"]) # Tensorflow order of ops for updating squared avg square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha) # square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) # PyTorch original - if group['centered']: - grad_avg = state['grad_avg'] + if group["centered"]: + grad_avg = state["grad_avg"] grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha) - avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(group['eps']).sqrt_() # eps in sqrt + avg = ( + square_avg.addcmul(grad_avg, grad_avg, value=-1) + .add(group["eps"]) + .sqrt_() + ) # eps in sqrt # grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) # PyTorch original else: - avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt + avg = square_avg.add(group["eps"]).sqrt_() # eps moved in sqrt - if group['momentum'] > 0: - buf = state['momentum_buffer'] + if group["momentum"] > 0: + buf = state["momentum_buffer"] # Tensorflow accumulates the LR scaling in the momentum buffer - if group['lr_in_momentum']: - buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr']) + if group["lr_in_momentum"]: + buf.mul_(group["momentum"]).addcdiv_( + grad, avg, value=group["lr"] + ) p.add_(-buf) else: # PyTorch scales the param update by LR - buf.mul_(group['momentum']).addcdiv_(grad, avg) - p.add_(buf, alpha=-group['lr']) + buf.mul_(group["momentum"]).addcdiv_(grad, avg) + p.add_(buf, alpha=-group["lr"]) else: - p.addcdiv_(grad, avg, value=-group['lr']) + p.addcdiv_(grad, avg, value=-group["lr"]) return loss - - class MLP(nn.Module): def __init__( - self, i_dim: typing.Tuple[int], o_dim: typing.Tuple[int], dropout: typing.Tuple[float], - bias: bool = True, batch_normalization: bool = False, activation: bool = True + self, + i_dim: typing.Tuple[int], + o_dim: typing.Tuple[int], + dropout: typing.Tuple[float], + bias: bool = True, + batch_normalization: bool = False, + activation: bool = True, ) -> None: super().__init__() self.i_dim = i_dim @@ -200,16 +239,15 @@ def __init__( module_seq = [] for _i_dim, _o_dim, _dropout in zip(i_dim, o_dim, dropout): - if _dropout > 0: module_seq.append(nn.Dropout(dropout)) - hidden = Linear(_i_dim, _o_dim, bias = bias) + hidden = Linear(_i_dim, _o_dim, bias=bias) self.hiddens.append(hidden) module_seq.append(hidden) if batch_normalization: - module_seq.append(nn.BatchNorm1d(_o_dim, eps = 0.001, momentum = 0.01)) + module_seq.append(nn.BatchNorm1d(_o_dim, eps=0.001, momentum=0.01)) if activation: module_seq.append(nn.LeakyReLU(0.2)) @@ -221,19 +259,19 @@ def __init__( def forward(self, x: torch.Tensor): return self.layer_seq(x) - #fine-tune + # fine-tune def save_origin_state(self) -> None: for _hidden in self.hiddens: _hidden.save_origin_state() - #fine-tune + # fine-tune def deviation_loss(self) -> torch.Tensor: loss = torch.tensor(0) for _hidden in self.hiddens: loss = loss + _hidden.deviation_loss() return loss - #fine-tune + # fine-tune def requires_grad_(self, requires_grad: bool = True): super().requires_grad_(requires_grad) if not self.first_layer_trainable: @@ -249,4 +287,4 @@ def first_layer_trainable(self): def first_layer_trainable(self, flag: bool): self._first_layer_trainable = flag if len(self.hiddens) > 0: - self.hiddens[0].requires_grad_(flag) \ No newline at end of file + self.hiddens[0].requires_grad_(flag) diff --git a/Cell_BLAST/rmbatch.py b/Cell_BLAST/rmbatch.py index dae9a0d..0919797 100644 --- a/Cell_BLAST/rmbatch.py +++ b/Cell_BLAST/rmbatch.py @@ -2,22 +2,29 @@ Batch effect removing modules for DIRECTi """ +import typing + import torch -from torch import nn import torch.nn.functional as F -import typing -from .rebuild import Linear -from .rebuild import MLP +from torch import nn + from . import config, utils +from .rebuild import MLP, Linear + class RMBatch(nn.Module): r""" Parent class for systematical bias / batch effect removal modules. """ + def __init__( - self, batch_dim: int, latent_dim: int, delay: int = 20, - name: str = "RMBatch", _class: str = "RMBatch", - **kwargs + self, + batch_dim: int, + latent_dim: int, + delay: int = 20, + name: str = "RMBatch", + _class: str = "RMBatch", + **kwargs, ) -> None: super().__init__() self.batch_dim = batch_dim @@ -25,7 +32,7 @@ def __init__( self.delay = delay self.name = name self._class = _class - self.record_prefix = 'discriminator' + self.record_prefix = "discriminator" self.n_steps = 0 for key in kwargs.keys(): @@ -33,15 +40,27 @@ def __init__( def get_mask(self, x: torch.Tensor, feed_dict: typing.Mapping) -> torch.Tensor: b = feed_dict[self.name] - return b.sum(dim = 1) > 0 + return b.sum(dim=1) > 0 def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: return x[mask] - def d_loss(self, x: torch.Tensor, feed_dict: typing.Mapping, mask: torch.Tensor, loss_record: typing.Mapping) -> torch.Tensor: + def d_loss( + self, + x: torch.Tensor, + feed_dict: typing.Mapping, + mask: torch.Tensor, + loss_record: typing.Mapping, + ) -> torch.Tensor: return torch.tensor(0) - def g_loss(self, x: torch.Tensor, feed_dict: typing.Mapping, mask: torch.Tensor, loss_record: typing.Mapping) -> torch.Tensor: + def g_loss( + self, + x: torch.Tensor, + feed_dict: typing.Mapping, + mask: torch.Tensor, + loss_record: typing.Mapping, + ) -> torch.Tensor: return torch.tensor(0) def init_loss_record(self, loss_record: typing.Mapping) -> None: @@ -53,9 +72,10 @@ def get_config(self) -> typing.Mapping: "latent_dim": self.latent_dim, "delay": self.delay, "name": self.name, - "_class": self._class + "_class": self._class, } + class Adversarial(RMBatch): r""" Build a batch effect correction module that uses adversarial batch alignment. @@ -81,19 +101,20 @@ class Adversarial(RMBatch): name Name of the module. """ + def __init__( - self, - batch_dim: int, - latent_dim: int, - h_dim: int = 128, - depth: int = 1, - dropout: float = 0.0, - lambda_reg: float = 0.01, - n_steps: int = 1, - delay: int = 20, - name: str = "AdvBatch", - _class: str = "Adversarial", - **kwargs + self, + batch_dim: int, + latent_dim: int, + h_dim: int = 128, + depth: int = 1, + dropout: float = 0.0, + lambda_reg: float = 0.01, + n_steps: int = 1, + delay: int = 20, + name: str = "AdvBatch", + _class: str = "Adversarial", + **kwargs, ) -> None: super().__init__(batch_dim, latent_dim, delay, name, _class, **kwargs) self.h_dim = h_dim @@ -108,29 +129,42 @@ def __init__( if depth > 0: dropout[0] = 0.0 self.mlp = MLP(i_dim, o_dim, dropout) - self.pred = Linear(h_dim, batch_dim) if depth > 0 else Linear(latent_dim, batch_dim) + self.pred = ( + Linear(h_dim, batch_dim) if depth > 0 else Linear(latent_dim, batch_dim) + ) def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: return self.pred(self.mlp(x[mask])) - def d_loss(self, pred: torch.Tensor, feed_dict: typing.Mapping, mask: torch.Tensor, loss_record: typing.Mapping) -> torch.Tensor: - + def d_loss( + self, + pred: torch.Tensor, + feed_dict: typing.Mapping, + mask: torch.Tensor, + loss_record: typing.Mapping, + ) -> torch.Tensor: b = feed_dict[self.name] - rmbatch_d_loss = F.cross_entropy(pred, b[mask].argmax(dim = 1)) - loss_record[self.record_prefix + '/' + self.name + '/d_loss'] += rmbatch_d_loss.item() * b.shape[0] + rmbatch_d_loss = F.cross_entropy(pred, b[mask].argmax(dim=1)) + loss_record[self.record_prefix + "/" + self.name + "/d_loss"] += ( + rmbatch_d_loss.item() * b.shape[0] + ) return self.lambda_reg * rmbatch_d_loss - def g_loss(self, pred: torch.Tensor, feed_dict: typing.Mapping, mask: torch.Tensor, loss_record: typing.Mapping) -> torch.Tensor: - + def g_loss( + self, + pred: torch.Tensor, + feed_dict: typing.Mapping, + mask: torch.Tensor, + loss_record: typing.Mapping, + ) -> torch.Tensor: b = feed_dict[self.name] - rmbatch_g_loss = F.cross_entropy(pred, b[mask].argmax(dim = 1)) + rmbatch_g_loss = F.cross_entropy(pred, b[mask].argmax(dim=1)) - return - self.lambda_reg * rmbatch_g_loss + return -self.lambda_reg * rmbatch_g_loss def init_loss_record(self, loss_record: typing.Mapping) -> None: - - loss_record[self.record_prefix + '/' + self.name + '/d_loss'] = 0 + loss_record[self.record_prefix + "/" + self.name + "/d_loss"] = 0 def get_config(self) -> typing.Mapping: return { @@ -161,16 +195,17 @@ class MNN(RMBatch): name Name of the module. """ + def __init__( - self, - batch_dim: int, - latent_dim: int, - n_neighbors: int = 5, - lambda_reg: float = 1.0, - delay: int = 20, - name: str = "MNN", - _class: str = "MNN", - **kwargs + self, + batch_dim: int, + latent_dim: int, + n_neighbors: int = 5, + lambda_reg: float = 1.0, + delay: int = 20, + name: str = "MNN", + _class: str = "MNN", + **kwargs, ) -> None: super().__init__(batch_dim, latent_dim, delay, name, _class, **kwargs) self.n_neighbors = n_neighbors @@ -179,17 +214,22 @@ def __init__( @staticmethod def _neighbor_mask(d: torch.Tensor, k: int) -> torch.Tensor: n = d.shape[1] - _, idx = d.topk(min(k, n), largest = False) - return F.one_hot(idx, n).sum(dim = 1) > 0 + _, idx = d.topk(min(k, n), largest=False) + return F.one_hot(idx, n).sum(dim=1) > 0 @staticmethod def _mnn_mask(d: torch.Tensor, k: int) -> torch.Tensor: return MNN._neighbor_mask(d, k) & MNN._neighbor_mask(d.T, k).T - def g_loss(self, x: torch.Tensor, feed_dict: typing.Mapping, mask: torch.Tensor, loss_record: typing.Mapping) -> torch.Tensor: - + def g_loss( + self, + x: torch.Tensor, + feed_dict: typing.Mapping, + mask: torch.Tensor, + loss_record: typing.Mapping, + ) -> torch.Tensor: b = feed_dict[self.name] - barg = b[mask].argmax(dim = 1) + barg = b[mask].argmax(dim=1) masked_x = x[mask] x_grouping = [] for i in range(b.shape[1]): @@ -200,11 +240,11 @@ def g_loss(self, x: torch.Tensor, feed_dict: typing.Mapping, mask: torch.Tensor, if x_grouping[i].shape[0] > 0 and x_grouping[j].shape[0] > 0: u = x_grouping[i].unsqueeze(1) v = x_grouping[j].unsqueeze(0) - uv_dist = ((u - v).square()).sum(dim = 2) + uv_dist = ((u - v).square()).sum(dim=2) mnn_idx = self._mnn_mask(uv_dist, self.n_neighbors) penalty = mnn_idx.float() * uv_dist penalties.append(penalty.reshape(-1)) - penalties = torch.cat(penalties, dim = 0) + penalties = torch.cat(penalties, dim=0) return self.lambda_reg * penalties.mean() @@ -245,41 +285,56 @@ class MNNAdversarial(Adversarial): name Name of the module. """ + def __init__( - self, - batch_dim: int, - latent_dim: int, - h_dim: int = 128, - depth: int = 1, - dropout: float = 0.0, - lambda_reg: float = 0.01, - n_steps: int = 1, - n_neighbors: int = 5, - delay: int = 20, - name: str = "MNNAdvBatch", - _class: str = "MNNAdversarial", - **kwargs + self, + batch_dim: int, + latent_dim: int, + h_dim: int = 128, + depth: int = 1, + dropout: float = 0.0, + lambda_reg: float = 0.01, + n_steps: int = 1, + n_neighbors: int = 5, + delay: int = 20, + name: str = "MNNAdvBatch", + _class: str = "MNNAdversarial", + **kwargs, ) -> None: - super().__init__(batch_dim, latent_dim, h_dim, depth, dropout, lambda_reg, n_steps, delay, name, _class, **kwargs) + super().__init__( + batch_dim, + latent_dim, + h_dim, + depth, + dropout, + lambda_reg, + n_steps, + delay, + name, + _class, + **kwargs, + ) self.n_neighbors = n_neighbors @staticmethod def _neighbor_mask(d: torch.Tensor, k: int) -> torch.Tensor: n = d.shape[1] - _, idx = d.topk(min(k, n), largest = False) - return F.one_hot(idx, n).sum(dim = 1) > 0 + _, idx = d.topk(min(k, n), largest=False) + return F.one_hot(idx, n).sum(dim=1) > 0 @staticmethod def _mnn_mask(d: torch.Tensor, k: int) -> torch.Tensor: - return MNNAdversarial._neighbor_mask(d, k) & MNNAdversarial._neighbor_mask(d.T, k).T + return ( + MNNAdversarial._neighbor_mask(d, k) + & MNNAdversarial._neighbor_mask(d.T, k).T + ) def get_mask(self, x: torch.Tensor, feed_dict: typing.Mapping) -> torch.Tensor: - b = feed_dict[self.name] - mask = b.sum(dim = 1) > 0 + mask = b.sum(dim=1) > 0 mnn_mask = torch.zeros(b.shape[0], device=config.DEVICE) > 0 masked_mnn_mask = mnn_mask[mask] - barg = b[mask].argmax(dim = 1) + barg = b[mask].argmax(dim=1) x_grouping = [] for i in range(b.shape[1]): x_grouping.append(x[mask][barg == i].detach()) @@ -288,20 +343,20 @@ def get_mask(self, x: torch.Tensor, feed_dict: typing.Mapping) -> torch.Tensor: if x_grouping[i].shape[0] > 0 and x_grouping[j].shape[0] > 0: u = x_grouping[i].unsqueeze(1) v = x_grouping[j].unsqueeze(0) - uv_dist = ((u - v).square()).sum(dim = 2) + uv_dist = ((u - v).square()).sum(dim=2) mnn_idx = self._mnn_mask(uv_dist, self.n_neighbors) - masked_mnn_mask[barg == i] |= (mnn_idx.sum(dim = 1) > 0) - masked_mnn_mask[barg == j] |= (mnn_idx.sum(dim = 0) > 0) + masked_mnn_mask[barg == i] |= mnn_idx.sum(dim=1) > 0 + masked_mnn_mask[barg == j] |= mnn_idx.sum(dim=0) > 0 mnn_mask[mask] = masked_mnn_mask return mnn_mask def get_config(self) -> typing.Mapping: - return { - "n_neighbors": self.n_neighbors, - **super().get_config() - } + return {"n_neighbors": self.n_neighbors, **super().get_config()} + class AdaptiveMNNAdversarial(MNNAdversarial): def __init__(self, *args, **kwargs): - utils.logger.warning("RMBatch module `AdaptiveMNNAdversarial` is no longer supported, running as `MNNAdversarial`") - super().__init__(*args, **kwargs) \ No newline at end of file + utils.logger.warning( + "RMBatch module `AdaptiveMNNAdversarial` is no longer supported, running as `MNNAdversarial`" + ) + super().__init__(*args, **kwargs) diff --git a/Cell_BLAST/utils.py b/Cell_BLAST/utils.py index c8fd9b5..d1ad716 100644 --- a/Cell_BLAST/utils.py +++ b/Cell_BLAST/utils.py @@ -6,22 +6,21 @@ import collections import functools import json +import logging import operator import os -import typing -import logging import re +import typing +import h5py import igraph import numpy as np import pandas as pd +import pynvml import scipy.sparse -import h5py +import torch import tqdm import tqdm.notebook -import torch -import pynvml - log_handler = logging.StreamHandler() log_handler.setLevel(logging.INFO) @@ -30,6 +29,7 @@ logger.setLevel(logging.INFO) logger.addHandler(log_handler) + def rand_hex() -> str: return binascii.b2a_hex(os.urandom(15)).decode() @@ -48,13 +48,13 @@ def in_ipynb() -> bool: # pragma: no cover # noinspection PyUnresolvedReferences shell = get_ipython().__class__.__name__ if shell == "ZMQInteractiveShell": - return True # Jupyter notebook or qtconsole + return True # Jupyter notebook or qtconsole elif shell == "TerminalInteractiveShell": return False # Terminal running IPython else: return False # Other type (?) except NameError: - return False # Probably standard Python interpreter + return False # Probably standard Python interpreter def smart_tqdm(): # pragma: no cover @@ -68,38 +68,44 @@ def with_self_graph(fn: typing.Callable) -> typing.Callable: def wrapped(self, *args, **kwargs): with self.graph.as_default(): return fn(self, *args, **kwargs) + return wrapped # Wraps a batch function into minibatch version -def minibatch(batch_size: int, desc: str, use_last: bool = False, progress_bar: bool = True) -> typing.Callable: +def minibatch( + batch_size: int, desc: str, use_last: bool = False, progress_bar: bool = True +) -> typing.Callable: def minibatch_wrapper(func: typing.Callable) -> typing.Callable: @functools.wraps(func) def wrapped_func(*args, **kwargs): total_size = args[0].shape[0] if use_last: - n_batch = np.ceil( - total_size / float(batch_size) - ).astype(np.int) + n_batch = np.ceil(total_size / float(batch_size)).astype(np.int) else: - n_batch = max(1, np.floor( - total_size / float(batch_size) - ).astype(np.int)) + n_batch = max( + 1, np.floor(total_size / float(batch_size)).astype(np.int) + ) for batch_idx in smart_tqdm()( - range(n_batch), desc=desc, unit="batches", - leave=False, disable=not progress_bar + range(n_batch), + desc=desc, + unit="batches", + leave=False, + disable=not progress_bar, ): start = batch_idx * batch_size end = min((batch_idx + 1) * batch_size, total_size) this_args = (item[start:end] for item in args) func(*this_args, **kwargs) + return wrapped_func + return minibatch_wrapper # Avoid sklearn warning def encode_integer( - label: typing.List[typing.Any], sort: bool = False + label: typing.List[typing.Any], sort: bool = False ) -> typing.Tuple[np.ndarray, np.ndarray]: index = pd.Index(label).dropna().drop_duplicates() if sort: @@ -109,7 +115,7 @@ def encode_integer( # Avoid sklearn warning def encode_onehot( - label: typing.List[typing.Any], sort: bool = False + label: typing.List[typing.Any], sort: bool = False ) -> scipy.sparse.csr_matrix: i, c = encode_integer(label, sort) val = np.ones_like(i, dtype=np.int32)[i >= 0] @@ -119,10 +125,10 @@ def encode_onehot( class CellTypeDAG(object): - def __init__( - self, graph: typing.Optional[igraph.Graph] = None, - vdict: typing.Optional[typing.Mapping[str, str]] = None + self, + graph: typing.Optional[igraph.Graph] = None, + vdict: typing.Optional[typing.Mapping[str, str]] = None, ) -> None: self.graph = igraph.Graph(directed=True) if graph is None else graph self.vdict = vdict or {} @@ -145,8 +151,11 @@ def load_json(cls, file: str) -> "CellTypeDAG": return dag @classmethod - def load_obo(cls, file: str) -> "CellTypeDAG": # Only building on "is_a" relation between CL terms + def load_obo( + cls, file: str + ) -> "CellTypeDAG": # Only building on "is_a" relation between CL terms import pronto + ont = pronto.Ontology(file) graph, vdict = igraph.Graph(directed=True), {} for item in ont: @@ -155,11 +164,10 @@ def load_obo(cls, file: str) -> "CellTypeDAG": # Only building on "is_a" relati if "is_obsolete" in item.other and item.other["is_obsolete"][0] == "true": continue graph.add_vertex( - name=item.id, cell_ontology_class=item.name, - desc=str(item.desc), synonyms=[ - (f"{syn.desc} ({syn.scope})") - for syn in item.synonyms - ] + name=item.id, + cell_ontology_class=item.name, + desc=str(item.desc), + synonyms=[(f"{syn.desc} ({syn.scope})") for syn in item.synonyms], ) assert item.id not in vdict vdict[item.id] = item.id @@ -177,7 +185,9 @@ def load_obo(cls, file: str) -> "CellTypeDAG": # Only building on "is_a" relati continue graph.add_edge( source["name"], - graph.vs.find(name=target.id.split()[0])["name"] # pylint: disable=no-member + graph.vs.find(name=target.id.split()[0])[ + "name" + ], # pylint: disable=no-member ) # Split because there are many "{is_infered...}" suffix, # falsely joined to the actual id when pronto parses the @@ -185,8 +195,7 @@ def load_obo(cls, file: str) -> "CellTypeDAG": # Only building on "is_a" relati return cls(graph, vdict) def _build_tree( - self, d: typing.Mapping[str, str], - parent: typing.Optional[igraph.Vertex] = None + self, d: typing.Mapping[str, str], parent: typing.Optional[igraph.Vertex] = None ) -> None: # For json loading self.graph.add_vertex(name=d["name"]) v = self.graph.vs.find(d["name"]) @@ -204,8 +213,7 @@ def get_vertex(self, name: str) -> igraph.Vertex: return self.graph.vs.find(self.vdict[name]) def is_related(self, name1: str, name2: str) -> bool: - return self.is_descendant_of(name1, name2) \ - or self.is_ancestor_of(name1, name2) + return self.is_descendant_of(name1, name2) or self.is_ancestor_of(name1, name2) def is_descendant_of(self, name1: str, name2: str) -> bool: if name1 not in self.vdict or name2 not in self.vdict: @@ -227,10 +235,8 @@ def conditional_prob(self, name1: str, name2: str) -> float: # p(name1|name2) if name1 not in self.vdict or name2 not in self.vdict: return 0 self.graph.vs["prob"] = 0 - v2_parents = list(self.graph.bfsiter( - self.get_vertex(name2), mode=igraph.OUT)) - v1_parents = list(self.graph.bfsiter( - self.get_vertex(name1), mode=igraph.OUT)) + v2_parents = list(self.graph.bfsiter(self.get_vertex(name2), mode=igraph.OUT)) + v1_parents = list(self.graph.bfsiter(self.get_vertex(name1), mode=igraph.OUT)) for v in v2_parents: v["prob"] = 1 while True: @@ -238,10 +244,12 @@ def conditional_prob(self, name1: str, name2: str) -> float: # p(name1|name2) for v1_parent in v1_parents[::-1]: # Reverse may be more efficient if v1_parent["prob"] != 0: continue - v1_parent["prob"] = np.prod([ - v["prob"] / v.degree(mode=igraph.IN) - for v in v1_parent.neighbors(mode=igraph.OUT) - ]) + v1_parent["prob"] = np.prod( + [ + v["prob"] / v.degree(mode=igraph.IN) + for v in v1_parent.neighbors(mode=igraph.OUT) + ] + ) if v1_parent["prob"] != 0: changed = True if not changed: @@ -251,8 +259,8 @@ def conditional_prob(self, name1: str, name2: str) -> float: # p(name1|name2) def similarity(self, name1: str, name2: str, method: str = "probability") -> float: if method == "probability": return ( - self.conditional_prob(name1, name2) + - self.conditional_prob(name2, name1) + self.conditional_prob(name1, name2) + + self.conditional_prob(name2, name1) ) / 2 # if method == "distance": # return self.distance_ratio(name1, name2) @@ -275,14 +283,12 @@ def value_update(self) -> None: for v in self.graph.bfsiter(origin, mode=igraph.OUT): if v != origin: # bfsiter includes the vertex self v["prop_value"] += origin["raw_value"] - self.graph.vs["value"] = list(map( - operator.add, self.graph.vs["raw_value"], - self.graph.vs["prop_value"] - )) + self.graph.vs["value"] = list( + map(operator.add, self.graph.vs["raw_value"], self.graph.vs["prop_value"]) + ) def best_leaves( - self, thresh: float = 0.5, min_path: int = 4, - retrieve: str = "name" + self, thresh: float = 0.5, min_path: int = 4, retrieve: str = "name" ) -> typing.List[str]: subgraph = self.graph.subgraph(self.graph.vs.select(value_gt=thresh)) leaves, max_value = [], 0 @@ -303,13 +309,18 @@ def cal_longest_paths_to_root(self, weight: float = 1.0) -> None: root["longest_paths_to_root"] = 0 self.graph.es["weight"] = weight for vertex in self.graph.vs[self.graph.topological_sorting(mode=igraph.IN)]: - for neighbor in self.graph.vs[self.graph.neighborhood(vertex, mode=igraph.IN)]: + for neighbor in self.graph.vs[ + self.graph.neighborhood(vertex, mode=igraph.IN) + ]: if neighbor == vertex: continue - if neighbor["longest_paths_to_root"] < vertex["longest_paths_to_root"] + \ - self.graph[neighbor, vertex]: - neighbor["longest_paths_to_root"] = vertex["longest_paths_to_root"] + \ - self.graph[neighbor, vertex] + if ( + neighbor["longest_paths_to_root"] + < vertex["longest_paths_to_root"] + self.graph[neighbor, vertex] + ): + neighbor["longest_paths_to_root"] = ( + vertex["longest_paths_to_root"] + self.graph[neighbor, vertex] + ) def longest_paths_to_root(self, name: str) -> int: if "longest_paths_to_root" not in self.get_vertex(name).attribute_names(): @@ -318,13 +329,15 @@ def longest_paths_to_root(self, name: str) -> int: class DataDict(collections.OrderedDict): - def shuffle(self, random_state: np.random.RandomState = np.random) -> "DataDict": shuffled = DataDict() shuffle_idx = None for item in self: - shuffle_idx = random_state.permutation(self[item].shape[0]) \ - if shuffle_idx is None else shuffle_idx + shuffle_idx = ( + random_state.permutation(self[item].shape[0]) + if shuffle_idx is None + else shuffle_idx + ) shuffled[item] = self[item][shuffle_idx] return shuffled @@ -341,12 +354,10 @@ def shape(self) -> typing.List[int]: # Compatibility with numpy arrays return [self.size] def __getitem__( - self, fetch: typing.Union[str, slice, np.ndarray] + self, fetch: typing.Union[str, slice, np.ndarray] ) -> typing.Union["DataDict", np.ndarray]: if isinstance(fetch, (slice, np.ndarray)): - return DataDict([ - (item, self[item][fetch]) for item in self - ]) + return DataDict([(item, self[item][fetch]) for item in self]) return super(DataDict, self).__getitem__(fetch) @@ -361,6 +372,7 @@ def _fn(x): if x.size: return fn(x) return x.astype(dtype) + return _fn @@ -382,6 +394,7 @@ def isnan(x: typing.Any) -> bool: lower = empty_safe(np.vectorize(lambda x: str(x).lower()), str) tostr = empty_safe(np.vectorize(str), str) + def read_hybrid_path(hybrid_path: str) -> np.ndarray: file_name, h5_path = hybrid_path.split("//") with h5py.File(file_name, "r") as f: @@ -402,6 +415,7 @@ def write_hybrid_path(obj: np.ndarray, hybrid_path: str) -> None: obj = encode(obj) f.create_dataset(h5_path, data=obj) + def dataframe2list(dataframe): identical = [] for i in zip(dataframe.iloc[:, 0], dataframe.iloc[:, 1]): @@ -423,11 +437,14 @@ def autodevice() -> torch.device: used_device = -1 try: pynvml.nvmlInit() - free_mems = np.array([ - pynvml.nvmlDeviceGetMemoryInfo( - pynvml.nvmlDeviceGetHandleByIndex(i) - ).free for i in range(pynvml.nvmlDeviceGetCount()) - ]) + free_mems = np.array( + [ + pynvml.nvmlDeviceGetMemoryInfo( + pynvml.nvmlDeviceGetHandleByIndex(i) + ).free + for i in range(pynvml.nvmlDeviceGetCount()) + ] + ) if free_mems.size: best_devices = np.where(free_mems == free_mems.max())[0] used_device = np.random.choice(best_devices, 1)[0] diff --git a/Cell_BLAST/weighting.py b/Cell_BLAST/weighting.py index fd4c775..0100c69 100644 --- a/Cell_BLAST/weighting.py +++ b/Cell_BLAST/weighting.py @@ -2,32 +2,35 @@ Weighting strategy for adversarial batch alignment in DIRECTi """ -from sklearn.metrics.pairwise import pairwise_distances -import scanpy as sc -import anndata as ad import time +import typing +from collections import Counter + +import anndata as ad +import matplotlib.pyplot as plt import numpy as np import pandas as pd +import scanpy as sc +import seaborn as sns import torch import torch.nn.functional as F -import seaborn as sns -import matplotlib.pyplot as plt -import typing from matplotlib.pyplot import rc_context -from collections import Counter +from sklearn.metrics.pairwise import pairwise_distances -from . import utils, config, data +from . import config, data, utils _identity = lambda x, y: 1 if x == y else 0 -def calc_weights(adata: ad.AnnData, - genes: typing.Optional[typing.List[str]], - batch_effect: typing.Optional[str], - add_weight: typing.Tuple[bool], - clustering_space: str, - similarity_space: str, - random_seed: int) -> None: +def calc_weights( + adata: ad.AnnData, + genes: typing.Optional[typing.List[str]], + batch_effect: typing.Optional[str], + add_weight: typing.Tuple[bool], + clustering_space: str, + similarity_space: str, + random_seed: int, +) -> None: r""" Calculate the proper weight of each cell for adversarial batch alignment. @@ -54,90 +57,114 @@ def calc_weights(adata: ad.AnnData, """ if any(add_weight): - utils.logger.info('Calculating weights...') + utils.logger.info("Calculating weights...") start_time = time.time() for _add_weight, _batch_effect in zip(add_weight, batch_effect): - if _add_weight: if config.SUPERVISION is None: if clustering_space is None: - clustering_latent = get_default_clustering_space(adata, genes, _batch_effect, random_seed) + clustering_latent = get_default_clustering_space( + adata, genes, _batch_effect, random_seed + ) else: clustering_latent = adata.obsm[clustering_space] if similarity_space is None: - similarity_latent = get_default_similarity_space(adata, genes, _batch_effect, random_seed) + similarity_latent = get_default_similarity_space( + adata, genes, _batch_effect, random_seed + ) else: similarity_latent = adata.obsm[similarity_space] - weight_full = np.zeros(adata.n_obs, dtype = np.float32) + weight_full = np.zeros(adata.n_obs, dtype=np.float32) - batch = utils.densify(utils.encode_onehot( - adata.obs[_batch_effect], sort=True - )) + batch = utils.densify( + utils.encode_onehot(adata.obs[_batch_effect], sort=True) + ) num_batch = batch.shape[1] - mask = batch.sum(axis = 1) > 0 - batch = batch[mask].argmax(axis = 1) - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_mask'] = mask - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_batch'] = batch - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_num_batch'] = num_batch + mask = batch.sum(axis=1) > 0 + batch = batch[mask].argmax(axis=1) + adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + "_mask"] = mask + adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + "_batch"] = batch + adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + "_num_batch"] = num_batch if config.SUPERVISION is not None: - all_labels = adata.obs[config.SUPERVISION][mask] - cluster, sum_cluster, num_clusters = \ - get_supervised_cluster(all_labels, batch, num_batch) - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_cluster'] = cluster - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_sum_cluster'] = sum_cluster - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_num_clusters'] = num_clusters - - volume = get_volume(batch, num_batch, cluster, sum_cluster, num_clusters) - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_volume'] = volume - - similarity = \ - get_supervised_similarity(all_labels, cluster, sum_cluster) - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_raw_similarity'] = similarity - - else: #supervision is None - - cluster, sum_cluster, num_clusters = \ - get_unsupervised_cluster(clustering_latent[mask], batch, num_batch, random_seed) - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_cluster'] = cluster - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_sum_cluster'] = sum_cluster - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_num_clusters'] = num_clusters - - volume = get_volume(batch, num_batch, cluster, sum_cluster, num_clusters) - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_volume'] = volume - - similarity, raw_similarity = \ - get_unsupervised_similarity(similarity_latent[mask], num_batch, cluster, sum_cluster, num_clusters, volume) - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_raw_similarity'] = raw_similarity - - - weight = get_weight(batch, num_batch, cluster, sum_cluster, num_clusters, volume, similarity) - adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + '_weight'] = weight + cluster, sum_cluster, num_clusters = get_supervised_cluster( + all_labels, batch, num_batch + ) + adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + "_cluster"] = cluster + adata.uns[ + config._WEIGHT_PREFIX_ + _batch_effect + "_sum_cluster" + ] = sum_cluster + adata.uns[ + config._WEIGHT_PREFIX_ + _batch_effect + "_num_clusters" + ] = num_clusters + + volume = get_volume( + batch, num_batch, cluster, sum_cluster, num_clusters + ) + adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + "_volume"] = volume + + similarity = get_supervised_similarity(all_labels, cluster, sum_cluster) + adata.uns[ + config._WEIGHT_PREFIX_ + _batch_effect + "_raw_similarity" + ] = similarity + + else: # supervision is None + cluster, sum_cluster, num_clusters = get_unsupervised_cluster( + clustering_latent[mask], batch, num_batch, random_seed + ) + adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + "_cluster"] = cluster + adata.uns[ + config._WEIGHT_PREFIX_ + _batch_effect + "_sum_cluster" + ] = sum_cluster + adata.uns[ + config._WEIGHT_PREFIX_ + _batch_effect + "_num_clusters" + ] = num_clusters + + volume = get_volume( + batch, num_batch, cluster, sum_cluster, num_clusters + ) + adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + "_volume"] = volume + + similarity, raw_similarity = get_unsupervised_similarity( + similarity_latent[mask], + num_batch, + cluster, + sum_cluster, + num_clusters, + volume, + ) + adata.uns[ + config._WEIGHT_PREFIX_ + _batch_effect + "_raw_similarity" + ] = raw_similarity + + weight = get_weight( + batch, num_batch, cluster, sum_cluster, num_clusters, volume, similarity + ) + adata.uns[config._WEIGHT_PREFIX_ + _batch_effect + "_weight"] = weight weight_full[mask] = weight - report = f"[batch effect \"{_batch_effect}\"] " + report = f'[batch effect "{_batch_effect}"] ' report += f"time elapsed={time.time() - start_time:.1f}s" print(report) else: - weight_full = np.ones(adata.n_obs, dtype = np.float32) + weight_full = np.ones(adata.n_obs, dtype=np.float32) adata.obs[config._WEIGHT_PREFIX_ + _batch_effect] = weight_full -def get_supervised_cluster(all_labels: typing.List[str], - batch: np.ndarray, - num_batch: int) -> typing.Tuple[np.ndarray, int, typing.List[int], np.ndarray]: - +def get_supervised_cluster( + all_labels: typing.List[str], batch: np.ndarray, num_batch: int +) -> typing.Tuple[np.ndarray, int, typing.List[int], np.ndarray]: sum_cluster = 0 num_clusters = [0] - cluster = np.zeros(batch.shape[0], dtype = np.int) + cluster = np.zeros(batch.shape[0], dtype=np.int) for i in range(num_batch): if config.NO_CLUSTER: @@ -155,10 +182,10 @@ def get_supervised_cluster(all_labels: typing.List[str], return cluster, sum_cluster, num_clusters -def get_supervised_similarity(all_labels: typing.List[str], - cluster: np.ndarray, - sum_cluster: int) -> np.ndarray: +def get_supervised_similarity( + all_labels: typing.List[str], cluster: np.ndarray, sum_cluster: int +) -> np.ndarray: similarity = np.zeros((sum_cluster, sum_cluster)) for i in range(sum_cluster): for j in range(sum_cluster): @@ -167,14 +194,13 @@ def get_supervised_similarity(all_labels: typing.List[str], return similarity -def get_unsupervised_cluster(clustering_latent: np.ndarray, - batch: np.ndarray, - num_batch: int, - random_seed: int) -> typing.Tuple[np.ndarray, int, typing.List[int], np.ndarray]: +def get_unsupervised_cluster( + clustering_latent: np.ndarray, batch: np.ndarray, num_batch: int, random_seed: int +) -> typing.Tuple[np.ndarray, int, typing.List[int], np.ndarray]: sum_cluster = 0 num_clusters = [0] - cluster = np.zeros(batch.shape[0], dtype = np.int) + cluster = np.zeros(batch.shape[0], dtype=np.int) for i in range(num_batch): if config.NO_CLUSTER: @@ -184,42 +210,53 @@ def get_unsupervised_cluster(clustering_latent: np.ndarray, num_clusters.append(sum_cluster) else: curr_latent = ad.AnnData(clustering_latent[batch == i]) - sc.pp.neighbors(curr_latent, use_rep = 'X', random_state = random_seed) - sc.tl.leiden(curr_latent, random_state = random_seed, resolution = config.RESOLUTION) - cluster[batch == i] = np.asarray(curr_latent.obs['leiden'].astype(int)) + sc.pp.neighbors(curr_latent, use_rep="X", random_state=random_seed) + sc.tl.leiden( + curr_latent, random_state=random_seed, resolution=config.RESOLUTION + ) + cluster[batch == i] = np.asarray(curr_latent.obs["leiden"].astype(int)) cluster[batch == i] += sum_cluster sum_cluster = cluster[batch == i].max() + 1 num_clusters.append(sum_cluster) return cluster, sum_cluster, num_clusters -def get_volume(batch: np.ndarray, - num_batch: int, - cluster: np.ndarray, - sum_cluster: int, - num_clusters: typing.List[int]): +def get_volume( + batch: np.ndarray, + num_batch: int, + cluster: np.ndarray, + sum_cluster: int, + num_clusters: typing.List[int], +): volume = np.zeros(sum_cluster) for i in range(num_batch): - for j in range(num_clusters[i], num_clusters[i+1]): + for j in range(num_clusters[i], num_clusters[i + 1]): volume[j] = (cluster == j).sum() / (batch == i).sum() return volume -def get_unsupervised_similarity(similarity_latent: np.ndarray, - num_batch: int, - cluster: np.ndarray, - sum_cluster: int, - num_clusters: typing.List[int], - volume: np.ndarray) -> typing.Tuple[np.ndarray]: + +def get_unsupervised_similarity( + similarity_latent: np.ndarray, + num_batch: int, + cluster: np.ndarray, + sum_cluster: int, + num_clusters: typing.List[int], + volume: np.ndarray, +) -> typing.Tuple[np.ndarray]: center = np.zeros((sum_cluster, similarity_latent.shape[1])) for i in range(sum_cluster): - center[i] = similarity_latent[cluster == i].mean(axis = 0) + center[i] = similarity_latent[cluster == i].mean(axis=0) - raw_similarity = -pairwise_distances(center, metric = config.METRIC, **config.METRIC_KWARGS) + raw_similarity = -pairwise_distances( + center, metric=config.METRIC, **config.METRIC_KWARGS + ) if config.MNN: - raw_similarity = get_MNN(num_batch, cluster, sum_cluster, num_clusters, raw_similarity) + raw_similarity = get_MNN( + num_batch, cluster, sum_cluster, num_clusters, raw_similarity + ) else: raw_similarity = (raw_similarity - config.THRESHOLD) / (1 - config.THRESHOLD) raw_similarity[raw_similarity < 0] = 0 @@ -228,29 +265,31 @@ def get_unsupervised_similarity(similarity_latent: np.ndarray, for i in range(sum_cluster): similarity[:, i] *= volume for j in range(num_batch): - curr_sum = similarity[num_clusters[j]:num_clusters[j+1], i].sum() + curr_sum = similarity[num_clusters[j] : num_clusters[j + 1], i].sum() if curr_sum > 0: - similarity[num_clusters[j]:num_clusters[j+1], i] /= curr_sum + similarity[num_clusters[j] : num_clusters[j + 1], i] /= curr_sum return similarity, raw_similarity -def get_weight(batch: np.ndarray, - num_batch: int, - cluster: np.ndarray, - sum_cluster: int, - num_clusters: typing.List[int], - volume: np.ndarray, - similarity: np.ndarray) -> np.ndarray: +def get_weight( + batch: np.ndarray, + num_batch: int, + cluster: np.ndarray, + sum_cluster: int, + num_clusters: typing.List[int], + volume: np.ndarray, + similarity: np.ndarray, +) -> np.ndarray: weight = np.ones(batch.shape[0]) for i in range(sum_cluster): vv = volume * (similarity[i, :]) ww = np.zeros(num_batch) for j in range(num_batch): - ww[j] = vv[num_clusters[j]:num_clusters[j+1]].sum() - if (num_clusters[j] <= i) and (i < num_clusters[j+1]): + ww[j] = vv[num_clusters[j] : num_clusters[j + 1]].sum() + if (num_clusters[j] <= i) and (i < num_clusters[j + 1]): ww[j] = volume[i] - ww = ((ww ** 0.5).sum() ** 2 - ww.sum()) / num_batch / (num_batch - 1) + ww = ((ww**0.5).sum() ** 2 - ww.sum()) / num_batch / (num_batch - 1) ww = ww / volume[i] weight[cluster == i] = ww @@ -259,12 +298,12 @@ def get_weight(batch: np.ndarray, return weight -def plot_clustering_confidence(adata: ad.AnnData, - batch_effect: str, - ground_truth: str) -> None: - mask = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_mask'] - cluster = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_cluster'] - sum_cluster = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_sum_cluster'] +def plot_clustering_confidence( + adata: ad.AnnData, batch_effect: str, ground_truth: str +) -> None: + mask = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + "_mask"] + cluster = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + "_cluster"] + sum_cluster = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + "_sum_cluster"] main_cell_types = [] confidence = np.ndarray(sum_cluster) @@ -274,30 +313,37 @@ def plot_clustering_confidence(adata: ad.AnnData, most_common = label_counts.most_common(1) main_cell_types.append(most_common[0][0]) confidence[i] = most_common[0][1] / len(curr_label) - adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_main_cell_types'] = main_cell_types - adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_confidence'] = confidence + adata.uns[ + config._WEIGHT_PREFIX_ + batch_effect + "_main_cell_types" + ] = main_cell_types + adata.uns[config._WEIGHT_PREFIX_ + batch_effect + "_confidence"] = confidence - plt.figure(figsize = (7,7)) - sns.displot(confidence, kind = 'ecdf') + plt.figure(figsize=(7, 7)) + sns.displot(confidence, kind="ecdf") plt.show() -def plot_similarity_confidence(adata: ad.AnnData, - batch_effect: str, - similarity: typing.Callable = _identity) -> None: - batch = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_batch'] - num_batch = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_num_batch'] - cluster = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_cluster'] - sum_cluster = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_sum_cluster'] - num_clusters = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_num_clusters'] - main_cell_types = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_main_cell_types'] - raw_similarity = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_raw_similarity'] + +def plot_similarity_confidence( + adata: ad.AnnData, batch_effect: str, similarity: typing.Callable = _identity +) -> None: + batch = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + "_batch"] + num_batch = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + "_num_batch"] + cluster = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + "_cluster"] + sum_cluster = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + "_sum_cluster"] + num_clusters = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + "_num_clusters"] + main_cell_types = adata.uns[ + config._WEIGHT_PREFIX_ + batch_effect + "_main_cell_types" + ] + raw_similarity = adata.uns[ + config._WEIGHT_PREFIX_ + batch_effect + "_raw_similarity" + ] truth = raw_similarity.copy() mask = raw_similarity.copy() for x in range(num_batch): for y in range(num_batch): - for i in range(num_clusters[x], num_clusters[x+1]): - for j in range(num_clusters[y], num_clusters[y+1]): + for i in range(num_clusters[x], num_clusters[x + 1]): + for j in range(num_clusters[y], num_clusters[y + 1]): if similarity(main_cell_types[i], main_cell_types[j]): truth[i][j] = 1.0 else: @@ -307,93 +353,109 @@ def plot_similarity_confidence(adata: ad.AnnData, else: mask[i][j] = 1.0 - print('truth') - plt.figure(figsize=(7,7)) - sns.heatmap(data = truth * mask) + print("truth") + plt.figure(figsize=(7, 7)) + sns.heatmap(data=truth * mask) plt.show() - print('raw_similarity') - plt.figure(figsize=(7,7)) - sns.heatmap(data = raw_similarity * mask) + print("raw_similarity") + plt.figure(figsize=(7, 7)) + sns.heatmap(data=raw_similarity * mask) plt.show() coef = 1 - (truth - raw_similarity) ** 2 - print("true positive = %d" % ((truth == 1) & (raw_similarity == 1) & (mask == 1)).sum()) - print("true negative = %d" % ((truth == 0) & (raw_similarity == 0) & (mask == 1)).sum()) - print("false positive = %d" % ((truth == 0) & (raw_similarity == 1) & (mask == 1)).sum()) - print("false negative = %d" % ((truth == 1) & (raw_similarity == 0) & (mask == 1)).sum()) - plt.figure(figsize=(7,7)) - sns.heatmap(data = coef * mask) + print( + "true positive = %d" + % ((truth == 1) & (raw_similarity == 1) & (mask == 1)).sum() + ) + print( + "true negative = %d" + % ((truth == 0) & (raw_similarity == 0) & (mask == 1)).sum() + ) + print( + "false positive = %d" + % ((truth == 0) & (raw_similarity == 1) & (mask == 1)).sum() + ) + print( + "false negative = %d" + % ((truth == 1) & (raw_similarity == 0) & (mask == 1)).sum() + ) + plt.figure(figsize=(7, 7)) + sns.heatmap(data=coef * mask) plt.show() -def plot_performance(adata: ad.AnnData, - target: str) -> None: - - with rc_context({'figure.figsize': (7, 7)}): - sc.pl.umap(adata, color = target) - +def plot_performance(adata: ad.AnnData, target: str) -> None: + with rc_context({"figure.figsize": (7, 7)}): + sc.pl.umap(adata, color=target) -def plot_strategy(adata: ad.AnnData, - batch_effect: str, - target: str, - categorical: bool = False, - log: bool = False) -> None: - adata.obs[target] = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + '_' + target] +def plot_strategy( + adata: ad.AnnData, + batch_effect: str, + target: str, + categorical: bool = False, + log: bool = False, +) -> None: + adata.obs[target] = adata.uns[config._WEIGHT_PREFIX_ + batch_effect + "_" + target] if categorical: - adata.obs[target] = pd.Series(adata.obs[target], dtype = 'category') + adata.obs[target] = pd.Series(adata.obs[target], dtype="category") if log: adata.obs[target] = np.log1p(adata.obs[target]) - with rc_context({'figure.figsize': (7, 7)}): - sc.pl.umap(adata, color = target) + with rc_context({"figure.figsize": (7, 7)}): + sc.pl.umap(adata, color=target) - -def get_default_clustering_space(adata: ad.AnnData, - genes: typing.List[str], - batch_effect: str, - random_seed: int) -> np.ndarray: - +def get_default_clustering_space( + adata: ad.AnnData, genes: typing.List[str], batch_effect: str, random_seed: int +) -> np.ndarray: adata = data.select_vars(adata, genes) x = np.zeros((adata.obs.shape[0], config.PCA_N_COMPONENTS)) for batch in adata.obs[batch_effect].unique(): sub_adata = adata[adata.obs[batch_effect] == batch].copy() data.normalize(sub_adata) sc.pp.log1p(sub_adata) - sc.pp.pca(sub_adata, n_comps = config.PCA_N_COMPONENTS, random_state = random_seed) - x[adata.obs[batch_effect] == batch] = sub_adata.obsm['X_pca'] + sc.pp.pca(sub_adata, n_comps=config.PCA_N_COMPONENTS, random_state=random_seed) + x[adata.obs[batch_effect] == batch] = sub_adata.obsm["X_pca"] return x -def get_default_similarity_space(adata: ad.AnnData, - genes: typing.List[str], - batch_effect: str, - random_seed: int) -> np.ndarray: - +def get_default_similarity_space( + adata: ad.AnnData, genes: typing.List[str], batch_effect: str, random_seed: int +) -> np.ndarray: adata = data.select_vars(adata, genes) data.normalize(adata) return np.log1p(adata.X) -def get_MNN(num_batch: int, - cluster: np.ndarray, - sum_cluster: int, - num_clusters: typing.List[int], - raw_similarity: np.ndarray) -> np.ndarray: +def get_MNN( + num_batch: int, + cluster: np.ndarray, + sum_cluster: int, + num_clusters: typing.List[int], + raw_similarity: np.ndarray, +) -> np.ndarray: similarity = raw_similarity.copy() for i in range(num_batch): for j in range(num_batch): - xy_dist = torch.tensor(similarity[num_clusters[i]:num_clusters[i+1], num_clusters[j]:num_clusters[j+1]]) + xy_dist = torch.tensor( + similarity[ + num_clusters[i] : num_clusters[i + 1], + num_clusters[j] : num_clusters[j + 1], + ] + ) num_x = xy_dist.shape[0] num_y = xy_dist.shape[1] kx = min(num_x, config.MNN_K) ky = min(num_y, config.MNN_K) - x_topk = F.one_hot(xy_dist.topk(kx, dim = 0)[1], num_x).sum(dim = 0).bool() - y_topk = F.one_hot(xy_dist.topk(ky, dim = 1)[1], num_y).sum(dim = 1).bool() + x_topk = F.one_hot(xy_dist.topk(kx, dim=0)[1], num_x).sum(dim=0).bool() + y_topk = F.one_hot(xy_dist.topk(ky, dim=1)[1], num_y).sum(dim=1).bool() mnn_idx = x_topk.T & y_topk - similarity[num_clusters[i]:num_clusters[i+1], num_clusters[j]:num_clusters[j+1]] = mnn_idx.float().numpy() + similarity[ + num_clusters[i] : num_clusters[i + 1], + num_clusters[j] : num_clusters[j + 1], + ] = mnn_idx.float().numpy() - return similarity \ No newline at end of file + return similarity diff --git a/test/regression_refresh.py b/test/regression_refresh.py deleted file mode 100644 index fcea6bb..0000000 --- a/test/regression_refresh.py +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env python - -import sys -import os -import shutil -import unittest -import anndata - -sys.path.insert(0, "..") -import Cell_BLAST as cb -cb.config.RANDOM_SEED = 0 - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" -os.environ["CUDA_VISIBLE_DEVICES"] = "0" - - -class DirectiTest(unittest.TestCase): - - def setUp(self): - self.data = anndata.read_h5ad("pollen.h5ad") - cb.data.normalize(self.data) - - def tearDown(self): - if os.path.exists("./test_directi"): - shutil.rmtree("./test_directi") - - def test_gau(self): - model = cb.directi.fit_DIRECTi( - self.data, genes=self.data.uns["scmap_genes"], - latent_dim=10, epoch=3, path="./test_directi" - ) - latent = model.inference(self.data) - cb.utils.write_hybrid_path(latent, "./regression_test/gau.h5//latent") - - def test_catgau(self): - model = cb.directi.fit_DIRECTi( - self.data, genes=self.data.uns["scmap_genes"], - latent_dim=10, cat_dim=10, epoch=3, path="./test_directi" - ) - latent = model.inference(self.data) - cb.utils.write_hybrid_path(latent, "./regression_test/catgau.h5//latent") - - def test_semisupervised_catgau(self): - ''' - self.data.obs.loc[ - cb.data.annotation_confidence(self.data, "cell_type1")[1] <= 0.5, - "cell_type1" - ] = "" - ''' - model = cb.directi.fit_DIRECTi( - self.data, genes=self.data.uns["scmap_genes"], - latent_dim=10, supervision="cell_type1", - epoch=3, path="./test_directi" - ) - latent = model.inference(self.data) - cb.utils.write_hybrid_path(latent, "./regression_test/semisupervised_catgau.h5//latent") - - def test_rmbatch(self): - model = cb.directi.fit_DIRECTi( - self.data, genes=self.data.uns["scmap_genes"], - latent_dim=10, batch_effect="cell_type1", # Just for test - epoch=3, path="./test_directi" - ) - latent = model.inference(self.data) - cb.utils.write_hybrid_path(latent, "./regression_test/rmbatch.h5//latent") - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/test/regression_test.py b/test/regression_test.py deleted file mode 100644 index 66242ce..0000000 --- a/test/regression_test.py +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin/env python - -import sys -import os -import shutil -import unittest -import numpy as np -import anndata - -sys.path.insert(0, "..") -import Cell_BLAST as cb -cb.config.RANDOM_SEED = 0 - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" -os.environ["CUDA_VISIBLE_DEVICES"] = "0" - - -class DirectiTest(unittest.TestCase): - - def setUp(self): - self.data = anndata.read_h5ad("pollen.h5ad") - cb.data.normalize(self.data) - - def tearDown(self): - if os.path.exists("./test_directi"): - shutil.rmtree("./test_directi") - - def test_gau(self): - model = cb.directi.fit_DIRECTi( - self.data, genes=self.data.uns["scmap_genes"], - latent_dim=10, epoch=3, path="./test_directi" - ) - latent = model.inference(self.data) - regression = cb.utils.read_hybrid_path("./regression_test/gau.h5//latent") - deviation = np.abs(latent - regression).max() - print(deviation) - self.assertAlmostEqual(deviation, 0, places=6) - - def test_catgau(self): - model = cb.directi.fit_DIRECTi( - self.data, genes=self.data.uns["scmap_genes"], - latent_dim=10, cat_dim=10, epoch=3, path="./test_directi" - ) - latent = model.inference(self.data) - regression = cb.utils.read_hybrid_path("./regression_test/catgau.h5//latent") - deviation = np.abs(latent - regression).max() - print(deviation) - self.assertAlmostEqual(deviation, 0, places=6) - - def test_semisupervised_catgau(self): - ''' - self.data.obs.loc[ - cb.data.annotation_confidence(self.data, "cell_type1")[1] <= 0.5, - "cell_type1" - ] = "" - ''' - model = cb.directi.fit_DIRECTi( - self.data, genes=self.data.uns["scmap_genes"], - latent_dim=10, supervision="cell_type1", - epoch=3, path="./test_directi" - ) - latent = model.inference(self.data) - regression = cb.utils.read_hybrid_path("./regression_test/semisupervised_catgau.h5//latent") - deviation = np.abs(latent - regression).max() - print(deviation) - self.assertAlmostEqual(deviation, 0, places=6) - - def test_rmbatch(self): - model = cb.directi.fit_DIRECTi( - self.data, genes=self.data.uns["scmap_genes"], - latent_dim=10, batch_effect="cell_type1", # Just for test - epoch=3, path="./test_directi" - ) - latent = model.inference(self.data) - regression = cb.utils.read_hybrid_path("./regression_test/rmbatch.h5//latent") - deviation = np.abs(latent - regression).max() - print(deviation) - self.assertAlmostEqual(deviation, 0, places=6) - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file