Skip to content

Commit

Permalink
[BC-Breaking] Default to native complex type when returning raw spect…
Browse files Browse the repository at this point in the history
…rogram

Part of pytorch#1337 .

- This code changes the return type of spectrogram to be native complex dtype,
when (and only when) returning raw (complex-valued) spectrogram.
- Change `return_complex=False` to `return_complex=True` in spectrogram ops.
- `return_complex` is only effective when `power` is `None`. It is ignored for
cases where `power` is not `None`. Because the returned Tensor is power spectrogram,
which is real-valued Tensors.
  • Loading branch information
mthrok committed Jun 3, 2021
1 parent 6882342 commit 57c7f7b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,34 @@ def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False):

self.assertEqual(ts_output, output)

def test_spectrogram(self):
def test_spectrogram_complex(self):
def func(tensor):
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2.
power = None
normalize = False
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)

tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor)

def test_spectrogram_real(self):
def func(tensor):
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2.
normalize = False
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize, return_complex=False)

tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor)

@skipIfRocm
def test_griffinlim(self):
def func(tensor):
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def spectrogram(
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: bool = False,
return_complex: bool = True,
) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: bool = False) -> None:
return_complex: bool = True) -> None:
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
Expand Down

0 comments on commit 57c7f7b

Please sign in to comment.