diff --git a/test/torchaudio_unittest/functional/functional_cpu_test.py b/test/torchaudio_unittest/functional/functional_cpu_test.py index 7669d96744..9f4b4f1a64 100644 --- a/test/torchaudio_unittest/functional/functional_cpu_test.py +++ b/test/torchaudio_unittest/functional/functional_cpu_test.py @@ -14,7 +14,7 @@ skipIfNoSox, ) -from .functional_impl import Lfilter, Spectrogram +from .functional_impl import Lfilter, Spectrogram, FunctionalComplex class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): @@ -41,6 +41,18 @@ class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase): device = torch.device('cpu') +class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cpu') + + +class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase): + complex_dtype = torch.complex128 + real_dtype = torch.float64 + device = torch.device('cpu') + + class TestCreateFBMatrix(common_utils.TorchaudioTestCase): def test_no_warning_high_n_freq(self): with warnings.catch_warnings(record=True) as w: diff --git a/test/torchaudio_unittest/functional/functional_cuda_test.py b/test/torchaudio_unittest/functional/functional_cuda_test.py index f4db6ca3a0..62d9eff3eb 100644 --- a/test/torchaudio_unittest/functional/functional_cuda_test.py +++ b/test/torchaudio_unittest/functional/functional_cuda_test.py @@ -2,7 +2,7 @@ import unittest from torchaudio_unittest import common_utils -from .functional_impl import Lfilter, Spectrogram +from .functional_impl import Lfilter, Spectrogram, FunctionalComplex @common_utils.skipIfNoCuda @@ -31,3 +31,17 @@ class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase): class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase): dtype = torch.float64 device = torch.device('cuda') + + +@common_utils.skipIfNoCuda +class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cuda') + + +@common_utils.skipIfNoCuda +class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index f7f8c17ecb..c4f863fe45 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -2,9 +2,11 @@ import torch import torchaudio.functional as F from parameterized import parameterized +import numpy as np from scipy import signal from torchaudio_unittest import common_utils +from torchaudio_unittest.common_utils import nested_params class Lfilter(common_utils.TestBaseMixin): @@ -89,3 +91,39 @@ def test_grad_at_zero(self, power): ) spec.sum().backward() assert not x.grad.isnan().sum() + + +class FunctionalComplex(common_utils.TestBaseMixin): + complex_dtype = None + real_dtype = None + device = None + + @nested_params( + [0.5, 1.01, 1.3], + [True, False], + ) + def test_phase_vocoder_shape(self, rate, test_pseudo_complex): + """Verify the output shape of phase vocoder""" + hop_length = 256 + num_freq = 1025 + num_frames = 400 + batch_size = 2 + + torch.random.manual_seed(42) + spec = torch.randn( + batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device) + if test_pseudo_complex: + spec = torch.view_as_real(spec) + + phase_advance = torch.linspace( + 0, + np.pi * hop_length, + num_freq, + dtype=self.real_dtype, device=self.device)[..., None] + + spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance) + + assert spec.dim() == spec_stretch.dim() + expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))]) + output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape + assert output_shape == expected_shape diff --git a/test/torchaudio_unittest/functional/librosa_compatibility_test.py b/test/torchaudio_unittest/functional/librosa_compatibility_test.py index 5905c239d3..f3ddb354ae 100644 --- a/test/torchaudio_unittest/functional/librosa_compatibility_test.py +++ b/test/torchaudio_unittest/functional/librosa_compatibility_test.py @@ -1,4 +1,3 @@ -import itertools import unittest from distutils.version import StrictVersion @@ -15,6 +14,9 @@ import librosa from torchaudio_unittest import common_utils +from torchaudio_unittest.common_utils import ( + nested_params, +) @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") @@ -130,45 +132,36 @@ def test_resample(self): @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") -class TestPhaseVocoder(common_utils.TorchaudioTestCase): - @parameterized.expand(list(itertools.product( - [(2, 1025, 400, 2)], +class TestFunctionalComplex(common_utils.TorchaudioTestCase): + @nested_params( [0.5, 1.01, 1.3], - [256] - ))) - def test_phase_vocoder(self, shape, rate, hop_length): + [True, False], + ) + def test_phase_vocoder(self, rate, test_pseudo_complex): + hop_length = 256 + num_freq = 1025 + num_frames = 400 + torch.random.manual_seed(42) + # Due to cummulative sum, numerical error in using torch.float32 will # result in bottom right values of the stretched sectrogram to not # match with librosa. - torch.random.manual_seed(42) - complex_specgrams = torch.randn(*shape) - complex_specgrams = complex_specgrams.type(torch.float64) + spec = torch.randn(num_freq, num_frames, dtype=torch.complex128) phase_advance = torch.linspace( 0, np.pi * hop_length, - complex_specgrams.shape[-3], + num_freq, dtype=torch.float64)[..., None] - complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance) + stretched = F.phase_vocoder( + torch.view_as_real(spec) if test_pseudo_complex else spec, + rate=rate, phase_advance=phase_advance) - # == Test shape - expected_size = list(complex_specgrams.size()) - expected_size[-2] = int(np.ceil(expected_size[-2] / rate)) - - assert complex_specgrams.dim() == complex_specgrams_stretch.dim() - assert complex_specgrams_stretch.size() == torch.Size(expected_size) - - # == Test values - index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3 - mono_complex_specgram = complex_specgrams[index].numpy() - mono_complex_specgram = mono_complex_specgram[..., 0] + \ - mono_complex_specgram[..., 1] * 1j - expected_complex_stretch = librosa.phase_vocoder( - mono_complex_specgram, + expected_stretched = librosa.phase_vocoder( + spec.numpy(), rate=rate, hop_length=hop_length) - complex_stretch = complex_specgrams_stretch[index].numpy() - complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1] - - self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5) + self.assertEqual( + torch.view_as_complex(stretched) if test_pseudo_complex else stretched, + torch.from_numpy(expected_stretched)) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py b/test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py index 06871f9bde..d9d34eff76 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py @@ -1,7 +1,7 @@ import torch from torchaudio_unittest.common_utils import PytorchTestCase -from .torchscript_consistency_impl import Functional +from .torchscript_consistency_impl import Functional, FunctionalComplex class TestFunctionalFloat32(Functional, PytorchTestCase): @@ -12,3 +12,15 @@ class TestFunctionalFloat32(Functional, PytorchTestCase): class TestFunctionalFloat64(Functional, PytorchTestCase): dtype = torch.float64 device = torch.device('cpu') + + +class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cpu') + + +class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase): + complex_dtype = torch.complex128 + real_dtype = torch.float64 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py b/test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py index 53ccbad3c9..026e09abfa 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py @@ -1,7 +1,7 @@ import torch from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase -from .torchscript_consistency_impl import Functional +from .torchscript_consistency_impl import Functional, FunctionalComplex @skipIfNoCuda @@ -14,3 +14,17 @@ class TestFunctionalFloat32(Functional, PytorchTestCase): class TestFunctionalFloat64(Functional, PytorchTestCase): dtype = torch.float64 device = torch.device('cuda') + + +@skipIfNoCuda +class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cuda') + + +@skipIfNoCuda +class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase): + complex_dtype = torch.complex128 + real_dtype = torch.float64 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index b6aec04d9c..46d1a98233 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -3,6 +3,7 @@ import torch import torchaudio.functional as F +from parameterized import parameterized from torchaudio_unittest import common_utils @@ -547,21 +548,6 @@ def func(tensor): tensor = common_utils.get_whitenoise(sample_rate=44100) self._assert_consistency(func, tensor) - def test_phase_vocoder(self): - def func(tensor, device: torch.device = self.device): - rate = 0.5 - hop_length = 256 - phase_advance = torch.linspace( - 0, - 3.14 * hop_length, - tensor.shape[-3], - dtype=torch.float64, - ).to(device)[..., None] - return F.phase_vocoder(tensor, rate, phase_advance) - - tensor = torch.randn(2, 1025, 400, 2) - self._assert_consistency(func, tensor) - @common_utils.skipIfNoKaldi def test_compute_kaldi_pitch(self): if self.dtype != torch.float32 or self.device != torch.device('cpu'): @@ -573,3 +559,40 @@ def func(tensor): tensor = common_utils.get_whitenoise(sample_rate=44100) self._assert_consistency(func, tensor) + + +class FunctionalComplex: + complex_dtype = None + real_dtype = None + device = None + + def _assert_consistency(self, func, tensor, test_pseudo_complex=False): + assert tensor.is_complex() + tensor = tensor.to(device=self.device, dtype=self.complex_dtype) + ts_func = torch.jit.script(func) + + if test_pseudo_complex: + tensor = torch.view_as_real(tensor) + output = func(tensor) + ts_output = ts_func(tensor) + self.assertEqual(ts_output, output) + + @parameterized.expand([(True, ), (False, )]) + def test_phase_vocoder(self, test_paseudo_complex): + def func(tensor): + is_complex = tensor.is_complex() + + n_freq = tensor.size(-2 if is_complex else -3) + rate = 0.5 + hop_length = 256 + phase_advance = torch.linspace( + 0, + 3.14 * hop_length, + n_freq, + dtype=(torch.real(tensor) if is_complex else tensor).dtype, + device=tensor.device, + )[..., None] + return F.phase_vocoder(tensor, rate, phase_advance) + + tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2)) + self._assert_consistency(func, tensor, test_paseudo_complex) diff --git a/test/torchaudio_unittest/transforms/batch_consistency_test.py b/test/torchaudio_unittest/transforms/batch_consistency_test.py index 745401c328..cab68e0152 100644 --- a/test/torchaudio_unittest/transforms/batch_consistency_test.py +++ b/test/torchaudio_unittest/transforms/batch_consistency_test.py @@ -1,6 +1,7 @@ """Test numerical consistency among single input and batched input.""" import torch import torchaudio +from parameterized import parameterized from torchaudio_unittest import common_utils @@ -130,40 +131,31 @@ def test_batch_mfcc(self): computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1)) self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5) - def test_batch_TimeStretch(self): - test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') - waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100 - + @parameterized.expand([(True, ), (False, )]) + def test_batch_TimeStretch(self, test_pseudo_complex): rate = 2 + num_freq = 1025 + num_frames = 400 - complex_specgrams = torch.view_as_real( - torch.stft( - input=waveform, - n_fft=2048, - hop_length=512, - win_length=2048, - window=torch.hann_window(2048), - center=True, - pad_mode='reflect', - normalized=True, - onesided=True, - return_complex=True, - ) - ) + spec = torch.randn(num_freq, num_frames, dtype=torch.complex64) + pattern = [3, 1, 1, 1] + if test_pseudo_complex: + spec = torch.view_as_real(spec) + pattern += [1] # Single then transform then batch expected = torchaudio.transforms.TimeStretch( fixed_rate=rate, - n_freq=1025, + n_freq=num_freq, hop_length=512, - )(complex_specgrams).repeat(3, 1, 1, 1, 1) + )(spec).repeat(*pattern) # Batch then transform computed = torchaudio.transforms.TimeStretch( fixed_rate=rate, - n_freq=1025, + n_freq=num_freq, hop_length=512, - )(complex_specgrams.repeat(3, 1, 1, 1, 1)) + )(spec.repeat(*pattern)) self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5) diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py b/test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py index 9092f3a64b..4de32a57bb 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py @@ -1,7 +1,7 @@ import torch from torchaudio_unittest.common_utils import PytorchTestCase -from .torchscript_consistency_impl import Transforms +from .torchscript_consistency_impl import Transforms, TransformsComplex class TestTransformsFloat32(Transforms, PytorchTestCase): @@ -12,3 +12,15 @@ class TestTransformsFloat32(Transforms, PytorchTestCase): class TestTransformsFloat64(Transforms, PytorchTestCase): dtype = torch.float64 device = torch.device('cpu') + + +class TestTransformsComplex64(TransformsComplex, PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cpu') + + +class TestTransformsComplex128(TransformsComplex, PytorchTestCase): + complex_dtype = torch.complex128 + real_dtype = torch.float64 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py b/test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py index 7425647bab..2435b82589 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py @@ -1,7 +1,7 @@ import torch from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase -from .torchscript_consistency_impl import Transforms +from .torchscript_consistency_impl import Transforms, TransformsComplex @skipIfNoCuda @@ -14,3 +14,17 @@ class TestTransformsFloat32(Transforms, PytorchTestCase): class TestTransformsFloat64(Transforms, PytorchTestCase): dtype = torch.float64 device = torch.device('cuda') + + +@skipIfNoCuda +class TestTransformsComplex64(TransformsComplex, PytorchTestCase): + complex_dtype = torch.complex64 + real_dtype = torch.float32 + device = torch.device('cuda') + + +@skipIfNoCuda +class TestTransformsComplex128(TransformsComplex, PytorchTestCase): + complex_dtype = torch.complex128 + real_dtype = torch.float64 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 492e1e4a92..b2117abdb5 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -2,6 +2,7 @@ import torch import torchaudio.transforms as T +from parameterized import parameterized from torchaudio_unittest import common_utils @@ -58,16 +59,6 @@ def test_MuLawDecoding(self): tensor = torch.rand((1, 10)) self._assert_consistency(T.MuLawDecoding(), tensor) - def test_TimeStretch(self): - n_freq = 400 - hop_length = 512 - fixed_rate = 1.3 - tensor = torch.rand((10, 2, n_freq, 10, 2)) - self._assert_consistency( - T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), - tensor, - ) - def test_Fade(self): waveform = common_utils.get_whitenoise() fade_in_len = 3000 @@ -99,3 +90,34 @@ def test_SpectralCentroid(self): sample_rate = 44100 waveform = common_utils.get_whitenoise(sample_rate=sample_rate) self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform) + + +class TransformsComplex: + complex_dtype = None + real_dtype = None + device = None + + def _assert_consistency(self, transform, tensor, test_pseudo_complex=False): + assert tensor.is_complex() + tensor = tensor.to(device=self.device, dtype=self.complex_dtype) + transform = transform.to(device=self.device, dtype=self.real_dtype) + ts_transform = torch.jit.script(transform) + + if test_pseudo_complex: + tensor = torch.view_as_real(tensor) + + output = transform(tensor) + ts_output = ts_transform(tensor) + self.assertEqual(ts_output, output) + + @parameterized.expand([(True, ), (False, )]) + def test_TimeStretch(self, test_pseudo_complex): + n_freq = 400 + hop_length = 512 + fixed_rate = 1.3 + tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2))) + self._assert_consistency( + T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), + tensor, + test_pseudo_complex + ) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 73496bf4cc..a26d7ab7fe 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -565,14 +565,29 @@ def phase_vocoder( factor of ``rate``. Args: - complex_specgrams (Tensor): Dimension of `(..., freq, time, complex=2)` + complex_specgrams (Tensor): + Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)`` + or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype. rate (float): Speed-up factor phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1) Returns: - Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate), complex=2)` + Tensor: + Stretched spectrogram. The resulting tensor is of the same dtype as the input + spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``. - Example + Example - With Tensor of complex dtype + >>> freq, hop_length = 1025, 512 + >>> # (channel, freq, time) + >>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat) + >>> rate = 1.3 # Speed up by 30% + >>> phase_advance = torch.linspace( + >>> 0, math.pi * hop_length, freq)[..., None] + >>> x = phase_vocoder(complex_specgrams, rate, phase_advance) + >>> x.shape # with 231 == ceil(300 / 1.3) + torch.Size([2, 1025, 231]) + + Example - With Tensor of real dtype and extra dimension for complex field >>> freq, hop_length = 1025, 512 >>> # (channel, freq, time, complex=2) >>> complex_specgrams = torch.randn(2, freq, 300, 2) @@ -583,32 +598,48 @@ def phase_vocoder( >>> x.shape # with 231 == ceil(300 / 1.3) torch.Size([2, 1025, 231, 2]) """ + if rate == 1.0: + return complex_specgrams + + if not complex_specgrams.is_complex() and complex_specgrams.size(-1) != 2: + raise ValueError( + "complex_specgrams must be either native complex tensors or " + "real valued tensors with shape (..., 2)") + + is_complex = complex_specgrams.is_complex() + + if not is_complex: + complex_specgrams = torch.view_as_complex(complex_specgrams) # pack batch shape = complex_specgrams.size() - complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) - - time_steps = torch.arange(0, - complex_specgrams.size(-2), - rate, - device=complex_specgrams.device, - dtype=complex_specgrams.dtype) + complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:])) + + # Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32 + # Note torch.real is a view so it does not incur any memory copy. + real_dtype = torch.real(complex_specgrams).dtype + time_steps = torch.arange( + 0, + complex_specgrams.size(-1), + rate, + device=complex_specgrams.device, + dtype=real_dtype) alphas = time_steps % 1.0 - phase_0 = angle(complex_specgrams[..., :1, :]) + phase_0 = complex_specgrams[..., :1].angle() # Time Padding - complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) + complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2]) # (new_bins, freq, 2) - complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long()) - complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long()) + complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long()) + complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long()) - angle_0 = angle(complex_specgrams_0) - angle_1 = angle(complex_specgrams_1) + angle_0 = complex_specgrams_0.angle() + angle_1 = complex_specgrams_1.angle() - norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1) - norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1) + norm_0 = complex_specgrams_0.abs() + norm_1 = complex_specgrams_1.abs() phase = angle_1 - angle_0 - phase_advance phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi)) @@ -620,14 +651,13 @@ def phase_vocoder( mag = alphas * norm_1 + (1 - alphas) * norm_0 - real_stretch = mag * torch.cos(phase_acc) - imag_stretch = mag * torch.sin(phase_acc) - - complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) + complex_specgrams_stretch = torch.polar(mag, phase_acc) # unpack batch - complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:]) + complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:]) + if not is_complex: + return torch.view_as_real(complex_specgrams_stretch) return complex_specgrams_stretch diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 6f3ea1b6e1..e5dd5b2210 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -729,26 +729,24 @@ def __init__(self, def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor: r""" Args: - complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2). + complex_specgrams (Tensor): + Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)`` + or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype. overriding_rate (float or None, optional): speed up to apply to this batch. If no rate is passed, use ``self.fixed_rate``. (Default: ``None``) Returns: - Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2). + Tensor: + Stretched spectrogram. The resulting tensor is of the same dtype as the input + spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``. """ - assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)" - if overriding_rate is None: + if self.fixed_rate is None: + raise ValueError( + "If no fixed_rate is specified, must pass a valid rate to the forward method.") rate = self.fixed_rate - if rate is None: - raise ValueError("If no fixed_rate is specified" - ", must pass a valid rate to the forward method.") else: rate = overriding_rate - - if rate == 1.0: - return complex_specgrams - return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)