diff --git a/test/torchaudio_unittest/functional/functional_cpu_test.py b/test/torchaudio_unittest/functional/functional_cpu_test.py index 4fbe4d871e..348d5bad60 100644 --- a/test/torchaudio_unittest/functional/functional_cpu_test.py +++ b/test/torchaudio_unittest/functional/functional_cpu_test.py @@ -1,5 +1,4 @@ import math -import unittest import torch import torchaudio @@ -8,7 +7,7 @@ import pytest from torchaudio_unittest import common_utils -from .functional_impl import Lfilter +from .functional_impl import Lfilter, Spectrogram class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): @@ -21,6 +20,16 @@ class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase): device = torch.device('cpu') +class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase): + dtype = torch.float32 + device = torch.device('cpu') + + +class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase): + dtype = torch.float64 + device = torch.device('cpu') + + class TestCreateFBMatrix(common_utils.TorchaudioTestCase): def test_no_warning_high_n_freq(self): with pytest.warns(None) as w: diff --git a/test/torchaudio_unittest/functional/functional_cuda_test.py b/test/torchaudio_unittest/functional/functional_cuda_test.py index 80ac5fa0c2..c89795be01 100644 --- a/test/torchaudio_unittest/functional/functional_cuda_test.py +++ b/test/torchaudio_unittest/functional/functional_cuda_test.py @@ -1,7 +1,7 @@ import torch from torchaudio_unittest import common_utils -from .functional_impl import Lfilter +from .functional_impl import Lfilter, Spectrogram @common_utils.skipIfNoCuda @@ -14,3 +14,15 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase): dtype = torch.float64 device = torch.device('cuda') + + +@common_utils.skipIfNoCuda +class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase): + dtype = torch.float32 + device = torch.device('cuda') + + +@common_utils.skipIfNoCuda +class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase): + dtype = torch.float64 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index c63001a68e..1ba5e21fbb 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -1,6 +1,7 @@ """Test defintion common to CPU and CUDA""" import torch import torchaudio.functional as F +from parameterized import parameterized from torchaudio_unittest import common_utils @@ -29,3 +30,25 @@ def test_clamp(self): assert output_signal.max() <= 1 output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False) assert output_signal.max() > 1 + + +class Spectrogram(common_utils.TestBaseMixin): + @parameterized.expand([(0., ), (1., ), (2., ), (3., )]) + def test_grad_at_zero(self, power): + """The gradient of power spectrogram should not be nan but zero near x=0 + + https://github.com/pytorch/audio/issues/993 + """ + x = torch.zeros(1, 22050, requires_grad=True) + spec = F.spectrogram( + x, + pad=0, + window=None, + n_fft=2048, + hop_length=None, + win_length=None, + power=power, + normalized=False, + ) + spec.sum().backward() + assert not x.grad.isnan().sum() diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 3fedf93d09..0c07ec0847 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -70,30 +70,29 @@ def spectrogram( waveform = waveform.reshape(-1, shape[-1]) # default values are consistent with librosa.core.spectrum._spectrogram - spec_f = torch.view_as_real( - torch.stft( - input=waveform, - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - window=window, - center=True, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) + spec_f = torch.stft( + input=waveform, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=True, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, ) # unpack batch - spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:]) + spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:]) if normalized: spec_f /= window.pow(2.).sum().sqrt() if power is not None: - spec_f = complex_norm(spec_f, power=power) - - return spec_f + if power == 1.0: + return spec_f.abs() + return spec_f.abs().pow(power) + return torch.view_as_real(spec_f) def griffinlim(