Skip to content

Commit

Permalink
migrate to complex dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
anjali411 authored and mthrok committed Mar 30, 2021
1 parent 512c2fa commit d27a820
Show file tree
Hide file tree
Showing 13 changed files with 296 additions and 99 deletions.
14 changes: 13 additions & 1 deletion test/torchaudio_unittest/functional/functional_cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
skipIfNoSox,
)

from .functional_impl import Lfilter, Spectrogram
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex


class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
Expand All @@ -41,6 +41,18 @@ class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
device = torch.device('cpu')


class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')


class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
def test_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w:
Expand Down
16 changes: 15 additions & 1 deletion test/torchaudio_unittest/functional/functional_cuda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest

from torchaudio_unittest import common_utils
from .functional_impl import Lfilter, Spectrogram
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex


@common_utils.skipIfNoCuda
Expand Down Expand Up @@ -31,3 +31,17 @@ class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
37 changes: 37 additions & 0 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import torch
import torchaudio.functional as F
from parameterized import parameterized
import numpy as np
from scipy import signal

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import nested_params


class Lfilter(common_utils.TestBaseMixin):
Expand Down Expand Up @@ -89,3 +91,38 @@ def test_grad_at_zero(self, power):
)
spec.sum().backward()
assert not x.grad.isnan().sum()


class FunctionalComplex(common_utils.TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None

@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
hop_length = 256
num_freq = 1025
num_frames = 400
batch_size = 2

torch.random.manual_seed(42)
spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)

phase_advance = torch.linspace(
0,
np.pi * hop_length,
num_freq,
dtype=self.real_dtype, device=self.device)[..., None]

spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)

assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape
60 changes: 29 additions & 31 deletions test/torchaudio_unittest/functional/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import unittest
from distutils.version import StrictVersion

Expand Down Expand Up @@ -130,45 +129,44 @@ def test_resample(self):


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestPhaseVocoder(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[(2, 1025, 400, 2)],
[0.5, 1.01, 1.3],
[256]
)))
def test_phase_vocoder(self, shape, rate, hop_length):
class TestFunctionalComplex(common_utils.TorchaudioTestCase):
def _test_phase_vocoder(self, rate, test_pseudo_complex=False):
hop_length = 256
num_freq = 1025
num_frames = 400
torch.random.manual_seed(42)

# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
torch.random.manual_seed(42)
complex_specgrams = torch.randn(*shape)
complex_specgrams = complex_specgrams.type(torch.float64)
spec = torch.randn(num_freq, num_frames, dtype=torch.complex128)
phase_advance = torch.linspace(
0,
np.pi * hop_length,
complex_specgrams.shape[-3],
num_freq,
dtype=torch.float64)[..., None]

complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)
stretched = F.phase_vocoder(
torch.view_as_real(spec) if test_pseudo_complex else spec,
rate=rate, phase_advance=phase_advance)

# == Test shape
expected_size = list(complex_specgrams.size())
expected_size[-2] = int(np.ceil(expected_size[-2] / rate))

assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
assert complex_specgrams_stretch.size() == torch.Size(expected_size)

# == Test values
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
mono_complex_specgram = complex_specgrams[index].numpy()
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j
expected_complex_stretch = librosa.phase_vocoder(
mono_complex_specgram,
expected_stretched = librosa.phase_vocoder(
spec.numpy(),
rate=rate,
hop_length=hop_length)

complex_stretch = complex_specgrams_stretch[index].numpy()
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]

self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5)
self.assertEqual(
torch.view_as_complex(stretched) if test_pseudo_complex else stretched,
torch.from_numpy(expected_stretched))

@parameterized.expand(
[(0.5, ), (1.01, ), (1.3, ), ],
)
def test_phase_vocoder(self, rate):
self._test_phase_vocoder(rate)

@parameterized.expand(
[(0.5, ), (1.01, ), (1.3, ), ],
)
def test_phase_vocoder_pseudo_complex(self, rate):
self._test_phase_vocoder(rate, test_pseudo_complex=True)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Functional
from .torchscript_consistency_impl import Functional, FunctionalComplex


class TestFunctionalFloat32(Functional, PytorchTestCase):
Expand All @@ -12,3 +12,15 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Functional
from .torchscript_consistency_impl import Functional, FunctionalComplex


@skipIfNoCuda
Expand All @@ -14,3 +14,17 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')


@skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cuda')
50 changes: 36 additions & 14 deletions test/torchaudio_unittest/functional/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,29 +547,51 @@ def func(tensor):
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)

@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
raise unittest.SkipTest("Only float32, cpu is supported.")

def func(tensor):
sample_rate: float = 44100.
return F.compute_kaldi_pitch(tensor, sample_rate)

tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)


class FunctionalComplex:
complex_dtype = None
real_dtype = None
device = None

def _assert_consistency(self, func, tensor):
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch.jit.script(func)

# on complex dtype
output = func(tensor)
ts_output = ts_func(tensor)
self.assertEqual(ts_output, output)

# on pseudo complex dtype
tensor = torch.view_as_real(tensor)
output = func(tensor)
ts_output = ts_func(tensor)
self.assertEqual(ts_output, output)

def test_phase_vocoder(self):
def func(tensor, device: torch.device = self.device):
n_freq = tensor.size(-2 if tensor.is_complex() else -3)
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
tensor.shape[-3],
n_freq,
dtype=torch.float64,
).to(device)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)

tensor = torch.randn(2, 1025, 400, 2)
self._assert_consistency(func, tensor)

@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
raise unittest.SkipTest("Only float32, cpu is supported.")

def func(tensor):
sample_rate: float = 44100.
return F.compute_kaldi_pitch(tensor, sample_rate)

tensor = common_utils.get_whitenoise(sample_rate=44100)
tensor = torch.randn(2, 1025, 400)
self._assert_consistency(func, tensor)
35 changes: 21 additions & 14 deletions test/torchaudio_unittest/transforms/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,28 @@ def test_batch_mfcc(self):
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)

def test_batch_TimeStretch(self):
def _assert_batch_TimeStretch(self, complex):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100

rate = 2

complex_specgrams = torch.view_as_real(
torch.stft(
input=waveform,
n_fft=2048,
hop_length=512,
win_length=2048,
window=torch.hann_window(2048),
center=True,
pad_mode='reflect',
normalized=True,
onesided=True,
return_complex=True,
)
complex_specgrams = torch.stft(
input=waveform,
n_fft=2048,
hop_length=512,
win_length=2048,
window=torch.hann_window(2048),
center=True,
pad_mode='reflect',
normalized=True,
onesided=True,
return_complex=True,
)

if not complex:
complex_specgrams = torch.view_as_real(complex_specgrams)

# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
Expand All @@ -167,6 +168,12 @@ def test_batch_TimeStretch(self):

self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)

def test_batch_TimeStretch_complex(self):
self._assert_batch_TimeStretch(complex=True)

def test_batch_TimeStretch_paseudo_complex(self):
self._assert_batch_TimeStretch(complex=False)

def test_batch_Fade(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms
from .torchscript_consistency_impl import Transforms, TransformsComplex


class TestTransformsFloat32(Transforms, PytorchTestCase):
Expand All @@ -12,3 +12,15 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
Loading

0 comments on commit d27a820

Please sign in to comment.