-
Notifications
You must be signed in to change notification settings - Fork 664
/
Copy pathautograd_test_impl.py
217 lines (188 loc) · 9 KB
/
autograd_test_impl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from typing import List
import unittest
from parameterized import parameterized
import torch
from torch.autograd import gradcheck, gradgradcheck
import torchaudio.transforms as T
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
get_spectrogram,
nested_params,
)
class _DeterministicWrapper(torch.nn.Module):
"""Helper transform wrapper to make the given transform deterministic"""
def __init__(self, transform, seed=0):
super().__init__()
self.seed = seed
self.transform = transform
def forward(self, input: torch.Tensor):
torch.random.manual_seed(self.seed)
return self.transform(input)
class AutogradTestMixin(TestBaseMixin):
def assert_grad(
self,
transform: torch.nn.Module,
inputs: List[torch.Tensor],
*,
nondet_tol: float = 0.0,
):
transform = transform.to(dtype=torch.float64, device=self.device)
# gradcheck and gradgradcheck only pass if the input tensors are of dtype `torch.double` or
# `torch.cdouble`, when the default eps and tolerance values are used.
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(
dtype=torch.cdouble if i.is_complex() else torch.double,
device=self.device)
i.requires_grad = True
inputs_.append(i)
assert gradcheck(transform, inputs_)
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)
@parameterized.expand([
({'pad': 0, 'normalized': False, 'power': None, 'return_complex': True}, ),
({'pad': 3, 'normalized': False, 'power': None, 'return_complex': True}, ),
({'pad': 0, 'normalized': True, 'power': None, 'return_complex': True}, ),
({'pad': 3, 'normalized': True, 'power': None, 'return_complex': True}, ),
({'pad': 0, 'normalized': False, 'power': None}, ),
({'pad': 3, 'normalized': False, 'power': None}, ),
({'pad': 0, 'normalized': True, 'power': None}, ),
({'pad': 3, 'normalized': True, 'power': None}, ),
({'pad': 0, 'normalized': False, 'power': 1.0}, ),
({'pad': 3, 'normalized': False, 'power': 1.0}, ),
({'pad': 0, 'normalized': True, 'power': 1.0}, ),
({'pad': 3, 'normalized': True, 'power': 1.0}, ),
({'pad': 0, 'normalized': False, 'power': 2.0}, ),
({'pad': 3, 'normalized': False, 'power': 2.0}, ),
({'pad': 0, 'normalized': True, 'power': 2.0}, ),
({'pad': 3, 'normalized': True, 'power': 2.0}, ),
])
def test_spectrogram(self, kwargs):
# replication_pad1d_backward_cuda is not deteministic and
# gives very small (~2.7756e-17) difference.
#
# See https://github.com/pytorch/pytorch/issues/54093
transform = T.Spectrogram(**kwargs)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)
def test_melspectrogram(self):
# replication_pad1d_backward_cuda is not deteministic and
# gives very small (~2.7756e-17) difference.
#
# See https://github.com/pytorch/pytorch/issues/54093
sample_rate = 8000
transform = T.MelSpectrogram(sample_rate=sample_rate)
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)
@nested_params(
[0, 0.99],
[False, True],
)
def test_griffinlim(self, momentum, rand_init):
n_fft = 400
power = 1
n_iter = 3
spec = get_spectrogram(
get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2),
n_fft=n_fft, power=power)
transform = _DeterministicWrapper(
T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=rand_init, power=power))
self.assert_grad(transform, [spec])
@parameterized.expand([(False, ), (True, )])
def test_mfcc(self, log_mels):
sample_rate = 8000
transform = T.MFCC(sample_rate=sample_rate, log_mels=log_mels)
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform])
def test_compute_deltas(self):
transform = T.ComputeDeltas()
spec = torch.rand(10, 20)
self.assert_grad(transform, [spec])
@parameterized.expand([(8000, 8000), (8000, 4000), (4000, 8000)])
def test_resample(self, orig_freq, new_freq):
transform = T.Resample(orig_freq=orig_freq, new_freq=new_freq)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform])
@parameterized.expand([("linear", ), ("exponential", ), ("logarithmic", ), ("quarter_sine", ), ("half_sine", )])
def test_fade(self, fade_shape):
transform = T.Fade(fade_shape=fade_shape)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)
def test_spectral_centroid(self):
sample_rate = 8000
transform = T.SpectralCentroid(sample_rate=sample_rate)
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)
def test_amplitude_to_db(self):
sample_rate = 8000
transform = T.AmplitudeToDB()
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform])
def test_melscale(self):
sample_rate = 8000
n_fft = 400
n_mels = n_fft // 2 + 1
transform = T.MelScale(sample_rate=sample_rate, n_mels=n_mels)
spec = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2),
n_fft=n_fft, power=1)
self.assert_grad(transform, [spec])
@parameterized.expand([(1.5, "amplitude"), (2, "power"), (10, "db")])
def test_vol(self, gain, gain_type):
sample_rate = 8000
transform = T.Vol(gain=gain, gain_type=gain_type)
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform])
@parameterized.expand([
({'cmn_window': 100, 'min_cmn_window': 50, 'center': False, 'norm_vars': False}, ),
({'cmn_window': 100, 'min_cmn_window': 50, 'center': True, 'norm_vars': False}, ),
({'cmn_window': 100, 'min_cmn_window': 50, 'center': False, 'norm_vars': True}, ),
({'cmn_window': 100, 'min_cmn_window': 50, 'center': True, 'norm_vars': True}, ),
])
def test_sliding_window_cmn(self, kwargs):
n_fft = 10
power = 1
spec = get_spectrogram(
get_whitenoise(sample_rate=200, duration=0.05, n_channels=2),
n_fft=n_fft, power=power)
spec_reshaped = spec.transpose(-1, -2)
transform = T.SlidingWindowCmn(**kwargs)
self.assert_grad(transform, [spec_reshaped])
@unittest.expectedFailure
def test_timestretch_zeros_fail(self):
"""Test that ``T.TimeStretch`` fails gradcheck at 0
This is because ``F.phase_vocoder`` converts data from cartesian to polar coordinate,
which performs ``atan2(img, real)``, and gradient is not defined at 0.
"""
n_fft = 16
transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=0.99)
waveform = torch.zeros(2, 40)
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
self.assert_grad(transform, [spectrogram])
@nested_params(
[0.7, 0.8, 0.9, 1.0, 1.3],
[False, True],
)
def test_timestretch_non_zero(self, rate, test_pseudo_complex):
"""Verify that ``T.TimeStretch`` does not fail if it's not close to 0
``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability
for cases where input is not zero.
As tested above, when spectrogram contains values close to zero, the gradients are unstable
and gradcheck fails.
In this test, we generate spectrogram from random signal, then we push the points around
zero away from the origin.
This process does not reflect the real use-case, and it is not practical for users, but
this helps us understand to what degree the function is differentiable and when not.
"""
n_fft = 16
transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=rate)
waveform = get_whitenoise(sample_rate=40, duration=1, n_channels=2)
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
# 1e-3 is too small (on CPU)
epsilon = 1e-2
too_close = spectrogram.abs() < epsilon
spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs()
if test_pseudo_complex:
spectrogram = torch.view_as_real(spectrogram)
self.assert_grad(transform, [spectrogram])