From 81db19be529233ab7a91daa4c664089fbb654973 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 21 May 2021 16:54:14 -0700 Subject: [PATCH] Import torchaudio #1513 08f2bde Summary: Import from github Reviewed By: mthrok Differential Revision: D28606124 fbshipit-source-id: 05dcb07efc5537d928bec682a68e6ccee7cc325e --- .github/workflows/codeql.yml | 10 +- docs/source/models.rst | 8 + .../common_utils/case_utils.py | 5 +- .../compliance_kaldi_test.py | 43 +++-- .../functional/functional_impl.py | 21 ++- .../torchscript_consistency_impl.py | 22 +++ test/torchaudio_unittest/models_test.py | 19 +- .../rnnt/torchscript_consistency_cpu_test.py | 10 + .../rnnt/torchscript_consistency_cuda_test.py | 11 ++ .../rnnt/torchscript_consistency_impl.py | 70 +++++++ test/torchaudio_unittest/rnnt/utils.py | 8 +- .../transforms/autograd_test_impl.py | 25 +++ .../transforms/sox_compatibility_test.py | 25 +++ .../torchscript_consistency_impl.py | 4 + .../transforms/transforms_test.py | 6 +- .../transforms/transforms_test_impl.py | 20 +- torchaudio/compliance/kaldi.py | 11 +- torchaudio/csrc/CMakeLists.txt | 1 + torchaudio/csrc/rnnt/autograd.cpp | 74 ++++++++ torchaudio/csrc/rnnt/compute.cpp | 24 +++ torchaudio/csrc/rnnt/compute.h | 13 ++ torchaudio/csrc/rnnt/transducer.h | 121 ------------ torchaudio/functional/filtering.py | 15 +- torchaudio/functional/functional.py | 129 +++++++++---- torchaudio/models/__init__.py | 2 + torchaudio/models/deepspeech.py | 92 +++++++++ torchaudio/prototype/rnnt_loss.py | 176 +++--------------- torchaudio/transforms.py | 69 ++++++- 28 files changed, 682 insertions(+), 352 deletions(-) create mode 100644 test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py create mode 100644 test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py create mode 100644 test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py create mode 100644 torchaudio/csrc/rnnt/autograd.cpp create mode 100644 torchaudio/csrc/rnnt/compute.h delete mode 100644 torchaudio/csrc/rnnt/transducer.h create mode 100644 torchaudio/models/deepspeech.py diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index f0ea58f309..7d38c2c0e2 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -20,19 +20,13 @@ jobs: with: languages: python, cpp - - name: Install Ninja - run: | - sudo apt-get update -y - sudo apt-get install -y ninja-build - - name: Update submodules run: git submodule update --init --recursive - name: Install Torch run: | - python -m pip install cmake - python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html - sudo ln -s /usr/bin/ninja /usr/bin/ninja-build + python -m pip install cmake ninja + python -m pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - name: Build TorchAudio run: BUILD_SOX=1 USE_CUDA=0 python setup.py develop --user diff --git a/docs/source/models.rst b/docs/source/models.rst index ea86d8b73b..2030eefd28 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -17,6 +17,14 @@ The models subpackage contains definitions of models for addressing common audio .. automethod:: forward +:hidden:`DeepSpeech` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeepSpeech + + .. automethod:: forward + + :hidden:`Wav2Letter` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/torchaudio_unittest/common_utils/case_utils.py b/test/torchaudio_unittest/common_utils/case_utils.py index 773f1915b9..fa3d5077bb 100644 --- a/test/torchaudio_unittest/common_utils/case_utils.py +++ b/test/torchaudio_unittest/common_utils/case_utils.py @@ -59,8 +59,9 @@ def setUpClass(cls): super().setUpClass() cls._proc = subprocess.Popen( ['python', '-m', 'http.server', f'{cls._port}'], - cwd=cls.get_base_temp_dir()) - time.sleep(1.0) + cwd=cls.get_base_temp_dir(), + stderr=subprocess.DEVNULL) # Disable server-side error log because it is confusing + time.sleep(2.0) @classmethod def tearDownClass(cls): diff --git a/test/torchaudio_unittest/compliance_kaldi_test.py b/test/torchaudio_unittest/compliance_kaldi_test.py index 17d73fbf2f..af093cea55 100644 --- a/test/torchaudio_unittest/compliance_kaldi_test.py +++ b/test/torchaudio_unittest/compliance_kaldi_test.py @@ -7,6 +7,7 @@ from torchaudio_unittest import common_utils from .compliance import utils as compliance_utils +from parameterized import parameterized def extract_window(window, wave, f, frame_length, frame_shift, snip_edges): @@ -182,20 +183,26 @@ def get_output_fn(sound, args): self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5) - def test_resample_waveform_upsample_size(self): - upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2) + @parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) + def test_resample_waveform_upsample_size(self, resampling_method): + upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2, + resampling_method=resampling_method) self.assertTrue(upsample_sound.size(-1) == self.test1_signal.size(-1) * 2) - def test_resample_waveform_downsample_size(self): - downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2) + @parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) + def test_resample_waveform_downsample_size(self, resampling_method): + downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2, + resampling_method=resampling_method) self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1) // 2) - def test_resample_waveform_identity_size(self): - downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr) + @parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) + def test_resample_waveform_identity_size(self, resampling_method): + downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr, + resampling_method=resampling_method) self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1)) def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None, - atol=1e-1, rtol=1e-4): + resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4): # resample the signal and compare it to the ground truth n_to_trim = 20 sample_rate = 1000 @@ -211,7 +218,8 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact original_timestamps = torch.arange(0, duration, 1.0 / sample_rate) sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0) - estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze() + estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate, + resampling_method=resampling_method).squeeze() new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)] ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps) @@ -222,15 +230,18 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol) - def test_resample_waveform_downsample_accuracy(self): + @parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) + def test_resample_waveform_downsample_accuracy(self, resampling_method): for i in range(1, 20): - self._test_resample_waveform_accuracy(down_scale_factor=i * 2) + self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method) - def test_resample_waveform_upsample_accuracy(self): + @parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) + def test_resample_waveform_upsample_accuracy(self, resampling_method): for i in range(1, 20): - self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0) + self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method) - def test_resample_waveform_multi_channel(self): + @parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) + def test_resample_waveform_multi_channel(self, resampling_method): num_channels = 3 multi_sound = self.test1_signal.repeat(num_channels, 1) # (num_channels, 8000 smp) @@ -238,11 +249,13 @@ def test_resample_waveform_multi_channel(self): for i in range(num_channels): multi_sound[i, :] *= (i + 1) * 1.5 - multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2) + multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2, + resampling_method=resampling_method) # check that sampling is same whether using separately or in a tensor of size (c, n) for i in range(num_channels): single_channel = self.test1_signal * (i + 1) * 1.5 single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr, - self.test1_signal_sr // 2) + self.test1_signal_sr // 2, + resampling_method=resampling_method) self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7) diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index 424ead2ca9..1a2026233d 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -9,7 +9,7 @@ from parameterized import parameterized from scipy import signal -from torchaudio_unittest.common_utils import TestBaseMixin, get_sinusoid, nested_params +from torchaudio_unittest.common_utils import TestBaseMixin, get_sinusoid, nested_params, get_whitenoise class Functional(TestBaseMixin): @@ -259,6 +259,25 @@ def test_mask_along_axis_iid_preserve(self, mask_param, mask_value, axis): self.assertEqual(specgrams, specgrams_copy) + def test_resample_no_warning(self): + sample_rate = 44100 + waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + F.resample(waveform, float(sample_rate), sample_rate / 2.) + assert len(w) == 0 + + def test_resample_warning(self): + """resample should throw a warning if an input frequency is not of an integer value""" + sample_rate = 44100 + waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + F.resample(waveform, sample_rate, 5512.5) + assert len(w) == 1 + @nested_params( [0.5, 1.01, 1.3], [True, False], diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index a39973f867..3eb869c876 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -591,6 +591,28 @@ def func(tensor): tensor = common_utils.get_whitenoise(sample_rate=44100) self._assert_consistency(func, tensor) + def test_resample_sinc(self): + def func(tensor): + sr1, sr2 = 16000., 8000. + return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation") + + tensor = common_utils.get_whitenoise(sample_rate=16000) + self._assert_consistency(func, tensor) + + def test_resample_kaiser(self): + def func(tensor): + sr1, sr2 = 16000., 8000. + return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window") + + def func_beta(tensor): + sr1, sr2 = 16000., 8000. + beta = 6. + return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window", beta=beta) + + tensor = common_utils.get_whitenoise(sample_rate=16000) + self._assert_consistency(func, tensor) + self._assert_consistency(func_beta, tensor) + @parameterized.expand([(True, ), (False, )]) def test_phase_vocoder(self, test_paseudo_complex): def func(tensor): diff --git a/test/torchaudio_unittest/models_test.py b/test/torchaudio_unittest/models_test.py index 484ec6c10c..4db4895b8f 100644 --- a/test/torchaudio_unittest/models_test.py +++ b/test/torchaudio_unittest/models_test.py @@ -3,7 +3,7 @@ import torch from parameterized import parameterized -from torchaudio.models import ConvTasNet, Wav2Letter, WaveRNN +from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN from torchaudio.models.wavernn import MelResNet, UpsampleNetwork from torchaudio_unittest import common_utils @@ -174,3 +174,20 @@ def test_paper_configuration(self, num_sources, model_params): output = model(tensor) assert output.shape == (batch_size, num_sources, num_frames) + + +class TestDeepSpeech(common_utils.TorchaudioTestCase): + + def test_deepspeech(self): + n_batch = 2 + n_feature = 1 + n_channel = 1 + n_class = 40 + n_time = 320 + + model = DeepSpeech(n_feature=n_feature, n_class=n_class) + + x = torch.rand(n_batch, n_channel, n_time, n_feature) + out = model(x) + + assert out.size() == (n_batch, n_time, n_class) diff --git a/test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py b/test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py new file mode 100644 index 0000000000..06b6baf5a1 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py @@ -0,0 +1,10 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase +from .utils import skipIfNoTransducer +from .torchscript_consistency_impl import RNNTLossTorchscript + + +@skipIfNoTransducer +class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase): + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py b/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py new file mode 100644 index 0000000000..22b1713582 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py @@ -0,0 +1,11 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda +from .utils import skipIfNoTransducer +from .torchscript_consistency_impl import RNNTLossTorchscript + + +@skipIfNoTransducer +@skipIfNoCuda +class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase): + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py b/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py new file mode 100644 index 0000000000..aeba7e3ae4 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py @@ -0,0 +1,70 @@ +import torch +from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin +from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss + + +class RNNTLossTorchscript(TempDirMixin, TestBaseMixin): + """Implements test for RNNT Loss that are performed for different devices""" + def _assert_consistency(self, func, tensor, shape_only=False): + tensor = tensor.to(device=self.device, dtype=self.dtype) + + path = self.get_temp_path('func.zip') + torch.jit.script(func).save(path) + ts_func = torch.jit.load(path) + + torch.random.manual_seed(40) + input_tensor = tensor.clone().detach().requires_grad_(True) + output = func(input_tensor) + + torch.random.manual_seed(40) + input_tensor = tensor.clone().detach().requires_grad_(True) + ts_output = ts_func(input_tensor) + + self.assertEqual(ts_output, output) + + def test_rnnt_loss(self): + def func( + logits, + ): + targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32) + logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) + target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) + return rnnt_loss(logits, targets, logit_lengths, target_lengths) + + logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1]], + [[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1]]]]) + + self._assert_consistency(func, logits) + + def test_RNNTLoss(self): + func = RNNTLoss() + + logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1]], + [[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1]]]]) + targets = torch.tensor([[1, 2]], device=self.device, dtype=torch.int32) + logit_lengths = torch.tensor([2], device=self.device, dtype=torch.int32) + target_lengths = torch.tensor([2], device=self.device, dtype=torch.int32) + + tensor = logits.to(device=self.device, dtype=self.dtype) + + path = self.get_temp_path('func.zip') + torch.jit.script(func).save(path) + ts_func = torch.jit.load(path) + + torch.random.manual_seed(40) + input_tensor = tensor.clone().detach().requires_grad_(True) + output = func(input_tensor, targets, logit_lengths, target_lengths) + + torch.random.manual_seed(40) + input_tensor = tensor.clone().detach().requires_grad_(True) + ts_output = ts_func(input_tensor, targets, logit_lengths, target_lengths) + + self.assertEqual(ts_output, output) diff --git a/test/torchaudio_unittest/rnnt/utils.py b/test/torchaudio_unittest/rnnt/utils.py index 8e93d28032..5f4c40379e 100644 --- a/test/torchaudio_unittest/rnnt/utils.py +++ b/test/torchaudio_unittest/rnnt/utils.py @@ -405,10 +405,10 @@ def get_numpy_random_data( def numpy_to_torch(data, device, requires_grad=True): - logits = torch.from_numpy(data["logits"]) - targets = torch.from_numpy(data["targets"]) - logit_lengths = torch.from_numpy(data["logit_lengths"]) - target_lengths = torch.from_numpy(data["target_lengths"]) + logits = torch.from_numpy(data["logits"]).to(device=device) + targets = torch.from_numpy(data["targets"]).to(device=device) + logit_lengths = torch.from_numpy(data["logit_lengths"]).to(device=device) + target_lengths = torch.from_numpy(data["target_lengths"]).to(device=device) if "nbest_wers" in data: data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device) diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 717a7bc87b..0f7b612482 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -125,6 +125,31 @@ def test_fade(self, fade_shape): waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) self.assert_grad(transform, [waveform], nondet_tol=1e-10) + @parameterized.expand([(T.TimeMasking,), (T.FrequencyMasking,)]) + def test_masking(self, masking_transform): + sample_rate = 8000 + n_fft = 400 + spectrogram = get_spectrogram( + get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), + n_fft=n_fft, power=1) + deterministic_transform = _DeterministicWrapper(masking_transform(400)) + self.assert_grad(deterministic_transform, [spectrogram]) + + @parameterized.expand([(T.TimeMasking,), (T.FrequencyMasking,)]) + def test_masking_iid(self, masking_transform): + sample_rate = 8000 + n_fft = 400 + specs = [get_spectrogram( + get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2, seed=i), + n_fft=n_fft, power=1) + for i in range(3) + ] + + batch = torch.stack(specs) + assert batch.ndim == 4 + deterministic_transform = _DeterministicWrapper(masking_transform(400, True)) + self.assert_grad(deterministic_transform, [batch]) + def test_spectral_centroid(self): sample_rate = 8000 transform = T.SpectralCentroid(sample_rate=sample_rate) diff --git a/test/torchaudio_unittest/transforms/sox_compatibility_test.py b/test/torchaudio_unittest/transforms/sox_compatibility_test.py index 81582c8393..be6c9020ab 100644 --- a/test/torchaudio_unittest/transforms/sox_compatibility_test.py +++ b/test/torchaudio_unittest/transforms/sox_compatibility_test.py @@ -1,3 +1,6 @@ +import warnings + +import torch import torchaudio.transforms as T from parameterized import parameterized @@ -61,3 +64,25 @@ def test_vad(self, filename): data, sample_rate = load_wav(path) result = T.Vad(sample_rate)(data) self.assert_sox_effect(result, path, ['vad']) + + def test_vad_warning(self): + """vad should throw a warning if input dimension is greater than 2""" + sample_rate = 41100 + + data = torch.rand(5, 5, sample_rate) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + T.Vad(sample_rate)(data) + assert len(w) == 1 + + data = torch.rand(5, sample_rate) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + T.Vad(sample_rate)(data) + assert len(w) == 0 + + data = torch.rand(sample_rate) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + T.Vad(sample_rate)(data) + assert len(w) == 0 diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 506ca9af6c..5258343181 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -59,6 +59,10 @@ def test_AmplitudeToDB(self): spec = torch.rand((6, 201)) self._assert_consistency(T.AmplitudeToDB(), spec) + def test_MelScale_invalid(self): + with self.assertRaises(ValueError): + torch.jit.script(T.MelScale()) + def test_MelScale(self): spec_f = torch.rand((1, 201, 6)) self._assert_consistency(T.MelScale(n_stft=201), spec_f) diff --git a/test/torchaudio_unittest/transforms/transforms_test.py b/test/torchaudio_unittest/transforms/transforms_test.py index ab7b2ed981..5aa88ead90 100644 --- a/test/torchaudio_unittest/transforms/transforms_test.py +++ b/test/torchaudio_unittest/transforms/transforms_test.py @@ -169,9 +169,11 @@ def test_resample_size(self): upsample_rate = sample_rate * 2 downsample_rate = sample_rate // 2 - invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo') + invalid_resampling_method = 'foo' - self.assertRaises(ValueError, invalid_resample, waveform) + with self.assertRaises(ValueError): + torchaudio.transforms.Resample(sample_rate, upsample_rate, + resampling_method=invalid_resampling_method) upsample_resample = torchaudio.transforms.Resample( sample_rate, upsample_rate, resampling_method='sinc_interpolation') diff --git a/test/torchaudio_unittest/transforms/transforms_test_impl.py b/test/torchaudio_unittest/transforms/transforms_test_impl.py index 8a92662200..44f254e540 100644 --- a/test/torchaudio_unittest/transforms/transforms_test_impl.py +++ b/test/torchaudio_unittest/transforms/transforms_test_impl.py @@ -1,3 +1,5 @@ +import warnings + import torch import torchaudio.transforms as T @@ -39,7 +41,7 @@ def test_InverseMelScale(self): get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2), n_fft=n_fft, power=power).to(self.device, self.dtype) input = T.MelScale( - n_mels=n_mels, sample_rate=sample_rate + n_mels=n_mels, sample_rate=sample_rate, n_stft=n_stft ).to(self.device, self.dtype)(expected) # Run transform @@ -59,3 +61,19 @@ def test_InverseMelScale(self): assert _get_ratio(relative_diff < 1e-1) > 0.2 assert _get_ratio(relative_diff < 1e-3) > 5e-3 assert _get_ratio(relative_diff < 1e-5) > 1e-5 + + def test_melscale_unset_weight_warning(self): + """Issue a warning if MelScale initialized without a weight + + As part of the deprecation of lazy intialization behavior (#1510), + issue a warning if `n_stft` is not set. + """ + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + T.MelScale(n_mels=64, sample_rate=8000) + assert len(caught_warnings) == 1 + + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + T.MelScale(n_mels=64, sample_rate=8000, n_stft=201) + assert len(caught_warnings) == 0 diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index a5c3354c9b..22616038f5 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -755,7 +755,9 @@ def mfcc( def resample_waveform(waveform: Tensor, orig_freq: float, new_freq: float, - lowpass_filter_width: int = 6) -> Tensor: + lowpass_filter_width: int = 6, + rolloff: float = 0.99, + resampling_method: str = "sinc_interpolation") -> Tensor: r"""Resamples the waveform at the new frequency. This is a wrapper around ``torchaudio.functional.resample``. @@ -766,8 +768,13 @@ def resample_waveform(waveform: Tensor, new_freq (float): The desired frequency lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) + rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. + Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) + resampling_method (str, optional): The resampling method to use. + Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``) Returns: Tensor: The waveform at the new frequency """ - return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width) + return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width, + rolloff, resampling_method) diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index ebf577eb64..64661f96a5 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -19,6 +19,7 @@ if(BUILD_TRANSDUCER) rnnt/compute_alphas.cpp rnnt/compute_betas.cpp rnnt/compute.cpp + rnnt/autograd.cpp ) if (USE_CUDA) diff --git a/torchaudio/csrc/rnnt/autograd.cpp b/torchaudio/csrc/rnnt/autograd.cpp new file mode 100644 index 0000000000..73ad9f9b3c --- /dev/null +++ b/torchaudio/csrc/rnnt/autograd.cpp @@ -0,0 +1,74 @@ +#include +#include + +namespace torchaudio { +namespace rnnt { + +class RNNTLossFunction : public torch::autograd::Function { + public: + static torch::autograd::tensor_list forward( + torch::autograd::AutogradContext* ctx, + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax = true, + bool reuse_logits_for_grads = true) { + at::AutoNonVariableTypeMode g; + torch::Tensor undef; + auto result = rnnt_loss( + logits, + targets, + src_lengths, + tgt_lengths, + blank, + clamp, + fused_log_smax, + reuse_logits_for_grads); + auto costs = std::get<0>(result); + auto grads = std::get<1>(result).value_or(undef); + ctx->save_for_backward({grads}); + return {costs, grads}; + } + + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto grad = saved[0]; + auto grad_out = grad_outputs[0].view({-1, 1, 1, 1}); + auto result = grad * grad_out; + torch::Tensor undef; + return {result, undef, undef, undef, undef, undef, undef, undef}; + } +}; + +std::tuple> rnnt_loss_autograd( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax = true, + bool reuse_logits_for_grads = true) { + auto results = RNNTLossFunction::apply( + logits, + targets, + src_lengths, + tgt_lengths, + blank, + clamp, + fused_log_smax, + reuse_logits_for_grads); + return std::make_tuple(results[0], results[1]); +} + +TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { + m.impl("rnnt_loss", rnnt_loss_autograd); +} + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/compute.cpp b/torchaudio/csrc/rnnt/compute.cpp index bce803fffa..f47f0f505d 100644 --- a/torchaudio/csrc/rnnt/compute.cpp +++ b/torchaudio/csrc/rnnt/compute.cpp @@ -1,4 +1,28 @@ #include +#include + +std::tuple> rnnt_loss( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax = true, + bool reuse_logits_for_grads = true) { + static auto op = torch::Dispatcher::singleton() + .findSchemaOrThrow("torchaudio::rnnt_loss", "") + .typed(); + return op.call( + logits, + targets, + src_lengths, + tgt_lengths, + blank, + clamp, + fused_log_smax, + reuse_logits_for_grads); +} TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( diff --git a/torchaudio/csrc/rnnt/compute.h b/torchaudio/csrc/rnnt/compute.h new file mode 100644 index 0000000000..9616d45fc3 --- /dev/null +++ b/torchaudio/csrc/rnnt/compute.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +std::tuple> rnnt_loss( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax, + bool reuse_logits_for_grads); diff --git a/torchaudio/csrc/rnnt/transducer.h b/torchaudio/csrc/rnnt/transducer.h deleted file mode 100644 index 0553616114..0000000000 --- a/torchaudio/csrc/rnnt/transducer.h +++ /dev/null @@ -1,121 +0,0 @@ -#pragma once - -#include -#include - -namespace torchaudio { -namespace rnnt { - -template -status_t Compute( - const Workspace& workspace, - const DTYPE* logits, - const int* targets, - const int* srcLengths, - const int* tgtLengths, - DTYPE* costs, - DTYPE* gradients = nullptr) { - switch (workspace.GetOptions().device_) { - case CPU: { - status_t status = cpu::Compute( - /*workspace=*/workspace, - /*logits=*/logits, - /*targets=*/targets, - /*srcLengths=*/srcLengths, - /*tgtLengths=*/tgtLengths, - /*costs=*/costs, - /*gradients=*/gradients); - return status; - } - case GPU: { - status_t status = gpu::Compute( - /*workspace=*/workspace, - /*logits=*/logits, - /*targets=*/targets, - /*srcLengths=*/srcLengths, - /*tgtLengths=*/tgtLengths, - /*costs=*/costs, - /*gradients=*/gradients); - return status; - } - default: { - return FAILURE; - } - }; -} - -template -status_t ComputeAlphas( - const Workspace& workspace, - const DTYPE* logits, - const int* targets, - const int* srcLengths, - const int* tgtLengths, - DTYPE* alphas) { - switch (workspace.GetOptions().device_) { - case CPU: { - status_t status = cpu::ComputeAlphas( - /*workspace=*/workspace, - /*logits=*/logits, - /*targets=*/targets, - /*srcLengths=*/srcLengths, - /*tgtLengths=*/tgtLengths, - /*alphas=*/alphas); - return status; - } - case GPU: { - status_t status = gpu::ComputeAlphas( - /*workspace=*/workspace, - /*logits=*/logits, - /*targets=*/targets, - /*srcLengths=*/srcLengths, - /*tgtLengths=*/tgtLengths, - /*costs=*/alphas); - return status; - } - default: { - return FAILURE; - } - }; -} - -template -status_t ComputeBetas( - const Workspace& workspace, - const DTYPE* logits, - const int* targets, - const int* srcLengths, - const int* tgtLengths, - DTYPE* costs, - DTYPE* betas) { - switch (workspace.GetOptions().device_) { - case CPU: { - status_t status = cpu::ComputeBetas( - /*workspace=*/workspace, - /*logits=*/logits, - /*targets=*/targets, - /*srcLengths=*/srcLengths, - /*tgtLengths=*/tgtLengths, - /*costs=*/costs, - /*betas=*/betas); - return status; - } - case GPU: { - status_t status = gpu::ComputeBetas( - /*workspace=*/workspace, - /*logits=*/logits, - /*targets=*/targets, - /*srcLengths=*/srcLengths, - /*tgtLengths=*/tgtLengths, - /*costs=*/costs, - /*betas=*/betas); - return status; - } - default: { - return FAILURE; - } - }; -} - -} // namespace rnnt -} // namespace torchaudio diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index 93da27493b..85abe81339 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -1,4 +1,5 @@ import math +import warnings from typing import Optional import torch @@ -1374,7 +1375,10 @@ def vad( so in order to trim from the back, the reverse effect must also be used. Args: - waveform (Tensor): Tensor of audio of dimension `(..., time)` + waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)` + Tensor of shape `(channels, time)` is treated as a multi-channel recording + of the same event and the resulting output will be trimmed to the earliest + voice activity in any channel. sample_rate (int): Sample rate of audio signal. trigger_level (float, optional): The measurement level used to trigger activity detection. This may need to be cahnged depending on the noise level, signal level, @@ -1420,6 +1424,15 @@ def vad( http://sox.sourceforge.net/sox.html """ + if waveform.ndim > 2: + warnings.warn( + "Expected input tensor dimension of 1 for single channel" + f" or 2 for multi-channel. Got {waveform.ndim} instead. " + "Batch semantics is not supported. " + "Please refer to https://github.com/pytorch/audio/issues/1348" + " and https://github.com/pytorch/audio/issues/1468." + ) + measure_duration: float = ( 2.0 / measure_freq if measure_duration is None else measure_duration ) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index b7d9b112b8..74271798eb 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1298,8 +1298,36 @@ def compute_kaldi_pitch( return result -def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int, - device: torch.device, dtype: torch.dtype): +def _get_sinc_resample_kernel( + orig_freq: float, + new_freq: float, + gcd: int, + lowpass_filter_width: int, + rolloff: float, + resampling_method: str, + beta: Optional[float], + device: torch.device = torch.device("cpu"), + dtype: Optional[torch.dtype] = None): + + if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): + warnings.warn( + "Non-integer frequencies are being cast to ints and may result in poor resampling quality " + "because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. " + "Using non-integer valued frequencies will throw an error in the next release. " + "To work around this issue, manually convert both frequencies to integer values " + "that maintain their resampling rate ratio before passing them into the function " + "Example: To downsample a 44100 hz waveform by a factor of 8, use " + "`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` " + "For more information or to leave feedback about this change, please refer to " + "https://github.com/pytorch/audio/issues/1487." + ) + + if resampling_method not in ['sinc_interpolation', 'kaiser_window']: + raise ValueError('Invalid resampling method: {}'.format(resampling_method)) + + orig_freq = int(orig_freq) // gcd + new_freq = int(new_freq) // gcd + assert lowpass_filter_width > 0 kernels = [] base_freq = min(orig_freq, new_freq) @@ -1307,7 +1335,7 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt # At first I thought I only needed this when downsampling, but when upsampling # you will get edge artifacts without this, as the edge is equivalent to zero padding, # which will add high freq artifacts. - base_freq *= 0.99 + base_freq *= rolloff # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) # using the sinc interpolation formula: @@ -1331,70 +1359,99 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt # they will have a lot of almost zero values to the left or to the right... # There is probably a way to evaluate those filters more efficiently, but this is kept for # future work. - idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype) + idx_dtype = dtype if dtype is not None else torch.float64 + idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype) for i in range(new_freq): t = (-i / new_freq + idx / orig_freq) * base_freq t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) - t *= math.pi - # we do not use torch.hann_window here as we need to evaluate the window + + # we do not use built in torch windows here as we need to evaluate the window # at specific positions, not over a regular grid. - window = torch.cos(t / lowpass_filter_width / 2)**2 + if resampling_method == "sinc_interpolation": + window = torch.cos(t * math.pi / lowpass_filter_width / 2)**2 + else: + # kaiser_window + if beta is None: + beta = 14.769656459379492 + beta_tensor = torch.tensor(float(beta)) + window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor) + t *= math.pi kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t) kernel.mul_(window) kernels.append(kernel) scale = base_freq / orig_freq - return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width + kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale) + if dtype is None: + kernels = kernels.to(dtype=torch.float32) + return kernels, width + + +def _apply_sinc_resample_kernel( + waveform: Tensor, + orig_freq: float, + new_freq: float, + gcd: int, + kernel: Tensor, + width: int, +): + orig_freq = int(orig_freq) // gcd + new_freq = int(new_freq) // gcd + + # pack batch + shape = waveform.size() + waveform = waveform.view(-1, shape[-1]) + + num_wavs, length = waveform.shape + waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq)) + resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq) + resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) + target_length = int(math.ceil(new_freq * length / orig_freq)) + resampled = resampled[..., :target_length] + + # unpack batch + resampled = resampled.view(shape[:-1] + resampled.shape[-1:]) + return resampled def resample( waveform: Tensor, orig_freq: float, new_freq: float, - lowpass_filter_width: int = 6 + lowpass_filter_width: int = 6, + rolloff: float = 0.99, + resampling_method: str = "sinc_interpolation", + beta: Optional[float] = None, ) -> Tensor: - r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform - which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample - a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e - the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to - upsample/downsample the signal. + r"""Resamples the waveform at the new frequency using bandlimited interpolation. https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html - https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56 Args: waveform (Tensor): The input signal of dimension (..., time) orig_freq (float): The original frequency of the signal new_freq (float): The desired frequency lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper - but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) + but less efficient. (Default: ``6``) + rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. + Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) + resampling_method (str, optional): The resampling method to use. + Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``) + beta (float or None): The shape parameter used for kaiser window. Returns: Tensor: The waveform at the new frequency of dimension (..., time). + + Note: ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in + more efficient computation if resampling multiple waveforms with the same resampling parameters. """ - # pack batch - shape = waveform.size() - waveform = waveform.view(-1, shape[-1]) assert orig_freq > 0.0 and new_freq > 0.0 - orig_freq = int(orig_freq) - new_freq = int(new_freq) - gcd = math.gcd(orig_freq, new_freq) - orig_freq = orig_freq // gcd - new_freq = new_freq // gcd - - kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width, - waveform.device, waveform.dtype) + gcd = math.gcd(int(orig_freq), int(new_freq)) - num_wavs, length = waveform.shape - waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq)) - resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq) - resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) - target_length = int(math.ceil(new_freq * length / orig_freq)) - resampled = resampled[..., :target_length] - - # unpack batch - resampled = resampled.view(shape[:-1] + resampled.shape[-1:]) + kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff, + resampling_method, beta, waveform.device, waveform.dtype) + resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width) return resampled diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py index 5b134345af..6696d8ded2 100644 --- a/torchaudio/models/__init__.py +++ b/torchaudio/models/__init__.py @@ -1,9 +1,11 @@ from .wav2letter import Wav2Letter from .wavernn import WaveRNN from .conv_tasnet import ConvTasNet +from .deepspeech import DeepSpeech __all__ = [ 'Wav2Letter', 'WaveRNN', 'ConvTasNet', + 'DeepSpeech', ] diff --git a/torchaudio/models/deepspeech.py b/torchaudio/models/deepspeech.py new file mode 100644 index 0000000000..477993e411 --- /dev/null +++ b/torchaudio/models/deepspeech.py @@ -0,0 +1,92 @@ +import torch + +__all__ = ["DeepSpeech"] + + +class FullyConnected(torch.nn.Module): + """ + Args: + n_feature: Number of input features + n_hidden: Internal hidden unit size. + """ + + def __init__(self, + n_feature: int, + n_hidden: int, + dropout: float, + relu_max_clip: int = 20) -> None: + super(FullyConnected, self).__init__() + self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True) + self.relu_max_clip = relu_max_clip + self.dropout = dropout + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + x = torch.nn.functional.relu(x) + x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip) + if self.dropout: + x = torch.nn.functional.dropout(x, self.dropout, self.training) + return x + + +class DeepSpeech(torch.nn.Module): + """ + DeepSpeech model architecture from + `"Deep Speech: Scaling up end-to-end speech recognition"` + paper. + + Args: + n_feature: Number of input features + n_hidden: Internal hidden unit size. + n_class: Number of output classes + """ + + def __init__( + self, + n_feature: int, + n_hidden: int = 2048, + n_class: int = 40, + dropout: float = 0.0, + ) -> None: + super(DeepSpeech, self).__init__() + self.n_hidden = n_hidden + self.fc1 = FullyConnected(n_feature, n_hidden, dropout) + self.fc2 = FullyConnected(n_hidden, n_hidden, dropout) + self.fc3 = FullyConnected(n_hidden, n_hidden, dropout) + self.bi_rnn = torch.nn.RNN( + n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True + ) + self.fc4 = FullyConnected(n_hidden, n_hidden, dropout) + self.out = torch.nn.Linear(n_hidden, n_class) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Tensor of dimension (batch, channel, time, feature). + Returns: + Tensor: Predictor tensor of dimension (batch, time, class). + """ + # N x C x T x F + x = self.fc1(x) + # N x C x T x H + x = self.fc2(x) + # N x C x T x H + x = self.fc3(x) + # N x C x T x H + x = x.squeeze(1) + # N x T x H + x = x.transpose(0, 1) + # T x N x H + x, _ = self.bi_rnn(x) + # The fifth (non-recurrent) layer takes both the forward and backward units as inputs + x = x[:, :, :self.n_hidden] + x[:, :, self.n_hidden:] + # T x N x H + x = self.fc4(x) + # T x N x H + x = self.out(x) + # T x N x n_class + x = x.permute(1, 0, 2) + # N x T x n_class + x = torch.nn.functional.log_softmax(x, dim=2) + # N x T x n_class + return x diff --git a/torchaudio/prototype/rnnt_loss.py b/torchaudio/prototype/rnnt_loss.py index 08c96ab0f2..2bce60835b 100644 --- a/torchaudio/prototype/rnnt_loss.py +++ b/torchaudio/prototype/rnnt_loss.py @@ -1,4 +1,5 @@ import torch +from torch import Tensor __all__ = [ "RNNTLoss", @@ -6,141 +7,15 @@ ] -def _rnnt_loss_alphas( - logits, - targets, - logit_lengths, - target_lengths, - blank=-1, - clamp=-1, -): - """ - Compute alphas for RNN transducer loss. - - See documentation for RNNTLoss - """ - targets = targets.to(device=logits.device) - logit_lengths = logit_lengths.to(device=logits.device) - target_lengths = target_lengths.to(device=logits.device) - - # make sure all int tensors are of type int32. - targets = targets.int() - logit_lengths = logit_lengths.int() - target_lengths = target_lengths.int() - - return torch.ops.torchaudio.rnnt_loss_alphas( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - ) - - -def _rnnt_loss_betas( - logits, - targets, - logit_lengths, - target_lengths, - blank=-1, - clamp=-1, -): - """ - Compute betas for RNN transducer loss - - See documentation for RNNTLoss - """ - targets = targets.to(device=logits.device) - logit_lengths = logit_lengths.to(device=logits.device) - target_lengths = target_lengths.to(device=logits.device) - - # make sure all int tensors are of type int32. - targets = targets.int() - logit_lengths = logit_lengths.int() - target_lengths = target_lengths.int() - - return torch.ops.torchaudio.rnnt_loss_betas( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - ) - - -class _RNNT(torch.autograd.Function): - @staticmethod - def forward( - ctx, - logits, - targets, - logit_lengths, - target_lengths, - blank=-1, - clamp=-1, - fused_log_softmax=True, - reuse_logits_for_grads=True, - ): - """ - See documentation for RNNTLoss - """ - - # move everything to the same device. - targets = targets.to(device=logits.device) - logit_lengths = logit_lengths.to(device=logits.device) - target_lengths = target_lengths.to(device=logits.device) - - # make sure all int tensors are of type int32. - targets = targets.int() - logit_lengths = logit_lengths.int() - target_lengths = target_lengths.int() - - if blank < 0: # reinterpret blank index if blank < 0. - blank = logits.shape[-1] + blank - - costs, gradients = torch.ops.torchaudio.rnnt_loss( - logits=logits, - targets=targets, - src_lengths=logit_lengths, - tgt_lengths=target_lengths, - blank=blank, - clamp=clamp, - fused_log_smax=fused_log_softmax, - reuse_logits_for_grads=reuse_logits_for_grads, - ) - - ctx.grads = gradients - - return costs - - @staticmethod - def backward(ctx, output_gradients): - output_gradients = output_gradients.view(-1, 1, 1, 1).to(ctx.grads) - ctx.grads.mul_(output_gradients).to(ctx.grads) - - return ( - ctx.grads, # logits - None, # targets - None, # logit_lengths - None, # target_lengths - None, # blank - None, # clamp - None, # fused_log_softmax - None, # reuse_logits_for_grads - ) - - def rnnt_loss( - logits, - targets, - logit_lengths, - target_lengths, - blank=-1, - clamp=-1, - fused_log_softmax=True, - reuse_logits_for_grads=True, + logits: Tensor, + targets: Tensor, + logit_lengths: Tensor, + target_lengths: Tensor, + blank: int = -1, + clamp: float = -1, + fused_log_softmax: bool = True, + reuse_logits_for_grads: bool = True, ): """ Compute the RNN Transducer Loss. @@ -166,17 +41,20 @@ def rnnt_loss( False # softmax needs the original logits value ) - cost = _RNNT.apply( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax, - reuse_logits_for_grads, - ) - return cost + if blank < 0: # reinterpret blank index if blank < 0. + blank = logits.shape[-1] + blank + + costs, gradients = torch.ops.torchaudio.rnnt_loss( + logits=logits, + targets=targets, + src_lengths=logit_lengths, + tgt_lengths=target_lengths, + blank=blank, + clamp=clamp, + fused_log_smax=fused_log_softmax, + reuse_logits_for_grads=reuse_logits_for_grads,) + + return costs class RNNTLoss(torch.nn.Module): @@ -196,10 +74,10 @@ class RNNTLoss(torch.nn.Module): def __init__( self, - blank=-1, - clamp=-1, - fused_log_softmax=True, - reuse_logits_for_grads=True, + blank: int = -1, + clamp: float = -1., + fused_log_softmax: bool = True, + reuse_logits_for_grads: bool = True, ): super().__init__() self.blank = blank diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index fa38118043..49d3ff599b 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -8,6 +8,10 @@ from torch import Tensor from torchaudio import functional as F +from .functional.functional import ( + _get_sinc_resample_kernel, + _apply_sinc_resample_kernel, +) __all__ = [ 'Spectrogram', @@ -279,11 +283,34 @@ def __init__(self, assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) + if n_stft is None or n_stft == 0: + warnings.warn( + 'Initialization of torchaudio.transforms.MelScale with an unset weight ' + '`n_stft=None` is deprecated and will be removed from a future release. ' + 'Please set a proper `n_stft` value. Typically this is `n_fft // 2 + 1`. ' + 'Refer to https://github.com/pytorch/audio/issues/1510 ' + 'for more details.' + ) + fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale) self.register_buffer('fb', fb) + def __prepare_scriptable__(self): + r"""If `self.fb` is empty, the `forward` method will try to resize the parameter, + which does not work once the transform is scripted. However, this error does not happen + until the transform is executed. This is inconvenient especially if the resulting + TorchScript object is executed in other environments. Therefore, we check the + validity of `self.fb` here and fail if the resulting TS does not work. + + Returns: + MelScale: self + """ + if self.fb.numel() == 0: + raise ValueError("n_stft must be provided at construction") + return self + def forward(self, specgram: Tensor) -> Tensor: r""" Args: @@ -639,17 +666,40 @@ class Resample(torch.nn.Module): Args: orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``) new_freq (float, optional): The desired frequency. (Default: ``16000``) - resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``) + resampling_method (str, optional): The resampling method to use. + Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``) + lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper + but less efficient. (Default: ``6``) + rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. + Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) + beta (float or None): The shape parameter used for kaiser window. + + Note: If resampling on waveforms of higher precision than float32, there may be a small loss of precision + because the kernel is cached once as float32. If high precision resampling is important for your application, + the functional form will retain higher precision, but run slower because it does not cache the kernel. + Alternatively, you could rewrite a transform that caches a higher precision kernel. """ def __init__(self, - orig_freq: int = 16000, - new_freq: int = 16000, - resampling_method: str = 'sinc_interpolation') -> None: + orig_freq: float = 16000, + new_freq: float = 16000, + resampling_method: str = 'sinc_interpolation', + lowpass_filter_width: int = 6, + rolloff: float = 0.99, + beta: Optional[float] = None) -> None: super(Resample, self).__init__() + self.orig_freq = orig_freq self.new_freq = new_freq + self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq)) self.resampling_method = resampling_method + self.lowpass_filter_width = lowpass_filter_width + self.rolloff = rolloff + + kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd, + self.lowpass_filter_width, self.rolloff, + self.resampling_method, beta) + self.register_buffer('kernel', kernel) def forward(self, waveform: Tensor) -> Tensor: r""" @@ -659,10 +709,8 @@ def forward(self, waveform: Tensor) -> Tensor: Returns: Tensor: Output signal of dimension (..., time). """ - if self.resampling_method == 'sinc_interpolation': - return F.resample(waveform, self.orig_freq, self.new_freq) - - raise ValueError('Invalid resampling method: {}'.format(self.resampling_method)) + return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd, + self.kernel, self.width) class ComplexNorm(torch.nn.Module): @@ -1078,7 +1126,10 @@ def __init__(self, def forward(self, waveform: Tensor) -> Tensor: r""" Args: - waveform (Tensor): Tensor of audio of dimension `(..., time)` + waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)` + Tensor of shape `(channels, time)` is treated as a multi-channel recording + of the same event and the resulting output will be trimmed to the earliest + voice activity in any channel. """ return F.vad( waveform=waveform,