Skip to content

Commit

Permalink
make use_complex not caps
Browse files Browse the repository at this point in the history
  • Loading branch information
anjali411 committed Aug 7, 2020
1 parent d75f6fc commit c3d9b8e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import scipy

import pytest

from torchaudio_unittest import common_utils


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestFunctional(common_utils.TorchaudioTestCase):
"""Test suite for functions in `functional` module."""
Expand All @@ -35,7 +35,8 @@ def test_phase_vocoder(self, rate):
# result in bottom right values of the stretched sectrogram to not
# match with librosa.

phase_advance = torch.linspace(0, np.pi * hop_length, complex_specgrams.shape[-2], dtype=torch.double)[..., None]
phase_advance = torch.linspace(0, np.pi * hop_length,
complex_specgrams.shape[-2], dtype=torch.double)[..., None]

complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)

Expand All @@ -50,7 +51,7 @@ def test_phase_vocoder(self, rate):
index = [0] + [slice(None)] * 2
mono_complex_specgram = complex_specgrams[index].numpy()
expected_complex_stretch = librosa.phase_vocoder(mono_complex_specgram,
rate=rate,
hop_length=hop_length)
rate=rate,
hop_length=hop_length)

self.assertEqual(complex_specgrams_stretch[index], torch.from_numpy(expected_complex_stretch))
6 changes: 3 additions & 3 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,9 +490,9 @@ def phase_vocoder(
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231])
"""
USE_COMPLEX = complex_specgrams.is_complex()
use_complex = complex_specgrams.is_complex()
shape = complex_specgrams.size()
if USE_COMPLEX:
if use_complex:
# pack batch
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))
time_steps = torch.arange(0,
Expand Down Expand Up @@ -546,7 +546,7 @@ def phase_vocoder(
real_stretch = mag * torch.cos(phase_acc)
imag_stretch = mag * torch.sin(phase_acc)

if USE_COMPLEX:
if use_complex:
complex_specgrams_stretch = torch.view_as_complex(torch.stack([real_stretch, imag_stretch], dim=-1))

# unpack batch
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,8 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] =
Returns:
Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
"""
USE_COMPLEX = complex_specgrams.is_complex()
if not USE_COMPLEX:
use_complex = complex_specgrams.is_complex()
if not use_complex:
assert complex_specgrams.size(-1) == 2, "complex_specgrams \
should be a complex tensor, shape (..., complex=2)"

Expand Down

0 comments on commit c3d9b8e

Please sign in to comment.