-
Notifications
You must be signed in to change notification settings - Fork 6
/
residual_vq.py
160 lines (109 loc) · 5.62 KB
/
residual_vq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from math import ceil
from functools import partial
from random import randrange
import torch
from torch import nn
import torch.nn.functional as F
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
from einops import rearrange, repeat, pack, unpack
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
import time
# main class
class ResidualVQ(nn.Module):
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
def __init__(
self,
*,
num_quantizers,
shared_codebook = False,
heads = 1,
quantize_dropout = False,
quantize_dropout_cutoff_index = 0,
quantize_dropout_multiple_of = 1,
accept_image_fmap = False,
**kwargs
):
super().__init__()
assert heads == 1, 'residual vq is not compatible with multi-headed codes'
self.num_quantizers = num_quantizers
self.accept_image_fmap = accept_image_fmap
self.layers = nn.ModuleList([VectorQuantize(accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)])
self.quantize_dropout = quantize_dropout
assert quantize_dropout_cutoff_index >= 0
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
if not shared_codebook:
return
first_vq, *rest_vq = self.layers
codebook = first_vq._codebook
for vq in rest_vq:
vq._codebook = codebook
@property
def codebooks(self):
codebooks = [layer._codebook.embed for layer in self.layers]
codebooks = torch.stack(codebooks, dim = 0)
codebooks = rearrange(codebooks, 'q 1 c d -> q c d')
return codebooks
def get_codes_from_indices(self, indices):
batch, quantize_dim = indices.shape[0], indices.shape[-1]
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
indices, ps = pack([indices], 'b * q')
# because of quantize dropout, one can pass in indices that are coarse
# and the network should be able to reconstruct
if quantize_dim < self.num_quantizers:
assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations'
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
# get ready for gathering
codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch)
gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1])
# take care of quantizer dropout
mask = gather_indices == -1.
gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
all_codes = codebooks.gather(2, gather_indices) # gather all codes
# mask out any codes that were dropout-ed
all_codes = all_codes.masked_fill(mask, 0.)
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
all_codes, = unpack(all_codes, ps, 'q b * d')
return all_codes
def forward(
self,
x,
return_all_codes = False
):
num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device
quantized_out = 0.
residual = x
all_losses = []
all_indices = []
# should_quantize_dropout = self.training and self.quantize_dropout
# sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss
# if should_quantize_dropout:
# rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
# if quant_dropout_multiple_of != 1:
# rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
# null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
# null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
# null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)
# go through the layers
for quantizer_index, layer in enumerate(self.layers):
# if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
# all_indices.append(null_indices)
# all_losses.append(null_loss)
# continue
quantized, indices, loss = layer(residual)
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
all_losses, all_indices = map(partial(torch.stack, dim = -1), (all_losses, all_indices))
ret = (quantized_out, all_indices, all_losses)
start_time = time.time()
all_codes = self.get_codes_from_indices(all_indices)
if return_all_codes:
# whether to return all codes from all codebooks across layers
all_codes = self.get_codes_from_indices(all_indices)
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
ret = (*ret, all_codes)
return ret