From d48089a05fef44aad03a0a9bc428d3c93ffac63d Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 31 Mar 2021 00:03:31 -0700 Subject: [PATCH] Import torchaudio #1412 c0bfb03 Summary: Import latest from github to fbcode Pass: 951 Skip: 19 Omit: 1 ListingSuccess: 26 Result available at: https://www.internalfb.com/intern/testinfra/testrun/8444249336935844 Reviewed By: mthrok Differential Revision: D27448988 fbshipit-source-id: 61f63ffa1295a31b4452abaf2c74ebfefb827dcf --- docs/source/functional.rst | 7 +- .../common_utils/__init__.py | 1 - .../common_utils/case_utils.py | 3 - .../compliance_kaldi_test.py | 3 +- .../functional/librosa_compatibility_test.py | 19 ++++ .../functional/sox_compatibility_test.py | 4 +- .../transforms/sox_compatibility_test.py | 14 ++- torchaudio/compliance/kaldi.py | 81 +------------- torchaudio/datasets/yesno.py | 4 +- torchaudio/functional/__init__.py | 4 +- torchaudio/functional/functional.py | 103 ++++++++++++++++++ torchaudio/transforms.py | 13 +-- 12 files changed, 153 insertions(+), 103 deletions(-) diff --git a/docs/source/functional.rst b/docs/source/functional.rst index e52e83ca42..98517929ea 100644 --- a/docs/source/functional.rst +++ b/docs/source/functional.rst @@ -55,7 +55,7 @@ apply_codec ----------- .. autofunction:: apply_codec - + :hidden:`Complex Utility` ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -230,3 +230,8 @@ vad --------------------------- .. autofunction:: spectral_centroid + +:hidden:`resample` +--------------------------- + +.. autofunction:: resample diff --git a/test/torchaudio_unittest/common_utils/__init__.py b/test/torchaudio_unittest/common_utils/__init__.py index 22cb18a08e..736ec82855 100644 --- a/test/torchaudio_unittest/common_utils/__init__.py +++ b/test/torchaudio_unittest/common_utils/__init__.py @@ -17,7 +17,6 @@ skipIfNoModule, skipIfNoKaldi, skipIfNoSox, - skipIfNoSoxBackend, ) from .wav_utils import ( get_wav_data, diff --git a/test/torchaudio_unittest/common_utils/case_utils.py b/test/torchaudio_unittest/common_utils/case_utils.py index 5d476072bb..69b9a795cd 100644 --- a/test/torchaudio_unittest/common_utils/case_utils.py +++ b/test/torchaudio_unittest/common_utils/case_utils.py @@ -7,7 +7,6 @@ import torch from torch.testing._internal.common_utils import TestCase as PytorchTestCase -import torchaudio from torchaudio._internal.module_utils import ( is_module_available, is_sox_available, @@ -96,8 +95,6 @@ def skipIfNoModule(module, display_name=None): return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available') -skipIfNoSoxBackend = unittest.skipIf( - 'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available') skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available') skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available') skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason='Kaldi not available') diff --git a/test/torchaudio_unittest/compliance_kaldi_test.py b/test/torchaudio_unittest/compliance_kaldi_test.py index a98240a9b8..17d73fbf2f 100644 --- a/test/torchaudio_unittest/compliance_kaldi_test.py +++ b/test/torchaudio_unittest/compliance_kaldi_test.py @@ -47,7 +47,6 @@ def first_sample_of_frame(frame, window_size, window_shift, snip_edges): @common_utils.skipIfNoSox class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): - backend = 'sox_io' kaldi_output_dir = common_utils.get_asset_path('kaldi') test_filepath = common_utils.get_asset_path('kaldi_file.wav') @@ -113,7 +112,7 @@ def _create_data_set(self): # clear the last 16 bits because they aren't used anyways y = ((y >> 16) << 16).float() torchaudio.save(self.test_filepath, y, sr) - sound, sample_rate = torchaudio.load(self.test_filepath, normalization=False) + sound, sample_rate = common_utils.load_wav(self.test_filepath, normalize=False) print(y >> 16) self.assertTrue(sample_rate == sr) self.assertEqual(y, sound) diff --git a/test/torchaudio_unittest/functional/librosa_compatibility_test.py b/test/torchaudio_unittest/functional/librosa_compatibility_test.py index c88d9f3ccd..5905c239d3 100644 --- a/test/torchaudio_unittest/functional/librosa_compatibility_test.py +++ b/test/torchaudio_unittest/functional/librosa_compatibility_test.py @@ -109,6 +109,25 @@ def test_amplitude_to_DB(self): self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5) + def test_resample(self): + input_path = common_utils.get_asset_path('sinewave.wav') + waveform, sample_rate = common_utils.load_wav(input_path) + + upsample_rate = sample_rate * 2 + downsample_rate = sample_rate // 2 + + ta_upsampled = F.resample(waveform, sample_rate, upsample_rate) + lr_upsampled = librosa.resample(waveform.squeeze(0).numpy(), sample_rate, upsample_rate) + lr_upsampled = torch.from_numpy(lr_upsampled).unsqueeze(0) + + self.assertEqual(ta_upsampled, lr_upsampled, atol=1e-2, rtol=1e-5) + + ta_downsampled = F.resample(waveform, sample_rate, downsample_rate) + lr_downsampled = librosa.resample(waveform.squeeze(0).numpy(), sample_rate, downsample_rate) + lr_downsampled = torch.from_numpy(lr_downsampled).unsqueeze(0) + + self.assertEqual(ta_downsampled, lr_downsampled, atol=1e-2, rtol=1e-5) + @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") class TestPhaseVocoder(common_utils.TorchaudioTestCase): diff --git a/test/torchaudio_unittest/functional/sox_compatibility_test.py b/test/torchaudio_unittest/functional/sox_compatibility_test.py index 21f082683c..fe9744f22a 100644 --- a/test/torchaudio_unittest/functional/sox_compatibility_test.py +++ b/test/torchaudio_unittest/functional/sox_compatibility_test.py @@ -2,7 +2,7 @@ import torchaudio.functional as F from torchaudio_unittest.common_utils import ( - skipIfNoSoxBackend, + skipIfNoSox, skipIfNoExec, TempDirMixin, TorchaudioTestCase, @@ -14,7 +14,7 @@ ) -@skipIfNoSoxBackend +@skipIfNoSox @skipIfNoExec('sox') class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): def run_sox_effect(self, input_file, effect): diff --git a/test/torchaudio_unittest/transforms/sox_compatibility_test.py b/test/torchaudio_unittest/transforms/sox_compatibility_test.py index 4bf6bf7884..81582c8393 100644 --- a/test/torchaudio_unittest/transforms/sox_compatibility_test.py +++ b/test/torchaudio_unittest/transforms/sox_compatibility_test.py @@ -2,17 +2,19 @@ from parameterized import parameterized from torchaudio_unittest.common_utils import ( - skipIfNoSoxBackend, + skipIfNoSox, skipIfNoExec, TempDirMixin, TorchaudioTestCase, get_asset_path, sox_utils, load_wav, + save_wav, + get_whitenoise, ) -@skipIfNoSoxBackend +@skipIfNoSox @skipIfNoExec('sox') class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): def run_sox_effect(self, input_file, effect): @@ -24,6 +26,14 @@ def assert_sox_effect(self, result, input_path, effects, atol=1e-04, rtol=1e-5): expected, _ = self.run_sox_effect(input_path, effects) self.assertEqual(result, expected, atol=atol, rtol=rtol) + def get_whitenoise(self, sample_rate=8000): + noise = get_whitenoise( + sample_rate=sample_rate, duration=3, scale_factor=0.9, + ) + path = self.get_temp_path("whitenoise.wav") + save_wav(path, noise, sample_rate) + return noise, path + @parameterized.expand([ ('q', 'quarter_sine'), ('h', 'half_sine'), diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index f1e1ddf5b5..a5c3354c9b 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -3,7 +3,6 @@ import math import torch from torch import Tensor -from torch.nn import functional as F import torchaudio import torchaudio._internal.fft @@ -753,71 +752,16 @@ def mfcc( return feature -def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int, - device: torch.device, dtype: torch.dtype): - assert lowpass_filter_width > 0 - kernels = [] - base_freq = min(orig_freq, new_freq) - # This will perform antialiasing filtering by removing the highest frequencies. - # At first I thought I only needed this when downsampling, but when upsampling - # you will get edge artifacts without this, as the edge is equivalent to zero padding, - # which will add high freq artifacts. - base_freq *= 0.99 - - # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) - # using the sinc interpolation formula: - # x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t)) - # We can then sample the function x(t) with a different sample rate: - # y[j] = x(j / new_freq) - # or, - # y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) - - # We see here that y[j] is the convolution of x[i] with a specific filter, for which - # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing. - # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq]. - # Indeed: - # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq)) - # = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq)) - # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) - # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`. - # This will explain the F.conv1d after, with a stride of orig_freq. - width = math.ceil(lowpass_filter_width * orig_freq / base_freq) - # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e., - # they will have a lot of almost zero values to the left or to the right... - # There is probably a way to evaluate those filters more efficiently, but this is kept for - # future work. - idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype) - - for i in range(new_freq): - t = (-i / new_freq + idx / orig_freq) * base_freq - t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) - t *= math.pi - # we do not use torch.hann_window here as we need to evaluate the window - # at specific positions, not over a regular grid. - window = torch.cos(t / lowpass_filter_width / 2)**2 - kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t) - kernel.mul_(window) - kernels.append(kernel) - - scale = base_freq / orig_freq - return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width - - def resample_waveform(waveform: Tensor, orig_freq: float, new_freq: float, lowpass_filter_width: int = 6) -> Tensor: - r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform - which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample - a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e - the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to - upsample/downsample the signal. + r"""Resamples the waveform at the new frequency. - https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html - https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56 + This is a wrapper around ``torchaudio.functional.resample``. Args: - waveform (Tensor): The input signal of size (c, n) + waveform (Tensor): The input signal of size (..., time) orig_freq (float): The original frequency of the signal new_freq (float): The desired frequency lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper @@ -826,21 +770,4 @@ def resample_waveform(waveform: Tensor, Returns: Tensor: The waveform at the new frequency """ - assert waveform.dim() == 2 - assert orig_freq > 0.0 and new_freq > 0.0 - - orig_freq = int(orig_freq) - new_freq = int(new_freq) - gcd = math.gcd(orig_freq, new_freq) - orig_freq = orig_freq // gcd - new_freq = new_freq // gcd - - kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width, - waveform.device, waveform.dtype) - - num_wavs, length = waveform.shape - waveform = F.pad(waveform, (width, width + orig_freq)) - resampled = F.conv1d(waveform[:, None], kernel, stride=orig_freq) - resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) - target_length = int(math.ceil(new_freq * length / orig_freq)) - return resampled[..., :target_length] + return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width) diff --git a/torchaudio/datasets/yesno.py b/torchaudio/datasets/yesno.py index 6d0d2c0a5f..10b38a1260 100644 --- a/torchaudio/datasets/yesno.py +++ b/torchaudio/datasets/yesno.py @@ -16,7 +16,7 @@ "release1": { "folder_in_archive": "waves_yesno", "url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz", - "checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27", + "checksum": "c3f49e0cca421f96b75b41640749167b52118f232498667ca7a5f9416aef8e73", } } @@ -54,7 +54,7 @@ def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, downloa if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _RELEASE_CONFIGS["release1"]["checksum"] - download_url(url, root, hash_value=checksum, hash_type="md5") + download_url(url, root, hash_value=checksum) extract_archive(archive) if not os.path.isdir(self._path): diff --git a/torchaudio/functional/__init__.py b/torchaudio/functional/__init__.py index 88fc4d545f..39f2fb8f53 100644 --- a/torchaudio/functional/__init__.py +++ b/torchaudio/functional/__init__.py @@ -19,6 +19,7 @@ spectrogram, spectral_centroid, apply_codec, + resample, ) from .filtering import ( allpass_biquad, @@ -85,5 +86,6 @@ 'riaa_biquad', 'treble_biquad', 'vad', - 'apply_codec' + 'apply_codec', + 'resample', ] diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index b114b9aacf..73496bf4cc 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -33,6 +33,7 @@ 'sliding_window_cmn', "spectral_centroid", "apply_codec", + "resample", ] @@ -1209,3 +1210,105 @@ def compute_kaldi_pitch( ) result = result.reshape(shape[:-1] + result.shape[-2:]) return result + + +def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int, + device: torch.device, dtype: torch.dtype): + assert lowpass_filter_width > 0 + kernels = [] + base_freq = min(orig_freq, new_freq) + # This will perform antialiasing filtering by removing the highest frequencies. + # At first I thought I only needed this when downsampling, but when upsampling + # you will get edge artifacts without this, as the edge is equivalent to zero padding, + # which will add high freq artifacts. + base_freq *= 0.99 + + # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) + # using the sinc interpolation formula: + # x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t)) + # We can then sample the function x(t) with a different sample rate: + # y[j] = x(j / new_freq) + # or, + # y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) + + # We see here that y[j] is the convolution of x[i] with a specific filter, for which + # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing. + # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq]. + # Indeed: + # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq)) + # = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq)) + # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) + # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`. + # This will explain the F.conv1d after, with a stride of orig_freq. + width = math.ceil(lowpass_filter_width * orig_freq / base_freq) + # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e., + # they will have a lot of almost zero values to the left or to the right... + # There is probably a way to evaluate those filters more efficiently, but this is kept for + # future work. + idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype) + + for i in range(new_freq): + t = (-i / new_freq + idx / orig_freq) * base_freq + t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) + t *= math.pi + # we do not use torch.hann_window here as we need to evaluate the window + # at specific positions, not over a regular grid. + window = torch.cos(t / lowpass_filter_width / 2)**2 + kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t) + kernel.mul_(window) + kernels.append(kernel) + + scale = base_freq / orig_freq + return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width + + +def resample( + waveform: Tensor, + orig_freq: float, + new_freq: float, + lowpass_filter_width: int = 6 +) -> Tensor: + r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform + which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample + a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e + the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to + upsample/downsample the signal. + + https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html + https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56 + + Args: + waveform (Tensor): The input signal of dimension (..., time) + orig_freq (float): The original frequency of the signal + new_freq (float): The desired frequency + lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper + but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) + + Returns: + Tensor: The waveform at the new frequency of dimension (..., time). + """ + # pack batch + shape = waveform.size() + waveform = waveform.view(-1, shape[-1]) + + assert orig_freq > 0.0 and new_freq > 0.0 + + orig_freq = int(orig_freq) + new_freq = int(new_freq) + gcd = math.gcd(orig_freq, new_freq) + orig_freq = orig_freq // gcd + new_freq = new_freq // gcd + + kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width, + waveform.device, waveform.dtype) + + num_wavs, length = waveform.shape + waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq)) + resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq) + resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) + target_length = int(math.ceil(new_freq * length / orig_freq)) + resampled = resampled[..., :target_length] + + # unpack batch + resampled = resampled.view(shape[:-1] + resampled.shape[-1:]) + return resampled diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index d57b3eb35c..6f3ea1b6e1 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -6,7 +6,6 @@ import torch from torch import Tensor from torchaudio import functional as F -from torchaudio.compliance import kaldi __all__ = [ @@ -649,17 +648,7 @@ def forward(self, waveform: Tensor) -> Tensor: Tensor: Output signal of dimension (..., time). """ if self.resampling_method == 'sinc_interpolation': - - # pack batch - shape = waveform.size() - waveform = waveform.view(-1, shape[-1]) - - waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq) - - # unpack batch - waveform = waveform.view(shape[:-1] + waveform.shape[-1:]) - - return waveform + return F.resample(waveform, self.orig_freq, self.new_freq) raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))