Skip to content

Commit

Permalink
Make F.phase_vocoder and T.TimeStretch handle complex dtype (#1410)
Browse files Browse the repository at this point in the history
1. `F.phase_vocoder` accepts Tensor with complex dtype.
    * The implementation path has been updated from #758 so that they share the same code path by internally converting the input Tensor to complex dtype and performing all the operation on top of it.
    * Adopted `torch.polar` for simpler Tensor generation from magnitude and angle.
2. Updated tests
    * librosa compatibility test for complex dtype and pseudo complex dtype
        * Extracted the output shape check test and moved it to functional so that it will be tested on all the combination of `{CPU | CUDA} x {complex64 | complex128}`
    * TorchScript compatibility test for `F.phase_vocoder` and `T.TimeStretch`.
    * batch consistency test for `T.TimeStretch`.
  • Loading branch information
mthrok authored Apr 2, 2021
1 parent a6cdd6c commit 0433b7a
Show file tree
Hide file tree
Showing 13 changed files with 291 additions and 117 deletions.
14 changes: 13 additions & 1 deletion test/torchaudio_unittest/functional/functional_cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
skipIfNoSox,
)

from .functional_impl import Lfilter, Spectrogram
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex


class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
Expand All @@ -41,6 +41,18 @@ class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
device = torch.device('cpu')


class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')


class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
def test_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w:
Expand Down
16 changes: 15 additions & 1 deletion test/torchaudio_unittest/functional/functional_cuda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest

from torchaudio_unittest import common_utils
from .functional_impl import Lfilter, Spectrogram
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex


@common_utils.skipIfNoCuda
Expand Down Expand Up @@ -31,3 +31,17 @@ class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
38 changes: 38 additions & 0 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import torch
import torchaudio.functional as F
from parameterized import parameterized
import numpy as np
from scipy import signal

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import nested_params


class Lfilter(common_utils.TestBaseMixin):
Expand Down Expand Up @@ -89,3 +91,39 @@ def test_grad_at_zero(self, power):
)
spec.sum().backward()
assert not x.grad.isnan().sum()


class FunctionalComplex(common_utils.TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None

@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
"""Verify the output shape of phase vocoder"""
hop_length = 256
num_freq = 1025
num_frames = 400
batch_size = 2

torch.random.manual_seed(42)
spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)

phase_advance = torch.linspace(
0,
np.pi * hop_length,
num_freq,
dtype=self.real_dtype, device=self.device)[..., None]

spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)

assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape
53 changes: 23 additions & 30 deletions test/torchaudio_unittest/functional/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import unittest
from distutils.version import StrictVersion

Expand All @@ -15,6 +14,9 @@
import librosa

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
nested_params,
)


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
Expand Down Expand Up @@ -130,45 +132,36 @@ def test_resample(self):


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestPhaseVocoder(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[(2, 1025, 400, 2)],
class TestFunctionalComplex(common_utils.TorchaudioTestCase):
@nested_params(
[0.5, 1.01, 1.3],
[256]
)))
def test_phase_vocoder(self, shape, rate, hop_length):
[True, False],
)
def test_phase_vocoder(self, rate, test_pseudo_complex):
hop_length = 256
num_freq = 1025
num_frames = 400
torch.random.manual_seed(42)

# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
torch.random.manual_seed(42)
complex_specgrams = torch.randn(*shape)
complex_specgrams = complex_specgrams.type(torch.float64)
spec = torch.randn(num_freq, num_frames, dtype=torch.complex128)
phase_advance = torch.linspace(
0,
np.pi * hop_length,
complex_specgrams.shape[-3],
num_freq,
dtype=torch.float64)[..., None]

complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)
stretched = F.phase_vocoder(
torch.view_as_real(spec) if test_pseudo_complex else spec,
rate=rate, phase_advance=phase_advance)

# == Test shape
expected_size = list(complex_specgrams.size())
expected_size[-2] = int(np.ceil(expected_size[-2] / rate))

assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
assert complex_specgrams_stretch.size() == torch.Size(expected_size)

# == Test values
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
mono_complex_specgram = complex_specgrams[index].numpy()
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j
expected_complex_stretch = librosa.phase_vocoder(
mono_complex_specgram,
expected_stretched = librosa.phase_vocoder(
spec.numpy(),
rate=rate,
hop_length=hop_length)

complex_stretch = complex_specgrams_stretch[index].numpy()
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]

self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5)
self.assertEqual(
torch.view_as_complex(stretched) if test_pseudo_complex else stretched,
torch.from_numpy(expected_stretched))
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Functional
from .torchscript_consistency_impl import Functional, FunctionalComplex


class TestFunctionalFloat32(Functional, PytorchTestCase):
Expand All @@ -12,3 +12,15 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Functional
from .torchscript_consistency_impl import Functional, FunctionalComplex


@skipIfNoCuda
Expand All @@ -14,3 +14,17 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')


@skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cuda')
53 changes: 38 additions & 15 deletions test/torchaudio_unittest/functional/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torchaudio.functional as F
from parameterized import parameterized

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
Expand Down Expand Up @@ -551,21 +552,6 @@ def func(tensor):
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)

def test_phase_vocoder(self):
def func(tensor, device: torch.device = self.device):
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
tensor.shape[-3],
dtype=torch.float64,
).to(device)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)

tensor = torch.randn(2, 1025, 400, 2)
self._assert_consistency(func, tensor)

@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
Expand All @@ -577,3 +563,40 @@ def func(tensor):

tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)


class FunctionalComplex:
complex_dtype = None
real_dtype = None
device = None

def _assert_consistency(self, func, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch.jit.script(func)

if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = func(tensor)
ts_output = ts_func(tensor)
self.assertEqual(ts_output, output)

@parameterized.expand([(True, ), (False, )])
def test_phase_vocoder(self, test_paseudo_complex):
def func(tensor):
is_complex = tensor.is_complex()

n_freq = tensor.size(-2 if is_complex else -3)
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
n_freq,
dtype=(torch.real(tensor) if is_complex else tensor).dtype,
device=tensor.device,
)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)

tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
self._assert_consistency(func, tensor, test_paseudo_complex)
36 changes: 14 additions & 22 deletions test/torchaudio_unittest/transforms/batch_consistency_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test numerical consistency among single input and batched input."""
import torch
import torchaudio
from parameterized import parameterized

from torchaudio_unittest import common_utils

Expand Down Expand Up @@ -130,40 +131,31 @@ def test_batch_mfcc(self):
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)

def test_batch_TimeStretch(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100

@parameterized.expand([(True, ), (False, )])
def test_batch_TimeStretch(self, test_pseudo_complex):
rate = 2
num_freq = 1025
num_frames = 400

complex_specgrams = torch.view_as_real(
torch.stft(
input=waveform,
n_fft=2048,
hop_length=512,
win_length=2048,
window=torch.hann_window(2048),
center=True,
pad_mode='reflect',
normalized=True,
onesided=True,
return_complex=True,
)
)
spec = torch.randn(num_freq, num_frames, dtype=torch.complex64)
pattern = [3, 1, 1, 1]
if test_pseudo_complex:
spec = torch.view_as_real(spec)
pattern += [1]

# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
n_freq=num_freq,
hop_length=512,
)(complex_specgrams).repeat(3, 1, 1, 1, 1)
)(spec).repeat(*pattern)

# Batch then transform
computed = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
n_freq=num_freq,
hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1))
)(spec.repeat(*pattern))

self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms
from .torchscript_consistency_impl import Transforms, TransformsComplex


class TestTransformsFloat32(Transforms, PytorchTestCase):
Expand All @@ -12,3 +12,15 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
Loading

0 comments on commit 0433b7a

Please sign in to comment.