Skip to content

Commit

Permalink
Add gradgradcheck and move tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Mar 3, 2021
1 parent 85dec8a commit 7454333
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 4 deletions.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.autograd_test_impl import AutogradTestCase
from .autograd_test_impl import AutogradTestCase


class AutogradCPUTest(AutogradTestCase, PytorchTestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
PytorchTestCase,
skipIfNoCuda,
)
from torchaudio_unittest.autograd_test_impl import AutogradTestCase
from .autograd_test_impl import AutogradTestCase


@skipIfNoCuda
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from parameterized import parameterized
from torch.autograd import gradcheck
from torch.autograd import gradcheck, gradgradcheck
import torchaudio.transforms as T

from torchaudio_unittest.common_utils import (
Expand All @@ -17,6 +17,7 @@ def assert_grad(self, transform, *inputs, eps=1e-06, atol=1e-05, rtol=0.001):
i.requires_grad = True
inputs_.append(i.to(dtype=self.dtype, device=self.device))
assert gradcheck(transform, inputs_, eps=eps, atol=atol, rtol=rtol)
assert gradgradcheck(transform, inputs_, eps=eps, atol=atol, rtol=rtol)

@parameterized.expand([
({'pad': 0, 'normalized': False, 'power': None}, ),
Expand All @@ -34,5 +35,5 @@ def assert_grad(self, transform, *inputs, eps=1e-06, atol=1e-05, rtol=0.001):
])
def test_spectrogram(self, kwargs):
transform = T.Spectrogram(**kwargs)
waveform = get_whitenoise(sample_rate=16000, duration=0.05, n_channels=2)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, waveform)

0 comments on commit 7454333

Please sign in to comment.