Skip to content

Commit

Permalink
Import torchaudio #1412 c0bfb03
Browse files Browse the repository at this point in the history
Summary:
Import latest from github to fbcode

Pass: 951
Skip: 19
Omit: 1
ListingSuccess: 26

Result available at: https://www.internalfb.com/intern/testinfra/testrun/8444249336935844

Reviewed By: mthrok

Differential Revision: D27448988

fbshipit-source-id: 61f63ffa1295a31b4452abaf2c74ebfefb827dcf
  • Loading branch information
parmeet authored and cpuhrsch committed Apr 2, 2021
1 parent 7e37c35 commit ab86b88
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 103 deletions.
7 changes: 6 additions & 1 deletion docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ apply_codec
-----------

.. autofunction:: apply_codec

:hidden:`Complex Utility`
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -230,3 +230,8 @@ vad
---------------------------

.. autofunction:: spectral_centroid

:hidden:`resample`
---------------------------

.. autofunction:: resample
1 change: 0 additions & 1 deletion test/torchaudio_unittest/common_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
skipIfNoModule,
skipIfNoKaldi,
skipIfNoSox,
skipIfNoSoxBackend,
)
from .wav_utils import (
get_wav_data,
Expand Down
3 changes: 0 additions & 3 deletions test/torchaudio_unittest/common_utils/case_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import torch
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio
from torchaudio._internal.module_utils import (
is_module_available,
is_sox_available,
Expand Down Expand Up @@ -96,8 +95,6 @@ def skipIfNoModule(module, display_name=None):
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')


skipIfNoSoxBackend = unittest.skipIf(
'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available')
skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason='Kaldi not available')
3 changes: 1 addition & 2 deletions test/torchaudio_unittest/compliance_kaldi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def first_sample_of_frame(frame, window_size, window_shift, snip_edges):

@common_utils.skipIfNoSox
class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
backend = 'sox_io'

kaldi_output_dir = common_utils.get_asset_path('kaldi')
test_filepath = common_utils.get_asset_path('kaldi_file.wav')
Expand Down Expand Up @@ -113,7 +112,7 @@ def _create_data_set(self):
# clear the last 16 bits because they aren't used anyways
y = ((y >> 16) << 16).float()
torchaudio.save(self.test_filepath, y, sr)
sound, sample_rate = torchaudio.load(self.test_filepath, normalization=False)
sound, sample_rate = common_utils.load_wav(self.test_filepath, normalize=False)
print(y >> 16)
self.assertTrue(sample_rate == sr)
self.assertEqual(y, sound)
Expand Down
19 changes: 19 additions & 0 deletions test/torchaudio_unittest/functional/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,25 @@ def test_amplitude_to_DB(self):

self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)

def test_resample(self):
input_path = common_utils.get_asset_path('sinewave.wav')
waveform, sample_rate = common_utils.load_wav(input_path)

upsample_rate = sample_rate * 2
downsample_rate = sample_rate // 2

ta_upsampled = F.resample(waveform, sample_rate, upsample_rate)
lr_upsampled = librosa.resample(waveform.squeeze(0).numpy(), sample_rate, upsample_rate)
lr_upsampled = torch.from_numpy(lr_upsampled).unsqueeze(0)

self.assertEqual(ta_upsampled, lr_upsampled, atol=1e-2, rtol=1e-5)

ta_downsampled = F.resample(waveform, sample_rate, downsample_rate)
lr_downsampled = librosa.resample(waveform.squeeze(0).numpy(), sample_rate, downsample_rate)
lr_downsampled = torch.from_numpy(lr_downsampled).unsqueeze(0)

self.assertEqual(ta_downsampled, lr_downsampled, atol=1e-2, rtol=1e-5)


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestPhaseVocoder(common_utils.TorchaudioTestCase):
Expand Down
4 changes: 2 additions & 2 deletions test/torchaudio_unittest/functional/sox_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torchaudio.functional as F

from torchaudio_unittest.common_utils import (
skipIfNoSoxBackend,
skipIfNoSox,
skipIfNoExec,
TempDirMixin,
TorchaudioTestCase,
Expand All @@ -14,7 +14,7 @@
)


@skipIfNoSoxBackend
@skipIfNoSox
@skipIfNoExec('sox')
class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
def run_sox_effect(self, input_file, effect):
Expand Down
14 changes: 12 additions & 2 deletions test/torchaudio_unittest/transforms/sox_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
from parameterized import parameterized

from torchaudio_unittest.common_utils import (
skipIfNoSoxBackend,
skipIfNoSox,
skipIfNoExec,
TempDirMixin,
TorchaudioTestCase,
get_asset_path,
sox_utils,
load_wav,
save_wav,
get_whitenoise,
)


@skipIfNoSoxBackend
@skipIfNoSox
@skipIfNoExec('sox')
class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
def run_sox_effect(self, input_file, effect):
Expand All @@ -24,6 +26,14 @@ def assert_sox_effect(self, result, input_path, effects, atol=1e-04, rtol=1e-5):
expected, _ = self.run_sox_effect(input_path, effects)
self.assertEqual(result, expected, atol=atol, rtol=rtol)

