Skip to content

Commit

Permalink
sanity check, clearly demonstrate tokenization of audio ability using…
Browse files Browse the repository at this point in the history
… residual LFQ
  • Loading branch information
lucidrains committed Oct 28, 2023
1 parent c4f2ef5 commit 0124c4b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 5 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,34 @@ trainer.train()

# after a lot of training, you can test the autoencoding as so

soundstream.eval() # your soundstream must be in eval mode, to avoid having the residual dropout of the residual VQ necessary for training

audio = torch.randn(10080).cuda()
recons = soundstream(audio, return_recons_only = True) # (1, 10080) - 1 channel
```

Your trained `SoundStream` can be used as a generic tokenizer for audio

```python

soundstream.eval()

audio = torch.randn(1, 512 * 320)

codes = soundstream(audio, return_codes_only = True)

# you can now train anything with the codebook ids

recon_audio_from_codes = soundstream.decode_from_codebook_indices(codes)

# sanity check

assert torch.allclose(
recon_audio_from_codes,
soundstream(audio, return_recons_only = True)
)
```

You can also use soundstreams that are specific to `AudioLM` and `MusicLM` by importing `AudioLMSoundStream` and `MusicLMSoundStream` respectively

```python
Expand Down
19 changes: 15 additions & 4 deletions audiolm_pytorch/soundstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ def __init__(

self.rq_groups = rq_groups

self.use_lookup_free_quantizer = use_lookup_free_quantizer

if use_lookup_free_quantizer:
self.rq = GroupedResidualLFQ(
dim = codebook_dim,
Expand Down Expand Up @@ -630,18 +632,22 @@ def configs(self):
return pickle.loads(self._configs)

def decode_from_codebook_indices(self, quantized_indices):
quantized_indices = rearrange(quantized_indices, 'b n (g q) -> g b n q', g = self.rq_groups)
assert quantized_indices.dtype == torch.long

if quantized_indices.ndim == 3:
quantized_indices = rearrange(quantized_indices, 'b n (g q) -> g b n q', g = self.rq_groups)

codes = self.rq.get_codes_from_indices(quantized_indices)
x = reduce(codes, 'g q b n d -> b n (g d)', 'sum')
x = self.rq.get_output_from_indices(quantized_indices)

return self.decode(x)

def decode(self, x, quantize = False):
if quantize:
x, *_ = self.rq(x)

x = self.decoder_attn(x)
if exists(self.decoder_attn):
x = self.decoder_attn(x)

x = rearrange(x, 'b n c -> b c n')
return self.decoder(x)

Expand All @@ -666,6 +672,7 @@ def init_and_load_from(cls, path, strict = True):
config = pickle.loads(pkg['config'])
soundstream = cls(**config)
soundstream.load(path, strict = strict)
soundstream.eval()
return soundstream

def load(self, path, strict = True):
Expand Down Expand Up @@ -735,6 +742,7 @@ def forward(
target = None,
is_denoising = None, # if you want to learn film conditioners that teach the soundstream to denoise - target would need to be passed in above
return_encoded = False,
return_codes_only = False,
return_discr_loss = False,
return_discr_losses_separately = False,
return_loss_breakdown = False,
Expand Down Expand Up @@ -767,6 +775,9 @@ def forward(

x, indices, commit_loss = self.rq(x)

if return_codes_only:
return indices

if return_encoded:
indices = rearrange(indices, 'g b n q -> b n (g q)')
return x, indices, commit_loss
Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.6.4'
__version__ = '1.6.5'

0 comments on commit 0124c4b

Please sign in to comment.