diff --git a/test/torchscript_consistency_cuda_test.py b/test/torchscript_consistency_cuda_test.py index b317334ce42..49ac5a64ddf 100644 --- a/test/torchscript_consistency_cuda_test.py +++ b/test/torchscript_consistency_cuda_test.py @@ -26,3 +26,14 @@ class TestTransformsFloat32(Transforms, common_utils.PytorchTestCase): class TestTransformsFloat64(Transforms, common_utils.PytorchTestCase): dtype = torch.float64 device = torch.device('cuda') + +@common_utils.skipIfNoCuda +class TestTransformsCFloat(TransformsWithComplexDtypes, common_utils.PytorchTestCase): + dtype = torch.cfloat + device = torch.device('cuda') + + +@common_utils.skipIfNoCuda +class TestTransformsCDouble(TransformsWithComplexDtypes, common_utils.PytorchTestCase): + dtype = torch.cdouble + device = torch.device('cuda') diff --git a/test/torchscript_consistency_impl.py b/test/torchscript_consistency_impl.py index c0d8cada65f..779aa3cec7a 100644 --- a/test/torchscript_consistency_impl.py +++ b/test/torchscript_consistency_impl.py @@ -530,6 +530,27 @@ def func(tensor): self._assert_consistency(func, waveform) +class TransformsWithComplexDtypes(common_utils.TestBaseMixin): + """Implements test for Transforms that are performed for different devices""" + def _assert_consistency(self, transform, tensor): + tensor = tensor.to(device=self.device, dtype=self.dtype) + transform = transform.to(device=self.device, dtype=self.dtype) + + ts_transform = torch.jit.script(transform) + output = transform(tensor) + ts_output = ts_transform(tensor) + self.assertEqual(ts_output, output) + + def test_TimeStretch(self): + n_freq = 400 + hop_length = 512 + fixed_rate = 1.3 + tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble) + self._assert_consistency( + T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), + tensor, + ) + class Transforms(common_utils.TestBaseMixin): """Implements test for Transforms that are performed for different devices""" def _assert_consistency(self, transform, tensor): @@ -582,15 +603,15 @@ def test_MuLawDecoding(self): tensor = torch.rand((1, 10)) self._assert_consistency(T.MuLawDecoding(), tensor) - def test_TimeStretch(self): - n_freq = 400 - hop_length = 512 - fixed_rate = 1.3 - tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble) - self._assert_consistency( - T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), - tensor, - ) + # def test_TimeStretch(self): + # n_freq = 400 + # hop_length = 512 + # fixed_rate = 1.3 + # tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble) + # self._assert_consistency( + # T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), + # tensor, + # ) def test_Fade(self): waveform = common_utils.get_whitenoise()