Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use complex tensors in phase_vocoder #758

Closed
wants to merge 14 commits into from
37 changes: 37 additions & 0 deletions test/torchaudio_unittest/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,40 @@ def test_batch_Vol(self):
# Batch then transform
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)


class TestTransformsWithComplexTensors(common_utils.TorchaudioTestCase):
def test_batch_TimeStretch(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = common_utils.load_wav(test_filepath) # (2, 278756), 44100

kwargs = {
'n_fft': 2048,
'hop_length': 512,
'win_length': 2048,
'window': torch.hann_window(2048),
'center': True,
'pad_mode': 'reflect',
'normalized': True,
'onesided': True,
}
rate = 2

complex_specgrams = torch.stft(waveform, **kwargs)
complex_specgrams = torch.view_as_complex(complex_specgrams)

# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
hop_length=512,
)(complex_specgrams).repeat(3, 1, 1, 1, 1)

# Batch then transform
computed = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1))

self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
38 changes: 38 additions & 0 deletions test/torchaudio_unittest/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import unittest
from distutils.version import StrictVersion
from parameterized import parameterized

import torch
import torchaudio
Expand Down Expand Up @@ -111,6 +112,43 @@ def test_amplitude_to_DB(self):
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestFunctionalWithComplexTensors(common_utils.TorchaudioTestCase):
"""Test suite for functions in `functional` module using as input tensors with complex dtypes."""
@parameterized.expand([
(0.5,), (1.01,), (1.3,)
])
def test_phase_vocoder(self, rate):
torch.random.manual_seed(48)
complex_specgrams = torch.randn(2, 1025, 400, dtype=torch.cdouble)
hop_length = 256

# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.

phase_advance = torch.linspace(0, np.pi * hop_length,
complex_specgrams.shape[-2], dtype=torch.double)[..., None]

complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)

# == Test shape
expected_size = list(complex_specgrams.size())
expected_size[-1] = int(np.ceil(expected_size[-1] / rate))

assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
assert complex_specgrams_stretch.size() == torch.Size(expected_size)

# == Test values
index = [0] + [slice(None)] * 2
mono_complex_specgram = complex_specgrams[index].numpy()
expected_complex_stretch = librosa.phase_vocoder(mono_complex_specgram,
rate=rate,
hop_length=hop_length)

self.assertEqual(complex_specgrams_stretch[index], torch.from_numpy(expected_complex_stretch))


@pytest.mark.parametrize('complex_specgrams', [
torch.randn(2, 1025, 400, 2)
])
Expand Down
14 changes: 13 additions & 1 deletion test/torchaudio_unittest/torchscript_consistency_cuda_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest import common_utils
from .torchscript_consistency_impl import Functional, Transforms
from .torchscript_consistency_impl import Functional, Transforms, TransformsWithComplexDtypes


@common_utils.skipIfNoCuda
Expand All @@ -26,3 +26,15 @@ class TestTransformsFloat32(Transforms, common_utils.PytorchTestCase):
class TestTransformsFloat64(Transforms, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestTransformsCFloat(TransformsWithComplexDtypes, common_utils.PytorchTestCase):
dtype = torch.cfloat
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestTransformsCDouble(TransformsWithComplexDtypes, common_utils.PytorchTestCase):
dtype = torch.cdouble
device = torch.device('cuda')
28 changes: 26 additions & 2 deletions test/torchaudio_unittest/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def func(tensor):
self._assert_consistency(func, waveform)


class Transforms(common_utils.TestBaseMixin):
class TransformsMixin:
"""Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.dtype)
Expand All @@ -540,6 +540,30 @@ def _assert_consistency(self, transform, tensor):
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)


class TransformsWithComplexDtypes(TransformsMixin, common_utils.TestBaseMixin):
"""Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)

ts_transform = torch.jit.script(transform)
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)

def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10))
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)


class Transforms(TransformsMixin, common_utils.TestBaseMixin):
def test_Spectrogram(self):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.Spectrogram(), tensor)
Expand Down Expand Up @@ -585,7 +609,7 @@ def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2))
tensor = torch.rand((10, 2, n_freq, 10, 2), dtype=torch.double)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change related or necessary?
My understanding is that _assert_consistency method will change the dtype/device to appropriate ones so this change has no effect.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh okay removed it!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still see dtype=torch.double ...

self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
Expand Down
97 changes: 66 additions & 31 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,14 +458,28 @@ def phase_vocoder(
factor of ``rate``.

Args:
complex_specgrams (Tensor): Dimension of `(..., freq, time, complex=2)`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something like this?

        complex_specgrams (Tensor): Either a real tensor of dimension of `(..., freq, time, complex=2)`
            or a tensor of dimension `(..., freq, time)` with complex dtype.

We were using "complex tensor" to mean (..., complex=2). This is now ambiguous. What expression do you recommend to refer to a tensor of complex dtype? "tensor with a complex dtype"?

Copy link
Author

@anjali411 anjali411 Aug 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah tensor with a complex dtype sounds good. However this "or" way of documenting could be problematic in case where the function takes more than one complex tensors. Perhaps in those cases, we can add a note stating that either all inputs should be real tensors or all inputs should be of complex dtype.

I think it might be nicer to add a separate example with complex dtype tensors so that it's also clear that the returned output would also be complex (if applicable) etc., especially since we are planning to switch to using complex dtype tensors in the release after the upcoming release

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do find this "or" way of discussing this a little cumbersome, and I agree this will get long if many tensors are involved. We could add a note in each, and just define the args/returns with complex dtype. We can still keep the example for clarity.

"""
    We are migrating to complex-dtype tensors. For backward compatibility reason,
    this function still supports the legacy convention of ending with a dimension of 2
    to represent a complex tensor.

    Args:
        complex_specgrams (Tensor): A tensor of dimension `(..., freq, time)` with complex dtype.
        rate (float): Speed-up factor
        phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
    Returns:
        Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate))`
            with a complex dtype.

    Example

    Example - Legacy
"""

Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P.S. Good suggestion below for example naming

complex_specgrams (Tensor): Either a real tensor of dimension of `(..., freq, time, complex=2)`
or a tensor of dimension `(..., freq, time)` with complex dtype.
rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)

