Skip to content

Commit

Permalink
Add autograd test for T.GriffinLim (#1421)
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris authored Apr 6, 2021
1 parent b388d48 commit 6929331
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
9 changes: 9 additions & 0 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def test_melspectrogram(self):
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)

@parameterized.expand([(0, ), (0.99, )])
def test_griffinlim(self, momentum):
n_fft = 400
n_frames = 5
n_iter = 3
spec = torch.rand(n_fft // 2 + 1, n_frames) * n_fft
transform = T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=False)
self.assert_grad(transform, [spec], nondet_tol=1e-10)

@parameterized.expand([(False, ), (True, )])
def test_mfcc(self, log_mels):
sample_rate = 8000
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def griffinlim(
hop_length=hop_length,
win_length=win_length,
window=window,
length=length).float()
length=length)

# Rebuild the spectrogram
rebuilt = torch.view_as_real(
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(self,
super(GriffinLim, self).__init__()

assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
assert momentum > 0, 'momentum={} < 0'.format(momentum)
assert momentum >= 0, 'momentum={} < 0'.format(momentum)

self.n_fft = n_fft
self.n_iter = n_iter
Expand Down

0 comments on commit 6929331

Please sign in to comment.