diff --git a/.gitignore b/.gitignore index a8759d7b..352004e3 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,12 @@ dmypy.json # Adding logs folder logs/* -!logs/.gitkeep + +# Hydra outputs folder +outputs/* + +# Checkpoints folder +dnadiffusion/checkpoints/* + +# Wandb folder +wandb/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cf09a2da..55528670 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,7 @@ repos: rev: 5.12.0 hooks: - id: isort - args: ["--profile", "black", "line-length=88"] + args: ["--profile", "black", "line_length=88"] - repo: https://github.com/asottile/pyupgrade rev: v3.3.1 hooks: diff --git a/dnadiffusion/callbacks/ema.py b/dnadiffusion/callbacks/ema.py new file mode 100644 index 00000000..fc6d963f --- /dev/null +++ b/dnadiffusion/callbacks/ema.py @@ -0,0 +1,57 @@ +from typing import Any + +import lightning as L +import torch +from lightning.pytorch.utilities import rank_zero_only + +# Check this link +# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py + + +class EMA(L.Callback): + def init( + self, + beta: float = 0.995, + ) -> None: + self.beta = beta + self.step = 0 + self.step_start_ema = 100 + + def on_train_batch_end( + self, + trainer: L.Trainer, + pl_module: L.LightningModule, + ) -> None: + self.step_ema(pl_module.ema_model, pl_module.model) + + def update_model_average( + self, + ma_model: L.LightningModule, + current_model: L.LightningModule, + ) -> None: + for current_params, ma_params in zip( + current_model.parameters(), ma_model.parameters() + ): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old: torch.Tensor, new: torch.Tensor) -> torch.Tensor: + if old is None: + return new + old = old + return old * self.beta + (1 - self.beta) * new + + def step_ema( + self, + ema_model: L.LightningModule, + model: L.LightningModule, + ) -> None: + if self.step < self.step_start_ema: + self.reset_parameters(ema_model, model) + self.step += 1 + return + self.update_model_average(ema_model, model) + self.step += 1 + + def reset_parameters(self, ema_model, model): + ema_model.load_state_dict(model.state_dict()) diff --git a/dnadiffusion/callbacks/sampling.py b/dnadiffusion/callbacks/sampling.py new file mode 100644 index 00000000..5fccb453 --- /dev/null +++ b/dnadiffusion/callbacks/sampling.py @@ -0,0 +1,66 @@ +import lightning as L +import pandas as pd +from lightning.pytorch.utilities import rank_zero_only + +from dnadiffusion.metrics.sampling_metrics import ( + compare_motif_list, + generate_similarity_using_train, + sampling_to_metric, +) + + +class Sample(L.Callback): + def __init__( + self, + data_module: L.LightningDataModule, + image_size: int, + num_sampling_to_compare_cells: int, + ) -> None: + self.data_module = data_module + self.image_size = image_size + self.number_sampling_to_compare_cells = num_sampling_to_compare_cells + + def on_train_start(self, *args, **kwargs) -> None: + self.X_train = self.data_module.X_train + self.train_motifs = self.data_module.train_motifs + self.test_motifs = self.data_module.test_motifs + self.shuffle_motifs = self.data_module.shuffle_motifs + self.cell_types = self.data_module.cell_types + self.numeric_to_tag = self.data_module.numeric_to_tag + + @rank_zero_only + def on_train_epoch_end(self, trainer: L.Trainer, L_module: L.LightningModule): + if (trainer.current_epoch + 1) % 15 == 0: + L_module.eval() + additional_variables = { + "model": L_module.model, + "timesteps": L_module.timesteps, + "device": L_module.device, + "betas": L_module.betas, + "sqrt_one_minus_alphas_cumprod": L_module.sqrt_one_minus_alphas_cumprod, + "sqrt_recip_alphas": L_module.sqrt_recip_alphas, + "posterior_variance": L_module.posterior_variance, + "image_size": self.image_size, + } + + synt_df = sampling_to_metric( + self.cell_types, + self.numeric_to_tag, + additional_variables, + int(self.number_sampling_to_compare_cells / 10), + ) + seq_similarity = generate_similarity_using_train(self.X_train) + train_kl = compare_motif_list(synt_df, self.train_motifs) + test_kl = compare_motif_list(synt_df, self.test_motifs) + shuffle_kl = compare_motif_list(synt_df, self.shuffle_motifs) + L_module.train() + + trainer.logger.log_metrics( + { + "train_kl": train_kl, + "test_kl": test_kl, + "shuffle_kl": shuffle_kl, + "seq_similarity": seq_similarity, + }, + step=trainer.global_step, + ) diff --git a/dnadiffusion/configs.py b/dnadiffusion/configs.py new file mode 100644 index 00000000..f4b851dc --- /dev/null +++ b/dnadiffusion/configs.py @@ -0,0 +1,101 @@ +import lightning as L +import torch +import torchvision.transforms as T +from hydra.core.config_store import ConfigStore +from hydra_zen import MISSING, builds, make_custom_builds_fn +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger + +from dnadiffusion.callbacks.sampling import Sample +from dnadiffusion.data.dataloader import LoadingDataModule +from dnadiffusion.models.training_modules import UnetDiffusion +from dnadiffusion.models.unet import Unet as UnetBase + +# Custom Builds Function +sbuilds = make_custom_builds_fn(populate_full_signature=True) +pbuilds = make_custom_builds_fn(zen_partial=True, populate_full_signature=True) + +# Transforms config if we need to add more +# transforms = builds(T.Compose, [builds(T.ToTensor)]) +transforms = builds(T.ToTensor) + +# Loading data +LoadingData = builds( + LoadingDataModule, + input_csv="./dnadiffusion/data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt", + subset_components=[ + "GM12878_ENCLB441ZZZ", + "hESCT0_ENCLB449ZZZ", + "K562_ENCLB843GMH", + "HepG2_ENCLB029COU", + ], + load_saved_data=False, + change_component_index=True, + number_of_sequences_to_motif_creation=1000, + transform=transforms, + populate_full_signature=True, +) + +# Diffusion Model +Unet = builds( + UnetBase, + dim=200, + init_dim=None, + dim_mults=(1, 2, 4), + channels=1, + resnet_block_groups=4, + learned_sinusoidal_dim=10, + num_classes=10, + output_attention=False, +) + +# Optimizers +Adam = pbuilds(torch.optim.Adam, lr=1e-3) + +# Lightning Module +UnetConfig = builds( + UnetDiffusion, + model=Unet, + lr=1e-3, + timesteps=50, + beta=0.995, + optimizer=Adam, +) + +# Callbacks +sample = builds( + Sample, + data_module=MISSING, + image_size=200, + num_sampling_to_compare_cells=1000, +) + +wandb = builds( + WandbLogger, + project="dnadiffusion", + notes="lightning", +) + +checkpoint = builds( + ModelCheckpoint, + dirpath="dnadiffusion/checkpoints/", + every_n_epochs=500, +) + +# Lightning Trainer +LightningTrainer = builds( + L.Trainer, + accelerator="cuda", + strategy="ddp_find_unused_parameters_true", + num_nodes=1, + devices=8, + max_epochs=10000, + logger=wandb, + callbacks=[checkpoint], +) + +# Registering the builds +cs = ConfigStore.instance() + +cs.store(group="data", name="LoadingData", node=LoadingData) +cs.store(group="model", name="Unet", node=UnetConfig) diff --git a/dnadiffusion/data/dataloader.py b/dnadiffusion/data/dataloader.py index 51e7cfe1..3475cfc5 100644 --- a/dnadiffusion/data/dataloader.py +++ b/dnadiffusion/data/dataloader.py @@ -1,11 +1,18 @@ import os +import pickle import random from typing import Any, Dict, List, Optional, Tuple +import lightning as L import matplotlib.pyplot as plt +import numpy as np import pandas as pd +import torch import torchvision.transforms as T -from torch.utils.data import Dataset +from lightning.pytorch.utilities import rank_zero_only +from torch.utils.data import DataLoader, Dataset + +from dnadiffusion.utils.utils import one_hot_encode def motifs_from_fasta(fasta: str): @@ -35,6 +42,7 @@ def motifs_from_fasta(fasta: str): return df_results_seq_guime_count_out +@rank_zero_only class LoadingData: def __init__( self, @@ -50,23 +58,21 @@ def __init__( self.sample_number = sample_number self.subset_components = subset_components self.change_comp_index = change_component_index - self.data = self.read_csv() - self.df_generate = self.experiment() - ( - self.df_train_in, - self.df_test_in, - self.df_train_shuffled_in, - ) = self.create_train_groups() self.number_of_sequences_to_motif_creation = ( number_of_sequences_to_motif_creation ) - self.train = None - self.test = None - self.train_shuffle = None - self.get_motif() - def read_csv(self) -> pd.DataFrame: - df = pd.read_csv(self.csv, sep="\t") + def __call__(self) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + df = self.read_csv(self.csv) + subset_df = self.experiment(df) + df_train, df_test, df_train_shuffled = self.create_train_groups(subset_df) + train, test, train_shuffle = self.get_motif( + df_train, df_test, df_train_shuffled + ) + return train, test, train_shuffle + + def read_csv(self, input_csv: str) -> pd.DataFrame: + df = pd.read_csv(input_csv, sep="\t") if self.change_comp_index: df["component"] = df["component"] + 1 @@ -74,12 +80,10 @@ def read_csv(self) -> pd.DataFrame: print(f"Limiting total sequences {self.limit_total_sequences}") df = df.sample(self.limit_total_sequences) - # change this in simon original table - df.columns = [c.replace("seqname", "chr") for c in df.columns.values] return df - def experiment(self) -> pd.DataFrame: - df_generate = self.data.copy() + def experiment(self, df: pd.DataFrame) -> pd.DataFrame: + df_generate = df if self.subset_components is not None and type(self.subset_components) == list: print(" or ".join([f"TAG == {c}" for c in self.subset_components])) df_generate = df_generate.query( @@ -89,24 +93,32 @@ def experiment(self) -> pd.DataFrame: return df_generate - def create_train_groups(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: - # solve it inside the simon dataloader - df_sampled = self.df_generate.query('chr != "chr1" ') - df_train = df_sampled.query('chr != "chr2" ') - df_test = self.df_generate.query('chr == "chr1" ') - df_train_shuffled = df_sampled.query('chr == "chr2" ') + def create_train_groups( + self, df: pd.DataFrame + ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + df_test = df[df["chr"] == "chr1"].reset_index(drop=True) + df_train_shuffled = df[df["chr"] == "chr2"].reset_index(drop=True) + df_train = df_train = df[ + (df["chr"] != "chr1") & (df["chr"] != "chr2") + ].reset_index(drop=True) df_train_shuffled["sequence"] = df_train_shuffled["sequence"].apply( lambda x: "".join(random.sample(list(x), len(x))) ) return df_train, df_test, df_train_shuffled - def get_motif(self) -> None: - self.train = self.generate_motifs_and_fastas(self.df_train_in, "train") - self.test = self.generate_motifs_and_fastas(self.df_test_in, "test") - self.train_shuffle = self.generate_motifs_and_fastas( - self.df_train_shuffled_in, "train_shuffle" + def get_motif( + self, + df_train: pd.DataFrame, + df_test: pd.DataFrame, + df_train_shuffled: pd.DataFrame, + ) -> None: + train = self.generate_motifs_and_fastas(df_train, "train") + test = self.generate_motifs_and_fastas(df_test, "test") + train_shuffle = self.generate_motifs_and_fastas( + df_train_shuffled, "train_shuffle" ) + return train, test, train_shuffle def generate_motifs_and_fastas(self, df: pd.DataFrame, name: str) -> Dict[str, Any]: """return fasta anem , and dict with components motifs""" @@ -189,3 +201,108 @@ def __getitem__(self, index): y = self.c[index] return x, y + + +class LoadingDataModule(L.LightningDataModule): + def __init__( + self, + input_csv: str, + subset_components: list, + load_saved_data: bool = False, + sample_number: int = 0, + change_component_index: bool = True, + limit_total_sequences: Optional[int] = None, + number_of_sequences_to_motif_creation: Optional[int] = None, + transform: Optional[T.Compose] = None, + batch_size: int = 30, + shuffle: bool = True, + num_workers: int = 48, + pin_memory: bool = True, + ) -> None: + super().__init__() + self.input_csv = input_csv + self.subset_components = subset_components + self.load_saved_data = load_saved_data + self.sample_number = sample_number + self.change_component_index = change_component_index + self.limit_total_sequences = limit_total_sequences + self.number_of_sequences_to_motif_creation = ( + number_of_sequences_to_motif_creation + ) + self.transform = transform + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.pin_memory = pin_memory + + def prepare_data(self) -> None: + if not self.load_saved_data: + print("Loading data") + encode_data = LoadingData( + self.input_csv, + self.subset_components, + self.sample_number, + self.change_component_index, + self.limit_total_sequences, + self.number_of_sequences_to_motif_creation, + ) + train, test, train_shuffle = encode_data() + combined_dict = { + "train": train, + "test": test, + "train_shuffle": train_shuffle, + } + with open("dnadiffusion/data/encode_data.pkl", "wb") as f: + pickle.dump(combined_dict, f) + + def setup(self, stage: Optional[str] = None) -> None: + with open("dnadiffusion/data/encode_data.pkl", "rb") as f: + encode_data = pickle.load(f) + train = encode_data["train"] + test = encode_data["test"] + train_shuffle = encode_data["train_shuffle"] + + # Getting motif related data from encode_data + self.train_motifs = train["motifs"] + self.test_motifs = test["motifs"] + self.shuffle_motifs = train_shuffle["motifs"] + + self.train_motifs_per_components_dict = train["motifs_per_components_dict"] + self.test_motifs_per_components_dict = test["motifs_per_components_dict"] + self.shuffle_motifs_per_components_dict = train_shuffle[ + "motifs_per_components_dict" + ] + + # Sequence related data + df = train["dataset"] + self.cell_components = df.sort_values("TAG")["TAG"].unique().tolist() + self.tag_to_numeric = {x: n + 1 for n, x in enumerate(df.TAG.unique())} + self.numeric_to_tag = {n + 1: x for n, x in enumerate(df.TAG.unique())} + self.cell_types = sorted(list(self.numeric_to_tag.keys())) + self.x_train_cell_type = torch.from_numpy( + df["TAG"].apply(lambda x: self.tag_to_numeric[x]).to_numpy() + ) + nucleotides = ["A", "C", "G", "T"] + X_train = np.array( + [ + one_hot_encode(x, nucleotides, 200) + for x in (df["sequence"]) + if "N" not in x + ] + ) + X_train = np.array([x.T.tolist() for x in X_train]) + X_train[X_train == 0] = -1 + self.X_train = X_train + + def train_dataloader(self) -> DataLoader: + return DataLoader( + SequenceDataset( + self.X_train, + self.x_train_cell_type, + self.transform, + ), + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) diff --git a/dnadiffusion/data/encode_data.pkl b/dnadiffusion/data/encode_data.pkl new file mode 100644 index 00000000..051b7da1 Binary files /dev/null and b/dnadiffusion/data/encode_data.pkl differ diff --git a/dnadiffusion/metrics/metrics.py b/dnadiffusion/metrics/sampling_metrics.py similarity index 100% rename from dnadiffusion/metrics/metrics.py rename to dnadiffusion/metrics/sampling_metrics.py diff --git a/dnadiffusion/metrics/validation_metrics.py b/dnadiffusion/metrics/validation_metrics.py new file mode 100644 index 00000000..ba9a4411 --- /dev/null +++ b/dnadiffusion/metrics/validation_metrics.py @@ -0,0 +1,59 @@ +from collections import defaultdict + +import numpy as np +import pandas as pd +from Bio import SeqIO +from sourmash import MinHash + + +def _create_pandas_series_from_a_fasta_file(fasta_file: str) -> pd.Series: + """Create a pandas series from a fasta file. Input is a fasta file""" + data = defaultdict(list) + for record in SeqIO.parse(fasta_file, "fasta"): + seq = record.seq + data["sequence"].append(seq) + df = pd.DataFrame.from_dict(data) + + +def _create_mini_hash_of_a_sequence(seq: str, minihash: MinHash) -> MinHash: + """Create a minihash of a sequence. Input is a sequence and a minihash object""" + for k in seq: + minihash.add_sequence(k) + return minihash + + +def _compare_two_sequences_and_return_similarity( + seq: str, seq2: str, k: int, n: int +) -> float: + """Calculate similarity of two sequences. Input is 2 sequences, k size of kmer and n number of hashes""" + mh1 = MinHash(n=n, ksize=k) + mh2 = MinHash(n=n, ksize=k) + mh1 = _create_mini_hash_of_a_sequence(seq, mh1) + mh2 = _create_mini_hash_of_a_sequence(seq2, mh2) + similarity = round(mh1.similarity(mh2), 5) + return similarity + + +def average_jaccard_similarity( + seq: any, + seq2: any, + number_of_hashes: int = 20000, + k_sizes: list = [3, 7, 20], + is_fasta=False, +) -> float: + """Calculate average similarity of two sequences. Input is 2 sequences, k sizes of kmer, n number of hashes and a boolean to indicate + if the input is a fasta file or not""" + if is_fasta: + seq = _create_pandas_series_from_a_fasta_file(seq) + seq2 = _create_pandas_series_from_a_fasta_file(seq2) + average_similarities = [] + sequence_1 = seq.tolist() + sequence_2 = seq2.tolist() + for k in k_sizes: + similarity = _compare_two_sequences_and_return_similarity( + sequence_1, sequence_2, k, number_of_hashes + ) + average_similarities.append(similarity) + average_similarities = np.array(average_similarities) + average_similarity = round(average_similarities.mean(), 3) + return average_similarity diff --git a/dnadiffusion/models/diffusion.py b/dnadiffusion/models/diffusion.py index eccc39ac..eea1427c 100644 --- a/dnadiffusion/models/diffusion.py +++ b/dnadiffusion/models/diffusion.py @@ -57,6 +57,7 @@ def p_losses( classes = classes * context_mask # nn.Embedding needs type to be long, multiplying with mask changes type classes = classes.type(torch.long) + t = t.to(device) predicted_noise = denoise_model(x_noisy, t, classes) if loss_type == "l1": diff --git a/dnadiffusion/models/modules.py b/dnadiffusion/models/modules.py new file mode 100644 index 00000000..cf3976c4 --- /dev/null +++ b/dnadiffusion/models/modules.py @@ -0,0 +1,306 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import einsum + +from dnadiffusion.utils.utils import default, exists, l2norm + + +class SinusoidalPositionEmbeddings(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + def forward(self, time: torch.Tensor): + half_dim = self.dim // 2 + embeddings = math.log(10000) / (half_dim - 1) + embeddings = torch.exp(torch.arange(half_dim) * -embeddings) + + embeddings = time[:, None] * embeddings[None, :] + embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) + return embeddings + + +class ResBlock(nn.Module): + + """ + Iniialize a residual block with two convolutions followed by batchnorm layers + """ + + def __init__(self, in_size: int, hidden_size: int, out_size: int) -> None: + super().__init__() + self.conv1 = nn.Conv2d(in_size, hidden_size, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_size, out_size, 3, padding=1) + self.batchnorm1 = nn.BatchNorm2d(hidden_size) + self.batchnorm2 = nn.BatchNorm2d(out_size) + + def convblock(self, x: torch.Tensor): + x = F.relu(self.batchnorm1(self.conv1(x))) + x = F.relu(self.batchnorm2(self.conv2(x))) + return x + + """ + Combine output with the original input + """ + + def forward(self, x): + return x + self.convblock(x) + + +class ConvBlock_2d(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 4, padding=2), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, 4, 1, 1), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.zeros_(m.bias) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) + x = self.conv2(x) + + return x + + +class EmbedFC(nn.Module): + def __init__(self, input_dim: int, emb_dim: int) -> None: + super().__init__() + """ + generic one layer FC NN for embedding things + """ + self.input_dim = input_dim + layers = [nn.Linear(input_dim, emb_dim), nn.GELU(), nn.Linear(emb_dim, emb_dim)] + self.model = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor): + return self.model(x) + + +class Residual(nn.Module): + def __init__(self, fn) -> None: + super().__init__() + self.fn = fn + + def forward(self, x: torch.Tensor, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +def Upsample(dim: int, dim_out: Optional[int]): + return nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(dim, default(dim_out, dim), 3, padding=1), + ) + + +def Downsample(dim: int, dim_out: Optional[int]): + return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) + + +class LayerNorm(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) * (var + eps).rsqrt() * self.g + + +class PreNorm(nn.Module): + def __init__(self, dim: int, fn: nn.Module) -> None: + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x): + x = self.norm(x) + return self.fn(x) + + +# positional embeds + + +class LearnedSinusoidalPosEmb(nn.Module): + """following @crowsonkb 's lead with learned sinusoidal pos emb""" + + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim: int) -> None: + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x: torch.Tensor): + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +# building block modules + + +class Block(nn.Module): + def __init__(self, dim: int, dim_out: int, groups: int = 8) -> None: + super().__init__() + self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x: torch.Tensor, scale_shift: Optional[torch.Tensor] = None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, dim: int, dim_out: int, *, time_emb_dim: Optional[int], groups: int = 8 + ) -> None: + super().__init__() + self.mlp = ( + (nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))) + if exists(time_emb_dim) + else None + ) + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x: torch.Tensor, time_emb: Optional[torch.Tensor] = None): + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, "b c -> b c 1 1") + scale_shift = time_emb.chunk(2, dim=1) + + h = self.block1(x, scale_shift=scale_shift) + + h = self.block2(h) + + return h + self.res_conv(x) + + +class LinearAttention(nn.Module): + def __init__(self, dim: int, heads: int = 4, dim_head: int = 32) -> None: + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) + + def forward(self, x: torch.Tensor): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + + q = q * self.scale + v = v / (h * w) + + context = torch.einsum("b h d n, b h e n -> b h d e", k, v) + + out = torch.einsum("b h d e, b h d n -> b h e n", context, q) + out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) + return self.to_out(out) + + +class Attention(nn.Module): + def __init__( + self, dim: int, heads: int = 4, dim_head: int = 32, scale: int = 10 + ) -> None: + super().__init__() + self.scale = scale + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x: torch.Tensor): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) + + q, k = map(l2norm, (q, k)) + + sim = einsum("b h d i, b h d j -> b h i j", q, k) * self.scale + attn = sim.softmax(dim=-1) + out = einsum("b h i j, b h d j -> b h i d", attn, v) + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) + return self.to_out(out) + + +class CrossAttention_lucas(nn.Module): + def __init__( + self, dim: int, heads: int = 1, dim_head: int = 32, scale: int = 10 + ) -> None: + super().__init__() + self.scale = scale + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + b, c, h, w = x.shape + b_y, c_y, h_y, w_y = y.shape + + qkv_x = self.to_qkv(x).chunk(3, dim=1) + qkv_y = self.to_qkv(y).chunk(3, dim=1) + + q_x, k_x, v_x = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv_x + ) + + q_y, k_y, v_y = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv_y + ) + + q, k = map(l2norm, (q_x, k_y)) + + sim = einsum("b h d i, b h d j -> b h i j", q, k) * self.scale + attn = sim.softmax(dim=-1) + out = einsum("b h i j, b h d j -> b h i d", attn, v_y) + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) + return self.to_out(out) diff --git a/dnadiffusion/models/training_modules.py b/dnadiffusion/models/training_modules.py new file mode 100644 index 00000000..d0646bf5 --- /dev/null +++ b/dnadiffusion/models/training_modules.py @@ -0,0 +1,76 @@ +import copy + +import lightning as L +import torch +import torch.nn.functional as F +from torch.optim import Adam + +from dnadiffusion.models.diffusion import p_losses +from dnadiffusion.models.unet import Unet +from dnadiffusion.utils.ema import EMA +from dnadiffusion.utils.scheduler import linear_beta_schedule + + +class UnetDiffusion(L.LightningModule): + def __init__( + self, + model: Unet, + lr: float = 1e-3, + timesteps: int = 50, + beta=0.995, + optimizer: torch.optim.Optimizer = Adam, + ) -> None: + super().__init__() + self.model = model + self.lr = lr + self.timesteps = timesteps + + self.betas = linear_beta_schedule(timesteps=self.timesteps, beta_end=0.2) + # define alphas + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) + self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) + self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.ema = EMA(beta) + self.ema_model = copy.deepcopy(self.model).eval().requires_grad_(False) + + def forward(self, x: torch.Tensor, time: torch.Tensor, classes: torch.Tensor): + return self.model(x, time, classes) + + def training_step(self, batch, batch_idx: int) -> int: + x, y = batch + x = x.type(torch.float32) + y = y.type(torch.long) + batch_size = x.shape[0] + t = torch.randint(0, self.timesteps, (batch_size,)).long() + loss = p_losses( + self.model, + x, + t, + y, + loss_type="huber", + sqrt_alphas_cumprod_in=self.sqrt_alphas_cumprod, + sqrt_one_minus_alphas_cumprod_in=self.sqrt_one_minus_alphas_cumprod, + device=self.device, + ) + self.log("loss", loss, on_step=False, on_epoch=True, sync_dist=True) + return loss + + def configure_optimizers(self): + optimizer = Adam(self.model.parameters(), lr=self.lr) + return optimizer + + def optimizer_step(self, *args, **kwargs): + super().optimizer_step(*args, **kwargs) + + self.ema.step_ema(self.ema_model, self.model) + + def configure_callbacks(self): + pass diff --git a/dnadiffusion/models/unet.py b/dnadiffusion/models/unet.py new file mode 100644 index 00000000..824ff731 --- /dev/null +++ b/dnadiffusion/models/unet.py @@ -0,0 +1,175 @@ +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn +from memory_efficient_attention_pytorch import Attention as EfficientAttention + +from dnadiffusion.models.modules import ( + Attention, + Downsample, + LearnedSinusoidalPosEmb, + LinearAttention, + PreNorm, + Residual, + ResnetBlock, + Upsample, +) +from dnadiffusion.utils.utils import default + + +class Unet(nn.Module): + def __init__( + self, + dim: int, + init_dim: Optional[int] = None, + dim_mults: tuple = (1, 2, 4), + channels: int = 1, + resnet_block_groups: int = 8, + learned_sinusoidal_dim: int = 18, + num_classes: int = 10, + output_attention: bool = False, + ) -> None: + super().__init__() + + # determine dimensions + + channels = 1 + self.channels = channels + # if you want to do self conditioning uncomment this + input_channels = channels + self.output_attention = output_attention + + init_dim = default(init_dim, dim) + self.init_conv = nn.Conv2d(input_channels, init_dim, (7, 7), padding=3) + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + + in_out = list(zip(dims[:-1], dims[1:])) + block_klass = partial(ResnetBlock, groups=resnet_block_groups) + + # time embeddings + time_dim = dim * 4 + + sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim) + fourier_dim = learned_sinusoidal_dim + 1 + + self.time_mlp = nn.Sequential( + sinu_pos_emb, + nn.Linear(fourier_dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim), + ) + + if num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_dim) + + # layers + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append( + nn.ModuleList( + [ + block_klass(dim_in, dim_in, time_emb_dim=time_dim), + block_klass(dim_in, dim_in, time_emb_dim=time_dim), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) + if not is_last + else nn.Conv2d(dim_in, dim_out, 3, padding=1), + ] + ) + ) + + mid_dim = dims[-1] + self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) + self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) + self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind == (len(in_out) - 1) + self.ups.append( + nn.ModuleList( + [ + block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) + if not is_last + else nn.Conv2d(dim_out, dim_in, 3, padding=1), + ] + ) + ) + + self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) + self.final_conv = nn.Conv2d(dim, 1, 1) + self.cross_attn = EfficientAttention( + dim=200, + dim_head=64, + heads=1, + memory_efficient=True, + q_bucket_size=1024, + k_bucket_size=2048, + ) + self.norm_to_cross = nn.LayerNorm(dim * 4) + + def forward(self, x: torch.Tensor, time: torch.Tensor, classes: torch.Tensor): + x = self.init_conv(x) + r = x.clone() + + t_start = self.time_mlp(time) + t_mid = t_start.clone() + t_end = t_start.clone() + t_cross = t_start.clone() + + if classes is not None: + t_start += self.label_emb(classes) + t_mid += self.label_emb(classes) + t_end += self.label_emb(classes) + t_cross += self.label_emb(classes) + + h = [] + + for block1, block2, attn, downsample in self.downs: + x = block1(x, t_start) + h.append(x) + + x = block2(x, t_start) + x = attn(x) + h.append(x) + + x = downsample(x) + + x = self.mid_block1(x, t_mid) + x = self.mid_attn(x) + x = self.mid_block2(x, t_mid) + + for block1, block2, attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, t_mid) + + x = torch.cat((x, h.pop()), dim=1) + x = block2(x, t_mid) + x = attn(x) + + x = upsample(x) + + x = torch.cat((x, r), dim=1) + + x = self.final_res_block(x, t_end) + x = self.final_conv(x) + x_reshaped = x.reshape(-1, 4, 200) + t_cross_reshaped = t_cross.reshape(-1, 4, 200) + + crossattention_out = self.cross_attn( + self.norm_to_cross(x_reshaped.reshape(-1, 800)).reshape(-1, 4, 200), + context=t_cross_reshaped, + ) # (-1,1, 4, 200) + crossattention_out = x.view(-1, 1, 4, 200) + x = x + crossattention_out + if self.output_attention: + return x, crossattention_out + return x diff --git a/dnadiffusion/sample.py b/dnadiffusion/sample.py index 12094fd4..e1f6ebf7 100644 --- a/dnadiffusion/sample.py +++ b/dnadiffusion/sample.py @@ -6,7 +6,6 @@ import pandas as pd import torch import torch.nn as nn -from accelerate import Accelerator from dnadiffusion.utils.utils import extract @@ -69,12 +68,19 @@ def sampling_to_metric( "gimme scan synthetic_motifs.fasta -p JASPAR2020_vertebrates -g hg38 > syn_results_motifs.bed" ) df_results_syn = pd.read_csv( - "new_syn_results_motifs.bed", sep="\t", skiprows=5, header=None + "syn_results_motifs.bed", sep="\t", skiprows=5, header=None ) - df_results_syn["motifs"] = df_results_syn[8].apply( + """df_results_syn["motifs"] = df_results_syn[8].apply( lambda x: x.split('motif_name "')[1].split('"')[0] ) + """ + df_results_syn["motifs"] = ( + df_results_syn[8] + .dropna() + .apply(lambda x: x.split(" ")[1].strip('"')) + .reset_index(drop=True) + ) df_results_syn[0] = df_results_syn[0].apply(lambda x: "_".join(x.split("_")[:-1])) df_motifs_count_syn = ( df_results_syn[[0, "motifs"]].drop_duplicates().groupby("motifs").count() @@ -126,7 +132,6 @@ def p_sample_guided( sqrt_recip_alphas: torch.Tensor, posterior_variance: torch.Tensor, device: str, - accelerator: Accelerator, cond_weight: float = 0.0, ): # adapted from: https://openreview.net/pdf?id=qw8AKxfYbI @@ -147,8 +152,8 @@ def p_sample_guided( # classifier free sampling interpolates between guided and non guided using `cond_weight` classes_masked = classes * context_mask classes_masked = classes_masked.type(torch.long) - if accelerator: - model = accelerator.unwrap_model(model) + # if accelerator: + # model = accelerator.unwrap_model(model) model.output_attention = True preds, cross_map_full = model( x_double, time=t_double, classes=classes_masked @@ -187,7 +192,6 @@ def p_sample_loop( sqrt_one_minus_alphas_cumprod: torch.Tensor, sqrt_recip_alphas: torch.Tensor, posterior_variance: torch.Tensor, - accelerator: Accelerator, get_cross_map: bool = False, ): # to accelerate add timesteps b = shape[0] @@ -214,7 +218,6 @@ def p_sample_loop( sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas=sqrt_recip_alphas, posterior_variance=posterior_variance, - accelerator=accelerator, ) # to accelerate betas else: sampling_fn = partial(p_sample) @@ -246,7 +249,6 @@ def sample( sqrt_one_minus_alphas_cumprod: torch.Tensor, sqrt_recip_alphas: torch.Tensor, posterior_variance: torch.Tensor, - accelerator: Accelerator, batch_size: int = 16, channels: int = 3, cond_weight: int = 0, @@ -264,5 +266,4 @@ def sample( sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas=sqrt_recip_alphas, posterior_variance=posterior_variance, - accelerator=accelerator, ) diff --git a/dnadiffusion/utils/ema.py b/dnadiffusion/utils/ema.py index 879abdb0..d74e681c 100644 --- a/dnadiffusion/utils/ema.py +++ b/dnadiffusion/utils/ema.py @@ -26,7 +26,7 @@ def update_average(self, old: torch.Tensor, new: torch.Tensor) -> torch.Tensor: return old * self.beta + (1 - self.beta) * new def step_ema( - self, ema_model: nn.Module, model: nn.Module, step_start_ema: int = 2000 + self, ema_model: nn.Module, model: nn.Module, step_start_ema: int = 100 ) -> None: if self.step < step_start_ema: self.reset_parameters(ema_model, model) diff --git a/train.py b/train.py index d432c850..2c81086b 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,51 @@ -from dnadiffusion.data.dataloader import LoadingData -from dnadiffusion.trainer import Trainer +from pathlib import Path + +import hydra +import lightning as L +from hydra.core.config_store import ConfigStore +from hydra_zen import MISSING, instantiate, make_config +from omegaconf import DictConfig + +from dnadiffusion.configs import LightningTrainer, sample + +Config = make_config( + hydra_defaults=[ + "_self_", + {"data": "LoadingData"}, + {"model": "Unet"}, + ], + data=MISSING, + model=MISSING, + trainer=LightningTrainer, + sample=sample, + # Constants + data_dir="dna_diffusion/data", + random_seed=42, + ckpt_path=None, +) + +cs = ConfigStore.instance() +cs.store(name="config", node=Config) + + +def train(config): + data = instantiate(config.data) + sample = instantiate(config.sample, data_module=data) + model = instantiate(config.model) + trainer = instantiate(config.trainer) + + # Adding custom callbacks + trainer.callbacks.append(sample) + + trainer.fit(model, data) + + return model + + +@hydra.main(config_path=None, config_name="config", version_base="1.3") +def main(cfg: DictConfig): + return train(cfg) + if __name__ == "__main__": - trainer = Trainer() - trainer.train() + main()