-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
lep
committed
Nov 3, 2023
0 parents
commit 1a6de4a
Showing
30 changed files
with
7,201 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
.vscode | ||
__pycache__/ | ||
README.md |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2023 Erpai Luo | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
import numpy as np | ||
import torch | ||
from torch import nn | ||
|
||
class NBLoss(torch.nn.Module): | ||
def __init__(self): | ||
super(NBLoss, self).__init__() | ||
|
||
def forward(self, mu, y, theta, eps=1e-8): | ||
"""Negative binomial negative log-likelihood. It assumes targets `y` with n | ||
rows and d columns, but estimates `yhat` with n rows and 2d columns. | ||
The columns 0:d of `yhat` contain estimated means, the columns d:2*d of | ||
`yhat` contain estimated variances. This module assumes that the | ||
estimated mean and inverse dispersion are positive---for numerical | ||
stability, it is recommended that the minimum estimated variance is | ||
greater than a small number (1e-3). | ||
Parameters | ||
---------- | ||
yhat: Tensor | ||
Torch Tensor of reeconstructed data. | ||
y: Tensor | ||
Torch Tensor of ground truth data. | ||
eps: Float | ||
numerical stability constant. | ||
""" | ||
if theta.ndimension() == 1: | ||
# In this case, we reshape theta for broadcasting | ||
theta = theta.view(1, theta.size(0)) | ||
log_theta_mu_eps = torch.log(theta + mu + eps) | ||
res = ( | ||
theta * (torch.log(theta + eps) - log_theta_mu_eps) | ||
+ y * (torch.log(mu + eps) - log_theta_mu_eps) | ||
+ torch.lgamma(y + theta) | ||
- torch.lgamma(theta) | ||
- torch.lgamma(y + 1) | ||
) | ||
res = _nan2inf(res) | ||
return -torch.mean(res) | ||
|
||
def _nan2inf(x): | ||
return torch.where(torch.isnan(x), torch.zeros_like(x) + np.inf, x) | ||
|
||
class MLP(torch.nn.Module): | ||
""" | ||
A multilayer perceptron with ReLU activations and optional BatchNorm. | ||
""" | ||
|
||
def __init__(self, sizes, batch_norm=True, last_layer_act="linear"): | ||
super(MLP, self).__init__() | ||
layers = [] | ||
for s in range(len(sizes) - 1): | ||
layers += [ | ||
torch.nn.Linear(sizes[s], sizes[s + 1]), | ||
torch.nn.LayerNorm(sizes[s + 1]) | ||
if batch_norm and s < len(sizes) - 2 | ||
else None, | ||
torch.nn.ReLU(), | ||
] | ||
|
||
layers = [l for l in layers if l is not None][:-1] | ||
self.activation = last_layer_act | ||
if self.activation == "linear": | ||
pass | ||
elif self.activation == "ReLU": | ||
self.relu = torch.nn.ReLU() | ||
else: | ||
raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'") | ||
|
||
self.network = torch.nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
if self.activation == "ReLU": | ||
x = self.network(x) | ||
return self.relu(x) | ||
return self.network(x) | ||
|
||
|
||
class VAE(torch.nn.Module): | ||
""" | ||
Autoencoder | ||
""" | ||
def __init__( | ||
self, | ||
num_genes, | ||
device="cuda", | ||
seed=0, | ||
decoder_activation="linear", | ||
hparams="", | ||
): | ||
super(VAE, self).__init__() | ||
# set generic attributes | ||
self.num_genes = num_genes | ||
self.device = device | ||
self.seed = seed | ||
# early-stopping | ||
self.best_score = -1e3 | ||
self.patience_trials = 0 | ||
|
||
# set hyperparameters | ||
self.set_hparams_(hparams) | ||
|
||
# set models | ||
self.encoder = MLP( | ||
[num_genes] | ||
+ [6000] | ||
+ [self.hparams["dim"]] | ||
) | ||
|
||
self.decoder = MLP( | ||
[self.hparams["dim"]] | ||
+ [6000] + [12000] | ||
+ [num_genes], | ||
last_layer_act=decoder_activation, | ||
) | ||
|
||
# losses | ||
self.loss_autoencoder = nn.MSELoss(reduction='mean') | ||
|
||
self.iteration = 0 | ||
|
||
self.to(self.device) | ||
|
||
# optimizers | ||
get_params = lambda model, cond: list(model.parameters()) if cond else [] | ||
_parameters = ( | ||
get_params(self.encoder, True) | ||
+ get_params(self.decoder, True) | ||
) | ||
self.optimizer_autoencoder = torch.optim.Adam( | ||
_parameters, | ||
lr=self.hparams["autoencoder_lr"], | ||
weight_decay=self.hparams["autoencoder_wd"], | ||
) | ||
|
||
self.normalize_total = Normalize_total() | ||
|
||
def forward(self, genes, return_latent=False, return_decoded=False): | ||
""" | ||
If return_latent=True, act as encoder only. If return_decoded, genes should | ||
be the latent representation and this act as decoder only. | ||
""" | ||
if return_decoded: | ||
gene_reconstructions = self.decoder(genes) | ||
return gene_reconstructions | ||
|
||
latent = self.encoder(genes) | ||
if return_latent: | ||
return latent | ||
|
||
gene_reconstructions = self.decoder(latent) | ||
|
||
return gene_reconstructions | ||
|
||
def set_hparams_(self, hparams): | ||
""" | ||
Set hyper-parameters to default values or values fixed by user. | ||
""" | ||
|
||
self.hparams = { | ||
"dim": 1000, | ||
"autoencoder_width": 5000, | ||
"autoencoder_depth": 3, | ||
"adversary_lr": 3e-4, | ||
"autoencoder_wd": 4e-7, #4e-7 | ||
"autoencoder_lr": 1e-5, #1e-5 | ||
} | ||
|
||
return self.hparams | ||
|
||
|
||
def train(self, genes): | ||
""" | ||
Train VAE. | ||
""" | ||
genes = genes.to(self.device) | ||
gene_reconstructions = self.forward(genes) | ||
|
||
reconstruction_loss = self.loss_autoencoder(gene_reconstructions, genes) | ||
|
||
self.optimizer_autoencoder.zero_grad() | ||
reconstruction_loss.backward() | ||
self.optimizer_autoencoder.step() | ||
|
||
self.iteration += 1 | ||
|
||
return { | ||
"loss_reconstruction": reconstruction_loss.item(), | ||
} | ||
|
||
|
||
|
||
class Normalize_total(nn.Module): | ||
def __init__(self, target_sum=1e4): | ||
super(Normalize_total,self).__init__() | ||
self.target_sum = target_sum | ||
|
||
def forward(self, adata): | ||
counts_per_cell = adata.sum(axis=1) | ||
scale_factor = self.target_sum / counts_per_cell | ||
norm_adata = adata * scale_factor[:, np.newaxis] | ||
|
||
return norm_adata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import argparse | ||
import os | ||
import time | ||
|
||
import numpy as np | ||
import torch | ||
from VAE_model import VAE, MLP | ||
import sys | ||
sys.path.append("..") | ||
from guided_diffusion.cell_datasets_muris import load_data | ||
|
||
torch.autograd.set_detect_anomaly(True) | ||
import random | ||
|
||
def seed_everything(seed): | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
torch.backends.cudnn.deterministic = True | ||
|
||
|
||
def prepare_vae(args, state_dict=None): | ||
""" | ||
Instantiates autoencoder and dataset to run an experiment. | ||
""" | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
datasets = load_data( | ||
data_dir=args["data_dir"], | ||
batch_size=args["batch_size"], | ||
vae=True, | ||
ae_dir=args["save_dir"], | ||
num_gene=args["num_genes"], | ||
) | ||
|
||
autoencoder = VAE( | ||
num_genes=args["num_genes"], | ||
device=device, | ||
seed=args["seed"], | ||
hparams="", | ||
decoder_activation=args["decoder_activation"], | ||
) | ||
if state_dict is not None: | ||
autoencoder.load_state_dict(state_dict) | ||
|
||
return autoencoder, datasets | ||
|
||
|
||
def train_vae(args, return_model=False): | ||
""" | ||
Trains a autoencoder | ||
""" | ||
|
||
autoencoder, datasets = prepare_vae(args) | ||
|
||
args["hparams"] = autoencoder.hparams | ||
|
||
start_time = time.time() | ||
for step in range(args["max_steps"]): | ||
|
||
genes, _ = next(datasets) | ||
|
||
minibatch_training_stats = autoencoder.train(genes) | ||
|
||
if step % 1000 == 0: | ||
for key, val in minibatch_training_stats.items(): | ||
print('step ', step, 'loss ', val) | ||
|
||
ellapsed_minutes = (time.time() - start_time) / 60 | ||
|
||
stop = ellapsed_minutes > args["max_minutes"] or ( | ||
step == args["max_steps"] - 1 | ||
) | ||
|
||
if ((step % args["checkpoint_freq"]) == 0 or stop): | ||
|
||
os.makedirs(args["save_dir"],exist_ok=True) | ||
torch.save( | ||
autoencoder.state_dict(), | ||
os.path.join( | ||
args["save_dir"], | ||
"model_seed={}_step={}.pt".format(args["seed"], step), | ||
), | ||
) | ||
|
||
if stop: | ||
break | ||
|
||
if return_model: | ||
return autoencoder, datasets | ||
|
||
|
||
def parse_arguments(): | ||
""" | ||
Read arguments if this script is called from a terminal. | ||
""" | ||
parser = argparse.ArgumentParser(description="Autoencoder for gene expression") | ||
# dataset arguments | ||
parser.add_argument("--data_dir", type=str, default='/data1/lep/Workspace/guided-diffusion/data/tabula_muris/all.h5ad') | ||
parser.add_argument("--loss_ae", type=str, default="mse") | ||
parser.add_argument("--decoder_activation", type=str, default="ReLU") | ||
|
||
# CPA arguments (see set_hparams_() in cpa.model.CPA) | ||
parser.add_argument("--local_rank", type=int, default=0) | ||
parser.add_argument("--split_seed", type=int, default=1234) | ||
parser.add_argument("--num_genes", type=int, default=18996)# gene numbers after quality control | ||
parser.add_argument("--seed", type=int, default=0) | ||
parser.add_argument("--hparams", type=str, default="") | ||
|
||
# training arguments | ||
parser.add_argument("--max_steps", type=int, default=1000000) | ||
parser.add_argument("--max_minutes", type=int, default=3000) | ||
parser.add_argument("--checkpoint_freq", type=int, default=200000) | ||
parser.add_argument("--batch_size", type=int, default=128) | ||
|
||
parser.add_argument("--save_dir", type=str, default='../checkpoint/AE/my_AE') | ||
parser.add_argument("--sweep_seeds", type=int, default=200) | ||
return dict(vars(parser.parse_args())) | ||
|
||
|
||
if __name__ == "__main__": | ||
seed_everything(1234) | ||
train_vae(parse_arguments()) |
Empty file.
Oops, something went wrong.