From 7f984283f8e344cd31b2445048abc3e2e977c65d Mon Sep 17 00:00:00 2001 From: anjali411 Date: Wed, 1 Jul 2020 16:45:17 -0700 Subject: [PATCH] updated torchscript tests --- .../torchscript_consistency_cuda_test.py | 11 ++++++ .../torchscript_consistency_impl.py | 39 ++++++++++++++----- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/test/torchaudio_unittest/torchscript_consistency_cuda_test.py b/test/torchaudio_unittest/torchscript_consistency_cuda_test.py index c205e192ed..50ce398e9a 100644 --- a/test/torchaudio_unittest/torchscript_consistency_cuda_test.py +++ b/test/torchaudio_unittest/torchscript_consistency_cuda_test.py @@ -26,3 +26,14 @@ class TestTransformsFloat32(Transforms, common_utils.PytorchTestCase): class TestTransformsFloat64(Transforms, common_utils.PytorchTestCase): dtype = torch.float64 device = torch.device('cuda') + +@common_utils.skipIfNoCuda +class TestTransformsCFloat(TransformsWithComplexDtypes, common_utils.PytorchTestCase): + dtype = torch.cfloat + device = torch.device('cuda') + + +@common_utils.skipIfNoCuda +class TestTransformsCDouble(TransformsWithComplexDtypes, common_utils.PytorchTestCase): + dtype = torch.cdouble + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/torchscript_consistency_impl.py b/test/torchaudio_unittest/torchscript_consistency_impl.py index 40b4858998..e7d5849b72 100644 --- a/test/torchaudio_unittest/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/torchscript_consistency_impl.py @@ -529,6 +529,27 @@ def func(tensor): self._assert_consistency(func, waveform) +class TransformsWithComplexDtypes(common_utils.TestBaseMixin): + """Implements test for Transforms that are performed for different devices""" + def _assert_consistency(self, transform, tensor): + tensor = tensor.to(device=self.device, dtype=self.dtype) + transform = transform.to(device=self.device, dtype=self.dtype) + + ts_transform = torch.jit.script(transform) + output = transform(tensor) + ts_output = ts_transform(tensor) + self.assertEqual(ts_output, output) + + def test_TimeStretch(self): + n_freq = 400 + hop_length = 512 + fixed_rate = 1.3 + tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble) + self._assert_consistency( + T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), + tensor, + ) + class Transforms(common_utils.TestBaseMixin): """Implements test for Transforms that are performed for different devices""" def _assert_consistency(self, transform, tensor): @@ -581,15 +602,15 @@ def test_MuLawDecoding(self): tensor = torch.rand((1, 10)) self._assert_consistency(T.MuLawDecoding(), tensor) - def test_TimeStretch(self): - n_freq = 400 - hop_length = 512 - fixed_rate = 1.3 - tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble) - self._assert_consistency( - T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), - tensor, - ) + # def test_TimeStretch(self): + # n_freq = 400 + # hop_length = 512 + # fixed_rate = 1.3 + # tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble) + # self._assert_consistency( + # T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), + # tensor, + # ) def test_Fade(self): waveform = common_utils.get_whitenoise()