Skip to content

Commit

Permalink
moved the tests back to the existing test files; added them in a new …
Browse files Browse the repository at this point in the history
…class
  • Loading branch information
anjali411 committed Aug 11, 2020
1 parent c3d9b8e commit 84d954c
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 126 deletions.
36 changes: 36 additions & 0 deletions test/torchaudio_unittest/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,39 @@ def test_batch_Vol(self):
# Batch then transform
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)

class TestTransformsWithComplexTensors(common_utils.TorchaudioTestCase):
def test_batch_TimeStretch(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = common_utils.load_wav(test_filepath) # (2, 278756), 44100

kwargs = {
'n_fft': 2048,
'hop_length': 512,
'win_length': 2048,
'window': torch.hann_window(2048),
'center': True,
'pad_mode': 'reflect',
'normalized': True,
'onesided': True,
}
rate = 2

complex_specgrams = torch.stft(waveform, **kwargs)
complex_specgrams = torch.view_as_complex(complex_specgrams)

# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
hop_length=512,
)(complex_specgrams).repeat(3, 1, 1, 1, 1)

# Batch then transform
computed = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1))

self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
49 changes: 0 additions & 49 deletions test/torchaudio_unittest/complex_batch_consistency_test.py

This file was deleted.

57 changes: 0 additions & 57 deletions test/torchaudio_unittest/complex_librosa_compatibility_test.py

This file was deleted.

38 changes: 38 additions & 0 deletions test/torchaudio_unittest/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import unittest
from distutils.version import StrictVersion
import parameterized

import torch
import torchaudio
Expand Down Expand Up @@ -111,6 +112,43 @@ def test_amplitude_to_DB(self):
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestFunctionalWithComplexTensors(common_utils.TorchaudioTestCase):
"""Test suite for functions in `functional` module using as input tensors with complex dtypes."""
@parameterized.expand([
(0.5,), (1.01,), (1.3,)
])
def test_phase_vocoder(self, rate):
torch.random.manual_seed(48)
complex_specgrams = torch.randn(2, 1025, 400, dtype=torch.cdouble)
hop_length = 256

# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.

phase_advance = torch.linspace(0, np.pi * hop_length,
complex_specgrams.shape[-2], dtype=torch.double)[..., None]

complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)

# == Test shape
expected_size = list(complex_specgrams.size())
expected_size[-1] = int(np.ceil(expected_size[-1] / rate))

assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
assert complex_specgrams_stretch.size() == torch.Size(expected_size)

# == Test values
index = [0] + [slice(None)] * 2
mono_complex_specgram = complex_specgrams[index].numpy()
expected_complex_stretch = librosa.phase_vocoder(mono_complex_specgram,
rate=rate,
hop_length=hop_length)

self.assertEqual(complex_specgrams_stretch[index], torch.from_numpy(expected_complex_stretch))


@pytest.mark.parametrize('complex_specgrams', [
torch.randn(2, 1025, 400, 2)
])
Expand Down
26 changes: 14 additions & 12 deletions test/torchaudio_unittest/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def func(tensor):
self._assert_consistency(func, waveform)


class TransformsWithComplexDtypes(common_utils.TestBaseMixin):
class TransformsMixin:
"""Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.dtype)
Expand All @@ -540,18 +540,8 @@ def _assert_consistency(self, transform, tensor):
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)

def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble)
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)


class Transforms(common_utils.TestBaseMixin):
class TransformsWithComplexDtypes(TransformsMixin, common_utils.TestBaseMixin):
"""Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.dtype)
Expand All @@ -562,6 +552,18 @@ def _assert_consistency(self, transform, tensor):
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)

def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10))
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)


class Transforms(TransformsMixin, common_utils.TestBaseMixin):
def test_Spectrogram(self):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.Spectrogram(), tensor)
Expand Down
16 changes: 8 additions & 8 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,27 +468,27 @@ def phase_vocoder(
`(..., freq, ceil(time/rate), complex=2)` or
a complex dtype and dimension of `(..., freq, ceil(time/rate))`.
Example - old API
Example - New API (using tensors with complex dtype)
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time, complex=2)
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
>>> # (channel, freq, time)
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
>>> rate = 1.3 # Speed up by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, freq)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231, 2])
torch.Size([2, 1025, 231])
Example - New API (using tensors with complex dtype)
Example - Old API (using real tensors with shape (..., complex=2))
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time)
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
>>> # (channel, freq, time, complex=2)
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
>>> rate = 1.3 # Speed up by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, freq)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231])
torch.Size([2, 1025, 231, 2])
"""
use_complex = complex_specgrams.is_complex()
shape = complex_specgrams.size()
Expand Down

0 comments on commit 84d954c

Please sign in to comment.