Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WaveRNN inference function #3

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,15 @@ WaveRNN

.. automethod:: forward

Factory Functions
-----------------

get_pretrained_wavernn
----------------------

.. autofunction:: get_pretrained_wavernn

References
~~~~~~~~~~

.. footbibliography::

272 changes: 272 additions & 0 deletions examples/pipeline_wavernn/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
import random
from typing import List

import torch
from torch import Tensor
import torch.nn.functional as F
import torchaudio
from torchaudio.transforms import MelSpectrogram
from torchaudio.models import get_pretrained_wavernn
from torchaudio.datasets import LJSPEECH
import numpy as np
from tqdm import tqdm

from processing import NormalizeDB


def fold_with_overlap(x, target, overlap):
r'''Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold()

Args:
x (tensor): Upsampled conditioning features.
shape=(1, timesteps, features)
target (int): Target timesteps for each index of batch
overlap (int): Timesteps for both xfade and rnn warmup
Return:
(tensor) : shape=(num_folds, target + 2 * overlap, features)
Details:
x = [[h1, h2, ... hn]]
Where each h is a vector of conditioning features
Eg: target=2, overlap=1 with x.size(1)=10
folded = [[h1, h2, h3, h4],
[h4, h5, h6, h7],
[h7, h8, h9, h10]]
'''

_, total_len, features = x.size()

# Calculate variables needed
num_folds = (total_len - overlap) // (target + overlap)
extended_len = num_folds * (overlap + target) + overlap
remaining = total_len - extended_len

# Pad if some time steps poking out
if remaining != 0:
num_folds += 1
padding = target + 2 * overlap - remaining
x = pad_tensor(x, padding, side='after')

folded = torch.zeros(num_folds, target + 2 * overlap, features, device=x.device)

# Get the values for the folded tensor
for i in range(num_folds):
start = i * (target + overlap)
end = start + target + 2 * overlap
folded[i] = x[:, start:end, :]

return folded

def xfade_and_unfold(y: Tensor, target: int, overlap: int) -> Tensor:
''' Applies a crossfade and unfolds into a 1d array.

Args:
y (Tensor): Batched sequences of audio samples
shape=(num_folds, target + 2 * overlap)
target (int):
overlap (int): Timesteps for both xfade and rnn warmup

Returns:
(Tensor) : audio samples in a 1d array
shape=(total_len)
dtype=np.float64
Details:
y = [[seq1],
[seq2],
[seq3]]
Apply a gain envelope at both ends of the sequences
y = [[seq1_in, seq1_target, seq1_out],
[seq2_in, seq2_target, seq2_out],
[seq3_in, seq3_target, seq3_out]]
Stagger and add up the groups of samples:
[seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
'''

num_folds, length = y.shape
target = length - 2 * overlap
total_len = num_folds * (target + overlap) + overlap

# Need some silence for the rnn warmup
silence_len = overlap // 2
fade_len = overlap - silence_len
silence = torch.zeros((silence_len), dtype=y.dtype, device=y.device)
linear = torch.ones((silence_len), dtype=y.dtype, device=y.device)

# Equal power crossfade
t = torch.linspace(-1, 1, fade_len, dtype=y.dtype, device=y.device)
fade_in = np.sqrt(0.5 * (1 + t))
fade_out = np.sqrt(0.5 * (1 - t))

# Concat the silence to the fades
fade_in = torch.cat([silence, fade_in])
fade_out = torch.cat([linear, fade_out])

# Apply the gain to the overlap samples
y[:, :overlap] *= fade_in
y[:, -overlap:] *= fade_out

unfolded = torch.zeros((total_len), dtype=y.dtype, device=y.device)

# Loop to add up all the samples
for i in range(num_folds):
start = i * (target + overlap)
end = start + target + 2 * overlap
unfolded[start:end] += y[i]

return unfolded

def get_gru_cell(gru):
gru_cell = torch.nn.GRUCell(gru.input_size, gru.hidden_size)
gru_cell.weight_hh.data = gru.weight_hh_l0.data
gru_cell.weight_ih.data = gru.weight_ih_l0.data
gru_cell.bias_hh.data = gru.bias_hh_l0.data
gru_cell.bias_ih.data = gru.bias_ih_l0.data
return gru_cell

def pad_tensor(x, pad, side='both'):
# NB - this is just a quick method i need right now
# i.e., it won't generalise to other shapes/dims
b, t, c = x.size()
total = t + 2 * pad if side == 'both' else t + pad
padded = torch.zeros(b, total, c, device=x.device)
if side == 'before' or side == 'both':
padded[:, pad:pad + t, :] = x
elif side == 'after':
padded[:, :t, :] = x
return padded

def infer(model, mel_specgram: Tensor, mulaw_decode: bool = True,
batched: bool = True, target: int = 11000, overlap: int = 550) -> Tensor:
r"""Inference

Args:
model (WaveRNN): The WaveRNN model.
mel_specgram (Tensor): mel spectrogram with shape (n_mels, n_time)
mulaw_decode (bool):
batched (bool): batch prediction
target (int): (Default: ``11000``)
overlap (int): (Default: ``550``)

Returns:
waveform (Tensor): Reconstructed wave form with shape (n_time, ).

"""
device = mel_specgram.device
dtype = mel_specgram.dtype

output: List[Tensor] = []
rnn1 = get_gru_cell(model.rnn1)
rnn2 = get_gru_cell(model.rnn2)

mel_specgram = mel_specgram.unsqueeze(0)
mel_specgram = pad_tensor(mel_specgram.transpose(1, 2), pad=model.pad, side='both')
mel_specgram, aux = model.upsample(mel_specgram.transpose(1, 2))

