Skip to content

Commit

Permalink
throw in something experimental based on @cfifty paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 15, 2024
1 parent a72e251 commit a3c051b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.18.2"
version = "1.18.3"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
11 changes: 10 additions & 1 deletion vector_quantize_pytorch/lookup_free_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
commitment_loss_weight = 0.,
diversity_gamma = 1.,
straight_through_activation = nn.Identity(),
scale_trick = False, # @cfifty Fifty et al. https://arxiv.org/abs/2410.06424
num_codebooks = 1,
keep_num_codebooks_dim = None,
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
Expand Down Expand Up @@ -160,6 +161,9 @@ def __init__(

self.activation = straight_through_activation

assert not (scale_trick and spherical)
self.scale_trick = scale_trick

# whether to use BSQ (binary spherical quantization)

self.spherical = spherical
Expand Down Expand Up @@ -322,7 +326,12 @@ def forward(

if self.training:
x = self.activation(x)
x = x + (quantized - x).detach()

if self.scale_trick:
x = x * (quantized / x).detach()
else:
x = x + (quantized - x).detach()

else:
x = quantized

Expand Down

0 comments on commit a3c051b

Please sign in to comment.