From d9967bee981dd6017b63a3594183eba516827f2d Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 10 May 2024 07:52:09 -0700 Subject: [PATCH] received emails from confused researchers re: pytest-examples this morning. get rid of it --- .github/workflows/test.yml | 2 +- README.md | 110 +++++++------ pyproject.toml | 1 - tests/test_examples_readme.py | 27 ---- tests/test_readme.py | 280 ++++++++++++++++++++++++++++++++++ 5 files changed, 333 insertions(+), 87 deletions(-) delete mode 100644 tests/test_examples_readme.py create mode 100644 tests/test_readme.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4b792b9..3a4a009 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,4 +16,4 @@ jobs: run: | rye sync - name: Run pytest - run: rye run pytest --cov=. tests/test_examples_readme.py + run: rye run pytest --cov=. tests/ diff --git a/README.md b/README.md index c3580fd..0464a67 100644 --- a/README.md +++ b/README.md @@ -27,8 +27,7 @@ vq = VectorQuantize( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1) -print(quantized.shape, indices.shape, commit_loss.shape) -#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1]) + ``` ## Residual VQ @@ -49,13 +48,13 @@ x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = residual_vq(x) print(quantized.shape, indices.shape, commit_loss.shape) -#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8]) +# (1, 1024, 256), (1, 1024, 8), (1, 8) # if you need all the codes across the quantization layers, just pass return_all_codes = True quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True) -print(all_codes.shape) -#> torch.Size([8, 1, 1024, 256]) + +# (8, 1, 1024, 256) ``` Furthermore, this paper uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes. @@ -97,8 +96,8 @@ residual_vq = GroupedResidualVQ( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = residual_vq(x) -print(quantized.shape, indices.shape, commit_loss.shape) -#> torch.Size([1, 1024, 256]) torch.Size([2, 1, 1024, 8]) torch.Size([2, 1, 8]) + +# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8) ``` @@ -120,8 +119,8 @@ residual_vq = ResidualVQ( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = residual_vq(x) -print(quantized.shape, indices.shape, commit_loss.shape) -#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 4]) torch.Size([1, 4]) + +# (1, 1024, 256), (1, 1024, 4), (1, 4) ``` ## Increasing codebook usage @@ -144,8 +143,8 @@ vq = VectorQuantize( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = vq(x) -print(quantized.shape, indices.shape, commit_loss.shape) -#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1]) + +# (1, 1024, 256), (1, 1024), (1,) ``` ### Cosine similarity @@ -164,8 +163,8 @@ vq = VectorQuantize( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = vq(x) -print(quantized.shape, indices.shape, commit_loss.shape) -#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1]) + +# (1, 1024, 256), (1, 1024), (1,) ``` ### Expiring stale codes @@ -184,8 +183,8 @@ vq = VectorQuantize( x = torch.randn(1, 1024, 256) quantized, indices, commit_loss = vq(x) -print(quantized.shape, indices.shape, commit_loss.shape) -#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1]) + +# (1, 1024, 256), (1, 1024), (1,) ``` ### Orthogonal regularization loss @@ -209,9 +208,8 @@ vq = VectorQuantize( img_fmap = torch.randn(1, 256, 32, 32) quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,) + # loss now contains the orthogonal regularization loss with the weight as assigned -print(quantized.shape, indices.shape, loss.shape) -#> torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32]) torch.Size([1]) ``` ### Multi-headed VQ @@ -235,8 +233,8 @@ vq = VectorQuantize( img_fmap = torch.randn(1, 256, 32, 32) quantized, indices, loss = vq(img_fmap) -print(quantized.shape, indices.shape, loss.shape) -#> torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32, 8]) torch.Size([1]) + +# (1, 256, 32, 32), (1, 32, 32, 8), (1,) ``` @@ -259,8 +257,8 @@ quantizer = RandomProjectionQuantizer( x = torch.randn(1, 1024, 512) indices = quantizer(x) -print(indices.shape) -#> torch.Size([1, 1024, 16]) + +# (1, 1024, 16) ``` This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False` @@ -285,16 +283,14 @@ Thanks goes out to [@sekstini](https://github.com/sekstini) for porting over thi import torch from vector_quantize_pytorch import FSQ -levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper -quantizer = FSQ(levels) +quantizer = FSQ( + levels = [8, 5, 5, 5] +) x = torch.randn(1, 1024, 4) # 4 since there are 4 levels xhat, indices = quantizer(x) -print(xhat.shape) -#> torch.Size([1, 1024, 4]) -print(indices.shape) -#> torch.Size([1, 1024]) +# (1, 1024, 4), (1, 1024) assert torch.all(xhat == quantizer.indices_to_codes(indices)) ``` @@ -318,12 +314,12 @@ x = torch.randn(1, 1024, 256) residual_fsq.eval() quantized, indices = residual_fsq(x) -print(quantized.shape, indices.shape) -#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) + +# (1, 1024, 256), (1, 1024, 8) quantized_out = residual_fsq.get_output_from_indices(indices) -print(quantized_out.shape) -#> torch.Size([1, 1024, 256]) + +# (1, 1024, 256) assert torch.all(quantized == quantized_out) ``` @@ -357,8 +353,8 @@ quantizer = LFQ( image_feats = torch.randn(1, 16, 32, 32) quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.) # you may want to experiment with temperature -print(quantized.shape, indices.shape, entropy_aux_loss.shape) -#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([]) + +# (1, 16, 32, 32), (1, 32, 32), () assert (quantized == quantizer.indices_to_codes(indices)).all() ``` @@ -379,13 +375,12 @@ quantizer = LFQ( seq = torch.randn(1, 32, 16) quantized, *_ = quantizer(seq) -# assert seq.shape == quantized.shape +assert seq.shape == quantized.shape -# video_feats = torch.randn(1, 16, 10, 32, 32) -# quantized, *_ = quantizer(video_feats) - -# assert video_feats.shape == quantized.shape +video_feats = torch.randn(1, 16, 10, 32, 32) +quantized, *_ = quantizer(video_feats) +assert video_feats.shape == quantized.shape ``` Or support multiple codebooks @@ -403,8 +398,8 @@ quantizer = LFQ( image_feats = torch.randn(1, 16, 32, 32) quantized, indices, entropy_aux_loss = quantizer(image_feats) -print(quantized.shape, indices.shape, entropy_aux_loss.shape) -#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32, 4]) torch.Size([]) + +# (1, 16, 32, 32), (1, 32, 32, 4), () assert image_feats.shape == quantized.shape assert (quantized == quantizer.indices_to_codes(indices)).all() @@ -427,12 +422,12 @@ x = torch.randn(1, 1024, 256) residual_lfq.eval() quantized, indices, commit_loss = residual_lfq(x) -print(quantized.shape, indices.shape, commit_loss.shape) -#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([8]) + +# (1, 1024, 256), (1, 1024, 8), (8) quantized_out = residual_lfq.get_output_from_indices(indices) -print(quantized_out.shape) -#> torch.Size([1, 1024, 256]) + +# (1, 1024, 256) assert torch.all(quantized == quantized_out) ``` @@ -460,8 +455,8 @@ quantizer = LatentQuantize( image_feats = torch.randn(1, 16, 32, 32) quantized, indices, loss = quantizer(image_feats) -print(quantized.shape, indices.shape, loss.shape) -#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([]) + +# (1, 16, 32, 32), (1, 32, 32), () assert image_feats.shape == quantized.shape assert (quantized == quantizer.indices_to_codes(indices)).all() @@ -483,13 +478,13 @@ quantizer = LatentQuantize( seq = torch.randn(1, 32, 16) quantized, *_ = quantizer(seq) -print(quantized.shape) -#> torch.Size([1, 32, 16]) + +# (1, 32, 16) video_feats = torch.randn(1, 16, 10, 32, 32) quantized, *_ = quantizer(video_feats) -print(quantized.shape) -#> torch.Size([1, 16, 10, 32, 32]) + +# (1, 16, 10, 32, 32) ``` @@ -499,23 +494,22 @@ Or support multiple codebooks import torch from vector_quantize_pytorch import LatentQuantize -levels = [4, 8, 16] -dim = 9 -num_codebooks = 3 - -model = LatentQuantize(levels, dim, num_codebooks=num_codebooks) +model = LatentQuantize( + levels = [4, 8, 16], + dim = 9, + num_codebooks = 3 +) input_tensor = torch.randn(2, 3, dim) output_tensor, indices, loss = model(input_tensor) -print(output_tensor.shape, indices.shape, loss.shape) -#> torch.Size([2, 3, 9]) torch.Size([2, 3, 3]) torch.Size([]) + +# (2, 3, 9), (2, 3, 3), () assert output_tensor.shape == input_tensor.shape assert indices.shape == (2, 3, num_codebooks) assert loss.item() >= 0 ``` - ## Citations ```bibtex diff --git a/pyproject.toml b/pyproject.toml index 4de811d..707742f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ managed = true dev-dependencies = [ "ruff>=0.4.2", "pytest>=8.2.0", - "pytest-examples>=0.0.10", "pytest-cov>=5.0.0", ] diff --git a/tests/test_examples_readme.py b/tests/test_examples_readme.py deleted file mode 100644 index 6c034bc..0000000 --- a/tests/test_examples_readme.py +++ /dev/null @@ -1,27 +0,0 @@ -import pytest -from pytest_examples import find_examples, CodeExample, EvalExample - - -@pytest.mark.parametrize('example', find_examples('README.md'), ids=str) -def test_docstrings(example: CodeExample, eval_example: EvalExample): - """Test all examples (automatically) found in README. - - Usage, in an activated virtual env: - ```py - (.venv) pytest tests/test_examples_readme.py - ``` - - for a simple check on running the examples, and - ```py - (.venv) pytest tests/test_examples_readme.py --update-examples - ``` - - to lint and format the code in the README. - - """ - if eval_example.update_examples: - eval_example.format(example) - eval_example.lint(example) - eval_example.run_print_check(example) - else: - eval_example.run_print_check(example) diff --git a/tests/test_readme.py b/tests/test_readme.py new file mode 100644 index 0000000..15e0b94 --- /dev/null +++ b/tests/test_readme.py @@ -0,0 +1,280 @@ +import torch +import pytest + +def test_vq(): + from vector_quantize_pytorch import VectorQuantize + + vq = VectorQuantize( + dim = 256, + codebook_size = 512, # codebook size + decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster + commitment_weight = 1. # the weight on the commitment loss + ) + + x = torch.randn(1, 1024, 256) + quantized, indices, commit_loss = vq(x) + + +def test_residual_vq(): + import torch + from vector_quantize_pytorch import ResidualVQ + + residual_vq = ResidualVQ( + dim = 256, + num_quantizers = 8, # specify number of quantizers + codebook_size = 1024, # codebook size + ) + + x = torch.randn(1, 1024, 256) + + quantized, indices, commit_loss = residual_vq(x) + quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True) + +def test_residual_vq2(): + import torch + from vector_quantize_pytorch import ResidualVQ + + residual_vq = ResidualVQ( + dim = 256, + num_quantizers = 8, + codebook_size = 1024, + stochastic_sample_codes = True, + sample_codebook_temp = 0.1, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic + shared_codebook = True # whether to share the codebooks for all quantizers or not + ) + + x = torch.randn(1, 1024, 256) + quantized, indices, commit_loss = residual_vq(x) + + +def test_grouped_residual_vq(): + import torch + from vector_quantize_pytorch import GroupedResidualVQ + + residual_vq = GroupedResidualVQ( + dim = 256, + num_quantizers = 8, # specify number of quantizers + groups = 2, + codebook_size = 1024, # codebook size + ) + + x = torch.randn(1, 1024, 256) + + quantized, indices, commit_loss = residual_vq(x) + +def test_residual_vq3(): + import torch + from vector_quantize_pytorch import ResidualVQ + + residual_vq = ResidualVQ( + dim = 256, + codebook_size = 256, + num_quantizers = 4, + kmeans_init = True, # set to True + kmeans_iters = 10 # number of kmeans iterations to calculate the centroids for the codebook on init + ) + + x = torch.randn(1, 1024, 256) + quantized, indices, commit_loss = residual_vq(x) + +def test_vq_lower_codebook(): + import torch + from vector_quantize_pytorch import VectorQuantize + + vq = VectorQuantize( + dim = 256, + codebook_size = 256, + codebook_dim = 16 # paper proposes setting this to 32 or as low as 8 to increase codebook usage + ) + + x = torch.randn(1, 1024, 256) + quantized, indices, commit_loss = vq(x) + +def test_vq_cosine_sim(): + import torch + from vector_quantize_pytorch import VectorQuantize + + vq = VectorQuantize( + dim = 256, + codebook_size = 256, + use_cosine_sim = True # set this to True + ) + + x = torch.randn(1, 1024, 256) + quantized, indices, commit_loss = vq(x) + +def test_vq_expire_code(): + import torch + from vector_quantize_pytorch import VectorQuantize + + vq = VectorQuantize( + dim = 256, + codebook_size = 512, + threshold_ema_dead_code = 2 # should actively replace any codes that have an exponential moving average cluster size less than 2 + ) + + x = torch.randn(1, 1024, 256) + quantized, indices, commit_loss = vq(x) + +def test_vq_multiheaded(): + import torch + from vector_quantize_pytorch import VectorQuantize + + vq = VectorQuantize( + dim = 256, + codebook_dim = 32, # a number of papers have shown smaller codebook dimension to be acceptable + heads = 8, # number of heads to vector quantize, codebook shared across all heads + separate_codebook_per_head = True, # whether to have a separate codebook per head. False would mean 1 shared codebook + codebook_size = 8196, + accept_image_fmap = True + ) + + img_fmap = torch.randn(1, 256, 32, 32) + quantized, indices, loss = vq(img_fmap) + +def test_rq(): + import torch + from vector_quantize_pytorch import RandomProjectionQuantizer + + quantizer = RandomProjectionQuantizer( + dim = 512, # input dimensions + num_codebooks = 16, # in USM, they used up to 16 for 5% gain + codebook_dim = 256, # codebook dimension + codebook_size = 1024 # codebook size + ) + + x = torch.randn(1, 1024, 512) + indices = quantizer(x) + +def test_fsq(): + import torch + from vector_quantize_pytorch import FSQ + + levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper + quantizer = FSQ(levels) + + x = torch.randn(1, 1024, 4) # 4 since there are 4 levels + xhat, indices = quantizer(x) + + assert torch.all(xhat == quantizer.indices_to_codes(indices)) + +def test_rfsq(): + import torch + from vector_quantize_pytorch import ResidualFSQ + + residual_fsq = ResidualFSQ( + dim = 256, + levels = [8, 5, 5, 3], + num_quantizers = 8 + ) + + x = torch.randn(1, 1024, 256) + + residual_fsq.eval() + + quantized, indices = residual_fsq(x) + + quantized_out = residual_fsq.get_output_from_indices(indices) + + assert torch.all(quantized == quantized_out) + +def test_lfq(): + import torch + from vector_quantize_pytorch import LFQ + + # you can specify either dim or codebook_size + # if both specified, will be validated against each other + + quantizer = LFQ( + codebook_size = 65536, # codebook size, must be a power of 2 + dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined + entropy_loss_weight = 0.1, # how much weight to place on entropy loss + diversity_gamma = 1. # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894 + ) + + image_feats = torch.randn(1, 16, 32, 32) + + quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.) # you may want to experiment with temperature + + assert (quantized == quantizer.indices_to_codes(indices)).all() + + +def test_lfq_video(): + import torch + from vector_quantize_pytorch import LFQ + + quantizer = LFQ( + codebook_size = 65536, + dim = 16, + entropy_loss_weight = 0.1, + diversity_gamma = 1. + ) + + seq = torch.randn(1, 32, 16) + quantized, *_ = quantizer(seq) + + assert seq.shape == quantized.shape + + video_feats = torch.randn(1, 16, 10, 32, 32) + quantized, *_ = quantizer(video_feats) + + assert video_feats.shape == quantized.shape + + +def test_lfq2(): + import torch + from vector_quantize_pytorch import LFQ + + quantizer = LFQ( + codebook_size = 4096, + dim = 16, + num_codebooks = 4 # 4 codebooks, total codebook dimension is log2(4096) * 4 + ) + + image_feats = torch.randn(1, 16, 32, 32) + + quantized, indices, entropy_aux_loss = quantizer(image_feats) + + assert image_feats.shape == quantized.shape + assert (quantized == quantizer.indices_to_codes(indices)).all() + +def test_rflq(): + import torch + from vector_quantize_pytorch import ResidualLFQ + + residual_lfq = ResidualLFQ( + dim = 256, + codebook_size = 256, + num_quantizers = 8 + ) + + x = torch.randn(1, 1024, 256) + + residual_lfq.eval() + + quantized, indices, commit_loss = residual_lfq(x) + + quantized_out = residual_lfq.get_output_from_indices(indices) + + assert torch.all(quantized == quantized_out) + +def test_latent_q(): + import torch + from vector_quantize_pytorch import LatentQuantize + + # you can specify either dim or codebook_size + # if both specified, will be validated against each other + + quantizer = LatentQuantize( + levels = [5, 5, 8], # number of levels per codebook dimension + dim = 16, # input dim + commitment_loss_weight=0.1, + quantization_loss_weight=0.1, + ) + + image_feats = torch.randn(1, 16, 32, 32) + + quantized, indices, loss = quantizer(image_feats) + + assert image_feats.shape == quantized.shape + assert (quantized == quantizer.indices_to_codes(indices)).all()