Skip to content

Commit

Permalink
[BC-Breaking] Ensure integer input frequencies for resample (#1857)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen authored Oct 13, 2021
1 parent 483d8fa commit 25a8adf
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 47 deletions.
23 changes: 2 additions & 21 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact
new_sample_rate = sample_rate

if up_scale_factor is not None:
new_sample_rate *= up_scale_factor
new_sample_rate = int(new_sample_rate * up_scale_factor)

if down_scale_factor is not None:
new_sample_rate //= down_scale_factor
new_sample_rate = int(new_sample_rate / down_scale_factor)

duration = 5 # seconds
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)
Expand Down Expand Up @@ -439,25 +439,6 @@ def test_resample_waveform_downsample_accuracy(self, resampling_method, i):
def test_resample_waveform_upsample_accuracy(self, resampling_method, i):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)

def test_resample_no_warning(self):
sample_rate = 44100
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1)

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.resample(waveform, float(sample_rate), sample_rate / 2.)
assert len(w) == 0

def test_resample_warning(self):
"""resample should throw a warning if an input frequency is not of an integer value"""
sample_rate = 44100
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1)

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.resample(waveform, sample_rate, 5512.5)
assert len(w) == 1

@nested_params(
[0.5, 1.01, 1.3],
[True, False],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,19 +659,19 @@ def func(tensor):

def test_resample_sinc(self):
def func(tensor):
sr1, sr2 = 16000., 8000.
sr1, sr2 = 16000, 8000
return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation")

tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, tensor)

def test_resample_kaiser(self):
def func(tensor):
sr1, sr2 = 16000., 8000.
sr1, sr2 = 16000, 8000
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window")

def func_beta(tensor):
sr1, sr2 = 16000., 8000.
sr1, sr2 = 16000, 8000
beta = 6.
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window", beta=beta)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_LFCC(self):
def test_Resample(self):
sr1, sr2 = 16000, 8000
tensor = common_utils.get_whitenoise(sample_rate=sr1)
self._assert_consistency(T.Resample(float(sr1), float(sr2)), tensor)
self._assert_consistency(T.Resample(sr1, sr2), tensor)

def test_ComplexNorm(self):
tensor = torch.rand((1, 2, 201, 2))
Expand Down
33 changes: 15 additions & 18 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,8 +1471,8 @@ def compute_kaldi_pitch(


def _get_sinc_resample_kernel(
orig_freq: float,
new_freq: float,
orig_freq: int,
new_freq: int,
gcd: int,
lowpass_filter_width: int,
rolloff: float,
Expand All @@ -1482,16 +1482,13 @@ def _get_sinc_resample_kernel(
dtype: Optional[torch.dtype] = None):

if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in release 0.10. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
raise Exception(
"Frequencies must be of integer type to ensure quality resampling computation. "
"To work around this, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function. "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5`. "
"For more information, please refer to https://github.com/pytorch/audio/issues/1487."
)

if resampling_method not in ['sinc_interpolation', 'kaiser_window']:
Expand Down Expand Up @@ -1562,8 +1559,8 @@ def _get_sinc_resample_kernel(

def _apply_sinc_resample_kernel(
waveform: Tensor,
orig_freq: float,
new_freq: float,
orig_freq: int,
new_freq: int,
gcd: int,
kernel: Tensor,
width: int,
Expand All @@ -1589,8 +1586,8 @@ def _apply_sinc_resample_kernel(

def resample(
waveform: Tensor,
orig_freq: float,
new_freq: float,
orig_freq: int,
new_freq: int,
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation",
Expand All @@ -1606,8 +1603,8 @@ def resample(
Args:
waveform (Tensor): The input signal of dimension `(..., time)`
orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency
orig_freq (int): The original frequency of the signal
new_freq (int): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Expand Down Expand Up @@ -1736,7 +1733,7 @@ def pitch_shift(
win_length=win_length,
window=window,
length=len_stretch)
waveform_shift = resample(waveform_stretch, sample_rate // rate, float(sample_rate))
waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)
shift_len = waveform_shift.size()[-1]
if shift_len > ori_len:
waveform_shift = waveform_shift[..., :ori_len]
Expand Down
8 changes: 4 additions & 4 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,8 +815,8 @@ class Resample(torch.nn.Module):
Alternatively, you could rewrite a transform that caches a higher precision kernel.
Args:
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
orig_freq (int, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (int, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method to use.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
Expand All @@ -840,8 +840,8 @@ class Resample(torch.nn.Module):

def __init__(
self,
orig_freq: float = 16000,
new_freq: float = 16000,
orig_freq: int = 16000,
new_freq: int = 16000,
resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
Expand Down

0 comments on commit 25a8adf

Please sign in to comment.