From 9503df5c132541ded58da1f477062e313b711d56 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Tue, 8 Dec 2020 16:49:54 -0500 Subject: [PATCH] Option to normalize latent interpolation images (#438) * add option to normalize latent interpolation images * linspace * update Co-authored-by: ananyahjha93 --- pl_bolts/callbacks/variational.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/pl_bolts/callbacks/variational.py b/pl_bolts/callbacks/variational.py index 1b945e9631..21c64aa937 100644 --- a/pl_bolts/callbacks/variational.py +++ b/pl_bolts/callbacks/variational.py @@ -1,4 +1,5 @@ import math +import numpy as np import torch from pytorch_lightning.callbacks import Callback @@ -26,20 +27,30 @@ class LatentDimInterpolator(Callback): """ def __init__( - self, interpolate_epoch_interval: int = 20, range_start: int = -5, range_end: int = 5, num_samples: int = 2 + self, + interpolate_epoch_interval: int = 20, + range_start: int = -5, + range_end: int = 5, + steps: int = 11, + num_samples: int = 2, + normalize: bool = True, ): """ Args: interpolate_epoch_interval: default 20 range_start: default -5 range_end: default 5 + steps: number of step between start and end num_samples: default 2 + normalize: default True (change image to (0, 1) range) """ super().__init__() self.interpolate_epoch_interval = interpolate_epoch_interval self.range_start = range_start self.range_end = range_end self.num_samples = num_samples + self.normalize = normalize + self.steps = steps def on_epoch_end(self, trainer, pl_module): if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0: @@ -48,7 +59,7 @@ def on_epoch_end(self, trainer, pl_module): num_images = (self.range_end - self.range_start) ** 2 num_rows = int(math.sqrt(num_images)) - grid = torchvision.utils.make_grid(images, nrow=num_rows) + grid = torchvision.utils.make_grid(images, nrow=num_rows, normalize=self.normalize) str_title = f'{pl_module.__class__.__name__}_latent_space' trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step) @@ -56,8 +67,8 @@ def interpolate_latent_space(self, pl_module, latent_dim): images = [] with torch.no_grad(): pl_module.eval() - for z1 in range(self.range_start, self.range_end, 1): - for z2 in range(self.range_start, self.range_end, 1): + for z1 in np.linspace(self.range_start, self.range_end, self.steps): + for z2 in np.linspace(self.range_start, self.range_end, self.steps): # set all dims to zero z = torch.zeros(self.num_samples, latent_dim, device=pl_module.device)