Skip to content

Commit

Permalink
Import torchaudio #1513 08f2bde
Browse files Browse the repository at this point in the history
Summary: Import from github

Reviewed By: mthrok

Differential Revision: D28606124

fbshipit-source-id: 05dcb07efc5537d928bec682a68e6ccee7cc325e
  • Loading branch information
parmeet authored and facebook-github-bot committed May 21, 2021
1 parent 7f6ac05 commit 81db19b
Show file tree
Hide file tree
Showing 28 changed files with 682 additions and 352 deletions.
10 changes: 2 additions & 8 deletions .github/workflows/codeql.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,13 @@ jobs:
with:
languages: python, cpp

- name: Install Ninja
run: |
sudo apt-get update -y
sudo apt-get install -y ninja-build
- name: Update submodules
run: git submodule update --init --recursive

- name: Install Torch
run: |
python -m pip install cmake
python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
sudo ln -s /usr/bin/ninja /usr/bin/ninja-build
python -m pip install cmake ninja
python -m pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- name: Build TorchAudio
run: BUILD_SOX=1 USE_CUDA=0 python setup.py develop --user
Expand Down
8 changes: 8 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ The models subpackage contains definitions of models for addressing common audio
.. automethod:: forward


:hidden:`DeepSpeech`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: DeepSpeech

.. automethod:: forward


:hidden:`Wav2Letter`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
5 changes: 3 additions & 2 deletions test/torchaudio_unittest/common_utils/case_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def setUpClass(cls):
super().setUpClass()
cls._proc = subprocess.Popen(
['python', '-m', 'http.server', f'{cls._port}'],
cwd=cls.get_base_temp_dir())
time.sleep(1.0)
cwd=cls.get_base_temp_dir(),
stderr=subprocess.DEVNULL) # Disable server-side error log because it is confusing
time.sleep(2.0)

@classmethod
def tearDownClass(cls):
Expand Down
43 changes: 28 additions & 15 deletions test/torchaudio_unittest/compliance_kaldi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torchaudio_unittest import common_utils
from .compliance import utils as compliance_utils
from parameterized import parameterized


def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
Expand Down Expand Up @@ -182,20 +183,26 @@ def get_output_fn(sound, args):

self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)

def test_resample_waveform_upsample_size(self):
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_size(self, resampling_method):
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2,
resampling_method=resampling_method)
self.assertTrue(upsample_sound.size(-1) == self.test1_signal.size(-1) * 2)

def test_resample_waveform_downsample_size(self):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1) // 2)

def test_resample_waveform_identity_size(self):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_identity_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1))

def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
atol=1e-1, rtol=1e-4):
resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4):
# resample the signal and compare it to the ground truth
n_to_trim = 20
sample_rate = 1000
Expand All @@ -211,7 +218,8 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)

sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze()
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate,
resampling_method=resampling_method).squeeze()

new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)
Expand All @@ -222,27 +230,32 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact

self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)

def test_resample_waveform_downsample_accuracy(self):
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_accuracy(self, resampling_method):
for i in range(1, 20):
self._test_resample_waveform_accuracy(down_scale_factor=i * 2)
self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method)

def test_resample_waveform_upsample_accuracy(self):
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_accuracy(self, resampling_method):
for i in range(1, 20):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0)
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)

def test_resample_waveform_multi_channel(self):
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_multi_channel(self, resampling_method):
num_channels = 3

multi_sound = self.test1_signal.repeat(num_channels, 1) # (num_channels, 8000 smp)

for i in range(num_channels):
multi_sound[i, :] *= (i + 1) * 1.5

multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2)
multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2,
resampling_method=resampling_method)

# check that sampling is same whether using separately or in a tensor of size (c, n)
for i in range(num_channels):
single_channel = self.test1_signal * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr,
self.test1_signal_sr // 2)
self.test1_signal_sr // 2,
resampling_method=resampling_method)
self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)
21 changes: 20 additions & 1 deletion test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from parameterized import parameterized
from scipy import signal

from torchaudio_unittest.common_utils import TestBaseMixin, get_sinusoid, nested_params
from torchaudio_unittest.common_utils import TestBaseMixin, get_sinusoid, nested_params, get_whitenoise


class Functional(TestBaseMixin):
Expand Down Expand Up @@ -259,6 +259,25 @@ def test_mask_along_axis_iid_preserve(self, mask_param, mask_value, axis):

self.assertEqual(specgrams, specgrams_copy)

def test_resample_no_warning(self):
sample_rate = 44100
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1)

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.resample(waveform, float(sample_rate), sample_rate / 2.)
assert len(w) == 0

def test_resample_warning(self):
"""resample should throw a warning if an input frequency is not of an integer value"""
sample_rate = 44100
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1)

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.resample(waveform, sample_rate, 5512.5)
assert len(w) == 1

@nested_params(
[0.5, 1.01, 1.3],
[True, False],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,28 @@ def func(tensor):
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)

