diff --git a/pyproject.toml b/pyproject.toml index 2dc4c83..c53f493 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.18.2" +version = "1.18.3" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/lookup_free_quantization.py b/vector_quantize_pytorch/lookup_free_quantization.py index 0df700a..c8ff2e8 100644 --- a/vector_quantize_pytorch/lookup_free_quantization.py +++ b/vector_quantize_pytorch/lookup_free_quantization.py @@ -103,6 +103,7 @@ def __init__( commitment_loss_weight = 0., diversity_gamma = 1., straight_through_activation = nn.Identity(), + scale_trick = False, # @cfifty Fifty et al. https://arxiv.org/abs/2410.06424 num_codebooks = 1, keep_num_codebooks_dim = None, codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer @@ -160,6 +161,9 @@ def __init__( self.activation = straight_through_activation + assert not (scale_trick and spherical) + self.scale_trick = scale_trick + # whether to use BSQ (binary spherical quantization) self.spherical = spherical @@ -322,7 +326,12 @@ def forward( if self.training: x = self.activation(x) - x = x + (quantized - x).detach() + + if self.scale_trick: + x = x * (quantized / x).detach() + else: + x = x + (quantized - x).detach() + else: x = quantized