From b67b2c3dc034018889daaef2ffc3b1e9f1e9e52b Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 3 Mar 2021 20:56:36 +0000 Subject: [PATCH] Address most of feedbacks --- .../transforms/autograd_cpu_test.py | 6 ++---- .../transforms/autograd_cuda_test.py | 6 ++---- .../transforms/autograd_test_impl.py | 13 +++++++------ 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/test/torchaudio_unittest/transforms/autograd_cpu_test.py b/test/torchaudio_unittest/transforms/autograd_cpu_test.py index 66905cddc0d..e4abfb237d1 100644 --- a/test/torchaudio_unittest/transforms/autograd_cpu_test.py +++ b/test/torchaudio_unittest/transforms/autograd_cpu_test.py @@ -1,8 +1,6 @@ -import torch from torchaudio_unittest.common_utils import PytorchTestCase -from .autograd_test_impl import AutogradTestCase +from .autograd_test_impl import AutogradTestMixin -class AutogradCPUTest(AutogradTestCase, PytorchTestCase): +class AutogradCPUTest(AutogradTestMixin, PytorchTestCase): device = 'cpu' - dtype = torch.float64 diff --git a/test/torchaudio_unittest/transforms/autograd_cuda_test.py b/test/torchaudio_unittest/transforms/autograd_cuda_test.py index 458f728288c..ecccb9d897c 100644 --- a/test/torchaudio_unittest/transforms/autograd_cuda_test.py +++ b/test/torchaudio_unittest/transforms/autograd_cuda_test.py @@ -1,12 +1,10 @@ -import torch from torchaudio_unittest.common_utils import ( PytorchTestCase, skipIfNoCuda, ) -from .autograd_test_impl import AutogradTestCase +from .autograd_test_impl import AutogradTestMixin @skipIfNoCuda -class AutogradCUDATest(AutogradTestCase, PytorchTestCase): +class AutogradCUDATest(AutogradTestMixin, PytorchTestCase): device = 'cuda' - dtype = torch.float64 diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index f4913ca3128..7a1a7b6f84d 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -1,4 +1,5 @@ from parameterized import parameterized +import torch from torch.autograd import gradcheck, gradgradcheck import torchaudio.transforms as T @@ -8,16 +9,16 @@ ) -class AutogradTestCase(TestBaseMixin): - def assert_grad(self, transform, *inputs, eps=1e-06, atol=1e-05, rtol=0.001): - transform = transform.to(self.device, self.dtype) +class AutogradTestMixin(TestBaseMixin): + def assert_grad(self, transform, *inputs): + transform = transform.to(dtype=torch.float64, device=self.device) inputs_ = [] for i in inputs: i.requires_grad = True - inputs_.append(i.to(dtype=self.dtype, device=self.device)) - assert gradcheck(transform, inputs_, eps=eps, atol=atol, rtol=rtol) - assert gradgradcheck(transform, inputs_, eps=eps, atol=atol, rtol=rtol) + inputs_.append(i.to(dtype=torch.float64, device=self.device)) + assert gradcheck(transform, inputs_) + assert gradgradcheck(transform, inputs_) @parameterized.expand([ ({'pad': 0, 'normalized': False, 'power': None}, ),