diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 57e9cacf96..fc9201d291 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -8,9 +8,20 @@ from torchaudio_unittest.common_utils import ( TestBaseMixin, get_whitenoise, + nested_params, ) +# TODO: +# - replace T.Spectrogram +# - generalize it +# - move to common_utils +def get_spectrogram(return_complex): + spectrogram = T.Spectrogram(return_complex=return_complex, power=None) + waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) + return spectrogram(waveform) + + class AutogradTestMixin(TestBaseMixin): def assert_grad( self, @@ -23,8 +34,12 @@ def assert_grad( inputs_ = [] for i in inputs: - i.requires_grad = True - inputs_.append(i.to(dtype=torch.float64, device=self.device)) + if torch.is_tensor(i): + i = i.to( + dtype=torch.cdouble if i.is_complex() else torch.double, + device=self.device) + i.requires_grad = True + inputs_.append(i) assert gradcheck(transform, inputs_) assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol) @@ -88,3 +103,21 @@ def test_fade(self, fade_shape): transform = T.Fade(fade_shape=fade_shape) waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) self.assert_grad(transform, [waveform], nondet_tol=1e-10) + + @nested_params( + [0.7, 0.8, 0.9, 1.0, 1.3], + [True, False], + ) + def test_timestretch(self, rate, test_complex): + transform = T.TimeStretch(fixed_rate=rate) + spectrogram = get_spectrogram(return_complex=test_complex) + self.assert_grad(transform, [spectrogram]) + + @nested_params( + [0.7, 0.8, 0.9, 1.0, 1.3], + [True, False], + ) + def test_timestretch_override(self, rate, test_complex): + transform = T.TimeStretch() + spectrogram = get_spectrogram(return_complex=test_complex) + self.assert_grad(transform, [spectrogram, rate])