def get_whitenoise(self, sample_rate=8000):
noise = get_whitenoise(
sample_rate=sample_rate, duration=3, scale_factor=0.9,
)
path = self.get_temp_path("whitenoise.wav")
save_wav(path, noise, sample_rate)
return noise, path

@parameterized.expand([
('q', 'quarter_sine'),
('h', 'half_sine'),
Expand Down
81 changes: 4 additions & 77 deletions torchaudio/compliance/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import math
import torch
from torch import Tensor
from torch.nn import functional as F

import torchaudio
import torchaudio._internal.fft
Expand Down Expand Up @@ -753,71 +752,16 @@ def mfcc(
return feature


def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
device: torch.device, dtype: torch.dtype):
assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
# This will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts.
base_freq *= 0.99

# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
# We can then sample the function x(t) with a different sample rate:
# y[j] = x(j / new_freq)
# or,
# y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))

# We see here that y[j] is the convolution of x[i] with a specific filter, for which
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
# Indeed:
# y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
# This will explain the F.conv1d after, with a stride of orig_freq.
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)

for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
t *= math.pi
# we do not use torch.hann_window here as we need to evaluate the window
# at specific positions, not over a regular grid.
window = torch.cos(t / lowpass_filter_width / 2)**2
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
kernel.mul_(window)
kernels.append(kernel)

scale = base_freq / orig_freq
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width


def resample_waveform(waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6) -> Tensor:
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to
upsample/downsample the signal.
r"""Resamples the waveform at the new frequency.
https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html
https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56
This is a wrapper around ``torchaudio.functional.resample``.
Args:
waveform (Tensor): The input signal of size (c, n)
waveform (Tensor): The input signal of size (..., time)
orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
Expand All @@ -826,21 +770,4 @@ def resample_waveform(waveform: Tensor,
Returns:
Tensor: The waveform at the new frequency
"""
assert waveform.dim() == 2
assert orig_freq > 0.0 and new_freq > 0.0

orig_freq = int(orig_freq)
new_freq = int(new_freq)
gcd = math.gcd(orig_freq, new_freq)
orig_freq = orig_freq // gcd
new_freq = new_freq // gcd

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
waveform.device, waveform.dtype)

num_wavs, length = waveform.shape
waveform = F.pad(waveform, (width, width + orig_freq))
resampled = F.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
return resampled[..., :target_length]
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width)
4 changes: 2 additions & 2 deletions torchaudio/datasets/yesno.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"release1": {
"folder_in_archive": "waves_yesno",
"url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz",
"checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27",
"checksum": "c3f49e0cca421f96b75b41640749167b52118f232498667ca7a5f9416aef8e73",
}
}

Expand Down Expand Up @@ -54,7 +54,7 @@ def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, downloa
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _RELEASE_CONFIGS["release1"]["checksum"]
download_url(url, root, hash_value=checksum, hash_type="md5")
download_url(url, root, hash_value=checksum)
extract_archive(archive)

if not os.path.isdir(self._path):
Expand Down
4 changes: 3 additions & 1 deletion torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
spectrogram,
spectral_centroid,
apply_codec,
resample,
)
from .filtering import (
allpass_biquad,
Expand Down Expand Up @@ -85,5 +86,6 @@
'riaa_biquad',
'treble_biquad',
'vad',
'apply_codec'
'apply_codec',
'resample',
]
103 changes: 103 additions & 0 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
'sliding_window_cmn',
"spectral_centroid",
"apply_codec",
"resample",
]


Expand Down Expand Up @@ -1209,3 +1210,105 @@ def compute_kaldi_pitch(
)
result = result.reshape(shape[:-1] + result.shape[-2:])
return result


def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
device: torch.device, dtype: torch.dtype):
assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
# This will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts.
base_freq *= 0.99

# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
# We can then sample the function x(t) with a different sample rate:
# y[j] = x(j / new_freq)
# or,
# y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))

# We see here that y[j] is the convolution of x[i] with a specific filter, for which
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
# Indeed:
# y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
# This will explain the F.conv1d after, with a stride of orig_freq.
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)

for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
t *= math.pi
# we do not use torch.hann_window here as we need to evaluate the window
# at specific positions, not over a regular grid.
window = torch.cos(t / lowpass_filter_width / 2)**2
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
kernel.mul_(window)
kernels.append(kernel)

scale = base_freq / orig_freq
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width


def resample(
waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6
) -> Tensor:
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to
upsample/downsample the signal.
https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html
https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56
Args:
waveform (Tensor): The input signal of dimension (..., time)
orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
"""
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])

assert orig_freq > 0.0 and new_freq > 0.0

orig_freq = int(orig_freq)
new_freq = int(new_freq)
gcd = math.gcd(orig_freq, new_freq)
orig_freq = orig_freq // gcd
new_freq = new_freq // gcd

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
waveform.device, waveform.dtype)

num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
resampled = resampled[..., :target_length]

# unpack batch
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
return resampled
Loading

0 comments on commit ab86b88

Please sign in to comment.