From a45a64b0ab6816ab0ec402557e02b50757e16089 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Tue, 8 Dec 2020 11:39:35 -0500 Subject: [PATCH 1/3] add option to normalize latent interpolation images --- pl_bolts/callbacks/variational.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pl_bolts/callbacks/variational.py b/pl_bolts/callbacks/variational.py index 1b945e9631..ee7b8ebb87 100644 --- a/pl_bolts/callbacks/variational.py +++ b/pl_bolts/callbacks/variational.py @@ -26,7 +26,12 @@ class LatentDimInterpolator(Callback): """ def __init__( - self, interpolate_epoch_interval: int = 20, range_start: int = -5, range_end: int = 5, num_samples: int = 2 + self, + interpolate_epoch_interval: int = 20, + range_start: int = -5, + range_end: int = 5, + num_samples: int = 2, + normalize=False, ): """ Args: @@ -34,12 +39,14 @@ def __init__( range_start: default -5 range_end: default 5 num_samples: default 2 + normalize: default False """ super().__init__() self.interpolate_epoch_interval = interpolate_epoch_interval self.range_start = range_start self.range_end = range_end self.num_samples = num_samples + self.normalize = normalize def on_epoch_end(self, trainer, pl_module): if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0: @@ -48,7 +55,7 @@ def on_epoch_end(self, trainer, pl_module): num_images = (self.range_end - self.range_start) ** 2 num_rows = int(math.sqrt(num_images)) - grid = torchvision.utils.make_grid(images, nrow=num_rows) + grid = torchvision.utils.make_grid(images, nrow=num_rows, normalize=self.normalize) str_title = f'{pl_module.__class__.__name__}_latent_space' trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step) From 4074654f589646beb0dddbc8a820b089d6e6e153 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Tue, 8 Dec 2020 15:48:42 -0500 Subject: [PATCH 2/3] linspace --- pl_bolts/callbacks/variational.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pl_bolts/callbacks/variational.py b/pl_bolts/callbacks/variational.py index ee7b8ebb87..f6524840e1 100644 --- a/pl_bolts/callbacks/variational.py +++ b/pl_bolts/callbacks/variational.py @@ -1,4 +1,5 @@ import math +import numpy as np import torch from pytorch_lightning.callbacks import Callback @@ -30,6 +31,7 @@ def __init__( interpolate_epoch_interval: int = 20, range_start: int = -5, range_end: int = 5, + steps: int = 11, num_samples: int = 2, normalize=False, ): @@ -47,6 +49,7 @@ def __init__( self.range_end = range_end self.num_samples = num_samples self.normalize = normalize + self.steps = steps def on_epoch_end(self, trainer, pl_module): if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0: @@ -63,8 +66,8 @@ def interpolate_latent_space(self, pl_module, latent_dim): images = [] with torch.no_grad(): pl_module.eval() - for z1 in range(self.range_start, self.range_end, 1): - for z2 in range(self.range_start, self.range_end, 1): + for z1 in np.linspace(self.range_start, self.range_end, self.steps): + for z2 in np.linspace(self.range_start, self.range_end, self.steps): # set all dims to zero z = torch.zeros(self.num_samples, latent_dim, device=pl_module.device) From 2e3199b8027c970ce3cdbeb336b603341ed40d28 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 8 Dec 2020 16:36:46 -0500 Subject: [PATCH 3/3] update --- pl_bolts/callbacks/variational.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pl_bolts/callbacks/variational.py b/pl_bolts/callbacks/variational.py index f6524840e1..21c64aa937 100644 --- a/pl_bolts/callbacks/variational.py +++ b/pl_bolts/callbacks/variational.py @@ -33,15 +33,16 @@ def __init__( range_end: int = 5, steps: int = 11, num_samples: int = 2, - normalize=False, + normalize: bool = True, ): """ Args: interpolate_epoch_interval: default 20 range_start: default -5 range_end: default 5 + steps: number of step between start and end num_samples: default 2 - normalize: default False + normalize: default True (change image to (0, 1) range) """ super().__init__() self.interpolate_epoch_interval = interpolate_epoch_interval