mel_specgram, aux = mel_specgram.transpose(1, 2), aux.transpose(1, 2)

if batched:
mel_specgram = fold_with_overlap(mel_specgram, target, overlap)
aux = fold_with_overlap(aux, target, overlap)

b_size, seq_len, _ = mel_specgram.size()

h1 = torch.zeros((b_size, model.n_rnn), device=device, dtype=dtype)
h2 = torch.zeros((b_size, model.n_rnn), device=device, dtype=dtype)
x = torch.zeros((b_size, 1), device=device, dtype=dtype)

d = model.n_aux
aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)]

for i in tqdm(range(seq_len)):

m_t = mel_specgram[:, i, :]

a1_t, a2_t, a3_t, a4_t = \
(a[:, i, :] for a in aux_split)

x = torch.cat([x, m_t, a1_t], dim=1)
x = model.fc(x)
h1 = rnn1(x, h1)

x = x + h1
inp = torch.cat([x, a2_t], dim=1)
h2 = rnn2(inp, h2)

x = x + h2
x = torch.cat([x, a3_t], dim=1)
x = F.relu(model.fc1(x))

x = torch.cat([x, a4_t], dim=1)
x = F.relu(model.fc2(x))

logits = model.fc3(x)

posterior = F.softmax(logits, dim=1)
distrib = torch.distributions.Categorical(posterior)

sample = 2 * distrib.sample().float() / (model.n_classes - 1.) - 1.
#sample = distrib.sample().float()
output.append(sample)
x = sample.unsqueeze(-1)

output = torch.stack(output).transpose(0, 1).cpu()

if batched:
output = xfade_and_unfold(output, target, overlap)
else:
output = output[0]

return output


def decode_mu_law(y: Tensor, mu: int, from_labels: bool = True) -> Tensor:
if from_labels:
y = 2 * y / (mu - 1.) - 1.
mu = mu - 1
x = torch.sign(y) / mu * ((1 + mu) ** torch.abs(y) - 1)
return x


def main():
torch.use_deterministic_algorithms(True)
device = "cuda" if torch.cuda.is_available() else "cpu"
dset = LJSPEECH("./", download=True)
waveform, sample_rate, _, _ = dset[0]
torchaudio.save("original.wav", waveform, sample_rate=sample_rate)

n_bits = 8
mel_kwargs = {
'sample_rate': sample_rate,
'n_fft': 2048,
'f_min': 40.,
'n_mels': 80,
'win_length': 1100,
'hop_length': 275,
'mel_scale': 'slaney',
'norm': 'slaney',
'power': 1,
}
transforms = torch.nn.Sequential(
MelSpectrogram(**mel_kwargs),
NormalizeDB(min_level_db=-100, normalization=True),
)
mel_specgram = transforms(waveform)

wavernn_model = get_pretrained_wavernn("wavernn_10k_epochs_8bits_ljspeech",
progress=True).eval().to(device)
wavernn_model.pad = (wavernn_model.kernel_size - 1) // 2

with torch.no_grad():
output = infer(wavernn_model, mel_specgram.to(device))

output = torchaudio.functional.mu_law_decoding(output, n_bits)
#output = decode_mu_law(output, 2**n_bits, False)

torch.save(output, "output.pkl")

torchaudio.save("result.wav", output.reshape(1, -1), sample_rate=sample_rate)
import ipdb; ipdb.set_trace()


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion torchaudio/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .wav2letter import Wav2Letter
from .wavernn import WaveRNN
from .wavernn import WaveRNN, get_pretrained_wavernn
from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech
from .wav2vec2 import (
Expand All @@ -13,6 +13,7 @@
__all__ = [
'Wav2Letter',
'WaveRNN',
'get_pretrained_wavernn',
'ConvTasNet',
'DeepSpeech',
'Wav2Vec2Model',
Expand Down
51 changes: 50 additions & 1 deletion torchaudio/models/wavernn.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,43 @@
from typing import List, Tuple
from typing import List, Tuple, Dict, Any

import torch
from torch import Tensor
from torch import nn
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url


__all__ = [
"ResBlock",
"MelResNet",
"Stretch2d",
"UpsampleNetwork",
"WaveRNN",
"get_pretrained_wavernn",
]


_MODEL_CONFIG_AND_URLS: Dict[str, Tuple[str, Dict[str, Any]]] = {
'wavernn_10k_epochs_8bits_ljspeech': (
'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth',
{
'upsample_scales': [5, 5, 11],
'n_classes': 2 ** 8, # n_bits = 8
'hop_length': 275,
'n_res_block': 10,
'n_rnn': 512,
'n_fc': 512,
'kernel_size': 5,
'n_freq': 80,
'n_hidden': 128,
'n_output': 128
}
)
}


class ResBlock(nn.Module):
r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`].

Expand Down Expand Up @@ -324,3 +349,27 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:

# bring back channel dimension
return x.unsqueeze(1)


def get_pretrained_wavernn(checkpoint_name: str, progress: bool = True) -> WaveRNN:
r"""Get pretrained WaveRNN model.

Here are the available checkpoints:

- wavernn_10k_epochs_8bits_ljspeech

WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset.
The model is trained using the default parameters and code of the examples/pipeline_wavernn/main.py.

Args:
checkpoint_name (str): The name of the checkpoint to load.
progress (bool): If True, displays a progress bar of the download to stderr.
"""
if checkpoint_name not in _MODEL_CONFIG_AND_URLS:
raise ValueError("The checkpoint_name `{}` is not supported.".format(checkpoint_name))

url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name]
model = WaveRNN(**configs)
state_dict = load_state_dict_from_url(url, progress=progress)
model.load_state_dict(state_dict)
return model