diff --git a/src/data/sequence_dataloader.py b/src/data/sequence_dataloader.py index 438c31d9..31ecbdf2 100644 --- a/src/data/sequence_dataloader.py +++ b/src/data/sequence_dataloader.py @@ -4,11 +4,19 @@ import torch import torchvision.transforms as T import torch.nn.functional as F -import pytorch_lightning as pl -from torch.utils.data import Dataset, DataLoader +import pytorch_lightning as pl +from torch.utils.data import Dataset, DataLoader + class SequenceDatasetBase(Dataset): - def __init__(self, data_path, sequence_length=200, sequence_encoding="polar", sequence_transform=None, cell_type_transform=None): + def __init__( + self, + data_path, + sequence_length: int = 200, + sequence_encoding: str = "polar", + sequence_transform=None, + cell_type_transform=None, + ) -> None: super().__init__() self.data = pd.read_csv(data_path, sep="\t") self.sequence_length = sequence_length @@ -18,18 +26,18 @@ def __init__(self, data_path, sequence_length=200, sequence_encoding="polar", se self.alphabet = ["A", "C", "T", "G"] self.check_data_validity() - def __len__(self): + def __len__(self) -> int: return len(self.data) - + def __getitem__(self, index): # Iterating through DNA sequences from dataset and one-hot encoding all nucleotides current_seq = self.data["raw_sequence"][index] - if 'N' not in current_seq: + if "N" not in current_seq: X_seq = self.encode_sequence(current_seq, encoding=self.sequence_encoding) - + # Reading cell component at current index X_cell_type = self.data["component"][index] - + if self.sequence_transform is not None: X_seq = self.sequence_transform(X_seq) if self.cell_type_transform is not None: @@ -37,7 +45,7 @@ def __getitem__(self, index): return X_seq, X_cell_type - def check_data_validity(self): + def check_data_validity(self) -> None: """ Checks if the data is valid. """ @@ -64,7 +72,7 @@ def encode_sequence(self, seq, encoding): return seq # Function for one hot encoding each line of the sequence dataset - def one_hot_encode(self, seq): + def one_hot_encode(self, seq) -> np.ndarray: """ One-hot encoding a sequence """ @@ -76,15 +84,17 @@ def one_hot_encode(self, seq): class SequenceDatasetTrain(SequenceDatasetBase): - def __init__(self, data_path="", **kwargs): + def __init__(self, data_path="", **kwargs) -> None: super().__init__(data_path=data_path, **kwargs) + class SequenceDatasetValidation(SequenceDatasetBase): - def __init__(self, data_path="", **kwargs): + def __init__(self, data_path="", **kwargs) -> None: super().__init__(data_path=data_path, **kwargs) + class SequenceDatasetTest(SequenceDatasetBase): - def __init__(self, data_path="", **kwargs): + def __init__(self, data_path="", **kwargs) -> None: super().__init__(data_path=data_path, **kwargs) @@ -94,16 +104,20 @@ def __init__( train_path=None, val_path=None, test_path=None, - sequence_length=200, - sequence_encoding="polar", + sequence_length: int = 200, + sequence_encoding: str = "polar", sequence_transform=None, cell_type_transform=None, batch_size=None, - num_workers=1 - ): + num_workers: int = 1, + ) -> None: super().__init__() self.datasets = dict() - self.train_dataloader, self.val_dataloader, self.test_dataloader = None, None, None + self.train_dataloader, self.val_dataloader, self.test_dataloader = ( + None, + None, + None, + ) if train_path: self.datasets["train"] = train_path @@ -131,7 +145,7 @@ def setup(self): sequence_length=self.sequence_length, sequence_encoding=self.sequence_encoding, sequence_transform=self.sequence_transform, - cell_type_transform=self.cell_type_transform + cell_type_transform=self.cell_type_transform, ) if "validation" in self.datasets: self.val_data = SequenceDatasetValidation( @@ -139,7 +153,7 @@ def setup(self): sequence_length=self.sequence_length, sequence_encoding=self.sequence_encoding, sequence_transform=self.sequence_transform, - cell_type_transform=self.cell_type_transform + cell_type_transform=self.cell_type_transform, ) if "test" in self.datasets: self.test_data = SequenceDatasetTest( @@ -147,27 +161,32 @@ def setup(self): sequence_length=self.sequence_length, sequence_encoding=self.sequence_encoding, sequence_transform=self.sequence_transform, - cell_type_transform=self.cell_type_transform + cell_type_transform=self.cell_type_transform, ) def _train_dataloader(self): - return DataLoader(self.train_data, - self.batch_size, - shuffle=True, - num_workers=self.num_workers, - pin_memory=True) + return DataLoader( + self.train_data, + self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + ) def _val_dataloader(self): - return DataLoader(self.val_data, - self.batch_size, - shuffle=True, - num_workers=self.num_workers, - pin_memory=True) + return DataLoader( + self.val_data, + self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + ) def _test_dataloader(self): - return DataLoader(self.test_data, - self.batch_size, - shuffle=True, - num_workers=self.num_workers, - pin_memory=True) - + return DataLoader( + self.test_data, + self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + ) diff --git a/src/models/diffusion/ddim.py b/src/models/diffusion/ddim.py deleted file mode 100644 index 046514ad..00000000 --- a/src/models/diffusion/ddim.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch -import torch.nn as nn - -import pytorch_lightning as pl - -from models.diffusion.diffusion import DiffusionModel - -class DDIM(DiffusionModel): - def __init__( - self, - model, - *, - image_size, - timesteps = 1000, - use_ddim = False, - noise_schedule = 'cosine', - time_difference = 0., - bit_scale = 1. - ): - super().__init__() - self.model = model - self.channels = self.model.channels - - self.image_size = image_size - - if noise_schedule == "linear": - self.log_snr = beta_linear_log_snr - elif noise_schedule == "cosine": - self.log_snr = alpha_cosine_log_snr - else: - raise ValueError(f'invalid noise schedule {noise_schedule}') - - self.bit_scale = bit_scale - - self.timesteps = timesteps - self.use_ddim = use_ddim - - # proposed in the paper, summed to time_next - # as a way to fix a deficiency in self-conditioning and lower FID when the number of sampling timesteps is < 400 - - self.time_difference = time_difference - - @property - def device(self): - return next(self.model.parameters()).device - - def get_sampling_timesteps(self, batch, *, device): - times = torch.linspace(1., 0., self.timesteps + 1, device = device) - times = repeat(times, 't -> b t', b = batch) - times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0) - times = times.unbind(dim = -1) - return times - - @torch.no_grad() - def ddim_sample(self, shape, classes, time_difference = None): - batch, device = shape[0], self.device - - time_difference = default(time_difference, self.time_difference) - - time_pairs = self.get_sampling_timesteps(batch, device = device) - img = torch.randn(shape, device = device) - - x_start = None - - for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'): - - # get times and noise levels - - log_snr = self.log_snr(times) - log_snr_next = self.log_snr(times_next) - - padded_log_snr, padded_log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next)) - - alpha, sigma = log_snr_to_alpha_sigma(padded_log_snr) - alpha_next, sigma_next = log_snr_to_alpha_sigma(padded_log_snr_next) - - # add the time delay - - times_next = (times_next - time_difference).clamp(min = 0.) - - # predict x0 - - x_start = self.model(img, log_snr, classes, x_start) - - # clip x0 - - x_start.clamp_(-self.bit_scale, self.bit_scale) - - # get predicted noise - - pred_noise = (img - alpha * x_start) / sigma.clamp(min = 1e-8) - - # calculate x next - - img = x_start * alpha_next + pred_noise * sigma_next - - return bits_to_decimal(img) - - # TODO: Need to add class conditioned weight - @torch.no_grad() - def sample(self, batch_size=16, classes=None): - image_size, channels = self.image_size, self.channels - sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample - return sample_fn((batch_size, 8, 4, image_size), classes=classes) # Lucas - - def forward(self, img, class_enc, *args, **kwargs): - batch, c, h, w, device, img_size, = *img.shape, img.device, self.image_size - - times = torch.zeros((batch,), device = device).float().uniform_(0, 0.999) - - # convert image to bit representation - - img = decimal_to_bits(img) * self.bit_scale - - noise = torch.randn_like(img) - - noise_level = self.log_snr(times) - padded_noise_level = right_pad_dims_to(img, noise_level) - alpha, sigma = log_snr_to_alpha_sigma(padded_noise_level) - - noised_img = alpha * img + sigma * noise - - # if doing self-conditioning, 50% of the time, predict x_start from current set of times - # and condition with unet with that - # this technique will slow down training by 25%, but seems to lower FID significantly - - self_cond = None - # #TODO: Does it make sense to self condition with a class? - # if random() < 0.5: - # with torch.no_grad(): - # self_cond = self.model(noised_img, noise_level, class_enc).detach_() - - - pred = self.model(noised_img, noise_level, class_enc, self_cond) # BACK TO NOISE_LEVEL - - #return F.mse_loss(pred, img) # LUCAS - return F.smooth_l1_loss(pred, img) # LUCAS ADDED \ No newline at end of file diff --git a/src/models/diffusion/ddpm.py b/src/models/diffusion/ddpm.py index a4d2eaa4..bb5867b8 100644 --- a/src/models/diffusion/ddpm.py +++ b/src/models/diffusion/ddpm.py @@ -2,8 +2,7 @@ import torch from torch import nn from torch.nn.functional import F - -from typing import Optional, List +from functools import partial from models.diffusion.diffusion import DiffusionModel @@ -12,36 +11,37 @@ alpha_cosine_log_snr, linear_beta_schedule, ) -from utils.misc import extract, mean_flat +from utils.misc import extract, extract_data_from_batch, mean_flat class DDPM(DiffusionModel): def __init__( self, - model, *, image_size, - timesteps=1000, + timesteps=50, noise_schedule="cosine", + time_difference=0.0, unet_config: dict, is_conditional: bool, + p_uncond: float = 0.1, + use_fp16: bool, logdir: str, - img_size: int, optimizer_config: dict, lr_scheduler_config: dict = None, criterion: nn.Module, use_ema: bool = True, ema_decay: float = 0.9999, lr_warmup=0, - use_p2_weigthing=False, + use_p2_weigthing: bool = False, p2_gamma: float = 0.5, p2_k: float = 1, ): super().__init__( unet_config, is_conditional, + use_fp16, logdir, - img_size, optimizer_config, lr_scheduler_config, criterion, @@ -51,36 +51,40 @@ def __init__( ) self.image_size = image_size - self.timesteps = timesteps if noise_schedule == "linear": self.log_snr = beta_linear_log_snr - self.betas = linear_beta_schedule(timesteps=timesteps, beta_end=0.05) elif noise_schedule == "cosine": - # Refer to Section 3.2 of https://arxiv.org/abs/2102.09672 for details self.log_snr = alpha_cosine_log_snr - self.betas = cosine_beta_schedule(timesteps=timesteps, s=0.0001) else: raise ValueError(f"invalid noise schedule {noise_schedule}") - # Define Beta Schedule - self.set_noise_schedule(self.betas) - - if self.use_p2_weighting: - self.p2_gamma = p2_gamma - self.p2_k = p2_k - self.snr = 1.0 / (1 - self.alphas_cumprod) - 1 + self.timesteps = timesteps + self.p_uncond = p_uncond + + # self.betas = cosine_beta_schedule(timesteps=timesteps, s=0.0001) + self.set_noise_schedule(self.betas, self.timesteps) + + # proposed in the paper, summed to time_next + # as a way to fix a deficiency in self-conditioning and lower FID when the number of sampling timesteps is < 400 + + self.time_difference = time_difference + + def set_noise_schedule(self, betas, timesteps): + # define beta schedule + self.betas = linear_beta_schedule(timesteps=timesteps, beta_end=0.05) - def set_noise_schedule(self, betas: torch.Tensor) -> None: # define alphas - alphas = 1.0 - betas - self.alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas = 1.0 - self.betas + alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + + # sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) # calculations for posterior q(x_{t-1} | x_t, x_0) @@ -88,12 +92,10 @@ def set_noise_schedule(self, betas: torch.Tensor) -> None: betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) ) - def q_sample( - self, - x_start: torch.Tensor, - t: torch.Tensor, - noise: Optional[None, torch.Tensor] = None, - ) -> torch.Tensor: + def q_sample(self, x_start, t, noise=None): + """ + Forward pass with noise. + """ if noise is None: noise = torch.randn_like(x_start) @@ -105,102 +107,194 @@ def q_sample( return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise @torch.no_grad() - def p_sample(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def p_sample(self, x, t, t_index): betas_t = extract(self.betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, x.shape ) - + # print (x.shape, 'x_shape') sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape) - # Eqn 11 from https://arxiv.org/abs/2006.11239v2 + # Equation 11 in the paper + # Use our model (noise predictor) to predict the mean model_mean = sqrt_recip_alphas_t * ( - x - betas_t * self.model(x, t) / sqrt_one_minus_alphas_cumprod_t + x - betas_t * self.model(x, time=t) / sqrt_one_minus_alphas_cumprod_t ) - posterior_variance_t = extract(self.posterior_variance, t, x.shape) - noise = torch.randn_like(x) + if t_index == 0: + return model_mean + else: + posterior_variance_t = extract(self.posterior_variance, t, x.shape) + noise = torch.randn_like(x) + # Algorithm 2 line 4: + return model_mean + torch.sqrt(posterior_variance_t) * noise - return model_mean + torch.sqrt(posterior_variance_t) * noise + @torch.no_grad() + def p_sample_guided(self, x, classes, t, t_index, context_mask, cond_weight=0.0): + # adapted from: https://openreview.net/pdf?id=qw8AKxfYbI + # print (classes[0]) + batch_size = x.shape[0] + # double to do guidance with + t_double = t.repeat(2) + x_double = x.repeat(2, 1, 1, 1) + betas_t = extract(self.betas, t_double, x_double.shape) + sqrt_one_minus_alphas_cumprod_t = extract( + self.sqrt_one_minus_alphas_cumprod, t_double, x_double.shape + ) + sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t_double, x_double.shape) + + # classifier free sampling interpolates between guided and non guided using `cond_weight` + classes_masked = classes * context_mask + classes_masked = classes_masked.type(torch.long) + # print ('class masked', classes_masked) + preds = self.model(x_double, time=t_double, classes=classes_masked) + eps1 = (1 + cond_weight) * preds[:batch_size] + eps2 = cond_weight * preds[batch_size:] + x_t = eps1 - eps2 + + # Equation 11 in the paper + # Use our model (noise predictor) to predict the mean + model_mean = sqrt_recip_alphas_t[:batch_size] * ( + x + - betas_t[:batch_size] * x_t / sqrt_one_minus_alphas_cumprod_t[:batch_size] + ) + + if t_index == 0: + return model_mean + else: + posterior_variance_t = extract(self.posterior_variance, t, x.shape) + noise = torch.randn_like(x) + # Algorithm 2 line 4: + return model_mean + torch.sqrt(posterior_variance_t) * noise + # Algorithm 2 but save all images: @torch.no_grad() - def p_sample_loop(self, shape: torch.Size) -> List: + def p_sample_loop(self, classes, shape, cond_weight): device = next(self.model.parameters()).device b = shape[0] # start from pure noise (for each example in the batch) - img = torch.randn(shape, device=device) - imgs = [] + image = torch.randn(shape, device=device) + images = [] + + if classes is not None: + n_sample = classes.shape[0] + context_mask = torch.ones_like(classes).to(device) + # make 0 index unconditional + # double the batch + classes = classes.repeat(2) + context_mask = context_mask.repeat(2) + context_mask[n_sample:] = 0.0 # makes second half of batch context free + sampling_fn = partial( + self.p_sample_guided, + classes=classes, + cond_weight=cond_weight, + context_mask=context_mask, + ) + else: + sampling_fn = partial(self.p_sample) for i in tqdm( reversed(range(0, self.timesteps)), desc="sampling loop time step", total=self.timesteps, ): - img = self.p_sample( - x=img, t=torch.full((b,), i, device=device, dtype=torch.long) + image = sampling_fn( + self.model, + x=image, + t=torch.full((b,), i, device=device, dtype=torch.long), + t_index=i, ) - imgs.append(img.cpu().numpy()) - return imgs + images.append(image.cpu().numpy()) + return images @torch.no_grad() def sample( - self, batch: torch.Tensor, channels: int = 3, nucleotides: int = 4 - ) -> List: + self, image_size, classes=None, batch_size=16, channels=3, cond_weight=0 + ): return self.p_sample_loop( - shape=(batch.shape[0], channels, nucleotides, self.image_size) + self.model, + classes=classes, + shape=(batch_size, channels, 4, image_size), + cond_weight=cond_weight, ) - - def p2_weighting(self, x_t, ts, target, prediction): - """ - From Perception Prioritized Training of Diffusion Models: https://arxiv.org/abs/2204.00227. - """ - weight = ( - 1 / (self.p2_k + self.snr) ** self.p2_gamma, ts, x_t.shape - ) - loss_batch = mean_flat(weight * (target - prediction) ** 2) - loss = torch.mean(loss_batch) - return loss - def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: - if self.training_custom_noise is None: - self.training_custom_noise = torch.randn_like(batch) - x_noisy = self.q_sample( - x_start=batch, t=self.timesteps, noise=self.training_custom_noise - ) + def training_step(self, batch: torch.Tensor, batch_idx: int): + x_start, condition = extract_data_from_batch(batch) + + if noise is None: + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=self.timesteps, noise=noise) # calculating generic loss function, we'll add it to the class constructor once we have the code # we should log more metrics at train and validation e.g. l1, l2 and other suggestions - predicted_noise = self.model(x_noisy, self.timesteps) - - if self.use_p2_weighting: - loss = 0 # Need to figure out how to use p2 weighting in DNA context + if self.use_fp16: + with torch.cuda.amp.autocast(): + if self.is_conditional: + predicted_noise = self.model(x_noisy, self.timesteps, condition) + else: + predicted_noise = self.model(x_noisy, self.timesteps) else: - loss = self.criterion(predicted_noise, self.training_custom_noise) - + if self.is_conditional: + predicted_noise = self.model(x_noisy, self.timesteps, condition) + else: + predicted_noise = self.model(x_noisy, self.timesteps) + + loss = self.criterion(predicted_noise, noise) self.log("train", loss, batch_size=batch.shape[0]) return loss def validation_step(self, batch: torch.Tensor, batch_idx: int): - if self.is_conditional: - return None - else: - return self.inference_step(batch, batch_idx, "validation") + return self.inference_step(batch, batch_idx, "validation") def test_step(self, batch: torch.Tensor, batch_idx: int): - if self.is_conditional: - return None - else: - return self.inference_step(batch, batch_idx, "test") + return self.inference_step(batch, batch_idx, "test") + + def inference_step( + self, batch: torch.Tensor, batch_idx: int, phase="validation", noise=None + ): + x_start, condition = extract_data_from_batch(batch) + device = x_start.device + batch_size = batch.shape[0] + + t = torch.randint( + 0, self.timesteps, (batch_size,), device=device + ).long() # sampling a t to generate t and t+1 + + if noise is None: + noise = torch.randn_like(x_start) # gauss noise + x_noisy = self.q_sample( + x_start=x_start, t=t, noise=noise + ) # this is the auto generated noise given t and Noise + + context_mask = torch.bernoulli( + torch.zeros(classes.shape[0]) + (1 - self.p_uncond) + ).to(device) - def inference_step(self, batch: torch.Tensor, batch_idx: int, phase="validation"): - predictions = self.sample(batch) + # mask for unconditinal guidance + classes = classes * context_mask + classes = classes.type(torch.long) + + predictions = self.model(x_noisy, t, condition) loss = self.criterion(predictions, batch) - self.log("val_loss", loss) if phase == "validation" else self.log( + self.log("validation_loss", loss) if phase == "validation" else self.log( "test_loss", loss ) + """ + Log multiple losses at validation/test time according to internal discussions. + """ + return predictions + + def p2_weighting(self, x_t, ts, target, prediction): + """ + From Perception Prioritized Training of Diffusion Models: https://arxiv.org/abs/2204.00227. + """ + weight = (1 / (self.p2_k + self.snr) ** self.p2_gamma, ts, x_t.shape) + loss_batch = mean_flat(weight * (target - prediction) ** 2) + loss = torch.mean(loss_batch) + return loss diff --git a/src/models/diffusion/diffusion.py b/src/models/diffusion/diffusion.py index ba46f1c0..dc508cce 100644 --- a/src/models/diffusion/diffusion.py +++ b/src/models/diffusion/diffusion.py @@ -3,7 +3,7 @@ import pytorch_lightning as pl from utils.ema import EMA -from utils.misc import instantiate_from_config + class DiffusionModel(pl.LightningModule): def __init__( @@ -11,6 +11,7 @@ def __init__( unet_config: dict, timesteps: int, is_conditional: bool, + use_fp16: bool, logdir: str, image_size: int, optimizer_config: dict, @@ -32,6 +33,7 @@ def __init__( if self.use_ema: self.eps_model_ema = EMA(self.model, decay=ema_decay) self.is_conditional = is_conditional + self.use_fp16 = use_fp16 self.image_size = image_size self.lr_scheduler_config = lr_scheduler_config self.optimizer_config = optimizer_config diff --git a/src/models/networks/unet_bitdiffusion.py b/src/models/networks/unet_bitdiffusion.py deleted file mode 100644 index b45b6491..00000000 --- a/src/models/networks/unet_bitdiffusion.py +++ /dev/null @@ -1,131 +0,0 @@ -# each layer has a time embedding AND class conditioned embedding - -class UNet(nn.Module): - def __init__( - self, - dim, - init_dim = None, - dim_mults=(1, 2, 4, 8), - channels = 3, - bits = BITS, - resnet_block_groups = 8, - learned_sinusoidal_dim = 16, - num_classes=10, - class_embed_dim=3, - ): - super().__init__() - - # determine dimensions - - channels *= bits #lucas - self.channels = channels *2 - - input_channels = channels * 2 - #input_channels =16 - - - init_dim = default(init_dim, dim) - #self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) # original TODO for zach: is there a difference? - self.init_conv = nn.Conv2d(input_channels, init_dim, (7,7), padding = 3) - - dims = [init_dim, *map(lambda m: dim * m, dim_mults)] - - - in_out = list(zip(dims[:-1], dims[1:])) - block_klass = partial(ResnetBlockClassConditioned, groups=resnet_block_groups, - num_classes=num_classes, class_embed_dim=class_embed_dim) - - # time embeddings - - time_dim = dim * 4 - - sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim) - fourier_dim = learned_sinusoidal_dim + 1 - - self.time_mlp = nn.Sequential( - sinu_pos_emb, - nn.Linear(fourier_dim, time_dim), - nn.GELU(), - nn.Linear(time_dim, time_dim) - ) - # layers - - self.downs = nn.ModuleList([]) - self.ups = nn.ModuleList([]) - num_resolutions = len(in_out) - for ind, (dim_in, dim_out) in enumerate(in_out): - is_last = ind >= (num_resolutions - 1) - self.downs.append(nn.ModuleList([ - block_klass(dim_in, dim_in, time_emb_dim = time_dim), - block_klass(dim_in, dim_in, time_emb_dim = time_dim), - Residual(PreNorm(dim_in, LinearAttention(dim_in))), - Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) - ])) - - mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) - self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) - self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) - - for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): - is_last = ind == (len(in_out) - 1) - - self.ups.append(nn.ModuleList([ - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - Residual(PreNorm(dim_out, LinearAttention(dim_out))), - Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) - ])) - - self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) - - #self.final_conv = nn.Conv2d(dim, 1, 1) #lucas - self.final_conv = nn.Conv2d(dim,8, 1) - - - def forward(self, x, time, c, x_self_cond = None): - #print(x.shape) - #c = torch.zeros_like(c) # removing the conditioning LUCAS - - x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) - - x = torch.cat((x_self_cond, x), dim=1) - x = self.init_conv(x) - r = x.clone() - - t = self.time_mlp(time) - - # todo class mask - - h = [] - for i, (block1, block2, attn, downsample) in enumerate(self.downs): - x = block1(x, t, c) - h.append(x) - - x = block2(x, t, c) - - x = attn(x) - h.append(x) - x = downsample(x) - - x = self.mid_block1(x, t, c) - x = self.mid_attn(x) - x = self.mid_block2(x, t, c) - - for block1, block2, attn, upsample in self.ups: - x = torch.cat((x, h.pop()), dim = 1) - x = block1(x, t, c) - - x = torch.cat((x, h.pop()), dim = 1) - x = block2(x, t, c) - x = attn(x) - - x = upsample(x) - - x = torch.cat((x, r), dim = 1) - - x = self.final_res_block(x, t, c) - - x = self.final_conv(x) - #print(x.shape, 'final') - return x \ No newline at end of file diff --git a/src/models/networks/unet_lucas.py b/src/models/networks/unet_lucas.py index ddef16a1..27a4e16d 100644 --- a/src/models/networks/unet_lucas.py +++ b/src/models/networks/unet_lucas.py @@ -1,7 +1,7 @@ import math from functools import partial from einops import rearrange -from typing import Optional, List +from typing import Optional, List, Callable import torch from torch import nn, einsum @@ -11,65 +11,66 @@ class Residual(nn.Module): - def __init__(self, fn): + def __init__(self, fn: Callable) -> None: super().__init__() self.fn = fn - def forward(self, x, *args, **kwargs): + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: return self.fn(x, *args, **kwargs) + x - + class LayerNorm(nn.Module): - def __init__(self, dim): + def __init__(self, dim: int) -> None: super().__init__() self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: eps = 1e-5 if x.dtype == torch.float32 else 1e-3 - var = torch.var(x, dim = 1, unbiased = False, keepdim = True) - mean = torch.mean(x, dim = 1, keepdim = True) + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) * (var + eps).rsqrt() * self.g - + class PreNorm(nn.Module): - def __init__(self, dim, fn): + def __init__(self, dim: int, fn) -> None: super().__init__() self.fn = fn self.norm = LayerNorm(dim) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.norm(x) return self.fn(x) - + # positional embeds class LearnedSinusoidalPositionalEmbedding(nn.Module): - """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """following @crowsonkb 's lead with learned sinusoidal pos emb""" + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ - def __init__(self, dim): + def __init__(self, dim: int) -> None: super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) - def forward(self, x): - x = rearrange(x, 'b -> b 1') - freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi - fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) - fouriered = torch.cat((x, fouriered), dim = -1) + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) return fouriered - + # building block modules class Block(nn.Module): - def __init__(self, dim, dim_out, groups = 8): + def __init__(self, dim: int, dim_out: int, groups: int = 8) -> None: super().__init__() - self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) + self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() - def forward(self, x, scale_shift = None): + def forward(self, x: torch.Tensor, scale_shift=None) -> torch.Tensor: x = self.proj(x) x = self.norm(x) @@ -80,84 +81,95 @@ def forward(self, x, scale_shift = None): x = self.act(x) return x - + class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + def __init__( + self, + dim: int, + dim_out: int, + *, + time_emb_dim: Optional[int] = None, + groups: int = 8 + ) -> None: super().__init__() - self.mlp = nn.Sequential( - nn.SiLU(), - nn.Linear(time_emb_dim, dim_out * 2) - ) if exists(time_emb_dim) else None + self.mlp = ( + nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) + if exists(time_emb_dim) + else None + ) - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() - def forward(self, x, time_emb = None): + def forward(self, x: torch.Tensor, time_emb=None) -> torch.Tensor: scale_shift = None if exists(self.mlp) and exists(time_emb): time_emb = self.mlp(time_emb) - time_emb = rearrange(time_emb, 'b c -> b c 1 1') - scale_shift = time_emb.chunk(2, dim = 1) + time_emb = rearrange(time_emb, "b c -> b c 1 1") + scale_shift = time_emb.chunk(2, dim=1) - h = self.block1(x, scale_shift = scale_shift) + h = self.block1(x, scale_shift=scale_shift) h = self.block2(h) return h + self.res_conv(x) - + class LinearAttention(nn.Module): - def __init__(self, dim, heads = 4, dim_head = 32): + def __init__(self, dim: int, heads: int = 4, dim_head: int = 32) -> None: super().__init__() - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - self.to_out = nn.Sequential( - nn.Conv2d(hidden_dim, dim, 1), - LayerNorm(dim) - ) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.shape - qkv = self.to_qkv(x).chunk(3, dim = 1) - q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) - q = q.softmax(dim = -2) - k = k.softmax(dim = -1) + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) q = q * self.scale v = v / (h * w) - context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + context = torch.einsum("b h d n, b h e n -> b h d e", k, v) - out = torch.einsum('b h d e, b h d n -> b h e n', context, q) - out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) + out = torch.einsum("b h d e, b h d n -> b h e n", context, q) + out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) return self.to_out(out) - + class Attention(nn.Module): - def __init__(self, dim, heads = 4, dim_head = 32, scale = 10): + def __init__( + self, dim: int, heads: int = 4, dim_head: int = 32, scale: int = 10 + ) -> None: super().__init__() self.scale = scale self.heads = heads hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.shape - qkv = self.to_qkv(x).chunk(3, dim = 1) - q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) q, k = map(l2norm, (q, k)) - sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale - attn = sim.softmax(dim = -1) - out = einsum('b h i j, b h d j -> b h i d', attn, v) - out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) + sim = einsum("b h d i, b h d j -> b h i j", q, k) * self.scale + attn = sim.softmax(dim=-1) + out = einsum("b h i j, b h d j -> b h i d", attn, v) + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) return self.to_out(out) @@ -167,12 +179,12 @@ def __init__( dim: int, init_dim: int = None, dim_mults: Optional[int, list] = (1, 2, 4), - channels: int = 1, + channels: int = 1, resnet_block_groups: int = 8, learned_sinusoidal_dim: int = 18, num_classes: int = 10, self_conditioned: bool = False, - ): + ) -> None: super().__init__() channels = 1 @@ -183,7 +195,7 @@ def __init__( input_channels = channels * 2 init_dim = default(init_dim, dim) - self.init_conv = nn.Conv2d(input_channels, init_dim, (7,7), padding=3) + self.init_conv = nn.Conv2d(input_channels, init_dim, (7, 7), padding=3) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] @@ -195,14 +207,14 @@ def __init__( sinu_pos_emb = LearnedSinusoidalPositionalEmbedding(learned_sinusoidal_dim) fourier_dim = learned_sinusoidal_dim + 1 - + self.time_mlp = nn.Sequential( sinu_pos_emb, nn.Linear(fourier_dim, time_dim), nn.GELU(), - nn.Linear(time_dim, time_dim) + nn.Linear(time_dim, time_dim), ) - + if num_classes is not None: self.label_emb = nn.Embedding(num_classes, time_dim) @@ -213,40 +225,51 @@ def __init__( for index, (dim_in, dim_out) in enumerate(in_out): is_last = index >= (num_resolutions - 1) - self.downs.append(nn.ModuleList([ - resnet_block(dim_in, dim_in, time_emb_dim = time_dim), - resnet_block(dim_in, dim_in, time_emb_dim = time_dim), - - (PreNorm(dim_in, LinearAttention(dim_in))), - Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1) - ])) + self.downs.append( + nn.ModuleList( + [ + resnet_block(dim_in, dim_in, time_emb_dim=time_dim), + resnet_block(dim_in, dim_in, time_emb_dim=time_dim), + (PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) + if not is_last + else nn.Conv2d(dim_in, dim_out, 3, padding=1), + ] + ) + ) mid_dim = dims[-1] - self.mid_block1 = resnet_block(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block1 = resnet_block(mid_dim, mid_dim, time_emb_dim=time_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) - self.mid_block2 = resnet_block(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block2 = resnet_block(mid_dim, mid_dim, time_emb_dim=time_dim) for index, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = index == (len(in_out) - 1) - self.ups.append(nn.ModuleList([ - resnet_block(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - resnet_block(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - Residual(PreNorm(dim_out, LinearAttention(dim_out))), - Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1) - ])) - - self.final_res_block = resnet_block(dim * 2, dim, time_emb_dim = time_dim) + self.ups.append( + nn.ModuleList( + [ + resnet_block(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + resnet_block(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) + if not is_last + else nn.Conv2d(dim_out, dim_in, 3, padding=1), + ] + ) + ) + + self.final_res_block = resnet_block(dim * 2, dim, time_emb_dim=time_dim) self.final_conv = nn.Conv2d(dim, 1, 1) - - def forward(self, x, time, classes, x_self_cond = None): + + def forward(self, x: torch.Tensor, time, classes, x_self_cond=None) -> torch.Tensor: x = self.init_conv(x) r = x.clone() t_start = self.time_mlp(time) t_mid = t_start.clone() t_end = t_start.clone() - + if classes is not None: t_start += self.label_emb(classes) t_mid += self.label_emb(classes) @@ -269,16 +292,16 @@ def forward(self, x, time, classes, x_self_cond = None): x = self.mid_block2(x, t_mid) for block1, block2, attn, upsample in self.ups: - x = torch.cat((x, h.pop()), dim = 1) + x = torch.cat((x, h.pop()), dim=1) x = block1(x, t_mid) - x = torch.cat((x, h.pop()), dim = 1) + x = torch.cat((x, h.pop()), dim=1) x = block2(x, t_mid) x = attn(x) x = upsample(x) - x = torch.cat((x, r), dim = 1) + x = torch.cat((x, r), dim=1) x = self.final_res_block(x, t_end) x = self.final_conv(x) diff --git a/src/models/networks/unet_lucas_cond.py b/src/models/networks/unet_lucas_cond.py new file mode 100644 index 00000000..03ef819b --- /dev/null +++ b/src/models/networks/unet_lucas_cond.py @@ -0,0 +1,378 @@ +import math +from einops import rearrange +from functools import partial +from typing import Optional, List, Callable + +import torch +from torch import nn, einsum + +from utils.misc import default, exists +from utils.network import l2norm, Upsample, Downsample + +# Building blocks of UNET + + +class Residual(nn.Module): + def __init__(self, fn: Callable) -> None: + super().__init__() + self.fn = fn + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return self.fn(x, *args, **kwargs) + x + + +def Upsample(dim: int, dim_out: Optional[int] = None): + return nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(dim, default(dim_out, dim), 3, padding=1), + ) + + +def Downsample(dim: int, dim_out: Optional[int] = None): + return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) + + +class LayerNorm(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) * (var + eps).rsqrt() * self.g + + +class PreNorm(nn.Module): + def __init__(self, dim: int, fn: Callable) -> None: + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + return self.fn(x) + + +# Building blocks of UNET, positional embeddings + + +class LearnedSinusoidalPosEmb(nn.Module): + """following @crowsonkb 's lead with learned sinusoidal pos emb""" + + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim: int) -> None: + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +class EmbedFC(nn.Module): + def __init__(self, input_dim: int, emb_dim: int) -> None: + super(EmbedFC, self).__init__() + """ + generic one layer FC NN for embedding things + """ + self.input_dim = input_dim + layers = [nn.Linear(input_dim, emb_dim), nn.GELU(), nn.Linear(emb_dim, emb_dim)] + self.model = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + +# Building blocks of UNET, convolution + group norm blocks + + +class Block(nn.Module): + def __init__(self, dim: int, dim_out: int, groups: int = 8) -> None: + super().__init__() + self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x: torch.Tensor, scale_shift=None) -> torch.Tensor: + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + return x + + +# Building blocks of UNET, residual blocks + + +class ResnetBlock(nn.Module): + def __init__( + self, dim: int, dim_out: int, *, time_emb_dim=None, groups: int = 8 + ) -> None: + super().__init__() + self.mlp = ( + nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) + if exists(time_emb_dim) + else None + ) + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x: torch.Tensor, time_emb=None) -> torch.Tensor: + + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, "b c -> b c 1 1") + scale_shift = time_emb.chunk(2, dim=1) + + h = self.block1(x, scale_shift=scale_shift) + + h = self.block2(h) + + return h + self.res_conv(x) + + +# Additional code to the https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py + + +class ResnetBlockClassConditioned(ResnetBlock): + def __init__( + self, + dim: int, + dim_out: int, + *, + num_classes: int, + class_embed_dim: int, + time_emb_dim=None, + groups: int = 8 + ) -> None: + super().__init__( + dim=dim + class_embed_dim, + dim_out=dim_out, + time_emb_dim=time_emb_dim, + groups=groups, + ) + self.class_mlp = EmbedFC(num_classes, class_embed_dim) + + def forward(self, x: torch.Tensor, time_emb=None, c=None) -> torch.Tensor: + emb_c = self.class_mlp(c) + emb_c = emb_c.view(*emb_c.shape, 1, 1) + emb_c = emb_c.expand(-1, -1, x.shape[-2], x.shape[-1]) + x = torch.cat([x, emb_c], axis=1) + + return super().forward(x, time_emb) + + +# Building blocks of UNET, attention modules + + +class LinearAttention(nn.Module): + def __init__(self, dim: int, heads: int = 4, dim_head: int = 32) -> None: + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + + q = q * self.scale + v = v / (h * w) + + context = torch.einsum("b h d n, b h e n -> b h d e", k, v) + + out = torch.einsum("b h d e, b h d n -> b h e n", context, q) + out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) + return self.to_out(out) + + +class Attention(nn.Module): + def __init__( + self, dim: int, heads: int = 4, dim_head: int = 32, scale: int = 10 + ) -> None: + super().__init__() + self.scale = scale + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) + + q, k = map(l2norm, (q, k)) + + sim = einsum("b h d i, b h d j -> b h i j", q, k) * self.scale + attn = sim.softmax(dim=-1) + out = einsum("b h i j, b h d j -> b h i d", attn, v) + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) + return self.to_out(out) + + +# Core part of UNET + + +class UNet(nn.Module): + """ + Refer to the main paper for the architecture details https://arxiv.org/pdf/2208.04202.pdf + """ + + def __init__( + self, + dim: int, + init_dim: int = 200, + dim_mults: Optional[list] = [1, 2, 4], + channels=1, + resnet_block_groups: int = 8, + learned_sinusoidal_dim: int = 18, + num_classes: int = 10, + class_embed_dim: bool = 3, + ) -> None: + super().__init__() + + self.channels = channels + # if you want to do self conditioning uncomment this + # input_channels = channels * 2 + input_channels = channels + + init_dim = default(init_dim, dim) + self.init_conv = nn.Conv2d(input_channels, init_dim, (7, 7), padding=3) + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + block_klass = partial(ResnetBlock, groups=resnet_block_groups) + + time_dim = dim * 4 + + sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim) + fourier_dim = learned_sinusoidal_dim + 1 + + self.time_mlp = nn.Sequential( + sinu_pos_emb, + nn.Linear(fourier_dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim), + ) + + if num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_dim) + + # layers + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append( + nn.ModuleList( + [ + block_klass(dim_in, dim_in, time_emb_dim=time_dim), + block_klass(dim_in, dim_in, time_emb_dim=time_dim), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) + if not is_last + else nn.Conv2d(dim_in, dim_out, 3, padding=1), + ] + ) + ) + + mid_dim = dims[-1] + self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) + self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) + self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind == (len(in_out) - 1) + + self.ups.append( + nn.ModuleList( + [ + block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) + if not is_last + else nn.Conv2d(dim_out, dim_in, 3, padding=1), + ] + ) + ) + + self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) + self.final_conv = nn.Conv2d(dim, 1, 1) + print("final", dim, channels, self.final_conv) + + # Additional code to the https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py mostly in forward method. + + def forward(self, x: torch.Tensor, time, classes, x_self_cond=None) -> torch.Tensor: + x = self.init_conv(x) + r = x.clone() + + t_start = self.time_mlp(time) + t_mid = t_start.clone() + t_end = t_start.clone() + + if classes is not None: + t_start += self.label_emb(classes) + t_mid += self.label_emb(classes) + t_end += self.label_emb(classes) + + h = [] + + for block1, block2, attn, downsample in self.downs: + x = block1(x, t_start) + h.append(x) + + x = block2(x, t_start) + x = attn(x) + h.append(x) + + x = downsample(x) + + x = self.mid_block1(x, t_mid) + x = self.mid_attn(x) + x = self.mid_block2(x, t_mid) + + for block1, block2, attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, t_mid) + + x = torch.cat((x, h.pop()), dim=1) + x = block2(x, t_mid) + x = attn(x) + + x = upsample(x) + + x = torch.cat((x, r), dim=1) + x = self.final_res_block(x, t_end) + + x = self.final_conv(x) + return x diff --git a/src/train.py b/src/train.py index 1da73cbf..e88cc029 100644 --- a/src/train.py +++ b/src/train.py @@ -5,15 +5,38 @@ from omegaconf import DictConfig, OmegaConf from hydra.utils import instantiate +import argparse logger = logging.getLogger() +def get_parser(**parser_kwargs): + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument( + "--logdir", type=str, default="logs", help="where to save logs and ckpts" + ) + parser.add_argument("--name", type=str, default="dummy", help="postfix for logdir") + parser.add_argument( + "--resume", + type=str, + default="", + help="resume training from given folder or checkpoint", + ) + return parser + + @hydra.main(config_path="configs", config_name="train") def train(cfg: DictConfig): + parser = get_parser() + # Keeping track of current config settings in logger logger.info(f"Training with config:\n{OmegaConf.to_yaml(cfg)}") - run = wandb.init(project=cfg.logger.wandb.project, config=cfg) + run = wandb.init( + name=parser.logdir, + save_dir=parser.logdir, + project=cfg.logger.wandb.project, + config=cfg, + ) # Placeholder for what loss or metric values we plan to track with wandb wandb.log({"loss": loss}) @@ -32,7 +55,7 @@ def train(cfg: DictConfig): callbacks=cfg.callbacks, accelerator=cfg.accelerator, devices=cfg.devices, - logger=cfg.logger.wandb + logger=cfg.logger.wandb, ) trainer.fit(model, train_dl, val_dl) diff --git a/src/utils/ema.py b/src/utils/ema.py index 0855c92c..b5642e83 100644 --- a/src/utils/ema.py +++ b/src/utils/ema.py @@ -1,12 +1,13 @@ -#https://github.com/dome272/Diffusion-Models-pytorch/blob/main/modules.py -class EMA: +class EMA: def __init__(self, beta): super().__init__() self.beta = beta self.step = 0 def update_model_average(self, ma_model, current_model): - for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + for current_params, ma_params in zip( + current_model.parameters(), ma_model.parameters() + ): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = self.update_average(old_weight, up_weight) diff --git a/src/utils/metrics.py b/src/utils/metrics.py new file mode 100644 index 00000000..5c6c83e5 --- /dev/null +++ b/src/utils/metrics.py @@ -0,0 +1,201 @@ +import torch +import pandas as pd +import numpy as np +from scipy.special import rel_entr +import seaborn as sns +import matplotlib.pyplot as plt +from tqdm.auto import tqdm +import os +from typing import Callable, Dict + + +def motif_scoring_KL_divergence( + original: pd.Series, generated: pd.Series +) -> torch.Tensor: + + """ + This function encapsulates the logic of evaluating the KL divergence metric + between two sequences. + Returns + ------- + kl_divergence: Float + The KL divergence between the input and output (generated) + sequences' distribution + """ + + kl_pq = rel_entr(original, generated) + return np.sum(kl_pq) + + +def compare_motif_list( + df_motifs_a: pd.DataFrame, + df_motifs_b: pd.DataFrame, + motif_scoring_metric: Callable = motif_scoring_KL_divergence, + plot_motif_probs: bool = False, +) -> torch.Tensor: + """ + This function encapsulates the logic of evaluating the difference between the distribution + of frequencies between generated (diffusion/df_motifs_a) and the input (training/df_motifs_b) for an arbitrary metric ("motif_scoring_metric") + + Please note that some metrics, like KL_divergence, are not metrics in official sense. Reason + for that is that they dont satisfy certain properties, such as in KL case, the simmetry property. + Hence it makes a big difference what are the positions of input. + """ + set_all_mot = set( + df_motifs_a.index.values.tolist() + df_motifs_b.index.values.tolist() + ) + create_new_matrix = [] + for x in set_all_mot: + list_in = [] + list_in.append(x) # adding the name + if x in df_motifs_a.index: + list_in.append(df_motifs_a.loc[x][0]) + else: + list_in.append(1) + + if x in df_motifs_b.index: + list_in.append(df_motifs_b.loc[x][0]) + else: + list_in.append(1) + + create_new_matrix.append(list_in) + + df_motifs = pd.DataFrame(create_new_matrix, columns=["motif", "motif_a", "motif_b"]) + + df_motifs["Diffusion_seqs"] = df_motifs["motif_a"] / df_motifs["motif_a"].sum() + df_motifs["Training_seqs"] = df_motifs["motif_b"] / df_motifs["motif_b"].sum() + if plot_motif_probs: + plt.rcParams["figure.figsize"] = (3, 3) + sns.regplot(x="Diffusion_seqs", y="Training_seqs", data=df_motifs) + plt.xlabel("Diffusion Seqs") + plt.ylabel("Training Seqs") + plt.title("Motifs Probs") + plt.show() + + return motif_scoring_metric( + df_motifs["Diffusion_seqs"].values, df_motifs["Training_seqs"].values + ) + + +def sampling_to_metric( + model, + cell_types, + image_size, + nucleotides, + number_of_samples=20, + specific_group=False, + group_number=None, + cond_weight_to_metric=0, +): + """ + Might need to add to the DDPM class since if we can't call the sample() method outside PyTorch Lightning. + + This function encapsulates the logic of sampling from the trained model in order to generate counts of the motifs. + The reasoning is that we are interested only in calculating the evaluation metric + for the count of occurances and not the nucleic acids themselves. + """ + final_sequences = [] + for n_a in tqdm(range(number_of_samples)): + sample_bs = 10 + if specific_group: + sampled = torch.from_numpy(np.array([group_number] * sample_bs)) + print("specific") + else: + sampled = torch.from_numpy(np.random.choice(cell_types, sample_bs)) + + random_classes = sampled.float().cuda() + sampled_images = model.sample( + classes=random_classes, + image_size=image_size, + batch_size=sample_bs, + channels=1, + cond_weight=cond_weight_to_metric, + ) + for n_b, x in enumerate(sampled_images[-1]): + seq_final = f">seq_test_{n_a}_{n_b}\n" + "".join( + [nucleotides[s] for s in np.argmax(x.reshape(4, 200), axis=0)] + ) + final_sequences.append(seq_final) + + save_motifs_syn = open("synthetic_motifs.fasta", "w") + + save_motifs_syn.write("\n".join(final_sequences)) + save_motifs_syn.close() + + # Scan for motifs + os.system( + "gimme scan synthetic_motifs.fasta -p JASPAR2020_vertebrates -g hg38 > syn_results_motifs.bed" + ) + df_results_syn = pd.read_csv( + "syn_results_motifs.bed", sep="\t", skiprows=5, header=None + ) + df_results_syn["motifs"] = df_results_syn[8].apply( + lambda x: x.split('motif_name "')[1].split('"')[0] + ) + df_results_syn[0] = df_results_syn[0].apply(lambda x: "_".join(x.split("_")[:-1])) + df_motifs_count_syn = ( + df_results_syn[[0, "motifs"]].drop_duplicates().groupby("motifs").count() + ) + plt.rcParams["figure.figsize"] = (30, 2) + df_motifs_count_syn.sort_values(0, ascending=False).head(50)[0].plot.bar() + plt.show() + + return df_motifs_count_syn + + +def metric_comparison_between_components( + original_data: Dict, + generated_data: Dict, + x_label_plot: str, + y_label_plot: str, + cell_components, +) -> None: + """ + This functions takes as inputs dictionaries, which contain as keys different components (cell types) + and as values the distribution of occurances of different motifs. These two dictionaries represent two different datasets, i.e. + generated dataset and the input (train) dataset. + + The goal is to then plot a the main evaluation metric (KL or otherwise) across all different types of cell types + in a heatmap fashion. + """ + ENUMARATED_CELL_NAME = """7 Trophoblasts + 5 CD8_cells + 15 CD34_cells + 9 Fetal_heart + 12 Fetal_muscle + 14 HMVEC(vascular) + 3 hESC(Embryionic) + 8 Fetal(Neural) + 13 Intestine + 2 Skin(stromalA) + 4 Fibroblast(stromalB) + 6 Renal(Cancer) + 16 Esophageal(Cancer) + 11 Fetal_Lung + 10 Fetal_kidney + 1 Tissue_Invariant""".split( + "\n" + ) + CELL_NAMES = {int(x.split(" ")[0]): x.split(" ")[1] for x in ENUMARATED_CELL_NAME} + + final_comparison_all_components = [] + for components_1, motif_occurance_frequency in original_data.items(): + comparisons_single_component = [] + for components_2 in generated_data.keys(): + compared_motifs_occurances = compare_motif_list( + motif_occurance_frequency, generated_data[components_2] + ) + comparisons_single_component.append(compared_motifs_occurances) + + final_comparison_all_components.append(comparisons_single_component) + + plt.rcParams["figure.figsize"] = (10, 10) + df_plot = pd.DataFrame(final_comparison_all_components) + df_plot.columns = [CELL_NAMES[x] for x in cell_components] + df_plot.index = df_plot.columns + sns.heatmap(df_plot, cmap="Blues_r", annot=True, lw=0.1, vmax=1, vmin=0) + plt.title( + f"Kl divergence \n {x_label_plot} sequences x {y_label_plot} sequences \n MOTIFS probabilities" + ) + plt.xlabel(f"{x_label_plot} Sequences \n(motifs dist)") + plt.ylabel(f"{y_label_plot} \n (motifs dist)") diff --git a/src/utils/misc.py b/src/utils/misc.py index 4ba9eda1..6ede566e 100644 --- a/src/utils/misc.py +++ b/src/utils/misc.py @@ -1,7 +1,23 @@ import math import importlib - import torch +import random +import os +import numpy as np + + +def seed_everything(seed: int) -> None: + """ " + Seed everything. + """ + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + def exists(x): return x is not None @@ -44,8 +60,22 @@ def convert_image_to(img_type, image): return image -def log(t, eps = 1e-20): - return torch.log(t.clamp(min = eps)) +def one_hot_encode(seq, nucleotides, max_seq_len: int) -> np.ndarray: + """ + One-hot encode a sequence of nucleotides. + """ + seq_len = len(seq) + seq_array = np.zeros((max_seq_len, len(nucleotides))) + for i in range(seq_len): + seq_array[i, nucleotides.index(seq[i])] = 1 + return seq_array + + +def log(t: torch.Tensor, eps=1e-20) -> torch.Tensor: + """ + Toch log for the purporses of diffusion time steps t. + """ + return torch.log(t.clamp(min=eps)) def right_pad_dims_to(x, t): @@ -68,6 +98,7 @@ def get_obj_from_str(string, reload=False): importlib.reload(module_to_reload) return getattr(importlib.import_module(module, package=None), class_) + def mean_flat(tensor): """ Take the mean over all non-batch dimensions. diff --git a/src/utils/network.py b/src/utils/network.py index e0bf3499..b8fdaadc 100644 --- a/src/utils/network.py +++ b/src/utils/network.py @@ -1,20 +1,5 @@ -from torch import nn import torch.nn.functional as F -from utils.misc import default - def l2norm(t): - return F.normalize(t, dim = -1) - - -def Upsample(dim, dim_out = None): - return nn.Sequential( - nn.Upsample(scale_factor = 2, mode = 'nearest'), - nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) - ) - - -def Downsample(dim, dim_out = None): - return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) - + return F.normalize(t, dim=-1) diff --git a/src/utils/schedules.py b/src/utils/schedules.py index fec3238b..abedce5a 100644 --- a/src/utils/schedules.py +++ b/src/utils/schedules.py @@ -1,23 +1,22 @@ import math -from math import exp +from math import log, exp import torch -def beta_linear_log_snr(t): +def beta_linear_log_snr(t: torch.Tensor) -> torch.Tensor: return -torch.log(exp(1e-4 + 10 * (t**2))) -def alpha_cosine_log_snr(t, s: float = 0.008): - return -torch.log( - (torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps=1e-5 - ) # not sure if this accounts for beta being clipped to 0.999 in discrete version +def alpha_cosine_log_snr(t: torch.Tensor, s: float = 0.008) -> torch.Tensor: + # not sure if this accounts for beta being clipped to 0.999 in discrete version + return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps=1e-5) -def log_snr_to_alpha_sigma(log_snr): +def log_snr_to_alpha_sigma(log_snr) -> torch.Tensor: return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr)) -def cosine_beta_schedule(timesteps, s=0.008) -> torch.Tensor: +def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://arxiv.org/abs/2102.09672 """ diff --git a/tests/data/test_sequence_dataloader.py b/tests/data/test_sequence_dataloader.py index b7b1ddc7..689c1e57 100644 --- a/tests/data/test_sequence_dataloader.py +++ b/tests/data/test_sequence_dataloader.py @@ -7,32 +7,54 @@ def prepare_default_data(path): """Prepares dummy data for testing.""" - pd.DataFrame({ - "raw_sequence": ["ATCGATCGATCG", "GGTGAACGATTA", "AATCGTATCGCG", "CTTATCGATCCG"], - "component": [1, 2, 1, 10], - }).to_csv(path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": [ + "ATCGATCGATCG", + "GGTGAACGATTA", + "AATCGTATCGCG", + "CTTATCGATCCG", + ], + "component": [1, 2, 1, 10], + } + ).to_csv(path, index=False, sep="\t") + def prepare_high_diversity_datasets(train_data_path, val_data_path, test_data_path): - pd.DataFrame({ - "raw_sequence": ["AAAAAAAAAA", "AAAAAAAAAA", "AAAAAAAAAA", "AAAAAAAAAA"], - "component": [0, 0, 0, 0], - }).to_csv(train_data_path, index=False, sep="\t") - pd.DataFrame({ - "raw_sequence": ["CCCCCCCCCC", "CCCCCCCCCC", "CCCCCCCCCC", "CCCCCCCCCC"], - "component": [1, 1, 1, 1], - }).to_csv(val_data_path, index=False, sep="\t") - pd.DataFrame({ - "raw_sequence": ["TTTTTTTTTT", "TTTTTTTTTT", "TTTTTTTTTT", "TTTTTTTTTT"], - "component": [2, 2, 2, 2], - }).to_csv(test_data_path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": ["AAAAAAAAAA", "AAAAAAAAAA", "AAAAAAAAAA", "AAAAAAAAAA"], + "component": [0, 0, 0, 0], + } + ).to_csv(train_data_path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": ["CCCCCCCCCC", "CCCCCCCCCC", "CCCCCCCCCC", "CCCCCCCCCC"], + "component": [1, 1, 1, 1], + } + ).to_csv(val_data_path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": ["TTTTTTTTTT", "TTTTTTTTTT", "TTTTTTTTTT", "TTTTTTTTTT"], + "component": [2, 2, 2, 2], + } + ).to_csv(test_data_path, index=False, sep="\t") + def test_invalid_sequence_letters(): # prepare invalid data dummy_data_path = "_tmp_seq_dataloader_data.csv" - pd.DataFrame({ - "raw_sequence": ["ZCCCACTGACTG", "ACTGACTGACTG", "AAAACCCCTTTT", "ABCDEFGHIJKL"], - "component": [1, 2, 1, 10], - }).to_csv(dummy_data_path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": [ + "ZCCCACTGACTG", + "ACTGACTGACTG", + "AAAACCCCTTTT", + "ABCDEFGHIJKL", + ], + "component": [1, 2, 1, 10], + } + ).to_csv(dummy_data_path, index=False, sep="\t") datamodule = SequenceDataModule( train_path=dummy_data_path, @@ -53,14 +75,17 @@ def test_invalid_sequence_letters(): # remove dummy data os.remove(dummy_data_path) + def test_invalid_sequence_lengths(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" # second sequence too short - pd.DataFrame({ - "raw_sequence": ["ATCG", "GGT", "AATC", "CTTA"], - "component": [1, 2, 1, 10], - }).to_csv(dummy_data_path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": ["ATCG", "GGT", "AATC", "CTTA"], + "component": [1, 2, 1, 10], + } + ).to_csv(dummy_data_path, index=False, sep="\t") # prepare data module datamodule = SequenceDataModule( @@ -82,10 +107,12 @@ def test_invalid_sequence_lengths(): os.remove(dummy_data_path) # fourth sequence too long - pd.DataFrame({ - "raw_sequence": ["ATCG", "GGT", "AATC", "CTTAT"], - "component": [1, 2, 1, 10], - }).to_csv(dummy_data_path, index=False, sep="\t") + pd.DataFrame( + { + "raw_sequence": ["ATCG", "GGT", "AATC", "CTTAT"], + "component": [1, 2, 1, 10], + } + ).to_csv(dummy_data_path, index=False, sep="\t") # prepare data module datamodule = SequenceDataModule( @@ -106,19 +133,22 @@ def test_invalid_sequence_lengths(): # remove dummy data os.remove(dummy_data_path) + def test_train_val_test_data_split(): # prepare dummy data dummy_train_data_path = "_tmp_seq_dataloader_train_data.csv" dummy_val_data_path = "_tmp_seq_dataloader_val_data.csv" dummy_test_data_path = "_tmp_seq_dataloader_test_data.csv" - prepare_high_diversity_datasets(dummy_train_data_path, dummy_val_data_path, dummy_test_data_path) + prepare_high_diversity_datasets( + dummy_train_data_path, dummy_val_data_path, dummy_test_data_path + ) # check loading of only a single data set datamodule = SequenceDataModule( train_path=None, val_path=dummy_val_data_path, test_path=None, - sequence_length=10 + sequence_length=10, ) datamodule.setup() assert datamodule.train_dataloader is None @@ -141,7 +171,13 @@ def test_train_val_test_data_split(): assert len(datamodule.val_dataloader()) == 2 assert len(datamodule.test_dataloader()) == 2 seen_nucleotide_idxs = set() - for dl_idx, dataloader in enumerate([datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader()]): + for dl_idx, dataloader in enumerate( + [ + datamodule.train_dataloader(), + datamodule.val_dataloader(), + datamodule.test_dataloader(), + ] + ): dataloader_iter = iter(dataloader) # first batch @@ -170,6 +206,7 @@ def test_train_val_test_data_split(): for path in [dummy_train_data_path, dummy_val_data_path, dummy_test_data_path]: os.remove(path) + def test_polar_encoding(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" @@ -191,7 +228,11 @@ def test_polar_encoding(): assert len(datamodule.train_data) == 4 assert len(datamodule.val_data) == 4 assert len(datamodule.test_data) == 4 - for dataloader in [datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader()]: + for dataloader in [ + datamodule.train_dataloader(), + datamodule.val_dataloader(), + datamodule.test_dataloader(), + ]: for batch in dataloader: assert len(batch) == 2 assert isinstance(batch[0], torch.Tensor) @@ -205,6 +246,7 @@ def test_polar_encoding(): # remove dummy data os.remove(dummy_data_path) + def test_onehot_encoding(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" @@ -226,7 +268,11 @@ def test_onehot_encoding(): assert len(datamodule.train_data) == 4 assert len(datamodule.val_data) == 4 assert len(datamodule.test_data) == 4 - for dataloader in [datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader()]: + for dataloader in [ + datamodule.train_dataloader(), + datamodule.val_dataloader(), + datamodule.test_dataloader(), + ]: for batch in dataloader: assert len(batch) == 2 assert isinstance(batch[0], torch.Tensor) @@ -240,6 +286,7 @@ def test_onehot_encoding(): # remove dummy data os.remove(dummy_data_path) + def test_ordinal_encoding(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" @@ -261,7 +308,11 @@ def test_ordinal_encoding(): assert len(datamodule.train_data) == 4 assert len(datamodule.val_data) == 4 assert len(datamodule.test_data) == 4 - for dataloader in [datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader()]: + for dataloader in [ + datamodule.train_dataloader(), + datamodule.val_dataloader(), + datamodule.test_dataloader(), + ]: for batch in dataloader: assert len(batch) == 2 assert isinstance(batch[0], torch.Tensor) @@ -275,6 +326,7 @@ def test_ordinal_encoding(): # remove dummy data os.remove(dummy_data_path) + def test_polar_transforms(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" @@ -283,20 +335,26 @@ def test_polar_transforms(): # prepare data module def seg_transform(seq): return seq + 1 + def cell_type_transform(cell_type): return cell_type + 20 + datamodule = SequenceDataModule( train_path=dummy_data_path, val_path=dummy_data_path, test_path=dummy_data_path, sequence_length=12, sequence_encoding="polar", - sequence_transform=transforms.Compose([ - transforms.Lambda(seg_transform), - ]), - cell_type_transform=transforms.Compose([ - transforms.Lambda(cell_type_transform), - ]), + sequence_transform=transforms.Compose( + [ + transforms.Lambda(seg_transform), + ] + ), + cell_type_transform=transforms.Compose( + [ + transforms.Lambda(cell_type_transform), + ] + ), batch_size=2, num_workers=0, ) @@ -306,7 +364,11 @@ def cell_type_transform(cell_type): assert len(datamodule.train_data) == 4 assert len(datamodule.val_data) == 4 assert len(datamodule.test_data) == 4 - for dataloader in [datamodule.train_dataloader, datamodule.val_dataloader, datamodule.test_dataloader]: + for dataloader in [ + datamodule.train_dataloader, + datamodule.val_dataloader, + datamodule.test_dataloader, + ]: seen_cell_type_ids = set() for batch in dataloader(): assert len(batch) == 2 @@ -325,6 +387,7 @@ def cell_type_transform(cell_type): # remove dummy data os.remove(dummy_data_path) + def test_onehot_transforms(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" @@ -333,20 +396,26 @@ def test_onehot_transforms(): # prepare data module def seg_transform(seq): return seq + 1 + def cell_type_transform(cell_type): return cell_type + 20 + datamodule = SequenceDataModule( train_path=dummy_data_path, val_path=dummy_data_path, test_path=dummy_data_path, sequence_length=12, sequence_encoding="onehot", - sequence_transform=transforms.Compose([ - transforms.Lambda(seg_transform), - ]), - cell_type_transform=transforms.Compose([ - transforms.Lambda(cell_type_transform), - ]), + sequence_transform=transforms.Compose( + [ + transforms.Lambda(seg_transform), + ] + ), + cell_type_transform=transforms.Compose( + [ + transforms.Lambda(cell_type_transform), + ] + ), batch_size=2, num_workers=0, ) @@ -356,7 +425,11 @@ def cell_type_transform(cell_type): assert len(datamodule.train_data) == 4 assert len(datamodule.val_data) == 4 assert len(datamodule.test_data) == 4 - for dataloader in [datamodule.train_dataloader, datamodule.val_dataloader, datamodule.test_dataloader]: + for dataloader in [ + datamodule.train_dataloader, + datamodule.val_dataloader, + datamodule.test_dataloader, + ]: seen_cell_type_ids = set() for batch in dataloader(): assert len(batch) == 2 @@ -375,6 +448,7 @@ def cell_type_transform(cell_type): # remove dummy data os.remove(dummy_data_path) + def test_ordinal_transforms(): # prepare dummy data dummy_data_path = "_tmp_seq_dataloader_data.csv" @@ -383,20 +457,26 @@ def test_ordinal_transforms(): # prepare data module def seg_transform(seq): return seq + 1 + def cell_type_transform(cell_type): return cell_type + 20 + datamodule = SequenceDataModule( train_path=dummy_data_path, val_path=dummy_data_path, test_path=dummy_data_path, sequence_length=12, sequence_encoding="ordinal", - sequence_transform=transforms.Compose([ - transforms.Lambda(seg_transform), - ]), - cell_type_transform=transforms.Compose([ - transforms.Lambda(cell_type_transform), - ]), + sequence_transform=transforms.Compose( + [ + transforms.Lambda(seg_transform), + ] + ), + cell_type_transform=transforms.Compose( + [ + transforms.Lambda(cell_type_transform), + ] + ), batch_size=2, num_workers=0, ) @@ -406,7 +486,11 @@ def cell_type_transform(cell_type): assert len(datamodule.train_data) == 4 assert len(datamodule.val_data) == 4 assert len(datamodule.test_data) == 4 - for dataloader in [datamodule.train_dataloader, datamodule.val_dataloader, datamodule.test_dataloader]: + for dataloader in [ + datamodule.train_dataloader, + datamodule.val_dataloader, + datamodule.test_dataloader, + ]: seen_cell_type_ids = set() for batch in dataloader(): assert len(batch) == 2