diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index b8fe6c1863..717a7bc87b 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -154,6 +154,23 @@ def test_vol(self, gain, gain_type): waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) self.assert_grad(transform, [waveform]) + @parameterized.expand([ + ({'cmn_window': 100, 'min_cmn_window': 50, 'center': False, 'norm_vars': False}, ), + ({'cmn_window': 100, 'min_cmn_window': 50, 'center': True, 'norm_vars': False}, ), + ({'cmn_window': 100, 'min_cmn_window': 50, 'center': False, 'norm_vars': True}, ), + ({'cmn_window': 100, 'min_cmn_window': 50, 'center': True, 'norm_vars': True}, ), + ]) + def test_sliding_window_cmn(self, kwargs): + n_fft = 10 + power = 1 + spec = get_spectrogram( + get_whitenoise(sample_rate=200, duration=0.05, n_channels=2), + n_fft=n_fft, power=power) + spec_reshaped = spec.transpose(-1, -2) + + transform = T.SlidingWindowCmn(**kwargs) + self.assert_grad(transform, [spec_reshaped]) + @unittest.expectedFailure def test_timestretch_zeros_fail(self): """Test that ``T.TimeStretch`` fails gradcheck at 0