From 5e067d85257adf3f280f0963d5f7b59e7e084f6e Mon Sep 17 00:00:00 2001 From: Yao-Yuan Yang Date: Thu, 24 Jun 2021 21:15:54 +0000 Subject: [PATCH 1/5] Add pretrained wavernn --- torchaudio/models/__init__.py | 3 ++- torchaudio/models/_utils.py | 8 ++++++ torchaudio/models/wavernn.py | 46 ++++++++++++++++++++++++++++++++++- 3 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 torchaudio/models/_utils.py diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py index 843f6d15f5..ae4a2993cd 100644 --- a/torchaudio/models/__init__.py +++ b/torchaudio/models/__init__.py @@ -1,5 +1,5 @@ from .wav2letter import Wav2Letter -from .wavernn import WaveRNN +from .wavernn import WaveRNN, wavernn_10k_epochs_8bits_ljspeech from .conv_tasnet import ConvTasNet from .deepspeech import DeepSpeech from .wav2vec2 import ( @@ -13,6 +13,7 @@ __all__ = [ 'Wav2Letter', 'WaveRNN', + 'wavernn_10k_epochs_8bits_ljspeech', 'ConvTasNet', 'DeepSpeech', 'Wav2Vec2Model', diff --git a/torchaudio/models/_utils.py b/torchaudio/models/_utils.py new file mode 100644 index 0000000000..3f4800dc75 --- /dev/null +++ b/torchaudio/models/_utils.py @@ -0,0 +1,8 @@ +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +__all__ = [ + 'load_state_dict_from_url', +] diff --git a/torchaudio/models/wavernn.py b/torchaudio/models/wavernn.py index 89c1e9d430..065657b3c5 100644 --- a/torchaudio/models/wavernn.py +++ b/torchaudio/models/wavernn.py @@ -1,18 +1,28 @@ -from typing import List, Tuple +from typing import List, Tuple, Any import torch from torch import Tensor from torch import nn +from ._utils import load_state_dict_from_url + + __all__ = [ "ResBlock", "MelResNet", "Stretch2d", "UpsampleNetwork", "WaveRNN", + "wavernn_10k_epochs_8bits_ljspeech", ] +model_urls = { + 'wavernn_10k_epochs_8bits_ljspeech': 'https://download.pytorch.org/models/' + 'audio/wavernn_10k_epochs_8bits_ljspeech.pth', +} + + class ResBlock(nn.Module): r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`]. @@ -324,3 +334,37 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: # bring back channel dimension return x.unsqueeze(1) + + +def _wavernn(arch: str, pretrained: bool, progress: bool, **kwargs: Any) -> WaveRNN: + model = WaveRNN(**kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['wavernn'], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def wavernn_10k_epochs_8bits_ljspeech(pretrained: bool = True, progress: bool = True, **kwargs: Any) -> WaveRNN: + r"""WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset. + The model is trained using the default parameters and code of the examples/pipeline_wavernn/main.py. + + Args: + pretrained (bool): If True, returns a model pre-trained on LJSpeech + progress (bool): If True, displays a progress bar of the download to stderr + """ + n_bits = 8 + configs = { + 'upsample_scales': [5, 5, 11], + 'n_classes': 2 ** n_bits, + 'hop_length': 275, + 'n_res_block': 10, + 'n_rnn': 512, + 'n_fc': 512, + 'kernel_size': 5, + 'n_freq': 80, + 'n_hidden': 128, + 'n_output': 128 + } + configs.update(kwargs) + return _wavernn("wavernn_10k_epochs_8bits_ljspeech", pretrained=pretrained, progress=progress, **configs) From 330e3295b8473714e1ca68cca7f828022a4b871f Mon Sep 17 00:00:00 2001 From: Yao-Yuan Yang Date: Thu, 8 Jul 2021 18:33:14 +0000 Subject: [PATCH 2/5] Refactor the pretrained wavernn interface --- torchaudio/models/__init__.py | 4 +-- torchaudio/models/wavernn.py | 67 ++++++++++++++++++----------------- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py index ae4a2993cd..1a1a85d874 100644 --- a/torchaudio/models/__init__.py +++ b/torchaudio/models/__init__.py @@ -1,5 +1,5 @@ from .wav2letter import Wav2Letter -from .wavernn import WaveRNN, wavernn_10k_epochs_8bits_ljspeech +from .wavernn import WaveRNN, get_pretrained_wavernn from .conv_tasnet import ConvTasNet from .deepspeech import DeepSpeech from .wav2vec2 import ( @@ -13,7 +13,7 @@ __all__ = [ 'Wav2Letter', 'WaveRNN', - 'wavernn_10k_epochs_8bits_ljspeech', + 'get_pretrained_wavernn', 'ConvTasNet', 'DeepSpeech', 'Wav2Vec2Model', diff --git a/torchaudio/models/wavernn.py b/torchaudio/models/wavernn.py index 065657b3c5..78c93fa0ce 100644 --- a/torchaudio/models/wavernn.py +++ b/torchaudio/models/wavernn.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Any +from typing import List, Tuple, Dict, Any import torch from torch import Tensor @@ -13,13 +13,26 @@ "Stretch2d", "UpsampleNetwork", "WaveRNN", - "wavernn_10k_epochs_8bits_ljspeech", + "get_pretrained_wavernn", ] -model_urls = { - 'wavernn_10k_epochs_8bits_ljspeech': 'https://download.pytorch.org/models/' - 'audio/wavernn_10k_epochs_8bits_ljspeech.pth', +model_config_and_urls: Dict[str, Tuple[str, Dict[str, Any]]] = { + 'wavernn_10k_epochs_8bits_ljspeech': ( + 'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth', + { + 'upsample_scales': [5, 5, 11], + 'n_classes': 2 ** 8, # n_bits = 8 + 'hop_length': 275, + 'n_res_block': 10, + 'n_rnn': 512, + 'n_fc': 512, + 'kernel_size': 5, + 'n_freq': 80, + 'n_hidden': 128, + 'n_output': 128 + } + ) } @@ -336,35 +349,25 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: return x.unsqueeze(1) -def _wavernn(arch: str, pretrained: bool, progress: bool, **kwargs: Any) -> WaveRNN: - model = WaveRNN(**kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls['wavernn'], - progress=progress) - model.load_state_dict(state_dict) - return model +def get_pretrained_wavernn(checkpoint_name: str, progress: bool = True) -> WaveRNN: + r"""Get pretrained WaveRNN model. + + Here are the available checkpoints: + - wavernn_10k_epochs_8bits_ljspeech -def wavernn_10k_epochs_8bits_ljspeech(pretrained: bool = True, progress: bool = True, **kwargs: Any) -> WaveRNN: - r"""WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset. - The model is trained using the default parameters and code of the examples/pipeline_wavernn/main.py. + WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset. + The model is trained using the default parameters and code of the examples/pipeline_wavernn/main.py. Args: - pretrained (bool): If True, returns a model pre-trained on LJSpeech - progress (bool): If True, displays a progress bar of the download to stderr + checkpoint_name (str): The name of the checkpoint to load. + progress (bool): If True, displays a progress bar of the download to stderr. """ - n_bits = 8 - configs = { - 'upsample_scales': [5, 5, 11], - 'n_classes': 2 ** n_bits, - 'hop_length': 275, - 'n_res_block': 10, - 'n_rnn': 512, - 'n_fc': 512, - 'kernel_size': 5, - 'n_freq': 80, - 'n_hidden': 128, - 'n_output': 128 - } - configs.update(kwargs) - return _wavernn("wavernn_10k_epochs_8bits_ljspeech", pretrained=pretrained, progress=progress, **configs) + if checkpoint_name in model_config_and_urls: + url, configs = model_config_and_urls[checkpoint_name] + model = WaveRNN(**configs) + state_dict = load_state_dict_from_url(url, progress=progress) + model.load_state_dict(state_dict) + return model + else: + raise ValueError("The model_name `{}` is not supported.".format(checkpoint_name)) From 8f0466d5ed96617b5d3fe571650ae9f427449922 Mon Sep 17 00:00:00 2001 From: Yao-Yuan Yang Date: Mon, 12 Jul 2021 18:25:08 +0000 Subject: [PATCH 3/5] Fix a few coding style --- torchaudio/models/_utils.py | 8 -------- torchaudio/models/wavernn.py | 24 +++++++++++++----------- 2 files changed, 13 insertions(+), 19 deletions(-) delete mode 100644 torchaudio/models/_utils.py diff --git a/torchaudio/models/_utils.py b/torchaudio/models/_utils.py deleted file mode 100644 index 3f4800dc75..0000000000 --- a/torchaudio/models/_utils.py +++ /dev/null @@ -1,8 +0,0 @@ -try: - from torch.hub import load_state_dict_from_url -except ImportError: - from torch.utils.model_zoo import load_url as load_state_dict_from_url - -__all__ = [ - 'load_state_dict_from_url', -] diff --git a/torchaudio/models/wavernn.py b/torchaudio/models/wavernn.py index 78c93fa0ce..f5142382b6 100644 --- a/torchaudio/models/wavernn.py +++ b/torchaudio/models/wavernn.py @@ -3,8 +3,10 @@ import torch from torch import Tensor from torch import nn - -from ._utils import load_state_dict_from_url +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url __all__ = [ @@ -17,7 +19,7 @@ ] -model_config_and_urls: Dict[str, Tuple[str, Dict[str, Any]]] = { +_MODEL_CONFIG_AND_URLS: Dict[str, Tuple[str, Dict[str, Any]]] = { 'wavernn_10k_epochs_8bits_ljspeech': ( 'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth', { @@ -363,11 +365,11 @@ def get_pretrained_wavernn(checkpoint_name: str, progress: bool = True) -> WaveR checkpoint_name (str): The name of the checkpoint to load. progress (bool): If True, displays a progress bar of the download to stderr. """ - if checkpoint_name in model_config_and_urls: - url, configs = model_config_and_urls[checkpoint_name] - model = WaveRNN(**configs) - state_dict = load_state_dict_from_url(url, progress=progress) - model.load_state_dict(state_dict) - return model - else: - raise ValueError("The model_name `{}` is not supported.".format(checkpoint_name)) + if checkpoint_name not in _MODEL_CONFIG_AND_URLS: + raise ValueError("The checkpoint_name `{}` is not supported.".format(checkpoint_name)) + + url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name] + model = WaveRNN(**configs) + state_dict = load_state_dict_from_url(url, progress=progress) + model.load_state_dict(state_dict) + return model From aacbf9c6a93a3f0df0657610ae51eee088783a94 Mon Sep 17 00:00:00 2001 From: Yao-Yuan Yang Date: Mon, 12 Jul 2021 21:51:58 +0000 Subject: [PATCH 4/5] add pretrained WaveRNN to docs --- docs/source/models.rst | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 39e162baa0..4c6cc0c698 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -88,8 +88,15 @@ WaveRNN .. automethod:: forward +Factory Functions +----------------- + +get_pretrained_wavernn +---------------------- + +.. autofunction:: get_pretrained_wavernn + References ~~~~~~~~~~ .. footbibliography:: - From d17fc54b9a19d8e2944b2dac1d8de80258594104 Mon Sep 17 00:00:00 2001 From: Yao-Yuan Yang Date: Tue, 13 Jul 2021 19:53:45 +0000 Subject: [PATCH 5/5] Add wavernn inference function --- examples/pipeline_wavernn/inference.py | 272 +++++++++++++++++++++++++ 1 file changed, 272 insertions(+) create mode 100644 examples/pipeline_wavernn/inference.py diff --git a/examples/pipeline_wavernn/inference.py b/examples/pipeline_wavernn/inference.py new file mode 100644 index 0000000000..ae0874afbb --- /dev/null +++ b/examples/pipeline_wavernn/inference.py @@ -0,0 +1,272 @@ +import random +from typing import List + +import torch +from torch import Tensor +import torch.nn.functional as F +import torchaudio +from torchaudio.transforms import MelSpectrogram +from torchaudio.models import get_pretrained_wavernn +from torchaudio.datasets import LJSPEECH +import numpy as np +from tqdm import tqdm + +from processing import NormalizeDB + + +def fold_with_overlap(x, target, overlap): + r'''Fold the tensor with overlap for quick batched inference. + Overlap will be used for crossfading in xfade_and_unfold() + + Args: + x (tensor): Upsampled conditioning features. + shape=(1, timesteps, features) + target (int): Target timesteps for each index of batch + overlap (int): Timesteps for both xfade and rnn warmup + Return: + (tensor) : shape=(num_folds, target + 2 * overlap, features) + Details: + x = [[h1, h2, ... hn]] + Where each h is a vector of conditioning features + Eg: target=2, overlap=1 with x.size(1)=10 + folded = [[h1, h2, h3, h4], + [h4, h5, h6, h7], + [h7, h8, h9, h10]] + ''' + + _, total_len, features = x.size() + + # Calculate variables needed + num_folds = (total_len - overlap) // (target + overlap) + extended_len = num_folds * (overlap + target) + overlap + remaining = total_len - extended_len + + # Pad if some time steps poking out + if remaining != 0: + num_folds += 1 + padding = target + 2 * overlap - remaining + x = pad_tensor(x, padding, side='after') + + folded = torch.zeros(num_folds, target + 2 * overlap, features, device=x.device) + + # Get the values for the folded tensor + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + folded[i] = x[:, start:end, :] + + return folded + +def xfade_and_unfold(y: Tensor, target: int, overlap: int) -> Tensor: + ''' Applies a crossfade and unfolds into a 1d array. + + Args: + y (Tensor): Batched sequences of audio samples + shape=(num_folds, target + 2 * overlap) + target (int): + overlap (int): Timesteps for both xfade and rnn warmup + + Returns: + (Tensor) : audio samples in a 1d array + shape=(total_len) + dtype=np.float64 + Details: + y = [[seq1], + [seq2], + [seq3]] + Apply a gain envelope at both ends of the sequences + y = [[seq1_in, seq1_target, seq1_out], + [seq2_in, seq2_target, seq2_out], + [seq3_in, seq3_target, seq3_out]] + Stagger and add up the groups of samples: + [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...] + ''' + + num_folds, length = y.shape + target = length - 2 * overlap + total_len = num_folds * (target + overlap) + overlap + + # Need some silence for the rnn warmup + silence_len = overlap // 2 + fade_len = overlap - silence_len + silence = torch.zeros((silence_len), dtype=y.dtype, device=y.device) + linear = torch.ones((silence_len), dtype=y.dtype, device=y.device) + + # Equal power crossfade + t = torch.linspace(-1, 1, fade_len, dtype=y.dtype, device=y.device) + fade_in = np.sqrt(0.5 * (1 + t)) + fade_out = np.sqrt(0.5 * (1 - t)) + + # Concat the silence to the fades + fade_in = torch.cat([silence, fade_in]) + fade_out = torch.cat([linear, fade_out]) + + # Apply the gain to the overlap samples + y[:, :overlap] *= fade_in + y[:, -overlap:] *= fade_out + + unfolded = torch.zeros((total_len), dtype=y.dtype, device=y.device) + + # Loop to add up all the samples + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + unfolded[start:end] += y[i] + + return unfolded + +def get_gru_cell(gru): + gru_cell = torch.nn.GRUCell(gru.input_size, gru.hidden_size) + gru_cell.weight_hh.data = gru.weight_hh_l0.data + gru_cell.weight_ih.data = gru.weight_ih_l0.data + gru_cell.bias_hh.data = gru.bias_hh_l0.data + gru_cell.bias_ih.data = gru.bias_ih_l0.data + return gru_cell + +def pad_tensor(x, pad, side='both'): + # NB - this is just a quick method i need right now + # i.e., it won't generalise to other shapes/dims + b, t, c = x.size() + total = t + 2 * pad if side == 'both' else t + pad + padded = torch.zeros(b, total, c, device=x.device) + if side == 'before' or side == 'both': + padded[:, pad:pad + t, :] = x + elif side == 'after': + padded[:, :t, :] = x + return padded + +def infer(model, mel_specgram: Tensor, mulaw_decode: bool = True, + batched: bool = True, target: int = 11000, overlap: int = 550) -> Tensor: + r"""Inference + + Args: + model (WaveRNN): The WaveRNN model. + mel_specgram (Tensor): mel spectrogram with shape (n_mels, n_time) + mulaw_decode (bool): + batched (bool): batch prediction + target (int): (Default: ``11000``) + overlap (int): (Default: ``550``) + + Returns: + waveform (Tensor): Reconstructed wave form with shape (n_time, ). + + """ + device = mel_specgram.device + dtype = mel_specgram.dtype + + output: List[Tensor] = [] + rnn1 = get_gru_cell(model.rnn1) + rnn2 = get_gru_cell(model.rnn2) + + mel_specgram = mel_specgram.unsqueeze(0) + mel_specgram = pad_tensor(mel_specgram.transpose(1, 2), pad=model.pad, side='both') + mel_specgram, aux = model.upsample(mel_specgram.transpose(1, 2)) + + mel_specgram, aux = mel_specgram.transpose(1, 2), aux.transpose(1, 2) + + if batched: + mel_specgram = fold_with_overlap(mel_specgram, target, overlap) + aux = fold_with_overlap(aux, target, overlap) + + b_size, seq_len, _ = mel_specgram.size() + + h1 = torch.zeros((b_size, model.n_rnn), device=device, dtype=dtype) + h2 = torch.zeros((b_size, model.n_rnn), device=device, dtype=dtype) + x = torch.zeros((b_size, 1), device=device, dtype=dtype) + + d = model.n_aux + aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)] + + for i in tqdm(range(seq_len)): + + m_t = mel_specgram[:, i, :] + + a1_t, a2_t, a3_t, a4_t = \ + (a[:, i, :] for a in aux_split) + + x = torch.cat([x, m_t, a1_t], dim=1) + x = model.fc(x) + h1 = rnn1(x, h1) + + x = x + h1 + inp = torch.cat([x, a2_t], dim=1) + h2 = rnn2(inp, h2) + + x = x + h2 + x = torch.cat([x, a3_t], dim=1) + x = F.relu(model.fc1(x)) + + x = torch.cat([x, a4_t], dim=1) + x = F.relu(model.fc2(x)) + + logits = model.fc3(x) + + posterior = F.softmax(logits, dim=1) + distrib = torch.distributions.Categorical(posterior) + + sample = 2 * distrib.sample().float() / (model.n_classes - 1.) - 1. + #sample = distrib.sample().float() + output.append(sample) + x = sample.unsqueeze(-1) + + output = torch.stack(output).transpose(0, 1).cpu() + + if batched: + output = xfade_and_unfold(output, target, overlap) + else: + output = output[0] + + return output + + +def decode_mu_law(y: Tensor, mu: int, from_labels: bool = True) -> Tensor: + if from_labels: + y = 2 * y / (mu - 1.) - 1. + mu = mu - 1 + x = torch.sign(y) / mu * ((1 + mu) ** torch.abs(y) - 1) + return x + + +def main(): + torch.use_deterministic_algorithms(True) + device = "cuda" if torch.cuda.is_available() else "cpu" + dset = LJSPEECH("./", download=True) + waveform, sample_rate, _, _ = dset[0] + torchaudio.save("original.wav", waveform, sample_rate=sample_rate) + + n_bits = 8 + mel_kwargs = { + 'sample_rate': sample_rate, + 'n_fft': 2048, + 'f_min': 40., + 'n_mels': 80, + 'win_length': 1100, + 'hop_length': 275, + 'mel_scale': 'slaney', + 'norm': 'slaney', + 'power': 1, + } + transforms = torch.nn.Sequential( + MelSpectrogram(**mel_kwargs), + NormalizeDB(min_level_db=-100, normalization=True), + ) + mel_specgram = transforms(waveform) + + wavernn_model = get_pretrained_wavernn("wavernn_10k_epochs_8bits_ljspeech", + progress=True).eval().to(device) + wavernn_model.pad = (wavernn_model.kernel_size - 1) // 2 + + with torch.no_grad(): + output = infer(wavernn_model, mel_specgram.to(device)) + + output = torchaudio.functional.mu_law_decoding(output, n_bits) + #output = decode_mu_law(output, 2**n_bits, False) + + torch.save(output, "output.pkl") + + torchaudio.save("result.wav", output.reshape(1, -1), sample_rate=sample_rate) + import ipdb; ipdb.set_trace() + + +if __name__ == "__main__": + main()