From d27a820b66b90f183fbf314477d3c198d8594680 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Wed, 1 Jul 2020 13:00:26 -0700 Subject: [PATCH 1/8] migrate to complex dtypes --- .../functional/functional_cpu_test.py | 14 +++- .../functional/functional_cuda_test.py | 16 ++++- .../functional/functional_impl.py | 37 ++++++++++ .../functional/librosa_compatibility_test.py | 60 ++++++++-------- .../torchscript_consistency_cpu_test.py | 14 +++- .../torchscript_consistency_cuda_test.py | 16 ++++- .../torchscript_consistency_impl.py | 50 ++++++++++---- .../transforms/batch_consistency_test.py | 35 ++++++---- .../torchscript_consistency_cpu_test.py | 14 +++- .../torchscript_consistency_cuda_test.py | 16 ++++- .../torchscript_consistency_impl.py | 49 ++++++++++--- torchaudio/functional/functional.py | 69 ++++++++++++------- torchaudio/transforms.py | 5 +- 13 files changed, 296 insertions(+), 99 deletions(-) diff --git a/test/torchaudio_unittest/functional/functional_cpu_test.py b/test/torchaudio_unittest/functional/functional_cpu_test.py index 7669d96744..9f4b4f1a64 100644 --- a/test/torchaudio_unittest/functional/functional_cpu_test.py +++ b/test/torchaudio_unittest/functional/functional_cpu_test.py @@ -14,7 +14,7 @@ skipIfNoSox, ) -from .functional_impl import Lfilter, Spectrogram +from .functional_impl import Lfilter, Spectrogram, FunctionalComplex class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): @@ -41,6 +41,18 @@ class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase): device = torch.device('cpu') +class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cpu') + + +class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase): + complex_dtype = torch.complex128 + real_dtype = torch.float64 + device = torch.device('cpu') + + class TestCreateFBMatrix(common_utils.TorchaudioTestCase): def test_no_warning_high_n_freq(self): with warnings.catch_warnings(record=True) as w: diff --git a/test/torchaudio_unittest/functional/functional_cuda_test.py b/test/torchaudio_unittest/functional/functional_cuda_test.py index f4db6ca3a0..62d9eff3eb 100644 --- a/test/torchaudio_unittest/functional/functional_cuda_test.py +++ b/test/torchaudio_unittest/functional/functional_cuda_test.py @@ -2,7 +2,7 @@ import unittest from torchaudio_unittest import common_utils -from .functional_impl import Lfilter, Spectrogram +from .functional_impl import Lfilter, Spectrogram, FunctionalComplex @common_utils.skipIfNoCuda @@ -31,3 +31,17 @@ class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase): class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase): dtype = torch.float64 device = torch.device('cuda') + + +@common_utils.skipIfNoCuda +class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cuda') + + +@common_utils.skipIfNoCuda +class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index f7f8c17ecb..3cef76103b 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -2,9 +2,11 @@ import torch import torchaudio.functional as F from parameterized import parameterized +import numpy as np from scipy import signal from torchaudio_unittest import common_utils +from torchaudio_unittest.common_utils import nested_params class Lfilter(common_utils.TestBaseMixin): @@ -89,3 +91,38 @@ def test_grad_at_zero(self, power): ) spec.sum().backward() assert not x.grad.isnan().sum() + + +class FunctionalComplex(common_utils.TestBaseMixin): + complex_dtype = None + real_dtype = None + device = None + + @nested_params( + [0.5, 1.01, 1.3], + [True, False], + ) + def test_phase_vocoder_shape(self, rate, test_pseudo_complex): + hop_length = 256 + num_freq = 1025 + num_frames = 400 + batch_size = 2 + + torch.random.manual_seed(42) + spec = torch.randn( + batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device) + if test_pseudo_complex: + spec = torch.view_as_real(spec) + + phase_advance = torch.linspace( + 0, + np.pi * hop_length, + num_freq, + dtype=self.real_dtype, device=self.device)[..., None] + + spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance) + + assert spec.dim() == spec_stretch.dim() + expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))]) + output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape + assert output_shape == expected_shape diff --git a/test/torchaudio_unittest/functional/librosa_compatibility_test.py b/test/torchaudio_unittest/functional/librosa_compatibility_test.py index 5905c239d3..04555830d6 100644 --- a/test/torchaudio_unittest/functional/librosa_compatibility_test.py +++ b/test/torchaudio_unittest/functional/librosa_compatibility_test.py @@ -1,4 +1,3 @@ -import itertools import unittest from distutils.version import StrictVersion @@ -130,45 +129,44 @@ def test_resample(self): @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") -class TestPhaseVocoder(common_utils.TorchaudioTestCase): - @parameterized.expand(list(itertools.product( - [(2, 1025, 400, 2)], - [0.5, 1.01, 1.3], - [256] - ))) - def test_phase_vocoder(self, shape, rate, hop_length): +class TestFunctionalComplex(common_utils.TorchaudioTestCase): + def _test_phase_vocoder(self, rate, test_pseudo_complex=False): + hop_length = 256 + num_freq = 1025 + num_frames = 400 + torch.random.manual_seed(42) + # Due to cummulative sum, numerical error in using torch.float32 will # result in bottom right values of the stretched sectrogram to not # match with librosa. - torch.random.manual_seed(42) - complex_specgrams = torch.randn(*shape) - complex_specgrams = complex_specgrams.type(torch.float64) + spec = torch.randn(num_freq, num_frames, dtype=torch.complex128) phase_advance = torch.linspace( 0, np.pi * hop_length, - complex_specgrams.shape[-3], + num_freq, dtype=torch.float64)[..., None] - complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance) + stretched = F.phase_vocoder( + torch.view_as_real(spec) if test_pseudo_complex else spec, + rate=rate, phase_advance=phase_advance) - # == Test shape - expected_size = list(complex_specgrams.size()) - expected_size[-2] = int(np.ceil(expected_size[-2] / rate)) - - assert complex_specgrams.dim() == complex_specgrams_stretch.dim() - assert complex_specgrams_stretch.size() == torch.Size(expected_size) - - # == Test values - index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3 - mono_complex_specgram = complex_specgrams[index].numpy() - mono_complex_specgram = mono_complex_specgram[..., 0] + \ - mono_complex_specgram[..., 1] * 1j - expected_complex_stretch = librosa.phase_vocoder( - mono_complex_specgram, + expected_stretched = librosa.phase_vocoder( + spec.numpy(), rate=rate, hop_length=hop_length) - complex_stretch = complex_specgrams_stretch[index].numpy() - complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1] - - self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5) + self.assertEqual( + torch.view_as_complex(stretched) if test_pseudo_complex else stretched, + torch.from_numpy(expected_stretched)) + + @parameterized.expand( + [(0.5, ), (1.01, ), (1.3, ), ], + ) + def test_phase_vocoder(self, rate): + self._test_phase_vocoder(rate) + + @parameterized.expand( + [(0.5, ), (1.01, ), (1.3, ), ], + ) + def test_phase_vocoder_pseudo_complex(self, rate): + self._test_phase_vocoder(rate, test_pseudo_complex=True) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py b/test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py index 06871f9bde..d9d34eff76 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py @@ -1,7 +1,7 @@ import torch from torchaudio_unittest.common_utils import PytorchTestCase -from .torchscript_consistency_impl import Functional +from .torchscript_consistency_impl import Functional, FunctionalComplex class TestFunctionalFloat32(Functional, PytorchTestCase): @@ -12,3 +12,15 @@ class TestFunctionalFloat32(Functional, PytorchTestCase): class TestFunctionalFloat64(Functional, PytorchTestCase): dtype = torch.float64 device = torch.device('cpu') + + +class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cpu') + + +class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase): + complex_dtype = torch.complex128 + real_dtype = torch.float64 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py b/test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py index 53ccbad3c9..026e09abfa 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py @@ -1,7 +1,7 @@ import torch from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase -from .torchscript_consistency_impl import Functional +from .torchscript_consistency_impl import Functional, FunctionalComplex @skipIfNoCuda @@ -14,3 +14,17 @@ class TestFunctionalFloat32(Functional, PytorchTestCase): class TestFunctionalFloat64(Functional, PytorchTestCase): dtype = torch.float64 device = torch.device('cuda') + + +@skipIfNoCuda +class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cuda') + + +@skipIfNoCuda +class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase): + complex_dtype = torch.complex128 + real_dtype = torch.float64 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index b6aec04d9c..6495cd7b75 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -547,29 +547,51 @@ def func(tensor): tensor = common_utils.get_whitenoise(sample_rate=44100) self._assert_consistency(func, tensor) + @common_utils.skipIfNoKaldi + def test_compute_kaldi_pitch(self): + if self.dtype != torch.float32 or self.device != torch.device('cpu'): + raise unittest.SkipTest("Only float32, cpu is supported.") + + def func(tensor): + sample_rate: float = 44100. + return F.compute_kaldi_pitch(tensor, sample_rate) + + tensor = common_utils.get_whitenoise(sample_rate=44100) + self._assert_consistency(func, tensor) + + +class FunctionalComplex: + complex_dtype = None + real_dtype = None + device = None + + def _assert_consistency(self, func, tensor): + tensor = tensor.to(device=self.device, dtype=self.complex_dtype) + ts_func = torch.jit.script(func) + + # on complex dtype + output = func(tensor) + ts_output = ts_func(tensor) + self.assertEqual(ts_output, output) + + # on pseudo complex dtype + tensor = torch.view_as_real(tensor) + output = func(tensor) + ts_output = ts_func(tensor) + self.assertEqual(ts_output, output) + def test_phase_vocoder(self): def func(tensor, device: torch.device = self.device): + n_freq = tensor.size(-2 if tensor.is_complex() else -3) rate = 0.5 hop_length = 256 phase_advance = torch.linspace( 0, 3.14 * hop_length, - tensor.shape[-3], + n_freq, dtype=torch.float64, ).to(device)[..., None] return F.phase_vocoder(tensor, rate, phase_advance) - tensor = torch.randn(2, 1025, 400, 2) - self._assert_consistency(func, tensor) - - @common_utils.skipIfNoKaldi - def test_compute_kaldi_pitch(self): - if self.dtype != torch.float32 or self.device != torch.device('cpu'): - raise unittest.SkipTest("Only float32, cpu is supported.") - - def func(tensor): - sample_rate: float = 44100. - return F.compute_kaldi_pitch(tensor, sample_rate) - - tensor = common_utils.get_whitenoise(sample_rate=44100) + tensor = torch.randn(2, 1025, 400) self._assert_consistency(func, tensor) diff --git a/test/torchaudio_unittest/transforms/batch_consistency_test.py b/test/torchaudio_unittest/transforms/batch_consistency_test.py index 745401c328..b2d0a2aa31 100644 --- a/test/torchaudio_unittest/transforms/batch_consistency_test.py +++ b/test/torchaudio_unittest/transforms/batch_consistency_test.py @@ -130,27 +130,28 @@ def test_batch_mfcc(self): computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1)) self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5) - def test_batch_TimeStretch(self): + def _assert_batch_TimeStretch(self, complex): test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100 rate = 2 - complex_specgrams = torch.view_as_real( - torch.stft( - input=waveform, - n_fft=2048, - hop_length=512, - win_length=2048, - window=torch.hann_window(2048), - center=True, - pad_mode='reflect', - normalized=True, - onesided=True, - return_complex=True, - ) + complex_specgrams = torch.stft( + input=waveform, + n_fft=2048, + hop_length=512, + win_length=2048, + window=torch.hann_window(2048), + center=True, + pad_mode='reflect', + normalized=True, + onesided=True, + return_complex=True, ) + if not complex: + complex_specgrams = torch.view_as_real(complex_specgrams) + # Single then transform then batch expected = torchaudio.transforms.TimeStretch( fixed_rate=rate, @@ -167,6 +168,12 @@ def test_batch_TimeStretch(self): self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5) + def test_batch_TimeStretch_complex(self): + self._assert_batch_TimeStretch(complex=True) + + def test_batch_TimeStretch_paseudo_complex(self): + self._assert_batch_TimeStretch(complex=False) + def test_batch_Fade(self): test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100 diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py b/test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py index 9092f3a64b..4de32a57bb 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py @@ -1,7 +1,7 @@ import torch from torchaudio_unittest.common_utils import PytorchTestCase -from .torchscript_consistency_impl import Transforms +from .torchscript_consistency_impl import Transforms, TransformsComplex class TestTransformsFloat32(Transforms, PytorchTestCase): @@ -12,3 +12,15 @@ class TestTransformsFloat32(Transforms, PytorchTestCase): class TestTransformsFloat64(Transforms, PytorchTestCase): dtype = torch.float64 device = torch.device('cpu') + + +class TestTransformsComplex64(TransformsComplex, PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cpu') + + +class TestTransformsComplex128(TransformsComplex, PytorchTestCase): + complex_dtype = torch.complex128 + real_dtype = torch.float64 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py b/test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py index 7425647bab..2435b82589 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py @@ -1,7 +1,7 @@ import torch from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase -from .torchscript_consistency_impl import Transforms +from .torchscript_consistency_impl import Transforms, TransformsComplex @skipIfNoCuda @@ -14,3 +14,17 @@ class TestTransformsFloat32(Transforms, PytorchTestCase): class TestTransformsFloat64(Transforms, PytorchTestCase): dtype = torch.float64 device = torch.device('cuda') + + +@skipIfNoCuda +class TestTransformsComplex64(TransformsComplex, PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cuda') + + +@skipIfNoCuda +class TestTransformsComplex128(TransformsComplex, PytorchTestCase): + complex_dtype = torch.complex128 + real_dtype = torch.float64 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 492e1e4a92..9908a07d72 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -58,16 +58,6 @@ def test_MuLawDecoding(self): tensor = torch.rand((1, 10)) self._assert_consistency(T.MuLawDecoding(), tensor) - def test_TimeStretch(self): - n_freq = 400 - hop_length = 512 - fixed_rate = 1.3 - tensor = torch.rand((10, 2, n_freq, 10, 2)) - self._assert_consistency( - T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), - tensor, - ) - def test_Fade(self): waveform = common_utils.get_whitenoise() fade_in_len = 3000 @@ -99,3 +89,42 @@ def test_SpectralCentroid(self): sample_rate = 44100 waveform = common_utils.get_whitenoise(sample_rate=sample_rate) self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform) + + +class TransformsComplex: + complex_dtype = None + real_dtype = None + device = None + + def _assert_consistency(self, transform, tensor, test_pseudo_complex=False): + tensor = tensor.to(device=self.device, dtype=self.complex_dtype) + transform = transform.to(device=self.device, dtype=self.real_dtype) + ts_transform = torch.jit.script(transform) + + if test_pseudo_complex: + tensor = torch.view_as_real(tensor) + + output = transform(tensor) + ts_output = ts_transform(tensor) + self.assertEqual(ts_output, output) + + def test_TimeStretch(self): + n_freq = 400 + hop_length = 512 + fixed_rate = 1.3 + tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2))) + self._assert_consistency( + T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), + tensor, + ) + + def test_TimeStretch_paseudo_complex(self): + n_freq = 400 + hop_length = 512 + fixed_rate = 1.3 + tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2))) + self._assert_consistency( + T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), + tensor, + test_pseudo_complex=True + ) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 73496bf4cc..bed739897a 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -565,14 +565,30 @@ def phase_vocoder( factor of ``rate``. Args: - complex_specgrams (Tensor): Dimension of `(..., freq, time, complex=2)` + complex_specgrams (Tensor): + Either a real tensor of dimension of `(..., freq, time, complex=2)` + or a tensor of dimension `(..., freq, time)` with complex dtype. rate (float): Speed-up factor phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1) Returns: - Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate), complex=2)` + Tensor: + Complex Specgrams Stretch with either a real dtype and dimension of + `(..., freq, ceil(time/rate), complex=2)` or + a complex dtype and dimension of `(..., freq, ceil(time/rate))`. - Example + Example - With Tensor of complex dtype + >>> freq, hop_length = 1025, 512 + >>> # (channel, freq, time) + >>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat) + >>> rate = 1.3 # Speed up by 30% + >>> phase_advance = torch.linspace( + >>> 0, math.pi * hop_length, freq)[..., None] + >>> x = phase_vocoder(complex_specgrams, rate, phase_advance) + >>> x.shape # with 231 == ceil(300 / 1.3) + torch.Size([2, 1025, 231]) + + Example - With Tensor of real dtype and extra dimension for complex field >>> freq, hop_length = 1025, 512 >>> # (channel, freq, time, complex=2) >>> complex_specgrams = torch.randn(2, freq, 300, 2) @@ -583,32 +599,40 @@ def phase_vocoder( >>> x.shape # with 231 == ceil(300 / 1.3) torch.Size([2, 1025, 231, 2]) """ + is_complex = complex_specgrams.is_complex() + + if not is_complex: + complex_specgrams = torch.view_as_complex(complex_specgrams) # pack batch shape = complex_specgrams.size() - complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) - - time_steps = torch.arange(0, - complex_specgrams.size(-2), - rate, - device=complex_specgrams.device, - dtype=complex_specgrams.dtype) + complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:])) + + # Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32 + # Note torch.real is a view so it does not incur any memory copy. + real_dtype = torch.real(complex_specgrams).dtype + time_steps = torch.arange( + 0, + complex_specgrams.size(-1), + rate, + device=complex_specgrams.device, + dtype=real_dtype) alphas = time_steps % 1.0 - phase_0 = angle(complex_specgrams[..., :1, :]) + phase_0 = complex_specgrams[..., :1].angle() # Time Padding - complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) + complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2]) # (new_bins, freq, 2) - complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long()) - complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long()) + complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long()) + complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long()) - angle_0 = angle(complex_specgrams_0) - angle_1 = angle(complex_specgrams_1) + angle_0 = complex_specgrams_0.angle() + angle_1 = complex_specgrams_1.angle() - norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1) - norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1) + norm_0 = complex_specgrams_0.abs() + norm_1 = complex_specgrams_1.abs() phase = angle_1 - angle_0 - phase_advance phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi)) @@ -620,14 +644,13 @@ def phase_vocoder( mag = alphas * norm_1 + (1 - alphas) * norm_0 - real_stretch = mag * torch.cos(phase_acc) - imag_stretch = mag * torch.sin(phase_acc) - - complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) + complex_specgrams_stretch = torch.polar(mag, phase_acc) # unpack batch - complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:]) + complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:]) + if not is_complex: + return torch.view_as_real(complex_specgrams_stretch) return complex_specgrams_stretch diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 6f3ea1b6e1..7731594870 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -736,7 +736,10 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = Returns: Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2). """ - assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)" + if not complex_specgrams.is_complex() and complex_specgrams.size(-1) != 2: + raise ValueError( + "complex_specgrams must be either complex dtype or " + "real dtype with the last dimension being 2, e.g. shape==(..., complex=2)") if overriding_rate is None: rate = self.fixed_rate From 5653327d8b6ad55da7b5628356d67dc65540a30d Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 30 Mar 2021 16:29:05 +0000 Subject: [PATCH 2/8] update --- .../functional/functional_impl.py | 1 + .../torchscript_consistency_impl.py | 26 +++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index 3cef76103b..c4f863fe45 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -103,6 +103,7 @@ class FunctionalComplex(common_utils.TestBaseMixin): [True, False], ) def test_phase_vocoder_shape(self, rate, test_pseudo_complex): + """Verify the output shape of phase vocoder""" hop_length = 256 num_freq = 1025 num_frames = 400 diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 6495cd7b75..d3eee145d1 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -3,6 +3,7 @@ import torch import torchaudio.functional as F +from parameterized import parameterized from torchaudio_unittest import common_utils @@ -565,33 +566,32 @@ class FunctionalComplex: real_dtype = None device = None - def _assert_consistency(self, func, tensor): + def _assert_consistency(self, func, tensor, test_pseudo_complex=False): tensor = tensor.to(device=self.device, dtype=self.complex_dtype) ts_func = torch.jit.script(func) - # on complex dtype + if test_pseudo_complex: + tensor = torch.view_as_real(tensor) output = func(tensor) ts_output = ts_func(tensor) self.assertEqual(ts_output, output) - # on pseudo complex dtype - tensor = torch.view_as_real(tensor) - output = func(tensor) - ts_output = ts_func(tensor) - self.assertEqual(ts_output, output) + @parameterized.expand([(True, ), (False, )]) + def test_phase_vocoder(self, test_paseudo_complex): + def func(tensor): + is_complex = tensor.is_complex() - def test_phase_vocoder(self): - def func(tensor, device: torch.device = self.device): - n_freq = tensor.size(-2 if tensor.is_complex() else -3) + n_freq = tensor.size(-2 if is_complex else -3) rate = 0.5 hop_length = 256 phase_advance = torch.linspace( 0, 3.14 * hop_length, n_freq, - dtype=torch.float64, - ).to(device)[..., None] + dtype=(torch.real(tensor) if is_complex else tensor).dtype, + device=tensor.device, + )[..., None] return F.phase_vocoder(tensor, rate, phase_advance) tensor = torch.randn(2, 1025, 400) - self._assert_consistency(func, tensor) + self._assert_consistency(func, tensor, test_paseudo_complex) From b6b4be0e006e7e25cb2dcea93d98600a2c244525 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 30 Mar 2021 16:43:57 +0000 Subject: [PATCH 3/8] fix test --- .../functional/torchscript_consistency_impl.py | 3 ++- .../transforms/torchscript_consistency_impl.py | 17 +++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index d3eee145d1..46d1a98233 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -567,6 +567,7 @@ class FunctionalComplex: device = None def _assert_consistency(self, func, tensor, test_pseudo_complex=False): + assert tensor.is_complex() tensor = tensor.to(device=self.device, dtype=self.complex_dtype) ts_func = torch.jit.script(func) @@ -593,5 +594,5 @@ def func(tensor): )[..., None] return F.phase_vocoder(tensor, rate, phase_advance) - tensor = torch.randn(2, 1025, 400) + tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2)) self._assert_consistency(func, tensor, test_paseudo_complex) diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 9908a07d72..b2117abdb5 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -2,6 +2,7 @@ import torch import torchaudio.transforms as T +from parameterized import parameterized from torchaudio_unittest import common_utils @@ -97,6 +98,7 @@ class TransformsComplex: device = None def _assert_consistency(self, transform, tensor, test_pseudo_complex=False): + assert tensor.is_complex() tensor = tensor.to(device=self.device, dtype=self.complex_dtype) transform = transform.to(device=self.device, dtype=self.real_dtype) ts_transform = torch.jit.script(transform) @@ -108,7 +110,8 @@ def _assert_consistency(self, transform, tensor, test_pseudo_complex=False): ts_output = ts_transform(tensor) self.assertEqual(ts_output, output) - def test_TimeStretch(self): + @parameterized.expand([(True, ), (False, )]) + def test_TimeStretch(self, test_pseudo_complex): n_freq = 400 hop_length = 512 fixed_rate = 1.3 @@ -116,15 +119,5 @@ def test_TimeStretch(self): self._assert_consistency( T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), tensor, - ) - - def test_TimeStretch_paseudo_complex(self): - n_freq = 400 - hop_length = 512 - fixed_rate = 1.3 - tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2))) - self._assert_consistency( - T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), - tensor, - test_pseudo_complex=True + test_pseudo_complex ) From 550dc90c2831346517212f3b3f13cfab986f51d8 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 30 Mar 2021 16:46:45 +0000 Subject: [PATCH 4/8] update test --- .../functional/librosa_compatibility_test.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/test/torchaudio_unittest/functional/librosa_compatibility_test.py b/test/torchaudio_unittest/functional/librosa_compatibility_test.py index 04555830d6..9439e605aa 100644 --- a/test/torchaudio_unittest/functional/librosa_compatibility_test.py +++ b/test/torchaudio_unittest/functional/librosa_compatibility_test.py @@ -14,6 +14,9 @@ import librosa from torchaudio_unittest import common_utils +from torcahduio_unittest.common_utils import ( + nested_params, +) @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") @@ -130,7 +133,11 @@ def test_resample(self): @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") class TestFunctionalComplex(common_utils.TorchaudioTestCase): - def _test_phase_vocoder(self, rate, test_pseudo_complex=False): + @nested_params( + [0.5, 1.01, 1.3], + [True, False], + ) + def test_phase_vocoder(self, rate, test_pseudo_complex): hop_length = 256 num_freq = 1025 num_frames = 400 @@ -158,15 +165,3 @@ def _test_phase_vocoder(self, rate, test_pseudo_complex=False): self.assertEqual( torch.view_as_complex(stretched) if test_pseudo_complex else stretched, torch.from_numpy(expected_stretched)) - - @parameterized.expand( - [(0.5, ), (1.01, ), (1.3, ), ], - ) - def test_phase_vocoder(self, rate): - self._test_phase_vocoder(rate) - - @parameterized.expand( - [(0.5, ), (1.01, ), (1.3, ), ], - ) - def test_phase_vocoder_pseudo_complex(self, rate): - self._test_phase_vocoder(rate, test_pseudo_complex=True) From 7bcb7ed3094b6be074594e32a7e187eb4c02107f Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 30 Mar 2021 17:25:48 +0000 Subject: [PATCH 5/8] simplify the test --- .../transforms/batch_consistency_test.py | 43 ++++++------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/test/torchaudio_unittest/transforms/batch_consistency_test.py b/test/torchaudio_unittest/transforms/batch_consistency_test.py index b2d0a2aa31..cab68e0152 100644 --- a/test/torchaudio_unittest/transforms/batch_consistency_test.py +++ b/test/torchaudio_unittest/transforms/batch_consistency_test.py @@ -1,6 +1,7 @@ """Test numerical consistency among single input and batched input.""" import torch import torchaudio +from parameterized import parameterized from torchaudio_unittest import common_utils @@ -130,50 +131,34 @@ def test_batch_mfcc(self): computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1)) self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5) - def _assert_batch_TimeStretch(self, complex): - test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') - waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100 - + @parameterized.expand([(True, ), (False, )]) + def test_batch_TimeStretch(self, test_pseudo_complex): rate = 2 + num_freq = 1025 + num_frames = 400 - complex_specgrams = torch.stft( - input=waveform, - n_fft=2048, - hop_length=512, - win_length=2048, - window=torch.hann_window(2048), - center=True, - pad_mode='reflect', - normalized=True, - onesided=True, - return_complex=True, - ) - - if not complex: - complex_specgrams = torch.view_as_real(complex_specgrams) + spec = torch.randn(num_freq, num_frames, dtype=torch.complex64) + pattern = [3, 1, 1, 1] + if test_pseudo_complex: + spec = torch.view_as_real(spec) + pattern += [1] # Single then transform then batch expected = torchaudio.transforms.TimeStretch( fixed_rate=rate, - n_freq=1025, + n_freq=num_freq, hop_length=512, - )(complex_specgrams).repeat(3, 1, 1, 1, 1) + )(spec).repeat(*pattern) # Batch then transform computed = torchaudio.transforms.TimeStretch( fixed_rate=rate, - n_freq=1025, + n_freq=num_freq, hop_length=512, - )(complex_specgrams.repeat(3, 1, 1, 1, 1)) + )(spec.repeat(*pattern)) self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5) - def test_batch_TimeStretch_complex(self): - self._assert_batch_TimeStretch(complex=True) - - def test_batch_TimeStretch_paseudo_complex(self): - self._assert_batch_TimeStretch(complex=False) - def test_batch_Fade(self): test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100 From a7fe7473134d3ef382a568cb73c2c4bbcb6530e0 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 30 Mar 2021 17:52:06 +0000 Subject: [PATCH 6/8] fix test --- .../functional/librosa_compatibility_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/functional/librosa_compatibility_test.py b/test/torchaudio_unittest/functional/librosa_compatibility_test.py index 9439e605aa..f3ddb354ae 100644 --- a/test/torchaudio_unittest/functional/librosa_compatibility_test.py +++ b/test/torchaudio_unittest/functional/librosa_compatibility_test.py @@ -14,7 +14,7 @@ import librosa from torchaudio_unittest import common_utils -from torcahduio_unittest.common_utils import ( +from torchaudio_unittest.common_utils import ( nested_params, ) From 6b40b5a396c1627f3cb1dff20e12bbb6f6138da9 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 1 Apr 2021 13:59:36 +0000 Subject: [PATCH 7/8] clean up error handling --- torchaudio/functional/functional.py | 8 ++++++++ torchaudio/transforms.py | 11 +++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index bed739897a..2866b0f70f 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -599,6 +599,14 @@ def phase_vocoder( >>> x.shape # with 231 == ceil(300 / 1.3) torch.Size([2, 1025, 231, 2]) """ + if rate == 1.0: + return complex_specgrams + + if not complex_specgrams.is_complex() and complex_specgrams.size(-1) != 2: + raise ValueError( + "complex_specgrams must be either native complex tensors or " + "real valued tensors with shape (..., 2)") + is_complex = complex_specgrams.is_complex() if not is_complex: diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 7731594870..85d56b8258 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -736,16 +736,11 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = Returns: Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2). """ - if not complex_specgrams.is_complex() and complex_specgrams.size(-1) != 2: - raise ValueError( - "complex_specgrams must be either complex dtype or " - "real dtype with the last dimension being 2, e.g. shape==(..., complex=2)") - if overriding_rate is None: + if self.fixed_rate is None: + raise ValueError( + "If no fixed_rate is specified, must pass a valid rate to the forward method.") rate = self.fixed_rate - if rate is None: - raise ValueError("If no fixed_rate is specified" - ", must pass a valid rate to the forward method.") else: rate = overriding_rate From bce0f6bf117d5a700608939050c2278cddff5b75 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 1 Apr 2021 14:17:03 +0000 Subject: [PATCH 8/8] Update docstring --- torchaudio/functional/functional.py | 9 ++++----- torchaudio/transforms.py | 12 ++++++------ 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 2866b0f70f..a26d7ab7fe 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -566,16 +566,15 @@ def phase_vocoder( Args: complex_specgrams (Tensor): - Either a real tensor of dimension of `(..., freq, time, complex=2)` - or a tensor of dimension `(..., freq, time)` with complex dtype. + Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)`` + or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype. rate (float): Speed-up factor phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1) Returns: Tensor: - Complex Specgrams Stretch with either a real dtype and dimension of - `(..., freq, ceil(time/rate), complex=2)` or - a complex dtype and dimension of `(..., freq, ceil(time/rate))`. + Stretched spectrogram. The resulting tensor is of the same dtype as the input + spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``. Example - With Tensor of complex dtype >>> freq, hop_length = 1025, 512 diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 85d56b8258..e5dd5b2210 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -729,12 +729,16 @@ def __init__(self, def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor: r""" Args: - complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2). + complex_specgrams (Tensor): + Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)`` + or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype. overriding_rate (float or None, optional): speed up to apply to this batch. If no rate is passed, use ``self.fixed_rate``. (Default: ``None``) Returns: - Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2). + Tensor: + Stretched spectrogram. The resulting tensor is of the same dtype as the input + spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``. """ if overriding_rate is None: if self.fixed_rate is None: @@ -743,10 +747,6 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = rate = self.fixed_rate else: rate = overriding_rate - - if rate == 1.0: - return complex_specgrams - return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)