Skip to content

Commit

Permalink
migrate to complex dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
anjali411 authored and mthrok committed Mar 29, 2021
1 parent 512c2fa commit 92f7755
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 49 deletions.
60 changes: 40 additions & 20 deletions test/torchaudio_unittest/functional/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,44 +131,64 @@ def test_resample(self):

@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestPhaseVocoder(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[(2, 1025, 400, 2)],
[0.5, 1.01, 1.3],
[256]
)))
def test_phase_vocoder(self, shape, rate, hop_length):
# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
torch.random.manual_seed(42)
complex_specgrams = torch.randn(*shape)
complex_specgrams = complex_specgrams.type(torch.float64)
def _librosa_consistency(self, complex_specgrams, rate, hop_length):
is_complex = complex_specgrams.is_complex()
time_axis = -1 if is_complex else -2
n_freq = complex_specgrams.size(-2 if is_complex else -3)

phase_advance = torch.linspace(
0,
np.pi * hop_length,
complex_specgrams.shape[-3],
n_freq,
dtype=torch.float64)[..., None]

complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)
complex_specgrams_stretch = F.phase_vocoder(
complex_specgrams, rate=rate, phase_advance=phase_advance)

# == Test shape
expected_size = list(complex_specgrams.size())
expected_size[-2] = int(np.ceil(expected_size[-2] / rate))
expected_size[time_axis] = int(np.ceil(expected_size[time_axis] / rate))

assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
assert complex_specgrams_stretch.size() == torch.Size(expected_size)

# == Test values
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
mono_complex_specgram = complex_specgrams[index].numpy()
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j
if is_complex:
mono_complex_specgram = complex_specgrams[0].numpy()
else:
mono_complex_specgram = complex_specgrams[0].numpy()
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j
print("mono_complex_specgram:", mono_complex_specgram.shape)
expected_complex_stretch = librosa.phase_vocoder(
mono_complex_specgram,
rate=rate,
hop_length=hop_length)

complex_stretch = complex_specgrams_stretch[index].numpy()
complex_stretch = complex_specgrams_stretch[0].numpy()
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]

self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5)


@parameterized.expand(
[(0.5, ), (1.01, ), (1.3, ), ],
)
def test_complex(self, rate):
# Due to cummulative sum, numerical error in using torch.complex64 will (likely)
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
torch.random.manual_seed(42)
tensor = torch.randn(2, 1025, 400, dtype=torch.complex128)
self._librosa_consistency(tensor, rate, 256)

@parameterized.expand(
[(0.5, ), (1.01, ), (1.3, ), ],
)
def test_pseudo_complex(self, rate):
# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
torch.random.manual_seed(42)
tensor = torch.randn(2, 1025, 400, 2, dtype=torch.float64)
self._librosa_consistency(tensor, rate, 256)
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def func(tensor):

def test_spectral_centroid(self):

<<<<<<< HEAD:test/torchaudio_unittest/functional/torchscript_consistency_impl.py
def func(tensor):
sample_rate = 44100
n_fft = 400
Expand All @@ -543,6 +544,99 @@ def func(tensor):
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
return F.spectral_centroid(tensor, sample_rate, pad, window, n_fft, hop, ws)
=======
class TransformsMixin:
"""Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)

ts_transform = torch.jit.script(transform)
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)


class TransformsWithComplexDtypes(TransformsMixin, common_utils.TestBaseMixin):
"""Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)

ts_transform = torch.jit.script(transform)
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)

def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10))
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)


class Transforms(TransformsMixin, common_utils.TestBaseMixin):
def test_Spectrogram(self):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.Spectrogram(), tensor)

def test_GriffinLim(self):
tensor = torch.rand((1, 201, 6))
self._assert_consistency(T.GriffinLim(length=1000, rand_init=False), tensor)

def test_AmplitudeToDB(self):
spec = torch.rand((6, 201))
self._assert_consistency(T.AmplitudeToDB(), spec)

