diff --git a/audiolm_pytorch/soundstream.py b/audiolm_pytorch/soundstream.py index 55c6cd5..7de2907 100644 --- a/audiolm_pytorch/soundstream.py +++ b/audiolm_pytorch/soundstream.py @@ -19,7 +19,7 @@ from vector_quantize_pytorch import ( GroupedResidualVQ, - ResidualLFQ + GroupedResidualLFQ ) from local_attention import LocalMHA @@ -518,12 +518,11 @@ def __init__( self.rq_groups = rq_groups if use_lookup_free_quantizer: - assert rq_groups == 1, 'grouped residual LFQ not implemented yet' - - self.rq = ResidualLFQ( + self.rq = GroupedResidualLFQ( dim = codebook_dim, num_quantizers = rq_num_quantizers, codebook_size = codebook_size, + groups = rq_groups, quantize_dropout = True, quantize_dropout_cutoff_index = quantize_dropout_cutoff_index, **rq_kwargs diff --git a/audiolm_pytorch/version.py b/audiolm_pytorch/version.py index bcd8d54..bb64aa4 100644 --- a/audiolm_pytorch/version.py +++ b/audiolm_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.6.0' +__version__ = '1.6.1' diff --git a/setup.py b/setup.py index d8b9bd9..18f8cf9 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ 'torchaudio', 'transformers', 'tqdm', - 'vector-quantize-pytorch>=1.10.2' + 'vector-quantize-pytorch>=1.10.4' ], classifiers=[ 'Development Status :: 4 - Beta',