Skip to content

Commit

Permalink
Revert "Make F.phase_vocoder and T.TimeStretch handle complex dty…
Browse files Browse the repository at this point in the history
…pe (pytorch#1410)"

This reverts commit 0433b7a.
  • Loading branch information
mthrok committed Apr 5, 2021
1 parent 8ef832f commit 554383e
Show file tree
Hide file tree
Showing 13 changed files with 117 additions and 291 deletions.
14 changes: 1 addition & 13 deletions test/torchaudio_unittest/functional/functional_cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
skipIfNoSox,
)

from .functional_impl import Lfilter, Spectrogram, FunctionalComplex
from .functional_impl import Lfilter, Spectrogram


class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
Expand All @@ -41,18 +41,6 @@ class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
device = torch.device('cpu')


class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')


class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
def test_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w:
Expand Down
16 changes: 1 addition & 15 deletions test/torchaudio_unittest/functional/functional_cuda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest

from torchaudio_unittest import common_utils
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex
from .functional_impl import Lfilter, Spectrogram


@common_utils.skipIfNoCuda
Expand Down Expand Up @@ -31,17 +31,3 @@ class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
38 changes: 0 additions & 38 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
import torch
import torchaudio.functional as F
from parameterized import parameterized
import numpy as np
from scipy import signal

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import nested_params


class Lfilter(common_utils.TestBaseMixin):
Expand Down Expand Up @@ -91,39 +89,3 @@ def test_grad_at_zero(self, power):
)
spec.sum().backward()
assert not x.grad.isnan().sum()


class FunctionalComplex(common_utils.TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None

@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
"""Verify the output shape of phase vocoder"""
hop_length = 256
num_freq = 1025
num_frames = 400
batch_size = 2

torch.random.manual_seed(42)
spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)

phase_advance = torch.linspace(
0,
np.pi * hop_length,
num_freq,
dtype=self.real_dtype, device=self.device)[..., None]

spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)

assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape
53 changes: 30 additions & 23 deletions test/torchaudio_unittest/functional/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import unittest
from distutils.version import StrictVersion

Expand All @@ -14,9 +15,6 @@
import librosa

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
nested_params,
)


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
Expand Down Expand Up @@ -113,36 +111,45 @@ def test_amplitude_to_DB(self):


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestFunctionalComplex(common_utils.TorchaudioTestCase):
@nested_params(
class TestPhaseVocoder(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[(2, 1025, 400, 2)],
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder(self, rate, test_pseudo_complex):
hop_length = 256
num_freq = 1025
num_frames = 400
torch.random.manual_seed(42)

[256]
)))
def test_phase_vocoder(self, shape, rate, hop_length):
# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
spec = torch.randn(num_freq, num_frames, dtype=torch.complex128)
torch.random.manual_seed(42)
complex_specgrams = torch.randn(*shape)
complex_specgrams = complex_specgrams.type(torch.float64)
phase_advance = torch.linspace(
0,
np.pi * hop_length,
num_freq,
complex_specgrams.shape[-3],
dtype=torch.float64)[..., None]

stretched = F.phase_vocoder(
torch.view_as_real(spec) if test_pseudo_complex else spec,
rate=rate, phase_advance=phase_advance)
complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)

expected_stretched = librosa.phase_vocoder(
spec.numpy(),
# == Test shape
expected_size = list(complex_specgrams.size())
expected_size[-2] = int(np.ceil(expected_size[-2] / rate))

assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
assert complex_specgrams_stretch.size() == torch.Size(expected_size)

# == Test values
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
mono_complex_specgram = complex_specgrams[index].numpy()
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j
expected_complex_stretch = librosa.phase_vocoder(
mono_complex_specgram,
rate=rate,
hop_length=hop_length)

self.assertEqual(
torch.view_as_complex(stretched) if test_pseudo_complex else stretched,
torch.from_numpy(expected_stretched))
complex_stretch = complex_specgrams_stretch[index].numpy()
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]

self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Functional, FunctionalComplex
from .torchscript_consistency_impl import Functional


class TestFunctionalFloat32(Functional, PytorchTestCase):
Expand All @@ -12,15 +12,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Functional, FunctionalComplex
from .torchscript_consistency_impl import Functional


@skipIfNoCuda
Expand All @@ -14,17 +14,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')


@skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cuda')
53 changes: 15 additions & 38 deletions test/torchaudio_unittest/functional/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torchaudio.functional as F
from parameterized import parameterized

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
Expand Down Expand Up @@ -552,6 +551,21 @@ def func(tensor):
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)

def test_phase_vocoder(self):
def func(tensor, device: torch.device = self.device):
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
tensor.shape[-3],
dtype=torch.float64,
).to(device)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)

tensor = torch.randn(2, 1025, 400, 2)
self._assert_consistency(func, tensor)

@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
Expand All @@ -563,40 +577,3 @@ def func(tensor):

tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)


class FunctionalComplex:
complex_dtype = None
real_dtype = None
device = None

def _assert_consistency(self, func, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch.jit.script(func)

if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = func(tensor)
ts_output = ts_func(tensor)
self.assertEqual(ts_output, output)

@parameterized.expand([(True, ), (False, )])
def test_phase_vocoder(self, test_paseudo_complex):
def func(tensor):
is_complex = tensor.is_complex()

n_freq = tensor.size(-2 if is_complex else -3)
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
n_freq,
dtype=(torch.real(tensor) if is_complex else tensor).dtype,
device=tensor.device,
)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)

tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
self._assert_consistency(func, tensor, test_paseudo_complex)
36 changes: 22 additions & 14 deletions test/torchaudio_unittest/transforms/batch_consistency_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Test numerical consistency among single input and batched input."""
import torch
import torchaudio
from parameterized import parameterized

from torchaudio_unittest import common_utils

Expand Down Expand Up @@ -131,31 +130,40 @@ def test_batch_mfcc(self):
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)

@parameterized.expand([(True, ), (False, )])
def test_batch_TimeStretch(self, test_pseudo_complex):
def test_batch_TimeStretch(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100

rate = 2
num_freq = 1025
num_frames = 400

spec = torch.randn(num_freq, num_frames, dtype=torch.complex64)
pattern = [3, 1, 1, 1]
if test_pseudo_complex:
spec = torch.view_as_real(spec)
pattern += [1]
complex_specgrams = torch.view_as_real(
torch.stft(
input=waveform,
n_fft=2048,
hop_length=512,
win_length=2048,
window=torch.hann_window(2048),
center=True,
pad_mode='reflect',
normalized=True,
onesided=True,
return_complex=True,
)
)

# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=num_freq,
n_freq=1025,
hop_length=512,
)(spec).repeat(*pattern)
)(complex_specgrams).repeat(3, 1, 1, 1, 1)

# Batch then transform
computed = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=num_freq,
n_freq=1025,
hop_length=512,
)(spec.repeat(*pattern))
)(complex_specgrams.repeat(3, 1, 1, 1, 1))

self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsComplex
from .torchscript_consistency_impl import Transforms


class TestTransformsFloat32(Transforms, PytorchTestCase):
Expand All @@ -12,15 +12,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
Loading

0 comments on commit 554383e

Please sign in to comment.