Skip to content

Commit

Permalink
Option to normalize latent interpolation images (#438)
Browse files Browse the repository at this point in the history
* add option to normalize latent interpolation images

* linspace

* update

Co-authored-by: ananyahjha93 <[email protected]>
  • Loading branch information
teddykoker and ananyahjha93 authored Dec 8, 2020
1 parent 5a3f791 commit 9503df5
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions pl_bolts/callbacks/variational.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import numpy as np

import torch
from pytorch_lightning.callbacks import Callback
Expand Down Expand Up @@ -26,20 +27,30 @@ class LatentDimInterpolator(Callback):
"""

def __init__(
self, interpolate_epoch_interval: int = 20, range_start: int = -5, range_end: int = 5, num_samples: int = 2
self,
interpolate_epoch_interval: int = 20,
range_start: int = -5,
range_end: int = 5,
steps: int = 11,
num_samples: int = 2,
normalize: bool = True,
):
"""
Args:
interpolate_epoch_interval: default 20
range_start: default -5
range_end: default 5
steps: number of step between start and end
num_samples: default 2
normalize: default True (change image to (0, 1) range)
"""
super().__init__()
self.interpolate_epoch_interval = interpolate_epoch_interval
self.range_start = range_start
self.range_end = range_end
self.num_samples = num_samples
self.normalize = normalize
self.steps = steps

def on_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0:
Expand All @@ -48,16 +59,16 @@ def on_epoch_end(self, trainer, pl_module):

num_images = (self.range_end - self.range_start) ** 2
num_rows = int(math.sqrt(num_images))
grid = torchvision.utils.make_grid(images, nrow=num_rows)
grid = torchvision.utils.make_grid(images, nrow=num_rows, normalize=self.normalize)
str_title = f'{pl_module.__class__.__name__}_latent_space'
trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step)

def interpolate_latent_space(self, pl_module, latent_dim):
images = []
with torch.no_grad():
pl_module.eval()
for z1 in range(self.range_start, self.range_end, 1):
for z2 in range(self.range_start, self.range_end, 1):
for z1 in np.linspace(self.range_start, self.range_end, self.steps):
for z2 in np.linspace(self.range_start, self.range_end, self.steps):
# set all dims to zero
z = torch.zeros(self.num_samples, latent_dim, device=pl_module.device)

Expand Down

0 comments on commit 9503df5

Please sign in to comment.