def test_MelScale(self):
spec_f = torch.rand((1, 6, 201))
self._assert_consistency(T.MelScale(), spec_f)

def test_MelSpectrogram(self):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.MelSpectrogram(), tensor)

def test_MFCC(self):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.MFCC(), tensor)

def test_Resample(self):
sr1, sr2 = 16000, 8000
tensor = common_utils.get_whitenoise(sample_rate=sr1)
self._assert_consistency(T.Resample(float(sr1), float(sr2)), tensor)

def test_ComplexNorm(self):
tensor = torch.rand((1, 2, 201, 2))
self._assert_consistency(T.ComplexNorm(), tensor)

def test_MuLawEncoding(self):
tensor = common_utils.get_whitenoise()
self._assert_consistency(T.MuLawEncoding(), tensor)

def test_MuLawDecoding(self):
tensor = torch.rand((1, 10))
self._assert_consistency(T.MuLawDecoding(), tensor)

def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2), dtype=torch.double)
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)

def test_Fade(self):
waveform = common_utils.get_whitenoise()
fade_in_len = 3000
fade_out_len = 3000
self._assert_consistency(T.Fade(fade_in_len, fade_out_len), waveform)
>>>>>>> 17f9987 (migrate to complex dtypes):test/torchaudio_unittest/torchscript_consistency_impl.py

tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)
Expand Down
39 changes: 39 additions & 0 deletions test/torchaudio_unittest/transforms/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def test_batch_Vol(self):
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)

<<<<<<< HEAD:test/torchaudio_unittest/transforms/batch_consistency_test.py
def test_batch_spectral_centroid(self):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
Expand All @@ -201,3 +202,41 @@ def test_batch_spectral_centroid(self):
# Batch then transform
computed = torchaudio.transforms.SpectralCentroid(sample_rate)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
=======

class TestTransformsWithComplexTensors(common_utils.TorchaudioTestCase):
def test_batch_TimeStretch(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = common_utils.load_wav(test_filepath) # (2, 278756), 44100

kwargs = {
'n_fft': 2048,
'hop_length': 512,
'win_length': 2048,
'window': torch.hann_window(2048),
'center': True,
'pad_mode': 'reflect',
'normalized': True,
'onesided': True,
}
rate = 2

complex_specgrams = torch.stft(waveform, **kwargs)
complex_specgrams = torch.view_as_complex(complex_specgrams)

# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
hop_length=512,
)(complex_specgrams).repeat(3, 1, 1, 1, 1)

# Batch then transform
computed = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1))

self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
>>>>>>> 17f9987 (migrate to complex dtypes):test/torchaudio_unittest/batch_consistency_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms
from .torchscript_consistency_impl import Transforms, TransformsComplex


class TestTransformsFloat32(Transforms, PytorchTestCase):
Expand All @@ -12,3 +12,15 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Transforms
from .torchscript_consistency_impl import Transforms, TransformsComplex


@skipIfNoCuda
Expand All @@ -14,3 +14,17 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@skipIfNoCuda
class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


@skipIfNoCuda
class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
42 changes: 32 additions & 10 deletions test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,6 @@ def test_MuLawDecoding(self):
tensor = torch.rand((1, 10))
self._assert_consistency(T.MuLawDecoding(), tensor)

def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2))
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)

def test_Fade(self):
waveform = common_utils.get_whitenoise()
fade_in_len = 3000
Expand Down Expand Up @@ -99,3 +89,35 @@ def test_SpectralCentroid(self):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)


class TransformsComplex:
complex_dtype = None
real_dtype = None
device = None

def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.real_dtype)
ts_transform = torch.jit.script(transform)

# on complex dtype
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)

# on pseudo complex dtype
tensor = torch.view_as_real(tensor)
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)

def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2)))
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)
Loading

0 comments on commit 92f7755

Please sign in to comment.