Skip to content

Commit

Permalink
add adaptive weight from vqgan paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 3, 2022
1 parent ff1ee18 commit 0d206e9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ video = nuwa.generate(text = text, text_mask = mask) # (1, 5, 3, 256, 256)
- [x] complete 3dna causal attention in decoder
- [x] write up easy generation functions
- [x] make sure GAN portion of VQGan is correct, reread paper
- [x] make sure adaptive weight in vqgan is correctly built
- [ ] flesh out VAE resnet blocks, offer some choices
- [ ] offer new vqvae improvements (orthogonal reg and smaller codebook dimensions)
- [ ] offer vqvae training script
Expand All @@ -105,7 +106,6 @@ video = nuwa.generate(text = text, text_mask = mask) # (1, 5, 3, 256, 256)
- [ ] add audio transformer, and build audio / video nearby cross attention
- [ ] investigate custom attention layouts in microsoft deepspeed sparse attention (using triton)
- [ ] batch video tokens -> vae during video generation, to prevent oom
- [ ] make sure adaptive weight in vqgan is correctly built
- [ ] add all stability tricks from cogview paper by default, as well as cosine sim attention from swinv2 as an option

## Citations
Expand Down
47 changes: 36 additions & 11 deletions nuwa_pytorch/nuwa_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import nn, einsum
from torch.autograd import grad
import torch.nn.functional as F

from einops import rearrange, reduce, repeat
Expand Down Expand Up @@ -49,6 +50,9 @@ def gumbel_noise(t):
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

def safe_div(numer, denom, eps = 1e-6):
return numer / (denom + eps)

# gan losses

def hinge_discr_loss(fake, real):
Expand All @@ -63,6 +67,14 @@ def bce_discr_loss(fake, real):
def bce_gen_loss(fake):
return -log(sigmoid(fake)).mean()

def grad_layer_wrt_loss(loss, layer):
return grad(
outputs = loss,
inputs = layer,
grad_outputs = torch.ones_like(loss),
retain_graph = True
)[0].detach()

# vqgan vae

class Discriminator(nn.Module):
Expand Down Expand Up @@ -113,7 +125,7 @@ def __init__(
self.encoders = MList([])
self.decoders = MList([])

dims = (channels, *((dim,) * num_layers))
dims = (dim,) * num_layers
reversed_dims = tuple(reversed(dims))
enc_dim_pairs = zip(dims[:-1], dims[1:])
dec_dim_pairs = zip(reversed_dims[:-1], reversed_dims[1:])
Expand All @@ -122,6 +134,9 @@ def __init__(
self.encoders.append(nn.Conv2d(enc_dim_in, enc_dim_out, 4, stride = 2, padding = 1))
self.decoders.append(nn.ConvTranspose2d(dec_dim_in, dec_dim_out, 4, stride = 2, padding = 1))

self.encoders.insert(0, nn.Conv2d(channels, dim, 3, padding = 1))
self.decoders.append(nn.Conv2d(dim, channels, 1))

self.vq = VQ(
dim = dim,
codebook_size = vq_codebook_size,
Expand All @@ -132,7 +147,7 @@ def __init__(

# reconstruction loss

self.l2_recon_loss = l2_recon_loss
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss

# preceptual loss

Expand Down Expand Up @@ -187,29 +202,39 @@ def forward(

assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'

# lpips

img_vgg_feats = self.vgg(img)
recon_vgg_feats = self.vgg(fmap)
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)

# whether to return discriminator loss

if return_discr_loss:
fmap.detach_()
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
return self.discr_loss(fmap_discr_logits, img_discr_logits)
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
return discr_loss

# generator loss

gen_loss = self.gen_loss(fmap)

# reconstruction loss
# calculate adaptive weight

recon_loss_fn = F.mse_loss if self.l2_recon_loss else F.l1_loss
recon_loss = recon_loss_fn(fmap, img)
last_dec_layer = self.decoders[-1].weight

# lpips
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)

img_vgg_feats = self.vgg(img)
recon_vgg_feats = self.vgg(fmap)
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
# reconstruction loss

recon_loss = self.recon_loss_fn(fmap, img)

# combine losses

loss = recon_loss + commit_loss + gen_loss
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
return loss

# normalizations
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'nuwa-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'NÜWA - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 0d206e9

Please sign in to comment.