Returns:
Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate), complex=2)`
Tensor: Complex Specgrams Stretch with either a real dtype and dimension of
`(..., freq, ceil(time/rate), complex=2)` or
a complex dtype and dimension of `(..., freq, ceil(time/rate))`.

Example
Example - New API (using tensors with complex dtype)
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time)
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is neat!

>>> rate = 1.3 # Speed up by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, freq)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231])
Copy link
Contributor

@vincentqb vincentqb Aug 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might not need to change the example. we could add a second example, or a comment next to each , 2. other ideas?

Copy link
Author

@anjali411 anjali411 Aug 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have an example with tensors of complex dtype so that users know how to deal with complex tensors. This is an option:

(Old API) Example:
....

(New API) Example:
...

what do you think?

Copy link
Contributor

@vincentqb vincentqb Aug 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion :)

How about standardizing on this?

    Example - New API (using tensors with complex dtype)
    Example - Old API (using tensors with (..., complex=2))


Example - Old API (using real tensors with shape (..., complex=2))
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time, complex=2)
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
Expand All @@ -476,50 +490,71 @@ def phase_vocoder(
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231, 2])
"""

# pack batch
use_complex = complex_specgrams.is_complex()
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))

time_steps = torch.arange(0,
complex_specgrams.size(-2),
rate,
device=complex_specgrams.device,
dtype=complex_specgrams.dtype)

alphas = time_steps % 1.0
phase_0 = angle(complex_specgrams[..., :1, :])

# Time Padding
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2])

# (new_bins, freq, 2)
complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long())
complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long())

angle_0 = angle(complex_specgrams_0)
angle_1 = angle(complex_specgrams_1)

norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1)
norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1)
if use_complex:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an alternative to duplicating all the logic here, you could have instead taken the real tensor, viewed it as complex, and then used the complex codepath (viewing it back as real in the end). Something to consider?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a possibility however the goal is to be able to remove the code in if not use_complex branch after a deprecation cycle and just use the code in the other branch (which has similar logic, however there are some substantial differences, e.g., padding logic).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, in my suggestion, you'd delete the real code immediately :) Anyway, this is NBD

Copy link
Author

@anjali411 anjali411 Aug 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry I misread your comment and thought you meant the other way round! yeah that sounds reasonable to me. cc. @mthrok thoughts?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on more thought, current lack of autograd support and testing for complex, might introduce bc breaking changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry I misread your comment and thought you meant the other way round! yeah that sounds reasonable to me. cc. @mthrok thoughts?
on more thought, current lack of autograd support and testing for complex, might introduce bc breaking changes.

I am in favor for not duplicating the logic, however if that introduces BC breaking on real value tensor input, then I think we can wait until the autograd support arrives.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's R2R with complex insides, the choice of JAX/TF convention doesn't matter, you'll always get the same gradients in the end.

# pack batch
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))
time_steps = torch.arange(0,
complex_specgrams.size(-1),
rate,
device=complex_specgrams.device,
dtype=torch.real(complex_specgrams).dtype)
phase_0 = complex_specgrams[..., :1].angle()
alphas = time_steps % 1.0
# Time Padding
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2])
# (new_bins, freq, 2)
complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long())
complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long())

angle_0 = complex_specgrams_0.angle()
angle_1 = complex_specgrams_1.angle()
norm_0 = complex_specgrams_0.abs()
norm_1 = complex_specgrams_1.abs()
else:
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
time_steps = torch.arange(0,
complex_specgrams.size(-2),
rate,
device=complex_specgrams.device,
dtype=complex_specgrams.dtype)
alphas = time_steps % 1.0
phase_0 = angle(complex_specgrams[..., :1, :])
# Time Padding
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2])
# (new_bins, freq, 2)
complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long())
complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long())

angle_0 = angle(complex_specgrams_0)
angle_1 = angle(complex_specgrams_1)
norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1)
norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1)

phase = angle_1 - angle_0 - phase_advance
phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))

# Compute Phase Accum
phase = phase + phase_advance
phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)

phase_acc = torch.cumsum(phase, -1)

mag = alphas * norm_1 + (1 - alphas) * norm_0

real_stretch = mag * torch.cos(phase_acc)
imag_stretch = mag * torch.sin(phase_acc)

complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
if use_complex:
complex_specgrams_stretch = torch.view_as_complex(torch.stack([real_stretch, imag_stretch], dim=-1))

# unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])
# unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
else:
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
# unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])

return complex_specgrams_stretch

Expand Down
5 changes: 4 additions & 1 deletion torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,10 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] =
Returns:
Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
"""
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
use_complex = complex_specgrams.is_complex()
if not use_complex:
assert complex_specgrams.size(-1) == 2, "complex_specgrams \
should be a complex tensor, shape (..., complex=2)"

if overriding_rate is None:
rate = self.fixed_rate
Expand Down