From 8b149f45d57c6197147d4a4419b3d35dfc7ae28b Mon Sep 17 00:00:00 2001 From: Matei Bejan <24592776+mateibejan1@users.noreply.github.com> Date: Mon, 19 Dec 2022 10:25:32 +0200 Subject: [PATCH 01/16] Add files via upload --- src/models/networks/unet_lucas_cond.py | 344 +++++++++++++++++++++++++ 1 file changed, 344 insertions(+) create mode 100644 src/models/networks/unet_lucas_cond.py diff --git a/src/models/networks/unet_lucas_cond.py b/src/models/networks/unet_lucas_cond.py new file mode 100644 index 00000000..246dddb0 --- /dev/null +++ b/src/models/networks/unet_lucas_cond.py @@ -0,0 +1,344 @@ +import math +from einops import rearrange +from functools import partial +from typing import Optional, List + +import torch +from torch import nn, einsum + +from utils.misc import default, exists +from utils.network import l2norm, Upsample, Downsample + +# Building blocks of UNET + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +def Upsample(dim, dim_out = None): + return nn.Sequential( + nn.Upsample(scale_factor = 2, mode = 'nearest'), + nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) + ) + + +def Downsample(dim, dim_out = None): + return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) + + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) * (var + eps).rsqrt() * self.g + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x): + x = self.norm(x) + return self.fn(x) + + +# Building blocks of UNET, positional embeddings + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = rearrange(x, 'b -> b 1') + freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) + fouriered = torch.cat((x, fouriered), dim = -1) + return fouriered + + +class EmbedFC(nn.Module): + def __init__(self, input_dim, emb_dim): + super(EmbedFC, self).__init__() + ''' + generic one layer FC NN for embedding things + ''' + self.input_dim = input_dim + layers = [ + nn.Linear(input_dim, emb_dim), + nn.GELU(), + nn.Linear(emb_dim, emb_dim) + ] + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +# Building blocks of UNET, convolution + group norm blocks + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups = 8): + super().__init__() + self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift = None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + return x + + +# Building blocks of UNET, residual blocks + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups = groups) + self.block2 = Block(dim_out, dim_out, groups = groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb = None): + + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1') + scale_shift = time_emb.chunk(2, dim = 1) + + h = self.block1(x, scale_shift = scale_shift) + + h = self.block2(h) + + return h + self.res_conv(x) + + +# Additional code to the https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py + +class ResnetBlockClassConditioned(ResnetBlock): + def __init__(self, dim, dim_out, *, num_classes, class_embed_dim, time_emb_dim = None, groups = 8): + super().__init__(dim=dim+class_embed_dim, dim_out=dim_out, time_emb_dim=time_emb_dim, groups=groups) + self.class_mlp = EmbedFC(num_classes, class_embed_dim) + + + def forward(self, x, time_emb=None, c=None): + emb_c = self.class_mlp(c) + emb_c = emb_c.view(*emb_c.shape, 1, 1) + emb_c = emb_c.expand(-1, -1, x.shape[-2], x.shape[-1]) + x = torch.cat([x, emb_c], axis=1) + + return super().forward(x, time_emb) + + +# Building blocks of UNET, attention modules + +class LinearAttention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Sequential( + nn.Conv2d(hidden_dim, dim, 1), + LayerNorm(dim) + ) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + + q = q.softmax(dim = -2) + k = k.softmax(dim = -1) + + q = q * self.scale + v = v / (h * w) + + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) + return self.to_out(out) + + +class Attention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32, scale = 10): + super().__init__() + self.scale = scale + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + + q, k = map(l2norm, (q, k)) + + sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale + attn = sim.softmax(dim = -1) + out = einsum('b h i j, b h d j -> b h i d', attn, v) + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) + return self.to_out(out) + + +# Core part of UNET + +class UNet(nn.Module): + """ + Refer to the main paper for the architecture details https://arxiv.org/pdf/2208.04202.pdf + """ + def __init__( + self, + dim: int, + init_dim: int = 200, + dim_mults: Optional[list] = [1, 2, 4], + channels = 1, + resnet_block_groups: int = 8, + learned_sinusoidal_dim: int = 18, + num_classes: int = 10, + class_embed_dim: bool =3, + ): + super().__init__() + + self.channels = channels + # if you want to do self conditioning uncomment this + #input_channels = channels * 2 + input_channels = channels + + + init_dim = default(init_dim, dim) + self.init_conv = nn.Conv2d(input_channels, init_dim, (7,7), padding = 3) + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + block_klass = partial(ResnetBlock, groups=resnet_block_groups) + + time_dim = dim * 4 + + sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim) + fourier_dim = learned_sinusoidal_dim + 1 + + self.time_mlp = nn.Sequential( + sinu_pos_emb, + nn.Linear(fourier_dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + if num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_dim) + + # layers + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass(dim_in, dim_in, time_emb_dim = time_dim), + block_klass(dim_in, dim_in, time_emb_dim = time_dim), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) + self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind == (len(in_out) - 1) + + self.ups.append(nn.ModuleList([ + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) + ])) + + self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.final_conv = nn.Conv2d(dim, 1, 1) + print('final',dim, channels, self.final_conv) + +# Additional code to the https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py mostly in forward method. + + def forward(self, x, time, classes, x_self_cond = None): + x = self.init_conv(x) + r = x.clone() + + t_start = self.time_mlp(time) + t_mid = t_start.clone() + t_end = t_start.clone() + + if classes is not None: + t_start += self.label_emb(classes) + t_mid += self.label_emb(classes) + t_end += self.label_emb(classes) + + h = [] + + for block1, block2, attn, downsample in self.downs: + x = block1(x, t_start) + h.append(x) + + x = block2(x, t_start) + x = attn(x) + h.append(x) + + x = downsample(x) + + x = self.mid_block1(x, t_mid) + x = self.mid_attn(x) + x = self.mid_block2(x, t_mid) + + for block1, block2, attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim = 1) + x = block1(x, t_mid) + + x = torch.cat((x, h.pop()), dim = 1) + x = block2(x, t_mid) + x = attn(x) + + x = upsample(x) + + + x = torch.cat((x, r), dim = 1) + x = self.final_res_block(x, t_end) + + x = self.final_conv(x) + return x \ No newline at end of file From a3a0f11b976a34a53a15b5c6886bd9617547d379 Mon Sep 17 00:00:00 2001 From: Matei Bejan <24592776+mateibejan1@users.noreply.github.com> Date: Mon, 19 Dec 2022 10:26:26 +0200 Subject: [PATCH 02/16] Add files via upload --- src/utils/ema.py | 5 +- src/utils/metrics.py | 155 +++++++++++++++++++++++++++++++++++++++++ src/utils/misc.py | 34 ++++++++- src/utils/network.py | 17 +---- src/utils/schedules.py | 19 +++-- 5 files changed, 199 insertions(+), 31 deletions(-) create mode 100644 src/utils/metrics.py diff --git a/src/utils/ema.py b/src/utils/ema.py index 0855c92c..026c64c2 100644 --- a/src/utils/ema.py +++ b/src/utils/ema.py @@ -1,5 +1,4 @@ -#https://github.com/dome272/Diffusion-Models-pytorch/blob/main/modules.py -class EMA: +class EMA: def __init__(self, beta): super().__init__() self.beta = beta @@ -24,4 +23,4 @@ def step_ema(self, ema_model, model, step_start_ema=2000): self.step += 1 def reset_parameters(self, ema_model, model): - ema_model.load_state_dict(model.state_dict()) + ema_model.load_state_dict(model.state_dict()) \ No newline at end of file diff --git a/src/utils/metrics.py b/src/utils/metrics.py new file mode 100644 index 00000000..f8307082 --- /dev/null +++ b/src/utils/metrics.py @@ -0,0 +1,155 @@ +import torch +import pandas as pd +import numpy as np +from scipy.special import rel_entr +import seaborn as sns +import matplotlib.pyplot as plt +from tqdm.auto import tqdm +import os + + +def motif_scoring_KL_divergence(original: pd.Series, + generated: pd.Series) -> torch.Tensor: + + """ + This function encapsulates the logic of evaluating the KL divergence metric + between two sequences. + Returns + ------- + kl_divergence: Float + The KL divergence between the input and output (generated) + sequences' distribution + """ + + kl_pq = rel_entr(original, generated) + return np.sum(kl_pq) + + +def compare_motif_list(df_motifs_a, df_motifs_b, motif_scoring_metric=motif_scoring_KL_divergence, plot_motif_probs=False): + """ + This function encapsulates the logic of evaluating the difference between the distribution + of frequencies between generated (diffusion/df_motifs_a) and the input (training/df_motifs_b) for an arbitrary metric ("motif_scoring_metric") + + Please note that some metrics, like KL_divergence, are not metrics in official sense. Reason + for that is that they dont satisfy certain properties, such as in KL case, the simmetry property. + Hence it makes a big difference what are the positions of input. + """ + set_all_mot = set(df_motifs_a.index.values.tolist() + df_motifs_b.index.values.tolist()) + create_new_matrix = [] + for x in set_all_mot: + list_in = [] + list_in.append(x) # adding the name + if x in df_motifs_a.index: + list_in.append(df_motifs_a.loc[x][0]) + else: + list_in.append(1) + + if x in df_motifs_b.index: + list_in.append(df_motifs_b.loc[x][0]) + else: + list_in.append(1) + + create_new_matrix.append(list_in) + + + df_motifs = pd.DataFrame(create_new_matrix, columns=['motif', 'motif_a', 'motif_b']) + + df_motifs['Diffusion_seqs'] = df_motifs['motif_a'] / df_motifs['motif_a'].sum() + df_motifs['Training_seqs'] = df_motifs['motif_b'] / df_motifs['motif_b'].sum() + if plot_motif_probs: + plt.rcParams["figure.figsize"] = (3,3) + sns.regplot(x='Diffusion_seqs', y='Training_seqs',data=df_motifs) + plt.xlabel('Diffusion Seqs') + plt.ylabel('Training Seqs') + plt.title('Motifs Probs') + plt.show() + + return motif_scoring_metric(df_motifs['Diffusion_seqs'].values, df_motifs['Training_seqs'].values) + + +def sampling_to_metric(model, cell_types, image_size, nucleotides, number_of_samples=20, specific_group=False, + group_number=None, cond_weight_to_metric=0): + """ + Might need to add to the DDPM class since if we can't call the sample() method outside PyTorch Lightning. + + This function encapsulates the logic of sampling from the trained model in order to generate counts of the motifs. + The reasoning is that we are interested only in calculating the evaluation metric + for the count of occurances and not the nucleic acids themselves. + """ + final_sequences=[] + for n_a in tqdm(range(number_of_samples)): + sample_bs = 10 + if specific_group: + sampled = torch.from_numpy(np.array([group_number] * sample_bs) ) + print ('specific') + else: + sampled = torch.from_numpy(np.random.choice(cell_types, sample_bs)) + + random_classes = sampled.float().cuda() + sampled_images = model.sample(classes=random_classes, image_size=image_size, batch_size=sample_bs, channels=1, cond_weight=cond_weight_to_metric) + for n_b, x in enumerate(sampled_images[-1]): + seq_final = f'>seq_test_{n_a}_{n_b}\n' +''.join([nucleotides[s] for s in np.argmax(x.reshape(4,200), axis=0)]) + final_sequences.append(seq_final) + + save_motifs_syn = open('synthetic_motifs.fasta', 'w') + + save_motifs_syn.write('\n'.join(final_sequences)) + save_motifs_syn.close() + + #Scan for motifs + os.system('gimme scan synthetic_motifs.fasta -p JASPAR2020_vertebrates -g hg38 > syn_results_motifs.bed') + df_results_syn = pd.read_csv('syn_results_motifs.bed', sep='\t', skiprows=5, header=None) + df_results_syn['motifs'] = df_results_syn[8].apply(lambda x: x.split( 'motif_name "')[1].split('"')[0] ) + df_results_syn[0] = df_results_syn[0].apply(lambda x : '_'.join( x.split('_')[:-1]) ) + df_motifs_count_syn = df_results_syn[[0,'motifs']].drop_duplicates().groupby('motifs').count() + plt.rcParams["figure.figsize"] = (30,2) + df_motifs_count_syn.sort_values(0, ascending=False).head(50)[0].plot.bar() + plt.show() + + return df_motifs_count_syn + + +def metric_comparison_between_components(original_data, generated_data, x_label_plot, y_label_plot, cell_components): + """ + This functions takes as inputs dictionaries, which contain as keys different components (cell types) + and as values the distribution of occurances of different motifs. These two dictionaries represent two different datasets, i.e. + generated dataset and the input (train) dataset. + + The goal is to then plot a the main evaluation metric (KL or otherwise) across all different types of cell types + in a heatmap fashion. + """ + ENUMARATED_CELL_NAME = '''7 Trophoblasts + 5 CD8_cells + 15 CD34_cells + 9 Fetal_heart + 12 Fetal_muscle + 14 HMVEC(vascular) + 3 hESC(Embryionic) + 8 Fetal(Neural) + 13 Intestine + 2 Skin(stromalA) + 4 Fibroblast(stromalB) + 6 Renal(Cancer) + 16 Esophageal(Cancer) + 11 Fetal_Lung + 10 Fetal_kidney + 1 Tissue_Invariant'''.split('\n') + CELL_NAMES = {int(x.split(' ')[0]): x.split(' ')[1] for x in ENUMARATED_CELL_NAME} + + final_comparison_all_components = [] + for components_1, motif_occurance_frequency in original_data.items(): + comparisons_single_component = [] + for components_2 in generated_data.keys(): + compared_motifs_occurances = compare_motif_list(motif_occurance_frequency, generated_data[components_2]) + comparisons_single_component.append(compared_motifs_occurances) + + final_comparison_all_components.append(comparisons_single_component) + + plt.rcParams["figure.figsize"] = (10,10) + df_plot = pd.DataFrame(final_comparison_all_components) + df_plot.columns = [CELL_NAMES[x] for x in cell_components] + df_plot.index = df_plot.columns + sns.heatmap(df_plot, cmap='Blues_r', annot=True, lw=0.1, vmax=1, vmin=0 ) + plt.title(f'Kl divergence \n {x_label_plot} sequences x {y_label_plot} sequences \n MOTIFS probabilities') + plt.xlabel(f'{x_label_plot} Sequences \n(motifs dist)') + plt.ylabel(f'{y_label_plot} \n (motifs dist)') \ No newline at end of file diff --git a/src/utils/misc.py b/src/utils/misc.py index 4ba9eda1..d8eaa3c0 100644 --- a/src/utils/misc.py +++ b/src/utils/misc.py @@ -1,7 +1,23 @@ import math import importlib - import torch +import random +import os +import numpy as np + + +def seed_everything(seed): + """" + Seed everything. + """ + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + def exists(x): return x is not None @@ -44,7 +60,21 @@ def convert_image_to(img_type, image): return image +def one_hot_encode(seq, nucleotides, max_seq_len): + """ + One-hot encode a sequence of nucleotides. + """ + seq_len = len(seq) + seq_array = np.zeros((max_seq_len, len(nucleotides))) + for i in range(seq_len): + seq_array[i, nucleotides.index(seq[i])] = 1 + return seq_array + + def log(t, eps = 1e-20): + """ + Toch log for the purporses of diffusion time steps t. + """ return torch.log(t.clamp(min = eps)) @@ -73,4 +103,4 @@ def mean_flat(tensor): Take the mean over all non-batch dimensions. From Perception Prioritized Training of Diffusion Models: https://arxiv.org/abs/2204.00227. """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) + return tensor.mean(dim=list(range(1, len(tensor.shape)))) \ No newline at end of file diff --git a/src/utils/network.py b/src/utils/network.py index e0bf3499..b9d91d65 100644 --- a/src/utils/network.py +++ b/src/utils/network.py @@ -1,20 +1,5 @@ -from torch import nn import torch.nn.functional as F -from utils.misc import default - def l2norm(t): - return F.normalize(t, dim = -1) - - -def Upsample(dim, dim_out = None): - return nn.Sequential( - nn.Upsample(scale_factor = 2, mode = 'nearest'), - nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) - ) - - -def Downsample(dim, dim_out = None): - return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) - + return F.normalize(t, dim = -1) \ No newline at end of file diff --git a/src/utils/schedules.py b/src/utils/schedules.py index fec3238b..82f1312b 100644 --- a/src/utils/schedules.py +++ b/src/utils/schedules.py @@ -1,23 +1,22 @@ import math -from math import exp +from math import log, exp import torch def beta_linear_log_snr(t): - return -torch.log(exp(1e-4 + 10 * (t**2))) + return -torch.log(exp(1e-4 + 10 * (t ** 2))) def alpha_cosine_log_snr(t, s: float = 0.008): - return -torch.log( - (torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps=1e-5 - ) # not sure if this accounts for beta being clipped to 0.999 in discrete version + # not sure if this accounts for beta being clipped to 0.999 in discrete version + return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) def log_snr_to_alpha_sigma(log_snr): return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr)) -def cosine_beta_schedule(timesteps, s=0.008) -> torch.Tensor: +def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://arxiv.org/abs/2102.09672 """ @@ -29,19 +28,19 @@ def cosine_beta_schedule(timesteps, s=0.008) -> torch.Tensor: return torch.clip(betas, 0.0001, 0.9999) -def linear_beta_schedule(timesteps, beta_end=0.005) -> torch.Tensor: +def linear_beta_schedule(timesteps, beta_end=0.005): beta_start = 0.0001 return torch.linspace(beta_start, beta_end, timesteps) -def quadratic_beta_schedule(timesteps) -> torch.Tensor: +def quadratic_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2 -def sigmoid_beta_schedule(timesteps) -> torch.Tensor: +def sigmoid_beta_schedule(timesteps): beta_start = 0.001 beta_end = 0.02 betas = torch.linspace(-6, 6, timesteps) - return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start \ No newline at end of file From 6e8273ee56ca7c2a1fd23cc422c3f18c9cd3ac8d Mon Sep 17 00:00:00 2001 From: Matei Bejan <24592776+mateibejan1@users.noreply.github.com> Date: Mon, 19 Dec 2022 10:27:36 +0200 Subject: [PATCH 03/16] Add files via upload Updated DDPM with the Noah's refactored notebook version. Preemptively added p2_weighting, need to figure out if/how it works on bit sequences. --- src/models/diffusion/ddpm.py | 291 +++++++++++++++++++++-------------- 1 file changed, 179 insertions(+), 112 deletions(-) diff --git a/src/models/diffusion/ddpm.py b/src/models/diffusion/ddpm.py index a4d2eaa4..d304c561 100644 --- a/src/models/diffusion/ddpm.py +++ b/src/models/diffusion/ddpm.py @@ -2,205 +2,272 @@ import torch from torch import nn from torch.nn.functional import F - -from typing import Optional, List +from functools import partial from models.diffusion.diffusion import DiffusionModel -from utils.schedules import ( - beta_linear_log_snr, - alpha_cosine_log_snr, - linear_beta_schedule, -) -from utils.misc import extract, mean_flat - +from utils.schedules import beta_linear_log_snr, alpha_cosine_log_snr, linear_beta_schedule +from utils.misc import extract, extract_data_from_batch, mean_flat class DDPM(DiffusionModel): def __init__( self, - model, *, image_size, - timesteps=1000, - noise_schedule="cosine", + timesteps = 50, + noise_schedule = 'cosine', + time_difference = 0., unet_config: dict, is_conditional: bool, + p_uncond: float = 0.1, + use_fp16: bool, logdir: str, - img_size: int, optimizer_config: dict, lr_scheduler_config: dict = None, criterion: nn.Module, use_ema: bool = True, ema_decay: float = 0.9999, lr_warmup=0, - use_p2_weigthing=False, + use_p2_weigthing: bool = False, p2_gamma: float = 0.5, - p2_k: float = 1, + p2_k: float = 1 ): super().__init__( unet_config, is_conditional, + use_fp16, logdir, - img_size, optimizer_config, lr_scheduler_config, criterion, use_ema, ema_decay, - lr_warmup, + lr_warmup ) self.image_size = image_size - self.timesteps = timesteps if noise_schedule == "linear": self.log_snr = beta_linear_log_snr - self.betas = linear_beta_schedule(timesteps=timesteps, beta_end=0.05) elif noise_schedule == "cosine": - # Refer to Section 3.2 of https://arxiv.org/abs/2102.09672 for details self.log_snr = alpha_cosine_log_snr - self.betas = cosine_beta_schedule(timesteps=timesteps, s=0.0001) else: - raise ValueError(f"invalid noise schedule {noise_schedule}") + raise ValueError(f'invalid noise schedule {noise_schedule}') + + self.timesteps = timesteps + self.p_uncond = p_uncond + + #self.betas = cosine_beta_schedule(timesteps=timesteps, s=0.0001) + self.set_noise_schedule(self.betas, self.timesteps) + + # proposed in the paper, summed to time_next + # as a way to fix a deficiency in self-conditioning and lower FID when the number of sampling timesteps is < 400 + + self.time_difference = time_difference - # Define Beta Schedule - self.set_noise_schedule(self.betas) - - if self.use_p2_weighting: - self.p2_gamma = p2_gamma - self.p2_k = p2_k - self.snr = 1.0 / (1 - self.alphas_cumprod) - 1 - - def set_noise_schedule(self, betas: torch.Tensor) -> None: - # define alphas - alphas = 1.0 - betas - self.alphas_cumprod = torch.cumprod(alphas, axis=0) - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + def set_noise_schedule(self, betas, timesteps): + # define beta schedule + self.betas = linear_beta_schedule(timesteps=timesteps, beta_end=0.05) + + # define alphas + alphas = 1. - self.betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + + #sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) # calculations for posterior q(x_{t-1} | x_t, x_0) - self.posterior_variance = ( - betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) - ) + self.posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - def q_sample( - self, - x_start: torch.Tensor, - t: torch.Tensor, - noise: Optional[None, torch.Tensor] = None, - ) -> torch.Tensor: + + def q_sample(self, x_start, t, noise=None): + """ + Forward pass with noise. + """ if noise is None: noise = torch.randn_like(x_start) sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape) sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, x_start.shape - ) + ) return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise + @torch.no_grad() - def p_sample(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def p_sample(self, x, t, t_index): betas_t = extract(self.betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, x.shape ) - + #print (x.shape, 'x_shape') sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape) - - # Eqn 11 from https://arxiv.org/abs/2006.11239v2 + + # Equation 11 in the paper + # Use our model (noise predictor) to predict the mean model_mean = sqrt_recip_alphas_t * ( - x - betas_t * self.model(x, t) / sqrt_one_minus_alphas_cumprod_t + x - betas_t * self.model(x, time=t) / sqrt_one_minus_alphas_cumprod_t + ) + + if t_index == 0: + return model_mean + else: + posterior_variance_t = extract(self.posterior_variance, t, x.shape) + noise = torch.randn_like(x) + # Algorithm 2 line 4: + return model_mean + torch.sqrt(posterior_variance_t) * noise + + + @torch.no_grad() + def p_sample_guided(self, x, classes, t, t_index, context_mask, cond_weight=0.0): + # adapted from: https://openreview.net/pdf?id=qw8AKxfYbI + #print (classes[0]) + batch_size = x.shape[0] + # double to do guidance with + t_double = t.repeat(2) + x_double = x.repeat(2, 1, 1, 1) + betas_t = extract(self.betas, t_double, x_double.shape) + sqrt_one_minus_alphas_cumprod_t = extract( + self.sqrt_one_minus_alphas_cumprod, t_double, x_double.shape + ) + sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t_double, x_double.shape) + + # classifier free sampling interpolates between guided and non guided using `cond_weight` + classes_masked = classes * context_mask + classes_masked = classes_masked.type(torch.long) + #print ('class masked', classes_masked) + preds = self.model(x_double, time=t_double, classes=classes_masked) + eps1 = (1 + cond_weight) * preds[:batch_size] + eps2 = cond_weight * preds[batch_size:] + x_t = eps1 - eps2 + + # Equation 11 in the paper + # Use our model (noise predictor) to predict the mean + model_mean = sqrt_recip_alphas_t[:batch_size] * ( + x - betas_t[:batch_size] * x_t / sqrt_one_minus_alphas_cumprod_t[:batch_size] ) - posterior_variance_t = extract(self.posterior_variance, t, x.shape) - noise = torch.randn_like(x) + if t_index == 0: + return model_mean + else: + posterior_variance_t = extract(self.posterior_variance, t, x.shape) + noise = torch.randn_like(x) + # Algorithm 2 line 4: + return model_mean + torch.sqrt(posterior_variance_t) * noise - return model_mean + torch.sqrt(posterior_variance_t) * noise + # Algorithm 2 but save all images: @torch.no_grad() - def p_sample_loop(self, shape: torch.Size) -> List: + def p_sample_loop(self, classes, shape, cond_weight): device = next(self.model.parameters()).device b = shape[0] # start from pure noise (for each example in the batch) - img = torch.randn(shape, device=device) - imgs = [] - - for i in tqdm( - reversed(range(0, self.timesteps)), - desc="sampling loop time step", - total=self.timesteps, - ): - img = self.p_sample( - x=img, t=torch.full((b,), i, device=device, dtype=torch.long) - ) - imgs.append(img.cpu().numpy()) - return imgs + image = torch.randn(shape, device=device) + images = [] + + if classes is not None: + n_sample = classes.shape[0] + context_mask = torch.ones_like(classes).to(device) + # make 0 index unconditional + # double the batch + classes = classes.repeat(2) + context_mask = context_mask.repeat(2) + context_mask[n_sample:] = 0. # makes second half of batch context free + sampling_fn = partial(self.p_sample_guided, classes=classes, cond_weight=cond_weight, context_mask=context_mask) + else: + sampling_fn = partial(self.p_sample) + + for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps): + image = sampling_fn(self.model, x=image, t=torch.full((b,), i, device=device, dtype=torch.long), t_index=i) + images.append(image.cpu().numpy()) + return images + @torch.no_grad() - def sample( - self, batch: torch.Tensor, channels: int = 3, nucleotides: int = 4 - ) -> List: - return self.p_sample_loop( - shape=(batch.shape[0], channels, nucleotides, self.image_size) - ) - - def p2_weighting(self, x_t, ts, target, prediction): - """ - From Perception Prioritized Training of Diffusion Models: https://arxiv.org/abs/2204.00227. - """ - weight = ( - 1 / (self.p2_k + self.snr) ** self.p2_gamma, ts, x_t.shape - ) - loss_batch = mean_flat(weight * (target - prediction) ** 2) - loss = torch.mean(loss_batch) - return loss + def sample(self, image_size, classes=None, batch_size=16, channels=3, cond_weight=0): + return self.p_sample_loop(self.model, classes=classes, shape=(batch_size, channels, 4, image_size), cond_weight=cond_weight) - def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: - if self.training_custom_noise is None: - self.training_custom_noise = torch.randn_like(batch) - x_noisy = self.q_sample( - x_start=batch, t=self.timesteps, noise=self.training_custom_noise - ) - # calculating generic loss function, we'll add it to the class constructor once we have the code - # we should log more metrics at train and validation e.g. l1, l2 and other suggestions - predicted_noise = self.model(x_noisy, self.timesteps) + def training_step(self, batch: torch.Tensor, batch_idx: int): + x_start, condition = extract_data_from_batch(batch) + + if noise is None: + noise = torch.randn_like(x_start) + x_noisy =self.q_sample(x_start=x_start, t=self.timesteps, noise=noise) - if self.use_p2_weighting: - loss = 0 # Need to figure out how to use p2 weighting in DNA context + #calculating generic loss function, we'll add it to the class constructor once we have the code + #we should log more metrics at train and validation e.g. l1, l2 and other suggestions + if self.use_fp16: + with torch.cuda.amp.autocast(): + if self.is_conditional: + predicted_noise = self.model(x_noisy, self.timesteps, condition) + else: + predicted_noise = self.model(x_noisy, self.timesteps) else: - loss = self.criterion(predicted_noise, self.training_custom_noise) - - self.log("train", loss, batch_size=batch.shape[0]) + if self.is_conditional: + predicted_noise = self.model(x_noisy, self.timesteps, condition) + else: + predicted_noise = self.model(x_noisy, self.timesteps) + + loss = self.criterion(predicted_noise, noise) + self.log('train', loss, batch_size=batch.shape[0]) return loss + def validation_step(self, batch: torch.Tensor, batch_idx: int): - if self.is_conditional: - return None - else: - return self.inference_step(batch, batch_idx, "validation") + return self.inference_step(batch, batch_idx, 'validation') + def test_step(self, batch: torch.Tensor, batch_idx: int): - if self.is_conditional: - return None - else: - return self.inference_step(batch, batch_idx, "test") + return self.inference_step(batch, batch_idx, 'test') + - def inference_step(self, batch: torch.Tensor, batch_idx: int, phase="validation"): - predictions = self.sample(batch) + def inference_step(self, batch: torch.Tensor, batch_idx: int, phase='validation', noise=None): + x_start, condition = extract_data_from_batch(batch) + device = x_start.device + batch_size = batch.shape[0] + + t = torch.randint(0, self.timesteps, (batch_size,), device=device).long() # sampling a t to generate t and t+1 + + if noise is None: + noise = torch.randn_like(x_start) # gauss noise + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) #this is the auto generated noise given t and Noise + + context_mask = torch.bernoulli(torch.zeros(classes.shape[0]) + (1 - self.p_uncond)).to(device) + + # mask for unconditinal guidance + classes = classes * context_mask + classes = classes.type(torch.long) + + predictions = self.model(x_noisy, t, condition) loss = self.criterion(predictions, batch) - self.log("val_loss", loss) if phase == "validation" else self.log( - "test_loss", loss - ) + self.log('validation_loss', loss) if phase == 'validation' else self.log('test_loss', loss) + + ''' + Log multiple losses at validation/test time according to internal discussions. + ''' return predictions + + def p2_weighting(self, x_t, ts, target, prediction): + """ + From Perception Prioritized Training of Diffusion Models: https://arxiv.org/abs/2204.00227. + """ + weight = ( + 1 / (self.p2_k + self.snr) ** self.p2_gamma, ts, x_t.shape + ) + loss_batch = mean_flat(weight * (target - prediction) ** 2) + loss = torch.mean(loss_batch) + return loss From 642d67c7f4664b83f0ad0b6c1d438483258d7244 Mon Sep 17 00:00:00 2001 From: Matei Bejan <24592776+mateibejan1@users.noreply.github.com> Date: Mon, 19 Dec 2022 10:27:59 +0200 Subject: [PATCH 04/16] Add files via upload --- src/models/diffusion/diffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/models/diffusion/diffusion.py b/src/models/diffusion/diffusion.py index ba46f1c0..9770da7d 100644 --- a/src/models/diffusion/diffusion.py +++ b/src/models/diffusion/diffusion.py @@ -3,7 +3,6 @@ import pytorch_lightning as pl from utils.ema import EMA -from utils.misc import instantiate_from_config class DiffusionModel(pl.LightningModule): def __init__( @@ -11,6 +10,7 @@ def __init__( unet_config: dict, timesteps: int, is_conditional: bool, + use_fp16: bool, logdir: str, image_size: int, optimizer_config: dict, @@ -32,6 +32,7 @@ def __init__( if self.use_ema: self.eps_model_ema = EMA(self.model, decay=ema_decay) self.is_conditional = is_conditional + self.use_fp16 = use_fp16 self.image_size = image_size self.lr_scheduler_config = lr_scheduler_config self.optimizer_config = optimizer_config From fdf3512e5b20f066c1353bf99f71d77f65618f6b Mon Sep 17 00:00:00 2001 From: Matei Bejan <24592776+mateibejan1@users.noreply.github.com> Date: Mon, 19 Dec 2022 10:29:02 +0200 Subject: [PATCH 05/16] Add files via upload --- src/train.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/train.py b/src/train.py index 1da73cbf..0ffe9bc3 100644 --- a/src/train.py +++ b/src/train.py @@ -5,15 +5,31 @@ from omegaconf import DictConfig, OmegaConf from hydra.utils import instantiate +import argparse logger = logging.getLogger() +def get_parser(**parser_kwargs): + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument( + "--logdir", type=str, default="logs", help="where to save logs and ckpts" + ) + parser.add_argument("--name", type=str, default="dummy", help="postfix for logdir") + parser.add_argument( + "--resume", + type=str, + default="", + help="resume training from given folder or checkpoint", + ) + return parser @hydra.main(config_path="configs", config_name="train") def train(cfg: DictConfig): + parser = get_parser() + # Keeping track of current config settings in logger logger.info(f"Training with config:\n{OmegaConf.to_yaml(cfg)}") - run = wandb.init(project=cfg.logger.wandb.project, config=cfg) + run = wandb.init(name=parser.logdir, save_dir=parser.logdir, project=cfg.logger.wandb.project, config=cfg) # Placeholder for what loss or metric values we plan to track with wandb wandb.log({"loss": loss}) @@ -37,6 +53,5 @@ def train(cfg: DictConfig): trainer.fit(model, train_dl, val_dl) - if __name__ == "__main__": - train() + train() \ No newline at end of file From d616041d058e9a282116b227ddd3199bad3545a9 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 20 Dec 2022 14:21:36 +0000 Subject: [PATCH 06/16] style: run black --- src/data/sequence_dataloader.py | 75 +++++--- src/models/diffusion/ddim.py | 70 ++++---- src/models/diffusion/ddpm.py | 153 +++++++++------- src/models/diffusion/diffusion.py | 1 + src/models/networks/unet_bitdiffusion.py | 102 ++++++----- src/models/networks/unet_lucas.py | 170 +++++++++--------- src/models/networks/unet_lucas_cond.py | 211 +++++++++++++---------- src/train.py | 16 +- src/utils/ema.py | 6 +- src/utils/metrics.py | 153 ++++++++++------ src/utils/misc.py | 13 +- src/utils/network.py | 2 +- src/utils/schedules.py | 6 +- tests/data/test_sequence_dataloader.py | 194 +++++++++++++++------ 14 files changed, 707 insertions(+), 465 deletions(-) diff --git a/src/data/sequence_dataloader.py b/src/data/sequence_dataloader.py index 438c31d9..cab51bb2 100644 --- a/src/data/sequence_dataloader.py +++ b/src/data/sequence_dataloader.py @@ -4,11 +4,19 @@ import torch import torchvision.transforms as T import torch.nn.functional as F -import pytorch_lightning as pl -from torch.utils.data import Dataset, DataLoader +import pytorch_lightning as pl +from torch.utils.data import Dataset, DataLoader + class SequenceDatasetBase(Dataset): - def __init__(self, data_path, sequence_length=200, sequence_encoding="polar", sequence_transform=None, cell_type_transform=None): + def __init__( + self, + data_path, + sequence_length=200, + sequence_encoding="polar", + sequence_transform=None, + cell_type_transform=None, + ): super().__init__() self.data = pd.read_csv(data_path, sep="\t") self.sequence_length = sequence_length @@ -20,16 +28,16 @@ def __init__(self, data_path, sequence_length=200, sequence_encoding="polar", se def __len__(self): return len(self.data) - + def __getitem__(self, index): # Iterating through DNA sequences from dataset and one-hot encoding all nucleotides current_seq = self.data["raw_sequence"][index] - if 'N' not in current_seq: + if "N" not in current_seq: X_seq = self.encode_sequence(current_seq, encoding=self.sequence_encoding) - + # Reading cell component at current index X_cell_type = self.data["component"][index] - + if self.sequence_transform is not None: X_seq = self.sequence_transform(X_seq) if self.cell_type_transform is not None: @@ -79,10 +87,12 @@ class SequenceDatasetTrain(SequenceDatasetBase): def __init__(self, data_path="", **kwargs): super().__init__(data_path=data_path, **kwargs) + class SequenceDatasetValidation(SequenceDatasetBase): def __init__(self, data_path="", **kwargs): super().__init__(data_path=data_path, **kwargs) + class SequenceDatasetTest(SequenceDatasetBase): def __init__(self, data_path="", **kwargs): super().__init__(data_path=data_path, **kwargs) @@ -99,11 +109,15 @@ def __init__( sequence_transform=None, cell_type_transform=None, batch_size=None, - num_workers=1 + num_workers=1, ): super().__init__() self.datasets = dict() - self.train_dataloader, self.val_dataloader, self.test_dataloader = None, None, None + self.train_dataloader, self.val_dataloader, self.test_dataloader = ( + None, + None, + None, + ) if train_path: self.datasets["train"] = train_path @@ -131,7 +145,7 @@ def setup(self): sequence_length=self.sequence_length, sequence_encoding=self.sequence_encoding, sequence_transform=self.sequence_transform, - cell_type_transform=self.cell_type_transform + cell_type_transform=self.cell_type_transform, ) if "validation" in self.datasets: self.val_data = SequenceDatasetValidation( @@ -139,7 +153,7 @@ def setup(self): sequence_length=self.sequence_length, sequence_encoding=self.sequence_encoding, sequence_transform=self.sequence_transform, - cell_type_transform=self.cell_type_transform + cell_type_transform=self.cell_type_transform, ) if "test" in self.datasets: self.test_data = SequenceDatasetTest( @@ -147,27 +161,32 @@ def setup(self): sequence_length=self.sequence_length, sequence_encoding=self.sequence_encoding, sequence_transform=self.sequence_transform, - cell_type_transform=self.cell_type_transform + cell_type_transform=self.cell_type_transform, ) def _train_dataloader(self): - return DataLoader(self.train_data, - self.batch_size, - shuffle=True, - num_workers=self.num_workers, - pin_memory=True) + return DataLoader( + self.train_data, + self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + ) def _val_dataloader(self): - return DataLoader(self.val_data, - self.batch_size, - shuffle=True, - num_workers=self.num_workers, - pin_memory=True) + return DataLoader( + self.val_data, + self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + ) def _test_dataloader(self): - return DataLoader(self.test_data, - self.batch_size, - shuffle=True, - num_workers=self.num_workers, - pin_memory=True) - + return DataLoader( + self.test_data, + self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + ) diff --git a/src/models/diffusion/ddim.py b/src/models/diffusion/ddim.py index 046514ad..4dc76522 100644 --- a/src/models/diffusion/ddim.py +++ b/src/models/diffusion/ddim.py @@ -5,17 +5,18 @@ from models.diffusion.diffusion import DiffusionModel + class DDIM(DiffusionModel): def __init__( self, model, *, image_size, - timesteps = 1000, - use_ddim = False, - noise_schedule = 'cosine', - time_difference = 0., - bit_scale = 1. + timesteps=1000, + use_ddim=False, + noise_schedule="cosine", + time_difference=0.0, + bit_scale=1.0, ): super().__init__() self.model = model @@ -28,7 +29,7 @@ def __init__( elif noise_schedule == "cosine": self.log_snr = alpha_cosine_log_snr else: - raise ValueError(f'invalid noise schedule {noise_schedule}') + raise ValueError(f"invalid noise schedule {noise_schedule}") self.bit_scale = bit_scale @@ -45,38 +46,40 @@ def device(self): return next(self.model.parameters()).device def get_sampling_timesteps(self, batch, *, device): - times = torch.linspace(1., 0., self.timesteps + 1, device = device) - times = repeat(times, 't -> b t', b = batch) - times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0) - times = times.unbind(dim = -1) + times = torch.linspace(1.0, 0.0, self.timesteps + 1, device=device) + times = repeat(times, "t -> b t", b=batch) + times = torch.stack((times[:, :-1], times[:, 1:]), dim=0) + times = times.unbind(dim=-1) return times @torch.no_grad() - def ddim_sample(self, shape, classes, time_difference = None): + def ddim_sample(self, shape, classes, time_difference=None): batch, device = shape[0], self.device time_difference = default(time_difference, self.time_difference) - time_pairs = self.get_sampling_timesteps(batch, device = device) - img = torch.randn(shape, device = device) + time_pairs = self.get_sampling_timesteps(batch, device=device) + img = torch.randn(shape, device=device) x_start = None - for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'): + for times, times_next in tqdm(time_pairs, desc="sampling loop time step"): # get times and noise levels log_snr = self.log_snr(times) log_snr_next = self.log_snr(times_next) - padded_log_snr, padded_log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next)) + padded_log_snr, padded_log_snr_next = map( + partial(right_pad_dims_to, img), (log_snr, log_snr_next) + ) alpha, sigma = log_snr_to_alpha_sigma(padded_log_snr) alpha_next, sigma_next = log_snr_to_alpha_sigma(padded_log_snr_next) # add the time delay - times_next = (times_next - time_difference).clamp(min = 0.) + times_next = (times_next - time_difference).clamp(min=0.0) # predict x0 @@ -88,38 +91,42 @@ def ddim_sample(self, shape, classes, time_difference = None): # get predicted noise - pred_noise = (img - alpha * x_start) / sigma.clamp(min = 1e-8) + pred_noise = (img - alpha * x_start) / sigma.clamp(min=1e-8) # calculate x next img = x_start * alpha_next + pred_noise * sigma_next return bits_to_decimal(img) - + # TODO: Need to add class conditioned weight @torch.no_grad() def sample(self, batch_size=16, classes=None): image_size, channels = self.image_size, self.channels sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample - return sample_fn((batch_size, 8, 4, image_size), classes=classes) # Lucas - + return sample_fn((batch_size, 8, 4, image_size), classes=classes) # Lucas + def forward(self, img, class_enc, *args, **kwargs): - batch, c, h, w, device, img_size, = *img.shape, img.device, self.image_size - - times = torch.zeros((batch,), device = device).float().uniform_(0, 0.999) + batch, c, h, w, device, img_size, = ( + *img.shape, + img.device, + self.image_size, + ) + + times = torch.zeros((batch,), device=device).float().uniform_(0, 0.999) # convert image to bit representation - + img = decimal_to_bits(img) * self.bit_scale - + noise = torch.randn_like(img) noise_level = self.log_snr(times) padded_noise_level = right_pad_dims_to(img, noise_level) - alpha, sigma = log_snr_to_alpha_sigma(padded_noise_level) + alpha, sigma = log_snr_to_alpha_sigma(padded_noise_level) noised_img = alpha * img + sigma * noise - + # if doing self-conditioning, 50% of the time, predict x_start from current set of times # and condition with unet with that # this technique will slow down training by 25%, but seems to lower FID significantly @@ -130,8 +137,9 @@ def forward(self, img, class_enc, *args, **kwargs): # with torch.no_grad(): # self_cond = self.model(noised_img, noise_level, class_enc).detach_() + pred = self.model( + noised_img, noise_level, class_enc, self_cond + ) # BACK TO NOISE_LEVEL - pred = self.model(noised_img, noise_level, class_enc, self_cond) # BACK TO NOISE_LEVEL - - #return F.mse_loss(pred, img) # LUCAS - return F.smooth_l1_loss(pred, img) # LUCAS ADDED \ No newline at end of file + # return F.mse_loss(pred, img) # LUCAS + return F.smooth_l1_loss(pred, img) # LUCAS ADDED diff --git a/src/models/diffusion/ddpm.py b/src/models/diffusion/ddpm.py index d304c561..bb5867b8 100644 --- a/src/models/diffusion/ddpm.py +++ b/src/models/diffusion/ddpm.py @@ -6,21 +6,26 @@ from models.diffusion.diffusion import DiffusionModel -from utils.schedules import beta_linear_log_snr, alpha_cosine_log_snr, linear_beta_schedule +from utils.schedules import ( + beta_linear_log_snr, + alpha_cosine_log_snr, + linear_beta_schedule, +) from utils.misc import extract, extract_data_from_batch, mean_flat + class DDPM(DiffusionModel): def __init__( self, *, image_size, - timesteps = 50, - noise_schedule = 'cosine', - time_difference = 0., + timesteps=50, + noise_schedule="cosine", + time_difference=0.0, unet_config: dict, is_conditional: bool, p_uncond: float = 0.1, - use_fp16: bool, + use_fp16: bool, logdir: str, optimizer_config: dict, lr_scheduler_config: dict = None, @@ -30,7 +35,7 @@ def __init__( lr_warmup=0, use_p2_weigthing: bool = False, p2_gamma: float = 0.5, - p2_k: float = 1 + p2_k: float = 1, ): super().__init__( unet_config, @@ -42,7 +47,7 @@ def __init__( criterion, use_ema, ema_decay, - lr_warmup + lr_warmup, ) self.image_size = image_size @@ -52,12 +57,12 @@ def __init__( elif noise_schedule == "cosine": self.log_snr = alpha_cosine_log_snr else: - raise ValueError(f'invalid noise schedule {noise_schedule}') + raise ValueError(f"invalid noise schedule {noise_schedule}") self.timesteps = timesteps self.p_uncond = p_uncond - #self.betas = cosine_beta_schedule(timesteps=timesteps, s=0.0001) + # self.betas = cosine_beta_schedule(timesteps=timesteps, s=0.0001) self.set_noise_schedule(self.betas, self.timesteps) # proposed in the paper, summed to time_next @@ -65,27 +70,27 @@ def __init__( self.time_difference = time_difference - def set_noise_schedule(self, betas, timesteps): # define beta schedule self.betas = linear_beta_schedule(timesteps=timesteps, beta_end=0.05) - # define alphas - alphas = 1. - self.betas + # define alphas + alphas = 1.0 - self.betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) - + self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) - #sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) + # sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) # calculations for posterior q(x_{t-1} | x_t, x_0) - self.posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - + self.posterior_variance = ( + betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ) def q_sample(self, x_start, t, noise=None): """ @@ -97,20 +102,19 @@ def q_sample(self, x_start, t, noise=None): sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape) sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, x_start.shape - ) + ) return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise - @torch.no_grad() def p_sample(self, x, t, t_index): betas_t = extract(self.betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, x.shape ) - #print (x.shape, 'x_shape') + # print (x.shape, 'x_shape') sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape) - + # Equation 11 in the paper # Use our model (noise predictor) to predict the mean model_mean = sqrt_recip_alphas_t * ( @@ -123,13 +127,12 @@ def p_sample(self, x, t, t_index): posterior_variance_t = extract(self.posterior_variance, t, x.shape) noise = torch.randn_like(x) # Algorithm 2 line 4: - return model_mean + torch.sqrt(posterior_variance_t) * noise - + return model_mean + torch.sqrt(posterior_variance_t) * noise @torch.no_grad() def p_sample_guided(self, x, classes, t, t_index, context_mask, cond_weight=0.0): # adapted from: https://openreview.net/pdf?id=qw8AKxfYbI - #print (classes[0]) + # print (classes[0]) batch_size = x.shape[0] # double to do guidance with t_double = t.repeat(2) @@ -139,20 +142,21 @@ def p_sample_guided(self, x, classes, t, t_index, context_mask, cond_weight=0.0) self.sqrt_one_minus_alphas_cumprod, t_double, x_double.shape ) sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t_double, x_double.shape) - + # classifier free sampling interpolates between guided and non guided using `cond_weight` classes_masked = classes * context_mask classes_masked = classes_masked.type(torch.long) - #print ('class masked', classes_masked) + # print ('class masked', classes_masked) preds = self.model(x_double, time=t_double, classes=classes_masked) eps1 = (1 + cond_weight) * preds[:batch_size] eps2 = cond_weight * preds[batch_size:] x_t = eps1 - eps2 - + # Equation 11 in the paper # Use our model (noise predictor) to predict the mean model_mean = sqrt_recip_alphas_t[:batch_size] * ( - x - betas_t[:batch_size] * x_t / sqrt_one_minus_alphas_cumprod_t[:batch_size] + x + - betas_t[:batch_size] * x_t / sqrt_one_minus_alphas_cumprod_t[:batch_size] ) if t_index == 0: @@ -161,8 +165,7 @@ def p_sample_guided(self, x, classes, t, t_index, context_mask, cond_weight=0.0) posterior_variance_t = extract(self.posterior_variance, t, x.shape) noise = torch.randn_like(x) # Algorithm 2 line 4: - return model_mean + torch.sqrt(posterior_variance_t) * noise - + return model_mean + torch.sqrt(posterior_variance_t) * noise # Algorithm 2 but save all images: @torch.no_grad() @@ -173,7 +176,7 @@ def p_sample_loop(self, classes, shape, cond_weight): # start from pure noise (for each example in the batch) image = torch.randn(shape, device=device) images = [] - + if classes is not None: n_sample = classes.shape[0] context_mask = torch.ones_like(classes).to(device) @@ -181,31 +184,50 @@ def p_sample_loop(self, classes, shape, cond_weight): # double the batch classes = classes.repeat(2) context_mask = context_mask.repeat(2) - context_mask[n_sample:] = 0. # makes second half of batch context free - sampling_fn = partial(self.p_sample_guided, classes=classes, cond_weight=cond_weight, context_mask=context_mask) + context_mask[n_sample:] = 0.0 # makes second half of batch context free + sampling_fn = partial( + self.p_sample_guided, + classes=classes, + cond_weight=cond_weight, + context_mask=context_mask, + ) else: sampling_fn = partial(self.p_sample) - - for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps): - image = sampling_fn(self.model, x=image, t=torch.full((b,), i, device=device, dtype=torch.long), t_index=i) + + for i in tqdm( + reversed(range(0, self.timesteps)), + desc="sampling loop time step", + total=self.timesteps, + ): + image = sampling_fn( + self.model, + x=image, + t=torch.full((b,), i, device=device, dtype=torch.long), + t_index=i, + ) images.append(image.cpu().numpy()) return images - @torch.no_grad() - def sample(self, image_size, classes=None, batch_size=16, channels=3, cond_weight=0): - return self.p_sample_loop(self.model, classes=classes, shape=(batch_size, channels, 4, image_size), cond_weight=cond_weight) - + def sample( + self, image_size, classes=None, batch_size=16, channels=3, cond_weight=0 + ): + return self.p_sample_loop( + self.model, + classes=classes, + shape=(batch_size, channels, 4, image_size), + cond_weight=cond_weight, + ) def training_step(self, batch: torch.Tensor, batch_idx: int): x_start, condition = extract_data_from_batch(batch) if noise is None: noise = torch.randn_like(x_start) - x_noisy =self.q_sample(x_start=x_start, t=self.timesteps, noise=noise) - - #calculating generic loss function, we'll add it to the class constructor once we have the code - #we should log more metrics at train and validation e.g. l1, l2 and other suggestions + x_noisy = self.q_sample(x_start=x_start, t=self.timesteps, noise=noise) + + # calculating generic loss function, we'll add it to the class constructor once we have the code + # we should log more metrics at train and validation e.g. l1, l2 and other suggestions if self.use_fp16: with torch.cuda.amp.autocast(): if self.is_conditional: @@ -219,32 +241,37 @@ def training_step(self, batch: torch.Tensor, batch_idx: int): predicted_noise = self.model(x_noisy, self.timesteps) loss = self.criterion(predicted_noise, noise) - self.log('train', loss, batch_size=batch.shape[0]) + self.log("train", loss, batch_size=batch.shape[0]) return loss - def validation_step(self, batch: torch.Tensor, batch_idx: int): - return self.inference_step(batch, batch_idx, 'validation') - + return self.inference_step(batch, batch_idx, "validation") def test_step(self, batch: torch.Tensor, batch_idx: int): - return self.inference_step(batch, batch_idx, 'test') - + return self.inference_step(batch, batch_idx, "test") - def inference_step(self, batch: torch.Tensor, batch_idx: int, phase='validation', noise=None): + def inference_step( + self, batch: torch.Tensor, batch_idx: int, phase="validation", noise=None + ): x_start, condition = extract_data_from_batch(batch) device = x_start.device batch_size = batch.shape[0] - t = torch.randint(0, self.timesteps, (batch_size,), device=device).long() # sampling a t to generate t and t+1 + t = torch.randint( + 0, self.timesteps, (batch_size,), device=device + ).long() # sampling a t to generate t and t+1 if noise is None: - noise = torch.randn_like(x_start) # gauss noise - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) #this is the auto generated noise given t and Noise - - context_mask = torch.bernoulli(torch.zeros(classes.shape[0]) + (1 - self.p_uncond)).to(device) - + noise = torch.randn_like(x_start) # gauss noise + x_noisy = self.q_sample( + x_start=x_start, t=t, noise=noise + ) # this is the auto generated noise given t and Noise + + context_mask = torch.bernoulli( + torch.zeros(classes.shape[0]) + (1 - self.p_uncond) + ).to(device) + # mask for unconditinal guidance classes = classes * context_mask classes = classes.type(torch.long) @@ -253,11 +280,13 @@ def inference_step(self, batch: torch.Tensor, batch_idx: int, phase='validation' loss = self.criterion(predictions, batch) - self.log('validation_loss', loss) if phase == 'validation' else self.log('test_loss', loss) + self.log("validation_loss", loss) if phase == "validation" else self.log( + "test_loss", loss + ) - ''' + """ Log multiple losses at validation/test time according to internal discussions. - ''' + """ return predictions @@ -265,9 +294,7 @@ def p2_weighting(self, x_t, ts, target, prediction): """ From Perception Prioritized Training of Diffusion Models: https://arxiv.org/abs/2204.00227. """ - weight = ( - 1 / (self.p2_k + self.snr) ** self.p2_gamma, ts, x_t.shape - ) + weight = (1 / (self.p2_k + self.snr) ** self.p2_gamma, ts, x_t.shape) loss_batch = mean_flat(weight * (target - prediction) ** 2) loss = torch.mean(loss_batch) return loss diff --git a/src/models/diffusion/diffusion.py b/src/models/diffusion/diffusion.py index 9770da7d..dc508cce 100644 --- a/src/models/diffusion/diffusion.py +++ b/src/models/diffusion/diffusion.py @@ -4,6 +4,7 @@ from utils.ema import EMA + class DiffusionModel(pl.LightningModule): def __init__( self, diff --git a/src/models/networks/unet_bitdiffusion.py b/src/models/networks/unet_bitdiffusion.py index b45b6491..79b96cd2 100644 --- a/src/models/networks/unet_bitdiffusion.py +++ b/src/models/networks/unet_bitdiffusion.py @@ -1,15 +1,16 @@ # each layer has a time embedding AND class conditioned embedding + class UNet(nn.Module): def __init__( self, dim, - init_dim = None, + init_dim=None, dim_mults=(1, 2, 4, 8), - channels = 3, - bits = BITS, - resnet_block_groups = 8, - learned_sinusoidal_dim = 16, + channels=3, + bits=BITS, + resnet_block_groups=8, + learned_sinusoidal_dim=16, num_classes=10, class_embed_dim=3, ): @@ -17,23 +18,25 @@ def __init__( # determine dimensions - channels *= bits #lucas - self.channels = channels *2 + channels *= bits # lucas + self.channels = channels * 2 input_channels = channels * 2 - #input_channels =16 + # input_channels =16 - init_dim = default(init_dim, dim) - #self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) # original TODO for zach: is there a difference? - self.init_conv = nn.Conv2d(input_channels, init_dim, (7,7), padding = 3) + # self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) # original TODO for zach: is there a difference? + self.init_conv = nn.Conv2d(input_channels, init_dim, (7, 7), padding=3) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] - in_out = list(zip(dims[:-1], dims[1:])) - block_klass = partial(ResnetBlockClassConditioned, groups=resnet_block_groups, - num_classes=num_classes, class_embed_dim=class_embed_dim) + block_klass = partial( + ResnetBlockClassConditioned, + groups=resnet_block_groups, + num_classes=num_classes, + class_embed_dim=class_embed_dim, + ) # time embeddings @@ -46,7 +49,7 @@ def __init__( sinu_pos_emb, nn.Linear(fourier_dim, time_dim), nn.GELU(), - nn.Linear(time_dim, time_dim) + nn.Linear(time_dim, time_dim), ) # layers @@ -55,46 +58,57 @@ def __init__( num_resolutions = len(in_out) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) - self.downs.append(nn.ModuleList([ - block_klass(dim_in, dim_in, time_emb_dim = time_dim), - block_klass(dim_in, dim_in, time_emb_dim = time_dim), - Residual(PreNorm(dim_in, LinearAttention(dim_in))), - Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) - ])) + self.downs.append( + nn.ModuleList( + [ + block_klass(dim_in, dim_in, time_emb_dim=time_dim), + block_klass(dim_in, dim_in, time_emb_dim=time_dim), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) + if not is_last + else nn.Conv2d(dim_in, dim_out, 3, padding=1), + ] + ) + ) mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) - self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = ind == (len(in_out) - 1) - self.ups.append(nn.ModuleList([ - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - Residual(PreNorm(dim_out, LinearAttention(dim_out))), - Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) - ])) - - self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.ups.append( + nn.ModuleList( + [ + block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) + if not is_last + else nn.Conv2d(dim_out, dim_in, 3, padding=1), + ] + ) + ) - #self.final_conv = nn.Conv2d(dim, 1, 1) #lucas - self.final_conv = nn.Conv2d(dim,8, 1) + self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) + # self.final_conv = nn.Conv2d(dim, 1, 1) #lucas + self.final_conv = nn.Conv2d(dim, 8, 1) - def forward(self, x, time, c, x_self_cond = None): - #print(x.shape) - #c = torch.zeros_like(c) # removing the conditioning LUCAS + def forward(self, x, time, c, x_self_cond=None): + # print(x.shape) + # c = torch.zeros_like(c) # removing the conditioning LUCAS x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) - + x = torch.cat((x_self_cond, x), dim=1) x = self.init_conv(x) r = x.clone() t = self.time_mlp(time) - + # todo class mask h = [] @@ -113,19 +127,19 @@ def forward(self, x, time, c, x_self_cond = None): x = self.mid_block2(x, t, c) for block1, block2, attn, upsample in self.ups: - x = torch.cat((x, h.pop()), dim = 1) + x = torch.cat((x, h.pop()), dim=1) x = block1(x, t, c) - x = torch.cat((x, h.pop()), dim = 1) + x = torch.cat((x, h.pop()), dim=1) x = block2(x, t, c) x = attn(x) x = upsample(x) - x = torch.cat((x, r), dim = 1) + x = torch.cat((x, r), dim=1) x = self.final_res_block(x, t, c) - + x = self.final_conv(x) - #print(x.shape, 'final') - return x \ No newline at end of file + # print(x.shape, 'final') + return x diff --git a/src/models/networks/unet_lucas.py b/src/models/networks/unet_lucas.py index ddef16a1..3e6a34e8 100644 --- a/src/models/networks/unet_lucas.py +++ b/src/models/networks/unet_lucas.py @@ -17,7 +17,7 @@ def __init__(self, fn): def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x - + class LayerNorm(nn.Module): def __init__(self, dim): @@ -26,11 +26,11 @@ def __init__(self, dim): def forward(self, x): eps = 1e-5 if x.dtype == torch.float32 else 1e-3 - var = torch.var(x, dim = 1, unbiased = False, keepdim = True) - mean = torch.mean(x, dim = 1, keepdim = True) + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) * (var + eps).rsqrt() * self.g - + class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() @@ -41,10 +41,11 @@ def forward(self, x): x = self.norm(x) return self.fn(x) - + # positional embeds class LearnedSinusoidalPositionalEmbedding(nn.Module): - """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """following @crowsonkb 's lead with learned sinusoidal pos emb""" + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ def __init__(self, dim): @@ -54,22 +55,22 @@ def __init__(self, dim): self.weights = nn.Parameter(torch.randn(half_dim)) def forward(self, x): - x = rearrange(x, 'b -> b 1') - freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi - fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) - fouriered = torch.cat((x, fouriered), dim = -1) + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) return fouriered - + # building block modules class Block(nn.Module): - def __init__(self, dim, dim_out, groups = 8): + def __init__(self, dim, dim_out, groups=8): super().__init__() - self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) + self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() - def forward(self, x, scale_shift = None): + def forward(self, x, scale_shift=None): x = self.proj(x) x = self.norm(x) @@ -80,84 +81,86 @@ def forward(self, x, scale_shift = None): x = self.act(x) return x - + class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): super().__init__() - self.mlp = nn.Sequential( - nn.SiLU(), - nn.Linear(time_emb_dim, dim_out * 2) - ) if exists(time_emb_dim) else None + self.mlp = ( + nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) + if exists(time_emb_dim) + else None + ) - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() - def forward(self, x, time_emb = None): + def forward(self, x, time_emb=None): scale_shift = None if exists(self.mlp) and exists(time_emb): time_emb = self.mlp(time_emb) - time_emb = rearrange(time_emb, 'b c -> b c 1 1') - scale_shift = time_emb.chunk(2, dim = 1) + time_emb = rearrange(time_emb, "b c -> b c 1 1") + scale_shift = time_emb.chunk(2, dim=1) - h = self.block1(x, scale_shift = scale_shift) + h = self.block1(x, scale_shift=scale_shift) h = self.block2(h) return h + self.res_conv(x) - + class LinearAttention(nn.Module): - def __init__(self, dim, heads = 4, dim_head = 32): + def __init__(self, dim, heads=4, dim_head=32): super().__init__() - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - self.to_out = nn.Sequential( - nn.Conv2d(hidden_dim, dim, 1), - LayerNorm(dim) - ) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) def forward(self, x): b, c, h, w = x.shape - qkv = self.to_qkv(x).chunk(3, dim = 1) - q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) - q = q.softmax(dim = -2) - k = k.softmax(dim = -1) + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) q = q * self.scale v = v / (h * w) - context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + context = torch.einsum("b h d n, b h e n -> b h d e", k, v) - out = torch.einsum('b h d e, b h d n -> b h e n', context, q) - out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) + out = torch.einsum("b h d e, b h d n -> b h e n", context, q) + out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) return self.to_out(out) - + class Attention(nn.Module): - def __init__(self, dim, heads = 4, dim_head = 32, scale = 10): + def __init__(self, dim, heads=4, dim_head=32, scale=10): super().__init__() self.scale = scale self.heads = heads hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape - qkv = self.to_qkv(x).chunk(3, dim = 1) - q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) q, k = map(l2norm, (q, k)) - sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale - attn = sim.softmax(dim = -1) - out = einsum('b h i j, b h d j -> b h i d', attn, v) - out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) + sim = einsum("b h d i, b h d j -> b h i j", q, k) * self.scale + attn = sim.softmax(dim=-1) + out = einsum("b h i j, b h d j -> b h i d", attn, v) + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) return self.to_out(out) @@ -167,7 +170,7 @@ def __init__( dim: int, init_dim: int = None, dim_mults: Optional[int, list] = (1, 2, 4), - channels: int = 1, + channels: int = 1, resnet_block_groups: int = 8, learned_sinusoidal_dim: int = 18, num_classes: int = 10, @@ -183,7 +186,7 @@ def __init__( input_channels = channels * 2 init_dim = default(init_dim, dim) - self.init_conv = nn.Conv2d(input_channels, init_dim, (7,7), padding=3) + self.init_conv = nn.Conv2d(input_channels, init_dim, (7, 7), padding=3) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] @@ -195,14 +198,14 @@ def __init__( sinu_pos_emb = LearnedSinusoidalPositionalEmbedding(learned_sinusoidal_dim) fourier_dim = learned_sinusoidal_dim + 1 - + self.time_mlp = nn.Sequential( sinu_pos_emb, nn.Linear(fourier_dim, time_dim), nn.GELU(), - nn.Linear(time_dim, time_dim) + nn.Linear(time_dim, time_dim), ) - + if num_classes is not None: self.label_emb = nn.Embedding(num_classes, time_dim) @@ -213,40 +216,51 @@ def __init__( for index, (dim_in, dim_out) in enumerate(in_out): is_last = index >= (num_resolutions - 1) - self.downs.append(nn.ModuleList([ - resnet_block(dim_in, dim_in, time_emb_dim = time_dim), - resnet_block(dim_in, dim_in, time_emb_dim = time_dim), - - (PreNorm(dim_in, LinearAttention(dim_in))), - Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1) - ])) + self.downs.append( + nn.ModuleList( + [ + resnet_block(dim_in, dim_in, time_emb_dim=time_dim), + resnet_block(dim_in, dim_in, time_emb_dim=time_dim), + (PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) + if not is_last + else nn.Conv2d(dim_in, dim_out, 3, padding=1), + ] + ) + ) mid_dim = dims[-1] - self.mid_block1 = resnet_block(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block1 = resnet_block(mid_dim, mid_dim, time_emb_dim=time_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) - self.mid_block2 = resnet_block(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block2 = resnet_block(mid_dim, mid_dim, time_emb_dim=time_dim) for index, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = index == (len(in_out) - 1) - self.ups.append(nn.ModuleList([ - resnet_block(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - resnet_block(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - Residual(PreNorm(dim_out, LinearAttention(dim_out))), - Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1) - ])) - - self.final_res_block = resnet_block(dim * 2, dim, time_emb_dim = time_dim) + self.ups.append( + nn.ModuleList( + [ + resnet_block(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + resnet_block(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) + if not is_last + else nn.Conv2d(dim_out, dim_in, 3, padding=1), + ] + ) + ) + + self.final_res_block = resnet_block(dim * 2, dim, time_emb_dim=time_dim) self.final_conv = nn.Conv2d(dim, 1, 1) - - def forward(self, x, time, classes, x_self_cond = None): + + def forward(self, x, time, classes, x_self_cond=None): x = self.init_conv(x) r = x.clone() t_start = self.time_mlp(time) t_mid = t_start.clone() t_end = t_start.clone() - + if classes is not None: t_start += self.label_emb(classes) t_mid += self.label_emb(classes) @@ -269,16 +283,16 @@ def forward(self, x, time, classes, x_self_cond = None): x = self.mid_block2(x, t_mid) for block1, block2, attn, upsample in self.ups: - x = torch.cat((x, h.pop()), dim = 1) + x = torch.cat((x, h.pop()), dim=1) x = block1(x, t_mid) - x = torch.cat((x, h.pop()), dim = 1) + x = torch.cat((x, h.pop()), dim=1) x = block2(x, t_mid) x = attn(x) x = upsample(x) - x = torch.cat((x, r), dim = 1) + x = torch.cat((x, r), dim=1) x = self.final_res_block(x, t_end) x = self.final_conv(x) diff --git a/src/models/networks/unet_lucas_cond.py b/src/models/networks/unet_lucas_cond.py index 246dddb0..366e3f91 100644 --- a/src/models/networks/unet_lucas_cond.py +++ b/src/models/networks/unet_lucas_cond.py @@ -11,6 +11,7 @@ # Building blocks of UNET + class Residual(nn.Module): def __init__(self, fn): super().__init__() @@ -20,14 +21,14 @@ def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x -def Upsample(dim, dim_out = None): +def Upsample(dim, dim_out=None): return nn.Sequential( - nn.Upsample(scale_factor = 2, mode = 'nearest'), - nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(dim, default(dim_out, dim), 3, padding=1), ) -def Downsample(dim, dim_out = None): +def Downsample(dim, dim_out=None): return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) @@ -38,8 +39,8 @@ def __init__(self, dim): def forward(self, x): eps = 1e-5 if x.dtype == torch.float32 else 1e-3 - var = torch.var(x, dim = 1, unbiased = False, keepdim = True) - mean = torch.mean(x, dim = 1, keepdim = True) + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) * (var + eps).rsqrt() * self.g @@ -56,8 +57,10 @@ def forward(self, x): # Building blocks of UNET, positional embeddings + class LearnedSinusoidalPosEmb(nn.Module): - """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """following @crowsonkb 's lead with learned sinusoidal pos emb""" + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ def __init__(self, dim): @@ -67,41 +70,38 @@ def __init__(self, dim): self.weights = nn.Parameter(torch.randn(half_dim)) def forward(self, x): - x = rearrange(x, 'b -> b 1') - freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi - fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) - fouriered = torch.cat((x, fouriered), dim = -1) + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) return fouriered class EmbedFC(nn.Module): def __init__(self, input_dim, emb_dim): super(EmbedFC, self).__init__() - ''' + """ generic one layer FC NN for embedding things - ''' + """ self.input_dim = input_dim - layers = [ - nn.Linear(input_dim, emb_dim), - nn.GELU(), - nn.Linear(emb_dim, emb_dim) - ] + layers = [nn.Linear(input_dim, emb_dim), nn.GELU(), nn.Linear(emb_dim, emb_dim)] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) - + # Building blocks of UNET, convolution + group norm blocks + class Block(nn.Module): - def __init__(self, dim, dim_out, groups = 8): + def __init__(self, dim, dim_out, groups=8): super().__init__() - self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) + self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() - def forward(self, x, scale_shift = None): + def forward(self, x, scale_shift=None): x = self.proj(x) x = self.norm(x) @@ -115,27 +115,29 @@ def forward(self, x, scale_shift = None): # Building blocks of UNET, residual blocks + class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): super().__init__() - self.mlp = nn.Sequential( - nn.SiLU(), - nn.Linear(time_emb_dim, dim_out * 2) - ) if exists(time_emb_dim) else None + self.mlp = ( + nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) + if exists(time_emb_dim) + else None + ) - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() - def forward(self, x, time_emb = None): + def forward(self, x, time_emb=None): scale_shift = None if exists(self.mlp) and exists(time_emb): time_emb = self.mlp(time_emb) - time_emb = rearrange(time_emb, 'b c -> b c 1 1') - scale_shift = time_emb.chunk(2, dim = 1) + time_emb = rearrange(time_emb, "b c -> b c 1 1") + scale_shift = time_emb.chunk(2, dim=1) - h = self.block1(x, scale_shift = scale_shift) + h = self.block1(x, scale_shift=scale_shift) h = self.block2(h) @@ -144,12 +146,19 @@ def forward(self, x, time_emb = None): # Additional code to the https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py + class ResnetBlockClassConditioned(ResnetBlock): - def __init__(self, dim, dim_out, *, num_classes, class_embed_dim, time_emb_dim = None, groups = 8): - super().__init__(dim=dim+class_embed_dim, dim_out=dim_out, time_emb_dim=time_emb_dim, groups=groups) + def __init__( + self, dim, dim_out, *, num_classes, class_embed_dim, time_emb_dim=None, groups=8 + ): + super().__init__( + dim=dim + class_embed_dim, + dim_out=dim_out, + time_emb_dim=time_emb_dim, + groups=groups, + ) self.class_mlp = EmbedFC(num_classes, class_embed_dim) - - + def forward(self, x, time_emb=None, c=None): emb_c = self.class_mlp(c) emb_c = emb_c.view(*emb_c.shape, 1, 1) @@ -161,86 +170,89 @@ def forward(self, x, time_emb=None, c=None): # Building blocks of UNET, attention modules + class LinearAttention(nn.Module): - def __init__(self, dim, heads = 4, dim_head = 32): + def __init__(self, dim, heads=4, dim_head=32): super().__init__() - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - self.to_out = nn.Sequential( - nn.Conv2d(hidden_dim, dim, 1), - LayerNorm(dim) - ) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) def forward(self, x): b, c, h, w = x.shape - qkv = self.to_qkv(x).chunk(3, dim = 1) - q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) - q = q.softmax(dim = -2) - k = k.softmax(dim = -1) + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) q = q * self.scale v = v / (h * w) - context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + context = torch.einsum("b h d n, b h e n -> b h d e", k, v) - out = torch.einsum('b h d e, b h d n -> b h e n', context, q) - out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) + out = torch.einsum("b h d e, b h d n -> b h e n", context, q) + out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) return self.to_out(out) class Attention(nn.Module): - def __init__(self, dim, heads = 4, dim_head = 32, scale = 10): + def __init__(self, dim, heads=4, dim_head=32, scale=10): super().__init__() self.scale = scale self.heads = heads hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape - qkv = self.to_qkv(x).chunk(3, dim = 1) - q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) q, k = map(l2norm, (q, k)) - sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale - attn = sim.softmax(dim = -1) - out = einsum('b h i j, b h d j -> b h i d', attn, v) - out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) + sim = einsum("b h d i, b h d j -> b h i j", q, k) * self.scale + attn = sim.softmax(dim=-1) + out = einsum("b h i j, b h d j -> b h i d", attn, v) + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) return self.to_out(out) # Core part of UNET + class UNet(nn.Module): - """ + """ Refer to the main paper for the architecture details https://arxiv.org/pdf/2208.04202.pdf """ + def __init__( self, dim: int, init_dim: int = 200, dim_mults: Optional[list] = [1, 2, 4], - channels = 1, - resnet_block_groups: int = 8, + channels=1, + resnet_block_groups: int = 8, learned_sinusoidal_dim: int = 18, num_classes: int = 10, - class_embed_dim: bool =3, + class_embed_dim: bool = 3, ): super().__init__() self.channels = channels # if you want to do self conditioning uncomment this - #input_channels = channels * 2 + # input_channels = channels * 2 input_channels = channels - init_dim = default(init_dim, dim) - self.init_conv = nn.Conv2d(input_channels, init_dim, (7,7), padding = 3) + self.init_conv = nn.Conv2d(input_channels, init_dim, (7, 7), padding=3) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) @@ -250,14 +262,14 @@ def __init__( sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim) fourier_dim = learned_sinusoidal_dim + 1 - + self.time_mlp = nn.Sequential( sinu_pos_emb, nn.Linear(fourier_dim, time_dim), nn.GELU(), - nn.Linear(time_dim, time_dim) + nn.Linear(time_dim, time_dim), ) - + if num_classes is not None: self.label_emb = nn.Embedding(num_classes, time_dim) @@ -269,42 +281,54 @@ def __init__( for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) - self.downs.append(nn.ModuleList([ - block_klass(dim_in, dim_in, time_emb_dim = time_dim), - block_klass(dim_in, dim_in, time_emb_dim = time_dim), - Residual(PreNorm(dim_in, LinearAttention(dim_in))), - Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) - ])) + self.downs.append( + nn.ModuleList( + [ + block_klass(dim_in, dim_in, time_emb_dim=time_dim), + block_klass(dim_in, dim_in, time_emb_dim=time_dim), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) + if not is_last + else nn.Conv2d(dim_in, dim_out, 3, padding=1), + ] + ) + ) mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) - self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = ind == (len(in_out) - 1) - self.ups.append(nn.ModuleList([ - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - Residual(PreNorm(dim_out, LinearAttention(dim_out))), - Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) - ])) - - self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.ups.append( + nn.ModuleList( + [ + block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) + if not is_last + else nn.Conv2d(dim_out, dim_in, 3, padding=1), + ] + ) + ) + + self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) self.final_conv = nn.Conv2d(dim, 1, 1) - print('final',dim, channels, self.final_conv) + print("final", dim, channels, self.final_conv) -# Additional code to the https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py mostly in forward method. + # Additional code to the https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py mostly in forward method. - def forward(self, x, time, classes, x_self_cond = None): + def forward(self, x, time, classes, x_self_cond=None): x = self.init_conv(x) r = x.clone() t_start = self.time_mlp(time) t_mid = t_start.clone() t_end = t_start.clone() - + if classes is not None: t_start += self.label_emb(classes) t_mid += self.label_emb(classes) @@ -327,18 +351,17 @@ def forward(self, x, time, classes, x_self_cond = None): x = self.mid_block2(x, t_mid) for block1, block2, attn, upsample in self.ups: - x = torch.cat((x, h.pop()), dim = 1) + x = torch.cat((x, h.pop()), dim=1) x = block1(x, t_mid) - x = torch.cat((x, h.pop()), dim = 1) + x = torch.cat((x, h.pop()), dim=1) x = block2(x, t_mid) x = attn(x) x = upsample(x) - - x = torch.cat((x, r), dim = 1) + x = torch.cat((x, r), dim=1) x = self.final_res_block(x, t_end) x = self.final_conv(x) - return x \ No newline at end of file + return x diff --git a/src/train.py b/src/train.py index 0ffe9bc3..e88cc029 100644 --- a/src/train.py +++ b/src/train.py @@ -9,6 +9,7 @@ logger = logging.getLogger() + def get_parser(**parser_kwargs): parser = argparse.ArgumentParser(**parser_kwargs) parser.add_argument( @@ -23,13 +24,19 @@ def get_parser(**parser_kwargs): ) return parser + @hydra.main(config_path="configs", config_name="train") def train(cfg: DictConfig): parser = get_parser() - + # Keeping track of current config settings in logger logger.info(f"Training with config:\n{OmegaConf.to_yaml(cfg)}") - run = wandb.init(name=parser.logdir, save_dir=parser.logdir, project=cfg.logger.wandb.project, config=cfg) + run = wandb.init( + name=parser.logdir, + save_dir=parser.logdir, + project=cfg.logger.wandb.project, + config=cfg, + ) # Placeholder for what loss or metric values we plan to track with wandb wandb.log({"loss": loss}) @@ -48,10 +55,11 @@ def train(cfg: DictConfig): callbacks=cfg.callbacks, accelerator=cfg.accelerator, devices=cfg.devices, - logger=cfg.logger.wandb + logger=cfg.logger.wandb, ) trainer.fit(model, train_dl, val_dl) + if __name__ == "__main__": - train() \ No newline at end of file + train() diff --git a/src/utils/ema.py b/src/utils/ema.py index 026c64c2..b5642e83 100644 --- a/src/utils/ema.py +++ b/src/utils/ema.py @@ -5,7 +5,9 @@ def __init__(self, beta): self.step = 0 def update_model_average(self, ma_model, current_model): - for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + for current_params, ma_params in zip( + current_model.parameters(), ma_model.parameters() + ): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = self.update_average(old_weight, up_weight) @@ -23,4 +25,4 @@ def step_ema(self, ema_model, model, step_start_ema=2000): self.step += 1 def reset_parameters(self, ema_model, model): - ema_model.load_state_dict(model.state_dict()) \ No newline at end of file + ema_model.load_state_dict(model.state_dict()) diff --git a/src/utils/metrics.py b/src/utils/metrics.py index f8307082..d8a3496c 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -8,9 +8,10 @@ import os -def motif_scoring_KL_divergence(original: pd.Series, - generated: pd.Series) -> torch.Tensor: - +def motif_scoring_KL_divergence( + original: pd.Series, generated: pd.Series +) -> torch.Tensor: + """ This function encapsulates the logic of evaluating the KL divergence metric between two sequences. @@ -25,91 +26,125 @@ def motif_scoring_KL_divergence(original: pd.Series, return np.sum(kl_pq) -def compare_motif_list(df_motifs_a, df_motifs_b, motif_scoring_metric=motif_scoring_KL_divergence, plot_motif_probs=False): - """ +def compare_motif_list( + df_motifs_a, + df_motifs_b, + motif_scoring_metric=motif_scoring_KL_divergence, + plot_motif_probs=False, +): + """ This function encapsulates the logic of evaluating the difference between the distribution of frequencies between generated (diffusion/df_motifs_a) and the input (training/df_motifs_b) for an arbitrary metric ("motif_scoring_metric") - Please note that some metrics, like KL_divergence, are not metrics in official sense. Reason + Please note that some metrics, like KL_divergence, are not metrics in official sense. Reason for that is that they dont satisfy certain properties, such as in KL case, the simmetry property. Hence it makes a big difference what are the positions of input. """ - set_all_mot = set(df_motifs_a.index.values.tolist() + df_motifs_b.index.values.tolist()) + set_all_mot = set( + df_motifs_a.index.values.tolist() + df_motifs_b.index.values.tolist() + ) create_new_matrix = [] for x in set_all_mot: list_in = [] - list_in.append(x) # adding the name + list_in.append(x) # adding the name if x in df_motifs_a.index: list_in.append(df_motifs_a.loc[x][0]) else: - list_in.append(1) - + list_in.append(1) + if x in df_motifs_b.index: list_in.append(df_motifs_b.loc[x][0]) else: - list_in.append(1) - - create_new_matrix.append(list_in) - - - df_motifs = pd.DataFrame(create_new_matrix, columns=['motif', 'motif_a', 'motif_b']) - - df_motifs['Diffusion_seqs'] = df_motifs['motif_a'] / df_motifs['motif_a'].sum() - df_motifs['Training_seqs'] = df_motifs['motif_b'] / df_motifs['motif_b'].sum() - if plot_motif_probs: - plt.rcParams["figure.figsize"] = (3,3) - sns.regplot(x='Diffusion_seqs', y='Training_seqs',data=df_motifs) - plt.xlabel('Diffusion Seqs') - plt.ylabel('Training Seqs') - plt.title('Motifs Probs') - plt.show() + list_in.append(1) - return motif_scoring_metric(df_motifs['Diffusion_seqs'].values, df_motifs['Training_seqs'].values) + create_new_matrix.append(list_in) + df_motifs = pd.DataFrame(create_new_matrix, columns=["motif", "motif_a", "motif_b"]) -def sampling_to_metric(model, cell_types, image_size, nucleotides, number_of_samples=20, specific_group=False, - group_number=None, cond_weight_to_metric=0): - """ + df_motifs["Diffusion_seqs"] = df_motifs["motif_a"] / df_motifs["motif_a"].sum() + df_motifs["Training_seqs"] = df_motifs["motif_b"] / df_motifs["motif_b"].sum() + if plot_motif_probs: + plt.rcParams["figure.figsize"] = (3, 3) + sns.regplot(x="Diffusion_seqs", y="Training_seqs", data=df_motifs) + plt.xlabel("Diffusion Seqs") + plt.ylabel("Training Seqs") + plt.title("Motifs Probs") + plt.show() + + return motif_scoring_metric( + df_motifs["Diffusion_seqs"].values, df_motifs["Training_seqs"].values + ) + + +def sampling_to_metric( + model, + cell_types, + image_size, + nucleotides, + number_of_samples=20, + specific_group=False, + group_number=None, + cond_weight_to_metric=0, +): + """ Might need to add to the DDPM class since if we can't call the sample() method outside PyTorch Lightning. This function encapsulates the logic of sampling from the trained model in order to generate counts of the motifs. - The reasoning is that we are interested only in calculating the evaluation metric + The reasoning is that we are interested only in calculating the evaluation metric for the count of occurances and not the nucleic acids themselves. """ - final_sequences=[] + final_sequences = [] for n_a in tqdm(range(number_of_samples)): sample_bs = 10 if specific_group: - sampled = torch.from_numpy(np.array([group_number] * sample_bs) ) - print ('specific') + sampled = torch.from_numpy(np.array([group_number] * sample_bs)) + print("specific") else: sampled = torch.from_numpy(np.random.choice(cell_types, sample_bs)) - + random_classes = sampled.float().cuda() - sampled_images = model.sample(classes=random_classes, image_size=image_size, batch_size=sample_bs, channels=1, cond_weight=cond_weight_to_metric) + sampled_images = model.sample( + classes=random_classes, + image_size=image_size, + batch_size=sample_bs, + channels=1, + cond_weight=cond_weight_to_metric, + ) for n_b, x in enumerate(sampled_images[-1]): - seq_final = f'>seq_test_{n_a}_{n_b}\n' +''.join([nucleotides[s] for s in np.argmax(x.reshape(4,200), axis=0)]) + seq_final = f">seq_test_{n_a}_{n_b}\n" + "".join( + [nucleotides[s] for s in np.argmax(x.reshape(4, 200), axis=0)] + ) final_sequences.append(seq_final) - save_motifs_syn = open('synthetic_motifs.fasta', 'w') + save_motifs_syn = open("synthetic_motifs.fasta", "w") - save_motifs_syn.write('\n'.join(final_sequences)) + save_motifs_syn.write("\n".join(final_sequences)) save_motifs_syn.close() - #Scan for motifs - os.system('gimme scan synthetic_motifs.fasta -p JASPAR2020_vertebrates -g hg38 > syn_results_motifs.bed') - df_results_syn = pd.read_csv('syn_results_motifs.bed', sep='\t', skiprows=5, header=None) - df_results_syn['motifs'] = df_results_syn[8].apply(lambda x: x.split( 'motif_name "')[1].split('"')[0] ) - df_results_syn[0] = df_results_syn[0].apply(lambda x : '_'.join( x.split('_')[:-1]) ) - df_motifs_count_syn = df_results_syn[[0,'motifs']].drop_duplicates().groupby('motifs').count() - plt.rcParams["figure.figsize"] = (30,2) + # Scan for motifs + os.system( + "gimme scan synthetic_motifs.fasta -p JASPAR2020_vertebrates -g hg38 > syn_results_motifs.bed" + ) + df_results_syn = pd.read_csv( + "syn_results_motifs.bed", sep="\t", skiprows=5, header=None + ) + df_results_syn["motifs"] = df_results_syn[8].apply( + lambda x: x.split('motif_name "')[1].split('"')[0] + ) + df_results_syn[0] = df_results_syn[0].apply(lambda x: "_".join(x.split("_")[:-1])) + df_motifs_count_syn = ( + df_results_syn[[0, "motifs"]].drop_duplicates().groupby("motifs").count() + ) + plt.rcParams["figure.figsize"] = (30, 2) df_motifs_count_syn.sort_values(0, ascending=False).head(50)[0].plot.bar() plt.show() return df_motifs_count_syn -def metric_comparison_between_components(original_data, generated_data, x_label_plot, y_label_plot, cell_components): +def metric_comparison_between_components( + original_data, generated_data, x_label_plot, y_label_plot, cell_components +): """ This functions takes as inputs dictionaries, which contain as keys different components (cell types) and as values the distribution of occurances of different motifs. These two dictionaries represent two different datasets, i.e. @@ -118,7 +153,7 @@ def metric_comparison_between_components(original_data, generated_data, x_label_ The goal is to then plot a the main evaluation metric (KL or otherwise) across all different types of cell types in a heatmap fashion. """ - ENUMARATED_CELL_NAME = '''7 Trophoblasts + ENUMARATED_CELL_NAME = """7 Trophoblasts 5 CD8_cells 15 CD34_cells 9 Fetal_heart @@ -133,23 +168,29 @@ def metric_comparison_between_components(original_data, generated_data, x_label_ 16 Esophageal(Cancer) 11 Fetal_Lung 10 Fetal_kidney - 1 Tissue_Invariant'''.split('\n') - CELL_NAMES = {int(x.split(' ')[0]): x.split(' ')[1] for x in ENUMARATED_CELL_NAME} + 1 Tissue_Invariant""".split( + "\n" + ) + CELL_NAMES = {int(x.split(" ")[0]): x.split(" ")[1] for x in ENUMARATED_CELL_NAME} final_comparison_all_components = [] for components_1, motif_occurance_frequency in original_data.items(): comparisons_single_component = [] for components_2 in generated_data.keys(): - compared_motifs_occurances = compare_motif_list(motif_occurance_frequency, generated_data[components_2]) + compared_motifs_occurances = compare_motif_list( + motif_occurance_frequency, generated_data[components_2] + ) comparisons_single_component.append(compared_motifs_occurances) final_comparison_all_components.append(comparisons_single_component) - plt.rcParams["figure.figsize"] = (10,10) + plt.rcParams["figure.figsize"] = (10, 10) df_plot = pd.DataFrame(final_comparison_all_components) df_plot.columns = [CELL_NAMES[x] for x in cell_components] df_plot.index = df_plot.columns - sns.heatmap(df_plot, cmap='Blues_r', annot=True, lw=0.1, vmax=1, vmin=0 ) - plt.title(f'Kl divergence \n {x_label_plot} sequences x {y_label_plot} sequences \n MOTIFS probabilities') - plt.xlabel(f'{x_label_plot} Sequences \n(motifs dist)') - plt.ylabel(f'{y_label_plot} \n (motifs dist)') \ No newline at end of file + sns.heatmap(df_plot, cmap="Blues_r", annot=True, lw=0.1, vmax=1, vmin=0) + plt.title( + f"Kl divergence \n {x_label_plot} sequences x {y_label_plot} sequences \n MOTIFS probabilities" + ) + plt.xlabel(f"{x_label_plot} Sequences \n(motifs dist)") + plt.ylabel(f"{y_label_plot} \n (motifs dist)") diff --git a/src/utils/misc.py b/src/utils/misc.py index d8eaa3c0..1215e2a6 100644 --- a/src/utils/misc.py +++ b/src/utils/misc.py @@ -7,11 +7,11 @@ def seed_everything(seed): - """" + """ " Seed everything. - """ + """ random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) + os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -71,11 +71,11 @@ def one_hot_encode(seq, nucleotides, max_seq_len): return seq_array -def log(t, eps = 1e-20): +def log(t, eps=1e-20): """ Toch log for the purporses of diffusion time steps t. """ - return torch.log(t.clamp(min = eps)) + return torch.log(t.clamp(min=eps)) def right_pad_dims_to(x, t): @@ -98,9 +98,10 @@ def get_obj_from_str(string, reload=False): importlib.reload(module_to_reload) return getattr(importlib.import_module(module, package=None), class_) + def mean_flat(tensor): """ Take the mean over all non-batch dimensions. From Perception Prioritized Training of Diffusion Models: https://arxiv.org/abs/2204.00227. """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) \ No newline at end of file + return tensor.mean(dim=list(range(1, len(tensor.shape)))) diff --git a/src/utils/network.py b/src/utils/network.py index b9d91d65..b8fdaadc 100644 --- a/src/utils/network.py +++ b/src/utils/network.py @@ -2,4 +2,4 @@ def l2norm(t): - return F.normalize(t, dim = -1) \ No newline at end of file + return F.normalize(t, dim=-1) diff --git a/src/utils/schedules.py b/src/utils/schedules.py index 82f1312b..5826af5f 100644 --- a/src/utils/schedules.py +++ b/src/utils/schedules.py @@ -4,12 +4,12 @@ def beta_linear_log_snr(t): - return -torch.log(exp(1e-4 + 10 * (t ** 2))) + return -torch.log(exp(1e-4 + 10 * (t**2))) def alpha_cosine_log_snr(t, s: float = 0.008): # not sure if this accounts for beta being clipped to 0.999 in discrete version - return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) + return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps=1e-5) def log_snr_to_alpha_sigma(log_snr): @@ -43,4 +43,4 @@ def sigmoid_beta_schedule(timesteps): beta_start = 0.001 beta_end = 0.02 betas = torch.linspace(-6, 6, timesteps) - return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start \ No newline at end of file + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start diff --git a/tests/data/test_sequence_dataloader.py b/tests/data/test_sequence_dataloader.py index b7b1ddc7..689c1e57 100644 --- a/tests/data/test_sequence_dataloader.py +++ b/tests/data/test_sequence_dataloader.py @@ -7,32 +7,54 @@ def prepare_default_data(path): """Prepares dummy data for testing.""" - pd.DataFrame({ - "raw_sequence": ["ATCGATCGATCG", "GGTGAACGATTA", "AATCGTATCGCG", "CTTATCGATCCG"], - "component": [1, 2, 1, 10], - }).to_csv(path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": [ + "ATCGATCGATCG", + "GGTGAACGATTA", + "AATCGTATCGCG", + "CTTATCGATCCG", + ], + "component": [1, 2, 1, 10], + } + ).to_csv(path, index=False, sep="\t") + def prepare_high_diversity_datasets(train_data_path, val_data_path, test_data_path): - pd.DataFrame({ - "raw_sequence": ["AAAAAAAAAA", "AAAAAAAAAA", "AAAAAAAAAA", "AAAAAAAAAA"], - "component": [0, 0, 0, 0], - }).to_csv(train_data_path, index=False, sep="\t") - pd.DataFrame({ - "raw_sequence": ["CCCCCCCCCC", "CCCCCCCCCC", "CCCCCCCCCC", "CCCCCCCCCC"], - "component": [1, 1, 1, 1], - }).to_csv(val_data_path, index=False, sep="\t") - pd.DataFrame({ - "raw_sequence": ["TTTTTTTTTT", "TTTTTTTTTT", "TTTTTTTTTT", "TTTTTTTTTT"], - "component": [2, 2, 2, 2], - }).to_csv(test_data_path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": ["AAAAAAAAAA", "AAAAAAAAAA", "AAAAAAAAAA", "AAAAAAAAAA"], + "component": [0, 0, 0, 0], + } + ).to_csv(train_data_path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": ["CCCCCCCCCC", "CCCCCCCCCC", "CCCCCCCCCC", "CCCCCCCCCC"], + "component": [1, 1, 1, 1], + } + ).to_csv(val_data_path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": ["TTTTTTTTTT", "TTTTTTTTTT", "TTTTTTTTTT", "TTTTTTTTTT"], + "component": [2, 2, 2, 2], + } + ).to_csv(test_data_path, index=False, sep="\t") + def test_invalid_sequence_letters(): # prepare invalid data dummy_data_path = "_tmp_seq_dataloader_data.csv" - pd.DataFrame({ - "raw_sequence": ["ZCCCACTGACTG", "ACTGACTGACTG", "AAAACCCCTTTT", "ABCDEFGHIJKL"], - "component": [1, 2, 1, 10], - }).to_csv(dummy_data_path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": [ + "ZCCCACTGACTG", + "ACTGACTGACTG", + "AAAACCCCTTTT", + "ABCDEFGHIJKL", + ], + "component": [1, 2, 1, 10], + } + ).to_csv(dummy_data_path, index=False, sep="\t") datamodule = SequenceDataModule( train_path=dummy_data_path, @@ -53,14 +75,17 @@ def test_invalid_sequence_letters(): # remove dummy data os.remove(dummy_data_path) + def test_invalid_sequence_lengths(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" # second sequence too short - pd.DataFrame({ - "raw_sequence": ["ATCG", "GGT", "AATC", "CTTA"], - "component": [1, 2, 1, 10], - }).to_csv(dummy_data_path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": ["ATCG", "GGT", "AATC", "CTTA"], + "component": [1, 2, 1, 10], + } + ).to_csv(dummy_data_path, index=False, sep="\t") # prepare data module datamodule = SequenceDataModule( @@ -82,10 +107,12 @@ def test_invalid_sequence_lengths(): os.remove(dummy_data_path) # fourth sequence too long - pd.DataFrame({ - "raw_sequence": ["ATCG", "GGT", "AATC", "CTTAT"], - "component": [1, 2, 1, 10], - }).to_csv(dummy_data_path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": ["ATCG", "GGT", "AATC", "CTTAT"], + "component": [1, 2, 1, 10], + } + ).to_csv(dummy_data_path, index=False, sep="\t") # prepare data module datamodule = SequenceDataModule( @@ -106,19 +133,22 @@ def test_invalid_sequence_lengths(): # remove dummy data os.remove(dummy_data_path) + def test_train_val_test_data_split(): # prepare dummy data dummy_train_data_path = "_tmp_seq_dataloader_train_data.csv" dummy_val_data_path = "_tmp_seq_dataloader_val_data.csv" dummy_test_data_path = "_tmp_seq_dataloader_test_data.csv" - prepare_high_diversity_datasets(dummy_train_data_path, dummy_val_data_path, dummy_test_data_path) + prepare_high_diversity_datasets( + dummy_train_data_path, dummy_val_data_path, dummy_test_data_path + ) # check loading of only a single data set datamodule = SequenceDataModule( train_path=None, val_path=dummy_val_data_path, test_path=None, - sequence_length=10 + sequence_length=10, ) datamodule.setup() assert datamodule.train_dataloader is None @@ -141,7 +171,13 @@ def test_train_val_test_data_split(): assert len(datamodule.val_dataloader()) == 2 assert len(datamodule.test_dataloader()) == 2 seen_nucleotide_idxs = set() - for dl_idx, dataloader in enumerate([datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader()]): + for dl_idx, dataloader in enumerate( + [ + datamodule.train_dataloader(), + datamodule.val_dataloader(), + datamodule.test_dataloader(), + ] + ): dataloader_iter = iter(dataloader) # first batch @@ -170,6 +206,7 @@ def test_train_val_test_data_split(): for path in [dummy_train_data_path, dummy_val_data_path, dummy_test_data_path]: os.remove(path) + def test_polar_encoding(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" @@ -191,7 +228,11 @@ def test_polar_encoding(): assert len(datamodule.train_data) == 4 assert len(datamodule.val_data) == 4 assert len(datamodule.test_data) == 4 - for dataloader in [datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader()]: + for dataloader in [ + datamodule.train_dataloader(), + datamodule.val_dataloader(), + datamodule.test_dataloader(), + ]: for batch in dataloader: assert len(batch) == 2 assert isinstance(batch[0], torch.Tensor) @@ -205,6 +246,7 @@ def test_polar_encoding(): # remove dummy data os.remove(dummy_data_path) + def test_onehot_encoding(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" @@ -226,7 +268,11 @@ def test_onehot_encoding(): assert len(datamodule.train_data) == 4 assert len(datamodule.val_data) == 4 assert len(datamodule.test_data) == 4 - for dataloader in [datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader()]: + for dataloader in [ + datamodule.train_dataloader(), + datamodule.val_dataloader(), + datamodule.test_dataloader(), + ]: for batch in dataloader: assert len(batch) == 2 assert isinstance(batch[0], torch.Tensor) @@ -240,6 +286,7 @@ def test_onehot_encoding(): # remove dummy data os.remove(dummy_data_path) + def test_ordinal_encoding(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" @@ -261,7 +308,11 @@ def test_ordinal_encoding(): assert len(datamodule.train_data) == 4 assert len(datamodule.val_data) == 4 assert len(datamodule.test_data) == 4 - for dataloader in [datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader()]: + for dataloader in [ + datamodule.train_dataloader(), + datamodule.val_dataloader(), + datamodule.test_dataloader(), + ]: for batch in dataloader: assert len(batch) == 2 assert isinstance(batch[0], torch.Tensor) @@ -275,6 +326,7 @@ def test_ordinal_encoding(): # remove dummy data os.remove(dummy_data_path) + def test_polar_transforms(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" @@ -283,20 +335,26 @@ def test_polar_transforms(): # prepare data module def seg_transform(seq): return seq + 1 + def cell_type_transform(cell_type): return cell_type + 20 + datamodule = SequenceDataModule( train_path=dummy_data_path, val_path=dummy_data_path, test_path=dummy_data_path, sequence_length=12, sequence_encoding="polar", - sequence_transform=transforms.Compose([ - transforms.Lambda(seg_transform), - ]), - cell_type_transform=transforms.Compose([ - transforms.Lambda(cell_type_transform), - ]), + sequence_transform=transforms.Compose( + [ + transforms.Lambda(seg_transform), + ] + ), + cell_type_transform=transforms.Compose( + [ + transforms.Lambda(cell_type_transform), + ] + ), batch_size=2, num_workers=0, ) @@ -306,7 +364,11 @@ def cell_type_transform(cell_type): assert len(datamodule.train_data) == 4 assert len(datamodule.val_data) == 4 assert len(datamodule.test_data) == 4 - for dataloader in [datamodule.train_dataloader, datamodule.val_dataloader, datamodule.test_dataloader]: + for dataloader in [ + datamodule.train_dataloader, + datamodule.val_dataloader, + datamodule.test_dataloader, + ]: seen_cell_type_ids = set() for batch in dataloader(): assert len(batch) == 2 @@ -325,6 +387,7 @@ def cell_type_transform(cell_type): # remove dummy data os.remove(dummy_data_path) + def test_onehot_transforms(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" @@ -333,20 +396,26 @@ def test_onehot_transforms(): # prepare data module def seg_transform(seq): return seq + 1 + def cell_type_transform(cell_type): return cell_type + 20 + datamodule = SequenceDataModule( train_path=dummy_data_path, val_path=dummy_data_path, test_path=dummy_data_path, sequence_length=12, sequence_encoding="onehot", - sequence_transform=transforms.Compose([ - transforms.Lambda(seg_transform), - ]), - cell_type_transform=transforms.Compose([ - transforms.Lambda(cell_type_transform), - ]), + sequence_transform=transforms.Compose( + [ + transforms.Lambda(seg_transform), + ] + ), + cell_type_transform=transforms.Compose( + [ + transforms.Lambda(cell_type_transform), + ] + ), batch_size=2, num_workers=0, ) @@ -356,7 +425,11 @@ def cell_type_transform(cell_type): assert len(datamodule.train_data) == 4 assert len(datamodule.val_data) == 4 assert len(datamodule.test_data) == 4 - for dataloader in [datamodule.train_dataloader, datamodule.val_dataloader, datamodule.test_dataloader]: + for dataloader in [ + datamodule.train_dataloader, + datamodule.val_dataloader, + datamodule.test_dataloader, + ]: seen_cell_type_ids = set() for batch in dataloader(): assert len(batch) == 2 @@ -375,6 +448,7 @@ def cell_type_transform(cell_type): # remove dummy data os.remove(dummy_data_path) + def test_ordinal_transforms(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" @@ -383,20 +457,26 @@ def test_ordinal_transforms(): # prepare data module def seg_transform(seq): return seq + 1 + def cell_type_transform(cell_type): return cell_type + 20 + datamodule = SequenceDataModule( train_path=dummy_data_path, val_path=dummy_data_path, test_path=dummy_data_path, sequence_length=12, sequence_encoding="ordinal", - sequence_transform=transforms.Compose([ - transforms.Lambda(seg_transform), - ]), - cell_type_transform=transforms.Compose([ - transforms.Lambda(cell_type_transform), - ]), + sequence_transform=transforms.Compose( + [ + transforms.Lambda(seg_transform), + ] + ), + cell_type_transform=transforms.Compose( + [ + transforms.Lambda(cell_type_transform), + ] + ), batch_size=2, num_workers=0, ) @@ -406,7 +486,11 @@ def cell_type_transform(cell_type): assert len(datamodule.train_data) == 4 assert len(datamodule.val_data) == 4 assert len(datamodule.test_data) == 4 - for dataloader in [datamodule.train_dataloader, datamodule.val_dataloader, datamodule.test_dataloader]: + for dataloader in [ + datamodule.train_dataloader, + datamodule.val_dataloader, + datamodule.test_dataloader, + ]: seen_cell_type_ids = set() for batch in dataloader(): assert len(batch) == 2 From b09d85f289a1665743cf28e12ec2458cc0e36e67 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 20 Dec 2022 20:00:13 +0530 Subject: [PATCH 07/16] feat: add type hints to `utils/misc.py` --- src/utils/misc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/utils/misc.py b/src/utils/misc.py index 1215e2a6..6ede566e 100644 --- a/src/utils/misc.py +++ b/src/utils/misc.py @@ -6,7 +6,7 @@ import numpy as np -def seed_everything(seed): +def seed_everything(seed: int) -> None: """ " Seed everything. """ @@ -60,7 +60,7 @@ def convert_image_to(img_type, image): return image -def one_hot_encode(seq, nucleotides, max_seq_len): +def one_hot_encode(seq, nucleotides, max_seq_len: int) -> np.ndarray: """ One-hot encode a sequence of nucleotides. """ @@ -71,7 +71,7 @@ def one_hot_encode(seq, nucleotides, max_seq_len): return seq_array -def log(t, eps=1e-20): +def log(t: torch.Tensor, eps=1e-20) -> torch.Tensor: """ Toch log for the purporses of diffusion time steps t. """ From 9bb49e8ee538afbcab83d90f9896184f0f5803e2 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 20 Dec 2022 14:52:16 +0000 Subject: [PATCH 08/16] feat: add type hints to utils/metrics --- src/utils/metrics.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/utils/metrics.py b/src/utils/metrics.py index d8a3496c..5c6c83e5 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt from tqdm.auto import tqdm import os +from typing import Callable, Dict def motif_scoring_KL_divergence( @@ -27,11 +28,11 @@ def motif_scoring_KL_divergence( def compare_motif_list( - df_motifs_a, - df_motifs_b, - motif_scoring_metric=motif_scoring_KL_divergence, - plot_motif_probs=False, -): + df_motifs_a: pd.DataFrame, + df_motifs_b: pd.DataFrame, + motif_scoring_metric: Callable = motif_scoring_KL_divergence, + plot_motif_probs: bool = False, +) -> torch.Tensor: """ This function encapsulates the logic of evaluating the difference between the distribution of frequencies between generated (diffusion/df_motifs_a) and the input (training/df_motifs_b) for an arbitrary metric ("motif_scoring_metric") @@ -143,8 +144,12 @@ def sampling_to_metric( def metric_comparison_between_components( - original_data, generated_data, x_label_plot, y_label_plot, cell_components -): + original_data: Dict, + generated_data: Dict, + x_label_plot: str, + y_label_plot: str, + cell_components, +) -> None: """ This functions takes as inputs dictionaries, which contain as keys different components (cell types) and as values the distribution of occurances of different motifs. These two dictionaries represent two different datasets, i.e. From 104c96d2f28c411204d913c94ac66b2abcb8d478 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 20 Dec 2022 14:54:22 +0000 Subject: [PATCH 09/16] feat: add type hints to utils/schedules --- src/utils/schedules.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/utils/schedules.py b/src/utils/schedules.py index 5826af5f..abedce5a 100644 --- a/src/utils/schedules.py +++ b/src/utils/schedules.py @@ -3,16 +3,16 @@ import torch -def beta_linear_log_snr(t): +def beta_linear_log_snr(t: torch.Tensor) -> torch.Tensor: return -torch.log(exp(1e-4 + 10 * (t**2))) -def alpha_cosine_log_snr(t, s: float = 0.008): +def alpha_cosine_log_snr(t: torch.Tensor, s: float = 0.008) -> torch.Tensor: # not sure if this accounts for beta being clipped to 0.999 in discrete version return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps=1e-5) -def log_snr_to_alpha_sigma(log_snr): +def log_snr_to_alpha_sigma(log_snr) -> torch.Tensor: return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr)) @@ -28,18 +28,18 @@ def cosine_beta_schedule(timesteps, s=0.008): return torch.clip(betas, 0.0001, 0.9999) -def linear_beta_schedule(timesteps, beta_end=0.005): +def linear_beta_schedule(timesteps, beta_end=0.005) -> torch.Tensor: beta_start = 0.0001 return torch.linspace(beta_start, beta_end, timesteps) -def quadratic_beta_schedule(timesteps): +def quadratic_beta_schedule(timesteps) -> torch.Tensor: beta_start = 0.0001 beta_end = 0.02 return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2 -def sigmoid_beta_schedule(timesteps): +def sigmoid_beta_schedule(timesteps) -> torch.Tensor: beta_start = 0.001 beta_end = 0.02 betas = torch.linspace(-6, 6, timesteps) From b3243e1f11ebd7d72a7cada9a7ddfc9d1002f338 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 20 Dec 2022 14:59:51 +0000 Subject: [PATCH 10/16] feat: add type hints to unet_bitdiffusion --- src/models/networks/unet_bitdiffusion.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/models/networks/unet_bitdiffusion.py b/src/models/networks/unet_bitdiffusion.py index 79b96cd2..31f263f6 100644 --- a/src/models/networks/unet_bitdiffusion.py +++ b/src/models/networks/unet_bitdiffusion.py @@ -7,13 +7,13 @@ def __init__( dim, init_dim=None, dim_mults=(1, 2, 4, 8), - channels=3, + channels: int = 3, bits=BITS, - resnet_block_groups=8, - learned_sinusoidal_dim=16, - num_classes=10, - class_embed_dim=3, - ): + resnet_block_groups: int = 8, + learned_sinusoidal_dim: int = 16, + num_classes: int = 10, + class_embed_dim: int = 3, + ) -> None: super().__init__() # determine dimensions @@ -97,7 +97,7 @@ def __init__( # self.final_conv = nn.Conv2d(dim, 1, 1) #lucas self.final_conv = nn.Conv2d(dim, 8, 1) - def forward(self, x, time, c, x_self_cond=None): + def forward(self, x, time, c, x_self_cond=None) -> torch.Tensor: # print(x.shape) # c = torch.zeros_like(c) # removing the conditioning LUCAS From ef93521a0060ea43756f3b62201a819cadf51474 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 20 Dec 2022 15:11:06 +0000 Subject: [PATCH 11/16] feat: add type hints to unet_lucas --- src/models/networks/unet_lucas.py | 47 ++++++++++++++++++------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/src/models/networks/unet_lucas.py b/src/models/networks/unet_lucas.py index 3e6a34e8..27a4e16d 100644 --- a/src/models/networks/unet_lucas.py +++ b/src/models/networks/unet_lucas.py @@ -1,7 +1,7 @@ import math from functools import partial from einops import rearrange -from typing import Optional, List +from typing import Optional, List, Callable import torch from torch import nn, einsum @@ -11,20 +11,20 @@ class Residual(nn.Module): - def __init__(self, fn): + def __init__(self, fn: Callable) -> None: super().__init__() self.fn = fn - def forward(self, x, *args, **kwargs): + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: return self.fn(x, *args, **kwargs) + x class LayerNorm(nn.Module): - def __init__(self, dim): + def __init__(self, dim: int) -> None: super().__init__() self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) @@ -32,12 +32,12 @@ def forward(self, x): class PreNorm(nn.Module): - def __init__(self, dim, fn): + def __init__(self, dim: int, fn) -> None: super().__init__() self.fn = fn self.norm = LayerNorm(dim) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.norm(x) return self.fn(x) @@ -48,13 +48,13 @@ class LearnedSinusoidalPositionalEmbedding(nn.Module): """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ - def __init__(self, dim): + def __init__(self, dim: int) -> None: super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = rearrange(x, "b -> b 1") freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) @@ -64,13 +64,13 @@ def forward(self, x): # building block modules class Block(nn.Module): - def __init__(self, dim, dim_out, groups=8): + def __init__(self, dim: int, dim_out: int, groups: int = 8) -> None: super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() - def forward(self, x, scale_shift=None): + def forward(self, x: torch.Tensor, scale_shift=None) -> torch.Tensor: x = self.proj(x) x = self.norm(x) @@ -83,7 +83,14 @@ def forward(self, x, scale_shift=None): class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + def __init__( + self, + dim: int, + dim_out: int, + *, + time_emb_dim: Optional[int] = None, + groups: int = 8 + ) -> None: super().__init__() self.mlp = ( nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) @@ -95,7 +102,7 @@ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() - def forward(self, x, time_emb=None): + def forward(self, x: torch.Tensor, time_emb=None) -> torch.Tensor: scale_shift = None if exists(self.mlp) and exists(time_emb): @@ -111,7 +118,7 @@ def forward(self, x, time_emb=None): class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): + def __init__(self, dim: int, heads: int = 4, dim_head: int = 32) -> None: super().__init__() self.scale = dim_head**-0.5 self.heads = heads @@ -119,7 +126,7 @@ def __init__(self, dim, heads=4, dim_head=32): self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( @@ -140,7 +147,9 @@ def forward(self, x): class Attention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32, scale=10): + def __init__( + self, dim: int, heads: int = 4, dim_head: int = 32, scale: int = 10 + ) -> None: super().__init__() self.scale = scale self.heads = heads @@ -148,7 +157,7 @@ def __init__(self, dim, heads=4, dim_head=32, scale=10): self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( @@ -175,7 +184,7 @@ def __init__( learned_sinusoidal_dim: int = 18, num_classes: int = 10, self_conditioned: bool = False, - ): + ) -> None: super().__init__() channels = 1 @@ -253,7 +262,7 @@ def __init__( self.final_res_block = resnet_block(dim * 2, dim, time_emb_dim=time_dim) self.final_conv = nn.Conv2d(dim, 1, 1) - def forward(self, x, time, classes, x_self_cond=None): + def forward(self, x: torch.Tensor, time, classes, x_self_cond=None) -> torch.Tensor: x = self.init_conv(x) r = x.clone() From f53a2e71d39bd941a8405f761280ee2e9bac82dd Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 20 Dec 2022 15:16:41 +0000 Subject: [PATCH 12/16] feat: add type hints to ddim --- src/models/diffusion/ddim.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/models/diffusion/ddim.py b/src/models/diffusion/ddim.py index 4dc76522..34c65c33 100644 --- a/src/models/diffusion/ddim.py +++ b/src/models/diffusion/ddim.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F import pytorch_lightning as pl @@ -12,12 +13,12 @@ def __init__( model, *, image_size, - timesteps=1000, - use_ddim=False, - noise_schedule="cosine", - time_difference=0.0, - bit_scale=1.0, - ): + timesteps: int = 1000, + use_ddim: bool = False, + noise_schedule: str = "cosine", + time_difference: float = 0.0, + bit_scale: float = 1.0, + ) -> None: super().__init__() self.model = model self.channels = self.model.channels @@ -106,7 +107,7 @@ def sample(self, batch_size=16, classes=None): sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample return sample_fn((batch_size, 8, 4, image_size), classes=classes) # Lucas - def forward(self, img, class_enc, *args, **kwargs): + def forward(self, img, class_enc, *args, **kwargs) -> torch.Tensor: batch, c, h, w, device, img_size, = ( *img.shape, img.device, From fb6bc8618ea0226033f38e0bbfa4c0ec6db95e6a Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 20 Dec 2022 15:21:45 +0000 Subject: [PATCH 13/16] feat: add type hints to seq dataloader --- src/data/sequence_dataloader.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/data/sequence_dataloader.py b/src/data/sequence_dataloader.py index cab51bb2..31ecbdf2 100644 --- a/src/data/sequence_dataloader.py +++ b/src/data/sequence_dataloader.py @@ -12,11 +12,11 @@ class SequenceDatasetBase(Dataset): def __init__( self, data_path, - sequence_length=200, - sequence_encoding="polar", + sequence_length: int = 200, + sequence_encoding: str = "polar", sequence_transform=None, cell_type_transform=None, - ): + ) -> None: super().__init__() self.data = pd.read_csv(data_path, sep="\t") self.sequence_length = sequence_length @@ -26,7 +26,7 @@ def __init__( self.alphabet = ["A", "C", "T", "G"] self.check_data_validity() - def __len__(self): + def __len__(self) -> int: return len(self.data) def __getitem__(self, index): @@ -45,7 +45,7 @@ def __getitem__(self, index): return X_seq, X_cell_type - def check_data_validity(self): + def check_data_validity(self) -> None: """ Checks if the data is valid. """ @@ -72,7 +72,7 @@ def encode_sequence(self, seq, encoding): return seq # Function for one hot encoding each line of the sequence dataset - def one_hot_encode(self, seq): + def one_hot_encode(self, seq) -> np.ndarray: """ One-hot encoding a sequence """ @@ -84,17 +84,17 @@ def one_hot_encode(self, seq): class SequenceDatasetTrain(SequenceDatasetBase): - def __init__(self, data_path="", **kwargs): + def __init__(self, data_path="", **kwargs) -> None: super().__init__(data_path=data_path, **kwargs) class SequenceDatasetValidation(SequenceDatasetBase): - def __init__(self, data_path="", **kwargs): + def __init__(self, data_path="", **kwargs) -> None: super().__init__(data_path=data_path, **kwargs) class SequenceDatasetTest(SequenceDatasetBase): - def __init__(self, data_path="", **kwargs): + def __init__(self, data_path="", **kwargs) -> None: super().__init__(data_path=data_path, **kwargs) @@ -104,13 +104,13 @@ def __init__( train_path=None, val_path=None, test_path=None, - sequence_length=200, - sequence_encoding="polar", + sequence_length: int = 200, + sequence_encoding: str = "polar", sequence_transform=None, cell_type_transform=None, batch_size=None, - num_workers=1, - ): + num_workers: int = 1, + ) -> None: super().__init__() self.datasets = dict() self.train_dataloader, self.val_dataloader, self.test_dataloader = ( From a3a879301b89bda98b53ba35bee566ae47f2672a Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 20 Dec 2022 15:29:26 +0000 Subject: [PATCH 14/16] feat: add type hints to unet_lucas_cond --- src/models/networks/unet_lucas_cond.py | 63 +++++++++++++++----------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/src/models/networks/unet_lucas_cond.py b/src/models/networks/unet_lucas_cond.py index 366e3f91..03ef819b 100644 --- a/src/models/networks/unet_lucas_cond.py +++ b/src/models/networks/unet_lucas_cond.py @@ -1,7 +1,7 @@ import math from einops import rearrange from functools import partial -from typing import Optional, List +from typing import Optional, List, Callable import torch from torch import nn, einsum @@ -13,31 +13,31 @@ class Residual(nn.Module): - def __init__(self, fn): + def __init__(self, fn: Callable) -> None: super().__init__() self.fn = fn - def forward(self, x, *args, **kwargs): + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: return self.fn(x, *args, **kwargs) + x -def Upsample(dim, dim_out=None): +def Upsample(dim: int, dim_out: Optional[int] = None): return nn.Sequential( nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(dim, default(dim_out, dim), 3, padding=1), ) -def Downsample(dim, dim_out=None): +def Downsample(dim: int, dim_out: Optional[int] = None): return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) class LayerNorm(nn.Module): - def __init__(self, dim): + def __init__(self, dim: int) -> None: super().__init__() self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) @@ -45,12 +45,12 @@ def forward(self, x): class PreNorm(nn.Module): - def __init__(self, dim, fn): + def __init__(self, dim: int, fn: Callable) -> None: super().__init__() self.fn = fn self.norm = LayerNorm(dim) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.norm(x) return self.fn(x) @@ -63,13 +63,13 @@ class LearnedSinusoidalPosEmb(nn.Module): """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ - def __init__(self, dim): + def __init__(self, dim: int) -> None: super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = rearrange(x, "b -> b 1") freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) @@ -78,7 +78,7 @@ def forward(self, x): class EmbedFC(nn.Module): - def __init__(self, input_dim, emb_dim): + def __init__(self, input_dim: int, emb_dim: int) -> None: super(EmbedFC, self).__init__() """ generic one layer FC NN for embedding things @@ -87,7 +87,7 @@ def __init__(self, input_dim, emb_dim): layers = [nn.Linear(input_dim, emb_dim), nn.GELU(), nn.Linear(emb_dim, emb_dim)] self.model = nn.Sequential(*layers) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) @@ -95,13 +95,13 @@ def forward(self, x): class Block(nn.Module): - def __init__(self, dim, dim_out, groups=8): + def __init__(self, dim: int, dim_out: int, groups: int = 8) -> None: super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() - def forward(self, x, scale_shift=None): + def forward(self, x: torch.Tensor, scale_shift=None) -> torch.Tensor: x = self.proj(x) x = self.norm(x) @@ -117,7 +117,9 @@ def forward(self, x, scale_shift=None): class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + def __init__( + self, dim: int, dim_out: int, *, time_emb_dim=None, groups: int = 8 + ) -> None: super().__init__() self.mlp = ( nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) @@ -129,7 +131,7 @@ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() - def forward(self, x, time_emb=None): + def forward(self, x: torch.Tensor, time_emb=None) -> torch.Tensor: scale_shift = None if exists(self.mlp) and exists(time_emb): @@ -149,8 +151,15 @@ def forward(self, x, time_emb=None): class ResnetBlockClassConditioned(ResnetBlock): def __init__( - self, dim, dim_out, *, num_classes, class_embed_dim, time_emb_dim=None, groups=8 - ): + self, + dim: int, + dim_out: int, + *, + num_classes: int, + class_embed_dim: int, + time_emb_dim=None, + groups: int = 8 + ) -> None: super().__init__( dim=dim + class_embed_dim, dim_out=dim_out, @@ -159,7 +168,7 @@ def __init__( ) self.class_mlp = EmbedFC(num_classes, class_embed_dim) - def forward(self, x, time_emb=None, c=None): + def forward(self, x: torch.Tensor, time_emb=None, c=None) -> torch.Tensor: emb_c = self.class_mlp(c) emb_c = emb_c.view(*emb_c.shape, 1, 1) emb_c = emb_c.expand(-1, -1, x.shape[-2], x.shape[-1]) @@ -172,7 +181,7 @@ def forward(self, x, time_emb=None, c=None): class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): + def __init__(self, dim: int, heads: int = 4, dim_head: int = 32) -> None: super().__init__() self.scale = dim_head**-0.5 self.heads = heads @@ -180,7 +189,7 @@ def __init__(self, dim, heads=4, dim_head=32): self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( @@ -201,7 +210,9 @@ def forward(self, x): class Attention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32, scale=10): + def __init__( + self, dim: int, heads: int = 4, dim_head: int = 32, scale: int = 10 + ) -> None: super().__init__() self.scale = scale self.heads = heads @@ -209,7 +220,7 @@ def __init__(self, dim, heads=4, dim_head=32, scale=10): self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( @@ -243,7 +254,7 @@ def __init__( learned_sinusoidal_dim: int = 18, num_classes: int = 10, class_embed_dim: bool = 3, - ): + ) -> None: super().__init__() self.channels = channels @@ -321,7 +332,7 @@ def __init__( # Additional code to the https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py mostly in forward method. - def forward(self, x, time, classes, x_self_cond=None): + def forward(self, x: torch.Tensor, time, classes, x_self_cond=None) -> torch.Tensor: x = self.init_conv(x) r = x.clone() From 74193a20f915904d6779598d7502ce44225ac33d Mon Sep 17 00:00:00 2001 From: Matei Bejan <24592776+mateibejan1@users.noreply.github.com> Date: Wed, 21 Dec 2022 10:42:44 +0200 Subject: [PATCH 15/16] Delete ddim.py Deprecated. --- src/models/diffusion/ddim.py | 146 ----------------------------------- 1 file changed, 146 deletions(-) delete mode 100644 src/models/diffusion/ddim.py diff --git a/src/models/diffusion/ddim.py b/src/models/diffusion/ddim.py deleted file mode 100644 index 34c65c33..00000000 --- a/src/models/diffusion/ddim.py +++ /dev/null @@ -1,146 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -import pytorch_lightning as pl - -from models.diffusion.diffusion import DiffusionModel - - -class DDIM(DiffusionModel): - def __init__( - self, - model, - *, - image_size, - timesteps: int = 1000, - use_ddim: bool = False, - noise_schedule: str = "cosine", - time_difference: float = 0.0, - bit_scale: float = 1.0, - ) -> None: - super().__init__() - self.model = model - self.channels = self.model.channels - - self.image_size = image_size - - if noise_schedule == "linear": - self.log_snr = beta_linear_log_snr - elif noise_schedule == "cosine": - self.log_snr = alpha_cosine_log_snr - else: - raise ValueError(f"invalid noise schedule {noise_schedule}") - - self.bit_scale = bit_scale - - self.timesteps = timesteps - self.use_ddim = use_ddim - - # proposed in the paper, summed to time_next - # as a way to fix a deficiency in self-conditioning and lower FID when the number of sampling timesteps is < 400 - - self.time_difference = time_difference - - @property - def device(self): - return next(self.model.parameters()).device - - def get_sampling_timesteps(self, batch, *, device): - times = torch.linspace(1.0, 0.0, self.timesteps + 1, device=device) - times = repeat(times, "t -> b t", b=batch) - times = torch.stack((times[:, :-1], times[:, 1:]), dim=0) - times = times.unbind(dim=-1) - return times - - @torch.no_grad() - def ddim_sample(self, shape, classes, time_difference=None): - batch, device = shape[0], self.device - - time_difference = default(time_difference, self.time_difference) - - time_pairs = self.get_sampling_timesteps(batch, device=device) - img = torch.randn(shape, device=device) - - x_start = None - - for times, times_next in tqdm(time_pairs, desc="sampling loop time step"): - - # get times and noise levels - - log_snr = self.log_snr(times) - log_snr_next = self.log_snr(times_next) - - padded_log_snr, padded_log_snr_next = map( - partial(right_pad_dims_to, img), (log_snr, log_snr_next) - ) - - alpha, sigma = log_snr_to_alpha_sigma(padded_log_snr) - alpha_next, sigma_next = log_snr_to_alpha_sigma(padded_log_snr_next) - - # add the time delay - - times_next = (times_next - time_difference).clamp(min=0.0) - - # predict x0 - - x_start = self.model(img, log_snr, classes, x_start) - - # clip x0 - - x_start.clamp_(-self.bit_scale, self.bit_scale) - - # get predicted noise - - pred_noise = (img - alpha * x_start) / sigma.clamp(min=1e-8) - - # calculate x next - - img = x_start * alpha_next + pred_noise * sigma_next - - return bits_to_decimal(img) - - # TODO: Need to add class conditioned weight - @torch.no_grad() - def sample(self, batch_size=16, classes=None): - image_size, channels = self.image_size, self.channels - sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample - return sample_fn((batch_size, 8, 4, image_size), classes=classes) # Lucas - - def forward(self, img, class_enc, *args, **kwargs) -> torch.Tensor: - batch, c, h, w, device, img_size, = ( - *img.shape, - img.device, - self.image_size, - ) - - times = torch.zeros((batch,), device=device).float().uniform_(0, 0.999) - - # convert image to bit representation - - img = decimal_to_bits(img) * self.bit_scale - - noise = torch.randn_like(img) - - noise_level = self.log_snr(times) - padded_noise_level = right_pad_dims_to(img, noise_level) - alpha, sigma = log_snr_to_alpha_sigma(padded_noise_level) - - noised_img = alpha * img + sigma * noise - - # if doing self-conditioning, 50% of the time, predict x_start from current set of times - # and condition with unet with that - # this technique will slow down training by 25%, but seems to lower FID significantly - - self_cond = None - # #TODO: Does it make sense to self condition with a class? - # if random() < 0.5: - # with torch.no_grad(): - # self_cond = self.model(noised_img, noise_level, class_enc).detach_() - - pred = self.model( - noised_img, noise_level, class_enc, self_cond - ) # BACK TO NOISE_LEVEL - - # return F.mse_loss(pred, img) # LUCAS - return F.smooth_l1_loss(pred, img) # LUCAS ADDED From 3b530727ab93d7175cc32340691007a3c35bac86 Mon Sep 17 00:00:00 2001 From: Matei Bejan <24592776+mateibejan1@users.noreply.github.com> Date: Wed, 21 Dec 2022 10:42:59 +0200 Subject: [PATCH 16/16] Delete unet_bitdiffusion.py Deprecated. --- src/models/networks/unet_bitdiffusion.py | 145 ----------------------- 1 file changed, 145 deletions(-) delete mode 100644 src/models/networks/unet_bitdiffusion.py diff --git a/src/models/networks/unet_bitdiffusion.py b/src/models/networks/unet_bitdiffusion.py deleted file mode 100644 index 31f263f6..00000000 --- a/src/models/networks/unet_bitdiffusion.py +++ /dev/null @@ -1,145 +0,0 @@ -# each layer has a time embedding AND class conditioned embedding - - -class UNet(nn.Module): - def __init__( - self, - dim, - init_dim=None, - dim_mults=(1, 2, 4, 8), - channels: int = 3, - bits=BITS, - resnet_block_groups: int = 8, - learned_sinusoidal_dim: int = 16, - num_classes: int = 10, - class_embed_dim: int = 3, - ) -> None: - super().__init__() - - # determine dimensions - - channels *= bits # lucas - self.channels = channels * 2 - - input_channels = channels * 2 - # input_channels =16 - - init_dim = default(init_dim, dim) - # self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) # original TODO for zach: is there a difference? - self.init_conv = nn.Conv2d(input_channels, init_dim, (7, 7), padding=3) - - dims = [init_dim, *map(lambda m: dim * m, dim_mults)] - - in_out = list(zip(dims[:-1], dims[1:])) - block_klass = partial( - ResnetBlockClassConditioned, - groups=resnet_block_groups, - num_classes=num_classes, - class_embed_dim=class_embed_dim, - ) - - # time embeddings - - time_dim = dim * 4 - - sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim) - fourier_dim = learned_sinusoidal_dim + 1 - - self.time_mlp = nn.Sequential( - sinu_pos_emb, - nn.Linear(fourier_dim, time_dim), - nn.GELU(), - nn.Linear(time_dim, time_dim), - ) - # layers - - self.downs = nn.ModuleList([]) - self.ups = nn.ModuleList([]) - num_resolutions = len(in_out) - for ind, (dim_in, dim_out) in enumerate(in_out): - is_last = ind >= (num_resolutions - 1) - self.downs.append( - nn.ModuleList( - [ - block_klass(dim_in, dim_in, time_emb_dim=time_dim), - block_klass(dim_in, dim_in, time_emb_dim=time_dim), - Residual(PreNorm(dim_in, LinearAttention(dim_in))), - Downsample(dim_in, dim_out) - if not is_last - else nn.Conv2d(dim_in, dim_out, 3, padding=1), - ] - ) - ) - - mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) - self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) - self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) - - for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): - is_last = ind == (len(in_out) - 1) - - self.ups.append( - nn.ModuleList( - [ - block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), - block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), - Residual(PreNorm(dim_out, LinearAttention(dim_out))), - Upsample(dim_out, dim_in) - if not is_last - else nn.Conv2d(dim_out, dim_in, 3, padding=1), - ] - ) - ) - - self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) - - # self.final_conv = nn.Conv2d(dim, 1, 1) #lucas - self.final_conv = nn.Conv2d(dim, 8, 1) - - def forward(self, x, time, c, x_self_cond=None) -> torch.Tensor: - # print(x.shape) - # c = torch.zeros_like(c) # removing the conditioning LUCAS - - x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) - - x = torch.cat((x_self_cond, x), dim=1) - x = self.init_conv(x) - r = x.clone() - - t = self.time_mlp(time) - - # todo class mask - - h = [] - for i, (block1, block2, attn, downsample) in enumerate(self.downs): - x = block1(x, t, c) - h.append(x) - - x = block2(x, t, c) - - x = attn(x) - h.append(x) - x = downsample(x) - - x = self.mid_block1(x, t, c) - x = self.mid_attn(x) - x = self.mid_block2(x, t, c) - - for block1, block2, attn, upsample in self.ups: - x = torch.cat((x, h.pop()), dim=1) - x = block1(x, t, c) - - x = torch.cat((x, h.pop()), dim=1) - x = block2(x, t, c) - x = attn(x) - - x = upsample(x) - - x = torch.cat((x, r), dim=1) - - x = self.final_res_block(x, t, c) - - x = self.final_conv(x) - # print(x.shape, 'final') - return x