From a355b8d59fbad93d13285185e47e0de7fcdb5e3f Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Thu, 9 Jun 2022 12:49:37 -0400 Subject: [PATCH 01/18] first git commit --- .pre-commit-config.yaml | 2 +- {mypackage => hmivae}/__init__.py | 0 hmivae/_hmivae_base_components.py | 183 +++++ .../_mymodel.py => hmivae/_hmivae_model.py | 13 +- hmivae/_hmivae_module.py | 690 ++++++++++++++++++ {mypackage => hmivae}/_mypyromodel.py | 0 {mypackage => hmivae}/_mypyromodule.py | 0 mypackage/_mymodule.py | 293 -------- pl_vae_scripts_new/HMIDataset.py | 28 + pl_vae_scripts_new/ScModeDataloader.py | 93 +++ pl_vae_scripts_new/pl_vae_run_refact.py | 454 ++++++++++++ 11 files changed, 1458 insertions(+), 298 deletions(-) rename {mypackage => hmivae}/__init__.py (100%) create mode 100644 hmivae/_hmivae_base_components.py rename mypackage/_mymodel.py => hmivae/_hmivae_model.py (87%) create mode 100644 hmivae/_hmivae_module.py rename {mypackage => hmivae}/_mypyromodel.py (100%) rename {mypackage => hmivae}/_mypyromodule.py (100%) delete mode 100644 mypackage/_mymodule.py create mode 100644 pl_vae_scripts_new/HMIDataset.py create mode 100644 pl_vae_scripts_new/ScModeDataloader.py create mode 100644 pl_vae_scripts_new/pl_vae_run_refact.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a1897b8..e93fd58 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/python/black - rev: 20.8b1 + rev: 22.3.0 hooks: - id: black - repo: https://gitlab.com/pycqa/flake8 diff --git a/mypackage/__init__.py b/hmivae/__init__.py similarity index 100% rename from mypackage/__init__.py rename to hmivae/__init__.py diff --git a/hmivae/_hmivae_base_components.py b/hmivae/_hmivae_base_components.py new file mode 100644 index 0000000..d19f920 --- /dev/null +++ b/hmivae/_hmivae_base_components.py @@ -0,0 +1,183 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class EncoderHMIVAE(nn.Module): + """Encoder for the case in which data is merged after initial encoding + input_exp_dim: Dimension for the original mean expression input + input_corr_dim: Dimension for the original correlations input + input_morph_dim: Dimension for the original morphology input + input_spcont_dim: Dimension for the original spatial context input + E_me: Dimension for the encoded mean expressions input + E_cr: Dimension for the encoded correlations input + E_mr: Dimension for the encoded morphology input + E_sc: Dimension for the encoded spatial context input + latent_dim: Dimension of the encoded output + n_hidden: Number of hidden layers, default=1 + """ + + def __init__( + self, + input_exp_dim, + input_corr_dim, + input_morph_dim, + input_spcont_dim, + E_me, + E_cr, + E_mr, + E_sc, + latent_dim, + n_hidden=1, + ): + super().__init__() + hidden_dim = E_me + E_cr + E_mr + E_sc + + self.input_exp = nn.Linear(input_exp_dim, E_me) + self.exp_hidden = nn.Linear(E_me, E_me) + + self.input_corr = nn.Linear(input_corr_dim, E_cr) + self.corr_hidden = nn.Linear(E_cr, E_cr) + + self.input_morph = nn.Linear(input_morph_dim, E_mr) + self.morph_hidden = nn.Linear(E_mr, E_mr) + + self.input_spatial_context = nn.Linear(input_spcont_dim, E_sc) + self.spatial_context_hidden = nn.Linear(E_sc, E_sc) + + self.linear = nn.ModuleList( + [nn.Linear(hidden_dim, hidden_dim) for i in range(n_hidden)] + ) + + self.mu_z = nn.Linear(hidden_dim, latent_dim) + self.std_z = nn.Linear(hidden_dim, latent_dim) + + def forward(self, x_mean, x_correlations, x_morphology, x_spatial_context): + h_mean = F.elu(self.input_exp(x_mean)) + h_mean2 = F.elu(self.exp_hidden(h_mean)) + + h_correlations = F.elu(self.input_corr(x_correlations)) + h_correlations2 = F.elu(self.corr_hidden(h_correlations)) + + h_morphology = F.elu(self.input_morph(x_morphology)) + h_morphology2 = F.elu(self.morph_hidden(h_morphology)) + + z1 = torch.cat([h_mean2, h_correlations2, h_morphology2], 1) + + h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) + h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context)) + + h = torch.cat([h_mean2, h_correlations2, h_morphology2, h_spatial_context2], 1) + + for net in self.linear: + h = F.elu(net(h)) + + mu_z = self.mu_z(h) + + log_std_z = self.std_z(h) + + return mu_z, log_std_z, z1 + + +class DecoderHMIVAE(nn.Module): + """ + Decoder for the case where data is merged after inital encoding + latent_dim: Dimension of the encoded input + E_me: Dimension for the encoded mean expressions input + E_cr: Dimension for the encoded correlations input + E_mr: Dimension for the encoded morphology input + E_sc: Dimension for the encoded spatial context input + input_exp_dim: Dimension for the decoded mean expression output + input_corr_dim: Dimension for the decoded correlations output + input_morph_dim: Dimension for the decoded morphology input + input_spcont_dim: Dimension for the decoded spatial context input + n_hidden: Number of hidden layers, default=1 + """ + + def __init__( + self, + latent_dim, + E_me, + E_cr, + E_mr, + E_sc, + input_exp_dim, + input_corr_dim, + input_morph_dim, + input_spcont_dim, + n_hidden=1, + ): + super().__init__() + hidden_dim = E_me + E_cr + E_mr + E_sc + self.linear = nn.ModuleList( + [nn.Linear(hidden_dim, hidden_dim) for i in range(n_hidden)] + ) + # mean expression + self.exp_hidden = nn.Linear(E_me, E_me) + self.mu_x_exp = nn.Linear(E_me, input_exp_dim) + self.std_x_exp = nn.Linear(E_me, input_exp_dim) + + # correlations/co-localizations + self.corr_hidden = nn.Linear(E_cr, E_cr) + self.mu_x_corr = nn.Linear(E_cr, input_corr_dim) + self.std_x_corr = nn.Linear(E_cr, input_corr_dim) + + # morphology + self.morph_hidden = nn.Linear(E_mr, E_mr) + self.mu_x_morph = nn.Linear(E_mr, input_morph_dim) + self.std_x_morph = nn.Linear(E_mr, input_morph_dim) + + # spatial context + self.spatial_context_hidden = nn.Linear(E_sc, E_sc) + self.mu_x_spcont = nn.Linear(E_sc, input_spcont_dim) + self.std_x_spcont = nn.Linear(E_sc, input_spcont_dim) + + def forward(self, z): + out = F.elu(self.input(z)) + for net in self.linear: + out = F.elu(net(out)) + + h2_mean = F.elu(self.exp_hidden(out[:, 0 : self.E_me])) + h2_correlations = F.elu( + self.corr_hidden(out[:, self.E_me : self.E_me + self.E_cr]) + ) + h2_morphology = F.elu( + self.morph_hidden( + out[:, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr] + ) + ) + h2_spatial_context = F.elu( + self.spatial_context_hidden(out[:, self.E_me + self.E_cr + self.E_mr :]) + ) + + mu_x_exp = self.mu_x_exp(h2_mean) + std_x_exp = self.std_x_exp(h2_mean) + + if self.use_weights: + with torch.no_grad(): + weights = self.get_corr_weights_per_cell( + mu_x_exp.detach() + ) # calculating correlation weights + else: + weights = None + + mu_x_corr = self.mu_x_corr(h2_correlations) + std_x_corr = self.std_x_corr(h2_correlations) + + mu_x_morph = self.mu_x_morph(h2_morphology) + std_x_morph = self.std_x_morph(h2_morphology) + + mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context) + std_x_spatial_context = self.std_x_spcont(h2_spatial_context) + + return ( + mu_x_exp, + std_x_exp, + mu_x_corr, + std_x_corr, + mu_x_morph, + std_x_morph, + mu_x_spatial_context, + std_x_spatial_context, + weights, + ) diff --git a/mypackage/_mymodel.py b/hmivae/_hmivae_model.py similarity index 87% rename from mypackage/_mymodel.py rename to hmivae/_hmivae_model.py index 8b4faea..1d52197 100644 --- a/mypackage/_mymodel.py +++ b/hmivae/_hmivae_model.py @@ -7,12 +7,12 @@ from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin from scvi.utils import setup_anndata_dsp -from ._mymodule import MyModule +from ._hmivae_module import hmiVAE logger = logging.getLogger(__name__) -class MyModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): +class hmivaeModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): """ Skeleton for an scvi-tools model. @@ -47,7 +47,7 @@ def __init__( n_layers: int = 1, **model_kwargs, ): - super(MyModel, self).__init__(adata) + super(hmiVAE, self).__init__(adata) library_log_means, library_log_vars = _init_library_size( adata, self.summary_stats["n_batch"] @@ -55,7 +55,7 @@ def __init__( # self.summary_stats provides information about anndata dimensions and other tensor info - self.module = MyModule( + self.module = hmiVAE( n_input=self.summary_stats["n_vars"], n_hidden=n_hidden, n_latent=n_latent, @@ -74,6 +74,11 @@ def __init__( @setup_anndata_dsp.dedent def setup_anndata( adata: AnnData, + protein_correlations_obsm_key: str, + cell_morphology_obsm_key: str, + cell_spatial_context_obsm_key: str, + protein_correlations_names_uns_key: Optional[str] = None, + cell_morphology_names_uns_key: Optional[str] = None, batch_key: Optional[str] = None, labels_key: Optional[str] = None, layer: Optional[str] = None, diff --git a/hmivae/_hmivae_module.py b/hmivae/_hmivae_module.py new file mode 100644 index 0000000..f655652 --- /dev/null +++ b/hmivae/_hmivae_module.py @@ -0,0 +1,690 @@ +from typing import Iterable, Optional # Dict, Tuple, Union + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from _hmivae_base_components import DecoderHMIVAE, EncoderHMIVAE + +# from scvi import _CONSTANTS +# from scvi.distributions import ZeroInflatedNegativeBinomial +# from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data +# from scvi.nn import one_hot +# from torch.distributions import Normal +# from torch.distributions import kl_divergence as kl + +torch.backends.cudnn.benchmark = True + + +# class HMIVAE(BaseModuleClass): +# """ +# Variational auto-encoder model. + +# Here we implement a basic version of scVI's underlying VAE [Lopez18]_. +# This implementation is for instructional purposes only. + +# Parameters +# ---------- +# n_input +# Number of input genes +# library_log_means +# 1 x n_batch array of means of the log library sizes. Parameterizes prior on library size if +# not using observed library size. +# library_log_vars +# 1 x n_batch array of variances of the log library sizes. Parameterizes prior on library size if +# not using observed library size. +# n_batch +# Number of batches, if 0, no batch correction is performed. +# n_hidden +# Number of nodes per hidden layer +# n_latent +# Dimensionality of the latent space +# n_layers +# Number of hidden layers used for encoder and decoder NNs +# dropout_rate +# Dropout rate for neural networks +# """ + +# def __init__( +# self, +# n_input: int, +# n_batch: int = 0, +# n_hidden: int = 128, +# n_latent: int = 10, +# n_layers: int = 1, +# dropout_rate: float = 0.1, +# ): +# # def __init__( +# # self, +# # input_exp_dim: int, +# # input_corr_dim: int, +# # input_morph_dim: int, +# # input_spcont_dim: int, +# # E_me: int, +# # E_cr: int, +# # E_mr: int, +# # E_sc: int, +# # n_latent: int = 10, +# # n_batch: int = 0, +# # n_hidden: int = 1, +# # ): +# super().__init__() +# self.n_latent = n_latent +# self.n_batch = n_batch +# # this is needed to comply with some requirement of the VAEMixin class +# self.latent_distribution = "normal" + +# self.register_buffer( +# "library_log_means", torch.from_numpy(library_log_means).float() +# ) +# self.register_buffer( +# "library_log_vars", torch.from_numpy(library_log_vars).float() +# ) + +# # setup the parameters of your generative model, as well as your inference model +# self.px_r = torch.nn.Parameter(torch.randn(n_input)) +# # z encoder goes from the n_input-dimensional data to an n_latent-d +# # latent space representation +# self.z_encoder = EncoderHMIVAE( +# n_input, +# n_latent, +# n_layers=n_layers, +# n_hidden=n_hidden, +# dropout_rate=dropout_rate, +# ) +# # l encoder goes from n_input-dimensional data to 1-d library size +# self.l_encoder = EncoderHMIVAE( +# n_input, +# 1, +# n_layers=1, +# n_hidden=n_hidden, +# dropout_rate=dropout_rate, +# ) +# # decoder goes from n_latent-dimensional space to n_input-d data +# self.decoder = DecoderHMIVAE( +# n_latent, +# n_input, +# n_layers=n_layers, +# n_hidden=n_hidden, +# ) + +# def _get_inference_input(self, tensors): +# """Parse the dictionary to get appropriate args""" +# x = tensors[_CONSTANTS.X_KEY] + +# input_dict = dict(x=x) +# return input_dict + +# def _get_generative_input(self, tensors, inference_outputs): +# z = inference_outputs["z"] +# library = inference_outputs["library"] + +# input_dict = { +# "z": z, +# "library": library, +# } +# return input_dict + +# @auto_move_data +# def inference(self, x): +# """ +# High level inference method. + +# Runs the inference (encoder) model. +# """ +# # log the input to the variational distribution for numerical stability +# x_ = torch.log(1 + x) +# # get variational parameters via the encoder networks +# qz_m, qz_v, z = self.z_encoder(x_) +# ql_m, ql_v, library = self.l_encoder(x_) + +# outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v, ql_m=ql_m, ql_v=ql_v, library=library) +# return outputs + +# @auto_move_data +# def generative(self, z, library): +# """Runs the generative model.""" + +# # form the parameters of the ZINB likelihood +# px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) +# px_r = torch.exp(self.px_r) + +# return dict( +# px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout +# ) + +# def loss( +# self, +# tensors, +# inference_outputs, +# generative_outputs, +# kl_weight: float = 1.0, +# ): +# x = tensors[_CONSTANTS.X_KEY] +# qz_m = inference_outputs["qz_m"] +# qz_v = inference_outputs["qz_v"] +# ql_m = inference_outputs["ql_m"] +# ql_v = inference_outputs["ql_v"] +# px_rate = generative_outputs["px_rate"] +# px_r = generative_outputs["px_r"] +# px_dropout = generative_outputs["px_dropout"] + +# mean = torch.zeros_like(qz_m) +# scale = torch.ones_like(qz_v) + +# kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( +# dim=1 +# ) + +# batch_index = tensors[_CONSTANTS.BATCH_KEY] +# n_batch = self.library_log_means.shape[1] +# local_library_log_means = F.linear( +# one_hot(batch_index, n_batch), self.library_log_means +# ) +# local_library_log_vars = F.linear( +# one_hot(batch_index, n_batch), self.library_log_vars +# ) + +# kl_divergence_l = kl( +# Normal(ql_m, torch.sqrt(ql_v)), +# Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), +# ).sum(dim=1) + +# reconst_loss = ( +# -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) +# .log_prob(x) +# .sum(dim=-1) +# ) + +# kl_local_for_warmup = kl_divergence_z +# kl_local_no_warmup = kl_divergence_l + +# weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup + +# loss = torch.mean(reconst_loss + weighted_kl_local) + +# kl_local = dict( +# kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z +# ) +# kl_global = torch.tensor(0.0) +# return LossRecorder(loss, reconst_loss, kl_local, kl_global) + +# @torch.no_grad() +# def sample( +# self, +# tensors, +# n_samples=1, +# library_size=1, +# ) -> np.ndarray: +# r""" +# Generate observation samples from the posterior predictive distribution. + +# The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. + +# Parameters +# ---------- +# tensors +# Tensors dict +# n_samples +# Number of required samples for each cell +# library_size +# Library size to scale scamples to + +# Returns +# ------- +# x_new : :py:class:`torch.Tensor` +# tensor with shape (n_cells, n_genes, n_samples) +# """ +# inference_kwargs = dict(n_samples=n_samples) +# _, generative_outputs, = self.forward( +# tensors, +# inference_kwargs=inference_kwargs, +# compute_loss=False, +# ) + +# px_r = generative_outputs["px_r"] +# px_rate = generative_outputs["px_rate"] +# px_dropout = generative_outputs["px_dropout"] + +# dist = ZeroInflatedNegativeBinomial( +# mu=px_rate, theta=px_r, zi_logits=px_dropout +# ) + +# if n_samples > 1: +# exprs = dist.sample().permute( +# [1, 2, 0] +# ) # Shape : (n_cells_batch, n_genes, n_samples) +# else: +# exprs = dist.sample() + +# return exprs.cpu() + +# @torch.no_grad() +# @auto_move_data +# def marginal_ll(self, tensors, n_mc_samples): +# sample_batch = tensors[_CONSTANTS.X_KEY] +# batch_index = tensors[_CONSTANTS.BATCH_KEY] + +# to_sum = torch.zeros(sample_batch.size()[0], n_mc_samples) + +# for i in range(n_mc_samples): +# # Distribution parameters and sampled variables +# inference_outputs, _, losses = self.forward(tensors) +# qz_m = inference_outputs["qz_m"] +# qz_v = inference_outputs["qz_v"] +# z = inference_outputs["z"] +# ql_m = inference_outputs["ql_m"] +# ql_v = inference_outputs["ql_v"] +# library = inference_outputs["library"] + +# # Reconstruction Loss +# reconst_loss = losses.reconstruction_loss + +# # Log-probabilities +# n_batch = self.library_log_means.shape[1] +# local_library_log_means = F.linear( +# one_hot(batch_index, n_batch), self.library_log_means +# ) +# local_library_log_vars = F.linear( +# one_hot(batch_index, n_batch), self.library_log_vars +# ) +# p_l = ( +# Normal(local_library_log_means, local_library_log_vars.sqrt()) +# .log_prob(library) +# .sum(dim=-1) +# ) + +# p_z = ( +# Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)) +# .log_prob(z) +# .sum(dim=-1) +# ) +# p_x_zl = -reconst_loss +# q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) +# q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1) + +# to_sum[:, i] = p_z + p_l + p_x_zl - q_z_x - q_l_x + +# batch_log_lkl = torch.logsumexp(to_sum, dim=-1) - np.log(n_mc_samples) +# log_lkl = torch.sum(batch_log_lkl).item() +# return log_lkl + + +class hmiVAE(pl.LightningModule): + """ + Variational Autoencoder for hmiVAE based on pytorch-lightning. + """ + + def __init__( + self, + input_exp_dim: int, + input_corr_dim: int, + input_morph_dim: int, + input_spcont_dim: int, + E_me: int = 32, + E_cr: int = 32, + E_mr: int = 32, + E_sc: int = 32, + latent_dim: int = 10, + n_hidden: int = 1, + ): + super().__init__() + # hidden_dim = E_me + E_cr + E_mr + E_sc + + self.encoder = EncoderHMIVAE( + input_exp_dim, + input_corr_dim, + input_morph_dim, + input_spcont_dim, + E_me, + E_cr, + E_mr, + E_sc, + latent_dim, + ) + + self.decoder = DecoderHMIVAE( + latent_dim, + E_me, + E_cr, + E_mr, + E_sc, + input_exp_dim, + input_corr_dim, + input_morph_dim, + input_spcont_dim, + ) + + def reparameterization(self, mu, log_std): + std = torch.exp(log_std) + eps = torch.randn_like(log_std) + + # sampling from encoded distribution + z_samples = mu + eps * std + + return z_samples + + def KL_div(self, enc_x_mu, enc_x_logstd, z): + """Takes in the encoded x mu and sigma, and the z sampled from + q, and outputs the KL-Divergence term in ELBO""" + + p = torch.distributions.Normal( + torch.zeros_like(enc_x_mu), torch.ones_like(enc_x_logstd) + ) + enc_x_std = torch.exp(enc_x_logstd) + q = torch.distributions.Normal(enc_x_mu, enc_x_std + 1e-6) + + log_q_zx = q.log_prob(z) + log_p_z = p.log_prob(z) + + kl = log_q_zx - log_p_z + kl = kl.sum(-1) + + return kl + + def em_recon_loss( + self, + dec_x_mu_exp, + dec_x_logstd_exp, + dec_x_mu_corr, + dec_x_logstd_corr, + dec_x_mu_morph, + dec_x_logstd_morph, + dec_x_mu_spcont, + dec_x_logstd_spcont, + y, + s, + m, + c, + weights=None, + ): + """Takes in the parameters output from the decoder, + and the original input x, and gives the reconstruction + loss term in ELBO + dec_x_mu_exp: torch.Tensor, decoded means for protein expression feature + dec_x_logstd_exp: torch.Tensor, decoded log std for protein expression feature + dec_x_mu_corr: torch.Tensor, decoded means for correlation feature + dec_x_logstd_corr: torch.Tensor, decoded log std for correlations feature + dec_x_mu_morph: torch.Tensor, decoded means for morphology feature + dec_x_logstd_morph: torch.Tensor, decoded log std for morphology feature + dec_x_mu_spcont: torch.Tensor, decoded means for spatial context feature + dec_x_logstd_spcont: torch.Tensor, decoded log std for spatial context feature + y: torch.Tensor, original mean expression input + s: torch.Tensor, original correlation input + m: torch.Tensor, original morphology input + c: torch.Tensor, original cell context input + weights: torch.Tensor, weights calculated from decoded means for protein expression feature + """ + + dec_x_std_exp = torch.exp(dec_x_logstd_exp) + dec_x_std_corr = torch.exp(dec_x_logstd_corr) + dec_x_std_morph = torch.exp(dec_x_logstd_morph) + dec_x_std_spcont = torch.exp(dec_x_logstd_spcont) + p_rec_exp = torch.distributions.Normal(dec_x_mu_exp, dec_x_std_exp + 1e-6) + p_rec_corr = torch.distributions.Normal(dec_x_mu_corr, dec_x_std_corr + 1e-6) + p_rec_morph = torch.distributions.Normal(dec_x_mu_morph, dec_x_std_morph + 1e-6) + p_rec_spcont = torch.distributions.Normal( + dec_x_mu_spcont, dec_x_std_spcont + 1e-6 + ) + + log_p_xz_exp = p_rec_exp.log_prob(y) + log_p_xz_morph = p_rec_morph.log_prob(m) + log_p_xz_spcont = p_rec_spcont.log_prob(c) # already dense matrix + + if weights is None: + log_p_xz_corr = p_rec_corr.log_prob(s) + else: + log_p_xz_corr = torch.mul( + weights, p_rec_corr.log_prob(s) + ) # does element-wise multiplication + + log_p_xz_exp = log_p_xz_exp.sum(-1) + log_p_xz_corr = log_p_xz_corr.sum(-1) + log_p_xz_morph = log_p_xz_morph.sum(-1) + log_p_xz_spcont = log_p_xz_spcont.sum(-1) + + return log_p_xz_exp, log_p_xz_corr, log_p_xz_morph, log_p_xz_spcont + + def neg_ELBO( + self, + enc_x_mu, + enc_x_logstd, + dec_x_mu_exp, + dec_x_logstd_exp, + dec_x_mu_corr, + dec_x_logstd_corr, + dec_x_mu_morph, + dec_x_logstd_morph, + dec_x_mu_spcont, + dec_x_logstd_spcont, + z, + y, + s, + m, + c, + weights=None, + ): + kl_div = self.KL_div(enc_x_mu, enc_x_logstd, z) + + recon_lik_me, recon_lik_corr, recon_lik_mor, recon_lik_sc = self.em_recon_loss( + dec_x_mu_exp, + dec_x_logstd_exp, + dec_x_mu_corr, + dec_x_logstd_corr, + dec_x_mu_morph, + dec_x_logstd_morph, + dec_x_mu_spcont, + dec_x_logstd_spcont, + y, + s, + m, + c, + weights, + ) + return ( + kl_div, + recon_lik_me, + recon_lik_corr, + recon_lik_mor, + recon_lik_sc, + ) + + def loss(self, kl_div, recon_loss, beta: float = 1.0): + + return beta * kl_div.mean() - recon_loss.mean() + + def training_step( + self, + train_batch, + spatial_context, + batch_idx, + categories: Optional[Iterable[int]] = None, + corr_weights=False, + recon_weights=np.array([1.0, 1.0, 1.0, 1.0]), + beta=1.0, + ): + """ + Carries out the training step. + train_batch: torch.Tensor. Training data, + spatial_context: torch.Tensor. Matrix with old mu_z integrated neighbours information, + corr_weights: numpy.array. Array with weights for the correlations for each cell. + recon_weights: numpy.array. Array with weights for each view during loss calculation. + beta: float. Coefficient for KL-Divergence term in ELBO. + """ + Y = train_batch[0] + S = train_batch[1] + M = train_batch[2] + spatial_context = train_batch[3] + + mu_z, log_std_z = self.encoder(Y, S, M, spatial_context) + + z_samples = self.reparameterization(mu_z, log_std_z) + + # decoding + ( + mu_x_exp_hat, + log_std_x_exp_hat, + mu_x_corr_hat, + log_std_x_corr_hat, + mu_x_morph_hat, + log_std_x_morph_hat, + mu_x_spcont_hat, + log_std_x_spcont_hat, + weights, + ) = self.decoder(z_samples) + + ( + kl_div, + recon_lik_me, + recon_lik_corr, + recon_lik_mor, + recon_lik_sc, + ) = self.neg_ELBO( + mu_z, + log_std_z, + mu_x_exp_hat, + log_std_x_exp_hat, + mu_x_corr_hat, + log_std_x_corr_hat, + mu_x_morph_hat, + log_std_x_morph_hat, + mu_x_spcont_hat, + log_std_x_spcont_hat, + z_samples, + Y, + S, + M, + spatial_context, + weights, + ) + + recon_loss = ( + recon_weights[0] * recon_lik_me + + recon_weights[1] * recon_lik_corr + + recon_weights[2] * recon_lik_mor + + recon_weights[3] * recon_lik_sc + ) + + loss = self.loss(kl_div, recon_loss, beta=beta) + + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) + + return ( + loss, + kl_div.mean().item(), + recon_loss.mean().item(), + recon_lik_me.mean().item(), + recon_lik_corr.mean().item(), + recon_lik_mor.mean().item(), + recon_lik_sc.mean().item(), + ) + + def test_step( + self, + test_batch, + spatial_context, + batch_idx, + corr_weights=False, + recon_weights=np.array([1.0, 1.0, 1.0, 1.0]), + beta=1.0, + ): + """ + Carries out the validation/test step. + test_batch: torch.Tensor. Validation/test data, + spatial_context: torch.Tensor. Matrix with old mu_z integrated neighbours information, + corr_weights: numpy.array. Array with weights for the correlations for each cell. + recon_weights: numpy.array. Array with weights for each view during loss calculation. + beta: float. Coefficient for KL-Divergence term in ELBO. + """ + Y = test_batch[0] + S = test_batch[1] + M = test_batch[2] + spatial_context = test_batch[3] + + mu_z, log_std_z, z1 = self.encode(Y, S, M, spatial_context) + + z_samples = self.reparameterization(mu_z, log_std_z) + + # decoding + ( + mu_x_exp_hat, + log_std_x_exp_hat, + mu_x_corr_hat, + log_std_x_corr_hat, + mu_x_morph_hat, + log_std_x_morph_hat, + mu_x_spcont_hat, + log_std_x_spcont_hat, + weights, + ) = self.decode(z_samples) + + ( + kl_div, + recon_lik_me, + recon_lik_corr, + recon_lik_mor, + recon_lik_sc, + ) = self.neg_ELBO( + mu_z, + log_std_z, + mu_x_exp_hat, + log_std_x_exp_hat, + mu_x_corr_hat, + log_std_x_corr_hat, + mu_x_morph_hat, + log_std_x_morph_hat, + mu_x_spcont_hat, + log_std_x_spcont_hat, + z_samples, + Y, + S, + M, + spatial_context, + weights, + ) + + recon_loss = ( + recon_weights[0] * recon_lik_me + + recon_weights[1] * recon_lik_corr + + recon_weights[2] * recon_lik_mor + + recon_weights[3] * recon_lik_sc + ) + + loss = self.loss(kl_div, recon_loss, beta=beta) + + self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True) + + return ( + loss, + kl_div.mean().item(), + recon_loss.mean().item(), + recon_lik_me.mean().item(), + recon_lik_corr.mean().item(), + recon_lik_mor.mean().item(), + recon_lik_sc.mean().item(), + ) + + def configure_optimizers(self): + """Optimizer""" + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer + + def __get_input_embeddings__( + self, x_mean, x_correlations, x_morphology, x_spatial_context + ): + """ + Returns the view-specific embeddings. + """ + h_mean = F.elu(self.input_exp(x_mean)) + h_mean2 = F.elu(self.exp_hidden(h_mean)) + + h_correlations = F.elu(self.input_corr(x_correlations)) + h_correlations2 = F.elu(self.corr_hidden(h_correlations)) + + h_morphology = F.elu(self.input_morph(x_morphology)) + h_morphology2 = F.elu(self.morph_hidden(h_morphology)) + + h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) + h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context)) + + return h_mean2, h_correlations2, h_morphology2, h_spatial_context2 diff --git a/mypackage/_mypyromodel.py b/hmivae/_mypyromodel.py similarity index 100% rename from mypackage/_mypyromodel.py rename to hmivae/_mypyromodel.py diff --git a/mypackage/_mypyromodule.py b/hmivae/_mypyromodule.py similarity index 100% rename from mypackage/_mypyromodule.py rename to hmivae/_mypyromodule.py diff --git a/mypackage/_mymodule.py b/mypackage/_mymodule.py deleted file mode 100644 index ed41b30..0000000 --- a/mypackage/_mymodule.py +++ /dev/null @@ -1,293 +0,0 @@ -import numpy as np -import torch -import torch.nn.functional as F -from scvi import _CONSTANTS -from scvi.distributions import ZeroInflatedNegativeBinomial -from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data -from scvi.nn import DecoderSCVI, Encoder, one_hot -from torch.distributions import Normal -from torch.distributions import kl_divergence as kl - -torch.backends.cudnn.benchmark = True - - -class MyModule(BaseModuleClass): - """ - Skeleton Variational auto-encoder model. - - Here we implement a basic version of scVI's underlying VAE [Lopez18]_. - This implementation is for instructional purposes only. - - Parameters - ---------- - n_input - Number of input genes - library_log_means - 1 x n_batch array of means of the log library sizes. Parameterizes prior on library size if - not using observed library size. - library_log_vars - 1 x n_batch array of variances of the log library sizes. Parameterizes prior on library size if - not using observed library size. - n_batch - Number of batches, if 0, no batch correction is performed. - n_hidden - Number of nodes per hidden layer - n_latent - Dimensionality of the latent space - n_layers - Number of hidden layers used for encoder and decoder NNs - dropout_rate - Dropout rate for neural networks - """ - - def __init__( - self, - n_input: int, - library_log_means: np.ndarray, - library_log_vars: np.ndarray, - n_batch: int = 0, - n_hidden: int = 128, - n_latent: int = 10, - n_layers: int = 1, - dropout_rate: float = 0.1, - ): - super().__init__() - self.n_latent = n_latent - self.n_batch = n_batch - # this is needed to comply with some requirement of the VAEMixin class - self.latent_distribution = "normal" - - self.register_buffer( - "library_log_means", torch.from_numpy(library_log_means).float() - ) - self.register_buffer( - "library_log_vars", torch.from_numpy(library_log_vars).float() - ) - - # setup the parameters of your generative model, as well as your inference model - self.px_r = torch.nn.Parameter(torch.randn(n_input)) - # z encoder goes from the n_input-dimensional data to an n_latent-d - # latent space representation - self.z_encoder = Encoder( - n_input, - n_latent, - n_layers=n_layers, - n_hidden=n_hidden, - dropout_rate=dropout_rate, - ) - # l encoder goes from n_input-dimensional data to 1-d library size - self.l_encoder = Encoder( - n_input, - 1, - n_layers=1, - n_hidden=n_hidden, - dropout_rate=dropout_rate, - ) - # decoder goes from n_latent-dimensional space to n_input-d data - self.decoder = DecoderSCVI( - n_latent, - n_input, - n_layers=n_layers, - n_hidden=n_hidden, - ) - - def _get_inference_input(self, tensors): - """Parse the dictionary to get appropriate args""" - x = tensors[_CONSTANTS.X_KEY] - - input_dict = dict(x=x) - return input_dict - - def _get_generative_input(self, tensors, inference_outputs): - z = inference_outputs["z"] - library = inference_outputs["library"] - - input_dict = { - "z": z, - "library": library, - } - return input_dict - - @auto_move_data - def inference(self, x): - """ - High level inference method. - - Runs the inference (encoder) model. - """ - # log the input to the variational distribution for numerical stability - x_ = torch.log(1 + x) - # get variational parameters via the encoder networks - qz_m, qz_v, z = self.z_encoder(x_) - ql_m, ql_v, library = self.l_encoder(x_) - - outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v, ql_m=ql_m, ql_v=ql_v, library=library) - return outputs - - @auto_move_data - def generative(self, z, library): - """Runs the generative model.""" - - # form the parameters of the ZINB likelihood - px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) - px_r = torch.exp(self.px_r) - - return dict( - px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout - ) - - def loss( - self, - tensors, - inference_outputs, - generative_outputs, - kl_weight: float = 1.0, - ): - x = tensors[_CONSTANTS.X_KEY] - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] - ql_m = inference_outputs["ql_m"] - ql_v = inference_outputs["ql_v"] - px_rate = generative_outputs["px_rate"] - px_r = generative_outputs["px_r"] - px_dropout = generative_outputs["px_dropout"] - - mean = torch.zeros_like(qz_m) - scale = torch.ones_like(qz_v) - - kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( - dim=1 - ) - - batch_index = tensors[_CONSTANTS.BATCH_KEY] - n_batch = self.library_log_means.shape[1] - local_library_log_means = F.linear( - one_hot(batch_index, n_batch), self.library_log_means - ) - local_library_log_vars = F.linear( - one_hot(batch_index, n_batch), self.library_log_vars - ) - - kl_divergence_l = kl( - Normal(ql_m, torch.sqrt(ql_v)), - Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), - ).sum(dim=1) - - reconst_loss = ( - -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) - .log_prob(x) - .sum(dim=-1) - ) - - kl_local_for_warmup = kl_divergence_z - kl_local_no_warmup = kl_divergence_l - - weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup - - loss = torch.mean(reconst_loss + weighted_kl_local) - - kl_local = dict( - kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z - ) - kl_global = torch.tensor(0.0) - return LossRecorder(loss, reconst_loss, kl_local, kl_global) - - @torch.no_grad() - def sample( - self, - tensors, - n_samples=1, - library_size=1, - ) -> np.ndarray: - r""" - Generate observation samples from the posterior predictive distribution. - - The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. - - Parameters - ---------- - tensors - Tensors dict - n_samples - Number of required samples for each cell - library_size - Library size to scale scamples to - - Returns - ------- - x_new : :py:class:`torch.Tensor` - tensor with shape (n_cells, n_genes, n_samples) - """ - inference_kwargs = dict(n_samples=n_samples) - _, generative_outputs, = self.forward( - tensors, - inference_kwargs=inference_kwargs, - compute_loss=False, - ) - - px_r = generative_outputs["px_r"] - px_rate = generative_outputs["px_rate"] - px_dropout = generative_outputs["px_dropout"] - - dist = ZeroInflatedNegativeBinomial( - mu=px_rate, theta=px_r, zi_logits=px_dropout - ) - - if n_samples > 1: - exprs = dist.sample().permute( - [1, 2, 0] - ) # Shape : (n_cells_batch, n_genes, n_samples) - else: - exprs = dist.sample() - - return exprs.cpu() - - @torch.no_grad() - @auto_move_data - def marginal_ll(self, tensors, n_mc_samples): - sample_batch = tensors[_CONSTANTS.X_KEY] - batch_index = tensors[_CONSTANTS.BATCH_KEY] - - to_sum = torch.zeros(sample_batch.size()[0], n_mc_samples) - - for i in range(n_mc_samples): - # Distribution parameters and sampled variables - inference_outputs, _, losses = self.forward(tensors) - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] - z = inference_outputs["z"] - ql_m = inference_outputs["ql_m"] - ql_v = inference_outputs["ql_v"] - library = inference_outputs["library"] - - # Reconstruction Loss - reconst_loss = losses.reconstruction_loss - - # Log-probabilities - n_batch = self.library_log_means.shape[1] - local_library_log_means = F.linear( - one_hot(batch_index, n_batch), self.library_log_means - ) - local_library_log_vars = F.linear( - one_hot(batch_index, n_batch), self.library_log_vars - ) - p_l = ( - Normal(local_library_log_means, local_library_log_vars.sqrt()) - .log_prob(library) - .sum(dim=-1) - ) - - p_z = ( - Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)) - .log_prob(z) - .sum(dim=-1) - ) - p_x_zl = -reconst_loss - q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) - q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1) - - to_sum[:, i] = p_z + p_l + p_x_zl - q_z_x - q_l_x - - batch_log_lkl = torch.logsumexp(to_sum, dim=-1) - np.log(n_mc_samples) - log_lkl = torch.sum(batch_log_lkl).item() - return log_lkl diff --git a/pl_vae_scripts_new/HMIDataset.py b/pl_vae_scripts_new/HMIDataset.py new file mode 100644 index 0000000..b463b40 --- /dev/null +++ b/pl_vae_scripts_new/HMIDataset.py @@ -0,0 +1,28 @@ +import os + +import pandas as pd +import scanpy as sc +from torch.utils.data import TensorDataset + + +class HMIDataset(TensorDataset): + def __init__( + self, + h5ad_dir, + ): + """ + Input is a directory with all h5ad files. + h5ad_dir: Directory containing all h5ad files for each image in HMI dataset + transform: Default is None. Any transformations to be applied to the h5ad files + """ + self.h5ad_dir = h5ad_dir + self.h5ad_names = pd.DataFrame({"Sample_names": os.listdir(h5ad_dir)}) + + def __len__(self): + return len(os.listdir(self.h5ad_dir)) + + def __getitem__(self, idx): + h5ad_path = os.path.join(self.h5ad_dir, self.h5ad_names.iloc[idx, 0]) + h5ad = sc.read_h5ad(h5ad_path) + + return h5ad diff --git a/pl_vae_scripts_new/ScModeDataloader.py b/pl_vae_scripts_new/ScModeDataloader.py new file mode 100644 index 0000000..329a1d7 --- /dev/null +++ b/pl_vae_scripts_new/ScModeDataloader.py @@ -0,0 +1,93 @@ +import numpy as np +import pandas as pd +import torch +from sklearn.preprocessing import OneHotEncoder, StandardScaler +from torch.utils.data import TensorDataset + + +def sparse_numpy_to_torch(adj_mat): + """Construct sparse torch tensor + Need to do csr -> coo + then follow https://stackoverflow.com/questions/50665141/converting-a-scipy-coo-matrix-to-pytorch-sparse-tensor + """ + adj_mat_coo = adj_mat.tocoo() + + values = adj_mat_coo.data + indices = np.vstack((adj_mat_coo.row, adj_mat_coo.col)) + + i = torch.LongTensor(indices) + v = torch.FloatTensor(values) + shape = adj_mat_coo.shape + + return torch.sparse_coo_tensor(i, v, shape) + + +class ScModeDataloader(TensorDataset): + def __init__(self, adata, scalers=None): + """ + Need to get the following from adata: + Y - NxP mean expression matrix + S - Nx(pC2) correlation matrix + M - Nx7 morphology matrix + scalers: set of data scalers + """ + self.adata = adata + Y = adata.X + S = adata.obsm["correlations"] + M = adata.obsm["morphology"] + + if scalers is None: + self.scalers = {} + self.scalers["Y"] = StandardScaler().fit(Y) + self.scalers["S"] = StandardScaler().fit(S) + self.scalers["M"] = StandardScaler().fit(M) + else: + self.scalers = scalers + + Y = self.scalers["Y"].transform(Y) + S = self.scalers["S"].transform(S) + M = self.scalers["M"].transform(M) + + self.Y = torch.tensor(Y).float() + self.S = torch.tensor(S).float() + self.M = torch.tensor(M).float() + + self.samples_onehot = self.one_hot_encoding() + + def __len__(self): + return self.Y.shape[0] + + def one_hot_encoding(self, test=False): + """ + Creates a onehot encoding for samples. + """ + onehotenc = OneHotEncoder() + X = self.adata.obs[["Sample_name"]] + onehot_X = onehotenc.fit_transform(X).toarray() + + df = pd.DataFrame(onehot_X, columns=onehotenc.categories_[0]) + + df = df.reindex(columns=self.adata.obs.Sample_name.unique().tolist()) + + return torch.tensor(df.to_numpy()) + + def get_spatial_context(self): + adj_mat = sparse_numpy_to_torch( + self.adata.obsp["connectivities"] + ) # adjacency matrix + concatenated_features = torch.cat((self.Y, self.S, self.M), 1) + + self.C = torch.smm( # normalize for the number of adjacent cells + adj_mat, concatenated_features + ).to_dense() # spatial context for each cell + + def __getitem__(self, idx): + + return ( + self.Y[idx, :], + self.S[idx, :], + self.M[idx, :], + self.C[idx, :], + self.samples_onehot[idx, :], + idx, + ) diff --git a/pl_vae_scripts_new/pl_vae_run_refact.py b/pl_vae_scripts_new/pl_vae_run_refact.py new file mode 100644 index 0000000..d1265aa --- /dev/null +++ b/pl_vae_scripts_new/pl_vae_run_refact.py @@ -0,0 +1,454 @@ +# import argparse +# import os + +# import time + +# import anndata as ad +# import numpy as np +# import scanpy as sc +# import torch + +# import wandb +# from pl_vae_classes_and_func_refact import * +# from pytorch_lightning import Trainer + +# from scipy.stats.mstats import winsorize +# from ScModeDataloader import ScModeDataloader + +# from sklearn.model_selection import train_test_split + + +# def sparse_numpy_to_torch(adj_mat): +# """Construct sparse torch tensor +# Need to do csr -> coo +# then follow https://stackoverflow.com/questions/50665141/converting-a-scipy-coo-matrix-to-pytorch-sparse-tensor +# """ +# adj_mat_coo = adj_mat.tocoo() + +# values = adj_mat_coo.data +# indices = np.vstack((adj_mat_coo.row, adj_mat_coo.col)) + +# i = torch.LongTensor(indices) +# v = torch.FloatTensor(values) +# shape = adj_mat_coo.shape + +# return torch.sparse_coo_tensor(i, v, shape) + + +# parser = argparse.ArgumentParser(description="Run emVAE, em2LVAE, dm2LVAE and dmVAE") + +# parser.add_argument( +# "--input_h5ad", +# type=str, +# required=True, +# help="h5ad file that contains mean expression and correlation information for one or more samples", +# ) + +# parser.add_argument("--lr", type=float, help="Learning rate for VAEs", default=0.001) + +# parser.add_argument( +# "--random_seed", +# type=int, +# required=False, +# help="Random seed for VAE initialization", +# default=1234, +# ) + +# parser.add_argument( +# "--train_ratio", +# type=float, +# help="Ratio of the full dataset to be treated as the training set", +# default=0.75, +# ) + +# parser.add_argument("--subset_to", type=int, help="Data subset size") + +# parser.add_argument( +# "--winsorize", type=int, help="0 or 1 to denote False or True", default=1 +# ) + +# parser.add_argument("--cofactor", type=float, help="Value for cofactor", default=5.0) + +# # parser.add_argument( +# # "--n_proteins", type=int, required=True, help="Number of proteins in the dataset" +# # ) + +# parser.add_argument( +# "--use_weights", type=int, help="0 or 1 to denote False or True", default=0 +# ) + +# parser.add_argument( +# "--apply_arctanh", type=int, help="0 or 1 to denote False or True", default=0 +# ) + +# parser.add_argument("--cohort", type=str, help="Name of cohort", default="None") + +# parser.add_argument("--beta", type=float, help="beta value for B-VAE", default=1.0) + +# parser.add_argument("--n_epochs", type=int, help="number of epochs", default=200) + +# parser.add_argument( +# "--apply_KLwarmup", +# type=int, +# help="0 or 1 as False or True, to apply a KL-warmup scheme, if not, then BETA is used as given", +# default=1, +# ) + +# parser.add_argument( +# "--regress_out_patient", +# type=int, +# help="0 or 1 as False or True, to regress out patient effects, default is False", +# default=0, +# ) + +# parser.add_argument( +# "--KL_limit", +# type=float, +# help="Max limit for the coefficient of the KL-Div term", +# default=0.3, +# ) + +# parser.add_argument( +# "--output_dir", type=str, help="Directory to store the outputs", default="." +# ) + +# args = parser.parse_args() + + +# adata = sc.read_h5ad(args.input_h5ad) + +# COFACTOR = args.cofactor + +# RANDOM_SEED = args.random_seed + +# N_EPOCHS = args.n_epochs +# N_HIDDEN = 2 +# HIDDEN_LAYER_SIZE_Eme = 8 +# HIDDEN_LAYER_SIZE_Ecr = 8 +# HIDDEN_LAYER_SIZE_Emr = 8 +# N_SPATIAL_CONTEXT = ( +# HIDDEN_LAYER_SIZE_Eme + HIDDEN_LAYER_SIZE_Ecr + HIDDEN_LAYER_SIZE_Emr +# ) +# HIDDEN_LAYER_SIZE_Esc = 8 # keeping this consistent with the previous Basel analysis + +# LATENT_DIM = 10 +# BATCH_SIZE = 256 +# CELLS_CUTOFF = 500 + +# N_TOTAL_CELLS = adata.shape[0] +# N_PROTEINS = adata.shape[1] +# N_CORRELATIONS = len(adata.uns["names_correlations"]) +# N_MORPHOLOGY = len(adata.uns["names_morphology"]) + +# N_TOTAL_FEATURES = N_PROTEINS + N_CORRELATIONS + N_MORPHOLOGY + +# BETA = args.beta ## beta for beta-vae + +# TRAIN_PROP = args.train_ratio # set the training set ratio + +# lr = args.lr # set the learning rate + +# # log_py = {} +# # elbo_losses = {} + + +# if args.output_dir is not None: +# output_dir = args.output_dir +# if not os.path.exists(output_dir): +# os.makedirs(output_dir) +# else: +# output_dir = "." + + +# # Set up the data +# np.random.seed(RANDOM_SEED) + +# if args.subset_to is not None: +# print("Subsetting samples") +# samples = adata.obs.Sample_name.unique().to_list() +# inds = np.random.choice(samples, args.subset_to) +# adata = adata[adata.obs.Sample_name.isin(inds)] +# else: +# adata = adata + + +# # adata.obs = adata.obs.reset_index() +# # adata.obs.columns = ["index", "Sample_name", "cell_id"] + + +# if adata.X.shape[0] > 705000: +# sample_drop_lst = [] +# for sample in adata.obs[ +# "Sample_name" +# ].unique(): # if sample has less than 500 cells, drop it +# if ( +# adata.obs.query("Sample_name==@sample").shape[0] < CELLS_CUTOFF +# ): # true for <235 samples out of all samples +# sample_drop_lst.append(sample) + +# adata_sub = adata.copy()[ +# ~adata.obs.Sample_name.isin(sample_drop_lst), : +# ] # select all rows except those that belong to samples w cells < CELLS_CUTOFF + +# adata_sub.obs = adata_sub.obs.reset_index() +# if "level_0" in adata_sub.obs.columns: +# adata_sub.obs = adata_sub.obs.drop(columns=["level_0"]) + +# else: +# adata_sub = adata + + +# print("Preprocessing data views") + +# if args.cofactor is not None: +# adata_sub.X = np.arcsinh(adata_sub.X / COFACTOR) + + +# if args.winsorize == 1: +# for i in range(N_PROTEINS): +# adata_sub.X[:, i] = winsorize(adata_sub.X[:, i], limits=[0, 0.01]) +# for i in range(N_MORPHOLOGY): +# adata_sub.obsm["morphology"][:, i] = winsorize( +# adata_sub.obsm["morphology"][:, i], limits=[0, 0.01] +# ) + +# if args.apply_arctanh == 1: +# adata_sub.obsm["correlations"] = np.arctanh(adata_sub.obsm["correlations"]) + + +# adata_sub.obs["Sample_name"] = adata_sub.obs["Sample_name"].astype( +# str +# ) # have to do this otherwise it will contain the ones that were removed + +# if args.regress_out_patient: +# print("Regressing out patient effect") +# sc.pp.regress_out(adata_sub, "Sample_name") + +# samples_list = adata_sub.obs["Sample_name"].unique().tolist() # samples in the adata + + +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# train_size = int(np.floor(len(samples_list) * TRAIN_PROP)) +# test_size = len(samples_list) - train_size + +# # separate images/samples as train or test *only* (this is different from before, when we separated cells into train/test) + +# print("Setting up train and test data") + +# samples_train, samples_test = train_test_split( +# samples_list, train_size=TRAIN_PROP, random_state=RANDOM_SEED +# ) + +# adata_train = adata_sub.copy()[adata_sub.obs["Sample_name"].isin(samples_train), :] +# adata_test = adata_sub.copy()[adata_sub.obs["Sample_name"].isin(samples_test), :] + + +# data_train = ScModeDataloader(adata_train) +# data_test = ScModeDataloader(adata_test, data_train.scalers) + +# loader_train = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True) +# loader_test = DataLoader(data_test, batch_size=BATCH_SIZE, shuffle=True) + + +# model = hmiVAE( +# N_PROTEINS, +# N_CORRELATIONS, +# N_MORPHOLOGY, +# N_SPATIAL_CONTEXT, +# HIDDEN_LAYER_SIZE_Eme, +# HIDDEN_LAYER_SIZE_Ecr, +# HIDDEN_LAYER_SIZE_Emr, +# HIDDEN_LAYER_SIZE_Esc, +# LATENT_DIM, +# ) + +# trainer = Trainer() + +# wandb.init( +# project="vae_new_morphs", +# entity="sayub", +# config={ +# "learning_rate": lr, +# "epochs": N_EPOCHS, +# "batch_size": BATCH_SIZE, +# "use_weights": int(corr_weights), +# "use_arctanh": int(args.apply_arctanh), +# "n_cells": len(data_train), +# "method": "emVAE", +# "cohort": args.cohort, +# "BETA": BETA, +# "cofactor": COFACTOR, +# "regress_out_patient": args.regress_out_patient, +# "apply_KL_warmup": args.apply_KLwarmup, +# "KL_max_limit": args.KL_limit, +# }, +# ) + + +# # start_time_em = time.time() + + +# # ## Reconstruction weights +# # r = np.array([1., 1., 1., 0.]) + +# # for n in range(N_EPOCHS): +# # if args.apply_KLwarmup: +# # if n>5: +# # new_beta = BETA + 0.05 +# # BETA = min(new_beta, args.KL_limit) +# # else: +# # BETA = BETA + +# # if n > 5: +# # spcont_r = r[3]+0.1 +# # r[3] = min(spcont_r, 1.0) +# # if n < 30: +# # print(r) +# # #r = np.array([1., 1., 1., 1.]) + +# # optimizer_em.zero_grad() + +# # train_losses_em = 0 +# # num_batches_em = 0 + +# for num, batch in enumerate(loader_train): +# data = { +# "Y": batch[0], +# "S": batch[1], +# "M": batch[2], +# "A": spatial_context[batch[-1], :], +# } + +# em_train_loss = train_test_run( +# data=data, +# spatial_context=data["A"], +# method="EM", +# method_enc=enc_em, +# method_dec=dec_em, +# n_proteins=N_PROTEINS, +# latent_dim=LATENT_DIM, +# corr_weights=corr_weights, +# recon_weights=r, +# beta=BETA, +# ) + +# ## Update gradients and weights +# em_train_loss[0].backward() + +# torch.nn.utils.clip_grad_norm_(parameters, 2.0) + +# optimizer_em.step() + +# train_losses_em += em_train_loss[0].detach().item() +# num_batches_em += 1 + + +# # Update old mean embeddings (once per epoch) +# with torch.no_grad(): +# mu_z_old, _, z1 = enc_em(data_train.Y, data_train.S, data_train.M, z1) # mu_z_old) + +# wandb.log( +# { +# "train_neg_elbo": train_losses_em / num_batches_em, +# "kl_div": em_train_loss[1], +# "recon_lik": em_train_loss[2], +# "recon_lik_me": em_train_loss[3], +# "recon_lik_corr": em_train_loss[4], +# "recon_lik_mor": em_train_loss[5], +# "recon_lik_spcont": em_train_loss[6], +# "mu_z_max": em_train_loss[8], +# "log_std_z_max": em_train_loss[9], +# "mu_z_min": em_train_loss[10], +# "log_std_z_min": em_train_loss[11], +# "mu_x_exp_hat_max": em_train_loss[12], +# "log_std_x_exp_hat_max": em_train_loss[13], +# "mu_x_exp_hat_min": em_train_loss[14], +# "log_std_x_exp_hat_min": em_train_loss[15], +# "mu_x_corr_hat_max": em_train_loss[16], +# "log_std_x_corr_hat_max": em_train_loss[17], +# "mu_x_corr_hat_min": em_train_loss[18], +# "log_std_x_corr_hat_min": em_train_loss[19], +# "mu_x_morph_hat_max": em_train_loss[20], +# "log_std_x_morph_hat_max": em_train_loss[21], +# "mu_x_morph_hat_min": em_train_loss[22], +# "log_std_x_morph_hat_min": em_train_loss[23], +# "mu_x_spcont_hat_max": em_train_loss[24], +# "log_std_x_spcont_hat_max": em_train_loss[25], +# "mu_x_spcont_hat_min": em_train_loss[26], +# "log_std_x_spcont_hat_min": em_train_loss[27], +# } +# ) + +# # Now compute test metrics + +# with torch.no_grad(): +# spatial_context_test = torch.smm( +# adj_mat_test_tensor, z1_test # mu_z_old_test +# ).to_dense() + +# test_data = { +# "Y": data_test.Y, +# "S": data_test.S, +# "M": data_test.M, +# "A": spatial_context_test, +# } + +# em_test_loss = train_test_run( +# data=test_data, +# spatial_context=test_data["A"], +# method="EM", +# method_enc=enc_em, +# method_dec=dec_em, +# n_proteins=N_PROTEINS, +# latent_dim=LATENT_DIM, +# corr_weights=corr_weights, +# recon_weights=r, +# beta=BETA, +# ) + +# # mu_z_old_test = em_test_loss[7] +# z1_test = em_test_loss[7] + +# wandb.log( +# { +# "test_neg_elbo": em_test_loss[0], +# "test_kl_div": em_test_loss[1], +# "test_recon_lik": em_test_loss[2], +# "test_recon_lik_me": em_test_loss[3], +# "test_recon_lik_corr": em_test_loss[4], +# "test_recon_lik_mor": em_test_loss[5], +# "test_recon_lik_spcont": em_test_loss[6], +# } +# ) + +# stop_time_em = time.time() +# wandb.finish() +# print("emVAE done training") + +# # adata_train.obsm['emVAE_final_spatial_context'] = mu_z_old_train_em.detach().numpy() +# # adata_test.obsm['emVAE_final_spatial_context'] = mu_z_old_test_em.detach().numpy() + +# # eval_em_spatial_context= torch.smm( +# # adj_mat_test_tensor, mu_z_old_test_em +# # ).to_dense() + +# with torch.no_grad(): +# mu_z, _, z1 = enc_em(data_train.Y, data_train.S, data_train.M, z1) +# mu_z_test, _, z1_test = enc_em(data_test.Y, data_test.S, data_test.M, z1_test) + +# adata_train.obsm["VAE"] = mu_z.numpy() +# adata_test.obsm["VAE"] = mu_z_test.numpy() + +# adata_train.obsm["spatial_context"] = z1.numpy() +# adata_test.obsm["spatial_context"] = z1_test.numpy() + +# adata_train.write(os.path.join(output_dir, "adata_train.h5ad")) +# adata_test.write(os.path.join(output_dir, "adata_test.h5ad")) + +# torch.save(enc_em.state_dict(), os.path.join(output_dir, "emVAE_encoder.pt")) +# torch.save(dec_em.state_dict(), os.path.join(output_dir, "emVAE_decoder.pt")) + + +# print("All done!") From 5bbb2a230278ea1b1b64b43d1f7d797629d2bbe0 Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Mon, 13 Jun 2022 15:22:57 -0400 Subject: [PATCH 02/18] changed to super().__init__() in hmivaeModel --- .../ScModeDataloader.py | 0 hmivae/_hmivae_model.py | 143 +++++++-- hmivae/_hmivae_module.py | 303 ------------------ 3 files changed, 115 insertions(+), 331 deletions(-) rename {pl_vae_scripts_new => hmivae}/ScModeDataloader.py (100%) diff --git a/pl_vae_scripts_new/ScModeDataloader.py b/hmivae/ScModeDataloader.py similarity index 100% rename from pl_vae_scripts_new/ScModeDataloader.py rename to hmivae/ScModeDataloader.py diff --git a/hmivae/_hmivae_model.py b/hmivae/_hmivae_model.py index 1d52197..b9253c9 100644 --- a/hmivae/_hmivae_model.py +++ b/hmivae/_hmivae_model.py @@ -1,18 +1,27 @@ import logging from typing import List, Optional +import numpy as np +import pytorch_lightning as pl +from _hmivae_module import hmiVAE from anndata import AnnData -from scvi.data import setup_anndata -from scvi.model._utils import _init_library_size -from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin -from scvi.utils import setup_anndata_dsp +from pytorch_lightning.trainer import Trainer +from scipy.stats.mstats import winsorize +from ScModeDataloader import ScModeDataloader +from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader + +# from scvi.data import setup_anndata +# from scvi.model._utils import _init_library_size +# from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin +# from scvi.utils import setup_anndata_dsp -from ._hmivae_module import hmiVAE logger = logging.getLogger(__name__) -class hmivaeModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): +# class hmivaeModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): +class hmivaeModel(pl.LightningModule): """ Skeleton for an scvi-tools model. @@ -42,52 +51,96 @@ class hmivaeModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): def __init__( self, adata: AnnData, - n_hidden: int = 128, - n_latent: int = 10, - n_layers: int = 1, + input_exp_dim: int, + input_corr_dim: int, + input_morph_dim: int, + input_spcont_dim: int, + E_me: int = 32, + E_cr: int = 32, + E_mr: int = 32, + E_sc: int = 32, + latent_dim: int = 10, + n_hidden: int = 1, **model_kwargs, ): - super(hmiVAE, self).__init__(adata) + # super(hmivaeModel, self).__init__(adata) + super().__init__() + + # library_log_means, library_log_vars = _init_library_size( + # adata, self.summary_stats["n_batch"] + # ) - library_log_means, library_log_vars = _init_library_size( - adata, self.summary_stats["n_batch"] + self.train_batch, self.test_batch = self.setup_anndata( + adata=adata, + protein_correlations_obsm_key="correlations", + cell_morphology_obsm_key="morphology", ) # self.summary_stats provides information about anndata dimensions and other tensor info self.module = hmiVAE( - n_input=self.summary_stats["n_vars"], + input_exp_dim=input_exp_dim, + input_corr_dim=input_corr_dim, + input_morph_dim=input_morph_dim, + input_spcont_dim=input_spcont_dim, + E_me=E_me, + E_cr=E_cr, + E_mr=E_mr, + E_sc=E_sc, + latent_dim=latent_dim, n_hidden=n_hidden, - n_latent=n_latent, - n_layers=n_layers, - library_log_means=library_log_means, - library_log_vars=library_log_vars, **model_kwargs, ) - self._model_summary_string = "Overwrite this attribute to get an informative representation for your model" + self._model_summary_string = ( + "hmiVAE model with the following parameters: \nn_latent:{}" + "n_protein_expression:{}, n_correlation:{}, n_morphology:{}, n_spatial_context:{}" + ).format( + latent_dim, + input_exp_dim, + input_corr_dim, + input_morph_dim, + input_spcont_dim, + ) # necessary line to get params that will be used for saving/loading self.init_params_ = self._get_init_params(locals()) logger.info("The model has been initialized") + def train(self): + + trainer = Trainer() + + trainer.fit(self.module, self.train_batch) # training, add wandb + trainer.test(dataloaders=self.test_batch) # test, add wandb + + return trainer() + + # @setup_anndata_dsp.dedent @staticmethod - @setup_anndata_dsp.dedent def setup_anndata( + # self, adata: AnnData, protein_correlations_obsm_key: str, cell_morphology_obsm_key: str, - cell_spatial_context_obsm_key: str, + # cell_spatial_context_obsm_key: str, protein_correlations_names_uns_key: Optional[str] = None, cell_morphology_names_uns_key: Optional[str] = None, + batch_size: Optional[int] = 128, batch_key: Optional[str] = None, labels_key: Optional[str] = None, layer: Optional[str] = None, categorical_covariate_keys: Optional[List[str]] = None, continuous_covariate_keys: Optional[List[str]] = None, + cofactor: float = 1.0, + train_prop: Optional[float] = 0.75, + apply_winsorize: Optional[bool] = True, + arctanh_corrs: Optional[bool] = False, + random_seed: Optional[int] = 1234, copy: bool = False, ) -> Optional[AnnData]: """ %(summary)s. + Takes in an AnnData object and returns the train and test loaders. Parameters ---------- %(param_adata)s @@ -102,12 +155,46 @@ def setup_anndata( ------- %(returns)s """ - return setup_anndata( - adata, - batch_key=batch_key, - labels_key=labels_key, - layer=layer, - categorical_covariate_keys=categorical_covariate_keys, - continuous_covariate_keys=continuous_covariate_keys, - copy=copy, + # N_TOTAL_CELLS = adata.shape[0] + N_PROTEINS = adata.shape[1] + # N_CORRELATIONS = len(adata.uns["names_correlations"]) + N_MORPHOLOGY = len(adata.uns["names_morphology"]) + + # N_TOTAL_FEATURES = N_PROTEINS + N_CORRELATIONS + N_MORPHOLOGY + # if cofactor is not None: + adata.X = np.arcsinh(adata.X / cofactor) + + if apply_winsorize: + for i in range(N_PROTEINS): + adata.X[:, i] = winsorize(adata.X[:, i], limits=[0, 0.01]) + for i in range(N_MORPHOLOGY): + adata.obsm[cell_morphology_obsm_key][:, i] = winsorize( + adata.obsm[cell_morphology_obsm_key][:, i], limits=[0, 0.01] + ) + + if arctanh_corrs: + adata.obsm[protein_correlations_obsm_key] = np.arctanh( + adata.obsm[protein_correlations_obsm_key] + ) + + samples_list = ( + adata.obs["Sample_name"].unique().tolist() + ) # samples in the adata + + # train_size = int(np.floor(len(samples_list) * train_prop)) + # test_size = len(samples_list) - train_size + + samples_train, samples_test = train_test_split( + samples_list, train_size=train_prop, random_state=random_seed ) + + adata_train = adata.copy()[adata.obs["Sample_name"].isin(samples_train), :] + adata_test = adata.copy()[adata.obs["Sample_name"].isin(samples_test), :] + + data_train = ScModeDataloader(adata_train) + data_test = ScModeDataloader(adata_test, data_train.scalers) + + loader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True) + loader_test = DataLoader(data_test, batch_size=batch_size, shuffle=True) + + return loader_train, loader_test diff --git a/hmivae/_hmivae_module.py b/hmivae/_hmivae_module.py index f655652..cf301db 100644 --- a/hmivae/_hmivae_module.py +++ b/hmivae/_hmivae_module.py @@ -6,310 +6,9 @@ import torch.nn.functional as F from _hmivae_base_components import DecoderHMIVAE, EncoderHMIVAE -# from scvi import _CONSTANTS -# from scvi.distributions import ZeroInflatedNegativeBinomial -# from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data -# from scvi.nn import one_hot -# from torch.distributions import Normal -# from torch.distributions import kl_divergence as kl - torch.backends.cudnn.benchmark = True -# class HMIVAE(BaseModuleClass): -# """ -# Variational auto-encoder model. - -# Here we implement a basic version of scVI's underlying VAE [Lopez18]_. -# This implementation is for instructional purposes only. - -# Parameters -# ---------- -# n_input -# Number of input genes -# library_log_means -# 1 x n_batch array of means of the log library sizes. Parameterizes prior on library size if -# not using observed library size. -# library_log_vars -# 1 x n_batch array of variances of the log library sizes. Parameterizes prior on library size if -# not using observed library size. -# n_batch -# Number of batches, if 0, no batch correction is performed. -# n_hidden -# Number of nodes per hidden layer -# n_latent -# Dimensionality of the latent space -# n_layers -# Number of hidden layers used for encoder and decoder NNs -# dropout_rate -# Dropout rate for neural networks -# """ - -# def __init__( -# self, -# n_input: int, -# n_batch: int = 0, -# n_hidden: int = 128, -# n_latent: int = 10, -# n_layers: int = 1, -# dropout_rate: float = 0.1, -# ): -# # def __init__( -# # self, -# # input_exp_dim: int, -# # input_corr_dim: int, -# # input_morph_dim: int, -# # input_spcont_dim: int, -# # E_me: int, -# # E_cr: int, -# # E_mr: int, -# # E_sc: int, -# # n_latent: int = 10, -# # n_batch: int = 0, -# # n_hidden: int = 1, -# # ): -# super().__init__() -# self.n_latent = n_latent -# self.n_batch = n_batch -# # this is needed to comply with some requirement of the VAEMixin class -# self.latent_distribution = "normal" - -# self.register_buffer( -# "library_log_means", torch.from_numpy(library_log_means).float() -# ) -# self.register_buffer( -# "library_log_vars", torch.from_numpy(library_log_vars).float() -# ) - -# # setup the parameters of your generative model, as well as your inference model -# self.px_r = torch.nn.Parameter(torch.randn(n_input)) -# # z encoder goes from the n_input-dimensional data to an n_latent-d -# # latent space representation -# self.z_encoder = EncoderHMIVAE( -# n_input, -# n_latent, -# n_layers=n_layers, -# n_hidden=n_hidden, -# dropout_rate=dropout_rate, -# ) -# # l encoder goes from n_input-dimensional data to 1-d library size -# self.l_encoder = EncoderHMIVAE( -# n_input, -# 1, -# n_layers=1, -# n_hidden=n_hidden, -# dropout_rate=dropout_rate, -# ) -# # decoder goes from n_latent-dimensional space to n_input-d data -# self.decoder = DecoderHMIVAE( -# n_latent, -# n_input, -# n_layers=n_layers, -# n_hidden=n_hidden, -# ) - -# def _get_inference_input(self, tensors): -# """Parse the dictionary to get appropriate args""" -# x = tensors[_CONSTANTS.X_KEY] - -# input_dict = dict(x=x) -# return input_dict - -# def _get_generative_input(self, tensors, inference_outputs): -# z = inference_outputs["z"] -# library = inference_outputs["library"] - -# input_dict = { -# "z": z, -# "library": library, -# } -# return input_dict - -# @auto_move_data -# def inference(self, x): -# """ -# High level inference method. - -# Runs the inference (encoder) model. -# """ -# # log the input to the variational distribution for numerical stability -# x_ = torch.log(1 + x) -# # get variational parameters via the encoder networks -# qz_m, qz_v, z = self.z_encoder(x_) -# ql_m, ql_v, library = self.l_encoder(x_) - -# outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v, ql_m=ql_m, ql_v=ql_v, library=library) -# return outputs - -# @auto_move_data -# def generative(self, z, library): -# """Runs the generative model.""" - -# # form the parameters of the ZINB likelihood -# px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) -# px_r = torch.exp(self.px_r) - -# return dict( -# px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout -# ) - -# def loss( -# self, -# tensors, -# inference_outputs, -# generative_outputs, -# kl_weight: float = 1.0, -# ): -# x = tensors[_CONSTANTS.X_KEY] -# qz_m = inference_outputs["qz_m"] -# qz_v = inference_outputs["qz_v"] -# ql_m = inference_outputs["ql_m"] -# ql_v = inference_outputs["ql_v"] -# px_rate = generative_outputs["px_rate"] -# px_r = generative_outputs["px_r"] -# px_dropout = generative_outputs["px_dropout"] - -# mean = torch.zeros_like(qz_m) -# scale = torch.ones_like(qz_v) - -# kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( -# dim=1 -# ) - -# batch_index = tensors[_CONSTANTS.BATCH_KEY] -# n_batch = self.library_log_means.shape[1] -# local_library_log_means = F.linear( -# one_hot(batch_index, n_batch), self.library_log_means -# ) -# local_library_log_vars = F.linear( -# one_hot(batch_index, n_batch), self.library_log_vars -# ) - -# kl_divergence_l = kl( -# Normal(ql_m, torch.sqrt(ql_v)), -# Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), -# ).sum(dim=1) - -# reconst_loss = ( -# -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) -# .log_prob(x) -# .sum(dim=-1) -# ) - -# kl_local_for_warmup = kl_divergence_z -# kl_local_no_warmup = kl_divergence_l - -# weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup - -# loss = torch.mean(reconst_loss + weighted_kl_local) - -# kl_local = dict( -# kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z -# ) -# kl_global = torch.tensor(0.0) -# return LossRecorder(loss, reconst_loss, kl_local, kl_global) - -# @torch.no_grad() -# def sample( -# self, -# tensors, -# n_samples=1, -# library_size=1, -# ) -> np.ndarray: -# r""" -# Generate observation samples from the posterior predictive distribution. - -# The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. - -# Parameters -# ---------- -# tensors -# Tensors dict -# n_samples -# Number of required samples for each cell -# library_size -# Library size to scale scamples to - -# Returns -# ------- -# x_new : :py:class:`torch.Tensor` -# tensor with shape (n_cells, n_genes, n_samples) -# """ -# inference_kwargs = dict(n_samples=n_samples) -# _, generative_outputs, = self.forward( -# tensors, -# inference_kwargs=inference_kwargs, -# compute_loss=False, -# ) - -# px_r = generative_outputs["px_r"] -# px_rate = generative_outputs["px_rate"] -# px_dropout = generative_outputs["px_dropout"] - -# dist = ZeroInflatedNegativeBinomial( -# mu=px_rate, theta=px_r, zi_logits=px_dropout -# ) - -# if n_samples > 1: -# exprs = dist.sample().permute( -# [1, 2, 0] -# ) # Shape : (n_cells_batch, n_genes, n_samples) -# else: -# exprs = dist.sample() - -# return exprs.cpu() - -# @torch.no_grad() -# @auto_move_data -# def marginal_ll(self, tensors, n_mc_samples): -# sample_batch = tensors[_CONSTANTS.X_KEY] -# batch_index = tensors[_CONSTANTS.BATCH_KEY] - -# to_sum = torch.zeros(sample_batch.size()[0], n_mc_samples) - -# for i in range(n_mc_samples): -# # Distribution parameters and sampled variables -# inference_outputs, _, losses = self.forward(tensors) -# qz_m = inference_outputs["qz_m"] -# qz_v = inference_outputs["qz_v"] -# z = inference_outputs["z"] -# ql_m = inference_outputs["ql_m"] -# ql_v = inference_outputs["ql_v"] -# library = inference_outputs["library"] - -# # Reconstruction Loss -# reconst_loss = losses.reconstruction_loss - -# # Log-probabilities -# n_batch = self.library_log_means.shape[1] -# local_library_log_means = F.linear( -# one_hot(batch_index, n_batch), self.library_log_means -# ) -# local_library_log_vars = F.linear( -# one_hot(batch_index, n_batch), self.library_log_vars -# ) -# p_l = ( -# Normal(local_library_log_means, local_library_log_vars.sqrt()) -# .log_prob(library) -# .sum(dim=-1) -# ) - -# p_z = ( -# Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)) -# .log_prob(z) -# .sum(dim=-1) -# ) -# p_x_zl = -reconst_loss -# q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) -# q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1) - -# to_sum[:, i] = p_z + p_l + p_x_zl - q_z_x - q_l_x - -# batch_log_lkl = torch.logsumexp(to_sum, dim=-1) - np.log(n_mc_samples) -# log_lkl = torch.sum(batch_log_lkl).item() -# return log_lkl - - class hmiVAE(pl.LightningModule): """ Variational Autoencoder for hmiVAE based on pytorch-lightning. @@ -496,7 +195,6 @@ def loss(self, kl_div, recon_loss, beta: float = 1.0): def training_step( self, train_batch, - spatial_context, batch_idx, categories: Optional[Iterable[int]] = None, corr_weights=False, @@ -582,7 +280,6 @@ def training_step( def test_step( self, test_batch, - spatial_context, batch_idx, corr_weights=False, recon_weights=np.array([1.0, 1.0, 1.0, 1.0]), From 820b212b9467d2862e048657f25f4cdbe0fe9845 Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Mon, 13 Jun 2022 17:25:28 -0400 Subject: [PATCH 03/18] version without one-hot encoding running --- hmivae/ScModeDataloader.py | 4 +- hmivae/_hmivae_base_components.py | 27 +++++++------ hmivae/_hmivae_model.py | 11 ++--- hmivae/_hmivae_module.py | 67 ++++++++++++++++--------------- 4 files changed, 58 insertions(+), 51 deletions(-) diff --git a/hmivae/ScModeDataloader.py b/hmivae/ScModeDataloader.py index 329a1d7..bbb4d69 100644 --- a/hmivae/ScModeDataloader.py +++ b/hmivae/ScModeDataloader.py @@ -51,6 +51,7 @@ def __init__(self, adata, scalers=None): self.Y = torch.tensor(Y).float() self.S = torch.tensor(S).float() self.M = torch.tensor(M).float() + self.C = self.get_spatial_context() self.samples_onehot = self.one_hot_encoding() @@ -77,9 +78,10 @@ def get_spatial_context(self): ) # adjacency matrix concatenated_features = torch.cat((self.Y, self.S, self.M), 1) - self.C = torch.smm( # normalize for the number of adjacent cells + C = torch.smm( # normalize for the number of adjacent cells adj_mat, concatenated_features ).to_dense() # spatial context for each cell + return C def __getitem__(self, idx): diff --git a/hmivae/_hmivae_base_components.py b/hmivae/_hmivae_base_components.py index d19f920..1aa92aa 100644 --- a/hmivae/_hmivae_base_components.py +++ b/hmivae/_hmivae_base_components.py @@ -38,10 +38,8 @@ def __init__( self.input_corr = nn.Linear(input_corr_dim, E_cr) self.corr_hidden = nn.Linear(E_cr, E_cr) - self.input_morph = nn.Linear(input_morph_dim, E_mr) self.morph_hidden = nn.Linear(E_mr, E_mr) - self.input_spatial_context = nn.Linear(input_spcont_dim, E_sc) self.spatial_context_hidden = nn.Linear(E_sc, E_sc) @@ -62,7 +60,7 @@ def forward(self, x_mean, x_correlations, x_morphology, x_spatial_context): h_morphology = F.elu(self.input_morph(x_morphology)) h_morphology2 = F.elu(self.morph_hidden(h_morphology)) - z1 = torch.cat([h_mean2, h_correlations2, h_morphology2], 1) + # z1 = torch.cat([h_mean2, h_correlations2, h_morphology2], 1) h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context)) @@ -76,7 +74,7 @@ def forward(self, x_mean, x_correlations, x_morphology, x_spatial_context): log_std_z = self.std_z(h) - return mu_z, log_std_z, z1 + return mu_z, log_std_z class DecoderHMIVAE(nn.Module): @@ -109,6 +107,11 @@ def __init__( ): super().__init__() hidden_dim = E_me + E_cr + E_mr + E_sc + self.E_me = E_me + self.E_cr = E_cr + self.E_mr = E_mr + self.E_sc = E_sc + self.input = nn.Linear(latent_dim, hidden_dim) self.linear = nn.ModuleList( [nn.Linear(hidden_dim, hidden_dim) for i in range(n_hidden)] ) @@ -153,13 +156,13 @@ def forward(self, z): mu_x_exp = self.mu_x_exp(h2_mean) std_x_exp = self.std_x_exp(h2_mean) - if self.use_weights: - with torch.no_grad(): - weights = self.get_corr_weights_per_cell( - mu_x_exp.detach() - ) # calculating correlation weights - else: - weights = None + # if self.use_weights: + # with torch.no_grad(): + # weights = self.get_corr_weights_per_cell( + # mu_x_exp.detach() + # ) # calculating correlation weights + # else: + # weights = None mu_x_corr = self.mu_x_corr(h2_correlations) std_x_corr = self.std_x_corr(h2_correlations) @@ -179,5 +182,5 @@ def forward(self, z): std_x_morph, mu_x_spatial_context, std_x_spatial_context, - weights, + # weights, ) diff --git a/hmivae/_hmivae_model.py b/hmivae/_hmivae_model.py index b9253c9..193c445 100644 --- a/hmivae/_hmivae_model.py +++ b/hmivae/_hmivae_model.py @@ -77,7 +77,6 @@ def __init__( ) # self.summary_stats provides information about anndata dimensions and other tensor info - self.module = hmiVAE( input_exp_dim=input_exp_dim, input_corr_dim=input_corr_dim, @@ -102,18 +101,20 @@ def __init__( input_spcont_dim, ) # necessary line to get params that will be used for saving/loading - self.init_params_ = self._get_init_params(locals()) + # self.init_params_ = self._get_init_params(locals()) logger.info("The model has been initialized") - def train(self): + def train( + self, + ): # misnomer, both train and test are here (either rename or separate) - trainer = Trainer() + trainer = Trainer(max_epochs=10) trainer.fit(self.module, self.train_batch) # training, add wandb trainer.test(dataloaders=self.test_batch) # test, add wandb - return trainer() + # return trainer # @setup_anndata_dsp.dedent @staticmethod diff --git a/hmivae/_hmivae_module.py b/hmivae/_hmivae_module.py index cf301db..b0bcc13 100644 --- a/hmivae/_hmivae_module.py +++ b/hmivae/_hmivae_module.py @@ -95,7 +95,7 @@ def em_recon_loss( s, m, c, - weights=None, + # weights=None, ): """Takes in the parameters output from the decoder, and the original input x, and gives the reconstruction @@ -127,15 +127,16 @@ def em_recon_loss( ) log_p_xz_exp = p_rec_exp.log_prob(y) + log_p_xz_corr = p_rec_corr.log_prob(s) log_p_xz_morph = p_rec_morph.log_prob(m) log_p_xz_spcont = p_rec_spcont.log_prob(c) # already dense matrix - if weights is None: - log_p_xz_corr = p_rec_corr.log_prob(s) - else: - log_p_xz_corr = torch.mul( - weights, p_rec_corr.log_prob(s) - ) # does element-wise multiplication + # if weights is None: + # log_p_xz_corr = p_rec_corr.log_prob(s) + # else: + # log_p_xz_corr = torch.mul( + # weights, p_rec_corr.log_prob(s) + # ) # does element-wise multiplication log_p_xz_exp = log_p_xz_exp.sum(-1) log_p_xz_corr = log_p_xz_corr.sum(-1) @@ -161,7 +162,7 @@ def neg_ELBO( s, m, c, - weights=None, + # weights=None, ): kl_div = self.KL_div(enc_x_mu, enc_x_logstd, z) @@ -178,7 +179,7 @@ def neg_ELBO( s, m, c, - weights, + # weights, ) return ( kl_div, @@ -228,7 +229,7 @@ def training_step( log_std_x_morph_hat, mu_x_spcont_hat, log_std_x_spcont_hat, - weights, + # weights, ) = self.decoder(z_samples) ( @@ -253,7 +254,7 @@ def training_step( S, M, spatial_context, - weights, + # weights, ) recon_loss = ( @@ -267,15 +268,15 @@ def training_step( self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) - return ( - loss, - kl_div.mean().item(), - recon_loss.mean().item(), - recon_lik_me.mean().item(), - recon_lik_corr.mean().item(), - recon_lik_mor.mean().item(), - recon_lik_sc.mean().item(), - ) + return { + "loss": loss, + "kl_div": kl_div.mean().item(), + "recon_loss": recon_loss.mean().item(), + "recon_lik_me": recon_lik_me.mean().item(), + "recon_lik_corr": recon_lik_corr.mean().item(), + "recon_lik_mor": recon_lik_mor.mean().item(), + "recon_lik_sc": recon_lik_sc.mean().item(), + } def test_step( self, @@ -298,7 +299,7 @@ def test_step( M = test_batch[2] spatial_context = test_batch[3] - mu_z, log_std_z, z1 = self.encode(Y, S, M, spatial_context) + mu_z, log_std_z = self.encoder(Y, S, M, spatial_context) z_samples = self.reparameterization(mu_z, log_std_z) @@ -312,8 +313,8 @@ def test_step( log_std_x_morph_hat, mu_x_spcont_hat, log_std_x_spcont_hat, - weights, - ) = self.decode(z_samples) + # weights, + ) = self.decoder(z_samples) ( kl_div, @@ -337,7 +338,7 @@ def test_step( S, M, spatial_context, - weights, + # weights, ) recon_loss = ( @@ -351,15 +352,15 @@ def test_step( self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True) - return ( - loss, - kl_div.mean().item(), - recon_loss.mean().item(), - recon_lik_me.mean().item(), - recon_lik_corr.mean().item(), - recon_lik_mor.mean().item(), - recon_lik_sc.mean().item(), - ) + return { + "loss": loss, + "kl_div": kl_div.mean().item(), + "recon_loss": recon_loss.mean().item(), + "recon_lik_me": recon_lik_me.mean().item(), + "recon_lik_corr": recon_lik_corr.mean().item(), + "recon_lik_mor": recon_lik_mor.mean().item(), + "recon_lik_sc": recon_lik_sc.mean().item(), + } def configure_optimizers(self): """Optimizer""" From 1711e6f23284351aaf07adc2e5aa3e82816a12f2 Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Mon, 20 Jun 2022 14:24:28 -0400 Subject: [PATCH 04/18] one-hot encoding added to train and test v1 --- hmivae/ScModeDataloader.py | 2 +- hmivae/_hmivae_base_components.py | 95 ++++++++---- hmivae/_hmivae_model.py | 62 ++++++-- hmivae/_hmivae_module.py | 241 ++++++++++++++++++++++-------- 4 files changed, 292 insertions(+), 108 deletions(-) diff --git a/hmivae/ScModeDataloader.py b/hmivae/ScModeDataloader.py index bbb4d69..71db37c 100644 --- a/hmivae/ScModeDataloader.py +++ b/hmivae/ScModeDataloader.py @@ -70,7 +70,7 @@ def one_hot_encoding(self, test=False): df = df.reindex(columns=self.adata.obs.Sample_name.unique().tolist()) - return torch.tensor(df.to_numpy()) + return torch.tensor(df.to_numpy()).float() def get_spatial_context(self): adj_mat = sparse_numpy_to_torch( diff --git a/hmivae/_hmivae_base_components.py b/hmivae/_hmivae_base_components.py index 1aa92aa..2dee753 100644 --- a/hmivae/_hmivae_base_components.py +++ b/hmivae/_hmivae_base_components.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F @@ -19,19 +21,22 @@ class EncoderHMIVAE(nn.Module): def __init__( self, - input_exp_dim, - input_corr_dim, - input_morph_dim, - input_spcont_dim, - E_me, - E_cr, - E_mr, - E_sc, - latent_dim, - n_hidden=1, + input_exp_dim: int, + input_corr_dim: int, + input_morph_dim: int, + input_spcont_dim: int, + E_me: int, + E_cr: int, + E_mr: int, + E_sc: int, + latent_dim: int, + n_covariates: Optional[int] = 0, + n_hidden: Optional[int] = 1, ): super().__init__() - hidden_dim = E_me + E_cr + E_mr + E_sc + hidden_dim = E_me + E_cr + E_mr + E_sc + n_covariates + + self.input_cov = nn.Linear(n_covariates, n_covariates) self.input_exp = nn.Linear(input_exp_dim, E_me) self.exp_hidden = nn.Linear(E_me, E_me) @@ -50,7 +55,14 @@ def __init__( self.mu_z = nn.Linear(hidden_dim, latent_dim) self.std_z = nn.Linear(hidden_dim, latent_dim) - def forward(self, x_mean, x_correlations, x_morphology, x_spatial_context): + def forward( + self, + x_mean: torch.Tensor, + x_correlations: torch.Tensor, + x_morphology: torch.Tensor, + x_spatial_context: torch.Tensor, + cov_list=torch.Tensor([]), + ): h_mean = F.elu(self.input_exp(x_mean)) h_mean2 = F.elu(self.exp_hidden(h_mean)) @@ -60,12 +72,14 @@ def forward(self, x_mean, x_correlations, x_morphology, x_spatial_context): h_morphology = F.elu(self.input_morph(x_morphology)) h_morphology2 = F.elu(self.morph_hidden(h_morphology)) - # z1 = torch.cat([h_mean2, h_correlations2, h_morphology2], 1) - h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context)) - h = torch.cat([h_mean2, h_correlations2, h_morphology2, h_spatial_context2], 1) + h_cov = F.elu(self.input_cov(cov_list)) + + h = torch.cat( + [h_mean2, h_correlations2, h_morphology2, h_spatial_context2, h_cov], 1 + ) for net in self.linear: h = F.elu(net(h)) @@ -94,19 +108,21 @@ class DecoderHMIVAE(nn.Module): def __init__( self, - latent_dim, - E_me, - E_cr, - E_mr, - E_sc, - input_exp_dim, - input_corr_dim, - input_morph_dim, - input_spcont_dim, - n_hidden=1, + latent_dim: int, + E_me: int, + E_cr: int, + E_mr: int, + E_sc: int, + input_exp_dim: int, + input_corr_dim: int, + input_morph_dim: int, + input_spcont_dim: int, + n_covariates: Optional[int] = 0, + n_hidden: Optional[int] = 1, ): super().__init__() - hidden_dim = E_me + E_cr + E_mr + E_sc + hidden_dim = E_me + E_cr + E_mr + E_sc + n_covariates + latent_dim = latent_dim + n_covariates self.E_me = E_me self.E_cr = E_cr self.E_mr = E_mr @@ -135,8 +151,12 @@ def __init__( self.mu_x_spcont = nn.Linear(E_sc, input_spcont_dim) self.std_x_spcont = nn.Linear(E_sc, input_spcont_dim) - def forward(self, z): - out = F.elu(self.input(z)) + self.covariates_out_mu = nn.Linear(n_covariates, n_covariates) + self.covariates_out_std = nn.Linear(n_covariates, n_covariates) + + def forward(self, z, cov_list): + z_s = torch.cat([z, cov_list], 1) + out = F.elu(self.input(z_s)) for net in self.linear: out = F.elu(net(out)) @@ -150,9 +170,21 @@ def forward(self, z): ) ) h2_spatial_context = F.elu( - self.spatial_context_hidden(out[:, self.E_me + self.E_cr + self.E_mr :]) + self.spatial_context_hidden( + out[ + :, + self.E_me + + self.E_cr + + self.E_mr : self.E_me + + self.E_cr + + self.E_mr + + self.E_sc, + ] + ) ) + covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :] + mu_x_exp = self.mu_x_exp(h2_mean) std_x_exp = self.std_x_exp(h2_mean) @@ -173,6 +205,9 @@ def forward(self, z): mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context) std_x_spatial_context = self.std_x_spcont(h2_spatial_context) + covariates_mu = self.covariates_out_mu(covariates) + covariates_std = self.covariates_out_std(covariates) + return ( mu_x_exp, std_x_exp, @@ -182,5 +217,7 @@ def forward(self, z): std_x_morph, mu_x_spatial_context, std_x_spatial_context, + covariates_mu, + covariates_std, # weights, ) diff --git a/hmivae/_hmivae_model.py b/hmivae/_hmivae_model.py index 193c445..6b27f1d 100644 --- a/hmivae/_hmivae_model.py +++ b/hmivae/_hmivae_model.py @@ -3,8 +3,10 @@ import numpy as np import pytorch_lightning as pl +import torch from _hmivae_module import hmiVAE from anndata import AnnData +from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.trainer import Trainer from scipy.stats.mstats import winsorize from ScModeDataloader import ScModeDataloader @@ -60,17 +62,14 @@ def __init__( E_mr: int = 32, E_sc: int = 32, latent_dim: int = 10, + n_covariates: int = 0, n_hidden: int = 1, **model_kwargs, ): # super(hmivaeModel, self).__init__(adata) super().__init__() - # library_log_means, library_log_vars = _init_library_size( - # adata, self.summary_stats["n_batch"] - # ) - - self.train_batch, self.test_batch = self.setup_anndata( + self.train_batch, self.test_batch, self.n_covariates = self.setup_anndata( adata=adata, protein_correlations_obsm_key="correlations", cell_morphology_obsm_key="morphology", @@ -87,18 +86,21 @@ def __init__( E_mr=E_mr, E_sc=E_sc, latent_dim=latent_dim, + n_covariates=self.n_covariates, n_hidden=n_hidden, **model_kwargs, ) self._model_summary_string = ( - "hmiVAE model with the following parameters: \nn_latent:{}" - "n_protein_expression:{}, n_correlation:{}, n_morphology:{}, n_spatial_context:{}" + "hmiVAE model with the following parameters: \n n_latent:{}, " + "n_protein_expression:{}, n_correlation:{}, n_morphology:{}, n_spatial_context:{}, " + "n_covariates:{} " ).format( latent_dim, input_exp_dim, input_corr_dim, input_morph_dim, input_spcont_dim, + n_covariates, ) # necessary line to get params that will be used for saving/loading # self.init_params_ = self._get_init_params(locals()) @@ -109,13 +111,45 @@ def train( self, ): # misnomer, both train and test are here (either rename or separate) - trainer = Trainer(max_epochs=10) + early_stopping = EarlyStopping( + monitor="test_loss", mode="min", patience=3 + ) # need to add this + + trainer = Trainer(max_epochs=10, callbacks=[early_stopping]) - trainer.fit(self.module, self.train_batch) # training, add wandb - trainer.test(dataloaders=self.test_batch) # test, add wandb + trainer.fit( + self.module, self.train_batch, self.test_batch + ) # training, add wandb + # trainer.test(dataloaders=self.test_batch) # test, add wandb # return trainer + @torch.no_grad() + def get_latent_representation( + self, + adata: AnnData, + protein_correlations_obsm_key: str, + cell_morphology_obsm_key: str, + is_trained_model: Optional[bool] = True, + ) -> np.ndarray: + """ + Gives the latent representation of each cell. + """ + if is_trained_model: + data_train, data_test = self.setup_anndata( + adata, + protein_correlations_obsm_key, + cell_morphology_obsm_key, + is_trained_model=is_trained_model, + ) + train_mu_z = self.module.inference(data_train) + test_mu_z = self.module.inference(data_test) + return train_mu_z, test_mu_z + else: + raise Exception( + "No latent representation to produce! Model is not trained!" + ) + # @setup_anndata_dsp.dedent @staticmethod def setup_anndata( @@ -136,6 +170,7 @@ def setup_anndata( train_prop: Optional[float] = 0.75, apply_winsorize: Optional[bool] = True, arctanh_corrs: Optional[bool] = False, + is_trained_model: Optional[bool] = False, random_seed: Optional[int] = 1234, copy: bool = False, ) -> Optional[AnnData]: @@ -196,6 +231,9 @@ def setup_anndata( data_test = ScModeDataloader(adata_test, data_train.scalers) loader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True) - loader_test = DataLoader(data_test, batch_size=batch_size, shuffle=True) + loader_test = DataLoader(data_test, batch_size=batch_size) # shuffle=True) - return loader_train, loader_test + if is_trained_model: + return data_train, data_test + else: + return loader_train, loader_test, len(samples_train) diff --git a/hmivae/_hmivae_module.py b/hmivae/_hmivae_module.py index b0bcc13..81f61ad 100644 --- a/hmivae/_hmivae_module.py +++ b/hmivae/_hmivae_module.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional # Dict, Tuple, Union +from typing import List, Optional, Sequence import numpy as np import pytorch_lightning as pl @@ -6,6 +6,8 @@ import torch.nn.functional as F from _hmivae_base_components import DecoderHMIVAE, EncoderHMIVAE +# from anndata import AnnData + torch.backends.cudnn.benchmark = True @@ -25,10 +27,12 @@ def __init__( E_mr: int = 32, E_sc: int = 32, latent_dim: int = 10, + n_covariates: int = 0, n_hidden: int = 1, ): super().__init__() # hidden_dim = E_me + E_cr + E_mr + E_sc + self.n_covariates = n_covariates self.encoder = EncoderHMIVAE( input_exp_dim, @@ -40,6 +44,7 @@ def __init__( E_mr, E_sc, latent_dim, + n_covariates=n_covariates, ) self.decoder = DecoderHMIVAE( @@ -52,6 +57,7 @@ def __init__( input_corr_dim, input_morph_dim, input_spcont_dim, + n_covariates=n_covariates, ) def reparameterization(self, mu, log_std): @@ -91,10 +97,13 @@ def em_recon_loss( dec_x_logstd_morph, dec_x_mu_spcont, dec_x_logstd_spcont, + covariates_mu, + covariates_std, y, s, m, c, + cov_list, # weights=None, ): """Takes in the parameters output from the decoder, @@ -119,17 +128,20 @@ def em_recon_loss( dec_x_std_corr = torch.exp(dec_x_logstd_corr) dec_x_std_morph = torch.exp(dec_x_logstd_morph) dec_x_std_spcont = torch.exp(dec_x_logstd_spcont) + cov_std = torch.exp(covariates_std) p_rec_exp = torch.distributions.Normal(dec_x_mu_exp, dec_x_std_exp + 1e-6) p_rec_corr = torch.distributions.Normal(dec_x_mu_corr, dec_x_std_corr + 1e-6) p_rec_morph = torch.distributions.Normal(dec_x_mu_morph, dec_x_std_morph + 1e-6) p_rec_spcont = torch.distributions.Normal( dec_x_mu_spcont, dec_x_std_spcont + 1e-6 ) + p_rec_cov = torch.distributions.Normal(covariates_mu, cov_std + 1e-6) log_p_xz_exp = p_rec_exp.log_prob(y) log_p_xz_corr = p_rec_corr.log_prob(s) log_p_xz_morph = p_rec_morph.log_prob(m) log_p_xz_spcont = p_rec_spcont.log_prob(c) # already dense matrix + log_p_cov = p_rec_cov.log_prob(cov_list) # if weights is None: # log_p_xz_corr = p_rec_corr.log_prob(s) @@ -142,8 +154,9 @@ def em_recon_loss( log_p_xz_corr = log_p_xz_corr.sum(-1) log_p_xz_morph = log_p_xz_morph.sum(-1) log_p_xz_spcont = log_p_xz_spcont.sum(-1) + log_p_cov = log_p_cov.sum(-1) - return log_p_xz_exp, log_p_xz_corr, log_p_xz_morph, log_p_xz_spcont + return log_p_xz_exp, log_p_xz_corr, log_p_xz_morph, log_p_xz_spcont, log_p_cov def neg_ELBO( self, @@ -157,16 +170,25 @@ def neg_ELBO( dec_x_logstd_morph, dec_x_mu_spcont, dec_x_logstd_spcont, + covariates_mu, + covariates_std, z, y, s, m, c, + cov_list, # weights=None, ): kl_div = self.KL_div(enc_x_mu, enc_x_logstd, z) - recon_lik_me, recon_lik_corr, recon_lik_mor, recon_lik_sc = self.em_recon_loss( + ( + recon_lik_me, + recon_lik_corr, + recon_lik_mor, + recon_lik_sc, + reconstructed_covs, + ) = self.em_recon_loss( dec_x_mu_exp, dec_x_logstd_exp, dec_x_mu_corr, @@ -175,10 +197,13 @@ def neg_ELBO( dec_x_logstd_morph, dec_x_mu_spcont, dec_x_logstd_spcont, + covariates_mu, + covariates_std, y, s, m, c, + cov_list, # weights, ) return ( @@ -187,6 +212,7 @@ def neg_ELBO( recon_lik_corr, recon_lik_mor, recon_lik_sc, + reconstructed_covs, ) def loss(self, kl_div, recon_loss, beta: float = 1.0): @@ -196,11 +222,10 @@ def loss(self, kl_div, recon_loss, beta: float = 1.0): def training_step( self, train_batch, - batch_idx, - categories: Optional[Iterable[int]] = None, corr_weights=False, recon_weights=np.array([1.0, 1.0, 1.0, 1.0]), beta=1.0, + categories: Optional[List[float]] = None, ): """ Carries out the training step. @@ -214,8 +239,19 @@ def training_step( S = train_batch[1] M = train_batch[2] spatial_context = train_batch[3] - - mu_z, log_std_z = self.encoder(Y, S, M, spatial_context) + one_hot = train_batch[4] + batch_idx = train_batch[-1] + if categories is not None: + if len(categories) > 0: + categories = torch.Tensor(categories)[batch_idx, :] + else: + categories = torch.Tensor(categories) + else: + categories = torch.Tensor([]) + + cov_list = torch.cat([one_hot, categories], 1).float() + # print('train',cov_list.size()) + mu_z, log_std_z = self.encoder(Y, S, M, spatial_context, cov_list) z_samples = self.reparameterization(mu_z, log_std_z) @@ -229,8 +265,10 @@ def training_step( log_std_x_morph_hat, mu_x_spcont_hat, log_std_x_spcont_hat, + covariates_mu, + covariates_std, # weights, - ) = self.decoder(z_samples) + ) = self.decoder(z_samples, cov_list) ( kl_div, @@ -238,6 +276,7 @@ def training_step( recon_lik_corr, recon_lik_mor, recon_lik_sc, + reconstructed_covs, ) = self.neg_ELBO( mu_z, log_std_z, @@ -249,11 +288,14 @@ def training_step( log_std_x_morph_hat, mu_x_spcont_hat, log_std_x_spcont_hat, + covariates_mu, + covariates_std, z_samples, Y, S, M, spatial_context, + cov_list, # weights, ) @@ -262,6 +304,7 @@ def training_step( + recon_weights[1] * recon_lik_corr + recon_weights[2] * recon_lik_mor + recon_weights[3] * recon_lik_sc + + reconstructed_covs ) loss = self.loss(kl_div, recon_loss, beta=beta) @@ -278,15 +321,17 @@ def training_step( "recon_lik_sc": recon_lik_sc.mean().item(), } - def test_step( + def validation_step( self, test_batch, - batch_idx, + n_other_cat: int = 0, + L_iter: int = 10, corr_weights=False, recon_weights=np.array([1.0, 1.0, 1.0, 1.0]), beta=1.0, + categories: Optional[List[float]] = None, ): - """ + """---> Add random one-hot encoding Carries out the validation/test step. test_batch: torch.Tensor. Validation/test data, spatial_context: torch.Tensor. Matrix with old mu_z integrated neighbours information, @@ -298,62 +343,99 @@ def test_step( S = test_batch[1] M = test_batch[2] spatial_context = test_batch[3] - - mu_z, log_std_z = self.encoder(Y, S, M, spatial_context) - - z_samples = self.reparameterization(mu_z, log_std_z) - - # decoding - ( - mu_x_exp_hat, - log_std_x_exp_hat, - mu_x_corr_hat, - log_std_x_corr_hat, - mu_x_morph_hat, - log_std_x_morph_hat, - mu_x_spcont_hat, - log_std_x_spcont_hat, - # weights, - ) = self.decoder(z_samples) - - ( - kl_div, - recon_lik_me, - recon_lik_corr, - recon_lik_mor, - recon_lik_sc, - ) = self.neg_ELBO( - mu_z, - log_std_z, - mu_x_exp_hat, - log_std_x_exp_hat, - mu_x_corr_hat, - log_std_x_corr_hat, - mu_x_morph_hat, - log_std_x_morph_hat, - mu_x_spcont_hat, - log_std_x_spcont_hat, - z_samples, - Y, - S, - M, - spatial_context, - # weights, - ) - - recon_loss = ( - recon_weights[0] * recon_lik_me - + recon_weights[1] * recon_lik_corr - + recon_weights[2] * recon_lik_mor - + recon_weights[3] * recon_lik_sc - ) - - loss = self.loss(kl_div, recon_loss, beta=beta) - - self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True) + batch_idx = test_batch[-1] + # print(batch_idx) + test_loss = [] + n_classes = self.n_covariates # - n_other_cat + for i in range(L_iter): + # print(n_classes) + # print(len(batch_idx)) + # print(np.eye(n_classes)[np.random.choice(n_classes, len(batch_idx))]) + one_hot = self.random_one_hot(n_classes=n_classes, n_samples=len(batch_idx)) + # print(one_hot.size()) + + if categories is not None: + if len(categories) > 0: + categories = torch.Tensor(categories)[batch_idx, :] + else: + categories = torch.Tensor(categories) + else: + categories = torch.Tensor([]) + + cov_list = torch.cat([one_hot, categories], 1).float() + + mu_z, log_std_z = self.encoder( + Y, S, M, spatial_context, cov_list + ) # valid step + + z_samples = self.reparameterization(mu_z, log_std_z) + + # decoding + ( + mu_x_exp_hat, + log_std_x_exp_hat, + mu_x_corr_hat, + log_std_x_corr_hat, + mu_x_morph_hat, + log_std_x_morph_hat, + mu_x_spcont_hat, + log_std_x_spcont_hat, + covariates_mu, + covariates_std, + # weights, + ) = self.decoder(z_samples, cov_list) + + ( + kl_div, + recon_lik_me, + recon_lik_corr, + recon_lik_mor, + recon_lik_sc, + reconstructed_covs, + ) = self.neg_ELBO( + mu_z, + log_std_z, + mu_x_exp_hat, + log_std_x_exp_hat, + mu_x_corr_hat, + log_std_x_corr_hat, + mu_x_morph_hat, + log_std_x_morph_hat, + mu_x_spcont_hat, + log_std_x_spcont_hat, + covariates_mu, + covariates_std, + z_samples, + Y, + S, + M, + spatial_context, + cov_list, + # weights, + ) + + recon_loss = ( + recon_weights[0] * recon_lik_me + + recon_weights[1] * recon_lik_corr + + recon_weights[2] * recon_lik_mor + + recon_weights[3] * recon_lik_sc + + reconstructed_covs + ) + + loss = self.loss(kl_div, recon_loss, beta=beta) + + test_loss.append(loss) + + self.log( + "test_loss", + sum(test_loss) / L_iter, + on_step=True, + on_epoch=True, + prog_bar=True, + ) # log the average test loss over all the iterations return { - "loss": loss, + "loss": sum(test_loss) / L_iter, "kl_div": kl_div.mean().item(), "recon_loss": recon_loss.mean().item(), "recon_lik_me": recon_lik_me.mean().item(), @@ -367,7 +449,8 @@ def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer - def __get_input_embeddings__( + @torch.no_grad() + def get_input_embeddings( self, x_mean, x_correlations, x_morphology, x_spatial_context ): """ @@ -386,3 +469,29 @@ def __get_input_embeddings__( h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context)) return h_mean2, h_correlations2, h_morphology2, h_spatial_context2 + + @torch.no_grad() + def inference( + self, data, indices: Optional[Sequence[int]] = None, give_mean: bool = True + ) -> np.ndarray: + """ + Return the latent representation of each cell. + """ + if give_mean: + mu_z, _ = self.encoder(data.Y, data.S, data.M, data.C) + + return mu_z.numpy() + else: + mu_z, log_std_z = self.encoder(data.Y, data.S, data.M, data.C) + z = self.reparameterization(mu_z, log_std_z) + + return z.numpy() + + @torch.no_grad() + def random_one_hot(self, n_classes: int, n_samples: int): + """ + Generates a random one hot encoded matrix. + From: https://stackoverflow.com/questions/45093615/random-one-hot-matrix-in-numpy + """ + # x = np.eye(n_classes) + return torch.Tensor(np.eye(n_classes)[np.random.choice(n_classes, n_samples)]) From e3f586936c7aadeb4dd211b73f7459e57d3d5893 Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Tue, 21 Jun 2022 15:01:08 -0400 Subject: [PATCH 05/18] corrected decoder and added get_latent_rep function --- hmivae/_hmivae_base_components.py | 18 ++++---- hmivae/_hmivae_model.py | 34 ++++++++++------ hmivae/_hmivae_module.py | 68 ++++++++++++++++--------------- 3 files changed, 66 insertions(+), 54 deletions(-) diff --git a/hmivae/_hmivae_base_components.py b/hmivae/_hmivae_base_components.py index 2dee753..8550862 100644 --- a/hmivae/_hmivae_base_components.py +++ b/hmivae/_hmivae_base_components.py @@ -151,11 +151,13 @@ def __init__( self.mu_x_spcont = nn.Linear(E_sc, input_spcont_dim) self.std_x_spcont = nn.Linear(E_sc, input_spcont_dim) - self.covariates_out_mu = nn.Linear(n_covariates, n_covariates) - self.covariates_out_std = nn.Linear(n_covariates, n_covariates) + # self.covariates_out_mu = nn.Linear(n_covariates, n_covariates) #this is one-hot + # self.covariates_out_std = nn.Linear(n_covariates, n_covariates) def forward(self, z, cov_list): - z_s = torch.cat([z, cov_list], 1) + z_s = torch.cat( + [z, cov_list], 1 + ) # takes in one-hot as input, doesn't need to be symmetric with the encoder, doesn't output it out = F.elu(self.input(z_s)) for net in self.linear: out = F.elu(net(out)) @@ -183,7 +185,7 @@ def forward(self, z, cov_list): ) ) - covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :] + # covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :] mu_x_exp = self.mu_x_exp(h2_mean) std_x_exp = self.std_x_exp(h2_mean) @@ -205,8 +207,8 @@ def forward(self, z, cov_list): mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context) std_x_spatial_context = self.std_x_spcont(h2_spatial_context) - covariates_mu = self.covariates_out_mu(covariates) - covariates_std = self.covariates_out_std(covariates) + # covariates_mu = self.covariates_out_mu(covariates) + # covariates_std = self.covariates_out_std(covariates) return ( mu_x_exp, @@ -217,7 +219,7 @@ def forward(self, z, cov_list): std_x_morph, mu_x_spatial_context, std_x_spatial_context, - covariates_mu, - covariates_std, + # covariates_mu, + # covariates_std, # weights, ) diff --git a/hmivae/_hmivae_model.py b/hmivae/_hmivae_model.py index 6b27f1d..c8115f7 100644 --- a/hmivae/_hmivae_model.py +++ b/hmivae/_hmivae_model.py @@ -1,12 +1,14 @@ import logging from typing import List, Optional +import anndata as ad import numpy as np import pytorch_lightning as pl import torch from _hmivae_module import hmiVAE from anndata import AnnData from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.trainer import Trainer from scipy.stats.mstats import winsorize from ScModeDataloader import ScModeDataloader @@ -69,8 +71,10 @@ def __init__( # super(hmivaeModel, self).__init__(adata) super().__init__() + self.adata = adata + self.train_batch, self.test_batch, self.n_covariates = self.setup_anndata( - adata=adata, + adata=self.adata, protein_correlations_obsm_key="correlations", cell_morphology_obsm_key="morphology", ) @@ -109,13 +113,16 @@ def __init__( def train( self, + max_epochs=100, ): # misnomer, both train and test are here (either rename or separate) - early_stopping = EarlyStopping( - monitor="test_loss", mode="min", patience=3 - ) # need to add this + early_stopping = EarlyStopping(monitor="test_loss", mode="min", patience=3) + + wandb_logger = WandbLogger(project="hmiVAE_init_trial_runs") - trainer = Trainer(max_epochs=10, callbacks=[early_stopping]) + trainer = Trainer( + max_epochs=max_epochs, callbacks=[early_stopping], logger=wandb_logger + ) trainer.fit( self.module, self.train_batch, self.test_batch @@ -127,24 +134,25 @@ def train( @torch.no_grad() def get_latent_representation( self, - adata: AnnData, protein_correlations_obsm_key: str, cell_morphology_obsm_key: str, is_trained_model: Optional[bool] = True, - ) -> np.ndarray: + ) -> AnnData: """ Gives the latent representation of each cell. """ if is_trained_model: - data_train, data_test = self.setup_anndata( - adata, + adata_train, adata_test, data_train, data_test = self.setup_anndata( + self.adata, protein_correlations_obsm_key, cell_morphology_obsm_key, is_trained_model=is_trained_model, ) - train_mu_z = self.module.inference(data_train) - test_mu_z = self.module.inference(data_test) - return train_mu_z, test_mu_z + # print(data_train.samples_onehot.size()) + adata_train.obsm["VAE"] = self.module.inference(data_train) + adata_test.obsm["VAE"] = self.module.inference(data_test) + # test_mu_z = self.module.inference(data_test) #leaving it out for now, how to incorporate one-hot encoding here? + return ad.concat([adata_train, adata_test], uns_merge="first") else: raise Exception( "No latent representation to produce! Model is not trained!" @@ -234,6 +242,6 @@ def setup_anndata( loader_test = DataLoader(data_test, batch_size=batch_size) # shuffle=True) if is_trained_model: - return data_train, data_test + return adata_train, adata_test, data_train, data_test else: return loader_train, loader_test, len(samples_train) diff --git a/hmivae/_hmivae_module.py b/hmivae/_hmivae_module.py index 81f61ad..487cf47 100644 --- a/hmivae/_hmivae_module.py +++ b/hmivae/_hmivae_module.py @@ -97,13 +97,10 @@ def em_recon_loss( dec_x_logstd_morph, dec_x_mu_spcont, dec_x_logstd_spcont, - covariates_mu, - covariates_std, y, s, m, c, - cov_list, # weights=None, ): """Takes in the parameters output from the decoder, @@ -128,20 +125,17 @@ def em_recon_loss( dec_x_std_corr = torch.exp(dec_x_logstd_corr) dec_x_std_morph = torch.exp(dec_x_logstd_morph) dec_x_std_spcont = torch.exp(dec_x_logstd_spcont) - cov_std = torch.exp(covariates_std) p_rec_exp = torch.distributions.Normal(dec_x_mu_exp, dec_x_std_exp + 1e-6) p_rec_corr = torch.distributions.Normal(dec_x_mu_corr, dec_x_std_corr + 1e-6) p_rec_morph = torch.distributions.Normal(dec_x_mu_morph, dec_x_std_morph + 1e-6) p_rec_spcont = torch.distributions.Normal( dec_x_mu_spcont, dec_x_std_spcont + 1e-6 ) - p_rec_cov = torch.distributions.Normal(covariates_mu, cov_std + 1e-6) log_p_xz_exp = p_rec_exp.log_prob(y) log_p_xz_corr = p_rec_corr.log_prob(s) log_p_xz_morph = p_rec_morph.log_prob(m) log_p_xz_spcont = p_rec_spcont.log_prob(c) # already dense matrix - log_p_cov = p_rec_cov.log_prob(cov_list) # if weights is None: # log_p_xz_corr = p_rec_corr.log_prob(s) @@ -154,9 +148,13 @@ def em_recon_loss( log_p_xz_corr = log_p_xz_corr.sum(-1) log_p_xz_morph = log_p_xz_morph.sum(-1) log_p_xz_spcont = log_p_xz_spcont.sum(-1) - log_p_cov = log_p_cov.sum(-1) - return log_p_xz_exp, log_p_xz_corr, log_p_xz_morph, log_p_xz_spcont, log_p_cov + return ( + log_p_xz_exp, + log_p_xz_corr, + log_p_xz_morph, + log_p_xz_spcont, + ) def neg_ELBO( self, @@ -170,14 +168,11 @@ def neg_ELBO( dec_x_logstd_morph, dec_x_mu_spcont, dec_x_logstd_spcont, - covariates_mu, - covariates_std, z, y, s, m, c, - cov_list, # weights=None, ): kl_div = self.KL_div(enc_x_mu, enc_x_logstd, z) @@ -187,7 +182,6 @@ def neg_ELBO( recon_lik_corr, recon_lik_mor, recon_lik_sc, - reconstructed_covs, ) = self.em_recon_loss( dec_x_mu_exp, dec_x_logstd_exp, @@ -197,13 +191,10 @@ def neg_ELBO( dec_x_logstd_morph, dec_x_mu_spcont, dec_x_logstd_spcont, - covariates_mu, - covariates_std, y, s, m, c, - cov_list, # weights, ) return ( @@ -212,7 +203,6 @@ def neg_ELBO( recon_lik_corr, recon_lik_mor, recon_lik_sc, - reconstructed_covs, ) def loss(self, kl_div, recon_loss, beta: float = 1.0): @@ -265,8 +255,6 @@ def training_step( log_std_x_morph_hat, mu_x_spcont_hat, log_std_x_spcont_hat, - covariates_mu, - covariates_std, # weights, ) = self.decoder(z_samples, cov_list) @@ -276,7 +264,6 @@ def training_step( recon_lik_corr, recon_lik_mor, recon_lik_sc, - reconstructed_covs, ) = self.neg_ELBO( mu_z, log_std_z, @@ -288,14 +275,11 @@ def training_step( log_std_x_morph_hat, mu_x_spcont_hat, log_std_x_spcont_hat, - covariates_mu, - covariates_std, z_samples, Y, S, M, spatial_context, - cov_list, # weights, ) @@ -304,7 +288,6 @@ def training_step( + recon_weights[1] * recon_lik_corr + recon_weights[2] * recon_lik_mor + recon_weights[3] * recon_lik_sc - + reconstructed_covs ) loss = self.loss(kl_div, recon_loss, beta=beta) @@ -380,8 +363,6 @@ def validation_step( log_std_x_morph_hat, mu_x_spcont_hat, log_std_x_spcont_hat, - covariates_mu, - covariates_std, # weights, ) = self.decoder(z_samples, cov_list) @@ -391,7 +372,6 @@ def validation_step( recon_lik_corr, recon_lik_mor, recon_lik_sc, - reconstructed_covs, ) = self.neg_ELBO( mu_z, log_std_z, @@ -403,14 +383,11 @@ def validation_step( log_std_x_morph_hat, mu_x_spcont_hat, log_std_x_spcont_hat, - covariates_mu, - covariates_std, z_samples, Y, S, M, spatial_context, - cov_list, # weights, ) @@ -419,7 +396,6 @@ def validation_step( + recon_weights[1] * recon_lik_corr + recon_weights[2] * recon_lik_mor + recon_weights[3] * recon_lik_sc - + reconstructed_covs ) loss = self.loss(kl_div, recon_loss, beta=beta) @@ -472,17 +448,43 @@ def get_input_embeddings( @torch.no_grad() def inference( - self, data, indices: Optional[Sequence[int]] = None, give_mean: bool = True + self, + data, + indices: Optional[Sequence[int]] = None, + give_mean: bool = True, + categories: Optional[List[float]] = None, ) -> np.ndarray: """ Return the latent representation of each cell. """ + Y = data.Y + S = data.S + M = data.M + C = data.C + one_hot = data.samples_onehot + if one_hot.shape[1] < self.n_covariates: + zeros_pad = torch.Tensor( + np.zeros([one_hot.shape[0], self.n_covariates - one_hot.shape[1]]) + ) + one_hot = torch.cat([one_hot, zeros_pad], 1) + else: + one_hot = one_hot + batch_idx = data[-1] + if categories is not None: + if len(categories) > 0: + categories = torch.Tensor(categories)[batch_idx, :] + else: + categories = torch.Tensor(categories) + else: + categories = torch.Tensor([]) + + cov_list = torch.cat([one_hot, categories], 1).float() if give_mean: - mu_z, _ = self.encoder(data.Y, data.S, data.M, data.C) + mu_z, _ = self.encoder(Y, S, M, C, cov_list) return mu_z.numpy() else: - mu_z, log_std_z = self.encoder(data.Y, data.S, data.M, data.C) + mu_z, log_std_z = self.encoder(Y, S, M, C, cov_list) z = self.reparameterization(mu_z, log_std_z) return z.numpy() From 358cef33df28a63334fb400f2621be3c909f8fac Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Sun, 24 Jul 2022 20:16:38 -0400 Subject: [PATCH 06/18] one-hot encoding, background stains included as optional --- hmivae/ScModeDataloader.py | 67 ++++++++++-- hmivae/_hmivae_base_components.py | 15 ++- hmivae/_hmivae_model.py | 167 ++++++++++++++++++++++++------ hmivae/_hmivae_module.py | 117 +++++++++++++-------- 4 files changed, 278 insertions(+), 88 deletions(-) diff --git a/hmivae/ScModeDataloader.py b/hmivae/ScModeDataloader.py index 71db37c..02f0545 100644 --- a/hmivae/ScModeDataloader.py +++ b/hmivae/ScModeDataloader.py @@ -32,15 +32,18 @@ def __init__(self, adata, scalers=None): scalers: set of data scalers """ self.adata = adata - Y = adata.X + Y = adata.X # per cell protein mean expression S = adata.obsm["correlations"] M = adata.obsm["morphology"] + self.n_cells = Y.shape[0] # number of cells + if scalers is None: self.scalers = {} self.scalers["Y"] = StandardScaler().fit(Y) self.scalers["S"] = StandardScaler().fit(S) self.scalers["M"] = StandardScaler().fit(M) + else: self.scalers = scalers @@ -55,6 +58,17 @@ def __init__(self, adata, scalers=None): self.samples_onehot = self.one_hot_encoding() + if "background_covs" in adata.obsm.keys(): # dealing with background covariates + bkg = adata.obsm["background_covs"] + if scalers is None: + self.scalers["BKG"] = StandardScaler().fit(bkg) + else: + BKG = self.scalers["BKG"].transform(bkg) + + self.BKG = torch.tensor(BKG).float() + else: + self.BKG = None + def __len__(self): return self.Y.shape[0] @@ -73,23 +87,54 @@ def one_hot_encoding(self, test=False): return torch.tensor(df.to_numpy()).float() def get_spatial_context(self): + """ + Multiplies the sparse neighbourhood matrix to protein mean expression (self.Y), + protein-protein correlation (self.S) and cell morphology (self.M) matrices. + The product-sum is normalized by the number of neighbours each cell has. + The resulting matrix, self.C, is the spatial context. + """ adj_mat = sparse_numpy_to_torch( self.adata.obsp["connectivities"] ) # adjacency matrix concatenated_features = torch.cat((self.Y, self.S, self.M), 1) - C = torch.smm( # normalize for the number of adjacent cells + n_cell_neighbours = self.adata.obsp[ + "connectivities" + ].toarray() # .sum(1).reshape([self.n_cells,1]) + n_cell_neighbours[np.where(n_cell_neighbours > 0)] = 1.0 + n_cell_neighbours = n_cell_neighbours.sum(1).reshape([self.n_cells, 1]) + n_cell_neighbours[np.where(n_cell_neighbours < 1.0)] = 1.0 + + # print('n_cell_neighbours', n_cell_neighbours) + + unnormalized_C = torch.smm( adj_mat, concatenated_features - ).to_dense() # spatial context for each cell + ).to_dense() # unnormalized spatial context for each cell + + C = torch.div( + unnormalized_C, torch.tensor(n_cell_neighbours) + ) # normalize by number of adjacent cells + # print('sum C', C.sum()) return C def __getitem__(self, idx): - return ( - self.Y[idx, :], - self.S[idx, :], - self.M[idx, :], - self.C[idx, :], - self.samples_onehot[idx, :], - idx, - ) + if self.BKG is None: + return ( + self.Y[idx, :], + self.S[idx, :], + self.M[idx, :], + self.C[idx, :], + self.samples_onehot[idx, :], + idx, + ) + else: + return ( + self.Y[idx, :], + self.S[idx, :], + self.M[idx, :], + self.C[idx, :], + self.samples_onehot[idx, :], + self.BKG[idx, :], + idx, + ) diff --git a/hmivae/_hmivae_base_components.py b/hmivae/_hmivae_base_components.py index 8550862..83a7b5e 100644 --- a/hmivae/_hmivae_base_components.py +++ b/hmivae/_hmivae_base_components.py @@ -66,16 +66,29 @@ def forward( h_mean = F.elu(self.input_exp(x_mean)) h_mean2 = F.elu(self.exp_hidden(h_mean)) + # print("h_mean2", h_mean2) + h_correlations = F.elu(self.input_corr(x_correlations)) h_correlations2 = F.elu(self.corr_hidden(h_correlations)) + # print("h_correlations2", h_correlations2) + h_morphology = F.elu(self.input_morph(x_morphology)) h_morphology2 = F.elu(self.morph_hidden(h_morphology)) + # print("h_morphology2", h_morphology2) + h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context)) - h_cov = F.elu(self.input_cov(cov_list)) + # print("h_spatial_context2", h_spatial_context2) + + if cov_list.shape[0] > 1: + h_cov = F.elu(self.input_cov(cov_list)) + + # print('h_cov', h_cov) + else: + h_cov = cov_list h = torch.cat( [h_mean2, h_correlations2, h_morphology2, h_spatial_context2, h_cov], 1 diff --git a/hmivae/_hmivae_model.py b/hmivae/_hmivae_model.py index c8115f7..ee5bc32 100644 --- a/hmivae/_hmivae_model.py +++ b/hmivae/_hmivae_model.py @@ -5,16 +5,18 @@ import numpy as np import pytorch_lightning as pl import torch -from _hmivae_module import hmiVAE from anndata import AnnData from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.trainer import Trainer from scipy.stats.mstats import winsorize -from ScModeDataloader import ScModeDataloader from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader +# import hmivae +import hmivae._hmivae_module as module +import hmivae.ScModeDataloader as ScModeDataloader + # from scvi.data import setup_anndata # from scvi.model._utils import _init_library_size # from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin @@ -65,7 +67,10 @@ def __init__( E_sc: int = 32, latent_dim: int = 10, n_covariates: int = 0, + use_covs: bool = False, n_hidden: int = 1, + cofactor: float = 1.0, + batch_correct: bool = True, **model_kwargs, ): # super(hmivaeModel, self).__init__(adata) @@ -73,14 +78,55 @@ def __init__( self.adata = adata - self.train_batch, self.test_batch, self.n_covariates = self.setup_anndata( + if n_covariates > 0: + + if not use_covs: + self.use_covs = True + print( + "`use_covs` is automatically set to True when `n_covariates` > 0." + ) + else: + self.use_covs = use_covs + + self.keys = [] + for key in adata.obsm.keys(): + # print(key) + if key not in ["correlations", "morphology", "xy"]: + self.keys.append(key) + + # print("n_keys", len(self.keys)) + else: + self.keys = None + self.use_covs = use_covs + + ( + self.train_batch, + self.test_batch, + self.n_covariates, + # self.cov_list, + ) = self.setup_anndata( adata=self.adata, protein_correlations_obsm_key="correlations", cell_morphology_obsm_key="morphology", + continuous_covariate_keys=self.keys, + cofactor=cofactor, + image_correct=batch_correct, ) + # for batch in self.train_batch: + # print('Y', batch[0]) + # print('S', batch[1]) + # print('M', batch[2]) + # print('C', batch[3]) + # print('one-hot', batch[4]) + # break + + # print("cov_list", self.cov_list.shape) + + # print('n_covs', self.n_covariates) + # self.summary_stats provides information about anndata dimensions and other tensor info - self.module = hmiVAE( + self.module = module.hmiVAE( input_exp_dim=input_exp_dim, input_corr_dim=input_corr_dim, input_morph_dim=input_morph_dim, @@ -92,6 +138,9 @@ def __init__( latent_dim=latent_dim, n_covariates=self.n_covariates, n_hidden=n_hidden, + use_covs=self.use_covs, + # cat_list=self.cov_list, + batch_correct=batch_correct, **model_kwargs, ) self._model_summary_string = ( @@ -114,44 +163,60 @@ def __init__( def train( self, max_epochs=100, + check_val_every_n_epoch=5, ): # misnomer, both train and test are here (either rename or separate) - early_stopping = EarlyStopping(monitor="test_loss", mode="min", patience=3) + early_stopping = EarlyStopping(monitor="test_loss", mode="min", patience=2) wandb_logger = WandbLogger(project="hmiVAE_init_trial_runs") trainer = Trainer( - max_epochs=max_epochs, callbacks=[early_stopping], logger=wandb_logger + max_epochs=max_epochs, + check_val_every_n_epoch=check_val_every_n_epoch, + callbacks=[early_stopping], + logger=wandb_logger, + # gradient_clip_val=2.0, ) - trainer.fit( - self.module, self.train_batch, self.test_batch - ) # training, add wandb - # trainer.test(dataloaders=self.test_batch) # test, add wandb - - # return trainer + trainer.fit(self.module, self.train_batch, self.test_batch) @torch.no_grad() def get_latent_representation( self, protein_correlations_obsm_key: str, cell_morphology_obsm_key: str, - is_trained_model: Optional[bool] = True, + continuous_covariate_keys: Optional[List[str]] = None, # default is self.keys + cofactor: float = 1.0, + is_trained_model: Optional[bool] = False, + batch_correct: Optional[bool] = True, ) -> AnnData: """ Gives the latent representation of each cell. """ if is_trained_model: - adata_train, adata_test, data_train, data_test = self.setup_anndata( + ( + adata_train, + adata_test, + data_train, + data_test, + # cat_list, + # train_idx, + # test_idx, + ) = self.setup_anndata( self.adata, protein_correlations_obsm_key, cell_morphology_obsm_key, + continuous_covariate_keys=self.keys, + cofactor=cofactor, is_trained_model=is_trained_model, + image_correct=batch_correct, ) - # print(data_train.samples_onehot.size()) - adata_train.obsm["VAE"] = self.module.inference(data_train) - adata_test.obsm["VAE"] = self.module.inference(data_test) - # test_mu_z = self.module.inference(data_test) #leaving it out for now, how to incorporate one-hot encoding here? + + adata_train.obsm["VAE"] = self.module.inference( + data_train + ) # idx=train_idx) + adata_test.obsm["VAE"] = self.module.inference(data_test) # idx=test_idx) + return ad.concat([adata_train, adata_test], uns_merge="first") else: raise Exception( @@ -168,12 +233,15 @@ def setup_anndata( # cell_spatial_context_obsm_key: str, protein_correlations_names_uns_key: Optional[str] = None, cell_morphology_names_uns_key: Optional[str] = None, + image_correct: bool = True, batch_size: Optional[int] = 128, batch_key: Optional[str] = None, labels_key: Optional[str] = None, layer: Optional[str] = None, categorical_covariate_keys: Optional[List[str]] = None, - continuous_covariate_keys: Optional[List[str]] = None, + continuous_covariate_keys: Optional[ + List[str] + ] = None, # obsm keys for other categories cofactor: float = 1.0, train_prop: Optional[float] = 0.75, apply_winsorize: Optional[bool] = True, @@ -199,13 +267,27 @@ def setup_anndata( ------- %(returns)s """ - # N_TOTAL_CELLS = adata.shape[0] N_PROTEINS = adata.shape[1] - # N_CORRELATIONS = len(adata.uns["names_correlations"]) N_MORPHOLOGY = len(adata.uns["names_morphology"]) - # N_TOTAL_FEATURES = N_PROTEINS + N_CORRELATIONS + N_MORPHOLOGY - # if cofactor is not None: + if continuous_covariate_keys is not None: + cat_list = [] + for cat_key in continuous_covariate_keys: + # print(cat_key) + # print(f"{cat_key} shape:", adata.obsm[cat_key].shape) + category = adata.obsm[cat_key] + cat_list.append(category) + cat_list = np.arcsinh(np.concatenate(cat_list, 1) / cofactor) + n_cats = cat_list.shape[1] + if apply_winsorize: + for i in range(cat_list.shape[1]): + cat_list[:, i] = winsorize(cat_list[:, i], limits=[0, 0.01]) + + adata.obsm["background_covs"] = cat_list + else: + # cat_list = np.array([]) + n_cats = 0 + adata.X = np.arcsinh(adata.X / cofactor) if apply_winsorize: @@ -225,23 +307,46 @@ def setup_anndata( adata.obs["Sample_name"].unique().tolist() ) # samples in the adata - # train_size = int(np.floor(len(samples_list) * train_prop)) - # test_size = len(samples_list) - train_size - samples_train, samples_test = train_test_split( samples_list, train_size=train_prop, random_state=random_seed ) - + # samples_df = adata.obs.reset_index() adata_train = adata.copy()[adata.obs["Sample_name"].isin(samples_train), :] + # train_idx = samples_df.loc[samples_df["Sample_name"].isin(samples_train),:].index adata_test = adata.copy()[adata.obs["Sample_name"].isin(samples_test), :] + # test_idx = samples_df.loc[samples_df["Sample_name"].isin(samples_test),:].index - data_train = ScModeDataloader(adata_train) - data_test = ScModeDataloader(adata_test, data_train.scalers) + data_train = ScModeDataloader.ScModeDataloader(adata_train) + data_test = ScModeDataloader.ScModeDataloader(adata_test, data_train.scalers) loader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True) loader_test = DataLoader(data_test, batch_size=batch_size) # shuffle=True) + if image_correct: + n_samples = len(samples_train) + # print("n_samples", n_samples) + # print("cat_list", cat_list.shape) + else: + n_samples = 0 + if is_trained_model: - return adata_train, adata_test, data_train, data_test + # if continuous_covariate_keys is not None: + return ( + adata_train, + adata_test, + data_train, + data_test, + ) # cat_list, train_idx, test_idx + # else: + # cat_list = None + # return adata_train, adata_test, data_train, data_test, cat_list, train_idx, test_idx + else: - return loader_train, loader_test, len(samples_train) + + return ( + loader_train, + loader_test, + n_samples + n_cats, + ) # cat_list + # else: + # return loader_train, loader_test, n_samples, cat_list diff --git a/hmivae/_hmivae_module.py b/hmivae/_hmivae_module.py index 487cf47..ade37b9 100644 --- a/hmivae/_hmivae_module.py +++ b/hmivae/_hmivae_module.py @@ -1,10 +1,12 @@ -from typing import List, Optional, Sequence +from typing import Optional, Sequence # List import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F -from _hmivae_base_components import DecoderHMIVAE, EncoderHMIVAE + +# import hmivae +from hmivae._hmivae_base_components import DecoderHMIVAE, EncoderHMIVAE # from anndata import AnnData @@ -28,12 +30,21 @@ def __init__( E_sc: int = 32, latent_dim: int = 10, n_covariates: int = 0, + # cat_list: Optional[List[float]] = None, + use_covs: bool = False, n_hidden: int = 1, + batch_correct: bool = True, ): super().__init__() # hidden_dim = E_me + E_cr + E_mr + E_sc self.n_covariates = n_covariates + # self.cat_list = cat_list + + self.batch_correct = batch_correct + + self.use_covs = use_covs + self.encoder = EncoderHMIVAE( input_exp_dim, input_corr_dim, @@ -60,6 +71,8 @@ def __init__( n_covariates=n_covariates, ) + self.save_hyperparameters(ignore=["adata", "cat_list"]) + def reparameterization(self, mu, log_std): std = torch.exp(log_std) eps = torch.randn_like(log_std) @@ -215,7 +228,6 @@ def training_step( corr_weights=False, recon_weights=np.array([1.0, 1.0, 1.0, 1.0]), beta=1.0, - categories: Optional[List[float]] = None, ): """ Carries out the training step. @@ -229,18 +241,20 @@ def training_step( S = train_batch[1] M = train_batch[2] spatial_context = train_batch[3] - one_hot = train_batch[4] - batch_idx = train_batch[-1] - if categories is not None: - if len(categories) > 0: - categories = torch.Tensor(categories)[batch_idx, :] - else: - categories = torch.Tensor(categories) + # batch_idx = train_batch[-1] + + if self.use_covs: + categories = train_batch[5] else: categories = torch.Tensor([]) - cov_list = torch.cat([one_hot, categories], 1).float() - # print('train',cov_list.size()) + if self.batch_correct: + one_hot = train_batch[4] + + cov_list = torch.cat([one_hot, categories], 1).float() + else: + cov_list = torch.Tensor([]) + mu_z, log_std_z = self.encoder(Y, S, M, spatial_context, cov_list) z_samples = self.reparameterization(mu_z, log_std_z) @@ -308,11 +322,10 @@ def validation_step( self, test_batch, n_other_cat: int = 0, - L_iter: int = 10, + L_iter: int = 100, corr_weights=False, recon_weights=np.array([1.0, 1.0, 1.0, 1.0]), beta=1.0, - categories: Optional[List[float]] = None, ): """---> Add random one-hot encoding Carries out the validation/test step. @@ -327,25 +340,25 @@ def validation_step( M = test_batch[2] spatial_context = test_batch[3] batch_idx = test_batch[-1] - # print(batch_idx) test_loss = [] - n_classes = self.n_covariates # - n_other_cat + + if self.use_covs: + categories = test_batch[5] + n_classes = self.n_covariates - categories.shape[1] + else: + categories = torch.Tensor([]) + n_classes = self.n_covariates + for i in range(L_iter): - # print(n_classes) - # print(len(batch_idx)) - # print(np.eye(n_classes)[np.random.choice(n_classes, len(batch_idx))]) - one_hot = self.random_one_hot(n_classes=n_classes, n_samples=len(batch_idx)) - # print(one_hot.size()) - - if categories is not None: - if len(categories) > 0: - categories = torch.Tensor(categories)[batch_idx, :] - else: - categories = torch.Tensor(categories) - else: - categories = torch.Tensor([]) - cov_list = torch.cat([one_hot, categories], 1).float() + if self.batch_correct: + one_hot = self.random_one_hot( + n_classes=n_classes, n_samples=len(batch_idx) + ) + + cov_list = torch.cat([one_hot, categories], 1).float() + else: + cov_list = torch.Tensor([]) mu_z, log_std_z = self.encoder( Y, S, M, spatial_context, cov_list @@ -400,6 +413,10 @@ def validation_step( loss = self.loss(kl_div, recon_loss, beta=beta) + # batch = [Y,S,M,spatial_context,one_hot,batch_idx] + + # loss = self.training_step(batch)[0] + test_loss.append(loss) self.log( @@ -452,7 +469,7 @@ def inference( data, indices: Optional[Sequence[int]] = None, give_mean: bool = True, - categories: Optional[List[float]] = None, + # idx = None, ) -> np.ndarray: """ Return the latent representation of each cell. @@ -461,24 +478,34 @@ def inference( S = data.S M = data.M C = data.C - one_hot = data.samples_onehot - if one_hot.shape[1] < self.n_covariates: - zeros_pad = torch.Tensor( - np.zeros([one_hot.shape[0], self.n_covariates - one_hot.shape[1]]) - ) - one_hot = torch.cat([one_hot, zeros_pad], 1) + # batch_idx = idx + # print(batch_idx) + if self.use_covs: + categories = data.BKG + n_cats = categories.shape[1] else: - one_hot = one_hot - batch_idx = data[-1] - if categories is not None: - if len(categories) > 0: - categories = torch.Tensor(categories)[batch_idx, :] + categories = torch.Tensor([]) + n_cats = 0 + + if self.batch_correct: + one_hot = data.samples_onehot + if one_hot.shape[1] < self.n_covariates - n_cats: + zeros_pad = torch.Tensor( + np.zeros( + [ + one_hot.shape[0], + (self.n_covariates - n_cats) - one_hot.shape[1], + ] + ) + ) + one_hot = torch.cat([one_hot, zeros_pad], 1) else: - categories = torch.Tensor(categories) + one_hot = one_hot + + cov_list = torch.cat([one_hot, categories], 1).float() else: - categories = torch.Tensor([]) + cov_list = torch.Tensor([]) - cov_list = torch.cat([one_hot, categories], 1).float() if give_mean: mu_z, _ = self.encoder(Y, S, M, C, cov_list) From 8f985cbf5052675ef4c0bcfd942ab713a745c37d Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Thu, 28 Jul 2022 14:26:14 -0400 Subject: [PATCH 07/18] latest code, sparse addition function in ScModeDataloader --- hmivae/ScModeDataloader.py | 30 +++++++------ hmivae/_hmivae_model.py | 88 ++++++++------------------------------ hmivae/_hmivae_module.py | 16 +------ 3 files changed, 38 insertions(+), 96 deletions(-) diff --git a/hmivae/ScModeDataloader.py b/hmivae/ScModeDataloader.py index 02f0545..c00a21e 100644 --- a/hmivae/ScModeDataloader.py +++ b/hmivae/ScModeDataloader.py @@ -22,6 +22,17 @@ def sparse_numpy_to_torch(adj_mat): return torch.sparse_coo_tensor(i, v, shape) +def get_n_cell_neighbours(adj_mat): + """Get the sum of a sparse matrix + Need to first replace all non-zero elements with 1 + Then add them up to get the number of neighbours + """ + adj_mat[adj_mat.nonzero()] = 1.0 + n_neighbours_sparse = adj_mat.sum(1) + + return np.asarray(n_neighbours_sparse) + + class ScModeDataloader(TensorDataset): def __init__(self, adata, scalers=None): """ @@ -59,11 +70,12 @@ def __init__(self, adata, scalers=None): self.samples_onehot = self.one_hot_encoding() if "background_covs" in adata.obsm.keys(): # dealing with background covariates - bkg = adata.obsm["background_covs"] + BKG = adata.obsm["background_covs"] if scalers is None: - self.scalers["BKG"] = StandardScaler().fit(bkg) + self.scalers["BKG"] = StandardScaler().fit(BKG) + BKG = self.scalers["BKG"].transform(BKG) else: - BKG = self.scalers["BKG"].transform(bkg) + BKG = self.scalers["BKG"].transform(BKG) self.BKG = torch.tensor(BKG).float() else: @@ -98,14 +110,9 @@ def get_spatial_context(self): ) # adjacency matrix concatenated_features = torch.cat((self.Y, self.S, self.M), 1) - n_cell_neighbours = self.adata.obsp[ - "connectivities" - ].toarray() # .sum(1).reshape([self.n_cells,1]) - n_cell_neighbours[np.where(n_cell_neighbours > 0)] = 1.0 - n_cell_neighbours = n_cell_neighbours.sum(1).reshape([self.n_cells, 1]) - n_cell_neighbours[np.where(n_cell_neighbours < 1.0)] = 1.0 - - # print('n_cell_neighbours', n_cell_neighbours) + n_cell_neighbours = get_n_cell_neighbours( + self.adata.copy().obsp["connectivities"] + ) unnormalized_C = torch.smm( adj_mat, concatenated_features @@ -114,7 +121,6 @@ def get_spatial_context(self): C = torch.div( unnormalized_C, torch.tensor(n_cell_neighbours) ) # normalize by number of adjacent cells - # print('sum C', C.sum()) return C def __getitem__(self, idx): diff --git a/hmivae/_hmivae_model.py b/hmivae/_hmivae_model.py index ee5bc32..3b104b4 100644 --- a/hmivae/_hmivae_model.py +++ b/hmivae/_hmivae_model.py @@ -17,16 +17,9 @@ import hmivae._hmivae_module as module import hmivae.ScModeDataloader as ScModeDataloader -# from scvi.data import setup_anndata -# from scvi.model._utils import _init_library_size -# from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin -# from scvi.utils import setup_anndata_dsp - - logger = logging.getLogger(__name__) -# class hmivaeModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): class hmivaeModel(pl.LightningModule): """ Skeleton for an scvi-tools model. @@ -66,7 +59,6 @@ def __init__( E_mr: int = 32, E_sc: int = 32, latent_dim: int = 10, - n_covariates: int = 0, use_covs: bool = False, n_hidden: int = 1, cofactor: float = 1.0, @@ -77,27 +69,17 @@ def __init__( super().__init__() self.adata = adata + self.use_covs = use_covs - if n_covariates > 0: - - if not use_covs: - self.use_covs = True - print( - "`use_covs` is automatically set to True when `n_covariates` > 0." - ) - else: - self.use_covs = use_covs - + if self.use_covs: self.keys = [] for key in adata.obsm.keys(): - # print(key) + if key not in ["correlations", "morphology", "xy"]: self.keys.append(key) - # print("n_keys", len(self.keys)) else: self.keys = None - self.use_covs = use_covs ( self.train_batch, @@ -113,18 +95,6 @@ def __init__( image_correct=batch_correct, ) - # for batch in self.train_batch: - # print('Y', batch[0]) - # print('S', batch[1]) - # print('M', batch[2]) - # print('C', batch[3]) - # print('one-hot', batch[4]) - # break - - # print("cov_list", self.cov_list.shape) - - # print('n_covs', self.n_covariates) - # self.summary_stats provides information about anndata dimensions and other tensor info self.module = module.hmiVAE( input_exp_dim=input_exp_dim, @@ -146,14 +116,14 @@ def __init__( self._model_summary_string = ( "hmiVAE model with the following parameters: \n n_latent:{}, " "n_protein_expression:{}, n_correlation:{}, n_morphology:{}, n_spatial_context:{}, " - "n_covariates:{} " + "use_covariates:{} " ).format( latent_dim, input_exp_dim, input_corr_dim, input_morph_dim, input_spcont_dim, - n_covariates, + use_covs, ) # necessary line to get params that will be used for saving/loading # self.init_params_ = self._get_init_params(locals()) @@ -194,15 +164,7 @@ def get_latent_representation( Gives the latent representation of each cell. """ if is_trained_model: - ( - adata_train, - adata_test, - data_train, - data_test, - # cat_list, - # train_idx, - # test_idx, - ) = self.setup_anndata( + (adata_train, adata_test, data_train, data_test,) = self.setup_anndata( self.adata, protein_correlations_obsm_key, cell_morphology_obsm_key, @@ -212,10 +174,8 @@ def get_latent_representation( image_correct=batch_correct, ) - adata_train.obsm["VAE"] = self.module.inference( - data_train - ) # idx=train_idx) - adata_test.obsm["VAE"] = self.module.inference(data_test) # idx=test_idx) + adata_train.obsm["VAE"] = self.module.inference(data_train) + adata_test.obsm["VAE"] = self.module.inference(data_test) return ad.concat([adata_train, adata_test], uns_merge="first") else: @@ -230,7 +190,6 @@ def setup_anndata( adata: AnnData, protein_correlations_obsm_key: str, cell_morphology_obsm_key: str, - # cell_spatial_context_obsm_key: str, protein_correlations_names_uns_key: Optional[str] = None, cell_morphology_names_uns_key: Optional[str] = None, image_correct: bool = True, @@ -273,19 +232,16 @@ def setup_anndata( if continuous_covariate_keys is not None: cat_list = [] for cat_key in continuous_covariate_keys: - # print(cat_key) - # print(f"{cat_key} shape:", adata.obsm[cat_key].shape) category = adata.obsm[cat_key] cat_list.append(category) - cat_list = np.arcsinh(np.concatenate(cat_list, 1) / cofactor) + cat_list = np.concatenate(cat_list, 1) n_cats = cat_list.shape[1] - if apply_winsorize: - for i in range(cat_list.shape[1]): - cat_list[:, i] = winsorize(cat_list[:, i], limits=[0, 0.01]) + # if apply_winsorize: + # for i in range(cat_list.shape[1]): + # cat_list[:, i] = winsorize(cat_list[:, i], limits=[0, 0.01]) adata.obsm["background_covs"] = cat_list else: - # cat_list = np.array([]) n_cats = 0 adata.X = np.arcsinh(adata.X / cofactor) @@ -310,36 +266,30 @@ def setup_anndata( samples_train, samples_test = train_test_split( samples_list, train_size=train_prop, random_state=random_seed ) - # samples_df = adata.obs.reset_index() + adata_train = adata.copy()[adata.obs["Sample_name"].isin(samples_train), :] - # train_idx = samples_df.loc[samples_df["Sample_name"].isin(samples_train),:].index + adata_test = adata.copy()[adata.obs["Sample_name"].isin(samples_test), :] - # test_idx = samples_df.loc[samples_df["Sample_name"].isin(samples_test),:].index data_train = ScModeDataloader.ScModeDataloader(adata_train) data_test = ScModeDataloader.ScModeDataloader(adata_test, data_train.scalers) loader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True) - loader_test = DataLoader(data_test, batch_size=batch_size) # shuffle=True) + loader_test = DataLoader(data_test, batch_size=batch_size) if image_correct: n_samples = len(samples_train) - # print("n_samples", n_samples) - # print("cat_list", cat_list.shape) else: n_samples = 0 if is_trained_model: - # if continuous_covariate_keys is not None: + return ( adata_train, adata_test, data_train, data_test, - ) # cat_list, train_idx, test_idx - # else: - # cat_list = None - # return adata_train, adata_test, data_train, data_test, cat_list, train_idx, test_idx + ) else: @@ -347,6 +297,4 @@ def setup_anndata( loader_train, loader_test, n_samples + n_cats, - ) # cat_list - # else: - # return loader_train, loader_test, n_samples, cat_list + ) diff --git a/hmivae/_hmivae_module.py b/hmivae/_hmivae_module.py index ade37b9..c5f3784 100644 --- a/hmivae/_hmivae_module.py +++ b/hmivae/_hmivae_module.py @@ -30,7 +30,6 @@ def __init__( E_sc: int = 32, latent_dim: int = 10, n_covariates: int = 0, - # cat_list: Optional[List[float]] = None, use_covs: bool = False, n_hidden: int = 1, batch_correct: bool = True, @@ -39,8 +38,6 @@ def __init__( # hidden_dim = E_me + E_cr + E_mr + E_sc self.n_covariates = n_covariates - # self.cat_list = cat_list - self.batch_correct = batch_correct self.use_covs = use_covs @@ -241,7 +238,6 @@ def training_step( S = train_batch[1] M = train_batch[2] spatial_context = train_batch[3] - # batch_idx = train_batch[-1] if self.use_covs: categories = train_batch[5] @@ -360,9 +356,7 @@ def validation_step( else: cov_list = torch.Tensor([]) - mu_z, log_std_z = self.encoder( - Y, S, M, spatial_context, cov_list - ) # valid step + mu_z, log_std_z = self.encoder(Y, S, M, spatial_context, cov_list) z_samples = self.reparameterization(mu_z, log_std_z) @@ -413,10 +407,6 @@ def validation_step( loss = self.loss(kl_div, recon_loss, beta=beta) - # batch = [Y,S,M,spatial_context,one_hot,batch_idx] - - # loss = self.training_step(batch)[0] - test_loss.append(loss) self.log( @@ -478,8 +468,6 @@ def inference( S = data.S M = data.M C = data.C - # batch_idx = idx - # print(batch_idx) if self.use_covs: categories = data.BKG n_cats = categories.shape[1] @@ -522,5 +510,5 @@ def random_one_hot(self, n_classes: int, n_samples: int): Generates a random one hot encoded matrix. From: https://stackoverflow.com/questions/45093615/random-one-hot-matrix-in-numpy """ - # x = np.eye(n_classes) + return torch.Tensor(np.eye(n_classes)[np.random.choice(n_classes, n_samples)]) From 801ab51524391d8b6d4cb3edbb71a23b8ba7bdd7 Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Mon, 24 Oct 2022 17:58:43 -0400 Subject: [PATCH 08/18] edits with beta and ablation tests --- hmivae/_hmivae_base_components.py | 272 +++++++++++++++++++++--------- hmivae/_hmivae_model.py | 78 ++++++++- hmivae/_hmivae_module.py | 248 +++++++++++++++++++++++---- 3 files changed, 470 insertions(+), 128 deletions(-) diff --git a/hmivae/_hmivae_base_components.py b/hmivae/_hmivae_base_components.py index 83a7b5e..ef19816 100644 --- a/hmivae/_hmivae_base_components.py +++ b/hmivae/_hmivae_base_components.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional, Union import torch import torch.nn as nn @@ -30,13 +30,17 @@ def __init__( E_mr: int, E_sc: int, latent_dim: int, + E_cov: Optional[int] = 10, n_covariates: Optional[int] = 0, n_hidden: Optional[int] = 1, + leave_out_view: Optional[ + Union[None, Literal["expression", "correlation", "morphology", "spatial"]] + ] = None, ): super().__init__() - hidden_dim = E_me + E_cr + E_mr + E_sc + n_covariates + hidden_dim = E_me + E_cr + E_mr + E_sc + E_cov - self.input_cov = nn.Linear(n_covariates, n_covariates) + self.input_cov = nn.Linear(n_covariates, E_cov) self.input_exp = nn.Linear(input_exp_dim, E_me) self.exp_hidden = nn.Linear(E_me, E_me) @@ -63,36 +67,76 @@ def forward( x_spatial_context: torch.Tensor, cov_list=torch.Tensor([]), ): - h_mean = F.elu(self.input_exp(x_mean)) - h_mean2 = F.elu(self.exp_hidden(h_mean)) - # print("h_mean2", h_mean2) + included_views = [] - h_correlations = F.elu(self.input_corr(x_correlations)) - h_correlations2 = F.elu(self.corr_hidden(h_correlations)) + if self.leave_out_view is None: - # print("h_correlations2", h_correlations2) + h_mean = F.elu(self.input_exp(x_mean)) + h_mean2 = F.elu(self.exp_hidden(h_mean)) - h_morphology = F.elu(self.input_morph(x_morphology)) - h_morphology2 = F.elu(self.morph_hidden(h_morphology)) + # print("h_mean2", h_mean2) - # print("h_morphology2", h_morphology2) + h_correlations = F.elu(self.input_corr(x_correlations)) + h_correlations2 = F.elu(self.corr_hidden(h_correlations)) - h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) - h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context)) + # print("h_correlations2", h_correlations2) - # print("h_spatial_context2", h_spatial_context2) + h_morphology = F.elu(self.input_morph(x_morphology)) + h_morphology2 = F.elu(self.morph_hidden(h_morphology)) - if cov_list.shape[0] > 1: - h_cov = F.elu(self.input_cov(cov_list)) + # print("h_morphology2", h_morphology2) - # print('h_cov', h_cov) + h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) + h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context)) + + # print("h_spatial_context2", h_spatial_context2) + + if cov_list.shape[0] > 1: + h_cov = F.elu(self.input_cov(cov_list)) + + # print('h_cov', h_cov) + else: + h_cov = cov_list + + h = torch.cat( + [h_mean2, h_correlations2, h_morphology2, h_spatial_context2, h_cov], 1 + ).type_as(x_mean) else: - h_cov = cov_list + if self.leave_out_view != "expression": + h_mean = F.elu(self.input_exp(x_mean)) + h_mean2 = F.elu(self.exp_hidden(h_mean)) - h = torch.cat( - [h_mean2, h_correlations2, h_morphology2, h_spatial_context2, h_cov], 1 - ) + included_views.append(h_mean2) + + if self.leave_out_view != "correlation": + h_correlations = F.elu(self.input_corr(x_correlations)) + h_correlations2 = F.elu(self.corr_hidden(h_correlations)) + + included_views.append(h_correlations2) + + if self.leave_out_view != "morphology": + h_morphology = F.elu(self.input_morph(x_morphology)) + h_morphology2 = F.elu(self.morph_hidden(h_morphology)) + + included_views.append(h_morphology2) + + if self.leave_out_view != "spatial": + h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) + h_spatial_context2 = F.elu( + self.spatial_context_hidden(h_spatial_context) + ) + + included_views.append(h_spatial_context2) + + if cov_list.shape[0] > 1: + h_cov = F.elu(self.input_cov(cov_list)) + included_views.append(h_cov) + else: + h_cov = cov_list + included_views.append(h_cov) + + h = torch.cat(included_views, 1) # .type_as(x_mean) for net in self.linear: h = F.elu(net(h)) @@ -130,11 +174,15 @@ def __init__( input_corr_dim: int, input_morph_dim: int, input_spcont_dim: int, + E_cov: Optional[int] = 10, n_covariates: Optional[int] = 0, n_hidden: Optional[int] = 1, + leave_out_view: Optional[ + Union[None, Literal["expression", "correlation", "morphology", "spatial"]] + ] = None, ): super().__init__() - hidden_dim = E_me + E_cr + E_mr + E_sc + n_covariates + hidden_dim = E_me + E_cr + E_mr + E_sc + E_cov latent_dim = latent_dim + n_covariates self.E_me = E_me self.E_cr = E_cr @@ -164,9 +212,6 @@ def __init__( self.mu_x_spcont = nn.Linear(E_sc, input_spcont_dim) self.std_x_spcont = nn.Linear(E_sc, input_spcont_dim) - # self.covariates_out_mu = nn.Linear(n_covariates, n_covariates) #this is one-hot - # self.covariates_out_std = nn.Linear(n_covariates, n_covariates) - def forward(self, z, cov_list): z_s = torch.cat( [z, cov_list], 1 @@ -175,64 +220,123 @@ def forward(self, z, cov_list): for net in self.linear: out = F.elu(net(out)) - h2_mean = F.elu(self.exp_hidden(out[:, 0 : self.E_me])) - h2_correlations = F.elu( - self.corr_hidden(out[:, self.E_me : self.E_me + self.E_cr]) - ) - h2_morphology = F.elu( - self.morph_hidden( - out[:, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr] + if self.leave_out_view is None: + + h2_mean = F.elu(self.exp_hidden(out[:, 0 : self.E_me])) + h2_correlations = F.elu( + self.corr_hidden(out[:, self.E_me : self.E_me + self.E_cr]) ) - ) - h2_spatial_context = F.elu( - self.spatial_context_hidden( - out[ - :, - self.E_me - + self.E_cr - + self.E_mr : self.E_me - + self.E_cr - + self.E_mr - + self.E_sc, - ] + h2_morphology = F.elu( + self.morph_hidden( + out[:, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr] + ) + ) + h2_spatial_context = F.elu( + self.spatial_context_hidden( + out[ + :, + self.E_me + + self.E_cr + + self.E_mr : self.E_me + + self.E_cr + + self.E_mr + + self.E_sc, + ] + ) ) - ) - # covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :] - - mu_x_exp = self.mu_x_exp(h2_mean) - std_x_exp = self.std_x_exp(h2_mean) - - # if self.use_weights: - # with torch.no_grad(): - # weights = self.get_corr_weights_per_cell( - # mu_x_exp.detach() - # ) # calculating correlation weights - # else: - # weights = None - - mu_x_corr = self.mu_x_corr(h2_correlations) - std_x_corr = self.std_x_corr(h2_correlations) - - mu_x_morph = self.mu_x_morph(h2_morphology) - std_x_morph = self.std_x_morph(h2_morphology) - - mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context) - std_x_spatial_context = self.std_x_spcont(h2_spatial_context) - - # covariates_mu = self.covariates_out_mu(covariates) - # covariates_std = self.covariates_out_std(covariates) - - return ( - mu_x_exp, - std_x_exp, - mu_x_corr, - std_x_corr, - mu_x_morph, - std_x_morph, - mu_x_spatial_context, - std_x_spatial_context, - # covariates_mu, - # covariates_std, - # weights, - ) + # covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :] + + mu_x_exp = self.mu_x_exp(h2_mean) + std_x_exp = self.std_x_exp(h2_mean) + + # if self.use_weights: + # with torch.no_grad(): + # weights = self.get_corr_weights_per_cell( + # mu_x_exp.detach() + # ) # calculating correlation weights + # else: + # weights = None + + mu_x_corr = self.mu_x_corr(h2_correlations) + std_x_corr = self.std_x_corr(h2_correlations) + + mu_x_morph = self.mu_x_morph(h2_morphology) + std_x_morph = self.std_x_morph(h2_morphology) + + mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context) + std_x_spatial_context = self.std_x_spcont(h2_spatial_context) + + # covariates_mu = self.covariates_out_mu(covariates) + # covariates_std = self.covariates_out_std(covariates) + + return ( + mu_x_exp, + std_x_exp, + mu_x_corr, + std_x_corr, + mu_x_morph, + std_x_morph, + mu_x_spatial_context, + std_x_spatial_context, + # covariates_mu, + # covariates_std, + # weights, + ) + + else: + included_views = [] + + if self.leave_out_view != "expression": + h2_mean = F.elu(self.exp_hidden(out[:, 0 : self.E_me])) + mu_x_exp = self.mu_x_exp(h2_mean) + std_x_exp = self.std_x_exp(h2_mean) + + included_views.append(mu_x_exp) + included_views.append(std_x_exp) + + if self.leave_out_view != "correlation": + h2_correlations = F.elu( + self.corr_hidden(out[:, self.E_me : self.E_me + self.E_cr]) + ) + mu_x_corr = self.mu_x_corr(h2_correlations) + std_x_corr = self.std_x_corr(h2_correlations) + + included_views.append(mu_x_corr) + included_views.append(std_x_corr) + + if self.leave_out_view != "morphology": + h2_morphology = F.elu( + self.morph_hidden( + out[ + :, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr + ] + ) + ) + mu_x_morph = self.mu_x_morph(h2_morphology) + std_x_morph = self.std_x_morph(h2_morphology) + + included_views.append(mu_x_morph) + included_views.append(std_x_morph) + + if self.leave_out_view != "spatial": + h2_spatial_context = F.elu( + self.spatial_context_hidden( + out[ + :, + self.E_me + + self.E_cr + + self.E_mr : self.E_me + + self.E_cr + + self.E_mr + + self.E_sc, + ] + ) + ) + mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context) + std_x_spatial_context = self.std_x_spcont(h2_spatial_context) + + included_views.append(mu_x_spatial_context) + included_views.append(std_x_spatial_context) + + return included_views diff --git a/hmivae/_hmivae_model.py b/hmivae/_hmivae_model.py index 3b104b4..0342cc3 100644 --- a/hmivae/_hmivae_model.py +++ b/hmivae/_hmivae_model.py @@ -1,12 +1,13 @@ import logging -from typing import List, Optional +from typing import List, Literal, Optional import anndata as ad import numpy as np import pytorch_lightning as pl import torch from anndata import AnnData -from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks.progress import RichProgressBar from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.trainer import Trainer from scipy.stats.mstats import winsorize @@ -63,19 +64,25 @@ def __init__( n_hidden: int = 1, cofactor: float = 1.0, batch_correct: bool = True, + leave_out_view: Optional[ + Literal["expression", "correlation", "morphology", "spatial"] + ] = None, + output_dir: str = ".", **model_kwargs, ): # super(hmivaeModel, self).__init__(adata) super().__init__() + self.output_dir = output_dir self.adata = adata self.use_covs = use_covs + self.leave_out_view = leave_out_view if self.use_covs: self.keys = [] for key in adata.obsm.keys(): - if key not in ["correlations", "morphology", "xy"]: + if key not in ["correlations", "morphology", "spatial"]: self.keys.append(key) else: @@ -85,6 +92,7 @@ def __init__( self.train_batch, self.test_batch, self.n_covariates, + self.features_config, # self.cov_list, ) = self.setup_anndata( adata=self.adata, @@ -133,19 +141,61 @@ def __init__( def train( self, max_epochs=100, - check_val_every_n_epoch=5, + check_val_every_n_epoch=1, ): # misnomer, both train and test are here (either rename or separate) early_stopping = EarlyStopping(monitor="test_loss", mode="min", patience=2) - wandb_logger = WandbLogger(project="hmiVAE_init_trial_runs") + cb_chkpt = ModelCheckpoint( + dirpath=f"{self.output_dir}", + monitor="test_loss", + mode="min", + save_top_k=1, + filename="{epoch}_{step}_{test_loss:.3f}", + ) + + cb_progress = RichProgressBar() + + if self.leave_out_view is None: + + wandb_logger = WandbLogger( + project="hmiVAE_init_trial_runs", + config={ + "Expression min/max": (self.adata.X.min(), self.adata.X.max()), + "Correlation min/max": ( + self.adata.obsm["correlations"].min(), + self.adata.obsm["correlations"].max(), + ), + "Morphology min/max": ( + self.adata.obsm["morphology"].min(), + self.adata.obsm["morphology"].max(), + ), + }, + ) + else: + wandb_logger = WandbLogger( + project="hmivae_ablation", + config={ + "Expression min/max": (self.adata.X.min(), self.adata.X.max()), + "Correlation min/max": ( + self.adata.obsm["correlations"].min(), + self.adata.obsm["correlations"].max(), + ), + "Morphology min/max": ( + self.adata.obsm["morphology"].min(), + self.adata.obsm["morphology"].max(), + ), + }, + ) trainer = Trainer( max_epochs=max_epochs, check_val_every_n_epoch=check_val_every_n_epoch, - callbacks=[early_stopping], + callbacks=[early_stopping, cb_progress, cb_chkpt], logger=wandb_logger, - # gradient_clip_val=2.0, + gradient_clip_val=2.0, + accelerator="auto", + devices="auto", ) trainer.fit(self.module, self.train_batch, self.test_batch) @@ -193,7 +243,7 @@ def setup_anndata( protein_correlations_names_uns_key: Optional[str] = None, cell_morphology_names_uns_key: Optional[str] = None, image_correct: bool = True, - batch_size: Optional[int] = 128, + batch_size: Optional[int] = 32, batch_key: Optional[str] = None, labels_key: Optional[str] = None, layer: Optional[str] = None, @@ -274,6 +324,17 @@ def setup_anndata( data_train = ScModeDataloader.ScModeDataloader(adata_train) data_test = ScModeDataloader.ScModeDataloader(adata_test, data_train.scalers) + features_ranges = { + "Train expression min/max": (data_train.Y.min(), data_train.Y.max()), + "Train correlation min/max": (data_train.S.min(), data_train.S.max()), + "Train morphology min/max": (data_train.M.min(), data_train.M.max()), + "Train spatial context min/max": (data_train.C.min(), data_train.C.max()), + "Test expression min/max": (data_test.Y.min(), data_test.Y.max()), + "Test correlation min/max": (data_test.S.min(), data_test.S.max()), + "Test morphology min/max": (data_test.M.min(), data_test.M.max()), + "Test spatial context min/max": (data_test.C.min(), data_test.C.max()), + } + loader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True) loader_test = DataLoader(data_test, batch_size=batch_size) @@ -297,4 +358,5 @@ def setup_anndata( loader_train, loader_test, n_samples + n_cats, + features_ranges, ) diff --git a/hmivae/_hmivae_module.py b/hmivae/_hmivae_module.py index c5f3784..2b950e4 100644 --- a/hmivae/_hmivae_module.py +++ b/hmivae/_hmivae_module.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence # List +from typing import Literal, Optional, Sequence, Union import numpy as np import pytorch_lightning as pl @@ -29,19 +29,30 @@ def __init__( E_mr: int = 32, E_sc: int = 32, latent_dim: int = 10, + E_cov: int = 10, n_covariates: int = 0, + leave_out_view: Optional[ + Union[None, Literal["expression", "correlation", "morphology", "spatial"]] + ] = None, use_covs: bool = False, n_hidden: int = 1, batch_correct: bool = True, + n_steps_kl_warmup: Union[int, None] = None, + n_epochs_kl_warmup: Union[int, None] = 100, ): super().__init__() # hidden_dim = E_me + E_cr + E_mr + E_sc + + self.n_steps_kl_warmup = n_steps_kl_warmup + self.n_epochs_kl_warmup = n_epochs_kl_warmup self.n_covariates = n_covariates self.batch_correct = batch_correct self.use_covs = use_covs + self.leave_out_view = leave_out_view + self.encoder = EncoderHMIVAE( input_exp_dim, input_corr_dim, @@ -52,7 +63,9 @@ def __init__( E_mr, E_sc, latent_dim, + E_cov, n_covariates=n_covariates, + leave_out_view=leave_out_view, ) self.decoder = DecoderHMIVAE( @@ -65,7 +78,9 @@ def __init__( input_corr_dim, input_morph_dim, input_spcont_dim, + E_cov, n_covariates=n_covariates, + leave_out_view=leave_out_view, ) self.save_hyperparameters(ignore=["adata", "cat_list"]) @@ -97,6 +112,36 @@ def KL_div(self, enc_x_mu, enc_x_logstd, z): return kl + def compute_kl_weight( + self, + epoch: int, + step: Optional[int], + n_epochs_kl_warmup: Optional[int], + n_steps_kl_warmup: Optional[int], + max_kl_weight: float = 1.0, + min_kl_weight: float = 0.0, + ) -> float: + """ + Compute the weight for the KL-Div term in loss. + Taken from scVI: + https://github.com/scverse/scvi-tools/blob/2c22bda9bcfb5a89d62c96c4ad39d8a1e297eb08/scvi/train/_trainingplans.py#L31 + """ + slope = max_kl_weight - min_kl_weight + + if min_kl_weight > max_kl_weight: + raise ValueError( + f"min_kl_weight={min_kl_weight} is larger than max_kl_weight={max_kl_weight}" + ) + + if n_epochs_kl_warmup: + if epoch < n_epochs_kl_warmup: + return slope * (epoch / n_epochs_kl_warmup) + min_kl_weight + elif n_steps_kl_warmup: + if step < n_steps_kl_warmup: + return slope * (step / n_steps_kl_warmup) + min_kl_weight + + return max_kl_weight + def em_recon_loss( self, dec_x_mu_exp, @@ -223,8 +268,6 @@ def training_step( self, train_batch, corr_weights=False, - recon_weights=np.array([1.0, 1.0, 1.0, 1.0]), - beta=1.0, ): """ Carries out the training step. @@ -242,14 +285,14 @@ def training_step( if self.use_covs: categories = train_batch[5] else: - categories = torch.Tensor([]) + categories = torch.Tensor([]).type_as(Y) if self.batch_correct: one_hot = train_batch[4] - cov_list = torch.cat([one_hot, categories], 1).float() + cov_list = torch.cat([one_hot, categories], 1).float().type_as(Y) else: - cov_list = torch.Tensor([]) + cov_list = torch.Tensor([]).type_as(Y) mu_z, log_std_z = self.encoder(Y, S, M, spatial_context, cov_list) @@ -293,6 +336,25 @@ def training_step( # weights, ) + beta = self.compute_kl_weight( + self.current_epoch, + self.global_step, + self.n_epochs_kl_warmup, + self.n_steps_kl_warmup, + ) + + if self.leave_out_view is not None: + if self.leave_out_view == "expression": + recon_weights = np.array([0.0, 1.0, 1.0, 1.0]) + if self.leave_out_view == "correlation": + recon_weights = np.array([1.0, 0.0, 1.0, 1.0]) + if self.leave_out_view == "morphology": + recon_weights = np.array([1.0, 1.0, 0.0, 1.0]) + if self.leave_out_view == "spatial": + recon_weights = np.array([1.0, 1.0, 1.0, 0.0]) + else: + recon_weights = np.array([1.0, 1.0, 1.0, 1.0]) + recon_loss = ( recon_weights[0] * recon_lik_me + recon_weights[1] * recon_lik_corr @@ -303,6 +365,45 @@ def training_step( loss = self.loss(kl_div, recon_loss, beta=beta) self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("beta", beta, on_step=False, on_epoch=True, prog_bar=False) + self.log( + "kl_div", kl_div.mean().item(), on_step=True, on_epoch=True, prog_bar=False + ) + self.log( + "recon_loss", + recon_loss.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_me", + recon_lik_me.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_corr", + recon_lik_corr.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_mor", + recon_lik_mor.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_sc", + recon_lik_sc.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) return { "loss": loss, @@ -318,10 +419,8 @@ def validation_step( self, test_batch, n_other_cat: int = 0, - L_iter: int = 100, + L_iter: int = 300, corr_weights=False, - recon_weights=np.array([1.0, 1.0, 1.0, 1.0]), - beta=1.0, ): """---> Add random one-hot encoding Carries out the validation/test step. @@ -336,7 +435,7 @@ def validation_step( M = test_batch[2] spatial_context = test_batch[3] batch_idx = test_batch[-1] - test_loss = [] + # test_loss = [] if self.use_covs: categories = test_batch[5] @@ -345,16 +444,27 @@ def validation_step( categories = torch.Tensor([]) n_classes = self.n_covariates - for i in range(L_iter): + test_loss = torch.empty(size=[len(batch_idx), n_classes]) + + # for i in range(L_iter): + for i in range(n_classes): if self.batch_correct: - one_hot = self.random_one_hot( - n_classes=n_classes, n_samples=len(batch_idx) - ) + # one_hot = self.random_one_hot( + # n_classes=n_classes, n_samples=len(batch_idx) + # ).type_as(Y) + + one_hot_zeros = torch.zeros(size=[1, n_classes]) + + one_hot_zeros[0, i] = 1.0 + + one_hot = one_hot_zeros.repeat((len(batch_idx), 1)).type_as(Y) - cov_list = torch.cat([one_hot, categories], 1).float() + cov_list = torch.cat([one_hot, categories], 1).float().type_as(Y) + elif self.use_covs: + cov_list = categories else: - cov_list = torch.Tensor([]) + cov_list = torch.Tensor([]).type_as(Y) mu_z, log_std_z = self.encoder(Y, S, M, spatial_context, cov_list) @@ -398,6 +508,25 @@ def validation_step( # weights, ) + beta = self.compute_kl_weight( + self.current_epoch, + self.global_step, + self.n_epochs_kl_warmup, + self.n_steps_kl_warmup, + ) + + if self.leave_out_view is not None: + if self.leave_out_view == "expression": + recon_weights = np.array([0.0, 1.0, 1.0, 1.0]) + if self.leave_out_view == "correlation": + recon_weights = np.array([1.0, 0.0, 1.0, 1.0]) + if self.leave_out_view == "morphology": + recon_weights = np.array([1.0, 1.0, 0.0, 1.0]) + if self.leave_out_view == "spatial": + recon_weights = np.array([1.0, 1.0, 1.0, 0.0]) + else: + recon_weights = np.array([1.0, 1.0, 1.0, 1.0]) + recon_loss = ( recon_weights[0] * recon_lik_me + recon_weights[1] * recon_lik_corr @@ -407,18 +536,61 @@ def validation_step( loss = self.loss(kl_div, recon_loss, beta=beta) - test_loss.append(loss) + test_loss[:, i] = loss self.log( "test_loss", - sum(test_loss) / L_iter, + # sum(test_loss) / L_iter, + test_loss.mean(1).sum(), on_step=True, on_epoch=True, prog_bar=True, ) # log the average test loss over all the iterations + self.log( + "kl_div_test", + kl_div.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_loss_test", + recon_loss.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_me_test", + recon_lik_me.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_corr_test", + recon_lik_corr.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_mor_test", + recon_lik_mor.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_sc_test", + recon_lik_sc.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) return { - "loss": sum(test_loss) / L_iter, + "loss": test_loss.mean(1).sum(), "kl_div": kl_div.mean().item(), "recon_loss": recon_loss.mean().item(), "recon_lik_me": recon_lik_me.mean().item(), @@ -429,7 +601,8 @@ def validation_step( def configure_optimizers(self): """Optimizer""" - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) + optimizer = torch.optim.Adam(parameters, lr=1e-3) return optimizer @torch.no_grad() @@ -470,25 +643,28 @@ def inference( C = data.C if self.use_covs: categories = data.BKG - n_cats = categories.shape[1] + n_classes = self.n_covariates - categories.shape[1] else: - categories = torch.Tensor([]) - n_cats = 0 + categories = torch.Tensor([]).type_as(Y) + n_classes = self.n_covariates if self.batch_correct: - one_hot = data.samples_onehot - if one_hot.shape[1] < self.n_covariates - n_cats: - zeros_pad = torch.Tensor( - np.zeros( - [ - one_hot.shape[0], - (self.n_covariates - n_cats) - one_hot.shape[1], - ] - ) - ) - one_hot = torch.cat([one_hot, zeros_pad], 1) - else: - one_hot = one_hot + one_hot = self.random_one_hot( + n_classes=n_classes, n_samples=Y.shape[0] + ).type_as(Y) + # one_hot = data.samples_onehot + # if one_hot.shape[1] < self.n_covariates - n_cats: + # zeros_pad = torch.Tensor( + # np.zeros( + # [ + # one_hot.shape[0], + # (self.n_covariates - n_cats) - one_hot.shape[1], + # ] + # ) + # ) + # one_hot = torch.cat([one_hot, zeros_pad], 1) + # else: + # one_hot = one_hot cov_list = torch.cat([one_hot, categories], 1).float() else: From e451de23b5d39bb98d4ce838af686cd88644921d Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Wed, 22 Feb 2023 14:22:06 -0500 Subject: [PATCH 09/18] latest model script --- hmivae/ScModeDataloader.py | 6 ++ hmivae/_hmivae_model.py | 210 +++++++++++++++++++++++++++---------- hmivae/_hmivae_module.py | 50 +++++---- 3 files changed, 190 insertions(+), 76 deletions(-) diff --git a/hmivae/ScModeDataloader.py b/hmivae/ScModeDataloader.py index c00a21e..ddfa9c1 100644 --- a/hmivae/ScModeDataloader.py +++ b/hmivae/ScModeDataloader.py @@ -46,6 +46,7 @@ def __init__(self, adata, scalers=None): Y = adata.X # per cell protein mean expression S = adata.obsm["correlations"] M = adata.obsm["morphology"] + weights = adata.obsm["weights"] self.n_cells = Y.shape[0] # number of cells @@ -66,6 +67,9 @@ def __init__(self, adata, scalers=None): self.S = torch.tensor(S).float() self.M = torch.tensor(M).float() self.C = self.get_spatial_context() + self.weights = torch.tensor( + weights + ).float() # these don't need to be scaled, not a data input self.samples_onehot = self.one_hot_encoding() @@ -132,6 +136,7 @@ def __getitem__(self, idx): self.M[idx, :], self.C[idx, :], self.samples_onehot[idx, :], + self.weights[idx, :], idx, ) else: @@ -141,6 +146,7 @@ def __getitem__(self, idx): self.M[idx, :], self.C[idx, :], self.samples_onehot[idx, :], + self.weights[idx, :], self.BKG[idx, :], idx, ) diff --git a/hmivae/_hmivae_model.py b/hmivae/_hmivae_model.py index 0342cc3..89dc01c 100644 --- a/hmivae/_hmivae_model.py +++ b/hmivae/_hmivae_model.py @@ -1,5 +1,6 @@ +import inspect import logging -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Union import anndata as ad import numpy as np @@ -25,8 +26,6 @@ class hmivaeModel(pl.LightningModule): """ Skeleton for an scvi-tools model. - Please use this skeleton to create new models. - Parameters ---------- adata @@ -59,13 +58,21 @@ def __init__( E_cr: int = 32, E_mr: int = 32, E_sc: int = 32, + E_cov: int = 10, latent_dim: int = 10, use_covs: bool = False, + use_weights: bool = True, + n_covariates: Optional[Union[None, int]] = None, + cohort: Optional[Union[None, str]] = None, n_hidden: int = 1, cofactor: float = 1.0, + beta_scheme: Optional[Literal["constant", "warmup"]] = "warmup", batch_correct: bool = True, + is_trained_model: bool = False, + batch_size: Optional[int] = 1234, + random_seed: Optional[int] = 1234, leave_out_view: Optional[ - Literal["expression", "correlation", "morphology", "spatial"] + Union[None, Literal["expression", "correlation", "morphology", "spatial"]] ] = None, output_dir: str = ".", **model_kwargs, @@ -74,35 +81,78 @@ def __init__( super().__init__() self.output_dir = output_dir - self.adata = adata + # self.adata = adata self.use_covs = use_covs + self.use_weights = use_weights self.leave_out_view = leave_out_view + self.is_trained_model = is_trained_model + self.random_seed = random_seed + self.name = f"{cohort}_rs{random_seed}_nh{n_hidden}_bs{batch_size}_hd{E_me}_ls{latent_dim}" if self.use_covs: self.keys = [] for key in adata.obsm.keys(): - - if key not in ["correlations", "morphology", "spatial"]: + # print(key) + if key not in ["correlations", "morphology", "spatial", "xy"]: self.keys.append(key) + if n_covariates is None: + raise ValueError("`n_covariates` cannot be None when `use_covs`==True") + else: + n_covariates = n_covariates + + # print("n_keys", len(self.keys)) else: self.keys = None + if n_covariates is None: + n_covariates = 0 + else: + n_covariates = 0 + print("`n_covariates` automatically set to 0 when use_covs == False") ( self.train_batch, self.test_batch, - self.n_covariates, + n_samples, self.features_config, # self.cov_list, ) = self.setup_anndata( - adata=self.adata, + adata=adata, protein_correlations_obsm_key="correlations", cell_morphology_obsm_key="morphology", continuous_covariate_keys=self.keys, cofactor=cofactor, image_correct=batch_correct, + batch_size=batch_size, + random_seed=random_seed, ) + n_covariates += n_samples + + print("n_covs", n_covariates) + + # for batch in self.train_batch: + # print('Y', torch.mean(batch[0],1)) + # print('S', torch.mean(batch[1],1)) + # print('M', torch.mean(batch[2],1)) + # print('C', torch.mean(batch[3],1)) + # print('one-hot', batch[4]) + # print('covariates', torch.mean(batch[5],1)) + # break + + # for batch in self.test_batch: + # print('Y_test', torch.mean(batch[0],1)) + # print('S_test', torch.mean(batch[1],1)) + # print('M_test', torch.mean(batch[2],1)) + # print('C_test', torch.mean(batch[3],1)) + # print('one-hot_test', batch[4]) + # print('covariates_test', torch.mean(batch[5],1)) + # break + + # print("cov_list", self.cov_list.shape) + + # print("self.adata", self.adata.X) + # self.summary_stats provides information about anndata dimensions and other tensor info self.module = module.hmiVAE( input_exp_dim=input_exp_dim, @@ -113,12 +163,15 @@ def __init__( E_cr=E_cr, E_mr=E_mr, E_sc=E_sc, + E_cov=E_cov, latent_dim=latent_dim, - n_covariates=self.n_covariates, + n_covariates=n_covariates, n_hidden=n_hidden, use_covs=self.use_covs, - # cat_list=self.cov_list, + use_weights=self.use_weights, + beta_scheme=beta_scheme, batch_correct=batch_correct, + leave_out_view=leave_out_view, **model_kwargs, ) self._model_summary_string = ( @@ -134,58 +187,67 @@ def __init__( use_covs, ) # necessary line to get params that will be used for saving/loading - # self.init_params_ = self._get_init_params(locals()) + self.init_params_ = self._get_init_params(locals()) logger.info("The model has been initialized") + def _get_init_params(self, locals): + """ + Taken from: https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_base_model.py + """ + + init = self.__init__ + sig = inspect.signature(init) + parameters = sig.parameters.values() + + init_params = [p.name for p in parameters] + all_params = {p: locals[p] for p in locals if p in init_params} + + non_var_params = [p.name for p in parameters if p.kind != p.VAR_KEYWORD] + non_var_params = {k: v for (k, v) in all_params.items() if k in non_var_params} + var_params = [p.name for p in parameters if p.kind == p.VAR_KEYWORD] + var_params = {k: v for (k, v) in all_params.items() if k in var_params} + + user_params = {"kwargs": var_params, "non_kwargs": non_var_params} + + return user_params + def train( self, - max_epochs=100, + max_epochs=15, check_val_every_n_epoch=1, - ): # misnomer, both train and test are here (either rename or separate) + config=None, + ): # misnomer, both train and test/val are here (either rename or separate) + + # with wandb.init(config=config): + # config=wandb.config - early_stopping = EarlyStopping(monitor="test_loss", mode="min", patience=2) + pl.seed_everything(self.random_seed) + + early_stopping = EarlyStopping(monitor="recon_lik_test", mode="max", patience=1) cb_chkpt = ModelCheckpoint( dirpath=f"{self.output_dir}", - monitor="test_loss", - mode="min", + monitor="recon_lik_test", + mode="max", save_top_k=1, - filename="{epoch}_{step}_{test_loss:.3f}", + filename="{epoch}_{step}_{recon_lik_test:.3f}", ) cb_progress = RichProgressBar() + # wandb.finish() + # wandb_logger = WandbLogger(log_model="all") if self.leave_out_view is None: wandb_logger = WandbLogger( - project="hmiVAE_init_trial_runs", - config={ - "Expression min/max": (self.adata.X.min(), self.adata.X.max()), - "Correlation min/max": ( - self.adata.obsm["correlations"].min(), - self.adata.obsm["correlations"].max(), - ), - "Morphology min/max": ( - self.adata.obsm["morphology"].min(), - self.adata.obsm["morphology"].max(), - ), - }, + project="hmivae_hyperparameter_runs", + name=self.name, + config=self.features_config, ) else: wandb_logger = WandbLogger( - project="hmivae_ablation", - config={ - "Expression min/max": (self.adata.X.min(), self.adata.X.max()), - "Correlation min/max": ( - self.adata.obsm["correlations"].min(), - self.adata.obsm["correlations"].max(), - ), - "Morphology min/max": ( - self.adata.obsm["morphology"].min(), - self.adata.obsm["morphology"].max(), - ), - }, + project="hmivae_ablation", config=self.features_config ) trainer = Trainer( @@ -193,39 +255,66 @@ def train( check_val_every_n_epoch=check_val_every_n_epoch, callbacks=[early_stopping, cb_progress, cb_chkpt], logger=wandb_logger, + # overfit_batches=0.01, gradient_clip_val=2.0, accelerator="auto", devices="auto", + log_every_n_steps=1, + # limit_train_batches=0.1, + # limit_val_batches=0.1, ) trainer.fit(self.module, self.train_batch, self.test_batch) + # wandb.finish() + @torch.no_grad() def get_latent_representation( self, + adata: AnnData, protein_correlations_obsm_key: str, cell_morphology_obsm_key: str, - continuous_covariate_keys: Optional[List[str]] = None, # default is self.keys + continuous_covariate_keys: Optional[List[str]] = None, cofactor: float = 1.0, is_trained_model: Optional[bool] = False, batch_correct: Optional[bool] = True, + use_covs: Optional[bool] = True, ) -> AnnData: """ Gives the latent representation of each cell. """ if is_trained_model: - (adata_train, adata_test, data_train, data_test,) = self.setup_anndata( - self.adata, + ( + adata_train, + adata_test, + data_train, + data_test, + n_covariates, + # cat_list, + # train_idx, + # test_idx, + ) = self.setup_anndata( + adata, protein_correlations_obsm_key, cell_morphology_obsm_key, - continuous_covariate_keys=self.keys, + continuous_covariate_keys=continuous_covariate_keys, cofactor=cofactor, is_trained_model=is_trained_model, image_correct=batch_correct, ) - adata_train.obsm["VAE"] = self.module.inference(data_train) - adata_test.obsm["VAE"] = self.module.inference(data_test) + adata_train.obsm["VAE"] = self.module.inference( + data_train, + n_covariates=n_covariates, + use_covs=use_covs, + batch_correct=batch_correct, + ) # idx=train_idx) + adata_test.obsm["VAE"] = self.module.inference( + data_test, + n_covariates=n_covariates, + use_covs=use_covs, + batch_correct=batch_correct, + ) # idx=test_idx) return ad.concat([adata_train, adata_test], uns_merge="first") else: @@ -240,10 +329,11 @@ def setup_anndata( adata: AnnData, protein_correlations_obsm_key: str, cell_morphology_obsm_key: str, + # cell_spatial_context_obsm_key: str, protein_correlations_names_uns_key: Optional[str] = None, cell_morphology_names_uns_key: Optional[str] = None, image_correct: bool = True, - batch_size: Optional[int] = 32, + batch_size: Optional[int] = 4321, batch_key: Optional[str] = None, labels_key: Optional[str] = None, layer: Optional[str] = None, @@ -279,9 +369,13 @@ def setup_anndata( N_PROTEINS = adata.shape[1] N_MORPHOLOGY = len(adata.uns["names_morphology"]) + # print("adata in setup_adata", adata.X) + if continuous_covariate_keys is not None: cat_list = [] for cat_key in continuous_covariate_keys: + # print(cat_key) + # print(f"{cat_key} shape:", adata.obsm[cat_key].shape) category = adata.obsm[cat_key] cat_list.append(category) cat_list = np.concatenate(cat_list, 1) @@ -316,9 +410,7 @@ def setup_anndata( samples_train, samples_test = train_test_split( samples_list, train_size=train_prop, random_state=random_seed ) - adata_train = adata.copy()[adata.obs["Sample_name"].isin(samples_train), :] - adata_test = adata.copy()[adata.obs["Sample_name"].isin(samples_test), :] data_train = ScModeDataloader.ScModeDataloader(adata_train) @@ -335,21 +427,29 @@ def setup_anndata( "Test spatial context min/max": (data_test.C.min(), data_test.C.max()), } - loader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True) - loader_test = DataLoader(data_test, batch_size=batch_size) + loader_train = DataLoader( + data_train, batch_size=batch_size, shuffle=True, num_workers=64 + ) + loader_test = DataLoader( + data_test, batch_size=batch_size, num_workers=64 + ) # shuffle=True) if image_correct: n_samples = len(samples_train) + # print("n_samples", n_samples) + # print("cat_list", cat_list.shape) + print("one-hot+covs", n_samples + n_cats) else: n_samples = 0 + print("n_cats", n_cats) if is_trained_model: - return ( adata_train, adata_test, data_train, data_test, + n_cats + n_samples, ) else: @@ -357,6 +457,6 @@ def setup_anndata( return ( loader_train, loader_test, - n_samples + n_cats, + n_samples, features_ranges, ) diff --git a/hmivae/_hmivae_module.py b/hmivae/_hmivae_module.py index 2b950e4..2a164fb 100644 --- a/hmivae/_hmivae_module.py +++ b/hmivae/_hmivae_module.py @@ -35,6 +35,7 @@ def __init__( Union[None, Literal["expression", "correlation", "morphology", "spatial"]] ] = None, use_covs: bool = False, + use_weights: bool = True, n_hidden: int = 1, batch_correct: bool = True, n_steps_kl_warmup: Union[int, None] = None, @@ -51,6 +52,8 @@ def __init__( self.use_covs = use_covs + self.use_weights = use_weights + self.leave_out_view = leave_out_view self.encoder = EncoderHMIVAE( @@ -156,7 +159,7 @@ def em_recon_loss( s, m, c, - # weights=None, + weights: Optional[Union[None, torch.tensor]] = None, ): """Takes in the parameters output from the decoder, and the original input x, and gives the reconstruction @@ -188,16 +191,16 @@ def em_recon_loss( ) log_p_xz_exp = p_rec_exp.log_prob(y) - log_p_xz_corr = p_rec_corr.log_prob(s) + # log_p_xz_corr = p_rec_corr.log_prob(s) log_p_xz_morph = p_rec_morph.log_prob(m) log_p_xz_spcont = p_rec_spcont.log_prob(c) # already dense matrix - # if weights is None: - # log_p_xz_corr = p_rec_corr.log_prob(s) - # else: - # log_p_xz_corr = torch.mul( - # weights, p_rec_corr.log_prob(s) - # ) # does element-wise multiplication + if weights is None: + log_p_xz_corr = p_rec_corr.log_prob(s) + else: + log_p_xz_corr = torch.mul( + weights, p_rec_corr.log_prob(s) + ) # does element-wise multiplication log_p_xz_exp = log_p_xz_exp.sum(-1) log_p_xz_corr = log_p_xz_corr.sum(-1) @@ -228,7 +231,7 @@ def neg_ELBO( s, m, c, - # weights=None, + weights: Optional[Union[None, torch.tensor]] = None, ): kl_div = self.KL_div(enc_x_mu, enc_x_logstd, z) @@ -250,7 +253,7 @@ def neg_ELBO( s, m, c, - # weights, + weights, ) return ( kl_div, @@ -267,7 +270,6 @@ def loss(self, kl_div, recon_loss, beta: float = 1.0): def training_step( self, train_batch, - corr_weights=False, ): """ Carries out the training step. @@ -282,8 +284,13 @@ def training_step( M = train_batch[2] spatial_context = train_batch[3] + if self.use_weights: + weights = train_batch[5] + else: + weights = None + if self.use_covs: - categories = train_batch[5] + categories = train_batch[6] else: categories = torch.Tensor([]).type_as(Y) @@ -308,7 +315,6 @@ def training_step( log_std_x_morph_hat, mu_x_spcont_hat, log_std_x_spcont_hat, - # weights, ) = self.decoder(z_samples, cov_list) ( @@ -333,7 +339,7 @@ def training_step( S, M, spatial_context, - # weights, + weights, ) beta = self.compute_kl_weight( @@ -420,7 +426,6 @@ def validation_step( test_batch, n_other_cat: int = 0, L_iter: int = 300, - corr_weights=False, ): """---> Add random one-hot encoding Carries out the validation/test step. @@ -435,10 +440,14 @@ def validation_step( M = test_batch[2] spatial_context = test_batch[3] batch_idx = test_batch[-1] - # test_loss = [] + + if self.use_weights: + weights = test_batch[5] + else: + weights = None if self.use_covs: - categories = test_batch[5] + categories = test_batch[6] n_classes = self.n_covariates - categories.shape[1] else: categories = torch.Tensor([]) @@ -480,7 +489,6 @@ def validation_step( log_std_x_morph_hat, mu_x_spcont_hat, log_std_x_spcont_hat, - # weights, ) = self.decoder(z_samples, cov_list) ( @@ -505,7 +513,7 @@ def validation_step( S, M, spatial_context, - # weights, + weights, ) beta = self.compute_kl_weight( @@ -541,7 +549,7 @@ def validation_step( self.log( "test_loss", # sum(test_loss) / L_iter, - test_loss.mean(1).sum(), + test_loss.mean(1).mean(), # log the mean across cells on_step=True, on_epoch=True, prog_bar=True, @@ -590,7 +598,7 @@ def validation_step( ) return { - "loss": test_loss.mean(1).sum(), + "loss": test_loss.mean(1).mean(), # get the mean across cells "kl_div": kl_div.mean().item(), "recon_loss": recon_loss.mean().item(), "recon_lik_me": recon_lik_me.mean().item(), From d8e2bcd55264d18a1532bbdf27fa0d931e6c1868 Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Wed, 22 Feb 2023 14:28:23 -0500 Subject: [PATCH 10/18] latest module script --- hmivae/_hmivae_module.py | 157 ++++++++++++++++++++++++++------------- 1 file changed, 106 insertions(+), 51 deletions(-) diff --git a/hmivae/_hmivae_module.py b/hmivae/_hmivae_module.py index 2a164fb..28a09db 100644 --- a/hmivae/_hmivae_module.py +++ b/hmivae/_hmivae_module.py @@ -8,6 +8,9 @@ # import hmivae from hmivae._hmivae_base_components import DecoderHMIVAE, EncoderHMIVAE +# from pytorch_lightning.callbacks import Callback + + # from anndata import AnnData torch.backends.cudnn.benchmark = True @@ -28,8 +31,8 @@ def __init__( E_cr: int = 32, E_mr: int = 32, E_sc: int = 32, - latent_dim: int = 10, E_cov: int = 10, + latent_dim: int = 10, n_covariates: int = 0, leave_out_view: Optional[ Union[None, Literal["expression", "correlation", "morphology", "spatial"]] @@ -37,17 +40,19 @@ def __init__( use_covs: bool = False, use_weights: bool = True, n_hidden: int = 1, + beta_scheme: Optional[Literal["constant", "warmup"]] = "warmup", batch_correct: bool = True, n_steps_kl_warmup: Union[int, None] = None, - n_epochs_kl_warmup: Union[int, None] = 100, + n_epochs_kl_warmup: Union[int, None] = 10, ): super().__init__() # hidden_dim = E_me + E_cr + E_mr + E_sc - self.n_steps_kl_warmup = n_steps_kl_warmup self.n_epochs_kl_warmup = n_epochs_kl_warmup self.n_covariates = n_covariates + # self.cat_list = cat_list + self.batch_correct = batch_correct self.use_covs = use_covs @@ -56,6 +61,8 @@ def __init__( self.leave_out_view = leave_out_view + self.beta_scheme = beta_scheme + self.encoder = EncoderHMIVAE( input_exp_dim, input_corr_dim, @@ -66,9 +73,10 @@ def __init__( E_mr, E_sc, latent_dim, - E_cov, + E_cov=E_cov, n_covariates=n_covariates, leave_out_view=leave_out_view, + n_hidden=n_hidden, ) self.decoder = DecoderHMIVAE( @@ -81,12 +89,13 @@ def __init__( input_corr_dim, input_morph_dim, input_spcont_dim, - E_cov, + E_cov=E_cov, n_covariates=n_covariates, leave_out_view=leave_out_view, + n_hidden=n_hidden, ) - self.save_hyperparameters(ignore=["adata", "cat_list"]) + self.save_hyperparameters(ignore=["adata"]) def reparameterization(self, mu, log_std): std = torch.exp(log_std) @@ -159,7 +168,7 @@ def em_recon_loss( s, m, c, - weights: Optional[Union[None, torch.tensor]] = None, + weights: Optional[Union[None, torch.Tensor]] = None, ): """Takes in the parameters output from the decoder, and the original input x, and gives the reconstruction @@ -179,32 +188,36 @@ def em_recon_loss( weights: torch.Tensor, weights calculated from decoded means for protein expression feature """ + ## Mean expression dec_x_std_exp = torch.exp(dec_x_logstd_exp) - dec_x_std_corr = torch.exp(dec_x_logstd_corr) - dec_x_std_morph = torch.exp(dec_x_logstd_morph) - dec_x_std_spcont = torch.exp(dec_x_logstd_spcont) p_rec_exp = torch.distributions.Normal(dec_x_mu_exp, dec_x_std_exp + 1e-6) - p_rec_corr = torch.distributions.Normal(dec_x_mu_corr, dec_x_std_corr + 1e-6) - p_rec_morph = torch.distributions.Normal(dec_x_mu_morph, dec_x_std_morph + 1e-6) - p_rec_spcont = torch.distributions.Normal( - dec_x_mu_spcont, dec_x_std_spcont + 1e-6 - ) - log_p_xz_exp = p_rec_exp.log_prob(y) - # log_p_xz_corr = p_rec_corr.log_prob(s) - log_p_xz_morph = p_rec_morph.log_prob(m) - log_p_xz_spcont = p_rec_spcont.log_prob(c) # already dense matrix + log_p_xz_exp = log_p_xz_exp.sum(-1) + ## Correlations + dec_x_std_corr = torch.exp(dec_x_logstd_corr) + p_rec_corr = torch.distributions.Normal(dec_x_mu_corr, dec_x_std_corr + 1e-6) + # log_p_xz_corr = p_rec_corr.log_prob(s) if weights is None: log_p_xz_corr = p_rec_corr.log_prob(s) else: log_p_xz_corr = torch.mul( weights, p_rec_corr.log_prob(s) ) # does element-wise multiplication - - log_p_xz_exp = log_p_xz_exp.sum(-1) log_p_xz_corr = log_p_xz_corr.sum(-1) + + ## Morphology + dec_x_std_morph = torch.exp(dec_x_logstd_morph) + p_rec_morph = torch.distributions.Normal(dec_x_mu_morph, dec_x_std_morph + 1e-6) + log_p_xz_morph = p_rec_morph.log_prob(m) log_p_xz_morph = log_p_xz_morph.sum(-1) + + ## Spatial context + dec_x_std_spcont = torch.exp(dec_x_logstd_spcont) + p_rec_spcont = torch.distributions.Normal( + dec_x_mu_spcont, dec_x_std_spcont + 1e-6 + ) + log_p_xz_spcont = p_rec_spcont.log_prob(c) # already dense matrix log_p_xz_spcont = log_p_xz_spcont.sum(-1) return ( @@ -231,7 +244,7 @@ def neg_ELBO( s, m, c, - weights: Optional[Union[None, torch.tensor]] = None, + weights: Optional[Union[None, torch.Tensor]] = None, ): kl_div = self.KL_div(enc_x_mu, enc_x_logstd, z) @@ -270,6 +283,7 @@ def loss(self, kl_div, recon_loss, beta: float = 1.0): def training_step( self, train_batch, + recon_weights=np.array([1.0, 1.0, 1.0, 1.0]), ): """ Carries out the training step. @@ -279,11 +293,12 @@ def training_step( recon_weights: numpy.array. Array with weights for each view during loss calculation. beta: float. Coefficient for KL-Divergence term in ELBO. """ + Y = train_batch[0] S = train_batch[1] M = train_batch[2] spatial_context = train_batch[3] - + # batch_idx = train_batch[-1] if self.use_weights: weights = train_batch[5] else: @@ -298,6 +313,8 @@ def training_step( one_hot = train_batch[4] cov_list = torch.cat([one_hot, categories], 1).float().type_as(Y) + elif self.use_covs: + cov_list = categories else: cov_list = torch.Tensor([]).type_as(Y) @@ -342,12 +359,18 @@ def training_step( weights, ) - beta = self.compute_kl_weight( - self.current_epoch, - self.global_step, - self.n_epochs_kl_warmup, - self.n_steps_kl_warmup, - ) + if self.beta_scheme == "warmup": + + beta = self.compute_kl_weight( + self.current_epoch, + self.global_step, + self.n_epochs_kl_warmup, + self.n_steps_kl_warmup, + ) + else: + beta = 1.0 + + # print('beta=', beta) if self.leave_out_view is not None: if self.leave_out_view == "expression": @@ -371,12 +394,12 @@ def training_step( loss = self.loss(kl_div, recon_loss, beta=beta) self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) - self.log("beta", beta, on_step=False, on_epoch=True, prog_bar=False) + self.log("beta", beta, on_step=True, on_epoch=True, prog_bar=False) self.log( "kl_div", kl_div.mean().item(), on_step=True, on_epoch=True, prog_bar=False ) self.log( - "recon_loss", + "recon_lik", recon_loss.mean().item(), on_step=True, on_epoch=True, @@ -414,7 +437,7 @@ def training_step( return { "loss": loss, "kl_div": kl_div.mean().item(), - "recon_loss": recon_loss.mean().item(), + "recon_lik": recon_loss.mean().item(), "recon_lik_me": recon_lik_me.mean().item(), "recon_lik_corr": recon_lik_corr.mean().item(), "recon_lik_mor": recon_lik_mor.mean().item(), @@ -450,10 +473,11 @@ def validation_step( categories = test_batch[6] n_classes = self.n_covariates - categories.shape[1] else: - categories = torch.Tensor([]) + categories = torch.Tensor([]).type_as(Y) n_classes = self.n_covariates test_loss = torch.empty(size=[len(batch_idx), n_classes]) + elbo_full = torch.empty(size=[len(batch_idx), n_classes]) # for i in range(L_iter): for i in range(n_classes): @@ -475,7 +499,9 @@ def validation_step( else: cov_list = torch.Tensor([]).type_as(Y) - mu_z, log_std_z = self.encoder(Y, S, M, spatial_context, cov_list) + mu_z, log_std_z = self.encoder( + Y, S, M, spatial_context, cov_list + ) # valid step z_samples = self.reparameterization(mu_z, log_std_z) @@ -516,12 +542,16 @@ def validation_step( weights, ) - beta = self.compute_kl_weight( - self.current_epoch, - self.global_step, - self.n_epochs_kl_warmup, - self.n_steps_kl_warmup, - ) + if self.beta_scheme == "warmup": + + beta = self.compute_kl_weight( + self.current_epoch, + self.global_step, + self.n_epochs_kl_warmup, + self.n_steps_kl_warmup, + ) + else: + beta = 1.0 if self.leave_out_view is not None: if self.leave_out_view == "expression": @@ -544,16 +574,27 @@ def validation_step( loss = self.loss(kl_div, recon_loss, beta=beta) + full_elbo = recon_loss.mean() - kl_div.mean() + test_loss[:, i] = loss + elbo_full[:, i] = full_elbo + self.log( "test_loss", # sum(test_loss) / L_iter, - test_loss.mean(1).mean(), # log the mean across cells + test_loss.mean(1).mean(), on_step=True, on_epoch=True, prog_bar=True, ) # log the average test loss over all the iterations + self.log( + "test_full_elbo", + elbo_full.mean(1).mean(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) self.log( "kl_div_test", kl_div.mean().item(), @@ -561,8 +602,9 @@ def validation_step( on_epoch=True, prog_bar=False, ) + self.log("beta_test", beta, on_step=True, on_epoch=True, prog_bar=False) self.log( - "recon_loss_test", + "recon_lik_test", recon_loss.mean().item(), on_step=True, on_epoch=True, @@ -598,9 +640,9 @@ def validation_step( ) return { - "loss": test_loss.mean(1).mean(), # get the mean across cells + "loss": test_loss.mean(1).mean(), "kl_div": kl_div.mean().item(), - "recon_loss": recon_loss.mean().item(), + "recon_lik": recon_loss.mean().item(), "recon_lik_me": recon_lik_me.mean().item(), "recon_lik_corr": recon_lik_corr.mean().item(), "recon_lik_mor": recon_lik_mor.mean().item(), @@ -638,6 +680,9 @@ def get_input_embeddings( def inference( self, data, + n_covariates: int, + use_covs: bool = True, + batch_correct: bool = True, indices: Optional[Sequence[int]] = None, give_mean: bool = True, # idx = None, @@ -645,18 +690,27 @@ def inference( """ Return the latent representation of each cell. """ + # if self.leave_out_view is None: Y = data.Y S = data.S M = data.M C = data.C - if self.use_covs: + # batch_idx = idx + # print(batch_idx) + # if self.use_covs: + # categories = data.BKG + # n_cats = categories.shape[1] + # else: + # categories = torch.Tensor([]) + # n_cats = 0 + if use_covs: categories = data.BKG - n_classes = self.n_covariates - categories.shape[1] + n_classes = n_covariates - categories.shape[1] else: categories = torch.Tensor([]).type_as(Y) - n_classes = self.n_covariates + n_classes = n_covariates - if self.batch_correct: + if batch_correct: one_hot = self.random_one_hot( n_classes=n_classes, n_samples=Y.shape[0] ).type_as(Y) @@ -684,9 +738,10 @@ def inference( return mu_z.numpy() else: mu_z, log_std_z = self.encoder(Y, S, M, C, cov_list) - z = self.reparameterization(mu_z, log_std_z) - return z.numpy() + z = self.reparameterization(mu_z, log_std_z) + + return z.numpy() @torch.no_grad() def random_one_hot(self, n_classes: int, n_samples: int): @@ -694,5 +749,5 @@ def random_one_hot(self, n_classes: int, n_samples: int): Generates a random one hot encoded matrix. From: https://stackoverflow.com/questions/45093615/random-one-hot-matrix-in-numpy """ - + # x = np.eye(n_classes) return torch.Tensor(np.eye(n_classes)[np.random.choice(n_classes, n_samples)]) From f5a67a3fcc60a5e24573ed222e7121398db051ff Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Wed, 22 Feb 2023 14:30:22 -0500 Subject: [PATCH 11/18] latest base_components script --- hmivae/_hmivae_base_components.py | 324 ++++++++++++++---------------- 1 file changed, 156 insertions(+), 168 deletions(-) diff --git a/hmivae/_hmivae_base_components.py b/hmivae/_hmivae_base_components.py index ef19816..8a93ea6 100644 --- a/hmivae/_hmivae_base_components.py +++ b/hmivae/_hmivae_base_components.py @@ -40,6 +40,8 @@ def __init__( super().__init__() hidden_dim = E_me + E_cr + E_mr + E_sc + E_cov + self.leave_out_view = leave_out_view + self.input_cov = nn.Linear(n_covariates, E_cov) self.input_exp = nn.Linear(input_exp_dim, E_me) @@ -68,75 +70,65 @@ def forward( cov_list=torch.Tensor([]), ): - included_views = [] - - if self.leave_out_view is None: - - h_mean = F.elu(self.input_exp(x_mean)) - h_mean2 = F.elu(self.exp_hidden(h_mean)) - - # print("h_mean2", h_mean2) + # included_views = [] - h_correlations = F.elu(self.input_corr(x_correlations)) - h_correlations2 = F.elu(self.corr_hidden(h_correlations)) + # if self.leave_out_view is None: - # print("h_correlations2", h_correlations2) + h_mean = F.elu(self.input_exp(x_mean)) + h_mean2 = F.elu(self.exp_hidden(h_mean)) - h_morphology = F.elu(self.input_morph(x_morphology)) - h_morphology2 = F.elu(self.morph_hidden(h_morphology)) + h_correlations = F.elu(self.input_corr(x_correlations)) + h_correlations2 = F.elu(self.corr_hidden(h_correlations)) - # print("h_morphology2", h_morphology2) + h_morphology = F.elu(self.input_morph(x_morphology)) + h_morphology2 = F.elu(self.morph_hidden(h_morphology)) - h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) - h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context)) + h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) + h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context)) - # print("h_spatial_context2", h_spatial_context2) - - if cov_list.shape[0] > 1: - h_cov = F.elu(self.input_cov(cov_list)) - - # print('h_cov', h_cov) - else: - h_cov = cov_list - - h = torch.cat( - [h_mean2, h_correlations2, h_morphology2, h_spatial_context2, h_cov], 1 - ).type_as(x_mean) + if cov_list.shape[0] > 1: + h_cov = F.elu(self.input_cov(cov_list)) else: - if self.leave_out_view != "expression": - h_mean = F.elu(self.input_exp(x_mean)) - h_mean2 = F.elu(self.exp_hidden(h_mean)) + h_cov = cov_list + + h = torch.cat( + [h_mean2, h_correlations2, h_morphology2, h_spatial_context2, h_cov], 1 + ).type_as(x_mean) + # else: + # if self.leave_out_view != "expression": + # h_mean = F.elu(self.input_exp(x_mean)) + # h_mean2 = F.elu(self.exp_hidden(h_mean)) - included_views.append(h_mean2) + # included_views.append(h_mean2) - if self.leave_out_view != "correlation": - h_correlations = F.elu(self.input_corr(x_correlations)) - h_correlations2 = F.elu(self.corr_hidden(h_correlations)) + # if self.leave_out_view != "correlation": + # h_correlations = F.elu(self.input_corr(x_correlations)) + # h_correlations2 = F.elu(self.corr_hidden(h_correlations)) - included_views.append(h_correlations2) + # included_views.append(h_correlations2) - if self.leave_out_view != "morphology": - h_morphology = F.elu(self.input_morph(x_morphology)) - h_morphology2 = F.elu(self.morph_hidden(h_morphology)) + # if self.leave_out_view != "morphology": + # h_morphology = F.elu(self.input_morph(x_morphology)) + # h_morphology2 = F.elu(self.morph_hidden(h_morphology)) - included_views.append(h_morphology2) + # included_views.append(h_morphology2) - if self.leave_out_view != "spatial": - h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) - h_spatial_context2 = F.elu( - self.spatial_context_hidden(h_spatial_context) - ) + # if self.leave_out_view != "spatial": + # h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) + # h_spatial_context2 = F.elu( + # self.spatial_context_hidden(h_spatial_context) + # ) - included_views.append(h_spatial_context2) + # included_views.append(h_spatial_context2) - if cov_list.shape[0] > 1: - h_cov = F.elu(self.input_cov(cov_list)) - included_views.append(h_cov) - else: - h_cov = cov_list - included_views.append(h_cov) + # if cov_list.shape[0] > 1: + # h_cov = F.elu(self.input_cov(cov_list)) + # included_views.append(h_cov) + # else: + # h_cov = cov_list + # included_views.append(h_cov) - h = torch.cat(included_views, 1) # .type_as(x_mean) + # h = torch.cat(included_views, 1) # .type_as(x_mean) for net in self.linear: h = F.elu(net(h)) @@ -184,6 +176,7 @@ def __init__( super().__init__() hidden_dim = E_me + E_cr + E_mr + E_sc + E_cov latent_dim = latent_dim + n_covariates + self.leave_out_view = leave_out_view self.E_me = E_me self.E_cr = E_cr self.E_mr = E_mr @@ -220,123 +213,118 @@ def forward(self, z, cov_list): for net in self.linear: out = F.elu(net(out)) - if self.leave_out_view is None: + # if self.leave_out_view is None: - h2_mean = F.elu(self.exp_hidden(out[:, 0 : self.E_me])) - h2_correlations = F.elu( - self.corr_hidden(out[:, self.E_me : self.E_me + self.E_cr]) - ) - h2_morphology = F.elu( - self.morph_hidden( - out[:, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr] - ) + h2_mean = F.elu(self.exp_hidden(out[:, 0 : self.E_me])) + h2_correlations = F.elu( + self.corr_hidden(out[:, self.E_me : self.E_me + self.E_cr]) + ) + h2_morphology = F.elu( + self.morph_hidden( + out[:, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr] ) - h2_spatial_context = F.elu( - self.spatial_context_hidden( - out[ - :, - self.E_me - + self.E_cr - + self.E_mr : self.E_me - + self.E_cr - + self.E_mr - + self.E_sc, - ] - ) + ) + h2_spatial_context = F.elu( + self.spatial_context_hidden( + out[ + :, + self.E_me + + self.E_cr + + self.E_mr : self.E_me + + self.E_cr + + self.E_mr + + self.E_sc, + ] ) + ) - # covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :] - - mu_x_exp = self.mu_x_exp(h2_mean) - std_x_exp = self.std_x_exp(h2_mean) - - # if self.use_weights: - # with torch.no_grad(): - # weights = self.get_corr_weights_per_cell( - # mu_x_exp.detach() - # ) # calculating correlation weights - # else: - # weights = None - - mu_x_corr = self.mu_x_corr(h2_correlations) - std_x_corr = self.std_x_corr(h2_correlations) - - mu_x_morph = self.mu_x_morph(h2_morphology) - std_x_morph = self.std_x_morph(h2_morphology) - - mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context) - std_x_spatial_context = self.std_x_spcont(h2_spatial_context) - - # covariates_mu = self.covariates_out_mu(covariates) - # covariates_std = self.covariates_out_std(covariates) - - return ( - mu_x_exp, - std_x_exp, - mu_x_corr, - std_x_corr, - mu_x_morph, - std_x_morph, - mu_x_spatial_context, - std_x_spatial_context, - # covariates_mu, - # covariates_std, - # weights, - ) + # covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :] + + mu_x_exp = self.mu_x_exp(h2_mean) + std_x_exp = self.std_x_exp(h2_mean) + + # if self.use_weights: + # with torch.no_grad(): + # weights = self.get_corr_weights_per_cell( + # mu_x_exp.detach() + # ) # calculating correlation weights + # else: + # weights = None + + mu_x_corr = self.mu_x_corr(h2_correlations) + std_x_corr = self.std_x_corr(h2_correlations) + + mu_x_morph = self.mu_x_morph(h2_morphology) + std_x_morph = self.std_x_morph(h2_morphology) + + mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context) + std_x_spatial_context = self.std_x_spcont(h2_spatial_context) + + return ( + mu_x_exp, + std_x_exp, + mu_x_corr, + std_x_corr, + mu_x_morph, + std_x_morph, + mu_x_spatial_context, + std_x_spatial_context, + # weights, + ) - else: - included_views = [] - - if self.leave_out_view != "expression": - h2_mean = F.elu(self.exp_hidden(out[:, 0 : self.E_me])) - mu_x_exp = self.mu_x_exp(h2_mean) - std_x_exp = self.std_x_exp(h2_mean) - - included_views.append(mu_x_exp) - included_views.append(std_x_exp) - - if self.leave_out_view != "correlation": - h2_correlations = F.elu( - self.corr_hidden(out[:, self.E_me : self.E_me + self.E_cr]) - ) - mu_x_corr = self.mu_x_corr(h2_correlations) - std_x_corr = self.std_x_corr(h2_correlations) - - included_views.append(mu_x_corr) - included_views.append(std_x_corr) - - if self.leave_out_view != "morphology": - h2_morphology = F.elu( - self.morph_hidden( - out[ - :, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr - ] - ) - ) - mu_x_morph = self.mu_x_morph(h2_morphology) - std_x_morph = self.std_x_morph(h2_morphology) - - included_views.append(mu_x_morph) - included_views.append(std_x_morph) - - if self.leave_out_view != "spatial": - h2_spatial_context = F.elu( - self.spatial_context_hidden( - out[ - :, - self.E_me - + self.E_cr - + self.E_mr : self.E_me - + self.E_cr - + self.E_mr - + self.E_sc, - ] - ) - ) - mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context) - std_x_spatial_context = self.std_x_spcont(h2_spatial_context) - - included_views.append(mu_x_spatial_context) - included_views.append(std_x_spatial_context) - - return included_views + # else: + # included_views = [] + + # if self.leave_out_view != "expression": + # h2_mean = F.elu(self.exp_hidden(out[:, 0 : self.E_me])) + # mu_x_exp = self.mu_x_exp(h2_mean) + # std_x_exp = self.std_x_exp(h2_mean) + + # included_views.append(mu_x_exp) + # included_views.append(std_x_exp) + + # if self.leave_out_view != "correlation": + # h2_correlations = F.elu( + # self.corr_hidden(out[:, self.E_me : self.E_me + self.E_cr]) + # ) + # mu_x_corr = self.mu_x_corr(h2_correlations) + # std_x_corr = self.std_x_corr(h2_correlations) + + # included_views.append(mu_x_corr) + # included_views.append(std_x_corr) + + # if self.leave_out_view != "morphology": + # h2_morphology = F.elu( + # self.morph_hidden( + # out[ + # :, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr + # ] + # ) + # ) + # mu_x_morph = self.mu_x_morph(h2_morphology) + # std_x_morph = self.std_x_morph(h2_morphology) + + # included_views.append(mu_x_morph) + # included_views.append(std_x_morph) + + # if self.leave_out_view != "spatial": + # h2_spatial_context = F.elu( + # self.spatial_context_hidden( + # out[ + # :, + # self.E_me + # + self.E_cr + # + self.E_mr : self.E_me + # + self.E_cr + # + self.E_mr + # + self.E_sc, + # ] + # ) + # ) + # mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context) + # std_x_spatial_context = self.std_x_spcont(h2_spatial_context) + + # included_views.append(mu_x_spatial_context) + # included_views.append(std_x_spatial_context) + + # return included_views From 4c28692989d6823e0ec10acef57aed5cdce06aa0 Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Mon, 27 Feb 2023 13:21:32 -0500 Subject: [PATCH 12/18] script for running hmivae added --- hmivae/run_hmivae.py | 563 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 563 insertions(+) create mode 100644 hmivae/run_hmivae.py diff --git a/hmivae/run_hmivae.py b/hmivae/run_hmivae.py new file mode 100644 index 0000000..ba2074c --- /dev/null +++ b/hmivae/run_hmivae.py @@ -0,0 +1,563 @@ +## run with hmivae + +import argparse +import os +import time +from collections import OrderedDict + +import numpy as np +import pandas as pd + +# import phenograph +import scanpy as sc +import squidpy as sq +import torch +import wandb +from rich.progress import ( # track, + BarColumn, + Progress, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, +) +from scipy.stats.mstats import winsorize +from sklearn.preprocessing import StandardScaler +from statsmodels.api import OLS, add_constant + +# import hmivae +from hmivae._hmivae_model import hmivaeModel + + +def create_cluster_dummy(adata, cluster_col, cluster): + # n_clusters = len(adata.obs[cluster_col].unique().tolist()) + x = np.zeros([adata.X.shape[0], 1]) + + for cell in adata.obs.index: + # cell_cluster = int(adata.obs[cluster_col][cell]) + # print(type(cell), type(cluster)) + + if adata.obs[cluster_col][int(cell)] == cluster: + x[int(cell)] = 1 + + return x + + +def get_feature_matrix(adata, scale_values=False, cofactor=1, weights=True): + + correlations = adata.obsm["correlations"] + if weights: + correlations = np.multiply( + correlations, adata.obsm["weights"] + ) # multiply weights with correlations + + if scale_values: + morphology = adata.obsm["morphology"] + for i in range(adata.obsm["morphology"].shape[1]): + morphology[:, i] = winsorize( + adata.obsm["morphology"][:, i], limits=[0, 0.01] + ) + + expression = np.arcsinh(adata.X / cofactor) + for j in range(adata.X.shape[1]): + expression[:, j] = winsorize(expression[:, j], limits=[0, 0.01]) + else: + morphology = adata.obsm["morphology"] + expression = adata.X + + y = StandardScaler().fit_transform( + np.concatenate([expression, correlations, morphology], axis=1) + ) + + var_names = np.concatenate( + [ + adata.var_names, + adata.uns["names_correlations"], + adata.uns["names_morphology"], + ] + ) + + return y, var_names + + +def rank_features_in_groups(adata, group_col, scale_values=False, cofactor=1): + + progress = Progress( + TextColumn(f"[progress.description]Ranking features in {group_col} groups"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + ) + ranked_features_in_groups = {} + dfs = [] + # create the feature matrix for entire adata + y, var_names = get_feature_matrix( + adata, scale_values=scale_values, cofactor=cofactor + ) + y = add_constant(y) # add intercept + + with progress: + + for group in progress.track(adata.obs[group_col].unique()): + ranked_features_in_groups[str(group)] = {} + x = create_cluster_dummy(adata, group_col, group) + mod = OLS(x, y) + res = mod.fit() + + df_values = pd.DataFrame( + res.tvalues[1:], # remove the intercept value + index=var_names, + columns=[f"t_value_{group}"], + ).sort_values(by=f"t_value_{group}", ascending=False) + + ranked_features_in_groups[str(group)]["names"] = df_values.index.to_list() + ranked_features_in_groups[str(group)]["t_values"] = df_values[ + f"t_value_{group}" + ].to_list() + + # print('df index:', df_values.index.tolist()) + + dfs.append(df_values) + + fc_df = pd.concat( + dfs, axis=1 + ).sort_index() # index is sorted as alphabetical! (order with original var_names is NOT maintained!) + + fc_df.index = fc_df.index.map(str) + fc_df.columns = fc_df.columns.map(str) + + adata.uns[f"{group_col}_ranked_features_in_groups"] = ranked_features_in_groups + adata.uns[f"{group_col}_feature_scores"] = fc_df + + # return adata + + +def top_common_features(df, top_n_features=10): + + sets_list = [] + + for i in df.columns: + abs_sorted_col = df[i].map(abs).sort_values(ascending=False) + for j in abs_sorted_col.index.to_list()[0:top_n_features]: + sets_list.append(j) + + common_features = list(set(sets_list)) + + common_feat_df = df.loc[common_features] + + return common_feat_df + + +parser = argparse.ArgumentParser(description="Run hmiVAE") + +parser.add_argument( + "--adata", type=str, required=True, help="AnnData file with all the inputs" +) + +parser.add_argument( + "--include_all_views", + type=int, + help="Run model using all views", + default=1, + choices=[0, 1], +) + +parser.add_argument( + "--remove_view", + type=str, + help="Name of view to leave out. One of ['expression', 'correlation', 'morphology', 'spatial']. Must be given when `include_all_views` is False", + default=None, + choices=["expression", "correlation", "morphology", "spatial"], +) + +parser.add_argument( + "--use_covs", + type=bool, + help="True/False for using background covariates", + default=True, +) + +parser.add_argument( + "--use_weights", + type=bool, + help="True/False for using correlation weights", + default=True, +) + +parser.add_argument( + "--batch_correct", + type=bool, + help="True/False for using one-hot encoding for batch correction", + default=True, +) + +parser.add_argument( + "--batch_size", + type=int, + help="Batch size for train/test data, default=1234", + default=1234, +) + +parser.add_argument( + "--hidden_dim_size", + type=int, + help="Size for view-specific hidden layers", + default=32, +) + +parser.add_argument( + "--latent_dim", + type=int, + help="Size for the final latent representation layer", + default=10, +) + +parser.add_argument( + "--n_hidden", + type=int, + help="Number of hidden layers", + default=1, +) + +parser.add_argument( + "--beta_scheme", + type=str, + help="Scheme to use for beta vae", + default="warmup", + choices=["constant", "warmup"], +) + +parser.add_argument( + "--cofactor", type=float, help="Cofactor for arcsinh transformation", default=1.0 +) + +parser.add_argument( + "--random_seed", + type=int, + help="Random seed for weights initialization", + default=1234, +) + +parser.add_argument("--cohort", type=str, help="Cohort name", default="cohort") + +parser.add_argument( + "--output_dir", type=str, help="Directory to store the outputs", default="." +) + +args = parser.parse_args() + +log_file = open( + os.path.join( + args.output_dir, + f"{args.cohort}_nhidden{args.n_hidden}_hiddendim{args.hidden_dim_size}_latentdim{args.latent_dim}_betascheme{args.beta_scheme}_randomseed{args.random_seed}_run_log.txt", + ), + "w+", +) + +raw_adata = sc.read_h5ad(args.adata) + +# print("connections", adata.obsp["connectivities"]) +# print("raw adata X min,max", raw_adata.X.max(), raw_adata.X.min()) +# print("raw adata corrs min,max", raw_adata.obsm['correlations'].max(), raw_adata.obsm['correlations'].min()) +# print("raw adata morph min,max", raw_adata.obsm['morphology'].max(), raw_adata.obsm['morphology'].min()) + +L = [ + f"raw adata X, max: {raw_adata.X.max()}, min: {raw_adata.X.min()} \n", + f"raw adata correlations, max: {raw_adata.obsm['correlations'].max()}, min: {raw_adata.obsm['correlations'].min()} \n", + f"raw adata morphology, max: {raw_adata.obsm['morphology'].max()}, min: {raw_adata.obsm['morphology'].min()} \n", +] + +log_file.writelines(L) +n_total_features = ( + raw_adata.X.shape[1] + + raw_adata.obsm["correlations"].shape[1] + + raw_adata.obsm["morphology"].shape[1] +) + +log_file.write(f"Total number of features:{n_total_features} \n") +log_file.write(f"Total number of cells:{raw_adata.X.shape[0]} \n") + +print("Set up the model") + +start = time.time() + + +E_me, E_cr, E_mr, E_sc = [ + args.hidden_dim_size, + args.hidden_dim_size, + args.hidden_dim_size, + args.hidden_dim_size, +] +input_exp_dim, input_corr_dim, input_morph_dim, input_spcont_dim = [ + raw_adata.shape[1], + raw_adata.obsm["correlations"].shape[1], + raw_adata.obsm["morphology"].shape[1], + n_total_features, +] +keys = [] +if args.use_covs: + cat_list = [] + + for key in raw_adata.obsm.keys(): + # print(key) + if key not in ["correlations", "morphology", "spatial", "xy"]: + keys.append(key) + for cat_key in keys: + # print(cat_key) + # print(f"{cat_key} shape:", adata.obsm[cat_key].shape) + category = raw_adata.obsm[cat_key] + cat_list.append(category) + cat_list = np.concatenate(cat_list, 1) + n_covariates = cat_list.shape[1] + E_cov = args.hidden_dim_size +else: + n_covariates = 0 + E_cov = 0 + +model = hmivaeModel( + adata=raw_adata, + input_exp_dim=input_exp_dim, + input_corr_dim=input_corr_dim, + input_morph_dim=input_morph_dim, + input_spcont_dim=input_spcont_dim, + E_me=E_me, + E_cr=E_cr, + E_mr=E_mr, + E_sc=E_sc, + E_cov=E_cov, + latent_dim=args.latent_dim, + cofactor=args.cofactor, + use_covs=args.use_covs, + cohort=args.cohort, + use_weights=args.use_weights, + beta_scheme=args.beta_scheme, + n_covariates=n_covariates, + batch_correct=args.batch_correct, + batch_size=args.batch_size, + random_seed=args.random_seed, + n_hidden=args.n_hidden, + leave_out_view=args.remove_view, + output_dir=args.output_dir, +) + + +print("Start training") + + +model.train() + +wandb.finish() + +model_checkpoint = [ + i for i in os.listdir(args.output_dir) if ".ckpt" in i +] # should only be 1 -- saved best model + +print("model_checkpoint", model_checkpoint) + +load_chkpt = torch.load(os.path.join(args.output_dir, model_checkpoint[0])) + +state_dict = load_chkpt["state_dict"] +# print(state_dict) +new_state_dict = OrderedDict() +for k, v in state_dict.items(): + # print("key", k) + if "weight" or "bias" in k: + # print("changing", k) + name = "module." + k # add `module.` + # print("new name", name) + else: + # print("staying same", k) + name = k + new_state_dict[name] = v +# load params + +load_chkpt["state_dict"] = new_state_dict + +# torch.save(os.path.join(args.output_dir, model_checkpoint[0])) + +model = hmivaeModel( + adata=raw_adata, + input_exp_dim=input_exp_dim, + input_corr_dim=input_corr_dim, + input_morph_dim=input_morph_dim, + input_spcont_dim=input_spcont_dim, + E_me=E_me, + E_cr=E_cr, + E_mr=E_mr, + E_sc=E_sc, + E_cov=E_cov, + latent_dim=args.latent_dim, + cofactor=args.cofactor, + use_covs=args.use_covs, + use_weights=args.use_weights, + beta_scheme=args.beta_scheme, + n_covariates=n_covariates, + batch_correct=args.batch_correct, + batch_size=args.batch_size, + random_seed=args.random_seed, + n_hidden=args.n_hidden, + leave_out_view=args.remove_view, + output_dir=args.output_dir, +) +model.load_state_dict(new_state_dict, strict=False) + + +# model.load_from_checkpoint(os.path.join(args.output_dir, model_checkpoint[0]), adata=raw_adata) + +print("Best model loaded from checkpoint") + +stop = time.time() + +log_file.write(f"All training done in {(stop-start)/60} minutes \n") + +starta = time.time() + +adata = model.get_latent_representation( # use the best model to get the latent representations + adata=raw_adata, + protein_correlations_obsm_key="correlations", + cell_morphology_obsm_key="morphology", + continuous_covariate_keys=keys, + is_trained_model=True, + batch_correct=args.batch_correct, +) + +print("Doing cluster and neighbourhood enrichment analysis") + +sc.pp.neighbors( + adata, n_neighbors=100, use_rep="VAE", key_added="vae" +) # 100 nearest neighbours, will be used in downstream tests -- keep with PG + +sc.pp.neighbors(adata, n_neighbors=100, use_rep="VAE", key_added="vae_100") + +sc.tl.leiden(adata, neighbors_key="vae") + +# k = 100 # choose k (number of nearest neighbours) -- keep consistent with n_neighbours above +# sc.settings.verbose = 0 +# communities, graph, Q = phenograph.cluster(pd.DataFrame(adata.obsm['VAE']),k=k) # run PhenoGraph +# # store the results in adata: +# adata.obs['PhenoGraph_clusters'] = pd.Categorical(communities) +# adata.uns['PhenoGraph_Q'] = Q +# adata.uns['PhenoGraph_k'] = k + +# # sc.tl.tsne(adata_new, use_rep='VAE') +sc.tl.umap(adata, neighbors_key="vae") + +# random_inds = np.random.choice(range(adata.X.shape[0]), 5000) + +# sc.pl.umap(adata[random_inds], color=['leiden'], show=False, save=f"_{args.cohort}_{args.beta_scheme}{args.hidden_dim_size}{args.latent_dim}_batchsize{args.batch_size}") + +# df = pd.DataFrame(adata.obsm['VAE']) + +# df.to_csv( +# os.path.join( +# args.output_dir, +# f"{args.cohort}_nhid{args.n_hidden}_hdim{args.hidden_dim_size}_lspace{args.latent_dim}_batchsize{args.batch_size}_randomseed{args.random_seed}.tsv" +# ), +# sep='\t') + +print("Ranking features across cluster") +start1 = time.time() +if "cell_id" not in adata.obs.columns: + print("Reset index to get cell_id column") + adata.obs = adata.obs.reset_index() + + +# ranked_dict, fc_df = +rank_features_in_groups( + adata, + "leiden", + scale_values=False, + cofactor=args.cofactor, +) # no scaling required because using adata_train and test which have already been normalized and winsorized -- StandardScaler still applied +fc_df = adata.uns["leiden_feature_scores"] + +# #ranked_dict_pg, fc_df_pg = +# rank_features_in_groups( +# adata, "PhenoGraph_clusters", scale_values=True, cofactor=args.cofactor, +# ) +# fc_df_pg = adata.uns["PhenoGraph_clusters_feature_scores"] +stop1 = time.time() + +print(f"\t ===> Finished ranking features across clusters in {stop1-start1} seconds") + +print("Sorting most common features") + +# print('fc_df', fc_df) + +top5_leiden = top_common_features(fc_df) + +if args.include_all_views: + + top5_leiden.to_csv( + os.path.join( + args.output_dir, f"{args.cohort}_top5_features_across_clusters_leiden.tsv" + ), + sep="\t", + ) + +else: + top5_leiden.to_csv( + os.path.join( + args.output_dir, + f"{args.cohort}_top5_features_across_clusters_leiden_remove_{args.remove_view}.tsv", + ), + sep="\t", + ) + +# top5_pg = top_common_features(fc_df_pg) + +# if args.include_all_views: + +# top5_pg.to_csv( +# os.path.join(args.output_dir, f"{args.cohort}_top5_features_across_clusters_pg.tsv"), +# sep="\t", +# ) + +# else: +# top5_pg.to_csv( +# os.path.join(args.output_dir, f"{args.cohort}_top5_features_across_clusters_pg_remove_{args.remove_view}.tsv"), +# sep="\t", +# ) + +print("Neighbourhood enrichment analysis") + +# sq.gr.co_occurrence(adata, cluster_key="leiden") # if it works, it works + +sq.gr.spatial_neighbors(adata) +sq.gr.nhood_enrichment(adata, cluster_key="leiden") +sq.gr.nhood_enrichment(adata, cluster_key="PhenoGraph_clusters") + + +# sc.pl.umap(adata, color=['leiden', 'DNA1', 'panCK', 'CD45', 'Vimentin', 'CD3', 'CD20'], show=False, save=f"_{args.cohort}") + +stopa = time.time() + +log_file.write(f"All analysis done in {(stopa-starta)/60} minutes") + +log_file.close() + +# print("old", adata.uns.keys()) + +# new_uns = {str(k):v for k,v in adata.uns.items()} + +# print("new", new_uns.keys()) + +# # adata.uns = new_uns + +if args.include_all_views: + adata.obs.to_csv( + os.path.join(args.output_dir, f"{args.cohort}_clusters.tsv"), sep="\t" + ) + adata.write(os.path.join(args.output_dir, f"{args.cohort}_adata_new.h5ad")) + +else: + adata.obs.to_csv( + os.path.join( + args.output_dir, f"{args.cohort}_remove_{args.remove_view}_clusters.tsv" + ), + sep="\t", + ) + adata.write( + os.path.join( + args.output_dir, f"{args.cohort}_adata_remove_{args.remove_view}.h5ad" + ) + ) From f71ba8aef5944cdbc3ff1610ba0af88d5f9ce0b7 Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Tue, 7 Mar 2023 16:51:52 -0500 Subject: [PATCH 13/18] added linear decoder --- hmivae/_hmivae_base_components.py | 68 ++++++++++++++++++++++++------- hmivae/run_hmivae.py | 2 +- 2 files changed, 54 insertions(+), 16 deletions(-) diff --git a/hmivae/_hmivae_base_components.py b/hmivae/_hmivae_base_components.py index 8a93ea6..757cfd4 100644 --- a/hmivae/_hmivae_base_components.py +++ b/hmivae/_hmivae_base_components.py @@ -16,7 +16,10 @@ class EncoderHMIVAE(nn.Module): E_mr: Dimension for the encoded morphology input E_sc: Dimension for the encoded spatial context input latent_dim: Dimension of the encoded output + E_cov: Dimension for the encoded covariates input + n_covariates: Number of covariates n_hidden: Number of hidden layers, default=1 + leave_out_view: For ablation testing. View to leave out, default=None """ def __init__( @@ -152,7 +155,11 @@ class DecoderHMIVAE(nn.Module): input_corr_dim: Dimension for the decoded correlations output input_morph_dim: Dimension for the decoded morphology input input_spcont_dim: Dimension for the decoded spatial context input + E_cov: Dimension for the encoded covariates input + n_covariates: Number of covariates + linear_decoder: True or False for using a linear decoder n_hidden: Number of hidden layers, default=1 + leave_out_view: For ablation testing. View to leave out during training, default=None """ def __init__( @@ -168,6 +175,7 @@ def __init__( input_spcont_dim: int, E_cov: Optional[int] = 10, n_covariates: Optional[int] = 0, + linear_decoder: Optional[bool] = False, n_hidden: Optional[int] = 1, leave_out_view: Optional[ Union[None, Literal["expression", "correlation", "morphology", "spatial"]] @@ -182,6 +190,7 @@ def __init__( self.E_mr = E_mr self.E_sc = E_sc self.input = nn.Linear(latent_dim, hidden_dim) + self.linear_decoder = linear_decoder self.linear = nn.ModuleList( [nn.Linear(hidden_dim, hidden_dim) for i in range(n_hidden)] ) @@ -209,23 +218,20 @@ def forward(self, z, cov_list): z_s = torch.cat( [z, cov_list], 1 ) # takes in one-hot as input, doesn't need to be symmetric with the encoder, doesn't output it - out = F.elu(self.input(z_s)) - for net in self.linear: - out = F.elu(net(out)) - # if self.leave_out_view is None: + if ( + self.linear_decoder + ): # linear decoder, no activation functions and single linear layer + out = self.input(z_s) - h2_mean = F.elu(self.exp_hidden(out[:, 0 : self.E_me])) - h2_correlations = F.elu( - self.corr_hidden(out[:, self.E_me : self.E_me + self.E_cr]) - ) - h2_morphology = F.elu( - self.morph_hidden( + h2_mean = self.exp_hidden(out[:, 0 : self.E_me]) + h2_correlations = self.corr_hidden( + out[:, self.E_me : self.E_me + self.E_cr] + ) + h2_morphology = self.morph_hidden( out[:, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr] ) - ) - h2_spatial_context = F.elu( - self.spatial_context_hidden( + h2_spatial_context = self.spatial_context_hidden( out[ :, self.E_me @@ -236,9 +242,41 @@ def forward(self, z, cov_list): + self.E_sc, ] ) - ) - # covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :] + # covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :] + + else: + # standard decoder with activation functions (non-linear) + out = F.elu(self.input(z_s)) + for net in self.linear: + out = F.elu(net(out)) + + # if self.leave_out_view is None: + + h2_mean = F.elu(self.exp_hidden(out[:, 0 : self.E_me])) + h2_correlations = F.elu( + self.corr_hidden(out[:, self.E_me : self.E_me + self.E_cr]) + ) + h2_morphology = F.elu( + self.morph_hidden( + out[:, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr] + ) + ) + h2_spatial_context = F.elu( + self.spatial_context_hidden( + out[ + :, + self.E_me + + self.E_cr + + self.E_mr : self.E_me + + self.E_cr + + self.E_mr + + self.E_sc, + ] + ) + ) + + # covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :] mu_x_exp = self.mu_x_exp(h2_mean) std_x_exp = self.std_x_exp(h2_mean) diff --git a/hmivae/run_hmivae.py b/hmivae/run_hmivae.py index ba2074c..9885a43 100644 --- a/hmivae/run_hmivae.py +++ b/hmivae/run_hmivae.py @@ -426,7 +426,7 @@ def top_common_features(df, top_n_features=10): adata, n_neighbors=100, use_rep="VAE", key_added="vae" ) # 100 nearest neighbours, will be used in downstream tests -- keep with PG -sc.pp.neighbors(adata, n_neighbors=100, use_rep="VAE", key_added="vae_100") +# sc.pp.neighbors(adata, n_neighbors=100, use_rep="VAE", key_added="vae_100") sc.tl.leiden(adata, neighbors_key="vae") From 6d8263d87fa4268367f2b38e6b2fab61d92e17df Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Tue, 7 Mar 2023 21:32:41 -0500 Subject: [PATCH 14/18] view-specific embedding clustering --- hmivae/_hmivae_model.py | 53 ++++++++++++---- hmivae/run_hmivae.py | 133 +++++++++++++++++++++++++--------------- 2 files changed, 124 insertions(+), 62 deletions(-) diff --git a/hmivae/_hmivae_model.py b/hmivae/_hmivae_model.py index 89dc01c..0f5f6c7 100644 --- a/hmivae/_hmivae_model.py +++ b/hmivae/_hmivae_model.py @@ -279,6 +279,7 @@ def get_latent_representation( is_trained_model: Optional[bool] = False, batch_correct: Optional[bool] = True, use_covs: Optional[bool] = True, + save_view_specific_embeddings: Optional[bool] = True, ) -> AnnData: """ Gives the latent representation of each cell. @@ -303,18 +304,46 @@ def get_latent_representation( image_correct=batch_correct, ) - adata_train.obsm["VAE"] = self.module.inference( - data_train, - n_covariates=n_covariates, - use_covs=use_covs, - batch_correct=batch_correct, - ) # idx=train_idx) - adata_test.obsm["VAE"] = self.module.inference( - data_test, - n_covariates=n_covariates, - use_covs=use_covs, - batch_correct=batch_correct, - ) # idx=test_idx) + if save_view_specific_embeddings: + ( + adata_train.obsm["VAE"], + adata_train.obsm["expression_embedding"], + adata_train.obsm["correlation_embedding"], + adata_train.obsm["morphology_embedding"], + adata_train.obsm["spatial_context_embedding"], + ) = self.module.inference( + data_train, + n_covariates=n_covariates, + use_covs=use_covs, + batch_correct=batch_correct, + ) # idx=train_idx) + + ( + adata_test.obsm["VAE"], + adata_test.obsm["expression_embedding"], + adata_test.obsm["correlation_embedding"], + adata_test.obsm["morphology_embedding"], + adata_test.obsm["spatial_context_embedding"], + ) = self.module.inference( + data_test, + n_covariates=n_covariates, + use_covs=use_covs, + batch_correct=batch_correct, + ) # idx=test_idx) + + else: + adata_train.obsm["VAE"] = self.module.inference( + data_train, + n_covariates=n_covariates, + use_covs=use_covs, + batch_correct=batch_correct, + ) # idx=train_idx) + adata_test.obsm["VAE"] = self.module.inference( + data_test, + n_covariates=n_covariates, + use_covs=use_covs, + batch_correct=batch_correct, + ) # idx=test_idx) return ad.concat([adata_train, adata_test], uns_merge="first") else: diff --git a/hmivae/run_hmivae.py b/hmivae/run_hmivae.py index 9885a43..3aa0155 100644 --- a/hmivae/run_hmivae.py +++ b/hmivae/run_hmivae.py @@ -422,6 +422,8 @@ def top_common_features(df, top_n_features=10): print("Doing cluster and neighbourhood enrichment analysis") +print("Clustering using integrated space") + sc.pp.neighbors( adata, n_neighbors=100, use_rep="VAE", key_added="vae" ) # 100 nearest neighbours, will be used in downstream tests -- keep with PG @@ -430,78 +432,109 @@ def top_common_features(df, top_n_features=10): sc.tl.leiden(adata, neighbors_key="vae") -# k = 100 # choose k (number of nearest neighbours) -- keep consistent with n_neighbours above -# sc.settings.verbose = 0 -# communities, graph, Q = phenograph.cluster(pd.DataFrame(adata.obsm['VAE']),k=k) # run PhenoGraph -# # store the results in adata: -# adata.obs['PhenoGraph_clusters'] = pd.Categorical(communities) -# adata.uns['PhenoGraph_Q'] = Q -# adata.uns['PhenoGraph_k'] = k +print("Clustering using specific views") -# # sc.tl.tsne(adata_new, use_rep='VAE') -sc.tl.umap(adata, neighbors_key="vae") +sc.pp.neighbors( + adata, n_neighbors=100, use_rep="expression_embedding", key_added="expression" +) # 100 nearest neighbours, will be used in downstream tests -- keep with PG -# random_inds = np.random.choice(range(adata.X.shape[0]), 5000) +# sc.pp.neighbors(adata, n_neighbors=100, use_rep="VAE", key_added="vae_100") -# sc.pl.umap(adata[random_inds], color=['leiden'], show=False, save=f"_{args.cohort}_{args.beta_scheme}{args.hidden_dim_size}{args.latent_dim}_batchsize{args.batch_size}") +sc.tl.leiden( + adata, + neighbors_key="expression", + key_added="expression_leiden", + random_state=args.random_seed, +) -# df = pd.DataFrame(adata.obsm['VAE']) +sc.pp.neighbors( + adata, n_neighbors=100, use_rep="correlation_embedding", key_added="correlation" +) -# df.to_csv( -# os.path.join( -# args.output_dir, -# f"{args.cohort}_nhid{args.n_hidden}_hdim{args.hidden_dim_size}_lspace{args.latent_dim}_batchsize{args.batch_size}_randomseed{args.random_seed}.tsv" -# ), -# sep='\t') +sc.tl.leiden( + adata, + neighbors_key="correlation", + key_added="correlation_leiden", + random_state=args.random_seed, +) -print("Ranking features across cluster") -start1 = time.time() -if "cell_id" not in adata.obs.columns: - print("Reset index to get cell_id column") - adata.obs = adata.obs.reset_index() +sc.pp.neighbors( + adata, n_neighbors=100, use_rep="morphology_embedding", key_added="morphology" +) +sc.tl.leiden( + adata, + neighbors_key="morphology", + key_added="morphology_leiden", + random_state=args.random_seed, +) -# ranked_dict, fc_df = -rank_features_in_groups( +sc.pp.neighbors( adata, - "leiden", - scale_values=False, - cofactor=args.cofactor, -) # no scaling required because using adata_train and test which have already been normalized and winsorized -- StandardScaler still applied -fc_df = adata.uns["leiden_feature_scores"] + n_neighbors=100, + use_rep="spatial_context_embedding", + key_added="spatial_context", +) + +sc.tl.leiden( + adata, + neighbors_key="spatial_context", + key_added="spatial_context_leiden", + random_state=args.random_seed, +) + + +sc.tl.umap(adata, neighbors_key="vae", random_state=args.random_seed) + + +# print("Ranking features across cluster") +# start1 = time.time() +# if "cell_id" not in adata.obs.columns: +# print("Reset index to get cell_id column") +# adata.obs = adata.obs.reset_index() + + +# # ranked_dict, fc_df = +# rank_features_in_groups( +# adata, +# "leiden", +# scale_values=False, +# cofactor=args.cofactor, +# ) # no scaling required because using adata_train and test which have already been normalized and winsorized -- StandardScaler still applied +# fc_df = adata.uns["leiden_feature_scores"] # #ranked_dict_pg, fc_df_pg = # rank_features_in_groups( # adata, "PhenoGraph_clusters", scale_values=True, cofactor=args.cofactor, # ) # fc_df_pg = adata.uns["PhenoGraph_clusters_feature_scores"] -stop1 = time.time() +# stop1 = time.time() -print(f"\t ===> Finished ranking features across clusters in {stop1-start1} seconds") +# print(f"\t ===> Finished ranking features across clusters in {stop1-start1} seconds") -print("Sorting most common features") +# print("Sorting most common features") # print('fc_df', fc_df) -top5_leiden = top_common_features(fc_df) +# top5_leiden = top_common_features(fc_df) -if args.include_all_views: +# if args.include_all_views: - top5_leiden.to_csv( - os.path.join( - args.output_dir, f"{args.cohort}_top5_features_across_clusters_leiden.tsv" - ), - sep="\t", - ) +# top5_leiden.to_csv( +# os.path.join( +# args.output_dir, f"{args.cohort}_top5_features_across_clusters_leiden.tsv" +# ), +# sep="\t", +# ) -else: - top5_leiden.to_csv( - os.path.join( - args.output_dir, - f"{args.cohort}_top5_features_across_clusters_leiden_remove_{args.remove_view}.tsv", - ), - sep="\t", - ) +# else: +# top5_leiden.to_csv( +# os.path.join( +# args.output_dir, +# f"{args.cohort}_top5_features_across_clusters_leiden_remove_{args.remove_view}.tsv", +# ), +# sep="\t", +# ) # top5_pg = top_common_features(fc_df_pg) @@ -524,7 +557,7 @@ def top_common_features(df, top_n_features=10): sq.gr.spatial_neighbors(adata) sq.gr.nhood_enrichment(adata, cluster_key="leiden") -sq.gr.nhood_enrichment(adata, cluster_key="PhenoGraph_clusters") +# sq.gr.nhood_enrichment(adata, cluster_key="PhenoGraph_clusters") # sc.pl.umap(adata, color=['leiden', 'DNA1', 'panCK', 'CD45', 'Vimentin', 'CD3', 'CD20'], show=False, save=f"_{args.cohort}") From c97ac9ffa754558ef958cfba2cca4b308c66a33e Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Tue, 7 Mar 2023 21:40:54 -0500 Subject: [PATCH 15/18] added option of linear decoder everywhere --- hmivae/_hmivae_model.py | 2 ++ hmivae/_hmivae_module.py | 2 ++ hmivae/run_hmivae.py | 9 +++++++++ 3 files changed, 13 insertions(+) diff --git a/hmivae/_hmivae_model.py b/hmivae/_hmivae_model.py index 0f5f6c7..273847c 100644 --- a/hmivae/_hmivae_model.py +++ b/hmivae/_hmivae_model.py @@ -67,6 +67,7 @@ def __init__( n_hidden: int = 1, cofactor: float = 1.0, beta_scheme: Optional[Literal["constant", "warmup"]] = "warmup", + linear_decoder: Optional[bool] = False, batch_correct: bool = True, is_trained_model: bool = False, batch_size: Optional[int] = 1234, @@ -169,6 +170,7 @@ def __init__( n_hidden=n_hidden, use_covs=self.use_covs, use_weights=self.use_weights, + linear_decoder=linear_decoder, beta_scheme=beta_scheme, batch_correct=batch_correct, leave_out_view=leave_out_view, diff --git a/hmivae/_hmivae_module.py b/hmivae/_hmivae_module.py index 28a09db..34784a5 100644 --- a/hmivae/_hmivae_module.py +++ b/hmivae/_hmivae_module.py @@ -39,6 +39,7 @@ def __init__( ] = None, use_covs: bool = False, use_weights: bool = True, + linear_decoder: Optional[bool] = False, n_hidden: int = 1, beta_scheme: Optional[Literal["constant", "warmup"]] = "warmup", batch_correct: bool = True, @@ -93,6 +94,7 @@ def __init__( n_covariates=n_covariates, leave_out_view=leave_out_view, n_hidden=n_hidden, + linear_decoder=linear_decoder, ) self.save_hyperparameters(ignore=["adata"]) diff --git a/hmivae/run_hmivae.py b/hmivae/run_hmivae.py index 3aa0155..5dd21a6 100644 --- a/hmivae/run_hmivae.py +++ b/hmivae/run_hmivae.py @@ -226,6 +226,13 @@ def top_common_features(df, top_n_features=10): choices=["constant", "warmup"], ) +parser.add_argument( + "--use_linear_decoder", + type=bool, + help="For using a linear decoder: True or False", + default=False, +) + parser.add_argument( "--cofactor", type=float, help="Cofactor for arcsinh transformation", default=1.0 ) @@ -330,6 +337,7 @@ def top_common_features(df, top_n_features=10): cohort=args.cohort, use_weights=args.use_weights, beta_scheme=args.beta_scheme, + linear_decoder=args.use_linear_decoder, n_covariates=n_covariates, batch_correct=args.batch_correct, batch_size=args.batch_size, @@ -390,6 +398,7 @@ def top_common_features(df, top_n_features=10): use_covs=args.use_covs, use_weights=args.use_weights, beta_scheme=args.beta_scheme, + linear_decoder=args.use_linear_decoder, n_covariates=n_covariates, batch_correct=args.batch_correct, batch_size=args.batch_size, From 0ff58b12b658613c1d43caefbb52a7ce715ee6f7 Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Wed, 5 Apr 2023 14:26:52 -0400 Subject: [PATCH 16/18] run_hmivae.py updated with how clustering and feature ranking was done --- hmivae/run_hmivae.py | 285 ++++++++++++++++++++++++++++++++----------- 1 file changed, 217 insertions(+), 68 deletions(-) diff --git a/hmivae/run_hmivae.py b/hmivae/run_hmivae.py index 5dd21a6..ce5fb68 100644 --- a/hmivae/run_hmivae.py +++ b/hmivae/run_hmivae.py @@ -1,18 +1,20 @@ ## run with hmivae - import argparse import os import time from collections import OrderedDict +import matplotlib.pyplot as plt import numpy as np import pandas as pd # import phenograph import scanpy as sc +import seaborn as sns import squidpy as sq import torch import wandb +from anndata import AnnData from rich.progress import ( # track, BarColumn, Progress, @@ -26,6 +28,31 @@ # import hmivae from hmivae._hmivae_model import hmivaeModel +from hmivae.ScModeDataloader import ScModeDataloader + + +def arrange_features(vars_lst, adata): + + arranged_features = {"E": [], "C": [], "M": [], "S": []} + orig_list = vars_lst + + for i in orig_list: + + if i in adata.var_names: + arranged_features["E"].append(i) + elif i in adata.uns["names_morphology"]: + arranged_features["M"].append(i) + elif i in adata.uns["names_correlations"]: + arranged_features["C"].append(i) + else: + arranged_features["S"].append(i) + arr_list = [ + *np.sort(arranged_features["E"]).tolist(), + *np.sort(arranged_features["C"]).tolist(), + *np.sort(arranged_features["M"]).tolist(), + *np.sort(arranged_features["S"]).tolist(), + ] + return arranged_features, arr_list def create_cluster_dummy(adata, cluster_col, cluster): @@ -431,30 +458,32 @@ def top_common_features(df, top_n_features=10): print("Doing cluster and neighbourhood enrichment analysis") -print("Clustering using integrated space") +print("===> Clustering using integrated space") sc.pp.neighbors( adata, n_neighbors=100, use_rep="VAE", key_added="vae" ) # 100 nearest neighbours, will be used in downstream tests -- keep with PG -# sc.pp.neighbors(adata, n_neighbors=100, use_rep="VAE", key_added="vae_100") sc.tl.leiden(adata, neighbors_key="vae") -print("Clustering using specific views") +print("===> Clustering using specific views") + +print("Expression") sc.pp.neighbors( adata, n_neighbors=100, use_rep="expression_embedding", key_added="expression" ) # 100 nearest neighbours, will be used in downstream tests -- keep with PG -# sc.pp.neighbors(adata, n_neighbors=100, use_rep="VAE", key_added="vae_100") - sc.tl.leiden( adata, neighbors_key="expression", key_added="expression_leiden", random_state=args.random_seed, -) + resolution=0.5, +) # expression wasn't too bad + +print("Correlation") sc.pp.neighbors( adata, n_neighbors=100, use_rep="correlation_embedding", key_added="correlation" @@ -465,7 +494,9 @@ def top_common_features(df, top_n_features=10): neighbors_key="correlation", key_added="correlation_leiden", random_state=args.random_seed, -) +) # probably no need to change correlation because there were few anyways + +print("Morphology") sc.pp.neighbors( adata, n_neighbors=100, use_rep="morphology_embedding", key_added="morphology" @@ -476,7 +507,10 @@ def top_common_features(df, top_n_features=10): neighbors_key="morphology", key_added="morphology_leiden", random_state=args.random_seed, -) + resolution=0.1, +) # pull it way down because there were LOTS of clusters + +print("Spatial context") sc.pp.neighbors( adata, @@ -490,106 +524,218 @@ def top_common_features(df, top_n_features=10): neighbors_key="spatial_context", key_added="spatial_context_leiden", random_state=args.random_seed, + resolution=0.5, ) +print("===> Creating UMAPs") + +print("Integrated space") sc.tl.umap(adata, neighbors_key="vae", random_state=args.random_seed) +adata.obsm["X_umap_int"] = adata.obsm["X_umap"].copy() -# print("Ranking features across cluster") -# start1 = time.time() -# if "cell_id" not in adata.obs.columns: -# print("Reset index to get cell_id column") -# adata.obs = adata.obs.reset_index() +print("Expression") +sc.tl.umap(adata, neighbors_key="expression", random_state=args.random_seed) -# # ranked_dict, fc_df = -# rank_features_in_groups( -# adata, -# "leiden", -# scale_values=False, -# cofactor=args.cofactor, -# ) # no scaling required because using adata_train and test which have already been normalized and winsorized -- StandardScaler still applied -# fc_df = adata.uns["leiden_feature_scores"] +adata.obsm["X_umap_exp"] = adata.obsm["X_umap"].copy() -# #ranked_dict_pg, fc_df_pg = -# rank_features_in_groups( -# adata, "PhenoGraph_clusters", scale_values=True, cofactor=args.cofactor, -# ) -# fc_df_pg = adata.uns["PhenoGraph_clusters_feature_scores"] -# stop1 = time.time() +print("Correlations") -# print(f"\t ===> Finished ranking features across clusters in {stop1-start1} seconds") +sc.tl.umap(adata, neighbors_key="correlation", random_state=args.random_seed) -# print("Sorting most common features") +adata.obsm["X_umap_corr"] = adata.obsm["X_umap"].copy() -# print('fc_df', fc_df) +print("Morphology") -# top5_leiden = top_common_features(fc_df) +sc.tl.umap(adata, neighbors_key="morphology", random_state=args.random_seed) -# if args.include_all_views: +adata.obsm["X_umap_morph"] = adata.obsm["X_umap"].copy() -# top5_leiden.to_csv( -# os.path.join( -# args.output_dir, f"{args.cohort}_top5_features_across_clusters_leiden.tsv" -# ), -# sep="\t", -# ) +print("Spatial context") -# else: -# top5_leiden.to_csv( -# os.path.join( -# args.output_dir, -# f"{args.cohort}_top5_features_across_clusters_leiden_remove_{args.remove_view}.tsv", -# ), -# sep="\t", -# ) +sc.tl.umap(adata, neighbors_key="spatial_context", random_state=args.random_seed) -# top5_pg = top_common_features(fc_df_pg) +adata.obsm["X_umap_spct"] = adata.obsm["X_umap"].copy() +# ranked_dict, fc_df = +# rank_features_in_groups( +# adata, "leiden", scale_values=False, cofactor=args.cofactor, +# ) # no scaling required because using adata_train and test which have already been normalized and winsorized -- StandardScaler still applied +# fc_df = adata.uns["leiden_feature_scores"] -# if args.include_all_views: +# top5_leiden = top_common_features(fc_df) -# top5_pg.to_csv( -# os.path.join(args.output_dir, f"{args.cohort}_top5_features_across_clusters_pg.tsv"), -# sep="\t", -# ) +# if args.include_all_views: -# else: -# top5_pg.to_csv( -# os.path.join(args.output_dir, f"{args.cohort}_top5_features_across_clusters_pg_remove_{args.remove_view}.tsv"), +# top5_leiden.to_csv( +# os.path.join(args.output_dir, f"{args.cohort}_top5_features_across_clusters_leiden.tsv"), # sep="\t", # ) print("Neighbourhood enrichment analysis") -# sq.gr.co_occurrence(adata, cluster_key="leiden") # if it works, it works +# sq.gr.co_occurrence(adata, cluster_key="leiden") # if it works, it works -- didn't work, always NaNs sq.gr.spatial_neighbors(adata) sq.gr.nhood_enrichment(adata, cluster_key="leiden") -# sq.gr.nhood_enrichment(adata, cluster_key="PhenoGraph_clusters") -# sc.pl.umap(adata, color=['leiden', 'DNA1', 'panCK', 'CD45', 'Vimentin', 'CD3', 'CD20'], show=False, save=f"_{args.cohort}") +print("===> Create the neighbourhood features") + +h5 = adata.copy() + +sc.pp.neighbors( + h5, use_rep="spatial", n_neighbors=10 +) # get spatial neighbour connectivities, we lose this when we make the new adata + +data = ScModeDataloader(h5) + +spatial_context = data.C.numpy() + +spatial_context_names = [ + "neighbour_" + i + for i in list(h5.var_names) + + h5.uns["names_correlations"].tolist() + + h5.uns["names_morphology"].tolist() +] + +print("===> Creating new adata and ranking all features") + +clustering = [i for i in h5.obs.columns if "leiden" in i] + +all_features = np.concatenate( + [h5.X, h5.obsm["correlations"], h5.obsm["morphology"], spatial_context], axis=1 +) + +names = np.concatenate( + [ + h5.var_names, + h5.uns["names_correlations"], + h5.uns["names_morphology"], + spatial_context_names, + ] +) + +all_features_df = pd.DataFrame(all_features, columns=names) + + +new_adata = AnnData( + X=all_features_df, + obs=h5.copy().obs, + obsm=h5.copy().obsm, + obsp=h5.copy().obsp, + uns=h5.copy().uns, +) + +for cl in clustering: + print(f"Ranking features for clustering: {cl}") + sc.tl.rank_genes_groups(new_adata, groupby=cl, key_added=f"{cl}_rank_gene_groups") + +dfs = [] + +for cl in clustering: + ranked_df = sc.get.rank_genes_groups_df( + new_adata, group=None, key=f"{cl}_rank_gene_groups" + ) + + ranked_df["clustering"] = [cl] * ranked_df.shape[0] + + dfs.append(ranked_df) + +full_ranked_df = pd.concat(dfs) + +## get the top features across all the different clustering + +dfs2 = {} + +for cl in clustering: + print(f"sorting ranked features for {cl}") + fs = [] + features_df = full_ranked_df.copy().query("clustering==@cl") + for group in features_df.group.unique(): + group_df = features_df.query("group==@group") + top10 = group_df.names.tolist()[0:10] # these are sorted by top + + for f in top10: + fs.append(f) + + top_features = list(set(fs)) + + new_df = pd.DataFrame({}) -stopa = time.time() + for group in features_df.group.unique(): + group_df = features_df.query("group==@group") -log_file.write(f"All analysis done in {(stopa-starta)/60} minutes") + # print('df shape', group_df.shape[0]) -log_file.close() + scores = ( + group_df.loc[group_df.names.isin(top_features), ["names", "scores"]] + .set_index("names") + .sort_index() + .scores.tolist() + ) + + new_df[group] = scores + + # print(group, new_df.shape) + + new_df.index = np.sort(top_features) + + arr_features2, arr_list2 = arrange_features(new_df.index.to_list(), adata) + + new_df = new_df.reindex(arr_list2) + + new_df.columns = new_df.columns.map(int) + + new_df = new_df[np.sort(new_df.columns)] + + # new_df['clustering'] = [cl]*new_df.shape[0] + + dfs2[cl] = new_df + +cmap = sns.diverging_palette(220, 20, as_cmap=True) -# print("old", adata.uns.keys()) +for n, cl in enumerate(clustering): + # print(n) + # bx = plt.subplot(6,1,n+1) + sns.clustermap( + dfs2[cl].fillna(0), + row_cluster=False, + center=0.00, + cmap=cmap, + vmin=-100, + vmax=100, + figsize=(25, 25), + linewidth=2, + linecolor="black", + ) + + # plt.title(f"rankings for {cl}") + + plt.savefig(f"{args.cohort}_cluster_rankings_for_{cl}.png") -# new_uns = {str(k):v for k,v in adata.uns.items()} +print("old", new_adata.uns.keys()) -# print("new", new_uns.keys()) +new_uns = {str(k): v for k, v in new_adata.uns.items()} -# # adata.uns = new_uns +print("new", new_uns.keys()) + +adata.uns = new_uns if args.include_all_views: - adata.obs.to_csv( + new_adata.obs.to_csv( os.path.join(args.output_dir, f"{args.cohort}_clusters.tsv"), sep="\t" ) - adata.write(os.path.join(args.output_dir, f"{args.cohort}_adata_new.h5ad")) + new_adata.write(os.path.join(args.output_dir, f"{args.cohort}_adata_new.h5ad")) + full_ranked_df.to_csv( + os.path.join(args.output_dir, f"{args.cohort}_clusters_ranked_features.tsv"), + sep="\t", + ) + +# if args.include_all_views: +# adata.obs.to_csv(os.path.join(args.output_dir, f"{args.cohort}_clusters.tsv"), sep="\t") +# adata.write(os.path.join(args.output_dir, f"{args.cohort}_adata_new.h5ad")) else: adata.obs.to_csv( @@ -603,3 +749,6 @@ def top_common_features(df, top_n_features=10): args.output_dir, f"{args.cohort}_adata_remove_{args.remove_view}.h5ad" ) ) + + +# sc.pl.umap(adata[random_inds], color=['leiden'], show From efb27d3d5a67c9f223798324aac1d8e34dfa4fca Mon Sep 17 00:00:00 2001 From: shanzaayub Date: Wed, 5 Apr 2023 17:03:39 -0400 Subject: [PATCH 17/18] clinical associations code added --- ...tions_latent_dim_and_cluster_prevalence.py | 429 ++++++++++++++++++ 1 file changed, 429 insertions(+) create mode 100644 hmivae/clinical_associations_latent_dim_and_cluster_prevalence.py diff --git a/hmivae/clinical_associations_latent_dim_and_cluster_prevalence.py b/hmivae/clinical_associations_latent_dim_and_cluster_prevalence.py new file mode 100644 index 0000000..d1d6209 --- /dev/null +++ b/hmivae/clinical_associations_latent_dim_and_cluster_prevalence.py @@ -0,0 +1,429 @@ +### Clinical associations + +from collections import Counter + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import scanpy as sc +import statsmodels.api as sm +import tifffile +from rich.progress import ( + BarColumn, + Progress, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, +) + +### Load data + +cohort = "basel" + +adata = sc.read_h5ad( + f"../cluster_analysis/{cohort}/best_run_{cohort}_no_dna/results_diff_res/{cohort}_adata_new.h5ad" +) # directory where adata was stored + +patient_data = pd.read_csv( + f"{cohort}/{cohort}_survival_patient_samples.tsv", sep="\t", index_col=0 +) + +clinical_variables = [ + "ERStatus", + "grade", + "PRStatus", + "HER2Status", + "Subtype", + "clinical_type", + "HR", +] # changes for each cohort, here example is basel + +patient_col = "PID" + +cluster_col = "leiden" + +### Visualize the data + +plt.rcParams["figure.figsize"] = [10, 10] + +for n, i in enumerate(clinical_variables): + ax = plt.subplot(4, 2, n + 1) + df = pd.DataFrame(patient_data[i].value_counts()).transpose() + + df.plot.bar(ax=ax) + + ax.set_xticklabels([i], rotation=0) + plt.legend(bbox_to_anchor=[1.0, 1.1]) + +# Patient / Latent Variable associations + +df = pd.DataFrame( + columns=["Sample_name"] + + [f"median_latent_dim_{n}" for n in range(adata.obsm["VAE"].shape[1])] +) + +for n, sample in enumerate(adata.obs.Sample_name.unique()): + sample_adata = adata.copy()[adata.obs.Sample_name.isin([sample]), :] + + df.loc[str(n)] = [sample] + np.median(sample_adata.obsm["VAE"], axis=0).tolist() + +patient_latent = pd.merge(df, patient_data, on="Sample_name") + +## first try + +latent_dim_cols = [i for i in patient_latent.columns if "median" in i] +exception_variables = [] + +dfs = [] + +for cvar in clinical_variables: + cvar_dfs = [] + + for sub_cvar in patient_latent[cvar].unique(): + print(cvar, sub_cvar) + sub_cvar_df = pd.DataFrame({}) + selected_df = patient_latent.copy()[ + ~patient_latent[cvar].isna() + ] # drop nan values for each var + selected_df[cvar] = list(map(int, selected_df[cvar] == sub_cvar)) + + X = selected_df[ + latent_dim_cols + ].to_numpy() # select columns corresponding to latent dims and convert to numpy + X = sm.add_constant(X) # add constant + y = selected_df[ + cvar + ].to_numpy() # select the clinical variable column and convert to numpy -- no fillna(0) since all the nans should have been dropped + try: + log_reg = sm.Logit(y, X).fit() # fit the Logistic Regression model + + sub_cvar_df["latent_dim"] = [c.split("_")[-1] for c in latent_dim_cols] + + sub_cvar_df["tvalues"] = log_reg.tvalues[1:] # remove the constant + + sub_cvar_df["clinical_variable"] = [ + f"{cvar}:{sub_cvar}" + ] * sub_cvar_df.shape[0] + + cvar_dfs.append(sub_cvar_df) + except Exception as e: + exception_variables.append((cvar, sub_cvar)) + print(f"{cvar}:{sub_cvar} had an exception occur: {e}") + + full_cvar_dfs = pd.concat(cvar_dfs) + + dfs.append(full_cvar_dfs) + +## Second try, which features caused issues for which clinical variable + +features_to_remove = [] +# cvar_dfs2 = [] +for cvar, sub_cvar in exception_variables: + # print(cvar, sub_cvar) + selected_df = patient_latent[ + ~patient_latent[cvar].isna() + ].copy() # drop nan values for each var + selected_df[cvar] = list(map(int, selected_df[cvar] == sub_cvar)) + y = selected_df[cvar].to_numpy() + X = selected_df[ + latent_dim_cols + ].to_numpy() # select columns corresponding to latent dims and convert to numpy + + perf_sep_features = [] + for i in range(X.shape[1]): + X_1 = X.copy()[:, 0 : i + 1] + X_1 = sm.add_constant(X_1) # add constant + try: + log_reg = sm.Logit(y, X_1).fit() # fit the Logistic Regression model + print( + f"Completed: tvalues for {cvar}:{sub_cvar}, features till {i} -> {log_reg.tvalues}" + ) + # print(log_reg.summary()) + except Exception as e: + print(f"{cvar}:{sub_cvar} for feature {i} has exception: {e}") + perf_sep_features.append(i) + + # if len(perf_sep_features) == 0: + # sub_cvar_df = pd.DataFrame({}) + # sub_cvar_df['latent_dim'] = [c.split('_')[-1] for c in latent_dim_cols] + + # assert len(log_reg.tvalues) == X.shape[1]+1 #for constant -- check this is the last one + + # sub_cvar_df['tvalues'] = log_reg.tvalues[1:] # remove the constant -- this should be the last one + + # sub_cvar_df['clinical_variable'] = [f"{cvar}:{sub_cvar}"]*sub_cvar_df.shape[0] + + # cvar_dfs2.append(sub_cvar_df) # this will often turn out to be empty since if it gave issues before, it should give issues now + + # else: + + features_to_remove.append((cvar, sub_cvar, perf_sep_features)) + +## final try, remove the features causing issues and store their t-value as NaN + +sub_cvars = [] + +for cvar, sub_cvar, del_inds in features_to_remove: + selected_df = patient_latent[ + ~patient_latent[cvar].isna() + ].copy() # drop nan values for each var + selected_df[cvar] = list(map(int, selected_df[cvar] == sub_cvar)) + y = selected_df[cvar].to_numpy() + X = selected_df[ + latent_dim_cols + ].to_numpy() # select columns corresponding to latent dims and convert to numpy + del_inds = del_inds + X = np.delete(X, del_inds, axis=1) + print(X.shape) + X = sm.add_constant(X) # add constant + try: + log_reg = sm.Logit(y, X).fit() # fit the Logistic Regression model + print( + f"Completed: tvalues for {cvar}:{sub_cvar}, features till {i} -> {log_reg.tvalues}" + ) + + sub_cvar_df = pd.DataFrame({}) + sub_cvar_df["latent_dim"] = [c.split("_")[-1] for c in latent_dim_cols] + + tvalues = log_reg.tvalues[1:].tolist() # + [np.nan] + + for i in del_inds: + if i > len(tvalues): + tvalues = np.insert(tvalues, i - 1, np.nan) + else: + tvalues = np.insert(tvalues, i, np.nan) + + # tvalues = np.insert(tvalues, del_inds.remove(19), np.nan) + # assert len(log_reg.tvalues) == X.shape[1]+1 #for constant -- check this is the last one + + sub_cvar_df["tvalues"] = tvalues + + sub_cvar_df["clinical_variable"] = [f"{cvar}:{sub_cvar}"] * sub_cvar_df.shape[0] + + sub_cvars.append(sub_cvar_df) + except Exception as e: + print(f"{cvar}:{sub_cvar} for feature {i} has exception: {e}") + +sub_cvar_df1 = pd.concat(sub_cvars) + +full_clin_df = pd.concat(dfs).reset_index(drop=True) + +final_full_clin_df = pd.concat([full_clin_df, sub_cvar_df1]).reset_index(drop=True) + +final_full_clin_df = pd.pivot_table( + final_full_clin_df, + index="clinical_variable", + values="tvalues", + columns="latent_dim", +) # df that's plotted + + +# Patient / Cluster associations +# First we need to define cluster prevalance within a patient. Doing this in two ways: +# 1. How we were doing it before -- proportion of cells in patient x that belong to cluster c +# 2. Cells of cluster c per mm^2 of tissue + +clusters_patient = pd.merge( + adata.obs.reset_index()[["Sample_name", "leiden", "cell_id"]], + patient_data.reset_index(), + on="Sample_name", +) + +## Option 1: Proportion of cells in patient x that belong in cluster c + +hi_or_low = clusters_patient[[patient_col, cluster_col]] + +## Proportion of cells belonging to each cluster for each image / patient + +hi_or_low = hi_or_low.groupby([patient_col, cluster_col]).size().unstack(fill_value=0) + + +hi_or_low = hi_or_low.div(hi_or_low.sum(axis=1), axis=0).fillna(0) + + +hi_low_cluster_variables = ( + pd.merge( + hi_or_low.reset_index(), + clusters_patient[clinical_variables + [patient_col]], + on=patient_col, + ) + .drop_duplicates() + .reset_index(drop=True) +) + +prop_cluster_cols = [ + i + for i in hi_low_cluster_variables.columns + if i in clusters_patient[cluster_col].unique() +] +exception_variables = [] + +dfs = [] + +for cvar in clinical_variables: + cvar_dfs = [] + filtered_df = hi_low_cluster_variables[ + ~hi_low_cluster_variables[cvar].isna() + ].copy() # drop nan values for each var + + for sub_cvar in filtered_df[cvar].unique(): + print(cvar, sub_cvar) + selected_df = filtered_df.copy() + selected_df[cvar] = list(map(int, selected_df[cvar] == sub_cvar)) + sub_cvar_df = pd.DataFrame({}) + y = selected_df[ + cvar + ].to_numpy() # select the clinical variable column and convert to numpy -- no fillna(0) since all the nans should have been dropped + X = selected_df[ + prop_cluster_cols + ].to_numpy() # select columns corresponding to latent dims and convert to numpy + tvalues = {} + for cluster in range(X.shape[1]): + X1 = X[:, cluster] + X1 = sm.add_constant(X1) + try: + log_reg = sm.Logit(y, X1).fit() # fit the Logistic Regression model + + tvalues[cluster] = log_reg.tvalues[ + 1 + ] # there will be 2 t values, first one belongs to the constant + + except Exception as e: + exception_variables.append((cvar, sub_cvar, cluster, e)) + print( + f"{cvar}:{sub_cvar} had an exception occur for cluster {cluster}: {e}" + ) + + sub_cvar_df["cluster"] = list(tvalues.keys()) + + sub_cvar_df["tvalues"] = list(tvalues.values()) + + sub_cvar_df["clinical_variable"] = [f"{cvar}:{sub_cvar}"] * sub_cvar_df.shape[0] + + cvar_dfs.append(sub_cvar_df) + + full_cvar_dfs = pd.concat(cvar_dfs) + + dfs.append(full_cvar_dfs) + +full_cluster_clin_df = pd.concat(dfs).reset_index(drop=True) + +full_cluster_clin_df = pd.pivot_table( + full_cluster_clin_df, index="clinical_variable", values="tvalues", columns="cluster" +) # df that's plotted + +## Option 2: Number of cells per mm^2 tissue +# We're going to do this per image for now -- mainly because sizes might differ between images that belong to the same patient + +clinical_variables = clinical_variables + [ + "diseasestatus" +] # for basel, since doing per image + +cohort_dirs = { + "basel": ["OMEnMasks/Basel_Zuri_masks", "_a0_full_maks.tiff"], + "metabric": ["METABRIC_IMC/to_public_repository/cell_masks", "_cellmask.tiff"], + "melanoma": [ + "full_data/protein/cpout/", + "_ac_ilastik_s2_Probabilities_equalized_cellmask.tiff", + ], +} # directories with the masks + +adata_df = adata.obs.reset_index()[["cell_id", "Sample_name", "leiden"]] +clusters = adata_df.leiden.unique().tolist() + +sample_dfs = [] + +progress = Progress( + TextColumn(f"[progress.description]Finding cluster prevalances in {cohort}."), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), +) +with progress: + for sample in progress.track(adata.obs.Sample_name.unique()): + s_df = pd.DataFrame({}) + s_cluster_prevs = {} + mask = tifffile.imread( + f"../../../data/{cohort_dirs[cohort][0]}/{sample}{cohort_dirs[cohort][1]}" + ) + sample_df = adata_df.copy().query("Sample_name==@sample") + for cluster in clusters: + num_cells_in_sample = Counter(sample_df.leiden.tolist()) + num_cells_in_clusters = num_cells_in_sample[cluster] + + # print(num_cells_in_clusters) + # print(mask.shape[0] , mask.shape[1]) + + cluster_prevalance_per_mm2 = ( + num_cells_in_clusters / (mask.shape[0] * mask.shape[1]) + ) * 1e6 # scale, 1 pixel == 1 micron + + s_cluster_prevs[cluster] = cluster_prevalance_per_mm2 + + s_df["cluster"] = list(s_cluster_prevs.keys()) + s_df["prevalance_per_mm2_scaled_by_1e6"] = list(s_cluster_prevs.values()) + s_df["Sample_name"] = [sample] * s_df.shape[0] + + sample_dfs.append(s_df) + +full_cohort_df = pd.concat(sample_dfs) + +full_cohort_df["cluster"] = full_cohort_df["cluster"].map(int) + +full_cohort_df = pd.pivot_table( + full_cohort_df, + values="prevalance_per_mm2_scaled_by_1e6", + index="Sample_name", + columns="cluster", +) + +clusters = full_cohort_df.columns.tolist() # to make sure correct order later + +cluster_per_tissue_patient = pd.merge( + full_cohort_df, patient_data[clinical_variables + ["Sample_name"]], on="Sample_name" +) + +# The below is still being run and tested but this is close to what I will be doing + +cluster_cols = clusters +exception_variables = [] + +dfs = [] + +for cvar in clinical_variables: + cvar_dfs = [] + + for sub_cvar in cluster_per_tissue_patient[cvar].dropna().unique().tolist(): + print(cvar, sub_cvar) + sub_cvar_df = pd.DataFrame({}) + selected_df = cluster_per_tissue_patient.copy()[ + ~cluster_per_tissue_patient[cvar].isna() + ] # drop nan values for each var + selected_df[cvar] = list(map(int, selected_df[cvar] == sub_cvar)) + + X = selected_df[ + cluster_cols + ].to_numpy() # select columns corresponding to latent dims and convert to numpy + X = sm.add_constant(X) # add constant + y = selected_df[ + cvar + ].to_numpy() # select the clinical variable column and convert to numpy -- no fillna(0) since all the nans should have been dropped + try: + log_reg = sm.Logit(y, X).fit() # fit the Logistic Regression model + + sub_cvar_df["cluster"] = [c for c in cluster_cols] + + sub_cvar_df["tvalues"] = log_reg.tvalues[1:] # remove the constant + + sub_cvar_df["clinical_variable"] = [ + f"{cvar}:{sub_cvar}" + ] * sub_cvar_df.shape[0] + + cvar_dfs.append(sub_cvar_df) + except Exception as e: + exception_variables.append((cvar, sub_cvar)) + print(f"{cvar}:{sub_cvar} had an exception occur: {e}") + + full_cvar_dfs = pd.concat(cvar_dfs) + + dfs.append(full_cvar_dfs) From 36db03a3d4808a3144878e1fbf44c4349829e00a Mon Sep 17 00:00:00 2001 From: shanzaayub <66146596+shanzaayub@users.noreply.github.com> Date: Mon, 5 Jun 2023 14:09:54 -0400 Subject: [PATCH 18/18] Some comments added to clinical_associations_latent_dim_and_cluster_prevalence.py --- ...tions_latent_dim_and_cluster_prevalence.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/hmivae/clinical_associations_latent_dim_and_cluster_prevalence.py b/hmivae/clinical_associations_latent_dim_and_cluster_prevalence.py index d1d6209..20f749c 100644 --- a/hmivae/clinical_associations_latent_dim_and_cluster_prevalence.py +++ b/hmivae/clinical_associations_latent_dim_and_cluster_prevalence.py @@ -76,7 +76,10 @@ dfs = [] -for cvar in clinical_variables: +# cvar is clinical variable +# sub_cvar are the values the clincal variable can take on e.g. for cvar == ERStatus, sub_cvar == pos or sub_cvar == neg + +for cvar in clinical_variables: # using all latent dims for this pass cvar_dfs = [] for sub_cvar in patient_latent[cvar].unique(): @@ -85,13 +88,13 @@ selected_df = patient_latent.copy()[ ~patient_latent[cvar].isna() ] # drop nan values for each var - selected_df[cvar] = list(map(int, selected_df[cvar] == sub_cvar)) + selected_df[cvar] = list(map(int, selected_df[cvar] == sub_cvar)) # map 1 and 0 for entries that belong to the sub_cvar X = selected_df[ latent_dim_cols ].to_numpy() # select columns corresponding to latent dims and convert to numpy X = sm.add_constant(X) # add constant - y = selected_df[ + y = selected_df[ # this is the 0 and 1 col cvar ].to_numpy() # select the clinical variable column and convert to numpy -- no fillna(0) since all the nans should have been dropped try: @@ -130,7 +133,7 @@ ].to_numpy() # select columns corresponding to latent dims and convert to numpy perf_sep_features = [] - for i in range(X.shape[1]): + for i in range(X.shape[1]): # introduce each latent dim one at a time to see which caused issues X_1 = X.copy()[:, 0 : i + 1] X_1 = sm.add_constant(X_1) # add constant try: @@ -141,7 +144,7 @@ # print(log_reg.summary()) except Exception as e: print(f"{cvar}:{sub_cvar} for feature {i} has exception: {e}") - perf_sep_features.append(i) + perf_sep_features.append(i) # store the issue causing latent dim # if len(perf_sep_features) == 0: # sub_cvar_df = pd.DataFrame({}) @@ -173,11 +176,11 @@ latent_dim_cols ].to_numpy() # select columns corresponding to latent dims and convert to numpy del_inds = del_inds - X = np.delete(X, del_inds, axis=1) + X = np.delete(X, del_inds, axis=1) # delete the issue causing latent dims from full set print(X.shape) X = sm.add_constant(X) # add constant try: - log_reg = sm.Logit(y, X).fit() # fit the Logistic Regression model + log_reg = sm.Logit(y, X).fit() # fit the Logistic Regression model on the remaining latent dims print( f"Completed: tvalues for {cvar}:{sub_cvar}, features till {i} -> {log_reg.tvalues}" ) @@ -187,7 +190,7 @@ tvalues = log_reg.tvalues[1:].tolist() # + [np.nan] - for i in del_inds: + for i in del_inds: # for latent dims that caused issues, store their tvalues as nan so we know which ones didn't work if i > len(tvalues): tvalues = np.insert(tvalues, i - 1, np.nan) else: @@ -238,7 +241,7 @@ hi_or_low = hi_or_low.groupby([patient_col, cluster_col]).size().unstack(fill_value=0) -hi_or_low = hi_or_low.div(hi_or_low.sum(axis=1), axis=0).fillna(0) +hi_or_low = hi_or_low.div(hi_or_low.sum(axis=1), axis=0).fillna(0) # get proportion of each cluster in each patient (all will sum to 1) hi_low_cluster_variables = ( @@ -278,7 +281,7 @@ prop_cluster_cols ].to_numpy() # select columns corresponding to latent dims and convert to numpy tvalues = {} - for cluster in range(X.shape[1]): + for cluster in range(X.shape[1]): # do each cluster one by one since these add up to 1 and Logit won't work X1 = X[:, cluster] X1 = sm.add_constant(X1) try: @@ -345,18 +348,18 @@ s_cluster_prevs = {} mask = tifffile.imread( f"../../../data/{cohort_dirs[cohort][0]}/{sample}{cohort_dirs[cohort][1]}" - ) + ) # get dims of the image sample_df = adata_df.copy().query("Sample_name==@sample") for cluster in clusters: num_cells_in_sample = Counter(sample_df.leiden.tolist()) - num_cells_in_clusters = num_cells_in_sample[cluster] + num_cells_in_clusters = num_cells_in_sample[cluster] # get number of cells belong to each cluster for each image # print(num_cells_in_clusters) # print(mask.shape[0] , mask.shape[1]) cluster_prevalance_per_mm2 = ( num_cells_in_clusters / (mask.shape[0] * mask.shape[1]) - ) * 1e6 # scale, 1 pixel == 1 micron + ) * 1e6 # scale, 1 pixel == 1 micron, get the prevalence and scale s_cluster_prevs[cluster] = cluster_prevalance_per_mm2 @@ -409,7 +412,7 @@ cvar ].to_numpy() # select the clinical variable column and convert to numpy -- no fillna(0) since all the nans should have been dropped try: - log_reg = sm.Logit(y, X).fit() # fit the Logistic Regression model + log_reg = sm.Logit(y, X).fit() # fit the Logistic Regression model, doing them altogether this time, not one by one because don't need to sub_cvar_df["cluster"] = [c for c in cluster_cols] @@ -421,7 +424,7 @@ cvar_dfs.append(sub_cvar_df) except Exception as e: - exception_variables.append((cvar, sub_cvar)) + exception_variables.append((cvar, sub_cvar)) # I keep a track of the exception variables but I don't deal with them print(f"{cvar}:{sub_cvar} had an exception occur: {e}") full_cvar_dfs = pd.concat(cvar_dfs)