Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make F.phase_vocoder and T.TimeStretch handle complex dtype #1410

Merged
merged 8 commits into from
Apr 2, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion test/torchaudio_unittest/functional/functional_cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
skipIfNoSox,
)

from .functional_impl import Lfilter, Spectrogram
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex


class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
Expand All @@ -41,6 +41,18 @@ class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
device = torch.device('cpu')


class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')


class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
def test_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w:
Expand Down
16 changes: 15 additions & 1 deletion test/torchaudio_unittest/functional/functional_cuda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest

from torchaudio_unittest import common_utils
from .functional_impl import Lfilter, Spectrogram
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex


@common_utils.skipIfNoCuda
Expand Down Expand Up @@ -31,3 +31,17 @@ class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
38 changes: 38 additions & 0 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import torch
import torchaudio.functional as F
from parameterized import parameterized
import numpy as np
from scipy import signal

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import nested_params


class Lfilter(common_utils.TestBaseMixin):
Expand Down Expand Up @@ -89,3 +91,39 @@ def test_grad_at_zero(self, power):
)
spec.sum().backward()
assert not x.grad.isnan().sum()


class FunctionalComplex(common_utils.TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None

@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
"""Verify the output shape of phase vocoder"""
hop_length = 256
num_freq = 1025
num_frames = 400
batch_size = 2

torch.random.manual_seed(42)
spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)

phase_advance = torch.linspace(
0,
np.pi * hop_length,
num_freq,
dtype=self.real_dtype, device=self.device)[..., None]

spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)

assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW torch also supports ceil

output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape
60 changes: 29 additions & 31 deletions test/torchaudio_unittest/functional/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import unittest
from distutils.version import StrictVersion

Expand Down Expand Up @@ -130,45 +129,44 @@ def test_resample(self):


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestPhaseVocoder(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[(2, 1025, 400, 2)],
[0.5, 1.01, 1.3],
[256]
)))
def test_phase_vocoder(self, shape, rate, hop_length):
class TestFunctionalComplex(common_utils.TorchaudioTestCase):
def _test_phase_vocoder(self, rate, test_pseudo_complex=False):
hop_length = 256
num_freq = 1025
num_frames = 400
torch.random.manual_seed(42)

# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
torch.random.manual_seed(42)
complex_specgrams = torch.randn(*shape)
complex_specgrams = complex_specgrams.type(torch.float64)
spec = torch.randn(num_freq, num_frames, dtype=torch.complex128)
phase_advance = torch.linspace(
0,
np.pi * hop_length,
complex_specgrams.shape[-3],
num_freq,
dtype=torch.float64)[..., None]

complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)
stretched = F.phase_vocoder(
torch.view_as_real(spec) if test_pseudo_complex else spec,
rate=rate, phase_advance=phase_advance)

# == Test shape
expected_size = list(complex_specgrams.size())
expected_size[-2] = int(np.ceil(expected_size[-2] / rate))

assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
assert complex_specgrams_stretch.size() == torch.Size(expected_size)

# == Test values
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
mono_complex_specgram = complex_specgrams[index].numpy()
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j
expected_complex_stretch = librosa.phase_vocoder(
mono_complex_specgram,
expected_stretched = librosa.phase_vocoder(
spec.numpy(),
rate=rate,
hop_length=hop_length)

complex_stretch = complex_specgrams_stretch[index].numpy()
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]

self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5)
self.assertEqual(
torch.view_as_complex(stretched) if test_pseudo_complex else stretched,
torch.from_numpy(expected_stretched))

@parameterized.expand(
[(0.5, ), (1.01, ), (1.3, ), ],
)
def test_phase_vocoder(self, rate):
self._test_phase_vocoder(rate)

@parameterized.expand(
[(0.5, ), (1.01, ), (1.3, ), ],
)
def test_phase_vocoder_pseudo_complex(self, rate):
self._test_phase_vocoder(rate, test_pseudo_complex=True)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Functional
from .torchscript_consistency_impl import Functional, FunctionalComplex


class TestFunctionalFloat32(Functional, PytorchTestCase):
Expand All @@ -12,3 +12,15 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Functional
from .torchscript_consistency_impl import Functional, FunctionalComplex


@skipIfNoCuda
Expand All @@ -14,3 +14,17 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')


@skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cuda')
52 changes: 37 additions & 15 deletions test/torchaudio_unittest/functional/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torchaudio.functional as F
from parameterized import parameterized

from torchaudio_unittest import common_utils

Expand Down Expand Up @@ -547,21 +548,6 @@ def func(tensor):
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)

def test_phase_vocoder(self):
def func(tensor, device: torch.device = self.device):
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
tensor.shape[-3],
dtype=torch.float64,
).to(device)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)

tensor = torch.randn(2, 1025, 400, 2)
self._assert_consistency(func, tensor)

@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
Expand All @@ -573,3 +559,39 @@ def func(tensor):

tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)


class FunctionalComplex:
complex_dtype = None
real_dtype = None
device = None

def _assert_consistency(self, func, tensor, test_pseudo_complex=False):
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch.jit.script(func)

if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = func(tensor)
ts_output = ts_func(tensor)
self.assertEqual(ts_output, output)

@parameterized.expand([(True, ), (False, )])
def test_phase_vocoder(self, test_paseudo_complex):
def func(tensor):
is_complex = tensor.is_complex()

n_freq = tensor.size(-2 if is_complex else -3)
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
n_freq,
dtype=(torch.real(tensor) if is_complex else tensor).dtype,
device=tensor.device,
)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)

tensor = torch.randn(2, 1025, 400)
self._assert_consistency(func, tensor, test_paseudo_complex)
35 changes: 21 additions & 14 deletions test/torchaudio_unittest/transforms/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,28 @@ def test_batch_mfcc(self):
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)

def test_batch_TimeStretch(self):
def _assert_batch_TimeStretch(self, complex):
mthrok marked this conversation as resolved.
Show resolved Hide resolved
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100

rate = 2

complex_specgrams = torch.view_as_real(
torch.stft(
input=waveform,
n_fft=2048,
hop_length=512,
win_length=2048,
window=torch.hann_window(2048),
center=True,
pad_mode='reflect',
normalized=True,
onesided=True,
return_complex=True,
)
complex_specgrams = torch.stft(
input=waveform,
n_fft=2048,
hop_length=512,
win_length=2048,
window=torch.hann_window(2048),
center=True,
pad_mode='reflect',
normalized=True,
onesided=True,
return_complex=True,
)

if not complex:
complex_specgrams = torch.view_as_real(complex_specgrams)

# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
Expand All @@ -167,6 +168,12 @@ def test_batch_TimeStretch(self):

self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)

def test_batch_TimeStretch_complex(self):
self._assert_batch_TimeStretch(complex=True)

def test_batch_TimeStretch_paseudo_complex(self):
self._assert_batch_TimeStretch(complex=False)

def test_batch_Fade(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms
from .torchscript_consistency_impl import Transforms, TransformsComplex


class TestTransformsFloat32(Transforms, PytorchTestCase):
Expand All @@ -12,3 +12,15 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
Loading