Skip to content

Commit

Permalink
received emails from confused researchers re: pytest-examples this mo…
Browse files Browse the repository at this point in the history
…rning. get rid of it
  • Loading branch information
lucidrains committed May 10, 2024
1 parent 85f03c5 commit d9967be
Show file tree
Hide file tree
Showing 5 changed files with 333 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ jobs:
run: |
rye sync
- name: Run pytest
run: rye run pytest --cov=. tests/test_examples_readme.py
run: rye run pytest --cov=. tests/
110 changes: 52 additions & 58 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ vq = VectorQuantize(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])

```

## Residual VQ
Expand All @@ -49,13 +48,13 @@ x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8])
# (1, 1024, 256), (1, 1024, 8), (1, 8)

# if you need all the codes across the quantization layers, just pass return_all_codes = True

quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)
print(all_codes.shape)
#> torch.Size([8, 1, 1024, 256])

# (8, 1, 1024, 256)
```

Furthermore, <a href="https://arxiv.org/abs/2203.01941">this paper</a> uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.
Expand Down Expand Up @@ -97,8 +96,8 @@ residual_vq = GroupedResidualVQ(
x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([2, 1, 1024, 8]) torch.Size([2, 1, 8])

# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)

```

Expand All @@ -120,8 +119,8 @@ residual_vq = ResidualVQ(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 4]) torch.Size([1, 4])

# (1, 1024, 256), (1, 1024, 4), (1, 4)
```

## Increasing codebook usage
Expand All @@ -144,8 +143,8 @@ vq = VectorQuantize(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])

# (1, 1024, 256), (1, 1024), (1,)
```

### Cosine similarity
Expand All @@ -164,8 +163,8 @@ vq = VectorQuantize(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])

# (1, 1024, 256), (1, 1024), (1,)
```

### Expiring stale codes
Expand All @@ -184,8 +183,8 @@ vq = VectorQuantize(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])

# (1, 1024, 256), (1, 1024), (1,)
```

### Orthogonal regularization loss
Expand All @@ -209,9 +208,8 @@ vq = VectorQuantize(

img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)

# loss now contains the orthogonal regularization loss with the weight as assigned
print(quantized.shape, indices.shape, loss.shape)
#> torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32]) torch.Size([1])
```

### Multi-headed VQ
Expand All @@ -235,8 +233,8 @@ vq = VectorQuantize(

img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap)
print(quantized.shape, indices.shape, loss.shape)
#> torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32, 8]) torch.Size([1])

# (1, 256, 32, 32), (1, 32, 32, 8), (1,)

```

Expand All @@ -259,8 +257,8 @@ quantizer = RandomProjectionQuantizer(

x = torch.randn(1, 1024, 512)
indices = quantizer(x)
print(indices.shape)
#> torch.Size([1, 1024, 16])

# (1, 1024, 16)
```

This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False`
Expand All @@ -285,16 +283,14 @@ Thanks goes out to [@sekstini](https://github.com/sekstini) for porting over thi
import torch
from vector_quantize_pytorch import FSQ

levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
quantizer = FSQ(levels)
quantizer = FSQ(
levels = [8, 5, 5, 5]
)

x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
xhat, indices = quantizer(x)

print(xhat.shape)
#> torch.Size([1, 1024, 4])
print(indices.shape)
#> torch.Size([1, 1024])
# (1, 1024, 4), (1, 1024)

assert torch.all(xhat == quantizer.indices_to_codes(indices))
```
Expand All @@ -318,12 +314,12 @@ x = torch.randn(1, 1024, 256)
residual_fsq.eval()

quantized, indices = residual_fsq(x)
print(quantized.shape, indices.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8])

# (1, 1024, 256), (1, 1024, 8)

quantized_out = residual_fsq.get_output_from_indices(indices)
print(quantized_out.shape)
#> torch.Size([1, 1024, 256])

# (1, 1024, 256)

assert torch.all(quantized == quantized_out)
```
Expand Down Expand Up @@ -357,8 +353,8 @@ quantizer = LFQ(
image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.) # you may want to experiment with temperature
print(quantized.shape, indices.shape, entropy_aux_loss.shape)
#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([])

# (1, 16, 32, 32), (1, 32, 32), ()

assert (quantized == quantizer.indices_to_codes(indices)).all()
```
Expand All @@ -379,13 +375,12 @@ quantizer = LFQ(
seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)

# assert seq.shape == quantized.shape
assert seq.shape == quantized.shape

# video_feats = torch.randn(1, 16, 10, 32, 32)
# quantized, *_ = quantizer(video_feats)

# assert video_feats.shape == quantized.shape
video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)

assert video_feats.shape == quantized.shape
```

Or support multiple codebooks
Expand All @@ -403,8 +398,8 @@ quantizer = LFQ(
image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, entropy_aux_loss = quantizer(image_feats)
print(quantized.shape, indices.shape, entropy_aux_loss.shape)
#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32, 4]) torch.Size([])

# (1, 16, 32, 32), (1, 32, 32, 4), ()

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
Expand All @@ -427,12 +422,12 @@ x = torch.randn(1, 1024, 256)
residual_lfq.eval()

quantized, indices, commit_loss = residual_lfq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([8])

# (1, 1024, 256), (1, 1024, 8), (8)

quantized_out = residual_lfq.get_output_from_indices(indices)
print(quantized_out.shape)
#> torch.Size([1, 1024, 256])

# (1, 1024, 256)

assert torch.all(quantized == quantized_out)
```
Expand Down Expand Up @@ -460,8 +455,8 @@ quantizer = LatentQuantize(
image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, loss = quantizer(image_feats)
print(quantized.shape, indices.shape, loss.shape)
#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([])

# (1, 16, 32, 32), (1, 32, 32), ()

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
Expand All @@ -483,13 +478,13 @@ quantizer = LatentQuantize(

seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)
print(quantized.shape)
#> torch.Size([1, 32, 16])

# (1, 32, 16)

video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)
print(quantized.shape)
#> torch.Size([1, 16, 10, 32, 32])

# (1, 16, 10, 32, 32)

```

Expand All @@ -499,23 +494,22 @@ Or support multiple codebooks
import torch
from vector_quantize_pytorch import LatentQuantize

levels = [4, 8, 16]
dim = 9
num_codebooks = 3

model = LatentQuantize(levels, dim, num_codebooks=num_codebooks)
model = LatentQuantize(
levels = [4, 8, 16],
dim = 9,
num_codebooks = 3
)

input_tensor = torch.randn(2, 3, dim)
output_tensor, indices, loss = model(input_tensor)
print(output_tensor.shape, indices.shape, loss.shape)
#> torch.Size([2, 3, 9]) torch.Size([2, 3, 3]) torch.Size([])

# (2, 3, 9), (2, 3, 3), ()

assert output_tensor.shape == input_tensor.shape
assert indices.shape == (2, 3, num_codebooks)
assert loss.item() >= 0
```


## Citations

```bibtex
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ managed = true
dev-dependencies = [
"ruff>=0.4.2",
"pytest>=8.2.0",
"pytest-examples>=0.0.10",
"pytest-cov>=5.0.0",
]

Expand Down
27 changes: 0 additions & 27 deletions tests/test_examples_readme.py

This file was deleted.

Loading

0 comments on commit d9967be

Please sign in to comment.