Skip to content

Commit

Permalink
updated torchscript tests
Browse files Browse the repository at this point in the history
  • Loading branch information
anjali411 committed Aug 6, 2020
1 parent 47e2b28 commit 7f98428
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
11 changes: 11 additions & 0 deletions test/torchaudio_unittest/torchscript_consistency_cuda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,14 @@ class TestTransformsFloat32(Transforms, common_utils.PytorchTestCase):
class TestTransformsFloat64(Transforms, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')

@common_utils.skipIfNoCuda
class TestTransformsCFloat(TransformsWithComplexDtypes, common_utils.PytorchTestCase):
dtype = torch.cfloat
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestTransformsCDouble(TransformsWithComplexDtypes, common_utils.PytorchTestCase):
dtype = torch.cdouble
device = torch.device('cuda')
39 changes: 30 additions & 9 deletions test/torchaudio_unittest/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,27 @@ def func(tensor):
self._assert_consistency(func, waveform)


class TransformsWithComplexDtypes(common_utils.TestBaseMixin):
"""Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)

ts_transform = torch.jit.script(transform)
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)

def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble)
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)

class Transforms(common_utils.TestBaseMixin):
"""Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor):
Expand Down Expand Up @@ -581,15 +602,15 @@ def test_MuLawDecoding(self):
tensor = torch.rand((1, 10))
self._assert_consistency(T.MuLawDecoding(), tensor)

def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble)
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)
# def test_TimeStretch(self):
# n_freq = 400
# hop_length = 512
# fixed_rate = 1.3
# tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble)
# self._assert_consistency(
# T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
# tensor,
# )

def test_Fade(self):
waveform = common_utils.get_whitenoise()
Expand Down

0 comments on commit 7f98428

Please sign in to comment.