Skip to content

Commit

Permalink
Add ComputeKaldiPitch
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jan 12, 2021
1 parent 72b7680 commit 044bfc2
Show file tree
Hide file tree
Showing 51 changed files with 9,328 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ jobs:
torchscript_bc_test:
docker:
- image: "pytorch/torchaudio_unittest_base:manylinux"
resource_class: large
resource_class: 2xlarge+
steps:
- checkout
- generate_cache_key
Expand Down
2 changes: 1 addition & 1 deletion .circleci/config.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ jobs:
torchscript_bc_test:
docker:
- image: "pytorch/torchaudio_unittest_base:manylinux"
resource_class: large
resource_class: 2xlarge+
steps:
- checkout
- generate_cache_key
Expand Down
3 changes: 3 additions & 0 deletions build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def _get_ela(debug):
def _get_srcs():
srcs = [_CSRC_DIR / 'pybind.cpp']
srcs += list(_CSRC_DIR.glob('sox/**/*.cpp'))
srcs += list(_CSRC_DIR.glob('kaldi/**/*.cpp'))
srcs += list(_CSRC_DIR.glob('kaldi/**/*.cc'))
if _BUILD_TRANSDUCER:
srcs += [_CSRC_DIR / 'transducer.cpp']
return [str(path) for path in srcs]
Expand All @@ -72,6 +74,7 @@ def _get_srcs():
def _get_include_dirs():
dirs = [
str(_ROOT_DIR),
str(_CSRC_DIR / 'kaldi'),
]
if _BUILD_SOX or _BUILD_TRANSDUCER:
dirs.append(str(_TP_INSTALL_DIR / 'include'))
Expand Down
5 changes: 5 additions & 0 deletions test/torchaudio_unittest/assets/kaldi_test_pitch_args.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{"sample_frequency": 8000}
{"sample_frequency": 8000, "frames_per_chunk": 200}
{"sample_frequency": 8000, "frames_per_chunk": 200, "simulate_first_pass_online": true}
{"sample_frequency": 16000}
{"sample_frequency": 44100}
7 changes: 6 additions & 1 deletion test/torchaudio_unittest/kaldi_compatibility_cpu_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from torchaudio_unittest import common_utils
from .kaldi_compatibility_impl import Kaldi
from .kaldi_compatibility_impl import Kaldi, KaldiCPUOnly


class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
Expand All @@ -12,3 +12,8 @@ class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestKaldiCPUOnly(KaldiCPUOnly, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
27 changes: 25 additions & 2 deletions test/torchaudio_unittest/kaldi_compatibility_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

from torchaudio_unittest.common_utils import (
TestBaseMixin,
TempDirMixin,
load_params,
skipIfNoExec,
get_asset_path,
load_wav
load_wav,
save_wav,
get_sinusoid,
)


Expand Down Expand Up @@ -48,11 +51,13 @@ def _run_kaldi(command, input_type, input_value):
return torch.from_numpy(result.copy()) # copy supresses some torch warning


class Kaldi(TestBaseMixin):
class KaldiTestBase(TempDirMixin, TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)


class Kaldi(KaldiTestBase):
@skipIfNoExec('apply-cmvn-sliding')
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
Expand Down Expand Up @@ -101,3 +106,21 @@ def test_mfcc(self, kwargs):
command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)


class KaldiCPUOnly(KaldiTestBase):
@parameterized.expand(load_params('kaldi_test_pitch_args.json'))
@skipIfNoExec('compute-kaldi-pitch-feats')
def test_pitch_feats(self, kwargs):
"""compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats"""
sample_rate = kwargs['sample_frequency']
waveform = get_sinusoid(dtype='float32', sample_rate=sample_rate)
result = F.compute_kaldi_pitch(waveform[0], **kwargs)

waveform = get_sinusoid(dtype='int16', sample_rate=sample_rate)
wave_file = self.get_temp_path('test.wav')
save_wav(wave_file, waveform, sample_rate)

command = ['compute-kaldi-pitch-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result)
Loading

0 comments on commit 044bfc2

Please sign in to comment.