Skip to content

Commit

Permalink
handel inplace optimizer correctly for shared codebook under residual…
Browse files Browse the repository at this point in the history
… vq setting
  • Loading branch information
lucidrains committed Sep 25, 2024
1 parent b5cb143 commit 25683f2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.17.5"
version = "1.17.6"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
4 changes: 3 additions & 1 deletion vector_quantize_pytorch/residual_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def __init__(

if shared_codebook:
vq_kwargs.update(
manual_ema_update = True
manual_ema_update = True,
manual_in_place_optimizer_update = True
)

self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **vq_kwargs) for _ in range(num_quantizers)])
Expand Down Expand Up @@ -360,6 +361,7 @@ def forward(

if self.shared_codebook:
first(self.layers)._codebook.update_ema()
first(self.layers).update_in_place_optimizer()

# project out, if needed

Expand Down
14 changes: 12 additions & 2 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ def __init__(
manual_ema_update = False,
learnable_codebook = False,
in_place_codebook_optimizer: Callable[..., Optimizer] = None, # Optimizer used to update the codebook embedding if using learnable_codebook
manual_in_place_optimizer_update = False,
affine_param = False,
affine_param_batch_decay = 0.99,
affine_param_codebook_decay = 0.9,
Expand Down Expand Up @@ -913,6 +914,7 @@ def __init__(
self._codebook = codebook_class(**codebook_kwargs)

self.in_place_codebook_optimizer = in_place_codebook_optimizer(self._codebook.parameters()) if exists(in_place_codebook_optimizer) else None
self.manual_in_place_optimizer_update = manual_in_place_optimizer_update

self.codebook_size = codebook_size

Expand Down Expand Up @@ -966,6 +968,13 @@ def get_output_from_indices(self, indices):
codes = self.get_codes_from_indices(indices)
return self.project_out(codes)

def update_in_place_optimizer(self):
if not exists(self.in_place_codebook_optimizer):
return

self.in_place_codebook_optimizer.step()
self.in_place_codebook_optimizer.zero_grad()

def forward(
self,
x,
Expand Down Expand Up @@ -1057,8 +1066,9 @@ def forward(
loss = F.mse_loss(quantize, x.detach())

loss.backward()
self.in_place_codebook_optimizer.step()
self.in_place_codebook_optimizer.zero_grad()

if not self.manual_in_place_optimizer_update:
self.update_in_place_optimizer()

inplace_optimize_loss = loss

Expand Down

0 comments on commit 25683f2

Please sign in to comment.