-
Notifications
You must be signed in to change notification settings - Fork 27.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10856 from akx/untamed
Remove taming_transformers dependency
- Loading branch information
Showing
5 changed files
with
148 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py, | ||
# where the license is as follows: | ||
# | ||
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy | ||
# of this software and associated documentation files (the "Software"), to deal | ||
# in the Software without restriction, including without limitation the rights | ||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
# copies of the Software, and to permit persons to whom the Software is | ||
# furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in all | ||
# copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | ||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | ||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. | ||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, | ||
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR | ||
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE | ||
# OR OTHER DEALINGS IN THE SOFTWARE./ | ||
|
||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
from einops import rearrange | ||
|
||
|
||
class VectorQuantizer2(nn.Module): | ||
""" | ||
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly | ||
avoids costly matrix multiplications and allows for post-hoc remapping of indices. | ||
""" | ||
|
||
# NOTE: due to a bug the beta term was applied to the wrong term. for | ||
# backwards compatibility we use the buggy version by default, but you can | ||
# specify legacy=False to fix it. | ||
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", | ||
sane_index_shape=False, legacy=True): | ||
super().__init__() | ||
self.n_e = n_e | ||
self.e_dim = e_dim | ||
self.beta = beta | ||
self.legacy = legacy | ||
|
||
self.embedding = nn.Embedding(self.n_e, self.e_dim) | ||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) | ||
|
||
self.remap = remap | ||
if self.remap is not None: | ||
self.register_buffer("used", torch.tensor(np.load(self.remap))) | ||
self.re_embed = self.used.shape[0] | ||
self.unknown_index = unknown_index # "random" or "extra" or integer | ||
if self.unknown_index == "extra": | ||
self.unknown_index = self.re_embed | ||
self.re_embed = self.re_embed + 1 | ||
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " | ||
f"Using {self.unknown_index} for unknown indices.") | ||
else: | ||
self.re_embed = n_e | ||
|
||
self.sane_index_shape = sane_index_shape | ||
|
||
def remap_to_used(self, inds): | ||
ishape = inds.shape | ||
assert len(ishape) > 1 | ||
inds = inds.reshape(ishape[0], -1) | ||
used = self.used.to(inds) | ||
match = (inds[:, :, None] == used[None, None, ...]).long() | ||
new = match.argmax(-1) | ||
unknown = match.sum(2) < 1 | ||
if self.unknown_index == "random": | ||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) | ||
else: | ||
new[unknown] = self.unknown_index | ||
return new.reshape(ishape) | ||
|
||
def unmap_to_all(self, inds): | ||
ishape = inds.shape | ||
assert len(ishape) > 1 | ||
inds = inds.reshape(ishape[0], -1) | ||
used = self.used.to(inds) | ||
if self.re_embed > self.used.shape[0]: # extra token | ||
inds[inds >= self.used.shape[0]] = 0 # simply set to zero | ||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) | ||
return back.reshape(ishape) | ||
|
||
def forward(self, z, temp=None, rescale_logits=False, return_logits=False): | ||
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" | ||
assert rescale_logits is False, "Only for interface compatible with Gumbel" | ||
assert return_logits is False, "Only for interface compatible with Gumbel" | ||
# reshape z -> (batch, height, width, channel) and flatten | ||
z = rearrange(z, 'b c h w -> b h w c').contiguous() | ||
z_flattened = z.view(-1, self.e_dim) | ||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z | ||
|
||
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ | ||
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ | ||
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) | ||
|
||
min_encoding_indices = torch.argmin(d, dim=1) | ||
z_q = self.embedding(min_encoding_indices).view(z.shape) | ||
perplexity = None | ||
min_encodings = None | ||
|
||
# compute loss for embedding | ||
if not self.legacy: | ||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \ | ||
torch.mean((z_q - z.detach()) ** 2) | ||
else: | ||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \ | ||
torch.mean((z_q - z.detach()) ** 2) | ||
|
||
# preserve gradients | ||
z_q = z + (z_q - z).detach() | ||
|
||
# reshape back to match original input shape | ||
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() | ||
|
||
if self.remap is not None: | ||
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis | ||
min_encoding_indices = self.remap_to_used(min_encoding_indices) | ||
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten | ||
|
||
if self.sane_index_shape: | ||
min_encoding_indices = min_encoding_indices.reshape( | ||
z_q.shape[0], z_q.shape[2], z_q.shape[3]) | ||
|
||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices) | ||
|
||
def get_codebook_entry(self, indices, shape): | ||
# shape specifying (batch, height, width, channel) | ||
if self.remap is not None: | ||
indices = indices.reshape(shape[0], -1) # add batch axis | ||
indices = self.unmap_to_all(indices) | ||
indices = indices.reshape(-1) # flatten again | ||
|
||
# get quantized latent vectors | ||
z_q = self.embedding(indices) | ||
|
||
if shape is not None: | ||
z_q = z_q.view(shape) | ||
# reshape back to match original input shape | ||
z_q = z_q.permute(0, 3, 1, 2).contiguous() | ||
|
||
return z_q |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters