diff --git a/docs/source/functional.rst b/docs/source/functional.rst index 12ee165060..89e992bb89 100644 --- a/docs/source/functional.rst +++ b/docs/source/functional.rst @@ -8,11 +8,6 @@ torchaudio.functional Functions to perform common audio operations. -:hidden:`istft` -~~~~~~~~~~~~~~~ - -.. autofunction:: istft - :hidden:`spectrogram` ~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/functional_cpu_test.py b/test/functional_cpu_test.py index b162df2ebc..7fa2520b24 100644 --- a/test/functional_cpu_test.py +++ b/test/functional_cpu_test.py @@ -11,31 +11,6 @@ from .functional_impl import Lfilter -def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32): - """ Generates random tensors given a seed and size - https://en.wikipedia.org/wiki/Linear_congruential_generator - X_{n + 1} = (a * X_n + c) % m - Using Borland C/C++ values - - The tensor will have values between [0,1) - Inputs: - seed (int): an int - size (Tuple[int]): the size of the output tensor - a (int): the multiplier constant to the generator - c (int): the additive constant to the generator - m (int): the modulus constant to the generator - """ - num_elements = 1 - for s in size: - num_elements *= s - - arr = [(a * seed + c) % m] - for i in range(num_elements - 1): - arr.append((a * arr[i] + c) % m) - - return torch.tensor(arr).float().view(size) / m - - class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): dtype = torch.float32 device = torch.device('cpu') @@ -63,242 +38,6 @@ def test_two_channels(self): torch.testing.assert_allclose(computed, expected) -def _compare_estimate(sound, estimate, atol=1e-6, rtol=1e-8): - # trim sound for case when constructed signal is shorter than original - sound = sound[..., :estimate.size(-1)] - torch.testing.assert_allclose(estimate, sound, atol=atol, rtol=rtol) - - -def _test_istft_is_inverse_of_stft(kwargs): - # generates a random sound signal for each tril and then does the stft/istft - # operation to check whether we can reconstruct signal - for data_size in [(2, 20), (3, 15), (4, 10)]: - for i in range(100): - - sound = random_float_tensor(i, data_size) - - stft = torch.stft(sound, **kwargs) - estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs) - - _compare_estimate(sound, estimate) - - -class TestIstft(common_utils.TorchaudioTestCase): - """Test suite for correctness of istft with various input""" - number_of_trials = 100 - - def test_istft_is_inverse_of_stft1(self): - # hann_window, centered, normalized, onesided - kwargs1 = { - 'n_fft': 12, - 'hop_length': 4, - 'win_length': 12, - 'window': torch.hann_window(12), - 'center': True, - 'pad_mode': 'reflect', - 'normalized': True, - 'onesided': True, - } - _test_istft_is_inverse_of_stft(kwargs1) - - def test_istft_is_inverse_of_stft2(self): - # hann_window, centered, not normalized, not onesided - kwargs2 = { - 'n_fft': 12, - 'hop_length': 2, - 'win_length': 8, - 'window': torch.hann_window(8), - 'center': True, - 'pad_mode': 'reflect', - 'normalized': False, - 'onesided': False, - } - _test_istft_is_inverse_of_stft(kwargs2) - - def test_istft_is_inverse_of_stft3(self): - # hamming_window, centered, normalized, not onesided - kwargs3 = { - 'n_fft': 15, - 'hop_length': 3, - 'win_length': 11, - 'window': torch.hamming_window(11), - 'center': True, - 'pad_mode': 'constant', - 'normalized': True, - 'onesided': False, - } - _test_istft_is_inverse_of_stft(kwargs3) - - def test_istft_is_inverse_of_stft4(self): - # hamming_window, not centered, not normalized, onesided - # window same size as n_fft - kwargs4 = { - 'n_fft': 5, - 'hop_length': 2, - 'win_length': 5, - 'window': torch.hamming_window(5), - 'center': False, - 'pad_mode': 'constant', - 'normalized': False, - 'onesided': True, - } - _test_istft_is_inverse_of_stft(kwargs4) - - def test_istft_is_inverse_of_stft5(self): - # hamming_window, not centered, not normalized, not onesided - # window same size as n_fft - kwargs5 = { - 'n_fft': 3, - 'hop_length': 2, - 'win_length': 3, - 'window': torch.hamming_window(3), - 'center': False, - 'pad_mode': 'reflect', - 'normalized': False, - 'onesided': False, - } - _test_istft_is_inverse_of_stft(kwargs5) - - def test_istft_of_ones(self): - # stft = torch.stft(torch.ones(4), 4) - stft = torch.tensor([ - [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]], - [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], - [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]] - ]) - - estimate = torchaudio.functional.istft(stft, n_fft=4, length=4) - _compare_estimate(torch.ones(4), estimate) - - def test_istft_of_zeros(self): - # stft = torch.stft(torch.zeros(4), 4) - stft = torch.zeros((3, 5, 2)) - - estimate = torchaudio.functional.istft(stft, n_fft=4, length=4) - _compare_estimate(torch.zeros(4), estimate) - - def test_istft_requires_overlap_windows(self): - # the window is size 1 but it hops 20 so there is a gap which throw an error - stft = torch.zeros((3, 5, 2)) - self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4, - hop_length=20, win_length=1, window=torch.ones(1)) - - def test_istft_requires_nola(self): - stft = torch.zeros((3, 5, 2)) - kwargs_ok = { - 'n_fft': 4, - 'win_length': 4, - 'window': torch.ones(4), - } - - kwargs_not_ok = { - 'n_fft': 4, - 'win_length': 4, - 'window': torch.zeros(4), - } - - # A window of ones meets NOLA but a window of zeros does not. This should - # throw an error. - torchaudio.functional.istft(stft, **kwargs_ok) - self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok) - - def test_istft_requires_non_empty(self): - self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2) - self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2) - - def _test_istft_of_sine(self, amplitude, L, n): - # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L - x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype()) - sound = amplitude * torch.sin(2 * math.pi / L * x * n) - # stft = torch.stft(sound, L, hop_length=L, win_length=L, - # window=torch.ones(L), center=False, normalized=False) - stft = torch.zeros((L // 2 + 1, 2, 2)) - stft_largest_val = (amplitude * L) / 2.0 - if n < stft.size(0): - stft[n, :, 1] = -stft_largest_val - - if 0 <= L - n < stft.size(0): - # symmetric about L // 2 - stft[L - n, :, 1] = stft_largest_val - - estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L, - window=torch.ones(L), center=False, normalized=False) - # There is a larger error due to the scaling of amplitude - _compare_estimate(sound, estimate, atol=1e-3) - - def test_istft_of_sine(self): - self._test_istft_of_sine(amplitude=123, L=5, n=1) - self._test_istft_of_sine(amplitude=150, L=5, n=2) - self._test_istft_of_sine(amplitude=111, L=5, n=3) - self._test_istft_of_sine(amplitude=160, L=7, n=4) - self._test_istft_of_sine(amplitude=145, L=8, n=5) - self._test_istft_of_sine(amplitude=80, L=9, n=6) - self._test_istft_of_sine(amplitude=99, L=10, n=7) - - def _test_linearity_of_istft(self, data_size, kwargs, atol=1e-6, rtol=1e-8): - for i in range(self.number_of_trials): - tensor1 = random_float_tensor(i, data_size) - tensor2 = random_float_tensor(i * 2, data_size) - a, b = torch.rand(2) - istft1 = torchaudio.functional.istft(tensor1, **kwargs) - istft2 = torchaudio.functional.istft(tensor2, **kwargs) - istft = a * istft1 + b * istft2 - estimate = torchaudio.functional.istft(a * tensor1 + b * tensor2, **kwargs) - _compare_estimate(istft, estimate, atol, rtol) - - def test_linearity_of_istft1(self): - # hann_window, centered, normalized, onesided - kwargs1 = { - 'n_fft': 12, - 'window': torch.hann_window(12), - 'center': True, - 'pad_mode': 'reflect', - 'normalized': True, - 'onesided': True, - } - data_size = (2, 7, 7, 2) - self._test_linearity_of_istft(data_size, kwargs1) - - def test_linearity_of_istft2(self): - # hann_window, centered, not normalized, not onesided - kwargs2 = { - 'n_fft': 12, - 'window': torch.hann_window(12), - 'center': True, - 'pad_mode': 'reflect', - 'normalized': False, - 'onesided': False, - } - data_size = (2, 12, 7, 2) - self._test_linearity_of_istft(data_size, kwargs2) - - def test_linearity_of_istft3(self): - # hamming_window, centered, normalized, not onesided - kwargs3 = { - 'n_fft': 12, - 'window': torch.hamming_window(12), - 'center': True, - 'pad_mode': 'constant', - 'normalized': True, - 'onesided': False, - } - data_size = (2, 12, 7, 2) - self._test_linearity_of_istft(data_size, kwargs3) - - def test_linearity_of_istft4(self): - # hamming_window, not centered, not normalized, onesided - kwargs4 = { - 'n_fft': 12, - 'window': torch.hamming_window(12), - 'center': False, - 'pad_mode': 'constant', - 'normalized': False, - 'onesided': True, - } - data_size = (2, 7, 3, 2) - self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8) - - class TestDetectPitchFrequency(common_utils.TorchaudioTestCase): @parameterized.expand([(100,), (440,)]) def test_pitch(self, frequency): diff --git a/test/test_batch_consistency.py b/test/test_batch_consistency.py index e7040fdc0b..6e89bbb235 100644 --- a/test/test_batch_consistency.py +++ b/test/test_batch_consistency.py @@ -59,14 +59,6 @@ def test_detect_pitch_frequency(self, frequency, sample_rate, n_channels): n_channels=n_channels, duration=5) self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate) - def test_istft(self): - stft = torch.tensor([ - [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]], - [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], - [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]] - ]) - self.assert_batch_consistencies(F.istft, stft, n_fft=4, length=4) - def test_contrast(self): waveform = torch.rand(2, 100) - 0.5 self.assert_batch_consistencies(F.contrast, waveform, enhancement_amount=80.) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 713544cc74..3b70c8efe2 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -8,7 +8,6 @@ from torch import Tensor __all__ = [ - "istft", "spectrogram", "griffinlim", "amplitude_to_DB", @@ -45,79 +44,6 @@ ] -def istft( - stft_matrix: Tensor, - n_fft: int, - hop_length: Optional[int] = None, - win_length: Optional[int] = None, - window: Optional[Tensor] = None, - center: bool = True, - pad_mode: Optional[str] = None, - normalized: bool = False, - onesided: bool = True, - length: Optional[int] = None, -) -> Tensor: - r"""Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft. - It has the same parameters (+ additional optional parameter of ``length``) and it should return the - least squares estimation of the original signal. The algorithm will check using the NOLA condition ( - nonzero overlap). - - Important consideration in the parameters ``window`` and ``center`` so that the envelop - created by the summation of all the windows is never zero at certain point in time. Specifically, - :math:`\sum_{t=-\infty}^{\infty} w^2[n-t\times hop\_length] \cancel{=} 0`. - - Since stft discards elements at the end of the signal if they do not fit in a frame, the - istft may return a shorter signal than the original signal (can occur if ``center`` is False - since the signal isn't padded). - - If ``center`` is True, then there will be padding e.g. 'constant', 'reflect', etc. Left padding - can be trimmed off exactly because they can be calculated but right padding cannot be calculated - without additional information. - - Example: Suppose the last window is: - [17, 18, 0, 0, 0] vs [18, 0, 0, 0, 0] - - The n_frame, hop_length, win_length are all the same which prevents the calculation of right padding. - These additional values could be zeros or a reflection of the signal so providing ``length`` - could be useful. If ``length`` is ``None`` then padding will be aggressively removed - (some loss of signal). - - [1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform," - IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984. - - Args: - stft_matrix (Tensor): Output of stft where each row of a channel is a frequency and each - column is a window. It has a size of either (..., fft_size, n_frame, 2) - n_fft (int): Size of Fourier transform - hop_length (int or None, optional): The distance between neighboring sliding window frames. - (Default: ``win_length // 4``) - win_length (int or None, optional): The size of window frame and STFT filter. (Default: ``n_fft``) - window (Tensor or None, optional): The optional window function. - (Default: ``torch.ones(win_length)``) - center (bool, optional): Whether ``input`` was padded on both sides so - that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. - (Default: ``True``) - pad_mode: This argument was ignored and to be removed. - normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``) - onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``) - length (int or None, optional): The amount to trim the signal by (i.e. the - original signal length). (Default: whole signal) - - Returns: - Tensor: Least squares estimation of the original signal of size (..., signal_length) - """ - warnings.warn( - 'istft has been moved to PyTorch and will be removed from torchaudio, ' - 'please use torch.istft instead.') - if pad_mode is not None: - warnings.warn( - 'The parameter `pad_mode` was ignored in isftft, and is thus being deprecated. ' - 'Please set `pad_mode` to None to suppress this warning.') - return torch.istft( - input=stft_matrix, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, - center=center, normalized=normalized, onesided=onesided, length=length) - - def spectrogram( waveform: Tensor, pad: int, @@ -250,12 +176,12 @@ def griffinlim( tprev = rebuilt # Invert with our current estimate of the phases - inverse = istft(specgram * angles, - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - window=window, - length=length).float() + inverse = torch.istft(specgram * angles, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + length=length).float() # Rebuild the spectrogram rebuilt = torch.stft(inverse, n_fft, hop_length, win_length, window, @@ -268,12 +194,12 @@ def griffinlim( angles = angles.div(complex_norm(angles).add(1e-16).unsqueeze(-1).expand_as(angles)) # Return the final phase estimates - waveform = istft(specgram * angles, - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - window=window, - length=length) + waveform = torch.istft(specgram * angles, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + length=length) # unpack batch waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])