def test_resample_sinc(self):
def func(tensor):
sr1, sr2 = 16000., 8000.
return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation")

tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, tensor)

def test_resample_kaiser(self):
def func(tensor):
sr1, sr2 = 16000., 8000.
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window")

def func_beta(tensor):
sr1, sr2 = 16000., 8000.
beta = 6.
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window", beta=beta)

tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, tensor)
self._assert_consistency(func_beta, tensor)

@parameterized.expand([(True, ), (False, )])
def test_phase_vocoder(self, test_paseudo_complex):
def func(tensor):
Expand Down
19 changes: 18 additions & 1 deletion test/torchaudio_unittest/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from parameterized import parameterized
from torchaudio.models import ConvTasNet, Wav2Letter, WaveRNN
from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN
from torchaudio.models.wavernn import MelResNet, UpsampleNetwork
from torchaudio_unittest import common_utils

Expand Down Expand Up @@ -174,3 +174,20 @@ def test_paper_configuration(self, num_sources, model_params):
output = model(tensor)

assert output.shape == (batch_size, num_sources, num_frames)


class TestDeepSpeech(common_utils.TorchaudioTestCase):

def test_deepspeech(self):
n_batch = 2
n_feature = 1
n_channel = 1
n_class = 40
n_time = 320

model = DeepSpeech(n_feature=n_feature, n_class=n_class)

x = torch.rand(n_batch, n_channel, n_time, n_feature)
out = model(x)

assert out.size() == (n_batch, n_time, n_class)
10 changes: 10 additions & 0 deletions test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .utils import skipIfNoTransducer
from .torchscript_consistency_impl import RNNTLossTorchscript


@skipIfNoTransducer
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase):
device = torch.device('cpu')
11 changes: 11 additions & 0 deletions test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .utils import skipIfNoTransducer
from .torchscript_consistency_impl import RNNTLossTorchscript


@skipIfNoTransducer
@skipIfNoCuda
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase):
device = torch.device('cuda')
70 changes: 70 additions & 0 deletions test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss


class RNNTLossTorchscript(TempDirMixin, TestBaseMixin):
"""Implements test for RNNT Loss that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)

path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)

torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
output = func(input_tensor)

torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
ts_output = ts_func(input_tensor)

self.assertEqual(ts_output, output)

def test_rnnt_loss(self):
def func(
logits,
):
targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32)
return rnnt_loss(logits, targets, logit_lengths, target_lengths)

logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])

self._assert_consistency(func, logits)

def test_RNNTLoss(self):
func = RNNTLoss()

logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
targets = torch.tensor([[1, 2]], device=self.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=self.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=self.device, dtype=torch.int32)

tensor = logits.to(device=self.device, dtype=self.dtype)

path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)

torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
output = func(input_tensor, targets, logit_lengths, target_lengths)

torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
ts_output = ts_func(input_tensor, targets, logit_lengths, target_lengths)

self.assertEqual(ts_output, output)
8 changes: 4 additions & 4 deletions test/torchaudio_unittest/rnnt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,10 @@ def get_numpy_random_data(


def numpy_to_torch(data, device, requires_grad=True):
logits = torch.from_numpy(data["logits"])
targets = torch.from_numpy(data["targets"])
logit_lengths = torch.from_numpy(data["logit_lengths"])
target_lengths = torch.from_numpy(data["target_lengths"])
logits = torch.from_numpy(data["logits"]).to(device=device)
targets = torch.from_numpy(data["targets"]).to(device=device)
logit_lengths = torch.from_numpy(data["logit_lengths"]).to(device=device)
target_lengths = torch.from_numpy(data["target_lengths"]).to(device=device)

if "nbest_wers" in data:
data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device)
Expand Down
25 changes: 25 additions & 0 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,31 @@ def test_fade(self, fade_shape):
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)

@parameterized.expand([(T.TimeMasking,), (T.FrequencyMasking,)])
def test_masking(self, masking_transform):
sample_rate = 8000
n_fft = 400
spectrogram = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2),
n_fft=n_fft, power=1)
deterministic_transform = _DeterministicWrapper(masking_transform(400))
self.assert_grad(deterministic_transform, [spectrogram])

@parameterized.expand([(T.TimeMasking,), (T.FrequencyMasking,)])
def test_masking_iid(self, masking_transform):
sample_rate = 8000
n_fft = 400
specs = [get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2, seed=i),
n_fft=n_fft, power=1)
for i in range(3)
]

batch = torch.stack(specs)
assert batch.ndim == 4
deterministic_transform = _DeterministicWrapper(masking_transform(400, True))
self.assert_grad(deterministic_transform, [batch])

def test_spectral_centroid(self):
sample_rate = 8000
transform = T.SpectralCentroid(sample_rate=sample_rate)
Expand Down
Loading

0 comments on commit 81db19b

Please sign in to comment.