Skip to content

Commit

Permalink
Fix nan gradient by using native complex abs op
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jan 6, 2021
1 parent 6b07bcf commit ed798ad
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 20 deletions.
13 changes: 11 additions & 2 deletions test/torchaudio_unittest/functional/functional_cpu_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math
import unittest

import torch
import torchaudio
Expand All @@ -8,7 +7,7 @@
import pytest

from torchaudio_unittest import common_utils
from .functional_impl import Lfilter
from .functional_impl import Lfilter, Spectrogram


class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
Expand All @@ -21,6 +20,16 @@ class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
device = torch.device('cpu')


class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')


class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
def test_no_warning_high_n_freq(self):
with pytest.warns(None) as w:
Expand Down
14 changes: 13 additions & 1 deletion test/torchaudio_unittest/functional/functional_cuda_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest import common_utils
from .functional_impl import Lfilter
from .functional_impl import Lfilter, Spectrogram


@common_utils.skipIfNoCuda
Expand All @@ -14,3 +14,15 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
23 changes: 23 additions & 0 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test defintion common to CPU and CUDA"""
import torch
import torchaudio.functional as F
from parameterized import parameterized

from torchaudio_unittest import common_utils

Expand Down Expand Up @@ -29,3 +30,25 @@ def test_clamp(self):
assert output_signal.max() <= 1
output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False)
assert output_signal.max() > 1


class Spectrogram(common_utils.TestBaseMixin):
@parameterized.expand([(0., ), (1., ), (2., ), (3., )])
def test_grad_at_zero(self, power):
"""The gradient of power spectrogram should not be nan but zero near x=0
https://github.com/pytorch/audio/issues/993
"""
x = torch.zeros(1, 22050, requires_grad=True)
spec = F.spectrogram(
x,
pad=0,
window=None,
n_fft=2048,
hop_length=None,
win_length=None,
power=power,
normalized=False,
)
spec.sum().backward()
assert not x.grad.isnan().sum()
33 changes: 16 additions & 17 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,30 +70,29 @@ def spectrogram(
waveform = waveform.reshape(-1, shape[-1])

# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = torch.view_as_real(
torch.stft(
input=waveform,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec_f = torch.stft(
input=waveform,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)

# unpack batch
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])

if normalized:
spec_f /= window.pow(2.).sum().sqrt()
if power is not None:
spec_f = complex_norm(spec_f, power=power)

return spec_f
if power == 1.0:
return spec_f.abs()
return spec_f.abs().pow(power)
return torch.view_as_real(spec_f)


def griffinlim(
Expand Down

0 comments on commit ed798ad

Please sign in to comment.