Skip to content

Commit

Permalink
migrate to complex dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
anjali411 authored and mthrok committed Mar 30, 2021
1 parent 512c2fa commit 88fba5c
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 78 deletions.
56 changes: 37 additions & 19 deletions test/torchaudio_unittest/functional/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,44 +131,62 @@ def test_resample(self):

@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestPhaseVocoder(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[(2, 1025, 400, 2)],
[0.5, 1.01, 1.3],
[256]
)))
def test_phase_vocoder(self, shape, rate, hop_length):
# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
torch.random.manual_seed(42)
complex_specgrams = torch.randn(*shape)
complex_specgrams = complex_specgrams.type(torch.float64)
def _librosa_consistency(self, complex_specgrams, rate, hop_length):
is_complex = complex_specgrams.is_complex()
time_axis = -1 if is_complex else -2
n_freq = complex_specgrams.size(-2 if is_complex else -3)

phase_advance = torch.linspace(
0,
np.pi * hop_length,
complex_specgrams.shape[-3],
n_freq,
dtype=torch.float64)[..., None]

complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)

# == Test shape
expected_size = list(complex_specgrams.size())
expected_size[-2] = int(np.ceil(expected_size[-2] / rate))
expected_size[time_axis] = int(np.ceil(expected_size[time_axis] / rate))

assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
assert complex_specgrams_stretch.size() == torch.Size(expected_size)

# == Test values
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
mono_complex_specgram = complex_specgrams[index].numpy()
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j
if is_complex:
mono_complex_specgram = complex_specgrams[0].numpy()
else:
mono_complex_specgram = complex_specgrams[0].numpy()
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j
print("mono_complex_specgram:", mono_complex_specgram.shape)
expected_complex_stretch = librosa.phase_vocoder(
mono_complex_specgram,
rate=rate,
hop_length=hop_length)

complex_stretch = complex_specgrams_stretch[index].numpy()
complex_stretch = complex_specgrams_stretch[0].numpy()
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]

self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5)

@parameterized.expand(
[(0.5, ), (1.01, ), (1.3, ), ],
)
def test_complex(self, rate):
# Due to cummulative sum, numerical error in using torch.complex64 will (likely)
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
torch.random.manual_seed(42)
tensor = torch.randn(2, 1025, 400, dtype=torch.complex128)
self._librosa_consistency(tensor, rate, 256)

@parameterized.expand(
[(0.5, ), (1.01, ), (1.3, ), ],
)
def test_pseudo_complex(self, rate):
# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
torch.random.manual_seed(42)
tensor = torch.randn(2, 1025, 400, 2, dtype=torch.float64)
self._librosa_consistency(tensor, rate, 256)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Functional
from .torchscript_consistency_impl import Functional, FunctionalComplex


class TestFunctionalFloat32(Functional, PytorchTestCase):
Expand All @@ -12,3 +12,15 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Functional
from .torchscript_consistency_impl import Functional, FunctionalComplex


@skipIfNoCuda
Expand All @@ -14,3 +14,17 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')


@skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cuda')
47 changes: 33 additions & 14 deletions test/torchaudio_unittest/functional/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,29 +547,48 @@ def func(tensor):
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)

@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
raise unittest.SkipTest("Only float32, cpu is supported.")

def func(tensor):
sample_rate: float = 44100.
return F.compute_kaldi_pitch(tensor, sample_rate)


class FunctionalComplex:
complex_dtype = None
real_dtype = None
device = None

def _assert_consistency(self, func, tensor):
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch.jit.script(func)

# on complex dtype
output = func(tensor)
ts_output = ts_func(tensor)
self.assertEqual(ts_output, output)

# on pseudo complex dtype
tensor = torch.view_as_real(tensor)
output = func(tensor)
ts_output = ts_func(tensor)
self.assertEqual(ts_output, output)

def test_phase_vocoder(self):
def func(tensor, device: torch.device = self.device):
n_freq = tensor.size(-2 if tensor.is_complex() else -3)
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
tensor.shape[-3],
n_freq,
dtype=torch.float64,
).to(device)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)

tensor = torch.randn(2, 1025, 400, 2)
self._assert_consistency(func, tensor)

@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
raise unittest.SkipTest("Only float32, cpu is supported.")

def func(tensor):
sample_rate: float = 44100.
return F.compute_kaldi_pitch(tensor, sample_rate)

tensor = common_utils.get_whitenoise(sample_rate=44100)
tensor = torch.randn(2, 1025, 400)
self._assert_consistency(func, tensor)
35 changes: 21 additions & 14 deletions test/torchaudio_unittest/transforms/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,28 @@ def test_batch_mfcc(self):
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)

def test_batch_TimeStretch(self):
def _assert_batch_TimeStretch(self, complex):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100

rate = 2

complex_specgrams = torch.view_as_real(
torch.stft(
input=waveform,
n_fft=2048,
hop_length=512,
win_length=2048,
window=torch.hann_window(2048),
center=True,
pad_mode='reflect',
normalized=True,
onesided=True,
return_complex=True,
)
complex_specgrams = torch.stft(
input=waveform,
n_fft=2048,
hop_length=512,
win_length=2048,
window=torch.hann_window(2048),
center=True,
pad_mode='reflect',
normalized=True,
onesided=True,
return_complex=True,
)

if not complex:
complex_specgrams = torch.view_as_real(complex_specgrams)

# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
Expand All @@ -167,6 +168,12 @@ def test_batch_TimeStretch(self):

self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)

def test_batch_TimeStretch_complex(self):
self._assert_batch_TimeStretch(complex=True)

def test_batch_TimeStretch_paseudo_complex(self):
self._assert_batch_TimeStretch(complex=False)

def test_batch_Fade(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms
from .torchscript_consistency_impl import Transforms, TransformsComplex


class TestTransformsFloat32(Transforms, PytorchTestCase):
Expand All @@ -12,3 +12,15 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Transforms
from .torchscript_consistency_impl import Transforms, TransformsComplex


@skipIfNoCuda
Expand All @@ -14,3 +14,17 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@skipIfNoCuda
class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')


@skipIfNoCuda
class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cuda')
42 changes: 32 additions & 10 deletions test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,6 @@ def test_MuLawDecoding(self):
tensor = torch.rand((1, 10))
self._assert_consistency(T.MuLawDecoding(), tensor)

def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2))
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)

def test_Fade(self):
waveform = common_utils.get_whitenoise()
fade_in_len = 3000
Expand Down Expand Up @@ -99,3 +89,35 @@ def test_SpectralCentroid(self):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)


class TransformsComplex:
complex_dtype = None
real_dtype = None
device = None

def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.real_dtype)
ts_transform = torch.jit.script(transform)

# on complex dtype
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)

# on pseudo complex dtype
tensor = torch.view_as_real(tensor)
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)

def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2)))
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)
Loading

0 comments on commit 88fba5c

Please sign in to comment.